telcom commited on
Commit
56ac11d
·
verified ·
1 Parent(s): a5ff267

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -60
app.py CHANGED
@@ -8,6 +8,7 @@ import gc
8
  import random
9
  import warnings
10
  import logging
 
11
 
12
  # ---- Spaces GPU decorator (must be imported early) ----------
13
  try:
@@ -23,14 +24,22 @@ from PIL import Image
23
  import torch
24
  from huggingface_hub import login
25
 
26
- # ---- Diffusers imports (robust for source installs) ---------
27
- try:
28
- from diffusers import ZImagePipeline, ZImageImg2ImgPipeline
29
- except Exception:
30
- from diffusers.pipelines.z_image.pipeline_z_image import ZImagePipeline
31
- from diffusers.pipelines.z_image.pipeline_z_image_img2img import ZImageImg2ImgPipeline
32
 
33
- from diffusers import FlowMatchEulerDiscreteScheduler
 
 
 
 
 
 
 
 
34
 
35
  # ============================================================
36
  # Config
@@ -72,7 +81,7 @@ if not cuda_available:
72
  fallback_msg = "GPU unavailable. Running in CPU fallback mode (slow)."
73
 
74
  # ============================================================
75
- # Load pipelines (txt2img + img2img share weights)
76
  # ============================================================
77
 
78
  pipe_txt2img = None
@@ -80,60 +89,63 @@ pipe_img2img = None
80
  model_loaded = False
81
  load_error = None
82
 
83
- try:
84
- fp_kwargs = {
85
- "torch_dtype": dtype,
86
- "use_safetensors": True,
87
- }
88
- if HF_TOKEN:
89
- fp_kwargs["token"] = HF_TOKEN
90
-
91
- # Default scheduler (you can change shift per-run)
92
- default_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
93
-
94
- pipe_txt2img = ZImagePipeline.from_pretrained(MODEL_PATH, scheduler=default_scheduler, **fp_kwargs).to(device)
95
-
96
- # Optional attention backend
97
  try:
98
- if hasattr(pipe_txt2img, "transformer") and hasattr(pipe_txt2img.transformer, "set_attention_backend"):
99
- pipe_txt2img.transformer.set_attention_backend(ATTENTION_BACKEND)
100
  except Exception:
101
  pass
102
 
103
- # Optional compile
104
- if ENABLE_COMPILE and device.type == "cuda":
105
- try:
106
- pipe_txt2img.transformer = torch.compile(
107
- pipe_txt2img.transformer,
 
 
108
  mode="max-autotune-no-cudagraphs",
109
  fullgraph=False,
110
  )
111
- except Exception:
112
- pass
113
-
114
- try:
115
- pipe_txt2img.set_progress_bar_config(disable=True)
116
  except Exception:
117
  pass
118
 
119
- # Build img2img pipeline reusing the exact same modules
120
- pipe_img2img = ZImageImg2ImgPipeline(
121
- scheduler=pipe_txt2img.scheduler,
122
- vae=pipe_txt2img.vae,
123
- text_encoder=pipe_txt2img.text_encoder,
124
- tokenizer=pipe_txt2img.tokenizer,
125
- transformer=pipe_txt2img.transformer,
126
- ).to(device)
127
-
128
  try:
129
- pipe_img2img.set_progress_bar_config(disable=True)
130
- except Exception:
131
- pass
 
 
 
132
 
133
- model_loaded = True
 
 
134
 
