Buckets:

|
download
raw
110 kB

Title: GFlowOut: Dropout with Generative Flow Networks

URL Source: https://arxiv.org/html/2210.12928

Markdown Content: Moksh Jain Bonaventure F. P. Dossou Qianli Shen Salem Lahlou Anirudh Goyal Nikolay Malkin Chris C. Emezue Dinghuai Zhang Nadhir Hassen Xu Ji Kenji Kawaguchi Yoshua Bengio

Abstract

Bayesian inference offers principled tools to tackle many critical problems with modern neural networks such as poor calibration and generalization, and data inefficiency. However, scaling Bayesian inference to large architectures is challenging and requires restrictive approximations. Monte Carlo Dropout has been widely used as a relatively cheap way to approximate inference and estimate uncertainty with deep neural networks. Traditionally, the dropout mask is sampled independently from a fixed distribution. Recent research shows that the dropout mask can be seen as a latent variable, which can be inferred with variational inference. These methods face two important challenges: (a) the posterior distribution over masks can be highly multi-modal which can be difficult to approximate with standard variational inference and (b) it is not trivial to fully utilize sample-dependent information and correlation among dropout masks to improve posterior estimation. In this work, we propose GFlowOut to address these issues. GFlowOut leverages the recently proposed probabilistic framework of Generative Flow Networks (GFlowNets) to learn the posterior distribution over dropout masks. We empirically demonstrate that GFlowOut results in predictive distributions that generalize better to out-of-distribution data and provide uncertainty estimates which lead to better performance in downstream tasks.

Machine Learning, ICML

1 Introduction

A key shortcoming of modern deep neural networks is that they are often overconfident about their predictions, especially when there is a distributional shift between train and test dataset(Daxberger et al., 2021; Nguyen et al., 2015; Guo et al., 2017). In risk-sensitive scenarios such as clinical practice and drug discovery, where mistakes can be extremely costly, it is important that models provide predictions with reliable uncertainty estimates(Bhatt et al., 2021). Bayesian inference offers principled tools to model the parameters of neural networks as random variables, placing a prior on them and inferring their posterior given some observed data(MacKay, 1992; Neal, 2012). The posterior captures the uncertainty in the predictions of the model and also serves as an effective regularization strategy resulting in improved generalization(Wilson & Izmailov, 2020; Lotfi et al., 2022). In practice, exact Bayesian inference is often intractable and existing Bayesian deep learning methods rely on assumptions that result in posteriors that are less expressive and can provide poorly calibrated uncertainty estimates(Ovadia et al., 2019; Fort et al., 2019; Foong et al., 2020; Daxberger et al., 2021). In addition, even with several approximations, Bayesian deep learning methods are often significantly more computationally expensive and slower to train compared to non-Bayesian methods(Kuleshov et al., 2018; Boluki et al., 2020).

Image 1: Refer to caption

Figure 1: In this work, we propose a Generative Flow Network (GFlowNet) based binary dropout mask generator which we refer to as GFlowOut. Purple squares are GFlowNet-based dropout mask generators parameterized as multi-layer perceptrons. z i,l subscript ๐‘ง ๐‘– ๐‘™ z_{i,l}italic_z start_POSTSUBSCRIPT italic_i , italic_l end_POSTSUBSCRIPT refers to dropout masks for data point indexed by i ๐‘– i italic_i at layer l ๐‘™ l italic_l of the model. h i,l subscript โ„Ž ๐‘– ๐‘™ h_{i,l}italic_h start_POSTSUBSCRIPT italic_i , italic_l end_POSTSUBSCRIPT refers to activations of the model at layer l ๐‘™ l italic_l given input x i subscript ๐‘ฅ ๐‘– x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. qโข(โ‹…)๐‘žโ‹…q(\cdot)italic_q ( โ‹… ) are auxiliary variational functions used and adapted only during model training, in which the posterior distribution over dropout masks is conditioned implicitly on input covariates (x i subscript ๐‘ฅ ๐‘– x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) and directly on the label (y i subscript ๐‘ฆ ๐‘– y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT) of the data point to make the estimation easier. pโข(โ‹…)๐‘โ‹…p(\cdot)italic_p ( โ‹… ) are mask generation functions used at test time, which are only conditioned on x i subscript ๐‘ฅ ๐‘– x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and trained by minimizing the Kullbackโ€“Leibler(KL) divergence with qโข(โ‹…)๐‘žโ‹…q(\cdot)italic_q ( โ‹… ). In addition, both qโข(โ‹…)๐‘žโ‹…q(\cdot)italic_q ( โ‹… ) and pโข(โ‹…)๐‘โ‹…p(\cdot)italic_p ( โ‹… ) conditions explicitly on dropout masks of all previous layers.

Gal and Ghahramani (2016) show that deep neural networks with dropout perform approximate Bayesian inference and approximate the posterior of a deep Gaussian process(Damianou & Lawrence, 2013). One can obtain samples from this predictive distribution by taking multiple forward passes through the neural network with independently sampled dropout masks. Due to its simplicity and minimal computational overhead, dropout has since been used as a method to estimate uncertainty and improve robustness in neural networks. Different variants of dropout have been proposed and can be interpreted as different variational approximations to model the posterior over the neural network parameters(Ba & Frey, 2013; Kingma et al., 2015; Gal et al., 2017; Ghiasi et al., 2018; Fan et al., 2021; Pham & Le, 2021).

There are a few major challenges in approximating the Bayesian posterior over model parameters using dropout: (1) the multimodal nature of the posterior distribution makes it difficult to approximate with standard variational inference(Gal & Ghahramani, 2016; Le Folgoc et al., 2021), which assumes factorized priors; (2) dropout masks are discrete objects making gradient-based optimization difficult(Boluki et al., 2020); (3) variational inference methods can suffer from high gradient variance resulting in optimization instability(Kingma et al., 2015); (4) modeling dependence between dropout masks from different layers is non-trivial.

The recently proposed Generative Flow Networks (GFlowNets)(Bengio et al., 2021a, b) frame the problem of generating discrete objects as a control problem based on the sequential construction of discrete components. GFlowNets learn probabilistic policies that sample objects proportional to a reward function (or exp(-energy)). They have demonstrated better generalization to multimodal distributions(Nica et al., 2022) and have lower gradient variance compared with policy gradient-based variational methods (Malkin et al., 2023), making it an interesting choice for posterior inference for dropout.

Contributions. In this work, to address the limitations of standard variational inference, we develop a GFlowNet-based binary dropout mask generator which we refer to as GFlowOut, to estimate the posterior distribution of binary dropout masks. GFlowOut generates dropout masks for a layer, conditioned on masks generated for the previous layer, therefore accounting for inter-layer dropout dependence. Furthermore, the GFlowOut estimator can be conditioned on the data point: GFlowOut improves posterior estimation here by utilizing both input covariates and labels in the training set of supervised learning tasks via an auxiliary variational function. To investigate the quality of the posterior distribution learned by GFlowOut, we design empirical experiments, including evaluating robustness to distribution shift during inference, detecting out-of-distribution examples with uncertainty estimates, and transfer learning, using both benchmark datasets and a real-world clinical dataset.

2 Related work

2.1 Dropout as a Bayesian approximation

Deep learning tools have shown tremendous power in different applications. However, traditional deep learning tools lack mechanisms to capture the uncertainty, which is of crucial importance in many fields. Uncertainty quantification (UQ) is studied extensively as a fundamental problem of deep learning and a large number of Bayesian deep learning tools have emerged in recent years. For example, Gal & Ghahramani (2016) showed that casting dropout in deep learning model training is an approximation of Bayesian inference in deep Gaussian processes and allows uncertainty estimation without extra computational cost. Kingma et al. (2015) proposed variational dropout, where a dropout posterior over parameters is learned by treating dropout regularization as approximate inference in deep models. Gal et al. (2017) developed a continuous relaxation of discrete dropout masks to improve uncertainty estimation, especially in reinforcement learning settings. Lee et al. (2020) introduced โ€œmeta-dropoutโ€, which involves an additional global term shared across all data points during inference to improve generalization. Xie et al. (2019) replaced the hard dropout mask following a Bernoulli distribution with the soft mask following a beta distribution and conducted the optimization using a stochastic gradient variational Bayesian algorithm to control the dropout rate. Boluki et al. (2020) combined a model-agnostic dropout scheme with variational auto-encoders (VAEs), resulting in semi-implicit VAE models. Instead of using mean-field family for variational inference, Nguyen et al. (2021) utilized a structured representation of multiplicative Gaussian noise for better posterior estimation. More recently, Fan et al. (2021) developed โ€œcontextual dropoutโ€, which optimizes variational objectives in a sample-dependent manner and, to the best of our knowledge, is the closest approach to GFlowOut in the literature. GFlowOut differs from contextual dropout in several aspects. First, both methods take trainable priors into account, but GFlowNet also takes into account priors that depend on the input covariate of each data point. Second, the variational posterior of contextual dropout only depends on the input covariate (x ๐‘ฅ x italic_x), while in GFlowOut, the variational posterior is also conditioned on the label y ๐‘ฆ y italic_y, which provides more information for training. Third, within each neural network layer, the mask of contextual dropout is conditioned on previous masks implicitly, while the mask of GFlowOut is conditioned on previous masks explicitly by directly feeding previous masks as inputs into the generator, which improves the training process. Finally, instead of a REINFORCE-based gradient estimator used for contextual dropout training, GFlowOut employs powerful GFlowNets for the variational posterior.

2.2 Generative flow networks

Generative flow networks (GFlowNets)(Bengio et al., 2021a, b) are a family of probabilistic models that amortizes sampling discrete compositional objects proportionally to a given unnormalized density function. GFlowNets learn a stochastic policy to construct objects through a sequence of actions akin to deep reinforcement learning(Sutton & Barto, 2018). GFlowNets are trained so as to make the likelihood of reaching a terminating state proportional to the reward. Recent works have shown close connections of GFlowNets to other generative models(Zhang et al., 2022a) and to hierarchical variational inference(Malkin et al., 2023). GFlowNets achieved great empirical success in learning energy-based models(Zhang et al., 2022b), small-molecule generation(Bengio et al., 2021a; Nica et al., 2022; Malkin et al., 2022; Madan et al., 2023; Pan et al., 2023), biological sequence generation(Malkin et al., 2022; Jain et al., 2022; Madan et al., 2023), and structure learning(Deleu et al., 2022). Several training objectives have been proposed for GFlowNets, including Flow Matching (FM)(Bengio et al., 2021a), Detailed Balance (DB)(Bengio et al., 2021b), Trajectory Balance (TB)(Malkin et al., 2022), and the more recent Sub-Trajectory Balance (SubTB)(Madan et al., 2023). In this work, we use the Trajectory Balance (TB) objective.

3 Method

In this section, we define the problem setting and mathematical notations used in this study, as well as describe the proposed method, GFlowOut, for dropout mask generation in detail.

3.1 Background and notation

