peterwisu commited on
Commit
5036b2a
·
1 Parent(s): c72eeb4
__pycache__/hparams.cpython-37.pyc ADDED
Binary file (2.15 kB). View file
 
src/main/__pycache__/inference.cpython-37.pyc CHANGED
Binary files a/src/main/__pycache__/inference.cpython-37.pyc and b/src/main/__pycache__/inference.cpython-37.pyc differ
 
src/main/inference.py CHANGED
@@ -54,16 +54,23 @@ class Inference():
54
  self.all_frames = self.all_frames[:len(self.mel_chunk)]
55
 
56
 
57
- # Image2Image translation model
58
- self.image2image = ResUnetGenerator(input_nc=6,output_nc=3,num_downs=6,use_dropout=False)
59
 
60
  # Load pretrained weights to image2image model
61
  image2image_weight = torch.load(self.image2image_ckpt, map_location=torch.device(device))['G']
 
62
  # Since the checkpoint of model was trained using DataParallel with multiple GPU
63
  # It required to wrap a model with DataParallel wrapper class
64
- self.image2image = DataParallel(self.image2image)
65
  # assgin weight to model
66
  self.image2image.load_state_dict(image2image_weight)
 
 
 
 
 
 
67
 
68
 
69
 
@@ -96,6 +103,8 @@ class Inference():
96
  reset_optimizer=True,
97
  pretrain=True)
98
 
 
 
99
 
100
  def __landmark_detection__(self,images, batch_size):
101
  """
@@ -401,6 +410,7 @@ class Inference():
401
 
402
  with torch.no_grad():
403
  self.image2image.eval()
 
404
  trans_out = self.image2image(trans_in)
405
  trans_out = torch.tanh(trans_out)
406
 
 
54
  self.all_frames = self.all_frames[:len(self.mel_chunk)]
55
 
56
 
57
+ # Image2Image translation model
58
+ self.image2image = ResUnetGenerator(input_nc=6,output_nc=3,num_downs=6,use_dropout=False).to(device)
59
 
60
  # Load pretrained weights to image2image model
61
  image2image_weight = torch.load(self.image2image_ckpt, map_location=torch.device(device))['G']
62
+
63
  # Since the checkpoint of model was trained using DataParallel with multiple GPU
64
  # It required to wrap a model with DataParallel wrapper class
65
+ self.image2image = DataParallel(self.image2image).to(device)
66
  # assgin weight to model
67
  self.image2image.load_state_dict(image2image_weight)
68
+
69
+ self.image2image = self.image2image.module # access model (remove DataParallel)
70
+
71
+
72
+
73
+
74
 
75
 
76
 
 
103
  reset_optimizer=True,
104
  pretrain=True)
105
 
106
+ print("Generator",next(self.generator.parameters()).is_cuda )
107
+ print("Img2Img",next(self.image2image.parameters()).is_cuda )
108
 
109
  def __landmark_detection__(self,images, batch_size):
110
  """
 
410
 
411
  with torch.no_grad():
412
  self.image2image.eval()
413
+ print("trans in", trans_in.is_cuda)
414
  trans_out = self.image2image(trans_in)
415
  trans_out = torch.tanh(trans_out)
416