peterwisu commited on
Commit
c72eeb4
·
1 Parent(s): 00a3766
Files changed (1) hide show
  1. src/main/inference.py +2 -2
src/main/inference.py CHANGED
@@ -58,12 +58,12 @@ class Inference():
58
  self.image2image = ResUnetGenerator(input_nc=6,output_nc=3,num_downs=6,use_dropout=False)
59
 
60
  # Load pretrained weights to image2image model
61
- image2image_weight = torch.load(self.image2image_ckpt)['G']
62
  # Since the checkpoint of model was trained using DataParallel with multiple GPU
63
  # It required to wrap a model with DataParallel wrapper class
64
  self.image2image = DataParallel(self.image2image)
65
  # assgin weight to model
66
- self.image2image.load_state_dict(image2image_weight, map_location=torch.device('cpu'))
67
 
68
 
69
 
 
58
  self.image2image = ResUnetGenerator(input_nc=6,output_nc=3,num_downs=6,use_dropout=False)
59
 
60
  # Load pretrained weights to image2image model
61
+ image2image_weight = torch.load(self.image2image_ckpt, map_location=torch.device(device))['G']
62
  # Since the checkpoint of model was trained using DataParallel with multiple GPU
63
  # It required to wrap a model with DataParallel wrapper class
64
  self.image2image = DataParallel(self.image2image)
65
  # assgin weight to model
66
+ self.image2image.load_state_dict(image2image_weight)
67
 
68
 
69