Dropout. In a vanilla feed-forward neural network (MLP) with L ๐ฟ L italic_L layers, each layer of the model has weight matrix w l subscript ๐‘ค ๐‘™ w_{l}italic_w start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT and bias vector b l subscript ๐‘ ๐‘™ b_{l}italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. It takes as input activations h lโˆ’1 subscript โ„Ž ๐‘™ 1 h_{l-1}italic_h start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT from previous layer with layer index lโˆ’1 ๐‘™ 1 l-1 italic_l - 1, and computes as output h l=ฯƒโข(w lโขh lโˆ’1+b l)subscript โ„Ž ๐‘™ ๐œŽ subscript ๐‘ค ๐‘™ subscript โ„Ž ๐‘™ 1 subscript ๐‘ ๐‘™ h_{l}=\sigma(w_{l}h_{l-1}+b_{l})italic_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = italic_ฯƒ ( italic_w start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) where ฯƒ ๐œŽ\sigma italic_ฯƒ is a non-linear activation function. Dropout consists of dropping out units from the output of a layer. Formally this can be described as applying a sampled binary mask z lโˆผpโข(z l)similar-to subscript ๐‘ง ๐‘™ ๐‘ subscript ๐‘ง ๐‘™ z_{l}\sim p(z_{l})italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT โˆผ italic_p ( italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) on the output of the layer h l=z lโˆ˜ฯƒโข(w lโขh lโˆ’1+b l)subscript โ„Ž ๐‘™ subscript ๐‘ง ๐‘™ ๐œŽ subscript ๐‘ค ๐‘™ subscript โ„Ž ๐‘™ 1 subscript ๐‘ ๐‘™ h_{l}=z_{l}\circ\sigma(w_{l}h_{l-1}+b_{l})italic_h start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT โˆ˜ italic_ฯƒ ( italic_w start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT + italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ), at each layer in the model. In regular random dropout, z l subscript ๐‘ง ๐‘™ z_{l}italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is a collection of i.i.d.Bernoulliโข(r)Bernoulli ๐‘Ÿ{\rm Bernoulli}(r)roman_Bernoulli ( italic_r ) variables, where r ๐‘Ÿ r italic_r is a fixed parameter for all the layers. Recently, several approaches have been proposed to learn pโข(z l)๐‘ subscript ๐‘ง ๐‘™ p(z_{l})italic_p ( italic_z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) along with the model parameters. In these approaches, z ๐‘ง z italic_z is viewed either as latent variables or part of the model parameters. We consider two variants for our proposed method: GFlowOut where the dropout masks z ๐‘ง z italic_z are viewed as sample dependent latent variables, and ID-GFlowOut, which generates masks in a sample independent manner where z ๐‘ง z italic_z is viewed as a part of the model parameters shared across all samples. Next, we briefly introduce GFlowNets and describe how they model the dropout masks z ๐‘ง z italic_z given the data.

GFlowNets. Let G=(๐’ฎ,๐’œ)๐บ ๐’ฎ ๐’œ G=(\mathcal{S},\mathcal{A})italic_G = ( caligraphic_S , caligraphic_A ) be a directed acyclic graph (DAG) where the vertices sโˆˆ๐’ฎ ๐‘  ๐’ฎ s\in\mathcal{S}italic_s โˆˆ caligraphic_S are states, including a special initial state s 0 subscript ๐‘  0 s_{0}italic_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT with no incoming edges, and directed edges (sโ†’sโ€ฒ)โˆˆ๐’œโ†’๐‘  superscript ๐‘ โ€ฒ๐’œ(s\rightarrow s^{\prime})\in\mathcal{A}( italic_s โ†’ italic_s start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) โˆˆ caligraphic_A are actions. ๐’ณโІ๐’ฎ ๐’ณ ๐’ฎ\mathcal{X}\subseteq\mathcal{S}caligraphic_X โІ caligraphic_S denotes the terminal states, with no outgoing edges. A complete trajectory ฯ„=(s 0โ†’โ€ฆโขs iโˆ’1โ†’s iโขโ€ฆโ†’z)โˆˆ๐’ฏ ๐œโ†’subscript ๐‘  0โ€ฆsubscript ๐‘  ๐‘– 1โ†’subscript ๐‘  ๐‘–โ€ฆโ†’๐‘ง ๐’ฏ\tau=(s_{0}\rightarrow\dots s_{i-1}\rightarrow s_{i}\dots\rightarrow z)\in% \mathcal{T}italic_ฯ„ = ( italic_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT โ†’ โ€ฆ italic_s start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT โ†’ italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT โ€ฆ โ†’ italic_z ) โˆˆ caligraphic_T in G ๐บ G italic_G is a sequence of states starting at s 0 subscript ๐‘  0 s_{0}italic_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and terminating at zโˆˆ๐’ณ ๐‘ง ๐’ณ z\in\mathcal{X}italic_z โˆˆ caligraphic_X where each (s iโˆ’1โ†’s i)โˆˆ๐’œโ†’subscript ๐‘  ๐‘– 1 subscript ๐‘  ๐‘– ๐’œ(s_{i-1}\rightarrow s_{i})\in\mathcal{A}( italic_s start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT โ†’ italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) โˆˆ caligraphic_A. The forward policy P F(โˆ’|s)P_{F}(-|s)italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ( - | italic_s ) is a collection of distributions over the children of each non-terminal node sโˆˆ๐’ฎ ๐‘  ๐’ฎ s\in\mathcal{S}italic_s โˆˆ caligraphic_S and defines a distribution over complete trajectories, P Fโข(ฯ„)=โˆ(s iโˆ’1โ†’s i)โˆˆฯ„ P Fโข(s i|s iโˆ’1)subscript ๐‘ƒ ๐น ๐œ subscript productโ†’subscript ๐‘  ๐‘– 1 subscript ๐‘  ๐‘– ๐œ subscript ๐‘ƒ ๐น conditional subscript ๐‘  ๐‘– subscript ๐‘  ๐‘– 1 P_{F}(\tau)=\prod_{(s_{i-1}\rightarrow s_{i})\in\tau}P_{F}(s_{i}|s_{i-1})italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ( italic_ฯ„ ) = โˆ start_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT โ†’ italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) โˆˆ italic_ฯ„ end_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_s start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT ). We can sample terminal states zโˆˆ๐’ณ ๐‘ง ๐’ณ z\in\mathcal{X}italic_z โˆˆ caligraphic_X by sampling trajectories following P F subscript ๐‘ƒ ๐น P_{F}italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT. Let ฯ€โข(x)๐œ‹ ๐‘ฅ\pi(x)italic_ฯ€ ( italic_x ) be the marginal likelihood of sampling terminal state x ๐‘ฅ x italic_x, ฯ€โข(z)=โˆ‘ฯ„=(s 0โ†’โ€ฆโ†’z)โˆˆ๐’ฏ P Fโข(ฯ„)๐œ‹ ๐‘ง subscript ๐œโ†’subscript ๐‘  0โ€ฆโ†’๐‘ง ๐’ฏ subscript ๐‘ƒ ๐น ๐œ\pi(z)=\sum_{\tau=(s_{0}\rightarrow\dots\rightarrow z)\in\mathcal{T}}P_{F}(\tau)italic_ฯ€ ( italic_z ) = โˆ‘ start_POSTSUBSCRIPT italic_ฯ„ = ( italic_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT โ†’ โ€ฆ โ†’ italic_z ) โˆˆ caligraphic_T end_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ( italic_ฯ„ ). Given a non-negative reward function R:๐’ณโ†’โ„+:๐‘…โ†’๐’ณ superscript โ„ R:\mathcal{X}\to\mathbb{R}^{+}italic_R : caligraphic_X โ†’ blackboard_R start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT, the learning problem tackled in GFlowNets is to estimate P F subscript ๐‘ƒ ๐น P_{F}italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT such that ฯ€โข(z)โˆRโข(z),โˆ€zโˆˆ๐’ณ formulae-sequence proportional-to ๐œ‹ ๐‘ง ๐‘… ๐‘ง for-all ๐‘ง ๐’ณ\pi(z)\propto R(z),\enspace\forall z\in\mathcal{X}italic_ฯ€ ( italic_z ) โˆ italic_R ( italic_z ) , โˆ€ italic_z โˆˆ caligraphic_X. We refer the reader to Bengio et al. (2021b); Malkin et al. (2022) for a more thorough introduction to GFlowNets.

We adopt the Trajectory Balance (TB)(Malkin et al., 2022) parameterization, which includes P F(โˆ’|โˆ’;ฯ•),P B(โˆ’|โˆ’;ฯ•)P_{F}(-|-;\phi),P_{B}(-|-;\phi)italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ( - | - ; italic_ฯ• ) , italic_P start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( - | - ; italic_ฯ• ), and Z ฮณ subscript ๐‘ ๐›พ Z_{\gamma}italic_Z start_POSTSUBSCRIPT italic_ฮณ end_POSTSUBSCRIPT, where ฯ• italic-ฯ•\phi italic_ฯ• and ฮณ ๐›พ\gamma italic_ฮณ are the learnable parameters. The backward policy P B subscript ๐‘ƒ ๐ต P_{B}italic_P start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT is a distribution over parents of every noninitial state, and Z ๐‘ Z italic_Z is an estimate of the partition function.

Within the context of generating dropout masks, a complete dropout mask xโˆˆ๐’ณ ๐‘ฅ ๐’ณ x\in\mathcal{X}italic_x โˆˆ caligraphic_X is a binary vector of dimension M ๐‘€ M italic_M, where M ๐‘€ M italic_M is the number of units in the neural network, i.e. ๐’ณ ๐’ณ\mathcal{X}caligraphic_X is equal to {0,1}M superscript 0 1 ๐‘€{0,1}^{M}{ 0 , 1 } start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT. A partially constructed mask sโˆˆ๐’ฎ ๐‘  ๐’ฎ s\in\mathcal{S}italic_s โˆˆ caligraphic_S is a binary vector of dimension m<M ๐‘š ๐‘€ m<M italic_m < italic_M representing the mask for a set of initial layers in the model, and an action consists of appending the mask for the subsequent layer to this vector. That is, each action in the sequence samples the mask for an entire layer (in parallel), conditioned on the masks for the previous layers.

In the next section, we formally describe how GFlowNets can be used for generating dropout masks, as well as practical implementation details.

Algorithm 1 GFlowOut

The whole system has the following 3 components:

  • โ€ข Backbone Model (eg, classifier) neural network pโข(y i|x i,z i;ฮธ)๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘– ๐œƒ p(y_{i}|x_{i},z_{i};\theta)italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮธ ).This algorithm section is written assuming pโข(y i|x i,z i;ฮธ)๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘– ๐œƒ p(y_{i}|x_{i},z_{i};\theta)italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮธ ) is an MLP with L ๐ฟ L italic_L hidden layers, but it can be easily extended to other architectures

  • โ€ข GFlowNet qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z_{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) which approximates the posterior distribution over dropout masks z i subscript ๐‘ง ๐‘– z_{i}italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT conditioned on both x i subscript ๐‘ฅ ๐‘– x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and y i subscript ๐‘ฆ ๐‘– y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT from the data point. Its tempered version qโˆผโข(z i|x i,y i;ฯ•)superscript ๐‘ž similar-to conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q^{\sim}(z_{i}|x_{i},y_{i};\phi)italic_q start_POSTSUPERSCRIPT โˆผ end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) is used for dropout mask sampling during training.

  • โ€ข pโข(z i|x i;ฮพ)๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐œ‰ p(z_{i}|x_{i};\xi)italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮพ ), which generates dropout mask distribution only conditioned on x i subscript ๐‘ฅ ๐‘– x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, is optimized by minimizing KL divergence with qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z_{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) and is used for dropout mask sampling during test time.

qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z_{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) and

pโข(z i|x i;ฮพ)๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐œ‰ p(z_{i}|x_{i};\xi)italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮพ ) are all implemented as groups of MLPs, one MLP for each layer

l ๐‘™ l italic_l and they do not share parameters with each other nor between different layers.

Next, we explain how dropout masks are generated and how

pโข(y i|x i,z i;ฮธ)๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘– ๐œƒ p(y_{i}|x_{i},z_{i};\theta)italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮธ ) ,

qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z_{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) and

pโข(z i|x i;ฮพ)๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐œ‰ p(z_{i}|x_{i};\xi)italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮพ ) are computed.

