root commited on
Commit
38a73c3
·
1 Parent(s): 9f5c9c8

trying cuda

Browse files
Files changed (1) hide show
  1. handler.py +2 -3
handler.py CHANGED
@@ -50,9 +50,8 @@ class EndpointHandler():
50
  pretrained_base_model_path_unet = os.path.join(base_dir, 'pretrained_weights', 'stable-diffusion-v1-5', 'unet')
51
  print("model path is " + pretrained_base_model_path_unet)
52
  reference_unet = UNet2DConditionModel.from_pretrained(
53
- pretrained_base_model_path_unet,
54
- local_files_only=True
55
- ).to(device, dtype=self.weight_dtype)
56
 
57
  inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml')
58
  motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth')
 
50
  pretrained_base_model_path_unet = os.path.join(base_dir, 'pretrained_weights', 'stable-diffusion-v1-5', 'unet')
51
  print("model path is " + pretrained_base_model_path_unet)
52
  reference_unet = UNet2DConditionModel.from_pretrained(
53
+ pretrained_base_model_path_unet
54
+ ).to(dtype=self.weight_dtype, device="cuda")
 
55
 
56
  inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml')
57
  motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth')