Husr commited on
Commit
edfb0d8
·
1 Parent(s): 8f553ed

Fix LoRA set_adapters when transformer compiled

Browse files
Files changed (1) hide show
  1. app.py +37 -7
app.py CHANGED
@@ -273,10 +273,37 @@ def attach_lora(pipeline: ZImagePipeline) -> Tuple[bool, str | None]:
273
  def set_lora_scale(pipeline: ZImagePipeline, scale: float) -> None:
274
  weight = max(float(scale), 0.0)
275
  adapter = lora_adapter_name or "default"
 
276
  try:
277
- pipeline.set_adapters([adapter], adapter_weights=[weight])
278
- except TypeError:
279
- pipeline.set_adapters([adapter], weights=[weight])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
 
282
  def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
@@ -498,10 +525,13 @@ def generate_image(
498
  )
499
 
500
  if lora_loaded:
501
- if use_lora:
502
- set_lora_scale(pipeline, float(lora_scale))
503
- else:
504
- set_lora_scale(pipeline, 0.0)
 
 
 
505
 
506
  with torch.inference_mode():
507
  image = pipeline(
 
273
  def set_lora_scale(pipeline: ZImagePipeline, scale: float) -> None:
274
  weight = max(float(scale), 0.0)
275
  adapter = lora_adapter_name or "default"
276
+ compiled_transformer = None
277
  try:
278
+ transformer = getattr(pipeline, "transformer", None)
279
+ if transformer is not None and hasattr(transformer, "_orig_mod"):
280
+ compiled_transformer = transformer
281
+ pipeline.transformer = transformer._orig_mod
282
+
283
+ try:
284
+ pipeline.set_adapters([adapter], adapter_weights=[weight])
285
+ except TypeError:
286
+ pipeline.set_adapters([adapter], weights=[weight])
287
+ except ValueError as exc:
288
+ msg = str(exc)
289
+ present_match = re.search(r"present adapters:\s*(\{[^}]*\}|set\([^)]*\))", msg)
290
+ if present_match:
291
+ present_names = re.findall(r"'([^']+)'", present_match.group(1))
292
+ else:
293
+ present_names = []
294
+ if present_names:
295
+ global lora_adapter_name
296
+ lora_adapter_name = present_names[0]
297
+ adapter = lora_adapter_name
298
+ try:
299
+ pipeline.set_adapters([adapter], adapter_weights=[weight])
300
+ except TypeError:
301
+ pipeline.set_adapters([adapter], weights=[weight])
302
+ else:
303
+ raise
304
+ finally:
305
+ if compiled_transformer is not None:
306
+ pipeline.transformer = compiled_transformer
307
 
308
 
309
  def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
 
525
  )
526
 
527
  if lora_loaded:
528
+ try:
529
+ if use_lora:
530
+ set_lora_scale(pipeline, float(lora_scale))
531
+ else:
532
+ set_lora_scale(pipeline, 0.0)
533
+ except Exception as exc: # noqa: BLE001
534
+ print(f"LoRA scale update failed (continuing without changing LoRA state): {exc}")
535
 
536
  with torch.inference_mode():
537
  image = pipeline(