Husr commited on
Commit
94ea93d
·
1 Parent(s): a3e095a

修复加载bug

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +146 -17
README.md CHANGED
@@ -52,7 +52,7 @@ Place the LoRA file under `lora/` first (or set `LORA_PATH`); otherwise the app
52
  - Prompt
53
  - Resolution category + explicit WxH selection
54
  - Seed (with random toggle)
55
- - Steps, time shift, max sequence length
56
  - LoRA toggle + strength (enabled only if the file is found)
57
 
58
  ## Git LFS note
 
52
  - Prompt
53
  - Resolution category + explicit WxH selection
54
  - Seed (with random toggle)
55
+ - Steps, CFG, scheduler + shift (and extra scheduler params), max sequence length
56
  - LoRA toggle + strength (enabled only if the file is found)
57
 
58
  ## Git LFS note
app.py CHANGED
@@ -3,6 +3,7 @@ import random
3
  import re
4
  import threading
5
  import warnings
 
6
  from typing import List, Tuple
7
 
8
  import gradio as gr
@@ -22,6 +23,7 @@ OFFLOAD_TO_CPU_AFTER_RUN = os.environ.get("OFFLOAD_TO_CPU_AFTER_RUN", "true").lo
22
  ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "false").lower() == "true"
23
  AOTI_REPO = os.environ.get("AOTI_REPO", "zerogpu-aoti/Z-Image")
24
  AOTI_VARIANT = os.environ.get("AOTI_VARIANT", "fa3")
 
25
 
26
  warnings.filterwarnings("ignore")
27
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -101,6 +103,14 @@ pipe_lock = threading.Lock()
101
  pipe_on_gpu: bool = False
102
  aoti_loaded: bool = False
103
 
 
 
 
 
 
 
 
 
104
 
105
  def parse_resolution(resolution: str) -> Tuple[int, int]:
106
  match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
@@ -109,6 +119,30 @@ def parse_resolution(resolution: str) -> Tuple[int, int]:
109
  return 1024, 1024
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def attach_lora(pipeline: ZImagePipeline) -> Tuple[bool, str | None]:
113
  if not LORA_PATH or not os.path.isfile(LORA_PATH):
114
  return False, "LoRA file not found"
@@ -132,7 +166,7 @@ def set_lora_scale(pipeline: ZImagePipeline, scale: float) -> None:
132
 
133
  def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
134
  global pipe, lora_loaded, lora_error
135
- if pipe is not None:
136
  return pipe, lora_loaded, lora_error
137
 
138
  use_auth_token = HF_TOKEN if HF_TOKEN else None
@@ -163,7 +197,7 @@ def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
163
 
164
  tokenizer.padding_side = "left"
165
 
166
- pipe = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
167
 
168
  if not os.path.exists(MODEL_PATH):
169
  transformer = ZImageTransformer2DModel.from_pretrained(
@@ -178,26 +212,31 @@ def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
178
  torch_dtype=torch.bfloat16,
179
  )
180
 
181
- transformer.set_attention_backend(ATTENTION_BACKEND)
 
182
 
183
- pipe.transformer = transformer
184
 
185
- lora_loaded, lora_error = attach_lora(pipe)
 
186
  if lora_error:
187
  print(lora_error)
188
  else:
189
  print(f"LoRA loaded: {lora_loaded} ({LORA_PATH})")
190
 
 
191
  return pipe, lora_loaded, lora_error
192
 
193
 
194
  def ensure_models_loaded() -> Tuple[ZImagePipeline, bool, str | None]:
195
- global pipe
196
- if pipe is not None:
197
  return pipe, lora_loaded, lora_error
198
  with pipe_lock:
199
- if pipe is not None:
200
  return pipe, lora_loaded, lora_error
 
 
201
  return load_models()
202
 
203
 
@@ -205,6 +244,8 @@ def ensure_on_gpu() -> None:
205
  global pipe_on_gpu, aoti_loaded
206
  if pipe is None:
207
  raise gr.Error("Model not loaded.")
 
 
208
  if not torch.cuda.is_available():
209
  raise gr.Error("CUDA is not available. This Space requires a GPU.")
210
  if pipe_on_gpu:
