Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,10 @@ import torch.nn as nn
|
|
| 12 |
import torch.nn.functional as F
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
class RelationModuleMultiScale(torch.nn.Module):
|
| 16 |
|
| 17 |
def __init__(self, img_feature_dim, num_bottleneck, num_frames):
|
|
@@ -264,11 +268,15 @@ def MyPlot(frame_id, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, s
|
|
| 264 |
plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
|
| 265 |
|
| 266 |
|
|
|
|
|
|
|
| 267 |
# == Load Model ==
|
| 268 |
model = TransferVAE_Video()
|
| 269 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
| 270 |
model.eval()
|
| 271 |
-
|
|
|
|
|
|
|
| 272 |
|
| 273 |
def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
|
| 274 |
|
|
@@ -330,10 +338,12 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
|
|
| 330 |
x = torch.cat((images_source, images_target), dim=0)
|
| 331 |
|
| 332 |
|
|
|
|
| 333 |
# == Forward ==
|
| 334 |
with torch.no_grad():
|
| 335 |
f_post, z_post, recon_x = model(x, [0]*3)
|
| 336 |
-
|
|
|
|
| 337 |
src_orig_sample = x[0, :, :, :, :]
|
| 338 |
src_recon_sample = recon_x[0, :, :, :, :]
|
| 339 |
src_f_post = f_post[0, :].unsqueeze(0)
|
|
@@ -378,7 +388,9 @@ def run(source, action_source, hair_source, top_source, bottom_source, target, a
|
|
| 378 |
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
|
| 379 |
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
| 380 |
|
|
|
|
| 381 |
MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
|
|
|
|
| 382 |
|
| 383 |
a = concat('MyPlot_')
|
| 384 |
|
|
|
|
| 12 |
import torch.nn.functional as F
|
| 13 |
|
| 14 |
|
| 15 |
+
def log_gpu_memory():
|
| 16 |
+
print(subprocess.check_output('nvidia-smi').decode('utf-8'))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
class RelationModuleMultiScale(torch.nn.Module):
|
| 20 |
|
| 21 |
def __init__(self, img_feature_dim, num_bottleneck, num_frames):
|
|
|
|
| 268 |
plt.savefig(save_name, dpi=200, format='png', bbox_inches='tight', pad_inches=0.0)
|
| 269 |
|
| 270 |
|
| 271 |
+
log_gpu_memory()
|
| 272 |
+
|
| 273 |
# == Load Model ==
|
| 274 |
model = TransferVAE_Video()
|
| 275 |
model.load_state_dict(torch.load('TransferVAE.pth.tar', map_location=torch.device('cpu'))['state_dict'])
|
| 276 |
model.eval()
|
| 277 |
+
|
| 278 |
+
log_gpu_memory()
|
| 279 |
+
|
| 280 |
|
| 281 |
def run(source, action_source, hair_source, top_source, bottom_source, target, action_target, hair_target, top_target, bottom_target):
|
| 282 |
|
|
|
|
| 338 |
x = torch.cat((images_source, images_target), dim=0)
|
| 339 |
|
| 340 |
|
| 341 |
+
log_gpu_memory()
|
| 342 |
# == Forward ==
|
| 343 |
with torch.no_grad():
|
| 344 |
f_post, z_post, recon_x = model(x, [0]*3)
|
| 345 |
+
log_gpu_memory()
|
| 346 |
+
|
| 347 |
src_orig_sample = x[0, :, :, :, :]
|
| 348 |
src_recon_sample = recon_x[0, :, :, :, :]
|
| 349 |
src_f_post = f_post[0, :].unsqueeze(0)
|
|
|
|
| 388 |
recon_x_tarZf_srcZt = model.decoder_frame(zf_tarZf_srcZt)
|
| 389 |
tar_Zf_src_Zt = recon_x_tarZf_srcZt.squeeze()[frame, :, :, :].detach().numpy().transpose((1, 2, 0))
|
| 390 |
|
| 391 |
+
log_gpu_memory()
|
| 392 |
MyPlot(frame, src_orig, tar_orig, src_recon, tar_recon, src_Zt, tar_Zt, src_Zf_tar_Zt, tar_Zf_src_Zt)
|
| 393 |
+
log_gpu_memory()
|
| 394 |
|
| 395 |
a = concat('MyPlot_')
|
| 396 |
|