Update app.py
Browse files
app.py
CHANGED
|
@@ -13,10 +13,6 @@ import torch.nn as nn
|
|
| 13 |
import torch.nn.functional as F
|
| 14 |
|
| 15 |
|
| 16 |
-
def log_gpu_memory():
|
| 17 |
-
print(subprocess.check_output('nvidia-smi').decode('utf-8'))
|
| 18 |
-
|
| 19 |
-
|
| 20 |
class RelationModuleMultiScale(torch.nn.Module):
|
| 21 |
|
| 22 |
def __init__(self, img_feature_dim, num_bottleneck, num_frames):
|
|
@@ -129,42 +125,19 @@ class TransferVAE_Video(nn.Module):
|
|
| 129 |
|
| 130 |
|
| 131 |
def encode_and_sample_post(self, x):
|
| 132 |
-
|
| 133 |
-
conv_x = self.encoder_frame(x[0])
|
| 134 |
-
else:
|
| 135 |
-
conv_x = self.encoder_frame(x)
|
| 136 |
-
|
| 137 |
lstm_out, _ = self.z_lstm(conv_x)
|
| 138 |
-
|
| 139 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
| 140 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
| 141 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
| 142 |
f_mean = self.f_mean(lstm_out_f)
|
| 143 |
f_logvar = self.f_logvar(lstm_out_f)
|
| 144 |
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
| 145 |
-
|
| 146 |
features, _ = self.z_rnn(lstm_out)
|
| 147 |
z_mean = self.z_mean(features)
|
| 148 |
z_logvar = self.z_logvar(features)
|
| 149 |
z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False)
|
| 150 |
-
|
| 151 |
-
if isinstance(x, list):
|
| 152 |
-
f_mean_list = [f_mean]
|
| 153 |
-
f_post_list = [f_post]
|
| 154 |
-
for t in range(1,3,1):
|
| 155 |
-
conv_x = self.encoder_frame(x[t])
|
| 156 |
-
lstm_out, _ = self.z_lstm(conv_x)
|
| 157 |
-
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
| 158 |
-
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
| 159 |
-
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
| 160 |
-
f_mean = self.f_mean(lstm_out_f)
|
| 161 |
-
f_logvar = self.f_logvar(lstm_out_f)
|
| 162 |
-
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
| 163 |
-
f_mean_list.append(f_mean)
|
| 164 |
-
f_post_list.append(f_post)
|
| 165 |
-
f_mean = f_mean_list
|
| 166 |
-
f_post = f_post_list
|
| 167 |
-
return f_mean, f_logvar, f_post, z_mean, z_logvar, z_post
|
| 168 |
|
| 169 |
|
| 170 |
def decoder_frame(self,zf):
|
|
@@ -190,7 +163,7 @@ class TransferVAE_Video(nn.Module):
|
|
| 190 |
|
| 191 |
|
| 192 |
def forward(self, x, beta):
|
| 193 |
-
|
| 194 |
if isinstance(f_post, list):
|
| 195 |
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
| 196 |
else:
|
|
@@ -269,15 +242,11 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
|
|
| 269 |
plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
|
| 270 |
|
| 271 |
|
| 272 |
-
log_gpu_memory()
|
| 273 |
-
|
| 274 |
# == Load Model ==
|
| 275 |
model = TransferVAE_Video()
|
| 276 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
| 277 |
model.eval()
|
| 278 |
|
| 279 |
-
log_gpu_memory()
|
| 280 |
-
|
| 281 |
|
| 282 |
def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
|
| 283 |
|
|
@@ -338,12 +307,11 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
|
|
| 338 |
images_target = name2seq(file_name_target)
|
| 339 |
x = torch.cat((images_source, images_target), dim=0)
|
| 340 |
|
| 341 |
-
|
| 342 |
-
log_gpu_memory()
|
| 343 |
# == Forward ==
|
| 344 |
with torch.no_grad():
|
| 345 |
f_post, z_post, recon_x = model(x, [0]*3)
|
| 346 |
-
|
| 347 |
|
| 348 |
src_orig_sample = x[0, :, :, :, :]
|
| 349 |
src_recon_sample = recon_x[0, :, :, :, :]
|
|
@@ -389,9 +357,7 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
|
|
| 389 |
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
|
| 390 |
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
| 391 |
|
| 392 |
-
log_gpu_memory()
|
| 393 |
MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
|
| 394 |
-
log_gpu_memory()
|
| 395 |
|
| 396 |
a = concat('MyPlot_')
|
| 397 |
|
|
|
|
| 13 |
import torch.nn.functional as F
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class RelationModuleMultiScale(torch.nn.Module):
|
| 17 |
|
| 18 |
def __init__(self, img_feature_dim, num_bottleneck, num_frames):
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def encode_and_sample_post(self, x):
|
| 128 |
+
conv_x = self.encoder_frame(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
lstm_out, _ = self.z_lstm(conv_x)
|
|
|
|
| 130 |
backward = lstm_out[:, 0, self.hidden_dim:2 * self.hidden_dim]
|
| 131 |
frontal = lstm_out[:, self.frames - 1, 0:self.hidden_dim]
|
| 132 |
lstm_out_f = torch.cat((frontal, backward), dim=1)
|
| 133 |
f_mean = self.f_mean(lstm_out_f)
|
| 134 |
f_logvar = self.f_logvar(lstm_out_f)
|
| 135 |
f_post = self.reparameterize(f_mean, f_logvar, random_sampling=False)
|
|
|
|
| 136 |
features, _ = self.z_rnn(lstm_out)
|
| 137 |
z_mean = self.z_mean(features)
|
| 138 |
z_logvar = self.z_logvar(features)
|
| 139 |
z_post = self.reparameterize(z_mean, z_logvar, random_sampling=False)
|
| 140 |
+
return f_post, z_post
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
def decoder_frame(self,zf):
|
|
|
|
| 163 |
|
| 164 |
|
| 165 |
def forward(self, x, beta):
|
| 166 |
+
f_post, z_post = self.encode_and_sample_post(x)
|
| 167 |
if isinstance(f_post, list):
|
| 168 |
f_expand = f_post[0].unsqueeze(1).expand(-1, self.frames, self.f_dim)
|
| 169 |
else:
|
|
|
|
| 242 |
plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
|
| 243 |
|
| 244 |
|
|
|
|
|
|
|
| 245 |
# == Load Model ==
|
| 246 |
model = TransferVAE_Video()
|
| 247 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
| 248 |
model.eval()
|
| 249 |
|
|
|
|
|
|
|
| 250 |
|
| 251 |
def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
|
| 252 |
|
|
|
|
| 307 |
images_target = name2seq(file_name_target)
|
| 308 |
x = torch.cat((images_source, images_target), dim=0)
|
| 309 |
|
| 310 |
+
|
|
|
|
| 311 |
# == Forward ==
|
| 312 |
with torch.no_grad():
|
| 313 |
f_post, z_post, recon_x = model(x, [0]*3)
|
| 314 |
+
|
| 315 |
|
| 316 |
src_orig_sample = x[0, :, :, :, :]
|
| 317 |
src_recon_sample = recon_x[0, :, :, :, :]
|
|
|
|
| 357 |
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
|
| 358 |
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
| 359 |
|
|
|
|
| 360 |
MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
|
|
|
|
| 361 |
|
| 362 |
a = concat('MyPlot_')
|
| 363 |
|