for epoch do

for Iterate data point

x i,y i subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– x_{i},y_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT doโ–ทโ–ท\trianglerightโ–ท (batches used in actual training)

h 0=x i subscript โ„Ž 0 subscript ๐‘ฅ ๐‘– h_{0}=x_{i}italic_h start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT

vโขaโขrโข1=0 ๐‘ฃ ๐‘Ž ๐‘Ÿ 1 0 var1=0 italic_v italic_a italic_r 1 = 0 โ–ทโ–ท\trianglerightโ–ท Variables to store probabilities from each layer

vโขaโขrโข2=0 ๐‘ฃ ๐‘Ž ๐‘Ÿ 2 0 var2=0 italic_v italic_a italic_r 2 = 0

for layer

l ๐‘™ l italic_l in

1:Lโˆ’1:1 ๐ฟ 1 1:L-1 1 : italic_L - 1 do

Use current layerโ€™s activation and dropout masks of all previous layers for mask generation

h l,iโ€ฒ=ReLUโข(b l+w lโขh lโˆ’1,i)subscript superscript โ„Žโ€ฒ๐‘™ ๐‘– ReLU subscript ๐‘ ๐‘™ subscript ๐‘ค ๐‘™ subscript โ„Ž ๐‘™ 1 ๐‘– h^{\prime}{l,i}={\rm ReLU}(b{l}+w_{l}h_{l-1,i})italic_h start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT = roman_ReLU ( italic_b start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT + italic_w start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_l - 1 , italic_i end_POSTSUBSCRIPT )

z i,lโˆผq lโˆผโข(z i,l|h l,iโ€ฒ,y i,(z i,j)j=1 lโˆ’1;ฯ•)similar-to subscript ๐‘ง ๐‘– ๐‘™ subscript superscript ๐‘ž similar-to ๐‘™ conditional subscript ๐‘ง ๐‘– ๐‘™ subscript superscript โ„Žโ€ฒ๐‘™ ๐‘– subscript ๐‘ฆ ๐‘– subscript superscript subscript ๐‘ง ๐‘– ๐‘— ๐‘™ 1 ๐‘— 1 italic-ฯ• z_{i,l}\sim q^{\sim}{l}(z{i,l}|h^{\prime}{l,i},y{i},(z_{i,j})^{l-1}_{j=1};\phi)italic_z start_POSTSUBSCRIPT italic_i , italic_l end_POSTSUBSCRIPT โˆผ italic_q start_POSTSUPERSCRIPT โˆผ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i , italic_l end_POSTSUBSCRIPT | italic_h start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ( italic_z start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT ; italic_ฯ• ) โ–ทโ–ท\trianglerightโ–ท dropout masks generated

h l,i=z i,lโขh l,iโ€ฒsubscript โ„Ž ๐‘™ ๐‘– subscript ๐‘ง ๐‘– ๐‘™ subscript superscript โ„Žโ€ฒ๐‘™ ๐‘– h_{l,i}=z_{i,l}h^{\prime}_{l,i}italic_h start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT = italic_z start_POSTSUBSCRIPT italic_i , italic_l end_POSTSUBSCRIPT italic_h start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT โ–ทโ–ท\trianglerightโ–ท apply dropout

vโขaโขrโข1+=logโกq lโข(z i,l|x i,y i,(z i,j)j=1 lโˆ’1;ฯ•)limit-from ๐‘ฃ ๐‘Ž ๐‘Ÿ 1 subscript ๐‘ž ๐‘™ conditional subscript ๐‘ง ๐‘– ๐‘™ subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– subscript superscript subscript ๐‘ง ๐‘– ๐‘— ๐‘™ 1 ๐‘— 1 italic-ฯ• var1+=\log q_{l}(z_{i,l}|x_{i},y_{i},(z_{i,j})^{l-1}_{j=1};\phi)italic_v italic_a italic_r 1 + = roman_log italic_q start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i , italic_l end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ( italic_z start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT ; italic_ฯ• ) โ–ทโ–ท\trianglerightโ–ท calculate log probabilities

vโขaโขrโข2+=logโกp lโข(z i,l|x iโข(z i,j)j=1 lโˆ’1;ฮพ)limit-from ๐‘ฃ ๐‘Ž ๐‘Ÿ 2 subscript ๐‘ ๐‘™ conditional subscript ๐‘ง ๐‘– ๐‘™ subscript ๐‘ฅ ๐‘– subscript superscript subscript ๐‘ง ๐‘– ๐‘— ๐‘™ 1 ๐‘— 1 ๐œ‰ var2+=\log p_{l}(z_{i,l}|x_{i}(z_{i,j})^{l-1}_{j=1};\xi)italic_v italic_a italic_r 2 + = roman_log italic_p start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i , italic_l end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT ; italic_ฮพ )

end for

logโกqโข(z i|x i,y i;ฯ•)=vโขaโขrโข1 ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• ๐‘ฃ ๐‘Ž ๐‘Ÿ 1\log q(z_{i}|x_{i},y_{i};\phi)=var1 roman_log italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) = italic_v italic_a italic_r 1

logโกpโข(z i|x i;ฮพ)=vโขaโขrโข2 ๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐œ‰ ๐‘ฃ ๐‘Ž ๐‘Ÿ 2\log p(z_{i}|x_{i};\xi)=var2 roman_log italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮพ ) = italic_v italic_a italic_r 2

In the output layer,

y i^=f outโข(b L+w Lโขh Lโˆ’1,i)^subscript ๐‘ฆ ๐‘– subscript ๐‘“ out subscript ๐‘ ๐ฟ subscript ๐‘ค ๐ฟ subscript โ„Ž ๐ฟ 1 ๐‘–\hat{y_{i}}=f_{\rm out}(b_{L}+w_{L}h_{L-1,i})over^ start_ARG italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG = italic_f start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT ( italic_b start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT + italic_w start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_L - 1 , italic_i end_POSTSUBSCRIPT )

where

f oโขuโขt subscript ๐‘“ ๐‘œ ๐‘ข ๐‘ก f_{out}italic_f start_POSTSUBSCRIPT italic_o italic_u italic_t end_POSTSUBSCRIPT is the output non-linearity of the output layer

Update

ฮธ,ฯ•,ฮพ,ฯ‰,ฮณ ๐œƒ italic-ฯ• ๐œ‰ ๐œ” ๐›พ\theta,\phi,\xi,\omega,\gamma italic_ฮธ , italic_ฯ• , italic_ฮพ , italic_ฯ‰ , italic_ฮณ using equations 4-12

end for

end for

3.2 GFlowOut

We consider a generative model of the form pโข(x,y,z)=pโข(x)โขpโข(z|x)โขpโข(y|x,z)๐‘ ๐‘ฅ ๐‘ฆ ๐‘ง ๐‘ ๐‘ฅ ๐‘ conditional ๐‘ง ๐‘ฅ ๐‘ conditional ๐‘ฆ ๐‘ฅ ๐‘ง p(x,y,z)=p(x)p(z|x)p(y|x,z)italic_p ( italic_x , italic_y , italic_z ) = italic_p ( italic_x ) italic_p ( italic_z | italic_x ) italic_p ( italic_y | italic_x , italic_z ), where x ๐‘ฅ x italic_x is the input data with corresponding label y ๐‘ฆ y italic_y and z ๐‘ง z italic_z is a local discrete latent variable representing the sample-dependent dropout mask, along with a dataset of observations D={(x i,y i)}i=1 N ๐ท superscript subscript subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– ๐‘– 1 ๐‘ D={(x_{i},y_{i})}_{i=1}^{N}italic_D = { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT. GFlowOut learns to approximate the posterior pโข(z|x,y)๐‘ conditional ๐‘ง ๐‘ฅ ๐‘ฆ p(z|x,y)italic_p ( italic_z | italic_x , italic_y ) using the given dataset D ๐ท D italic_D. In a supervised learning task where the goal is to learn the predictive distribution pโข(y|x)๐‘ conditional ๐‘ฆ ๐‘ฅ p(y|x)italic_p ( italic_y | italic_x ), with the assumed generative model above, the following variational bound can be derived:

logโขโˆi=1 N pโข(y i|x i)subscript superscript product ๐‘ ๐‘– 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘–\displaystyle\log\prod^{N}{i=1}p(y{i}|x_{i})roman_log โˆ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )(1) =logโขโˆi=1 Nโˆ‘z iโˆˆ๐’ณ pโข(z i|x i)โขpโข(y i|x i,z i)absent subscript superscript product ๐‘ ๐‘– 1 subscript subscript ๐‘ง ๐‘– ๐’ณ ๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘–\displaystyle=\log\prod^{N}{i=1}\sum{z_{i}\in\mathcal{X}}p(z_{i}|x_{i})p(y_{% i}|x_{i},z_{i})= roman_log โˆ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT โˆ‘ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT โˆˆ caligraphic_X end_POSTSUBSCRIPT italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) =logโขโˆi=1 Nโˆ‘z iโˆˆ๐’ณ pโข(z i|x i)โขqโข(z i|x i,y i)qโข(z i|x i,y i)โขpโข(y i|x i,z i)absent subscript superscript product ๐‘ ๐‘– 1 subscript subscript ๐‘ง ๐‘– ๐’ณ ๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘–\displaystyle=\log\prod^{N}{i=1}\sum{z_{i}\in\mathcal{X}}p(z_{i}|x_{i})\frac% {q(z_{i}|x_{i},y_{i})}{q(z_{i}|x_{i},y_{i})}p(y_{i}|x_{i},z_{i})= roman_log โˆ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT โˆ‘ start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT โˆˆ caligraphic_X end_POSTSUBSCRIPT italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) divide start_ARG italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) =โˆ‘i=1 N logโข๐”ผ qโข(z i|x i,y i)[pโข(z i|x i)qโข(z i|x i,y i)โขpโข(y i|x i,z i)]absent superscript subscript ๐‘– 1 ๐‘ subscript ๐”ผ ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– delimited-[]๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘–\displaystyle=\sum_{i=1}^{N}\log\mathop{\mathbb{E}}{q(z{i}|x_{i},y_{i})}% \left[\frac{p(z_{i}|x_{i})}{q(z_{i}|x_{i},y_{i})}p(y_{i}|x_{i},z_{i})\right]= โˆ‘ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT roman_log blackboard_E start_POSTSUBSCRIPT italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT divide start_ARG italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) โ‰ฅโˆ‘i=1 N ๐”ผ qโข(z i|x i,y i)โข[logโกpโข(z i|x i)qโข(z i|x i,y i)โขpโข(y i|x i,z i)]absent superscript subscript ๐‘– 1 ๐‘ subscript ๐”ผ ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– delimited-[]๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘–\displaystyle\geq\sum_{i=1}^{N}\mathbb{E}{q(z{i}|x_{i},y_{i})}\left[\log% \frac{p(z_{i}|x_{i})}{q(z_{i}|x_{i},y_{i})}p(y_{i}|x_{i},z_{i})\right]โ‰ฅ โˆ‘ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log divide start_ARG italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] =โˆ‘i=1 N[๐”ผ qโข(z i|x i,y i)[log p(y i|x i,z i)]\displaystyle=\sum_{i=1}^{N}\bigg{[}\mathbb{E}{q(z{i}|x_{i},y_{i})}[\log p(y% {i}|x{i},z_{i})]= โˆ‘ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] โˆ’KL(q(z i|x i,y i)โˆฅp(z i|x i))]\displaystyle\quad\quad\quad-{\rm KL}(q(z_{i}|x_{i},y_{i})|p(z_{i}|x_{i}))% \bigg{]}- roman_KL ( italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) โˆฅ italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ](3)

