root commited on
Commit
7083a27
·
1 Parent(s): 50f301d
Files changed (1) hide show
  1. handler.py +1 -0
handler.py CHANGED
@@ -51,6 +51,7 @@ class EndpointHandler():
51
 
52
  reference_unet = UNet2DConditionModel.from_pretrained(
53
  pretrained_base_model_path_unet,
 
54
  ).to(device, dtype=self.weight_dtype)
55
 
56
  inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml')
 
51
 
52
  reference_unet = UNet2DConditionModel.from_pretrained(
53
  pretrained_base_model_path_unet,
54
+ variant="fp16",
55
  ).to(device, dtype=self.weight_dtype)
56
 
57
  inference_config_path = os.path.join(base_dir, 'configs', 'inference', 'inference_v2.yaml')