135
- except Exception as e:
136
- load_error = repr(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  model_loaded = False
138
 
139
  # ============================================================
@@ -153,6 +165,20 @@ def prep_init_image(img: Image.Image, width: int, height: int) -> Image.Image:
153
  img = img.resize((width, height), Image.LANCZOS)
154
  return img
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  # ============================================================
157
  # Inference
158
  # ============================================================
@@ -202,13 +228,16 @@ def _infer_impl(
202
 
203
  init_image = prep_init_image(init_image, width, height)
204
 
205
- # Update scheduler shift per run
206
- scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
207
- pipe_txt2img.scheduler = scheduler
208
- pipe_img2img.scheduler = scheduler
 
 
 
209
 
210
  try:
211
- common_kwargs = dict(
212
  prompt=prompt,
213
  height=height,
214
  width=width,
@@ -217,21 +246,28 @@ def _infer_impl(
217
  generator=generator,
218
  max_sequence_length=msl,
219
  )
 
220
  if neg is not None:
221
- common_kwargs["negative_prompt"] = neg
222
 
223
  with torch.inference_mode():
224
  if device.type == "cuda":
225
  with torch.autocast("cuda", dtype=dtype):
226
  if init_image is not None:
227
- out = pipe_img2img(image=init_image, strength=st, **common_kwargs)
 
 
 
228
  else:
229
- out = pipe_txt2img(**common_kwargs)
230
  else:
231
  if init_image is not None:
232
- out = pipe_img2img(image=init_image, strength=st, **common_kwargs)
 
 
 
233
  else:
234
- out = pipe_txt2img(**common_kwargs)
235
 
236
  img = out.images[0]
237
  return img, status
@@ -253,7 +289,7 @@ else:
253
  return _infer_impl(*args, **kwargs)
254
 
255
  # ============================================================
256
- # UI
257
  # ============================================================
258
 
259
  CSS = """
 
8
  import random
9
  import warnings
10
  import logging
11
+ import inspect
12
 
13
  # ---- Spaces GPU decorator (must be imported early) ----------
14
  try:
 
24
  import torch
25
  from huggingface_hub import login
26
 
27
+ # ============================================================
28
+ # Try importing Z-Image pipelines (requires diffusers>=0.36.0)
29
+ # ============================================================
30
+
31
+ ZIMAGE_AVAILABLE = True
32
+ ZIMAGE_IMPORT_ERROR = None
33
 
34
+ try:
35
+ from diffusers import (
36
+ ZImagePipeline,
37
+ ZImageImg2ImgPipeline,
38
+ FlowMatchEulerDiscreteScheduler,
39
+ )
40
+ except Exception as e:
41
+ ZIMAGE_AVAILABLE = False
42
+ ZIMAGE_IMPORT_ERROR = repr(e)
43
 
44
  # ============================================================
45
  # Config
 
81
  fallback_msg = "GPU unavailable. Running in CPU fallback mode (slow)."
82
 
83
  # ============================================================
84
+ # Load pipelines
85
  # ============================================================
86
 
87
  pipe_txt2img = None
 
89
  model_loaded = False
90
  load_error = None
91
 
92
+ def _set_attention_backend_best_effort(p):
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  try:
94
+ if hasattr(p, "transformer") and hasattr(p.transformer, "set_attention_backend"):
95
+ p.transformer.set_attention_backend(ATTENTION_BACKEND)
96
  except Exception:
97
  pass
98
 
99
+ def _compile_best_effort(p):
100
+ if not (ENABLE_COMPILE and device.type == "cuda"):
101
+ return
102
+ try:
103
+ if hasattr(p, "transformer"):
104
+ p.transformer = torch.compile(
105
+ p.transformer,
106
  mode="max-autotune-no-cudagraphs",
107
  fullgraph=False,
108
  )
 
 
 
 
 
109
  except Exception:
110
  pass
111
 
112
+ if ZIMAGE_AVAILABLE:
 
 
 
 
 
 
 
 
113
  try:
114
+ fp_kwargs = {
115
+ "torch_dtype": dtype,
116
+ "use_safetensors": True,
117
+ }
118
+ if HF_TOKEN:
119
+ fp_kwargs["token"] = HF_TOKEN
120
 
121
+ pipe_txt2img = ZImagePipeline.from_pretrained(MODEL_PATH, **fp_kwargs).to(device)
122
+ _set_attention_backend_best_effort(pipe_txt2img)
123
+ _compile_best_effort(pipe_txt2img)
124
 
125
+ try:
126
+ pipe_txt2img.set_progress_bar_config(disable=True)
127
+ except Exception:
128
+ pass
129
+
130
+ # Share weights/components with img2img pipeline
131
+ pipe_img2img = ZImageImg2ImgPipeline(**pipe_txt2img.components).to(device)
132
+ _set_attention_backend_best_effort(pipe_img2img)
133
+ try:
134
+ pipe_img2img.set_progress_bar_config(disable=True)
135
+ except Exception:
136
+ pass
137
+
138
+ model_loaded = True
139
+
140
+ except Exception as e:
141
+ load_error = repr(e)
142
+ model_loaded = False
143
+ else:
144
+ load_error = (
145
+ "Z-Image pipelines not available in your diffusers install.\n\n"
146
+ f"Import error:\n{ZIMAGE_IMPORT_ERROR}\n\n"
147
+ "Fix: set requirements.txt to diffusers==0.36.0 (or install Diffusers from source)."
148
+ )
149
  model_loaded = False
150
 
151
  # ============================================================
 
165
  img = img.resize((width, height), Image.LANCZOS)
166
  return img
167
 
168
+ def _call_pipeline(pipe, kwargs: dict):
169
+ """
170
+ Robust call: only pass kwargs the pipeline actually accepts.
171
+ This avoids crashes if a particular build does not support negative_prompt, etc.
172
+ """
173
+ try:
174
+ sig = inspect.signature(pipe.__call__)
175
+ allowed = set(sig.parameters.keys())
176
+ filtered = {k: v for k, v in kwargs.items() if k in allowed and v is not None}
177
+ return pipe(**filtered)
178
+ except Exception:
179
+ # Fallback: try raw kwargs (some pipelines use **kwargs internally)
180
+ return pipe(**{k: v for k, v in kwargs.items() if v is not None})
181
+
182
  # ============================================================
183
  # Inference
184
  # ============================================================
 
228
 
229
  init_image = prep_init_image(init_image, width, height)
230
 
231
+ # Update scheduler (shift) per run
232
+ try:
233
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
234
+ pipe_txt2img.scheduler = scheduler
235
+ pipe_img2img.scheduler = scheduler
236
+ except Exception:
237
+ pass
238
 
239
  try:
240
+ base_kwargs = dict(
241
  prompt=prompt,
242
  height=height,
243
  width=width,
 
246
  generator=generator,
247
  max_sequence_length=msl,
248
  )
249
+ # only passed if supported by the pipeline
250
  if neg is not None:
251
+ base_kwargs["negative_prompt"] = neg
252
 
253
  with torch.inference_mode():
254
  if device.type == "cuda":
255
  with torch.autocast("cuda", dtype=dtype):
256
  if init_image is not None:
257
+ out = _call_pipeline(
258
+ pipe_img2img,
259
+ {**base_kwargs, "image": init_image, "strength": st},
260
+ )
261
  else:
262
+ out = _call_pipeline(pipe_txt2img, base_kwargs)
263
  else:
264
  if init_image is not None:
265
+ out = _call_pipeline(
266
+ pipe_img2img,
267
+ {**base_kwargs, "image": init_image, "strength": st},
268
+ )
269
  else:
270
+ out = _call_pipeline(pipe_txt2img, base_kwargs)
271
 
272
  img = out.images[0]
273
  return img, status
 
289
  return _infer_impl(*args, **kwargs)
290
 
291
  # ============================================================
292
+ # UI (simple black style like your SDXL example)
293
  # ============================================================
294
 
295
  CSS = """