where pโข(z i|x i)๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– p(z_{i}|x_{i})italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is part of the generative process and qโข(z i|x i,y i)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– q(z_{i}|x_{i},y_{i})italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is the variational distribution used to approximate the posterior of z i subscript ๐‘ง ๐‘– z_{i}italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. To improve the efficiency of training and fully utilize the information available in each data point, we design qโข(z i|x i,y i)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– q(z_{i}|x_{i},y_{i})italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) so that the distribution of z i subscript ๐‘ง ๐‘– z_{i}italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is conditioned on both x i subscript ๐‘ฅ ๐‘– x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and y i subscript ๐‘ฆ ๐‘– y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. As a consequence, qโข(z i|x i,y i)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– q(z_{i}|x_{i},y_{i})italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is not accessible during inference where y i subscript ๐‘ฆ ๐‘– y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is not available. Instead, pโข(z i|x i)๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– p(z_{i}|x_{i})italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), which is trained by minimizing KL divergence with qโข(z i|x i,y i)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– q(z_{i}|x_{i},y_{i})italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), is used for inference.

Parametrizing each of the terms as pโข(y i|x i,z i;ฮธ)๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘– ๐œƒ p(y_{i}|x_{i},z_{i};\theta)italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮธ ), qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z_{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) and pโข(z i|x i;ฮพ)๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐œ‰ p(z_{i}|x_{i};\xi)italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮพ ), the goal is to maximize the lower bound derived above, which we denote as:

โ„ฌ(๐’Ÿ;ฮธ,ฯ•,ฮพ)=โˆ‘i=1 N[๐”ผ qโข(z i|x i,y i;ฯ•)[log p(y i|x i,z i;ฮธ)]\displaystyle\mathcal{B}(\mathcal{D};\theta,\phi,\xi)=\sum^{N}{i=1}\bigg{[}% \mathop{\mathbb{E}}{q(z_{i}|x_{i},y_{i};\phi)}[\log p(y_{i}|x_{i},z_{i};% \theta)]caligraphic_B ( caligraphic_D ; italic_ฮธ , italic_ฯ• , italic_ฮพ ) = โˆ‘ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT [ blackboard_E start_POSTSUBSCRIPT italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) end_POSTSUBSCRIPT [ roman_log italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮธ ) ] โˆ’KL(q(z i|x i,y i;ฯ•)โˆฅp(z i|x i;ฮพ))]\displaystyle\quad\quad-{\rm KL}(q(z_{i}|x_{i},y_{i};\phi)|p(z_{i}|x_{i};\xi)% )\bigg{]}- roman_KL ( italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) โˆฅ italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮพ ) ) ](4)

The gradients of โ„ฌ โ„ฌ\mathcal{B}caligraphic_B with respect to its parameters are:

โˆ‡ฮธ โ„ฌโข(๐’Ÿ;ฮธ,ฯ•,ฮพ)subscriptโˆ‡๐œƒ โ„ฌ ๐’Ÿ ๐œƒ italic-ฯ• ๐œ‰\displaystyle\nabla_{\theta}\mathcal{B}(\mathcal{D};\theta,\phi,\xi)โˆ‡ start_POSTSUBSCRIPT italic_ฮธ end_POSTSUBSCRIPT caligraphic_B ( caligraphic_D ; italic_ฮธ , italic_ฯ• , italic_ฮพ )=โˆ‘i=1 Nโˆ‡ฮธ ๐”ผ qโข(z i|x i,y i)โข[logโกpโข(y i|x i,z i;ฮธ)]absent superscript subscript ๐‘– 1 ๐‘ subscriptโˆ‡๐œƒ subscript ๐”ผ ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– delimited-[]๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘– ๐œƒ\displaystyle=\sum_{i=1}^{N}\nabla_{\theta}\mathbb{E}{q(z{i}|x_{i},y_{i})}[% \log p(y_{i}|x_{i},z_{i};\theta)]= โˆ‘ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT โˆ‡ start_POSTSUBSCRIPT italic_ฮธ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮธ ) ] โˆ‡ฮพ โ„ฌโข(๐’Ÿ;ฮธ,ฯ•,ฮพ)subscriptโˆ‡๐œ‰ โ„ฌ ๐’Ÿ ๐œƒ italic-ฯ• ๐œ‰\displaystyle\nabla_{\xi}\mathcal{B}(\mathcal{D};\theta,\phi,\xi)โˆ‡ start_POSTSUBSCRIPT italic_ฮพ end_POSTSUBSCRIPT caligraphic_B ( caligraphic_D ; italic_ฮธ , italic_ฯ• , italic_ฮพ )=โˆ‘i=1 Nโˆ‡ฮพ ๐”ผ qโข(z i|x i,y i)โข[logโกpโข(z iโˆฃx i;ฮพ)]absent superscript subscript ๐‘– 1 ๐‘ subscriptโˆ‡๐œ‰ subscript ๐”ผ ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– delimited-[]๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐œ‰\displaystyle=\sum_{i=1}^{N}\nabla_{\xi}\mathbb{E}{q(z{i}|x_{i},y_{i})}[\log p% (z_{i}\mid x_{i};\xi)]= โˆ‘ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT โˆ‡ start_POSTSUBSCRIPT italic_ฮพ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT โˆฃ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮพ ) ]

The gradient of the variational objective โ„ฌ โ„ฌ\mathcal{B}caligraphic_B with respect to ฯ• italic-ฯ•\phi italic_ฯ• requires a score function estimator, which is known to suffer from high gradient variance (Malkin et al., 2023). Instead of directly optimizing โ„ฌ โ„ฌ\mathcal{B}caligraphic_B with respect to ฯ• italic-ฯ•\phi italic_ฯ•, we first observe that โ„ฌ โ„ฌ\mathcal{B}caligraphic_B can be written as:

โ„ฌ(๐’Ÿ)=โˆ‘i=1 N(log p(y i|x i)โˆ’KL(q(z i|x i,y i)โˆฅp(z i|x i,y i))),\displaystyle\mathcal{B}(\mathcal{D})=\sum_{i=1}^{N}\left(\log p(y_{i}|x_{i})-% {\rm KL}(q(z_{i}|x_{i},y_{i})|p(z_{i}|x_{i},y_{i}))\right),caligraphic_B ( caligraphic_D ) = โˆ‘ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ( roman_log italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_KL ( italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) โˆฅ italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ) ,

making pโข(z i|x i,y i)๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– p(z_{i}|x_{i},y_{i})italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), the true posterior, a target for the variational distribution qโข(z i|x i,y i)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– q(z_{i}|x_{i},y_{i})italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). We thus propose to use a GFlowNet with the Trajectory Balance loss to train qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z_{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) to match its target, given by its unnormalized density R=pโข(y i|x i,z i)โขpโข(z i|x i)๐‘… ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘– ๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– R=p(y_{i}|x_{i},z_{i})p(z_{i}|x_{i})italic_R = italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ). As binary dropout masks z ๐‘ง z italic_z are high-dimensional discrete objects that can be constructed sequentially, we consider them as the terminating states of a GFlowNet, and instead of learning a distribution over these terminating states directly, we exploit the DAG structure to learn a forward policy P F(โˆ’|โˆ’;ฯ•)P_{F}(-|-;\phi)italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ( - | - ; italic_ฯ• ), for which the terminating state distribution is qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z_{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ). The Trajectory Balance loss requires an additional parameter Z ฮณ subscript ๐‘ ๐›พ Z_{\gamma}italic_Z start_POSTSUBSCRIPT italic_ฮณ end_POSTSUBSCRIPT, to train qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z_{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) ((Bengio et al., 2021b; Malkin et al., 2022)).

Corresponding trajectory balance loss for a trajectory ฯ„=(s 0,โ€ฆโขs L)๐œ subscript ๐‘  0โ€ฆsubscript ๐‘  ๐ฟ\tau=(s_{0},...s_{L})italic_ฯ„ = ( italic_s start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , โ€ฆ italic_s start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) w.r.t. (Z ฮณ,P F(โˆ’|โˆ’;ฯ•))(Z_{\gamma},P_{F}(-|-;\phi))( italic_Z start_POSTSUBSCRIPT italic_ฮณ end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ( - | - ; italic_ฯ• ) ) will be

โ„’ TโขBโข(ฯ„,๐’Ÿ;ฯ•,ฮณ)=(logโกZ ฮณโขโˆt=1 L P Fโข(s t|s tโˆ’1;ฯ•)R)2 subscript โ„’ ๐‘‡ ๐ต ๐œ ๐’Ÿ italic-ฯ• ๐›พ superscript subscript ๐‘ ๐›พ superscript subscript product ๐‘ก 1 ๐ฟ subscript ๐‘ƒ ๐น conditional subscript ๐‘  ๐‘ก subscript ๐‘  ๐‘ก 1 italic-ฯ• ๐‘… 2\displaystyle\mathcal{L}{TB}(\tau,\mathcal{D};\phi,\gamma)=\left(\log\frac{Z% {\gamma}\prod_{t=1}^{L}P_{F}(s_{t}|s_{t-1};\phi)}{R}\right)^{2}caligraphic_L start_POSTSUBSCRIPT italic_T italic_B end_POSTSUBSCRIPT ( italic_ฯ„ , caligraphic_D ; italic_ฯ• , italic_ฮณ ) = ( roman_log divide start_ARG italic_Z start_POSTSUBSCRIPT italic_ฮณ end_POSTSUBSCRIPT โˆ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ; italic_ฯ• ) end_ARG start_ARG italic_R end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(5)

where a state s l subscript ๐‘  ๐‘™ s_{l}italic_s start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT in the GFlowNet graph refers to the set of dropout masks sampled by the GFlowNet from layer 1 to l ๐‘™ l italic_l of the model ((z j)j=1 l subscript superscript subscript ๐‘ง ๐‘— ๐‘™ ๐‘— 1(z_{j})^{l}{j=1}( italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_l end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT). L ๐ฟ L italic_L is the number of layers involving dropout. s L subscript ๐‘  ๐ฟ s{L}italic_s start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT indicates the termination of the trajectory sampling process. P Fโข(s t|s tโˆ’1;ฯ•)subscript ๐‘ƒ ๐น conditional subscript ๐‘  ๐‘ก subscript ๐‘  ๐‘ก 1 italic-ฯ• P_{F}(s_{t}|s_{t-1};\phi)italic_P start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT | italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT ; italic_ฯ• ) refers to the forward policy in GFlowNets 1 1 1 As the graph G ๐บ G italic_G used in this study has a tree structure, the backward policy P Bโข(s tโˆ’1|s t)subscript ๐‘ƒ ๐ต conditional subscript ๐‘  ๐‘ก 1 subscript ๐‘  ๐‘ก P_{B}(s_{t-1}|s_{t})italic_P start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT ( italic_s start_POSTSUBSCRIPT italic_t - 1 end_POSTSUBSCRIPT | italic_s start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) is a constant so we leave it out of the equations.. logโกZ ฮณ=fโข(x i,y i;ฮณ)subscript ๐‘ ๐›พ ๐‘“ subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– ๐›พ\log Z_{\gamma}=f(x_{i},y_{i};\gamma)roman_log italic_Z start_POSTSUBSCRIPT italic_ฮณ end_POSTSUBSCRIPT = italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮณ ) is the partition function estimator with parameters ฮณ ๐›พ\gamma italic_ฮณ conditioned on both x i subscript ๐‘ฅ ๐‘– x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and y i subscript ๐‘ฆ ๐‘– y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT from the data point indexed by i ๐‘– i italic_i. Its parameter ฮณ ๐›พ\gamma italic_ฮณ is trained together with ฯ• italic-ฯ•\phi italic_ฯ•. R ๐‘… R italic_R is the reward calculated from the likelihood of the data and the prior distribution of the states which are sets of dropout masks (see equation 12).

