Husr commited on
Commit
286e141
·
1 Parent(s): 5a54a8f

Disable AoTI when LoRA loaded (avoid runtime crash)

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. app.py +27 -15
README.md CHANGED
@@ -40,6 +40,7 @@ Gradio Space using the official Z-Image pipeline (`Tongyi-MAI/Z-Image-Turbo`) wi
40
  - `ENABLE_AOTI` (default `true`): Try to load ZeroGPU AoTI blocks via `spaces.aoti_blocks_load` for faster inference.
41
  - `AOTI_REPO` (default `zerogpu-aoti/Z-Image`): AoTI blocks repo.
42
  - `AOTI_VARIANT` (default `fa3`): AoTI variant.
 
43
 
44
  ## Run locally
45
  ```bash
 
40
  - `ENABLE_AOTI` (default `true`): Try to load ZeroGPU AoTI blocks via `spaces.aoti_blocks_load` for faster inference.
41
  - `AOTI_REPO` (default `zerogpu-aoti/Z-Image`): AoTI blocks repo.
42
  - `AOTI_VARIANT` (default `fa3`): AoTI variant.
43
+ - `AOTI_ALLOW_LORA` (default `false`): Allow AoTI to load even if LoRA adapters are loaded (may crash; AoTI blocks generally don’t support LoRA).
44
 
45
  ## Run locally
46
  ```bash
app.py CHANGED
@@ -24,6 +24,7 @@ OFFLOAD_TO_CPU_AFTER_RUN = os.environ.get("OFFLOAD_TO_CPU_AFTER_RUN", "false").l
24
  ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "true").lower() == "true"
25
  AOTI_REPO = os.environ.get("AOTI_REPO", "zerogpu-aoti/Z-Image")
26
  AOTI_VARIANT = os.environ.get("AOTI_VARIANT", "fa3")
 
27
  DEFAULT_CFG = float(os.environ.get("DEFAULT_CFG", "0.0"))
28
 
29
 
@@ -549,21 +550,26 @@ def init_app() -> None:
549
  if ENABLE_COMPILE and pipe is not None:
550
  ensure_on_gpu()
551
  if ENABLE_AOTI and not aoti_loaded and pipe is not None and getattr(pipe, "transformer", None) is not None:
552
- try:
553
- pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
554
- spaces.aoti_blocks_load(pipe.transformer.layers, AOTI_REPO, variant=AOTI_VARIANT)
555
- aoti_loaded = True
556
- aoti_error = None
557
- print(f"AoTI loaded: {AOTI_REPO} (variant={AOTI_VARIANT})")
558
- except Exception as exc: # noqa: BLE001
559
  aoti_loaded = False
560
- aoti_error = str(exc)
561
- print(f"AoTI load failed (continuing without AoTI): {exc}")
562
- try:
563
- applied_attention_backend = set_attention_backend_safe(pipe.transformer, ATTENTION_BACKEND)
564
- print(f"Attention backend (post-AoTI): {applied_attention_backend}")
565
- except Exception as exc: # noqa: BLE001
566
- print(f"Attention backend update failed (continuing): {exc}")
 
 
 
 
 
 
 
 
 
 
 
567
  if ENABLE_WARMUP and pipe is not None:
568
  ensure_on_gpu()
569
  try:
@@ -651,7 +657,13 @@ with gr.Blocks(title="Z-Image + LoRA") as demo:
651
  if aoti_loaded:
652
  aoti_status = "loaded"
653
  elif aoti_error:
654
- label = "unavailable" if "kernels" in aoti_error.lower() else "failed"
 
 
 
 
 
 
655
  aoti_status = f"{label} ({summarize_error(aoti_error)})"
656
  else:
657
  aoti_status = "not loaded"
 
24
  ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "true").lower() == "true"
25
  AOTI_REPO = os.environ.get("AOTI_REPO", "zerogpu-aoti/Z-Image")
26
  AOTI_VARIANT = os.environ.get("AOTI_VARIANT", "fa3")
27
+ AOTI_ALLOW_LORA = os.environ.get("AOTI_ALLOW_LORA", "false").lower() == "true"
28
  DEFAULT_CFG = float(os.environ.get("DEFAULT_CFG", "0.0"))
29
 
30
 
 
550
  if ENABLE_COMPILE and pipe is not None:
551
  ensure_on_gpu()
552
  if ENABLE_AOTI and not aoti_loaded and pipe is not None and getattr(pipe, "transformer", None) is not None:
553
+ if lora_loaded and not AOTI_ALLOW_LORA:
 
 
 
 
 
 
554
  aoti_loaded = False
555
+ aoti_error = "disabled: AoTI blocks are incompatible with LoRA adapters"
556
+ print("AoTI disabled: LoRA adapters are loaded (AoTI blocks are incompatible with LoRA).")
557
+ else:
558
+ try:
559
+ pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
560
+ spaces.aoti_blocks_load(pipe.transformer.layers, AOTI_REPO, variant=AOTI_VARIANT)
561
+ aoti_loaded = True
562
+ aoti_error = None
563
+ print(f"AoTI loaded: {AOTI_REPO} (variant={AOTI_VARIANT})")
564
+ except Exception as exc: # noqa: BLE001
565
+ aoti_loaded = False
566
+ aoti_error = str(exc)
567
+ print(f"AoTI load failed (continuing without AoTI): {exc}")
568
+ try:
569
+ applied_attention_backend = set_attention_backend_safe(pipe.transformer, ATTENTION_BACKEND)
570
+ print(f"Attention backend (post-AoTI): {applied_attention_backend}")
571
+ except Exception as exc: # noqa: BLE001
572
+ print(f"Attention backend update failed (continuing): {exc}")
573
  if ENABLE_WARMUP and pipe is not None:
574
  ensure_on_gpu()
575
  try:
 
657
  if aoti_loaded:
658
  aoti_status = "loaded"
659
  elif aoti_error:
660
+ lower = aoti_error.lower()
661
+ if "disabled" in lower:
662
+ label = "disabled"
663
+ elif "kernels" in lower:
664
+ label = "unavailable"
665
+ else:
666
+ label = "failed"
667
  aoti_status = f"{label} ({summarize_error(aoti_error)})"
668
  else:
669
  aoti_status = "not loaded"