apolinario commited on
Commit
bcabac3
·
1 Parent(s): 27aa079

Force eager attention on Z-Image text encoder (vmap conflict with ZeroGPU); add Citrus theme + dark-mode color css

Browse files
Files changed (1) hide show
  1. app.py +22 -2
app.py CHANGED
@@ -53,8 +53,26 @@ SR_SCALE = 4
53
  PID_INFERENCE_STEPS = 4
54
 
55
  print("[pid] loading Z-Image pipeline...", flush=True)
56
- pipeline, pipe_cfg = load_pipeline(BACKBONE, dtype=DTYPE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  pipeline.to("cuda")
 
 
58
 
59
  print("[pid] loading PiD decoder...", flush=True)
60
  pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)
@@ -183,7 +201,9 @@ PiD upsamples 4× during decode, so a 512² Z-Image latent track becomes a
183
  2048² super-resolved image.
184
  """
185
 
186
- with gr.Blocks() as demo:
 
 
187
  gr.Markdown(DESCRIPTION)
188
  with gr.Row():
189
  with gr.Column(scale=1):
 
53
  PID_INFERENCE_STEPS = 4
54
 
55
  print("[pid] loading Z-Image pipeline...", flush=True)
56
+ # transformers 4.57's SDPA causal-mask uses torch.vmap, which clashes with
57
+ # ZeroGPU's __torch_function__ hijack during fake tensor allocation. Force
58
+ # eager attention on the text encoder to skip the vmap codepath.
59
+ from diffusers import ZImagePipeline
60
+ from transformers import Qwen3Model
61
+
62
+ _text_encoder = Qwen3Model.from_pretrained(
63
+ "Tongyi-MAI/Z-Image",
64
+ subfolder="text_encoder",
65
+ torch_dtype=DTYPE,
66
+ attn_implementation="eager",
67
+ )
68
+ pipeline = ZImagePipeline.from_pretrained(
69
+ "Tongyi-MAI/Z-Image",
70
+ torch_dtype=DTYPE,
71
+ text_encoder=_text_encoder,
72
+ )
73
  pipeline.to("cuda")
74
+ from pid._src.inference.pipeline_registry import get_config as _get_pipe_cfg
75
+ pipe_cfg = _get_pipe_cfg(BACKBONE)
76
 
77
  print("[pid] loading PiD decoder...", flush=True)
78
  pid_meta = get_pid_checkpoint(BACKBONE, CKPT_TYPE)
 
201
  2048² super-resolved image.
202
  """
203
 
204
+ CSS = " .dark .gradio-container { color: var(--body-text-color);"
205
+
206
+ with gr.Blocks(theme=gr.themes.Citrus(), css=CSS) as demo:
207
  gr.Markdown(DESCRIPTION)
208
  with gr.Row():
209
  with gr.Column(scale=1):