修复加载bug
Browse files
README.md
CHANGED
|
@@ -52,7 +52,7 @@ Place the LoRA file under `lora/` first (or set `LORA_PATH`); otherwise the app
|
|
| 52 |
- Prompt
|
| 53 |
- Resolution category + explicit WxH selection
|
| 54 |
- Seed (with random toggle)
|
| 55 |
-
- Steps,
|
| 56 |
- LoRA toggle + strength (enabled only if the file is found)
|
| 57 |
|
| 58 |
## Git LFS note
|
|
|
|
| 52 |
- Prompt
|
| 53 |
- Resolution category + explicit WxH selection
|
| 54 |
- Seed (with random toggle)
|
| 55 |
+
- Steps, CFG, scheduler + shift (and extra scheduler params), max sequence length
|
| 56 |
- LoRA toggle + strength (enabled only if the file is found)
|
| 57 |
|
| 58 |
## Git LFS note
|
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import random
|
|
| 3 |
import re
|
| 4 |
import threading
|
| 5 |
import warnings
|
|
|
|
| 6 |
from typing import List, Tuple
|
| 7 |
|
| 8 |
import gradio as gr
|
|
@@ -22,6 +23,7 @@ OFFLOAD_TO_CPU_AFTER_RUN = os.environ.get("OFFLOAD_TO_CPU_AFTER_RUN", "true").lo
|
|
| 22 |
ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "false").lower() == "true"
|
| 23 |
AOTI_REPO = os.environ.get("AOTI_REPO", "zerogpu-aoti/Z-Image")
|
| 24 |
AOTI_VARIANT = os.environ.get("AOTI_VARIANT", "fa3")
|
|
|
|
| 25 |
|
| 26 |
warnings.filterwarnings("ignore")
|
| 27 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
@@ -101,6 +103,14 @@ pipe_lock = threading.Lock()
|
|
| 101 |
pipe_on_gpu: bool = False
|
| 102 |
aoti_loaded: bool = False
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
def parse_resolution(resolution: str) -> Tuple[int, int]:
|
| 106 |
match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
|
|
@@ -109,6 +119,30 @@ def parse_resolution(resolution: str) -> Tuple[int, int]:
|
|
| 109 |
return 1024, 1024
|
| 110 |
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
def attach_lora(pipeline: ZImagePipeline) -> Tuple[bool, str | None]:
|
| 113 |
if not LORA_PATH or not os.path.isfile(LORA_PATH):
|
| 114 |
return False, "LoRA file not found"
|
|
@@ -132,7 +166,7 @@ def set_lora_scale(pipeline: ZImagePipeline, scale: float) -> None:
|
|
| 132 |
|
| 133 |
def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
|
| 134 |
global pipe, lora_loaded, lora_error
|
| 135 |
-
if pipe is not None:
|
| 136 |
return pipe, lora_loaded, lora_error
|
| 137 |
|
| 138 |
use_auth_token = HF_TOKEN if HF_TOKEN else None
|
|
@@ -163,7 +197,7 @@ def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
|
|
| 163 |
|
| 164 |
tokenizer.padding_side = "left"
|
| 165 |
|
| 166 |
-
|
| 167 |
|
| 168 |
if not os.path.exists(MODEL_PATH):
|
| 169 |
transformer = ZImageTransformer2DModel.from_pretrained(
|
|
@@ -178,26 +212,31 @@ def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
|
|
| 178 |
torch_dtype=torch.bfloat16,
|
| 179 |
)
|
| 180 |
|
| 181 |
-
transformer
|
|
|
|
| 182 |
|
| 183 |
-
|
| 184 |
|
| 185 |
-
|
|
|
|
| 186 |
if lora_error:
|
| 187 |
print(lora_error)
|
| 188 |
else:
|
| 189 |
print(f"LoRA loaded: {lora_loaded} ({LORA_PATH})")
|
| 190 |
|
|
|
|
| 191 |
return pipe, lora_loaded, lora_error
|
| 192 |
|
| 193 |
|
| 194 |
def ensure_models_loaded() -> Tuple[ZImagePipeline, bool, str | None]:
|
| 195 |
-
global pipe
|
| 196 |
-
if pipe is not None:
|
| 197 |
return pipe, lora_loaded, lora_error
|
| 198 |
with pipe_lock:
|
| 199 |
-
if pipe is not None:
|
| 200 |
return pipe, lora_loaded, lora_error
|
|
|
|
|
|
|
| 201 |
return load_models()
|
| 202 |
|
| 203 |
|
|
@@ -205,6 +244,8 @@ def ensure_on_gpu() -> None:
|
|
| 205 |
global pipe_on_gpu, aoti_loaded
|
| 206 |
if pipe is None:
|
| 207 |
raise gr.Error("Model not loaded.")
|
|
|
|
|
|
|
| 208 |
if not torch.cuda.is_available():
|
| 209 |
raise gr.Error("CUDA is not available. This Space requires a GPU.")
|
| 210 |
if pipe_on_gpu:
|
|
@@ -241,8 +282,33 @@ def offload_to_cpu() -> None:
|
|
| 241 |
torch.cuda.empty_cache()
|
| 242 |
|
| 243 |
|
| 244 |
-
def
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
pipeline.scheduler = scheduler
|
| 247 |
|
| 248 |
|
|
@@ -257,10 +323,23 @@ def generate_image(
|
|
| 257 |
max_sequence_length: int,
|
| 258 |
use_lora: bool,
|
| 259 |
lora_scale: float,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
) -> Tuple[torch.Tensor, int]:
|
| 261 |
width, height = parse_resolution(resolution)
|
| 262 |
generator = torch.Generator("cuda").manual_seed(seed)
|
| 263 |
-
set_scheduler(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
if lora_loaded:
|
| 266 |
if use_lora:
|
|
@@ -327,6 +406,12 @@ def generate(
|
|
| 327 |
seed: int = 42,
|
| 328 |
steps: int = 9,
|
| 329 |
shift: float = 3.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
random_seed: bool = True,
|
| 331 |
use_lora: bool = True,
|
| 332 |
lora_scale: float = 1.0,
|
|
@@ -347,10 +432,15 @@ def generate(
|
|
| 347 |
seed=new_seed,
|
| 348 |
steps=int(steps),
|
| 349 |
shift=float(shift),
|
| 350 |
-
guidance_scale=
|
| 351 |
max_sequence_length=int(max_sequence_length),
|
| 352 |
use_lora=use_lora,
|
| 353 |
lora_scale=float(lora_scale),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
)[0]
|
| 355 |
finally:
|
| 356 |
if OFFLOAD_TO_CPU_AFTER_RUN:
|
|
@@ -397,11 +487,33 @@ Model: `{MODEL_PATH}` | {pipe_status}
|
|
| 397 |
seed = gr.Number(label="Seed", value=42, precision=0)
|
| 398 |
random_seed = gr.Checkbox(label="Random Seed", value=True)
|
| 399 |
|
| 400 |
-
with gr.
|
| 401 |
-
|
| 402 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
|
| 404 |
-
with gr.Row():
|
| 405 |
max_seq = gr.Slider(label="Max Sequence Length", minimum=256, maximum=1024, value=512, step=16)
|
| 406 |
|
| 407 |
with gr.Row():
|
|
@@ -443,7 +555,24 @@ Model: `{MODEL_PATH}` | {pipe_status}
|
|
| 443 |
|
| 444 |
generate_btn.click(
|
| 445 |
generate,
|
| 446 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
outputs=[output_gallery, used_seed, seed],
|
| 448 |
api_visibility="public",
|
| 449 |
)
|
|
|
|
| 3 |
import re
|
| 4 |
import threading
|
| 5 |
import warnings
|
| 6 |
+
import inspect
|
| 7 |
from typing import List, Tuple
|
| 8 |
|
| 9 |
import gradio as gr
|
|
|
|
| 23 |
ENABLE_AOTI = os.environ.get("ENABLE_AOTI", "false").lower() == "true"
|
| 24 |
AOTI_REPO = os.environ.get("AOTI_REPO", "zerogpu-aoti/Z-Image")
|
| 25 |
AOTI_VARIANT = os.environ.get("AOTI_VARIANT", "fa3")
|
| 26 |
+
DEFAULT_CFG = float(os.environ.get("DEFAULT_CFG", "0.0"))
|
| 27 |
|
| 28 |
warnings.filterwarnings("ignore")
|
| 29 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
| 103 |
pipe_on_gpu: bool = False
|
| 104 |
aoti_loaded: bool = False
|
| 105 |
|
| 106 |
+
SCHEDULERS = {"FlowMatch Euler": FlowMatchEulerDiscreteScheduler}
|
| 107 |
+
try:
|
| 108 |
+
from diffusers import FlowMatchHeunDiscreteScheduler # type: ignore
|
| 109 |
+
|
| 110 |
+
SCHEDULERS["FlowMatch Heun"] = FlowMatchHeunDiscreteScheduler
|
| 111 |
+
except Exception:
|
| 112 |
+
pass
|
| 113 |
+
|
| 114 |
|
| 115 |
def parse_resolution(resolution: str) -> Tuple[int, int]:
|
| 116 |
match = re.search(r"(\d+)\s*[×x]\s*(\d+)", resolution)
|
|
|
|
| 119 |
return 1024, 1024
|
| 120 |
|
| 121 |
|
| 122 |
+
def set_attention_backend_safe(transformer, backend: str) -> str:
|
| 123 |
+
candidates: List[str] = []
|
| 124 |
+
if backend:
|
| 125 |
+
candidates.append(backend)
|
| 126 |
+
if backend.startswith("_"):
|
| 127 |
+
candidates.append(backend.lstrip("_"))
|
| 128 |
+
else:
|
| 129 |
+
candidates.append(f"_{backend}")
|
| 130 |
+
candidates.extend(["flash", "xformers", "native"])
|
| 131 |
+
|
| 132 |
+
last_exc: Exception | None = None
|
| 133 |
+
for name in candidates:
|
| 134 |
+
if not name:
|
| 135 |
+
continue
|
| 136 |
+
try:
|
| 137 |
+
transformer.set_attention_backend(name)
|
| 138 |
+
return name
|
| 139 |
+
except Exception as exc: # noqa: BLE001
|
| 140 |
+
last_exc = exc
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
raise RuntimeError(f"Failed to set attention backend (tried {candidates}): {last_exc}")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
def attach_lora(pipeline: ZImagePipeline) -> Tuple[bool, str | None]:
|
| 147 |
if not LORA_PATH or not os.path.isfile(LORA_PATH):
|
| 148 |
return False, "LoRA file not found"
|
|
|
|
| 166 |
|
| 167 |
def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
|
| 168 |
global pipe, lora_loaded, lora_error
|
| 169 |
+
if pipe is not None and getattr(pipe, "transformer", None) is not None:
|
| 170 |
return pipe, lora_loaded, lora_error
|
| 171 |
|
| 172 |
use_auth_token = HF_TOKEN if HF_TOKEN else None
|
|
|
|
| 197 |
|
| 198 |
tokenizer.padding_side = "left"
|
| 199 |
|
| 200 |
+
pipeline = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=None)
|
| 201 |
|
| 202 |
if not os.path.exists(MODEL_PATH):
|
| 203 |
transformer = ZImageTransformer2DModel.from_pretrained(
|
|
|
|
| 212 |
torch_dtype=torch.bfloat16,
|
| 213 |
)
|
| 214 |
|
| 215 |
+
applied_backend = set_attention_backend_safe(transformer, ATTENTION_BACKEND)
|
| 216 |
+
print(f"Attention backend: {applied_backend}")
|
| 217 |
|
| 218 |
+
pipeline.transformer = transformer
|
| 219 |
|
| 220 |
+
loaded, error = attach_lora(pipeline)
|
| 221 |
+
lora_loaded, lora_error = loaded, error
|
| 222 |
if lora_error:
|
| 223 |
print(lora_error)
|
| 224 |
else:
|
| 225 |
print(f"LoRA loaded: {lora_loaded} ({LORA_PATH})")
|
| 226 |
|
| 227 |
+
pipe = pipeline
|
| 228 |
return pipe, lora_loaded, lora_error
|
| 229 |
|
| 230 |
|
| 231 |
def ensure_models_loaded() -> Tuple[ZImagePipeline, bool, str | None]:
|
| 232 |
+
global pipe, pipe_on_gpu
|
| 233 |
+
if pipe is not None and getattr(pipe, "transformer", None) is not None:
|
| 234 |
return pipe, lora_loaded, lora_error
|
| 235 |
with pipe_lock:
|
| 236 |
+
if pipe is not None and getattr(pipe, "transformer", None) is not None:
|
| 237 |
return pipe, lora_loaded, lora_error
|
| 238 |
+
pipe = None
|
| 239 |
+
pipe_on_gpu = False
|
| 240 |
return load_models()
|
| 241 |
|
| 242 |
|
|
|
|
| 244 |
global pipe_on_gpu, aoti_loaded
|
| 245 |
if pipe is None:
|
| 246 |
raise gr.Error("Model not loaded.")
|
| 247 |
+
if getattr(pipe, "transformer", None) is None:
|
| 248 |
+
raise gr.Error("Model init failed (transformer missing). Check startup logs.")
|
| 249 |
if not torch.cuda.is_available():
|
| 250 |
raise gr.Error("CUDA is not available. This Space requires a GPU.")
|
| 251 |
if pipe_on_gpu:
|
|
|
|
| 282 |
torch.cuda.empty_cache()
|
| 283 |
|
| 284 |
|
| 285 |
+
def make_scheduler(scheduler_cls, **kwargs):
|
| 286 |
+
sig = inspect.signature(scheduler_cls.__init__)
|
| 287 |
+
accepted = set(sig.parameters.keys())
|
| 288 |
+
accepted.discard("self")
|
| 289 |
+
filtered = {k: v for k, v in kwargs.items() if k in accepted and v is not None}
|
| 290 |
+
return scheduler_cls(**filtered)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def set_scheduler(
|
| 294 |
+
pipeline: ZImagePipeline,
|
| 295 |
+
scheduler_name: str,
|
| 296 |
+
*,
|
| 297 |
+
num_train_timesteps: int,
|
| 298 |
+
shift: float,
|
| 299 |
+
use_dynamic_shifting: bool,
|
| 300 |
+
base_shift: float,
|
| 301 |
+
max_shift: float,
|
| 302 |
+
) -> None:
|
| 303 |
+
scheduler_cls = SCHEDULERS.get(scheduler_name, FlowMatchEulerDiscreteScheduler)
|
| 304 |
+
scheduler = make_scheduler(
|
| 305 |
+
scheduler_cls,
|
| 306 |
+
num_train_timesteps=int(num_train_timesteps),
|
| 307 |
+
shift=float(shift),
|
| 308 |
+
use_dynamic_shifting=bool(use_dynamic_shifting),
|
| 309 |
+
base_shift=float(base_shift),
|
| 310 |
+
max_shift=float(max_shift),
|
| 311 |
+
)
|
| 312 |
pipeline.scheduler = scheduler
|
| 313 |
|
| 314 |
|
|
|
|
| 323 |
max_sequence_length: int,
|
| 324 |
use_lora: bool,
|
| 325 |
lora_scale: float,
|
| 326 |
+
scheduler_name: str,
|
| 327 |
+
num_train_timesteps: int,
|
| 328 |
+
use_dynamic_shifting: bool,
|
| 329 |
+
base_shift: float,
|
| 330 |
+
max_shift: float,
|
| 331 |
) -> Tuple[torch.Tensor, int]:
|
| 332 |
width, height = parse_resolution(resolution)
|
| 333 |
generator = torch.Generator("cuda").manual_seed(seed)
|
| 334 |
+
set_scheduler(
|
| 335 |
+
pipeline,
|
| 336 |
+
scheduler_name,
|
| 337 |
+
num_train_timesteps=num_train_timesteps,
|
| 338 |
+
shift=shift,
|
| 339 |
+
use_dynamic_shifting=use_dynamic_shifting,
|
| 340 |
+
base_shift=base_shift,
|
| 341 |
+
max_shift=max_shift,
|
| 342 |
+
)
|
| 343 |
|
| 344 |
if lora_loaded:
|
| 345 |
if use_lora:
|
|
|
|
| 406 |
seed: int = 42,
|
| 407 |
steps: int = 9,
|
| 408 |
shift: float = 3.0,
|
| 409 |
+
cfg: float = DEFAULT_CFG,
|
| 410 |
+
scheduler_name: str = "FlowMatch Euler",
|
| 411 |
+
num_train_timesteps: int = 1000,
|
| 412 |
+
use_dynamic_shifting: bool = False,
|
| 413 |
+
base_shift: float = 0.5,
|
| 414 |
+
max_shift: float = 3.0,
|
| 415 |
random_seed: bool = True,
|
| 416 |
use_lora: bool = True,
|
| 417 |
lora_scale: float = 1.0,
|
|
|
|
| 432 |
seed=new_seed,
|
| 433 |
steps=int(steps),
|
| 434 |
shift=float(shift),
|
| 435 |
+
guidance_scale=float(cfg),
|
| 436 |
max_sequence_length=int(max_sequence_length),
|
| 437 |
use_lora=use_lora,
|
| 438 |
lora_scale=float(lora_scale),
|
| 439 |
+
scheduler_name=str(scheduler_name),
|
| 440 |
+
num_train_timesteps=int(num_train_timesteps),
|
| 441 |
+
use_dynamic_shifting=bool(use_dynamic_shifting),
|
| 442 |
+
base_shift=float(base_shift),
|
| 443 |
+
max_shift=float(max_shift),
|
| 444 |
)[0]
|
| 445 |
finally:
|
| 446 |
if OFFLOAD_TO_CPU_AFTER_RUN:
|
|
|
|
| 487 |
seed = gr.Number(label="Seed", value=42, precision=0)
|
| 488 |
random_seed = gr.Checkbox(label="Random Seed", value=True)
|
| 489 |
|
| 490 |
+
with gr.Accordion("KSampler / Advanced", open=False):
|
| 491 |
+
with gr.Row():
|
| 492 |
+
steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=9, step=1)
|
| 493 |
+
cfg = gr.Slider(label="CFG", minimum=0.0, maximum=10.0, value=DEFAULT_CFG, step=0.1)
|
| 494 |
+
|
| 495 |
+
with gr.Row():
|
| 496 |
+
scheduler_name = gr.Dropdown(
|
| 497 |
+
label="Scheduler",
|
| 498 |
+
choices=list(SCHEDULERS.keys()),
|
| 499 |
+
value="FlowMatch Euler",
|
| 500 |
+
)
|
| 501 |
+
num_train_timesteps = gr.Slider(
|
| 502 |
+
label="num_train_timesteps",
|
| 503 |
+
minimum=100,
|
| 504 |
+
maximum=2000,
|
| 505 |
+
value=1000,
|
| 506 |
+
step=10,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
with gr.Row():
|
| 510 |
+
shift = gr.Slider(label="Shift", minimum=0.0, maximum=10.0, value=3.0, step=0.1)
|
| 511 |
+
use_dynamic_shifting = gr.Checkbox(label="use_dynamic_shifting", value=False)
|
| 512 |
+
|
| 513 |
+
with gr.Row():
|
| 514 |
+
base_shift = gr.Slider(label="base_shift", minimum=0.0, maximum=10.0, value=0.5, step=0.1)
|
| 515 |
+
max_shift = gr.Slider(label="max_shift", minimum=0.0, maximum=10.0, value=3.0, step=0.1)
|
| 516 |
|
|
|
|
| 517 |
max_seq = gr.Slider(label="Max Sequence Length", minimum=256, maximum=1024, value=512, step=16)
|
| 518 |
|
| 519 |
with gr.Row():
|
|
|
|
| 555 |
|
| 556 |
generate_btn.click(
|
| 557 |
generate,
|
| 558 |
+
inputs=[
|
| 559 |
+
prompt_input,
|
| 560 |
+
resolution,
|
| 561 |
+
seed,
|
| 562 |
+
steps,
|
| 563 |
+
shift,
|
| 564 |
+
cfg,
|
| 565 |
+
scheduler_name,
|
| 566 |
+
num_train_timesteps,
|
| 567 |
+
use_dynamic_shifting,
|
| 568 |
+
base_shift,
|
| 569 |
+
max_shift,
|
| 570 |
+
random_seed,
|
| 571 |
+
use_lora,
|
| 572 |
+
lora_strength,
|
| 573 |
+
max_seq,
|
| 574 |
+
output_gallery,
|
| 575 |
+
],
|
| 576 |
outputs=[output_gallery, used_seed, seed],
|
| 577 |
api_visibility="public",
|
| 578 |
)
|