Shalmoni commited on
Commit
623b9fe
·
verified ·
1 Parent(s): 96406a7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -234,7 +234,7 @@ _sd_t2i = None
234
  _sd_i2i = None
235
 
236
  def _lazy_sd_pipes():
237
- """Load SD once, disable safety checker to avoid offload_state_dict issues; reuse modules for img2img."""
238
  global _sd_t2i, _sd_i2i
239
  if _sd_t2i is not None and _sd_i2i is not None:
240
  return _sd_t2i, _sd_i2i
@@ -243,14 +243,16 @@ def _lazy_sd_pipes():
243
 
244
  _sd_t2i = StableDiffusionPipeline.from_pretrained(
245
  SD_MODEL,
246
- torch_dtype=dtype,
247
  safety_checker=None,
248
  feature_extractor=None,
249
- use_safetensors=True
 
250
  )
251
  if torch.cuda.is_available():
252
  _sd_t2i = _sd_t2i.to("cuda")
253
 
 
254
  _sd_i2i = StableDiffusionImg2ImgPipeline(
255
  vae=_sd_t2i.vae,
256
  text_encoder=_sd_t2i.text_encoder,
@@ -265,6 +267,7 @@ def _lazy_sd_pipes():
265
 
266
  return _sd_t2i, _sd_i2i
267
 
 
268
  def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
269
  pdir = project_dir(pid)
270
  out = os.path.join(pdir, "keyframes", f"shot_{shot_id:02d}.png")
 
234
  _sd_i2i = None
235
 
236
  def _lazy_sd_pipes():
237
+ """Load SD once without low_cpu_mem_usage to avoid offload_state_dict kwarg; reuse modules for img2img."""
238
  global _sd_t2i, _sd_i2i
239
  if _sd_t2i is not None and _sd_i2i is not None:
240
  return _sd_t2i, _sd_i2i
 
243
 
244
  _sd_t2i = StableDiffusionPipeline.from_pretrained(
245
  SD_MODEL,
246
+ dtype=dtype, # (`torch_dtype` is deprecated, use `dtype`)
247
  safety_checker=None,
248
  feature_extractor=None,
249
+ use_safetensors=True,
250
+ low_cpu_mem_usage=False # <-- critical: prevents passing offload_state_dict
251
  )
252
  if torch.cuda.is_available():
253
  _sd_t2i = _sd_t2i.to("cuda")
254
 
255
+ # Build img2img from already-loaded modules (avoids another from_pretrained call)
256
  _sd_i2i = StableDiffusionImg2ImgPipeline(
257
  vae=_sd_t2i.vae,
258
  text_encoder=_sd_t2i.text_encoder,
 
267
 
268
  return _sd_t2i, _sd_i2i
269
 
270
+
271
  def _save_keyframe(pid: str, shot_id: int, img: Image.Image) -> str:
272
  pdir = project_dir(pid)
273
  out = os.path.join(pdir, "keyframes", f"shot_{shot_id:02d}.png")