root commited on
Commit
2a94fc4
·
1 Parent(s): f751834

trying device type

Browse files
Files changed (1) hide show
  1. handler.py +2 -2
handler.py CHANGED
@@ -51,8 +51,8 @@ class EndpointHandler():
51
 
52
  print("model path is " + pretrained_base_model_path_unet)
53
  reference_unet = UNet2DConditionModel.from_pretrained(
54
- "./pretrained_weights/stable-diffusion-v1-5/unet",
55
- ).to(device, dtype=self.weight_dtype)
56
 
57
  inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml')
58
  motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth')
 
51
 
52
  print("model path is " + pretrained_base_model_path_unet)
53
  reference_unet = UNet2DConditionModel.from_pretrained(
54
+ pretrained_base_model_path_unet,
55
+ ).to(dtype=self.weight_dtype, device="cuda")
56
 
57
  inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml')
58
  motion_module_path = os.path.join(base_dir, 'pretrained_weights', 'motion_module.pth')