Sunbread commited on
Commit
d17b8f3
·
1 Parent(s): db0dcb9

fix bn stat

Browse files
Files changed (1) hide show
  1. model.py +4 -3
model.py CHANGED
@@ -38,12 +38,13 @@ class BatchNormVAE(nn.Module): # https://spaces.ac.cn/archives/7381/
38
  super(BatchNormVAE, self).__init__()
39
  kwargs['affine'] = False
40
  self.TAU = 0.5
41
- self.bn = nn.BatchNorm1d(num_features, **kwargs)
 
42
  self.theta = nn.Parameter(torch.zeros(1))
43
 
44
  def forward(self, mu, sigma):
45
- mu = self.bn(mu)
46
- sigma = self.bn(sigma)
47
  scale_mu = torch.sqrt(self.TAU + (1 - self.TAU) * F.sigmoid(self.theta))
48
  scale_sigma = torch.sqrt((1 - self.TAU) * F.sigmoid(-self.theta))
49
  return mu*scale_mu, sigma*scale_sigma
 
38
  super(BatchNormVAE, self).__init__()
39
  kwargs['affine'] = False
40
  self.TAU = 0.5
41
+ self.bn_mu = nn.BatchNorm1d(num_features, **kwargs)
42
+ self.bn_sigma = nn.BatchNorm1d(num_features, **kwargs)
43
  self.theta = nn.Parameter(torch.zeros(1))
44
 
45
  def forward(self, mu, sigma):
46
+ mu = self.bn_mu(mu)
47
+ sigma = self.bn_sigma(sigma)
48
  scale_mu = torch.sqrt(self.TAU + (1 - self.TAU) * F.sigmoid(self.theta))
49
  scale_sigma = torch.sqrt((1 - self.TAU) * F.sigmoid(-self.theta))
50
  return mu*scale_mu, sigma*scale_sigma