The parameters ฯ• italic-ฯ•\phi italic_ฯ• and ฮณ ๐›พ\gamma italic_ฮณ are updated by taking gradient steps on โ„’ TโขBโข(ฯ„,๐’Ÿ;ฯ•,ฮณ)subscript โ„’ ๐‘‡ ๐ต ๐œ ๐’Ÿ italic-ฯ• ๐›พ\mathcal{L}{TB}(\tau,\mathcal{D};\phi,\gamma)caligraphic_L start_POSTSUBSCRIPT italic_T italic_B end_POSTSUBSCRIPT ( italic_ฯ„ , caligraphic_D ; italic_ฯ• , italic_ฮณ ) for ฯ„ ๐œ\tau italic_ฯ„ sampled from some training policy. We choose to make the training policy a tempered version of qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ), denoted qโˆผโข(z i|x i,y i;ฯ•)superscript ๐‘ž similar-to conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q^{\sim}(z_{i}|x_{i},y_{i};\phi)italic_q start_POSTSUPERSCRIPT โˆผ end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ). The expected gradient update is thus equal to

โˆ‘i=1 N ๐”ผ z iโˆผqโˆผ(โ‹…|x i,y i;ฯ•)โˆ‡ฯ•,ฮณ(\displaystyle\sum_{i=1}^{N}\mathop{\mathbb{E}}{z{i}\sim q^{\sim}(\cdot|x_{i}% ,y_{i};\phi)}\nabla_{\phi,\gamma}(โˆ‘ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT โˆผ italic_q start_POSTSUPERSCRIPT โˆผ end_POSTSUPERSCRIPT ( โ‹… | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) end_POSTSUBSCRIPT โˆ‡ start_POSTSUBSCRIPT italic_ฯ• , italic_ฮณ end_POSTSUBSCRIPT (logโกZ ฮณ+logโกqโข(z i|x i,y i;ฯ•)subscript ๐‘ ๐›พ ๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ•\displaystyle\log Z_{\gamma}+\log q(z_{i}|x_{i},y_{i};\phi)roman_log italic_Z start_POSTSUBSCRIPT italic_ฮณ end_POSTSUBSCRIPT + roman_log italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) โˆ’log R)2\displaystyle-\log R)^{2}- roman_log italic_R ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(6)

where

logโกR=logโกpโข(y i|x i,z i;ฮธ)+logโกpโข(z i|x i,ฮพ).๐‘… ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘– ๐œƒ ๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐œ‰\log R=\log p(y_{i}|x_{i},z_{i};\theta)+\log p(z_{i}|x_{i},\xi).roman_log italic_R = roman_log italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮธ ) + roman_log italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_ฮพ ) .

During inference one estimates the posterior predictive as p pโขrโขeโขdโข(y i|x i)=1 Mโขโˆ‘j=1 M pโข(y i|x i,z j;ฮธ)superscript ๐‘ ๐‘ ๐‘Ÿ ๐‘’ ๐‘‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– 1 ๐‘€ subscript superscript ๐‘€ ๐‘— 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘— ๐œƒ p^{pred}(y_{i}|x_{i})=\frac{1}{M}\sum^{M}{j=1}p(y{i}|x_{i},z_{j};\theta)italic_p start_POSTSUPERSCRIPT italic_p italic_r italic_e italic_d end_POSTSUPERSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_M end_ARG โˆ‘ start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; italic_ฮธ ) where M ๐‘€ M italic_M different z j subscript ๐‘ง ๐‘— z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT are sampled from the pโข(z i|x i;ฮพ)๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐œ‰ p(z_{i}|x_{i};\xi)italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮพ ) distribution.

Implementation details. Algorithm1 presents a high-level overview of the GFlowOut implementation. qโข(z i|x i,y i;ฯ•)๐‘ž conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ฆ ๐‘– italic-ฯ• q(z_{i}|x_{i},y_{i};\phi)italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• ) is implemented as a set of multiple MLPs, one for each layer in the model that requires dropout mask generation (see Figure1). At layer l ๐‘™ l italic_l of the model, the dropout probabilities of all units in layer l ๐‘™ l italic_l are estimated in parallel conditioned on previous layerโ€™s activation h lโˆ’1,i subscript โ„Ž ๐‘™ 1 ๐‘– h_{l-1,i}italic_h start_POSTSUBSCRIPT italic_l - 1 , italic_i end_POSTSUBSCRIPT, the label y i subscript ๐‘ฆ ๐‘– y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, and all dropout mask in layers before l ๐‘™ l italic_l using q lโข(z i,l|h lโˆ’1,i,y i,(z i,j)j=1 lโˆ’1;ฯ•)subscript ๐‘ž ๐‘™ conditional subscript ๐‘ง ๐‘– ๐‘™ subscript โ„Ž ๐‘™ 1 ๐‘– subscript ๐‘ฆ ๐‘– subscript superscript subscript ๐‘ง ๐‘– ๐‘— ๐‘™ 1 ๐‘— 1 italic-ฯ• q_{l}(z_{i,l}|h_{l-1,i},y_{i},(z_{i,j})^{l-1}{j=1};\phi)italic_q start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i , italic_l end_POSTSUBSCRIPT | italic_h start_POSTSUBSCRIPT italic_l - 1 , italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ( italic_z start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT ; italic_ฯ• ) which is parameterized as an MLP. The same process is repeated for each layer in the model. In this way, dropout mask probability at layer l ๐‘™ l italic_l takes into consideration input x i subscript ๐‘ฅ ๐‘– x{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT through h lโˆ’1,i subscript โ„Ž ๐‘™ 1 ๐‘– h_{l-1,i}italic_h start_POSTSUBSCRIPT italic_l - 1 , italic_i end_POSTSUBSCRIPT, label y i subscript ๐‘ฆ ๐‘– y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and joint probability with all dropout mask in previous layers, but are independent of masks in the same layer. pโข(z i|x i;ฮพ)๐‘ conditional subscript ๐‘ง ๐‘– subscript ๐‘ฅ ๐‘– ๐œ‰ p(z_{i}|x_{i};\xi)italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮพ ) follows the same implementation except that it is not conditioned on y i subscript ๐‘ฆ ๐‘– y_{i}italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and hence it can be used at test time for prediction. In convolutional neural networks, GFlowOut is implemented in a similar manner except that the units are dropped out channel-wise(Yang et al., 2020; Park & Kwak, 2016). Early stopping based on performance on the validation set is used to prevent overfitting. Details of the computational efficiency of GFlowOut are discussed in the Appendix.

3.3 ID-GFlowOut: GFlowOut without sample-dependent information

To understand if the sample-dependent information is needed, we introduce a variant of GFlowOut that only uses sample-independent information and keeps the rest of the algorithm as close to GFlowOut as possible for comparison.

Consider a generative model of the form pโข(x,y,z)=pโข(x)โขpโข(z)โขpโข(y|x,z)๐‘ ๐‘ฅ ๐‘ฆ ๐‘ง ๐‘ ๐‘ฅ ๐‘ ๐‘ง ๐‘ conditional ๐‘ฆ ๐‘ฅ ๐‘ง p(x,y,z)=p(x)p(z)p(y|x,z)italic_p ( italic_x , italic_y , italic_z ) = italic_p ( italic_x ) italic_p ( italic_z ) italic_p ( italic_y | italic_x , italic_z ). Given a supervised learning task pโข(y|x)๐‘ conditional ๐‘ฆ ๐‘ฅ p(y|x)italic_p ( italic_y | italic_x ), we generate a dropout mask z ๐‘ง z italic_z that is not conditioned on the data point. We use qโข(z)๐‘ž ๐‘ง q(z)italic_q ( italic_z ) to approximate the posterior of z ๐‘ง z italic_z which can be seen as part of the model parameters shared by all data points. The following equations can be derived:

logโขโˆi=1 N pโข(y i|x i)=logโขโˆ‘z pโข(z)โขโˆi=1 N pโข(y i|x i,z)subscript superscript product ๐‘ ๐‘– 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘ ๐‘ง subscript superscript product ๐‘ ๐‘– 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ง\displaystyle\log\prod^{N}{i=1}p(y{i}|x_{i})=\log\sum_{z}p(z)\prod^{N}{i=1}% p(y{i}|x_{i},z)roman_log โˆ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = roman_log โˆ‘ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_p ( italic_z ) โˆ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) =logโขโˆ‘z pโข(z)โขqโข(z)qโข(z)โขโˆi=1 N pโข(y i|x i,z)absent subscript ๐‘ง ๐‘ ๐‘ง ๐‘ž ๐‘ง ๐‘ž ๐‘ง subscript superscript product ๐‘ ๐‘– 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ง\displaystyle=\log\sum_{z}p(z)\frac{q(z)}{q(z)}\prod^{N}{i=1}p(y{i}|x_{i},z)= roman_log โˆ‘ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT italic_p ( italic_z ) divide start_ARG italic_q ( italic_z ) end_ARG start_ARG italic_q ( italic_z ) end_ARG โˆ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) =logโข๐”ผ zโˆผqโข(z)[pโข(z)qโข(z)โขโˆi=1 N pโข(y i|x i,z)]absent subscript ๐”ผ similar-to ๐‘ง ๐‘ž ๐‘ง delimited-[]๐‘ ๐‘ง ๐‘ž ๐‘ง subscript superscript product ๐‘ ๐‘– 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ง\displaystyle=\log\mathop{\mathbb{E}}{z\sim q(z)}[\frac{p(z)}{q(z)}\prod^{N}% {i=1}p(y_{i}|x_{i},z)]= roman_log blackboard_E start_POSTSUBSCRIPT italic_z โˆผ italic_q ( italic_z ) end_POSTSUBSCRIPT [ divide start_ARG italic_p ( italic_z ) end_ARG start_ARG italic_q ( italic_z ) end_ARG โˆ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) ] โ‰ฅ๐”ผ zโˆผqโข(z)logโก[pโข(z)qโข(z)โขโˆi=1 N pโข(y i|x i,z)]absent subscript ๐”ผ similar-to ๐‘ง ๐‘ž ๐‘ง ๐‘ ๐‘ง ๐‘ž ๐‘ง subscript superscript product ๐‘ ๐‘– 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ง\displaystyle\geq\mathop{\mathbb{E}}{z\sim q(z)}\log\left[\frac{p(z)}{q(z)}% \prod^{N}{i=1}p(y_{i}|x_{i},z)\right]โ‰ฅ blackboard_E start_POSTSUBSCRIPT italic_z โˆผ italic_q ( italic_z ) end_POSTSUBSCRIPT roman_log [ divide start_ARG italic_p ( italic_z ) end_ARG start_ARG italic_q ( italic_z ) end_ARG โˆ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) ] =๐”ผ zโˆผqโข(z)[โˆ‘i=1 N logโกpโข(y i|x i,z)]โˆ’KLโข(qโข(z)โˆฅpโข(z))absent subscript ๐”ผ similar-to ๐‘ง ๐‘ž ๐‘ง delimited-[]subscript superscript ๐‘ ๐‘– 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ง KL conditional ๐‘ž ๐‘ง ๐‘ ๐‘ง\displaystyle=\mathop{\mathbb{E}}{z\sim q(z)}\left[\sum^{N}{i=1}\log p(y_{i}% |x_{i},z)\right]-{\rm KL}(q(z)|p(z))= blackboard_E start_POSTSUBSCRIPT italic_z โˆผ italic_q ( italic_z ) end_POSTSUBSCRIPT [ โˆ‘ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT roman_log italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ) ] - roman_KL ( italic_q ( italic_z ) โˆฅ italic_p ( italic_z ) )(7)

