add device specification
Browse files- handler.py +3 -2
handler.py
CHANGED
|
@@ -27,7 +27,8 @@ class EndpointHandler:
|
|
| 27 |
cfg = OmegaConf.create(config_dict)
|
| 28 |
|
| 29 |
self.generator = Generator(cfg)
|
| 30 |
-
self.
|
|
|
|
| 31 |
self.generator.eval()
|
| 32 |
|
| 33 |
|
|
@@ -35,7 +36,7 @@ class EndpointHandler:
|
|
| 35 |
base64_image = data.get('inputs')
|
| 36 |
input_tensor = self._decode_base64_image(base64_image)
|
| 37 |
print('Input tensor shape: ' + str(input_tensor.shape))
|
| 38 |
-
output_tensor = self.generator(input_tensor)
|
| 39 |
output_tensor = output_tensor.squeeze(0)
|
| 40 |
output_image = transforms.ToPILImage()(output_tensor)
|
| 41 |
output_image = output_image.convert('RGB')
|
|
|
|
| 27 |
cfg = OmegaConf.create(config_dict)
|
| 28 |
|
| 29 |
self.generator = Generator(cfg)
|
| 30 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
self.generator.load_state_dict(torch.load(generator_path, map_location=self.device))
|
| 32 |
self.generator.eval()
|
| 33 |
|
| 34 |
|
|
|
|
| 36 |
base64_image = data.get('inputs')
|
| 37 |
input_tensor = self._decode_base64_image(base64_image)
|
| 38 |
print('Input tensor shape: ' + str(input_tensor.shape))
|
| 39 |
+
output_tensor = self.generator(input_tensor.to(self.device))
|
| 40 |
output_tensor = output_tensor.squeeze(0)
|
| 41 |
output_image = transforms.ToPILImage()(output_tensor)
|
| 42 |
output_image = output_image.convert('RGB')
|