Spaces:
Build error
Build error
readme
Browse files- src/main/inference.py +2 -2
src/main/inference.py
CHANGED
|
@@ -58,12 +58,12 @@ class Inference():
|
|
| 58 |
self.image2image = ResUnetGenerator(input_nc=6,output_nc=3,num_downs=6,use_dropout=False)
|
| 59 |
|
| 60 |
# Load pretrained weights to image2image model
|
| 61 |
-
image2image_weight = torch.load(self.image2image_ckpt)['G']
|
| 62 |
# Since the checkpoint of model was trained using DataParallel with multiple GPU
|
| 63 |
# It required to wrap a model with DataParallel wrapper class
|
| 64 |
self.image2image = DataParallel(self.image2image)
|
| 65 |
# assgin weight to model
|
| 66 |
-
self.image2image.load_state_dict(image2image_weight
|
| 67 |
|
| 68 |
|
| 69 |
|
|
|
|
| 58 |
self.image2image = ResUnetGenerator(input_nc=6,output_nc=3,num_downs=6,use_dropout=False)
|
| 59 |
|
| 60 |
# Load pretrained weights to image2image model
|
| 61 |
+
image2image_weight = torch.load(self.image2image_ckpt, map_location=torch.device(device))['G']
|
| 62 |
# Since the checkpoint of model was trained using DataParallel with multiple GPU
|
| 63 |
# It required to wrap a model with DataParallel wrapper class
|
| 64 |
self.image2image = DataParallel(self.image2image)
|
| 65 |
# assgin weight to model
|
| 66 |
+
self.image2image.load_state_dict(image2image_weight)
|
| 67 |
|
| 68 |
|
| 69 |
|