kael558 commited on
Commit
fc15b05
·
1 Parent(s): d955083

change device to cpu

Browse files
Files changed (1) hide show
  1. sd.py +4 -5
sd.py CHANGED
@@ -47,7 +47,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
47
  from ldm.models.diffusion.plms import PLMSSampler
48
 
49
  # 2. Set model download config
50
- def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=False):
51
  map_location = "cuda" if torch.cuda.is_available() else "cpu"
52
  print(f"Loading model from {ckpt}")
53
  pl_sd = torch.load(ckpt, map_location=map_location)
@@ -64,9 +64,7 @@ def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_prec
64
  print(u)
65
 
66
  if half_precision:
67
- model = model.half().to(device)
68
- else:
69
- model = model.to(device)
70
  model.eval()
71
  return model
72
 
@@ -93,7 +91,8 @@ ckpt_path = hf_hub_download(repo_id="CompVis/stable-diffusion-v-1-4-original", f
93
  ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
94
 
95
  local_config = OmegaConf.load(f"{ckpt_config_path}")
96
- model = load_model_from_config(local_config, f"{ckpt_path}")
 
97
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
98
  model = model.to(device)
99
  print('Model saved.')
 
47
  from ldm.models.diffusion.plms import PLMSSampler
48
 
49
  # 2. Set model download config
50
+ def load_model_from_config(config, ckpt, verbose=False, half_precision=True):
51
  map_location = "cuda" if torch.cuda.is_available() else "cpu"
52
  print(f"Loading model from {ckpt}")
53
  pl_sd = torch.load(ckpt, map_location=map_location)
 
64
  print(u)
65
 
66
  if half_precision:
67
+ model = model.half()
 
 
68
  model.eval()
69
  return model
70
 
 
91
  ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
92
 
93
  local_config = OmegaConf.load(f"{ckpt_config_path}")
94
+
95
+ model = load_model_from_config(local_config, f"{ckpt_path}" )
96
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
97
  model = model.to(device)
98
  print('Model saved.')