@@ -241,8 +282,33 @@ def offload_to_cpu() -> None:
241
  torch.cuda.empty_cache()
242
 
243
 
244
- def set_scheduler(pipeline: ZImagePipeline, shift: float) -> None:
245
- scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  pipeline.scheduler = scheduler
247
 
248
 
@@ -257,10 +323,23 @@ def generate_image(
257
  max_sequence_length: int,
258
  use_lora: bool,
259
  lora_scale: float,
 
 
 
 
 
260
  ) -> Tuple[torch.Tensor, int]:
261
  width, height = parse_resolution(resolution)
262
  generator = torch.Generator("cuda").manual_seed(seed)
263
- set_scheduler(pipeline, shift)
 
 
 
 
 
 
 
 
264
 
265
  if lora_loaded:
266
  if use_lora:
@@ -327,6 +406,12 @@ def generate(
327
  seed: int = 42,
328
  steps: int = 9,
329
  shift: float = 3.0,
 
 
 
 
 
 
330
  random_seed: bool = True,
331
  use_lora: bool = True,
332
  lora_scale: float = 1.0,
@@ -347,10 +432,15 @@ def generate(
347
  seed=new_seed,
348
  steps=int(steps),
349
  shift=float(shift),
350
- guidance_scale=0.0,
351
  max_sequence_length=int(max_sequence_length),
352
  use_lora=use_lora,
353
  lora_scale=float(lora_scale),
 
 
 
 
 
354
  )[0]
355
  finally:
356
  if OFFLOAD_TO_CPU_AFTER_RUN:
@@ -397,11 +487,33 @@ Model: `{MODEL_PATH}` | {pipe_status}
397
  seed = gr.Number(label="Seed", value=42, precision=0)
398
  random_seed = gr.Checkbox(label="Random Seed", value=True)
399
 
400
- with gr.Row():
401
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=9, step=1)
402
- shift = gr.Slider(label="Time Shift", minimum=1.0, maximum=10.0, value=3.0, step=0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
 
404
- with gr.Row():
405
  max_seq = gr.Slider(label="Max Sequence Length", minimum=256, maximum=1024, value=512, step=16)
406
 
407
  with gr.Row():
@@ -443,7 +555,24 @@ Model: `{MODEL_PATH}` | {pipe_status}
443
 
444
  generate_btn.click(
445
  generate,
446
- inputs=[prompt_input, resolution, seed, steps, shift, random_seed, use_lora, lora_strength, max_seq, output_gallery],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  outputs=[output_gallery, used_seed, seed],
448
  api_visibility="public",
449
  )
 
3
  import re
4
  import threading
5
  import warnings
6
+ import inspect
7
  from typing import List, Tuple
8
 
9
  import gradio as gr
 
23
  ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "false").lower() == "true"
24
  AOTI_REPO = os.environ.get("AOTI_REPO", "zerogpu-aoti/Z-Image")
25
  AOTI_VARIANT = os.environ.get("AOTI_VARIANT", "fa3")
26
+ DEFAULT_CFG = float(os.environ.get("DEFAULT_CFG", "0.0"))
27
 
28
  warnings.filterwarnings("ignore")
29
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
103
  pipe_on_gpu: bool = False
104
  aoti_loaded: bool = False
105
 
106
+ SCHEDULERS = {"FlowMatch Euler": FlowMatchEulerDiscreteScheduler}
107
+ try:
108
+ from diffusers import FlowMatchHeunDiscreteScheduler # type: ignore
109
+
110
+ SCHEDULERS["FlowMatch Heun"] = FlowMatchHeunDiscreteScheduler
111
+ except Exception:
112
+ pass
113
+
114
 
115
  def parse_resolution(resolution: str) -> Tuple[int, int]:
116
  match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
 
119
  return 1024, 1024
120
 
121
 
