jlynxdev commited on
Commit
780b329
·
1 Parent(s): f6ab43f

add device specification

Browse files
Files changed (1) hide show
  1. handler.py +3 -2
handler.py CHANGED
@@ -27,7 +27,8 @@ class EndpointHandler:
27
  cfg = OmegaConf.create(config_dict)
28
 
29
  self.generator = Generator(cfg)
30
- self.generator.load_state_dict(torch.load(generator_path))
 
31
  self.generator.eval()
32
 
33
 
@@ -35,7 +36,7 @@ class EndpointHandler:
35
  base64_image = data.get('inputs')
36
  input_tensor = self._decode_base64_image(base64_image)
37
  print('Input tensor shape: ' + str(input_tensor.shape))
38
- output_tensor = self.generator(input_tensor)
39
  output_tensor = output_tensor.squeeze(0)
40
  output_image = transforms.ToPILImage()(output_tensor)
41
  output_image = output_image.convert('RGB')
 
27
  cfg = OmegaConf.create(config_dict)
28
 
29
  self.generator = Generator(cfg)
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ self.generator.load_state_dict(torch.load(generator_path, map_location=self.device))
32
  self.generator.eval()
33
 
34
 
 
36
  base64_image = data.get('inputs')
37
  input_tensor = self._decode_base64_image(base64_image)
38
  print('Input tensor shape: ' + str(input_tensor.shape))
39
+ output_tensor = self.generator(input_tensor.to(self.device))
40
  output_tensor = output_tensor.squeeze(0)
41
  output_image = transforms.ToPILImage()(output_tensor)
42
  output_image = output_image.convert('RGB')