Update app.py
Browse files
app.py
CHANGED
|
@@ -130,13 +130,15 @@ class TransferVAE_Video(nn.Module):
|
|
| 130 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
| 131 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
| 132 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
|
|
|
| 133 |
f_mean = self.f_mean(lstm_out_f)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
features, _ = self.z_rnn(lstm_out)
|
|
|
|
| 137 |
z_mean = self.z_mean(features)
|
| 138 |
-
|
| 139 |
-
|
| 140 |
return f_post, z_post
|
| 141 |
|
| 142 |
|
|
@@ -150,16 +152,6 @@ class TransferVAE_Video(nn.Module):
|
|
| 150 |
x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
|
| 151 |
x_embed = self.encoder(x)[0]
|
| 152 |
return x_embed.view(x_shape[0], x_shape[1], -1)
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def reparameterize(self, mean, logvar, random_sampling=True):
|
| 156 |
-
if random_sampling is True:
|
| 157 |
-
eps = torch.randn_like(logvar)
|
| 158 |
-
std = torch.exp(0.5 * logvar)
|
| 159 |
-
z = mean + eps * std
|
| 160 |
-
return z
|
| 161 |
-
else:
|
| 162 |
-
return mean
|
| 163 |
|
| 164 |
|
| 165 |
def forward(self, x, beta):
|
|
@@ -171,9 +163,7 @@ class TransferVAE_Video(nn.Module):
|
|
| 171 |
zf = torch.cat((z_post, f_expand), dim=2)
|
| 172 |
recon_x = self.decoder_frame(zf)
|
| 173 |
return f_post, z_post, recon_x
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
|
| 178 |
def name2seq(file_name):
|
| 179 |
images = []
|
|
@@ -235,7 +225,7 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
|
|
| 235 |
axs[1, 3].imshow(src_Zf_tar_Zt)
|
| 236 |
axs[1, 3].axis('off')
|
| 237 |
|
| 238 |
-
plt.subplots_adjust(hspace=0.
|
| 239 |
|
| 240 |
save_name = 'MyPlot_{}.png'.format(frame_id)
|
| 241 |
|
|
|
|
| 130 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
| 131 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
| 132 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
| 133 |
+
|
| 134 |
f_mean = self.f_mean(lstm_out_f)
|
| 135 |
+
f_post = f_mean
|
| 136 |
+
|
| 137 |
features, _ = self.z_rnn(lstm_out)
|
| 138 |
+
|
| 139 |
z_mean = self.z_mean(features)
|
| 140 |
+
z_post = z_mean
|
| 141 |
+
|
| 142 |
return f_post, z_post
|
| 143 |
|
| 144 |
|
|
|
|
| 152 |
x = x.view(-1, x_shape[-3], x_shape[-2], x_shape[-1])
|
| 153 |
x_embed = self.encoder(x)[0]
|
| 154 |
return x_embed.view(x_shape[0], x_shape[1], -1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
|
| 156 |
|
| 157 |
def forward(self, x, beta):
|
|
|
|
| 163 |
zf = torch.cat((z_post, f_expand), dim=2)
|
| 164 |
recon_x = self.decoder_frame(zf)
|
| 165 |
return f_post, z_post, recon_x
|
| 166 |
+
|
|
|
|
|
|
|
| 167 |
|
| 168 |
def name2seq(file_name):
|
| 169 |
images = []
|
|
|
|
| 225 |
axs[1, 3].imshow(src_Zf_tar_Zt)
|
| 226 |
axs[1, 3].axis('off')
|
| 227 |
|
| 228 |
+
plt.subplots_adjust(hspace=0.0125, wspace=0.0)
|
| 229 |
|
| 230 |
save_name = 'MyPlot_{}.png'.format(frame_id)
|
| 231 |
|