122
+ def set_attention_backend_safe(transformer, backend: str) -> str:
123
+ candidates: List[str] = []
124
+ if backend:
125
+ candidates.append(backend)
126
+ if backend.startswith("_"):
127
+ candidates.append(backend.lstrip("_"))
128
+ else:
129
+ candidates.append(f"_{backend}")
130
+ candidates.extend(["flash", "xformers", "native"])
131
+
132
+ last_exc: Exception | None = None
133
+ for name in candidates:
134
+ if not name:
135
+ continue
136
+ try:
137
+ transformer.set_attention_backend(name)
138
+ return name
139
+ except Exception as exc: # noqa: BLE001
140
+ last_exc = exc
141
+ continue
142
+
143
+ raise RuntimeError(f"Failed to set attention backend (tried {candidates}): {last_exc}")
144
+
145
+
146
  def attach_lora(pipeline: ZImagePipeline) -> Tuple[bool, str | None]:
147
  if not LORA_PATH or not os.path.isfile(LORA_PATH):
148
  return False, "LoRA file not found"
 
166
 
167
  def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
168
  global pipe, lora_loaded, lora_error
169
+ if pipe is not None and getattr(pipe, "transformer", None) is not None:
170
  return pipe, lora_loaded, lora_error
171
 
172
  use_auth_token = HF_TOKEN if HF_TOKEN else None
 
197
 
198
  tokenizer.padding_side = "left"
199
 
200
+ pipeline = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
201
 
202
  if not os.path.exists(MODEL_PATH):
203
  transformer = ZImageTransformer2DModel.from_pretrained(
 
212
  torch_dtype=torch.bfloat16,
213
  )
214
 
215
+ applied_backend = set_attention_backend_safe(transformer, ATTENTION_BACKEND)
216
+ print(f"Attention backend: {applied_backend}")
217
 
218
+ pipeline.transformer = transformer
219
 
220
+ loaded, error = attach_lora(pipeline)
221
+ lora_loaded, lora_error = loaded, error
222
  if lora_error:
223
  print(lora_error)
224
  else:
225
  print(f"LoRA loaded: {lora_loaded} ({LORA_PATH})")
226
 
227
+ pipe = pipeline
228
  return pipe, lora_loaded, lora_error
229
 
230
 
231
  def ensure_models_loaded() -> Tuple[ZImagePipeline, bool, str | None]:
232
+ global pipe, pipe_on_gpu
233
+ if pipe is not None and getattr(pipe, "transformer", None) is not None:
234
  return pipe, lora_loaded, lora_error
235
  with pipe_lock:
236
+ if pipe is not None and getattr(pipe, "transformer", None) is not None:
237
  return pipe, lora_loaded, lora_error
238
+ pipe = None
239
+ pipe_on_gpu = False
240
  return load_models()
241
 
242
 
 
244
  global pipe_on_gpu, aoti_loaded
245
  if pipe is None:
246
  raise gr.Error("Model not loaded.")
247
+ if getattr(pipe, "transformer", None) is None:
248
+ raise gr.Error("Model init failed (transformer missing). Check startup logs.")
249
  if not torch.cuda.is_available():
250
  raise gr.Error("CUDA is not available. This Space requires a GPU.")
251
  if pipe_on_gpu:
 
282
  torch.cuda.empty_cache()
283
 
284
 
285
+ def make_scheduler(scheduler_cls, **kwargs):
286
+ sig = inspect.signature(scheduler_cls.__init__)
287
+ accepted = set(sig.parameters.keys())
288
+ accepted.discard("self")
289
+ filtered = {k: v for k, v in kwargs.items() if k in accepted and v is not None}
290
+ return scheduler_cls(**filtered)
291
+
292
+
293
+ def set_scheduler(
294
+ pipeline: ZImagePipeline,
295
+ scheduler_name: str,
296
+ *,
297
+ num_train_timesteps: int,
298
+ shift: float,
299
+ use_dynamic_shifting: bool,
300
+ base_shift: float,
301
+ max_shift: float,
302
+ ) -> None:
303
+ scheduler_cls = SCHEDULERS.get(scheduler_name, FlowMatchEulerDiscreteScheduler)
304
+ scheduler = make_scheduler(
305
+ scheduler_cls,
306
+ num_train_timesteps=int(num_train_timesteps),
307
+ shift=float(shift),
308
+ use_dynamic_shifting=bool(use_dynamic_shifting),
309
+ base_shift=float(base_shift),
310
+ max_shift=float(max_shift),
311
+ )
312
  pipeline.scheduler = scheduler
313
 
314
 
 
323
  max_sequence_length: int,
324
  use_lora: bool,
325
  lora_scale: float,