where the same distribution pโข(z)๐‘ ๐‘ง p(z)italic_p ( italic_z ) is shared across the whole dataset and q ๐‘ž q italic_q is conditional on the whole data set implicitly.

โ„ฌโข(๐’Ÿ;ฮธโ€ฒ,ฯ•โ€ฒ)=๐”ผ qโข(z;ฯ•โ€ฒ)โˆ‘i=1 N logโกpโข(y i|x i,z;ฮธโ€ฒ)โ„ฌ ๐’Ÿ superscript ๐œƒโ€ฒsuperscript italic-ฯ•โ€ฒsubscript ๐”ผ ๐‘ž ๐‘ง superscript italic-ฯ•โ€ฒsubscript superscript ๐‘ ๐‘– 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ง superscript ๐œƒโ€ฒ\displaystyle\mathcal{B}(\mathcal{D};\theta^{\prime},\phi^{\prime})=\mathop{% \mathbb{E}}{q(z;\phi^{\prime})}\sum^{N}{i=1}\log p(y_{i}|x_{i},z;\theta^{% \prime})caligraphic_B ( caligraphic_D ; italic_ฮธ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT , italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) = blackboard_E start_POSTSUBSCRIPT italic_q ( italic_z ; italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT โˆ‘ start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT roman_log italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ; italic_ฮธ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) โˆ’KL(q(z;ฯ•โ€ฒ)||p(z))\displaystyle-{\rm KL}(q(z;\phi^{\prime})||p(z))- roman_KL ( italic_q ( italic_z ; italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) | | italic_p ( italic_z ) )(8)

Where โ„ฌ โ„ฌ\mathcal{B}caligraphic_B is a lower bound. We parameterized each of the terms as pโข(y|x,z;ฮธโ€ฒ)๐‘ conditional ๐‘ฆ ๐‘ฅ ๐‘ง superscript ๐œƒโ€ฒp(y|x,z;\theta^{\prime})italic_p ( italic_y | italic_x , italic_z ; italic_ฮธ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) and qโข(z;ฯ•โ€ฒ)๐‘ž ๐‘ง superscript italic-ฯ•โ€ฒq(z;\phi^{\prime})italic_q ( italic_z ; italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ). pโข(z)๐‘ ๐‘ง p(z)italic_p ( italic_z ) is set as fixed prior to each unit of the dropout rate of 0.5. The gradients for stochastic optimization of ฮธ ๐œƒ\theta italic_ฮธ can be obtained as

โˆ‡ฮธ โ„ฌโข(๐’Ÿ;ฮธโ€ฒ,ฯ•โ€ฒ,ฮพโ€ฒ)=โˆ‘i=1 N ๐”ผ zโˆผqโข(z;ฯ•โ€ฒ)โˆ‡ฮธโ€ฒlogโกpโข(y i|x i,z;ฮธโ€ฒ)subscriptโˆ‡๐œƒ โ„ฌ ๐’Ÿ superscript ๐œƒโ€ฒsuperscript italic-ฯ•โ€ฒsuperscript ๐œ‰โ€ฒsuperscript subscript ๐‘– 1 ๐‘ subscript ๐”ผ similar-to ๐‘ง ๐‘ž ๐‘ง superscript italic-ฯ•โ€ฒsubscriptโˆ‡superscript ๐œƒโ€ฒ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– ๐‘ง superscript ๐œƒโ€ฒ\displaystyle\nabla_{\theta}\mathcal{B}(\mathcal{D};\theta^{\prime},\phi^{% \prime},\xi^{\prime})=\sum_{i=1}^{N}\mathop{\mathbb{E}}{z\sim q(z;\phi^{% \prime})}\nabla{\theta^{\prime}}\log p(y_{i}|x_{i},z;\theta^{\prime})โˆ‡ start_POSTSUBSCRIPT italic_ฮธ end_POSTSUBSCRIPT caligraphic_B ( caligraphic_D ; italic_ฮธ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT , italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT , italic_ฮพ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) = โˆ‘ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z โˆผ italic_q ( italic_z ; italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT โˆ‡ start_POSTSUBSCRIPT italic_ฮธ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_log italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z ; italic_ฮธ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT )(9)

The distribution qโข(z;ฯ•โ€ฒ)๐‘ž ๐‘ง superscript italic-ฯ•โ€ฒq(z;\phi^{\prime})italic_q ( italic_z ; italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) can be trained as a the policy of a GFlowNet, using a tempered version qโˆผโข(z;ฯ•โ€ฒ)superscript ๐‘ž similar-to ๐‘ง superscript italic-ฯ•โ€ฒq^{\sim}(z;\phi^{\prime})italic_q start_POSTSUPERSCRIPT โˆผ end_POSTSUPERSCRIPT ( italic_z ; italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) to sample trajectories for training. The expected update direction for ฯ•โ€ฒsuperscript italic-ฯ•โ€ฒ\phi^{\prime}italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT and ฮณโ€ฒsuperscript ๐›พโ€ฒ\gamma^{\prime}italic_ฮณ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT can be shown to equal

โˆ‘i=1 N ๐”ผ z iโˆผqโˆผโข(z i;ฯ•โ€ฒ)โˆ‡ฯ•,ฮณโ€ฒ(log Z ฮณโ€ฒ+log q(z i;ฯ•โ€ฒ)โˆ’log R)2,\displaystyle\sum_{i=1}^{N}\mathop{\mathbb{E}}{z{i}\sim q^{\sim}(z_{i};\phi^% {\prime})}\nabla_{\phi,\gamma^{\prime}}(\log Z_{\gamma^{\prime}}+\log q(z_{i};% \phi^{\prime})-\log R)^{2},โˆ‘ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT โˆผ italic_q start_POSTSUPERSCRIPT โˆผ end_POSTSUPERSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT โˆ‡ start_POSTSUBSCRIPT italic_ฯ• , italic_ฮณ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( roman_log italic_Z start_POSTSUBSCRIPT italic_ฮณ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT + roman_log italic_q ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) - roman_log italic_R ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ,(10)

where the log-reward for is

logโกR=Nโขlogโกpโข(y i|x i,z i;ฮธโ€ฒ)+logโกpโข(z i).๐‘… ๐‘ ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘– superscript ๐œƒโ€ฒ๐‘ subscript ๐‘ง ๐‘–\displaystyle\log R=N\log p(y_{i}|x_{i},z_{i};\theta^{\prime})+\log p(z_{i}).roman_log italic_R = italic_N roman_log italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; italic_ฮธ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) + roman_log italic_p ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) .(11)

In ID-GFlowOut, logโกZ ฮณโ€ฒ=ฮณโ€ฒsubscript ๐‘ superscript ๐›พโ€ฒsuperscript ๐›พโ€ฒ\log Z_{\gamma^{\prime}}=\gamma^{\prime}roman_log italic_Z start_POSTSUBSCRIPT italic_ฮณ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT = italic_ฮณ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT and does not condition on any inputs.

During inference, the posterior predictive estimate is p pโขrโขeโขdโข(y i|x i)=1 Mโขโˆ‘j=1 M pโข(y i|x i,z j;ฮธโ€ฒ)superscript ๐‘ ๐‘ ๐‘Ÿ ๐‘’ ๐‘‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– 1 ๐‘€ subscript superscript ๐‘€ ๐‘— 1 ๐‘ conditional subscript ๐‘ฆ ๐‘– subscript ๐‘ฅ ๐‘– subscript ๐‘ง ๐‘— superscript ๐œƒโ€ฒp^{pred}(y_{i}|x_{i})=\frac{1}{M}\sum^{M}{j=1}p(y{i}|x_{i},z_{j};\theta^{% \prime})italic_p start_POSTSUPERSCRIPT italic_p italic_r italic_e italic_d end_POSTSUPERSCRIPT ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_M end_ARG โˆ‘ start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; italic_ฮธ start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) where M ๐‘€ M italic_M different z j subscript ๐‘ง ๐‘— z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT are sampled from the qโข(z;ฯ•โ€ฒ)๐‘ž ๐‘ง superscript italic-ฯ•โ€ฒq(z;\phi^{\prime})italic_q ( italic_z ; italic_ฯ• start_POSTSUPERSCRIPT โ€ฒ end_POSTSUPERSCRIPT ) distribution.

4 Experiments

In this section, we empirically evaluate GFlowOut 2 2 2 Code is available at https://github.com/kaiyuanmifen/GFNDropout on a variety of tasks to understand its ability to generalize across different distributions and estimate uncertainties in prediction. We first evaluate the generalization performance of the posterior predictive approximated by GFlowOut on an image classification task. We also evaluate the efficacy of GFlowOut in the context of transfer learning. To understand the performance of GFlowOut when used in larger models and datasets, we conduct Visual Question Answering (VQA) experiments using Transformer architectures. Next, we evaluate the uncertainty captured by the posterior to detect out-of-distribution examples. Finally, we study a potential application of GFlowOut in a real-world clinical use case for the cross-hospital prediction of mortality in intensive care units (ICUs). We supplement these results with further analysis and additional experimental details in the Appendix.

Robustness to data distribution shift. To evaluate the robustness of GFlowOut to distribution shifts between the train and test data, we study its predictive performance on OOD examples. We conduct experiments on MNIST, CIFAR-10, and CIFAR-100 datasets with different types and levels of deformations. For MNIST, we train a two-layer MLP with 300 and 100 units respectively and evaluate predictions on MNIST images rotated by a uniformly sampled angle (0โˆ’360โˆ˜0 superscript 360 0-360^{\circ}0 - 360 start_POSTSUPERSCRIPT โˆ˜ end_POSTSUPERSCRIPT). Similarly, we use the ResNet-18(He et al., 2016) models for the CIFAR-10/CIFAR-100 datasets and evaluate their robustness to distribution shifts induced by random rotations. Additionally, we consider โ€œSnowโ€, โ€œFrostโ€ and Gaussian noises image corruptions(Hendrycks & Dietterich, 2019), and analyze the robustness of models to each type of deformation applied with varying intensities. We consider both GFlowOut and ID-GFlowOut variants and as baselines use Random Dropout (Standard Bernoulli Dropout)(Hinton et al., 2012), Contextual Dropout(Fan et al., 2021) and Concrete Dropout(Gal et al., 2017). The results, as summarized in Table1, show that models trained using GFlowOut are in general more robust to random rotations, and GFlowOut outperforms (or at least matches the performance of) baselines in five out of six experiments with different levels of corruption (Figure2, Appendix Figure3 and Figure4). These observations suggest that models trained with GFlowOut are more robust to distribution shifts as compared to the baselines. Better generalization performance to distribution shifts indicates that GFlowOut potentially learns a better approximation of the Bayesian posterior over the model parameters.

Table 1: Performance on clean and corrupted data to evaluate the robustness of models trained with different dropout methods to random image rotations at test time.

Visual Question Answering task using transformer architecture To evaluate GFlowOut on large-scale tasks with larger models, we consider a transformer-based multi-modal architecture MCAN(Yu et al., 2019) for the Visual Question Answering (VQA) task, following Fan et al. (2021). The task involves answering a textual question related to the content of a given image. There are three types of questions in the task, namely binary yes/no questions, numerical questions, and other questions. Dropout is applied on cross-modal attention between images and texts, within data type self-attention, and the feed-forward layers after attention. Our experimental results in Table2 suggest that GFlowOut either outperforms or matches the performance of contextual and concrete dropout when tested on generalization to noisy dataset where a Gaussian noise is added to the visual inputs (Fan et al., 2021).

