Husr commited on
Commit
e963edc
·
1 Parent(s): 46983e8

Align defaults with official example (keep Advanced controls)

Browse files
Files changed (2) hide show
  1. README.md +2 -1
  2. app.py +96 -37
README.md CHANGED
@@ -54,7 +54,8 @@ Place the LoRA file under `lora/` first (or set `LORA_PATH`); otherwise the app
54
  - Prompt
55
  - Resolution category + explicit WxH selection
56
  - Seed (with random toggle)
57
- - Steps, CFG, scheduler + shift (and extra scheduler params), max sequence length
 
58
  - LoRA toggle + strength (enabled only if the file is found)
59
 
60
  ## Git LFS note
 
54
  - Prompt
55
  - Resolution category + explicit WxH selection
56
  - Seed (with random toggle)
57
+ - Steps + Time Shift
58
+ - Advanced: CFG, scheduler + extra scheduler params, max sequence length
59
  - LoRA toggle + strength (enabled only if the file is found)
60
 
61
  ## Git LFS note
app.py CHANGED
@@ -150,10 +150,12 @@ EXAMPLE_PROMPTS = [
150
  pipe: ZImagePipeline | None = None
151
  lora_loaded: bool = False
152
  lora_error: str | None = None
 
153
  pipe_lock = threading.Lock()
154
  pipe_on_gpu: bool = False
155
  aoti_loaded: bool = False
156
  applied_attention_backend: str | None = None
 
157
  aoti_error: str | None = None
158
  transformer_compiled: bool = False
159
  transformer_compile_attempted: bool = False
@@ -167,7 +169,6 @@ try:
167
  except Exception:
168
  pass
169
 
170
-
171
  def module_available(module_name: str) -> bool:
172
  try:
173
  return importlib.util.find_spec(module_name) is not None
@@ -175,6 +176,13 @@ def module_available(module_name: str) -> bool:
175
  return False
176
 
177
 
 
 
 
 
 
 
 
178
  def parse_resolution(resolution: str) -> Tuple[int, int]:
179
  match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
180
  if match:
@@ -183,6 +191,7 @@ def parse_resolution(resolution: str) -> Tuple[int, int]:
183
 
184
 
185
  def set_attention_backend_safe(transformer, backend: str) -> str:
 
186
  candidates: List[str] = []
187
  if backend:
188
  candidates.append(backend)
@@ -192,41 +201,76 @@ def set_attention_backend_safe(transformer, backend: str) -> str:
192
  candidates.append(f"_{backend}")
193
  candidates.extend(["flash", "xformers", "native"])
194
 
 
 
195
  last_exc: Exception | None = None
196
  for name in candidates:
197
  if not name:
198
  continue
199
  try:
200
  transformer.set_attention_backend(name)
 
 
 
 
 
 
 
201
  return name
202
  except Exception as exc: # noqa: BLE001
203
  last_exc = exc
 
204
  continue
205
 
206
  raise RuntimeError(f"Failed to set attention backend (tried {candidates}): {last_exc}")
207
 
208
 
209
  def attach_lora(pipeline: ZImagePipeline) -> Tuple[bool, str | None]:
 
210
  if not LORA_PATH or not os.path.isfile(LORA_PATH):
211
  return False, "LoRA file not found"
212
  if not module_available("peft"):
213
  return False, "PEFT backend is required for LoRA. Install `peft` and restart."
 
 
 
 
 
 
 
 
214
  try:
215
  folder, weight_name = os.path.split(LORA_PATH)
216
  folder = folder or "."
217
- pipeline.load_lora_weights(folder, weight_name=weight_name)
218
- set_lora_scale(pipeline, 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  return True, None
220
  except Exception as exc: # noqa: BLE001
 
221
  return False, f"Failed to load LoRA: {exc}"
222
 
223
 
224
  def set_lora_scale(pipeline: ZImagePipeline, scale: float) -> None:
225
  weight = max(float(scale), 0.0)
 
226
  try:
227
- pipeline.set_adapters(["default"], adapter_weights=[weight])
228
  except TypeError:
229
- pipeline.set_adapters(["default"], weights=[weight])
230
 
231
 
232
  def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
@@ -426,9 +470,9 @@ def generate_image(
426
  steps: int,
427
  shift: float,
428
  guidance_scale: float,
429
- max_sequence_length: int,
430
  use_lora: bool,
431
  lora_scale: float,
 
432
  scheduler_name: str,
433
  num_train_timesteps: int,
434
  use_dynamic_shifting: bool,
@@ -439,17 +483,17 @@ def generate_image(
439
  generator = torch.Generator("cuda").manual_seed(seed)
440
  set_scheduler(
441
  pipeline,
442
- scheduler_name,
443
- num_train_timesteps=num_train_timesteps,
444
- shift=shift,
445
- use_dynamic_shifting=use_dynamic_shifting,
446
- base_shift=base_shift,
447
- max_shift=max_shift,
448
  )
449
 
450
  if lora_loaded:
451
  if use_lora:
452
- set_lora_scale(pipeline, lora_scale)
453
  else:
454
  set_lora_scale(pipeline, 0.0)
455
 
@@ -458,10 +502,10 @@ def generate_image(
458
  prompt=prompt,
459
  height=height,
460
  width=width,
461
- guidance_scale=guidance_scale,
462
- num_inference_steps=steps,
463
  generator=generator,
464
- max_sequence_length=max_sequence_length,
465
  ).images[0]
466
  return image, seed
467
 
@@ -479,9 +523,9 @@ def warmup_model(pipeline: ZImagePipeline, resolutions: List[str]) -> None:
479
  steps=9,
480
  shift=3.0,
481
  guidance_scale=0.0,
482
- max_sequence_length=512,
483
  use_lora=False,
484
  lora_scale=0.0,
 
485
  scheduler_name="FlowMatch Euler",
486
  num_train_timesteps=1000,
487
  use_dynamic_shifting=False,
@@ -500,15 +544,20 @@ def init_app() -> None:
500
  if ENABLE_COMPILE and pipe is not None:
501
  ensure_on_gpu()
502
  if ENABLE_AOTI and not aoti_loaded and pipe is not None and getattr(pipe, "transformer", None) is not None:
503
- try:
504
- pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
505
- spaces.aoti_blocks_load(pipe.transformer.layers, AOTI_REPO, variant=AOTI_VARIANT)
506
- aoti_loaded = True
507
- aoti_error = None
508
- print(f"AoTI loaded: {AOTI_REPO} (variant={AOTI_VARIANT})")
509
- except Exception as exc: # noqa: BLE001
510
- aoti_error = str(exc)
511
- print(f"AoTI load failed (continuing without AoTI): {exc}")
 
 
 
 
 
512
  if ENABLE_WARMUP and pipe is not None:
513
  ensure_on_gpu()
514
  try:
@@ -551,15 +600,15 @@ def generate(
551
  try:
552
  image = generate_image(
553
  pipeline=pipe,
554
- prompt=prompt,
555
- resolution=resolution.split(" ")[0] if " " in resolution else resolution,
556
  seed=new_seed,
557
  steps=int(steps) + 1,
558
  shift=float(shift),
559
  guidance_scale=float(cfg),
560
- max_sequence_length=int(max_sequence_length),
561
  use_lora=use_lora,
562
  lora_scale=float(lora_scale),
 
563
  scheduler_name=str(scheduler_name),
564
  num_train_timesteps=int(num_train_timesteps),
565
  use_dynamic_shifting=bool(use_dynamic_shifting),
@@ -582,14 +631,24 @@ with gr.Blocks(title="Z-Image + LoRA") as demo:
582
  pipe_status = "loaded (GPU)" if pipe and pipe_on_gpu else "loaded (CPU)" if pipe else "not loaded"
583
  lora_file_status = "found" if os.path.isfile(LORA_PATH) else "missing"
584
  if lora_loaded:
585
- lora_status = f"LoRA: loaded ({LORA_PATH})"
 
586
  elif lora_error:
587
  lora_status = f"LoRA: not loaded ({lora_error})"
588
  else:
589
  lora_status = f"LoRA file: {LORA_PATH} ({lora_file_status})"
590
 
591
  attention_status = applied_attention_backend or "unknown"
592
- aoti_status = "loaded" if aoti_loaded else f"failed ({aoti_error})" if aoti_error else "not loaded"
 
 
 
 
 
 
 
 
 
593
  if not ENABLE_COMPILE:
594
  compile_status = "off"
595
  elif transformer_compiled:
@@ -629,10 +688,12 @@ Attention: `{attention_status}` | AoTI: `{aoti_status}` | torch.compile: `{compi
629
  seed = gr.Number(label="Seed", value=42, precision=0)
630
  random_seed = gr.Checkbox(label="Random Seed", value=True)
631
 
 
 
 
 
632
  with gr.Accordion("KSampler / Advanced", open=False):
633
- with gr.Row():
634
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1)
635
- cfg = gr.Slider(label="CFG", minimum=0.0, maximum=10.0, value=DEFAULT_CFG, step=0.1)
636
 
637
  with gr.Row():
638
  scheduler_name = gr.Dropdown(
@@ -649,15 +710,13 @@ Attention: `{attention_status}` | AoTI: `{aoti_status}` | torch.compile: `{compi
649
  )
650
 
651
  with gr.Row():
652
- shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
653
  use_dynamic_shifting = gr.Checkbox(label="use_dynamic_shifting", value=False)
 
654
 
655
  with gr.Row():
656
  base_shift = gr.Slider(label="base_shift", minimum=0.0, maximum=10.0, value=0.5, step=0.1)
657
  max_shift = gr.Slider(label="max_shift", minimum=0.0, maximum=10.0, value=3.0, step=0.1)
658
 
659
- max_seq = gr.Slider(label="Max Sequence Length", minimum=256, maximum=1024, value=512, step=16)
660
-
661
  with gr.Row():
662
  lora_controls_enabled = bool(lora_loaded)
663
  use_lora = gr.Checkbox(label="Use LoRA", value=lora_controls_enabled, interactive=lora_controls_enabled)
 
150
  pipe: ZImagePipeline | None = None
151
  lora_loaded: bool = False
152
  lora_error: str | None = None
153
+ lora_adapter_name: str | None = None
154
  pipe_lock = threading.Lock()
155
  pipe_on_gpu: bool = False
156
  aoti_loaded: bool = False
157
  applied_attention_backend: str | None = None
158
+ attention_backend_error: str | None = None
159
  aoti_error: str | None = None
160
  transformer_compiled: bool = False
161
  transformer_compile_attempted: bool = False
 
169
  except Exception:
170
  pass
171
 
 
172
  def module_available(module_name: str) -> bool:
173
  try:
174
  return importlib.util.find_spec(module_name) is not None
 
176
  return False
177
 
178
 
179
+ def summarize_error(message: str, *, max_len: int = 120) -> str:
180
+ one_line = " ".join(str(message).split())
181
+ if len(one_line) <= max_len:
182
+ return one_line
183
+ return one_line[: max_len - 1] + "…"
184
+
185
+
186
  def parse_resolution(resolution: str) -> Tuple[int, int]:
187
  match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
188
  if match:
 
191
 
192
 
193
  def set_attention_backend_safe(transformer, backend: str) -> str:
194
+ global attention_backend_error
195
  candidates: List[str] = []
196
  if backend:
197
  candidates.append(backend)
 
201
  candidates.append(f"_{backend}")
202
  candidates.extend(["flash", "xformers", "native"])
203
 
204
+ attention_backend_error = None
205
+ errors: dict[str, Exception] = {}
206
  last_exc: Exception | None = None
207
  for name in candidates:
208
  if not name:
209
  continue
210
  try:
211
  transformer.set_attention_backend(name)
212
+ if backend and name != backend:
213
+ for key in (backend, backend.lstrip("_"), f"_{backend}"):
214
+ if key in errors:
215
+ attention_backend_error = str(errors[key])
216
+ break
217
+ if attention_backend_error is None and last_exc is not None:
218
+ attention_backend_error = str(last_exc)
219
  return name
220
  except Exception as exc: # noqa: BLE001
221
  last_exc = exc
222
+ errors[name] = exc
223
  continue
224
 
225
  raise RuntimeError(f"Failed to set attention backend (tried {candidates}): {last_exc}")
226
 
227
 
228
  def attach_lora(pipeline: ZImagePipeline) -> Tuple[bool, str | None]:
229
+ global lora_adapter_name
230
  if not LORA_PATH or not os.path.isfile(LORA_PATH):
231
  return False, "LoRA file not found"
232
  if not module_available("peft"):
233
  return False, "PEFT backend is required for LoRA. Install `peft` and restart."
234
+
235
+ def extract_present_adapter_names(exc: Exception) -> List[str]:
236
+ msg = str(exc)
237
+ match = re.search(r"present adapters:\s*(\{[^}]*\})", msg)
238
+ if not match:
239
+ return []
240
+ return re.findall(r"'([^']+)'", match.group(1))
241
+
242
  try:
243
  folder, weight_name = os.path.split(LORA_PATH)
244
  folder = folder or "."
245
+ preferred_adapter = os.environ.get("LORA_ADAPTER_NAME", "default")
246
+ lora_adapter_name = preferred_adapter
247
+ try:
248
+ pipeline.load_lora_weights(folder, weight_name=weight_name, adapter_name=preferred_adapter)
249
+ except TypeError:
250
+ pipeline.load_lora_weights(folder, weight_name=weight_name)
251
+
252
+ try:
253
+ set_lora_scale(pipeline, 1.0)
254
+ except Exception as exc: # noqa: BLE001
255
+ adapter_names = extract_present_adapter_names(exc)
256
+ if adapter_names:
257
+ lora_adapter_name = adapter_names[0]
258
+ set_lora_scale(pipeline, 1.0)
259
+ else:
260
+ raise
261
  return True, None
262
  except Exception as exc: # noqa: BLE001
263
+ lora_adapter_name = None
264
  return False, f"Failed to load LoRA: {exc}"
265
 
266
 
267
  def set_lora_scale(pipeline: ZImagePipeline, scale: float) -> None:
268
  weight = max(float(scale), 0.0)
269
+ adapter = lora_adapter_name or "default"
270
  try:
271
+ pipeline.set_adapters([adapter], adapter_weights=[weight])
272
  except TypeError:
273
+ pipeline.set_adapters([adapter], weights=[weight])
274
 
275
 
276
  def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
 
470
  steps: int,
471
  shift: float,
472
  guidance_scale: float,
 
473
  use_lora: bool,
474
  lora_scale: float,
475
+ max_sequence_length: int,
476
  scheduler_name: str,
477
  num_train_timesteps: int,
478
  use_dynamic_shifting: bool,
 
483
  generator = torch.Generator("cuda").manual_seed(seed)
484
  set_scheduler(
485
  pipeline,
486
+ str(scheduler_name),
487
+ num_train_timesteps=int(num_train_timesteps),
488
+ shift=float(shift),
489
+ use_dynamic_shifting=bool(use_dynamic_shifting),
490
+ base_shift=float(base_shift),
491
+ max_shift=float(max_shift),
492
  )
493
 
494
  if lora_loaded:
495
  if use_lora:
496
+ set_lora_scale(pipeline, float(lora_scale))
497
  else:
498
  set_lora_scale(pipeline, 0.0)
499
 
 
502
  prompt=prompt,
503
  height=height,
504
  width=width,
505
+ guidance_scale=float(guidance_scale),
506
+ num_inference_steps=int(steps),
507
  generator=generator,
508
+ max_sequence_length=int(max_sequence_length),
509
  ).images[0]
510
  return image, seed
511
 
 
523
  steps=9,
524
  shift=3.0,
525
  guidance_scale=0.0,
 
526
  use_lora=False,
527
  lora_scale=0.0,
528
+ max_sequence_length=512,
529
  scheduler_name="FlowMatch Euler",
530
  num_train_timesteps=1000,
531
  use_dynamic_shifting=False,
 
544
  if ENABLE_COMPILE and pipe is not None:
545
  ensure_on_gpu()
546
  if ENABLE_AOTI and not aoti_loaded and pipe is not None and getattr(pipe, "transformer", None) is not None:
547
+ if not module_available("kernels"):
548
+ aoti_loaded = False
549
+ aoti_error = "kernels module not available"
550
+ print("AoTI unavailable (kernels module not available).")
551
+ else:
552
+ try:
553
+ pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"]
554
+ spaces.aoti_blocks_load(pipe.transformer.layers, AOTI_REPO, variant=AOTI_VARIANT)
555
+ aoti_loaded = True
556
+ aoti_error = None
557
+ print(f"AoTI loaded: {AOTI_REPO} (variant={AOTI_VARIANT})")
558
+ except Exception as exc: # noqa: BLE001
559
+ aoti_error = str(exc)
560
+ print(f"AoTI load failed (continuing without AoTI): {exc}")
561
  if ENABLE_WARMUP and pipe is not None:
562
  ensure_on_gpu()
563
  try:
 
600
  try:
601
  image = generate_image(
602
  pipeline=pipe,
603
+ prompt=str(prompt),
604
+ resolution=str(resolution),
605
  seed=new_seed,
606
  steps=int(steps) + 1,
607
  shift=float(shift),
608
  guidance_scale=float(cfg),
 
609
  use_lora=use_lora,
610
  lora_scale=float(lora_scale),
611
+ max_sequence_length=int(max_sequence_length),
612
  scheduler_name=str(scheduler_name),
613
  num_train_timesteps=int(num_train_timesteps),
614
  use_dynamic_shifting=bool(use_dynamic_shifting),
 
631
  pipe_status = "loaded (GPU)" if pipe and pipe_on_gpu else "loaded (CPU)" if pipe else "not loaded"
632
  lora_file_status = "found" if os.path.isfile(LORA_PATH) else "missing"
633
  if lora_loaded:
634
+ adapter = lora_adapter_name or "default"
635
+ lora_status = f"LoRA: loaded ({LORA_PATH}, adapter={adapter})"
636
  elif lora_error:
637
  lora_status = f"LoRA: not loaded ({lora_error})"
638
  else:
639
  lora_status = f"LoRA file: {LORA_PATH} ({lora_file_status})"
640
 
641
  attention_status = applied_attention_backend or "unknown"
642
+ if attention_backend_error and ATTENTION_BACKEND and attention_status != ATTENTION_BACKEND:
643
+ attention_status = f"{attention_status} ({ATTENTION_BACKEND} unavailable: {summarize_error(attention_backend_error)})"
644
+
645
+ if aoti_loaded:
646
+ aoti_status = "loaded"
647
+ elif aoti_error:
648
+ label = "unavailable" if "kernels" in aoti_error.lower() else "failed"
649
+ aoti_status = f"{label} ({summarize_error(aoti_error)})"
650
+ else:
651
+ aoti_status = "not loaded"
652
  if not ENABLE_COMPILE:
653
  compile_status = "off"
654
  elif transformer_compiled:
 
688
  seed = gr.Number(label="Seed", value=42, precision=0)
689
  random_seed = gr.Checkbox(label="Random Seed", value=True)
690
 
691
+ with gr.Row():
692
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=8, step=1)
693
+ shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
694
+
695
  with gr.Accordion("KSampler / Advanced", open=False):
696
+ cfg = gr.Slider(label="CFG", minimum=0.0, maximum=10.0, value=DEFAULT_CFG, step=0.1)
 
 
697
 
698
  with gr.Row():
699
  scheduler_name = gr.Dropdown(
 
710
  )
711
 
712
  with gr.Row():
 
713
  use_dynamic_shifting = gr.Checkbox(label="use_dynamic_shifting", value=False)
714
+ max_seq = gr.Slider(label="Max Sequence Length", minimum=256, maximum=1024, value=512, step=16)
715
 
716
  with gr.Row():
717
  base_shift = gr.Slider(label="base_shift", minimum=0.0, maximum=10.0, value=0.5, step=0.1)
718
  max_shift = gr.Slider(label="max_shift", minimum=0.0, maximum=10.0, value=3.0, step=0.1)
719
 
 
 
720
  with gr.Row():
721
  lora_controls_enabled = bool(lora_loaded)
722
  use_lora = gr.Checkbox(label="Use LoRA", value=lora_controls_enabled, interactive=lora_controls_enabled)