326
+ scheduler_name: str,
327
+ num_train_timesteps: int,
328
+ use_dynamic_shifting: bool,
329
+ base_shift: float,
330
+ max_shift: float,
331
  ) -> Tuple[torch.Tensor, int]:
332
  width, height = parse_resolution(resolution)
333
  generator = torch.Generator("cuda").manual_seed(seed)
334
+ set_scheduler(
335
+ pipeline,
336
+ scheduler_name,
337
+ num_train_timesteps=num_train_timesteps,
338
+ shift=shift,
339
+ use_dynamic_shifting=use_dynamic_shifting,
340
+ base_shift=base_shift,
341
+ max_shift=max_shift,
342
+ )
343
 
344
  if lora_loaded:
345
  if use_lora:
 
406
  seed: int = 42,
407
  steps: int = 9,
408
  shift: float = 3.0,
409
+ cfg: float = DEFAULT_CFG,
410
+ scheduler_name: str = "FlowMatch Euler",
411
+ num_train_timesteps: int = 1000,
412
+ use_dynamic_shifting: bool = False,
413
+ base_shift: float = 0.5,
414
+ max_shift: float = 3.0,
415
  random_seed: bool = True,
416
  use_lora: bool = True,
417
  lora_scale: float = 1.0,
 
432
  seed=new_seed,
433
  steps=int(steps),
434
  shift=float(shift),
435
+ guidance_scale=float(cfg),
436
  max_sequence_length=int(max_sequence_length),
437
  use_lora=use_lora,
438
  lora_scale=float(lora_scale),
439
+ scheduler_name=str(scheduler_name),
440
+ num_train_timesteps=int(num_train_timesteps),
441
+ use_dynamic_shifting=bool(use_dynamic_shifting),
442
+ base_shift=float(base_shift),
443
+ max_shift=float(max_shift),
444
  )[0]
445
  finally:
446
  if OFFLOAD_TO_CPU_AFTER_RUN:
 
487
  seed = gr.Number(label="Seed", value=42, precision=0)
488
  random_seed = gr.Checkbox(label="Random Seed", value=True)
489
 
490
+ with gr.Accordion("KSampler / Advanced", open=False):
491
+ with gr.Row():
492
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=9, step=1)
493
+ cfg = gr.Slider(label="CFG", minimum=0.0, maximum=10.0, value=DEFAULT_CFG, step=0.1)
494
+
495
+ with gr.Row():
496
+ scheduler_name = gr.Dropdown(
497
+ label="Scheduler",
498
+ choices=list(SCHEDULERS.keys()),
499
+ value="FlowMatch Euler",
500
+ )
501
+ num_train_timesteps = gr.Slider(
502
+ label="num_train_timesteps",
503
+ minimum=100,
504
+ maximum=2000,
505
+ value=1000,
506
+ step=10,
507
+ )
508
+
509
+ with gr.Row():
510
+ shift = gr.Slider(label="Shift", minimum=0.0, maximum=10.0, value=3.0, step=0.1)
511
+ use_dynamic_shifting = gr.Checkbox(label="use_dynamic_shifting", value=False)
512
+
513
+ with gr.Row():
514
+ base_shift = gr.Slider(label="base_shift", minimum=0.0, maximum=10.0, value=0.5, step=0.1)
515
+ max_shift = gr.Slider(label="max_shift", minimum=0.0, maximum=10.0, value=3.0, step=0.1)
516
 
 
517
  max_seq = gr.Slider(label="Max Sequence Length", minimum=256, maximum=1024, value=512, step=16)
518
 
519
  with gr.Row():
 
555
 
556
  generate_btn.click(
557
  generate,
558
+ inputs=[
559
+ prompt_input,
560
+ resolution,
561
+ seed,
562
+ steps,
563
+ shift,
564
+ cfg,
565
+ scheduler_name,
566
+ num_train_timesteps,
567
+ use_dynamic_shifting,
568
+ base_shift,
569
+ max_shift,
570
+ random_seed,
571
+ use_lora,
572
+ lora_strength,
573
+ max_seq,
574
+ output_gallery,
575
+ ],
576
  outputs=[output_gallery, used_seed, seed],
577
  api_visibility="public",
578
  )