Table 2: Performance on different question types in Visual Question Answering task with a Transformer-based model trained with different methods.

Uncertainty estimation for out-of-distribution detection. Another way to evaluate the quality of the learned posterior is to analyze the uncertainty estimates on a downstream task. We consider the standard task of using uncertainty estimates for detecting out-of-distribution (OOD) examples(Nado et al., 2021). The intuition is that a well-calibrated model should produce uncertainty in predictions on OOD examples. This can be useful in cases where difficult OOD examples can be delegated to humans for more careful consideration. As in the previous experiments, we consider ResNet-18 models for CIFAR-10/CIFAR-100 classification and compute uncertainty estimates on the CIFAR-10/CIFAR-100 and SVHN (OOD) test sets. Uncertainty for prediction on each example is calculated using the Dempster-Shafer metric(Sensoy et al., 2018). For baselines, we consider Contextual Dropout and Concrete Dropout, along with standard MC Dropout and Deep Ensembles which are strong baselines for this task. We run the experiment with 5 seeds and report the mean and standard error. We study both GFlowOut with sample-dependent information and ID-GFlowOut with only sample-independent information. In Table3, we present AUPR and AUROC for in-distribution classification (CIFAR-10 and CIFAR-100) and OOD classification (SVHN) using the uncertainty estimates from each method. We observe that GFlowOut outperforms the other dropout baselines with both CIFAR-10 and CIFAR-100 as the training dataset, indicating that sample-dependent information used in GFlowOut results in more calibrated uncertainty estimates. ID-GFlowOut performs well on CIFAR-100 but performs poorly on CIFAR-10. Results of deep ensembles, which is a widely used state-of-the-art uncertainty estimation method, is also reported in Table3 for comparison.

Image 2: Refer to caption

Image 3: Refer to caption

Image 4: Refer to caption

Image 5: Refer to caption

Figure 2: Top: Evaluating the robustness of ResNet-18 models, trained with different dropout methods, to different amounts of deformation including SNOW deformation, FROST deformation or Gaussian noise on CIFAR-10/CIFAR-100 at test time. See Figure 3 and 4 in Appendix for detailed results. Bottom: Evaluating the transfer learning performance of ResNet-18 models trained on CIFAR-10/CIFAR-100 with label noise and fine-tuned on varying amounts of clean data (i.e., without any label noise).

Adaptation after training on noisy data. Next, we evaluate the ability of models trained with GFlowOut to adapt quickly after being trained on noisy data. Concretely, during training we add label noise i.e., we randomly assign the labels for a fraction (30%) of points in the training set and then re-train the classifier on a small fraction of the dataset with clean labels. We adopt the same experimental setup as the previous experiments. We consider ResNet-18 models trained on CIFAR-10/CIFAR-100 datasets. As baselines, we again use Contextual Dropout and Concrete Dropout. Figure2 shows that the models trained with GFlowOut perform adapt faster than the dropout methods we use as a baseline. We also observe that both the sample-dependent and sample-independent variants of GFlowOut achieve similar performance.

Table 3: Performance on the out-of-distribution (OOD) detection task indicates the posterior approximated by GFlowOut provides better uncertainty estimates.

Application on real-world clinical data. We explore the competence of GFlowOut as a probabilistic tool for solving real-world problems. In intensive care units (ICUs), the ability to forecast the mortality of patients can help clinicians to allocate limited resources to help the individuals at the highest risk. However, to respect patient privacy, most hospitals have access only to data for a limited number of patients that is anonymized and available for training predictive models. Moreover, there are stringent regulations on the exchange of medical records among hospitals. To enable data-driven decision-making in these critical scenarios, we consider learning probabilistic classifiers for the problem of mortality predictions. We use patientsโ€™ medication usage in the first 48 hrs of ICU stay to make a binary prediction of mortality during the stay. We emphasize that the predictions are meant to help decision-makers (doctors) in making clinical decisions rather than being used directly. We use a 3-layer MLP trained with 4500 patientsโ€™ ICU records, including 158 deaths, from one hospital, and tested with data from another hospital (4018 patients, 147 deaths). Records of both hospitals are obtained from the eICU database(Pollard et al., 2018). This task encompasses several important challenges in applying machine learning tools to real-world tasks: (1) the underlying prediction task is extremely difficult due to limited information and complex case-specific clinical details, (2) the data distribution is severely imbalanced as deaths are rare events, (3) limited training data resulting in a complex posterior, and (4) large distribution shifts between hospitals. Overall, this cross-hospital task setup is quite challenging. Our results in Table4, show that GFlowOut significantly outperforms the baselines, in all the metrics. While the margins may appear to be small, in the context of real-world decision-making, they can have a significant impact. The findings demonstrate GFlowOutโ€™s effectiveness in addressing risk-averse real-world problems.

Table 4: Performance of methods on cross-hospital mortality prediction on real-world clinical data from ICUs demonstrates superior performance of GFlowOut.

5 Conclusion

In this work, we propose GFlowOut, to learn the posterior distribution over dropout masks in a neural network. We evaluate GFlowOut on various downstream tasks such as uncertainty estimation, robustness to distribution shift, and transfer learning, using both benchmark datasets and real-world clinical datasets. Our empirical results show the favorable performance of GFlowOut over related methods like Concrete and Contextual Dropout. Future work should involve combining top-down and bottom-up dropout strategies, applying GFlowOut on larger models with complex architectures, and using it to promote exploration in RL problems where accurate estimation of posterior has shown to enhance sample efficiency (Osband et al., 2013).

Author contributions

D.L., Y.B. and K.K. initialized the project. D.L., M.J., B.D., C.E., Q.S. and S.L. contributed to implementation and experiments of the project. D.L., M.J. and A.G. designed the experimental studies. D.L., Y.B., A.G., N.M., M.J., X.J. and S.L. contributed to conceptualization of the project. N.M., S.L., K.K., D.Z., D.L., M.J., Q.S., N.H. and Y.B. contributed to the mathematical parts of the project. D.L., A.G. and M.J. coordinated the project. Y.B. supervised the whole project and designed the whole framework. All authors contributed to the writing of the manuscript.

Acknowledgments

The authors thank CIFAR, Samsung and IVADO for funding and NVIDIA for equipment.

References

  • Ba & Frey (2013) Ba, J. and Frey, B. Adaptive dropout for training deep neural networks. Neural Information Processing Systems (NIPS), 2013.
  • Bengio et al. (2021a) Bengio, E., Jain, M., Korablyov, M., Precup, D., and Bengio, Y. Flow network based generative models for non-iterative diverse candidate generation. Neural Information Processing Systems (NeurIPS), 2021a.
  • Bengio et al. (2021b) Bengio, Y., Deleu, T., Hu, E.J., Lahlou, S., Tiwari, M., and Bengio, E. Gflownet foundations. arXiv preprint arXiv:2111.09266, 2021b.
  • Bhatt et al. (2021) Bhatt, U., Antorรกn, J., Zhang, Y., Liao, Q.V., Sattigeri, P., Fogliato, R., Melanรงon, G., Krishnan, R., Stanley, J., Tickoo, O., et al. Uncertainty as a form of transparency: Measuring, communicating, and using uncertainty. In Proceedings of the 2021 AAAI/ACM Conference on AI, Ethics, and Society, pp. 401โ€“413, 2021.
  • Boluki et al. (2020) Boluki, S., Ardywibowo, R., Dadaneh, S.Z., Zhou, M., and Qian, X. Learnable Bernoulli dropout for Bayesian deep learning. Artificial Intelligence and Statistics (AISTATS), 2020.
  • Damianou & Lawrence (2013) Damianou, A. and Lawrence, N.D. Deep Gaussian processes. Artificial Intelligence and Statistics (AISTATS), 2013.
  • Daxberger et al. (2021) Daxberger, E., Nalisnick, E., Allingham, J.U., Antorรกn, J., and Hernรกndez-Lobato, J.M. Bayesian deep learning via subnetwork inference. International Conference on Machine Learning (ICML), 2021.
  • Deleu et al. (2022) Deleu, T., Gรณis, A., Emezue, C., Rankawat, M., Lacoste-Julien, S., Bauer, S., and Bengio, Y. Bayesian structure learning with generative flow networks. Uncertainty in Artificial Intelligence (UAI), 2022.
  • Fan et al. (2021) Fan, X., Zhang, S., Tanwisuth, K., Qian, X., and Zhou, M. Contextual dropout: An efficient sample-dependent dropout module. International Conference on Learning Representations (ICLR), 2021.
  • Foong et al. (2020) Foong, A., Burt, D., Li, Y., and Turner, R. On the expressiveness of approximate inference in Bayesian neural networks. Neural Information Processing Systems (NeurIPS), 2020.
  • Fort et al. (2019) Fort, S., Hu, H., and Lakshminarayanan, B. Deep ensembles: A loss landscape perspective. arXiv preprint arXiv:1912.02757, 2019.
  • Gal & Ghahramani (2016) Gal, Y. and Ghahramani, Z. Dropout as a Bayesian approximation: Representing model uncertainty in deep learning. International Conference on Machine Learning (ICML), 2016.
  • Gal et al. (2017) Gal, Y., Hron, J., and Kendall, A. Concrete dropout. Neural Information Processing Systems (NIPS), 30, 2017.
  • Ghiasi et al. (2018) Ghiasi, G., Lin, T.-Y., and Le, Q.V. Dropblock: A regularization method for convolutional networks. Neural Information Processing Systems (NeurIPS), 2018.
  • Guo et al. (2017) Guo, C., Pleiss, G., Sun, Y., and Weinberger, K.Q. On calibration of modern neural networks. International Conference on Machine Learning (ICML), 2017.
  • He et al. (2016) He, K., Zhang, X., Ren, S., and Sun, J. Deep residual learning for image recognition. Computer Vision and Pattern Recognition (CVPR), 2016.
  • Hendrycks & Dietterich (2019) Hendrycks, D. and Dietterich, T. Benchmarking neural network robustness to common corruptions and perturbations. International Conference on Learning Representations (ICLR), 2019.
  • Hinton et al. (2012) Hinton, G.E., Srivastava, N., Krizhevsky, A., Sutskever, I., and Salakhutdinov, R.R. Improving neural networks by preventing co-adaptation of feature detectors. arXiv preprint arXiv:1207.0580, 2012.
  • Jain et al. (2022) Jain, M., Bengio, E., Hernandez-Garcia, A., Rector-Brooks, J., Dossou, B.F., Ekbote, C.A., Fu, J., Zhang, T., Kilgour, M., Zhang, D., Simine, L., Das, P., and Bengio, Y. Biological sequence design with gflownets. International Conference on Machine Learning (ICML), 2022.
  • Jain et al. (2023) Jain, M., Lahlou, S., Nekoei, H., Butoi, V., Bertin, P., Rector-Brooks, J., Korablyov, M., and Bengio, Y. DEUP: Direct epistemic uncertainty prediction. Transactions on Machine Learning Research (TMLR), 2023.
  • Kingma et al. (2015) Kingma, D.P., Salimans, T., and Welling, M. Variational dropout and the local reparameterization trick. Neural Information Processing Systems (NIPS), 2015.
  • Kuleshov et al. (2018) Kuleshov, V., Fenner, N., and Ermon, S. Accurate uncertainties for deep learning using calibrated regression. International Conference on Machine Learning (ICML), 2018.
  • Le Folgoc et al. (2021) Le Folgoc, L., Baltatzis, V., Desai, S., Devaraj, A., Ellis, S., Manzanera, O. E.M., Nair, A., Qiu, H., Schnabel, J., and Glocker, B. Is MC dropout bayesian? arXiv preprint arXiv:2110.04286, 2021.
  • Lee et al. (2020) Lee, H.B., Nam, T., Yang, E., and Hwang, S.J. Meta dropout: Learning to perturb latent features for generalization. International Conference on Learning Representations (ICLR), 2020.
  • Lotfi et al. (2022) Lotfi, S., Izmailov, P., Benton, G., Goldblum, M., and Wilson, A.G. Bayesian model selection, the marginal likelihood, and generalization. International Conference on Machine Learning (ICML), 2022.
  • MacKay (1992) MacKay, D.J. A practical Bayesian framework for backpropagation networks. Neural Computation, 4(3):448โ€“472, 1992.
  • Madan et al. (2023) Madan, K., Rector-Brooks, J., Korablyov, M., Bengio, E., Jain, M., Nica, A., Bosc, T., Bengio, Y., and Malkin, N. Learning GFlowNets from partial episodes for improved convergence and stability. International Conference on Machine Learning (ICML), 2023.
  • Malkin et al. (2022) Malkin, N., Jain, M., Bengio, E., Sun, C., and Bengio, Y. Trajectory balance: Improved credit assignment in GFlowNets. Neural Information Processing Systems (NeurIPS), 2022.
  • Malkin et al. (2023) Malkin, N., Lahlou, S., Deleu, T., Ji, X., Hu, E., Everett, K., Zhang, D., and Bengio, Y. GFlowNets and variational inference. International Conference on Learning Representations (ICLR), 2023.
  • Nado et al. (2021) Nado, Z., Band, N., Collier, M., Djolonga, J., Dusenberry, M., Farquhar, S., Filos, A., Havasi, M., Jenatton, R., Jerfel, G., Liu, J., Mariet, Z., Nixon, J., Padhy, S., Ren, J., Rudner, T., Wen, Y., Wenzel, F., Murphy, K., Sculley, D., Lakshminarayanan, B., Snoek, J., Gal, Y., and Tran, D. Uncertainty Baselines: Benchmarks for uncertainty & robustness in deep learning. arXiv preprint arXiv:2106.04015, 2021.
  • Neal (2012) Neal, R.M. Bayesian learning for neural networks, volume 118. Springer Science & Business Media, 2012.
  • Nguyen et al. (2015) Nguyen, A., Yosinski, J., and Clune, J. Deep neural networks are easily fooled: High confidence predictions for unrecognizable images. Computer Vision and Pattern Recognition (CVPR), 2015.
  • Nguyen et al. (2021) Nguyen, S., Nguyen, D., Nguyen, K., Than, K., Bui, H., and Ho, N. Structured dropout variational inference for Bayesian neural networks. Neural Information Processing Systems (NeurIPS), 2021.
  • Nica et al. (2022) Nica, A.C., Jain, M., Bengio, E., Liu, C.-H., Korablyov, M., Bronstein, M.M., and Bengio, Y. Evaluating generalization in gflownets for molecule design. ICLR 2022 Machine Learning for Drug Discovery workshop, 2022.
  • Osband et al. (2013) Osband, I., Russo, D., and Van Roy, B. (more) efficient reinforcement learning via posterior sampling. Neural Information Processing Systems (NIPS), 2013.
  • Ovadia et al. (2019) Ovadia, Y., Fertig, E., Ren, J., Nado, Z., Sculley, D., Nowozin, S., Dillon, J., Lakshminarayanan, B., and Snoek, J. Can you trust your modelโ€™s uncertainty? evaluating predictive uncertainty under dataset shift. Neural Information Processing Systems (NeurIPS), 2019.
  • Pan et al. (2023) Pan, L., Zhang, D., Courville, A.C., Huang, L., and Bengio, Y. Generative augmented flow networks. International Conference on Learning Representations (ICLR), 2023.
  • Park & Kwak (2016) Park, S. and Kwak, N. Analysis on the dropout effect in convolutional neural networks. Asian Conference on Computer Vision, 2016.
  • Pham & Le (2021) Pham, H. and Le, Q. Autodropout: Learning dropout patterns to regularize deep networks. Association for the Advancement of Artificial Intelligence (AAAI), 2021.
  • Pollard et al. (2018) Pollard, T.J., Johnson, A.E., Raffa, J.D., Celi, L.A., Mark, R.G., and Badawi, O. The eICU Collaborative Research Database, a freely available multi-center database for critical care research. Scientific data, 5(1):1โ€“13, 2018.
  • Sensoy et al. (2018) Sensoy, M., Kaplan, L., and Kandemir, M. Evidential deep learning to quantify classification uncertainty. Neural Information Processing Systems (NeurIPS), 2018.
  • Sutton & Barto (2018) Sutton, R.S. and Barto, A.G. Reinforcement learning: An introduction. MIT press, 2018.
  • Wilson & Izmailov (2020) Wilson, A.G. and Izmailov, P. Bayesian deep learning and a probabilistic perspective of generalization. Neural Information Processing Systems (NeurIPS), 2020.
  • Xie et al. (2019) Xie, J., Ma, Z., Zhang, G., Xue, J.-H., Tan, Z.-H., and Guo, J. Soft dropout and its variational bayes approximation. Machine Learning for Signal Processing (MLSP), 2019.
  • Yang et al. (2020) Yang, X., Tang, J., Torun, H.M., Becker, W.D., Hejase, J.A., and Swaminathan, M. Rx equalization for a high-speed channel based on bayesian active learning using dropout. Electrical Performance of Electronic Packaging and Systems (EPEPS), 2020.
  • Yu et al. (2019) Yu, Z., Yu, J., Cui, Y., Tao, D., and Tian, Q. Deep modular co-attention networks for visual question answering. Computer Vision and Pattern Recognition (CVPR), 2019.
  • Zhang et al. (2022a) Zhang, D., Chen, R.T., Malkin, N., and Bengio, Y. Unifying generative models with gflownets. arXiv preprint arXiv:2209.02606, 2022a.
  • Zhang et al. (2022b) Zhang, D., Malkin, N., Liu, Z., Volokhova, A., Courville, A., and Bengio, Y. Generative flow networks for discrete probabilistic modeling. International Conference on Machine Learning (ICML), 2022b.

Appendix A Appendix

A.0.1 Hyperparameters

Hyperparameters of the โ€œbackboneโ€ ResNet and Transformer models were obtained from published baselines or architectures(He et al., 2016; Yu et al., 2019; Fan et al., 2021; Gal & Ghahramani, 2016; Gal et al., 2017). Several GFlowNet-specific hyperparameters are taken into consideration in this study, including the architecture of the variational function qโข(โ‹…)๐‘žโ‹…q(\cdot)italic_q ( โ‹… ) and its associated hyperparameters and the temperature of qโˆผโข(โ‹…)superscript ๐‘ž similar-toโ‹…q^{\sim}(\cdot)italic_q start_POSTSUPERSCRIPT โˆผ end_POSTSUPERSCRIPT ( โ‹… ). For ID-GFlowOut, there is an additional hyperparameter, which is the prior pโข(z)๐‘ ๐‘ง p(z)italic_p ( italic_z ). The parameters are picked via grid search using the validation set. The temperature of qโˆผโข(โ‹…)superscript ๐‘ž similar-toโ‹…q^{\sim}(\cdot)italic_q start_POSTSUPERSCRIPT โˆผ end_POSTSUPERSCRIPT ( โ‹… ) is set as 2. In addition, with a 0.1 probability, the forward policy will choose a random mask set in each layer.

A.0.2 Computational efficiency

On a single RTX8000 GPU, training models with GFlowOut takes around the same time as Contextual dropout and Concrete Dropout, and around twice the time (ResNet 7 hrs and MCAN Transformer 16 hrs) as a model with the same architecture and random dropout. The three learned dropout methods have similar efficiency during inference.

A.1 Experimental details

A.1.1 Sampling dropout masks

In the forward pass during inference, 20 samples are used for each data point. In ResNet experiments, dropout masks are generated for each ResNet block. In the transformer VQA experiment, in each layer, dropout is applied to both the self-attention and the feed-forward layer.

A.1.2 Robustness to distribution shift

The performance of each method was obtained with 9 repeats of different random seeds for training. Early stop using validation set was used to prevent overfitting. VQA Transformer experiments are designed according to Yu et al. (2019).

A.1.3 OOD detection

For each data point, we take 20 forward passes and calculate the Uncertainty for prediction on each example using the Dempster-Shafer metric (Sensoy et al., 2018) and algorithm from Jain et al. (2023). The uncertainty score is used for classification of in-distribution vs. out-of-distribution data points assuming the later should have higher uncertainty.

A.1.4 Adaptation after training on noisy data

When training the model with noisy CIFAR-10/100 data, randomly picked 30% data points are assigned a random label in the whole training set. The model obtained is then fine-tuned using a small number of clean data points all with correct labels. We conducted experiments with 1000,2000,4000 and 8000 data points used for fine-tuning.

A.1.5 Real-world clinical data

The ICU dataset is a real-world dataset, containing information about the deaths or survival of 126489 patients, across 58 different hospitals, given a set of administrated drugs. The goal of this experiment is to evaluate how well our approach generalizes, in real-world settings. To imitate this, we built two sets:

  • โ€ข a training set that contains data points about patients from all hospitals, except the hospital with the highest number of patients (hospital ID 167). This results in a dataset with 120945 entries, which is equally partitioned (70:30 ratio) into the real training and validation sets.

  • โ€ข a test set that contains information about 5544 patients. As each hospital follows a specific distribution, the test set was designed to measure the OOD efficiency of GFlowOut, on the widest possible set of patients, which is a real-world scenario.

We used a 3-layer MLP with multiple Dropout options as presented in Table 4. For the evaluation, we perform 20 forward passes and take the mean of the prediction.

A.2 Analysing dropout masks

Here, we analyze the behavior and dynamics of the binary masks generated by GFlowOut for data points corresponding to different labels and different augmentations. First, we want to verify that GFlowOut generates masks with probability proportional to the reward R ๐‘… R italic_R as defined in equation (7). Our analysis shows a statistically significant correlation between the probabilities of a set of masks being generated by GFlowNet and the corresponding rewards, with correlation โ‰ฅ\geqโ‰ฅ 0.4 and p values โ‰ค\leqโ‰ค 0.05. Next, we want to explore whether GFlowOut generates diverse dropout masks. We take the mean dropout masks generated during inference for each data point and calculate Manhattan distances among different samples in the data set. The results are shown in Figure 6.

Image 6: Refer to caption

Image 7: Refer to caption

Image 8: Refer to caption

Image 9: Refer to caption

Figure 3: Robustness of ResNet-18 models trained with different dropout methods to different amounts of SNOW deformation, FROST deformation or Gaussian noise in CIFAR-10 during test time

Image 10: Refer to caption

Image 11: Refer to caption

Image 12: Refer to caption

Image 13: Refer to caption

Figure 4: Robustness of ResNet-18 models trained with different dropout methods to different amounts of SNOW deformation, FROST deformation or Gaussian noise in CIFAR-100 during test time

Image 14: Refer to caption

Image 15: Refer to caption

Figure 5: Robustness of ResNet-18 models trained with different dropout methods to different amounts of deformation using Resnet blocks with components in a slightly different order: BatchNorm -Conv-BatchNorm-Conv

Image 16: Refer to caption

Image 17: Refer to caption

Image 18: Refer to caption

Image 19: Refer to caption

Figure 6: Diversity of binary dropout masks among different data points measured by Manhattan distance

Xet Storage Details

Size:
110 kB
ยท
Xet hash:
a11724aaddb70013c0c7e8a5aebd97fcaf541863d133820698fb746009a1269b

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.