telcom commited on
Commit
d59b481
·
verified ·
1 Parent(s): 72ae055

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -100
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  # ============================================================
2
  # IMPORTANT: imports order matters for Hugging Face Spaces
3
  # ============================================================
@@ -22,12 +23,20 @@ from PIL import Image
22
  import torch
23
  from huggingface_hub import login
24
 
25
- from diffusers import (
26
- ZImagePipeline,
27
- ZImageImg2ImgPipeline,
28
- AutoencoderKL,
29
- FlowMatchEulerDiscreteScheduler,
30
- )
 
 
 
 
 
 
 
 
31
  from transformers import AutoModelForCausalLM, AutoTokenizer
32
 
33
  # ============================================================
@@ -35,7 +44,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
35
  # ============================================================
36
 
37
  MODEL_PATH = os.environ.get("MODEL_PATH", "telcom/dee-z-image").strip()
38
- ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3").strip() # try: flash_3, flash, sdpa
 
 
39
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true"
40
 
41
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
@@ -62,7 +73,6 @@ elif cuda_available:
62
  else:
63
  dtype = torch.float32
64
 
65
- # A conservative max for most Spaces GPUs. Increase if you know you have headroom.
66
  MAX_IMAGE_SIZE = 1536 if cuda_available else 768
67
 
68
  fallback_msg = ""
@@ -78,199 +88,219 @@ pipe_img2img = None
78
  model_loaded = False
79
  load_error = None
80
 
81
- def _try_load_with_from_pretrained():
82
- """
83
- Preferred path: load everything via Diffusers from_pretrained.
84
- Works when the repo is structured as a standard Diffusers pipeline repo.
85
  """
86
- kwargs = {
87
- "torch_dtype": dtype,
88
- "use_safetensors": True,
89
- }
90
- if HF_TOKEN:
91
- kwargs["token"] = HF_TOKEN
92
-
93
- p_txt = ZImagePipeline.from_pretrained(MODEL_PATH, **kwargs)
94
- p_img = ZImageImg2ImgPipeline(**p_txt.components)
95
- return p_txt, p_img
96
-
97
- def _fallback_manual_load():
98
- """
99
- Fallback path: load subfolders manually, similar to many Z-Image demos.
100
- Works when MODEL_PATH points to a repo with subfolders:
101
  vae/, transformer/, text_encoder/, tokenizer/
 
102
  """
103
  use_auth_token = HF_TOKEN if HF_TOKEN else True
104
 
 
105
  vae = AutoencoderKL.from_pretrained(
106
- MODEL_PATH,
107
  subfolder="vae",
108
  torch_dtype=dtype,
109
  use_auth_token=use_auth_token,
110
  )
 
 
111
  text_encoder = AutoModelForCausalLM.from_pretrained(
112
- MODEL_PATH,
113
  subfolder="text_encoder",
114
  torch_dtype=dtype,
115
  use_auth_token=use_auth_token,
116
  ).eval()
 
117
  tokenizer = AutoTokenizer.from_pretrained(
118
- MODEL_PATH,
119
  subfolder="tokenizer",
120
  use_auth_token=use_auth_token,
121
  )
122
  tokenizer.padding_side = "left"
123
 
124
- # ZImageTransformer2DModel lives inside diffusers; importing lazily avoids import issues on older versions.
125
  from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
126
 
127
  transformer = ZImageTransformer2DModel.from_pretrained(
128
- MODEL_PATH,
129
  subfolder="transformer",
130
  torch_dtype=dtype,
131
  use_auth_token=use_auth_token,
132
  )
133
 
134
- p_txt = ZImagePipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer)
135
- p_img = ZImageImg2ImgPipeline(scheduler=None, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer)
136
- return p_txt, p_img
 
 
 
 
 
 
137
 
138
  try:
139
- pipe_txt2img, pipe_img2img = _try_load_with_from_pretrained()
140
- model_loaded = True
141
- except Exception as e1:
 
 
142
  try:
143
- pipe_txt2img, pipe_img2img = _fallback_manual_load()
144
- model_loaded = True
145
- except Exception as e2:
146
- load_error = f"from_pretrained error: {repr(e1)}\nmanual_load error: {repr(e2)}"
147
- model_loaded = False
148
 
149
- if model_loaded:
150
  pipe_txt2img = pipe_txt2img.to(device)
151
- pipe_img2img = pipe_img2img.to(device)
152
 
153
- # Try attention backend (best-effort)
154
  try:
155
  if hasattr(pipe_txt2img, "transformer") and hasattr(pipe_txt2img.transformer, "set_attention_backend"):
156
  pipe_txt2img.transformer.set_attention_backend(ATTENTION_BACKEND)
157
- pipe_img2img.transformer.set_attention_backend(ATTENTION_BACKEND)
158
  except Exception:
159
  pass
160
 
161
- # Optional compile (best-effort, can break on some setups)
162
  if ENABLE_COMPILE and device.type == "cuda":
163
  try:
164
- pipe_txt2img.transformer = torch.compile(pipe_txt2img.transformer, mode="max-autotune-no-cudagraphs", fullgraph=False)
165
- pipe_img2img.transformer = pipe_txt2img.transformer
 
 
 
166
  except Exception:
167
  pass
168
 
169
- # Disable diffusers progress bars
170
  try:
171
  pipe_txt2img.set_progress_bar_config(disable=True)
 
 
 
 
 
 
 
 
 
 
 
 
172
  pipe_img2img.set_progress_bar_config(disable=True)
173
  except Exception:
174
  pass
175
 
 
 
 
 
 
 
176
  # ============================================================
177
- # Utility: error image
178
  # ============================================================
179
 
180
- def make_error_image(w, h):
181
- return Image.new("RGB", (w, h), (18, 18, 22))
182
 
183
- def _prep_init_image(init_image, width, height):
184
  if init_image is None:
185
  return None
186
  if not isinstance(init_image, Image.Image):
187
  return None
188
- init_image = init_image.convert("RGB")
189
- if init_image.size != (width, height):
190
- init_image = init_image.resize((width, height), Image.LANCZOS)
191
- return init_image
192
 
193
  # ============================================================
194
  # Inference
195
  # ============================================================
196
 
197
  def _infer_impl(
198
- prompt: str,
199
- negative_prompt: str,
200
- seed: int,
201
- randomize_seed: bool,
202
- width: int,
203
- height: int,
204
- guidance_scale: float,
205
- num_inference_steps: int,
206
- shift: float,
207
- max_sequence_length: int,
208
  init_image,
209
- strength: float,
210
  ):
211
  width = int(width)
212
  height = int(height)
213
  seed = int(seed)
214
 
215
  if not model_loaded:
216
- return make_error_image(width, height), f"Model load failed:\n\n{load_error}"
217
 
218
  prompt = (prompt or "").strip()
219
  if not prompt:
220
- return make_error_image(width, height), "Error: Prompt is empty."
221
 
222
  if randomize_seed:
223
  seed = random.randint(0, MAX_SEED)
224
 
 
225
  init_image = _prep_init_image(init_image, width, height)
226
 
227
- generator = torch.Generator(device=device)
228
- generator = generator.manual_seed(seed)
229
-
230
  status = f"Seed: {seed}"
231
  if fallback_msg:
232
  status += f" | {fallback_msg}"
233
 
234
- # Set scheduler per-run because shift can change
235
  scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
236
- pipe_txt2img.scheduler = scheduler
237
- pipe_img2img.scheduler = scheduler
 
 
 
 
 
 
238
 
239
  try:
 
 
 
 
 
 
 
 
 
 
240
  common_kwargs = dict(
241
  prompt=prompt,
242
- negative_prompt=(negative_prompt or "").strip() if (guidance_scale and float(guidance_scale) > 1.0) else None,
243
- guidance_scale=float(guidance_scale),
244
- num_inference_steps=int(num_inference_steps),
245
  generator=generator,
246
  height=height,
247
  width=width,
248
- max_sequence_length=int(max_sequence_length),
249
  )
250
 
 
 
 
251
  with torch.inference_mode():
252
  if device.type == "cuda":
253
  with torch.autocast("cuda", dtype=dtype):
254
  if init_image is not None:
255
- out = pipe_img2img(
256
- image=init_image,
257
- strength=float(strength),
258
- **common_kwargs,
259
- )
260
  else:
261
  out = pipe_txt2img(**common_kwargs)
262
  else:
263
  if init_image is not None:
264
- out = pipe_img2img(
265
- image=init_image,
266
- strength=float(strength),
267
- **common_kwargs,
268
- )
269
  else:
270
  out = pipe_txt2img(**common_kwargs)
271
 
272
- image = out.images[0]
273
- return image, status
274
 
275
  except Exception as e:
276
  return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
@@ -280,7 +310,7 @@ def _infer_impl(
280
  if device.type == "cuda":
281
  torch.cuda.empty_cache()
282
 
283
- # IMPORTANT: decorator must be explicit
284
  if SPACES_AVAILABLE:
285
  @spaces.GPU
286
  def infer(*args, **kwargs):
@@ -290,7 +320,7 @@ else:
290
  return _infer_impl(*args, **kwargs)
291
 
292
  # ============================================================
293
- # UI
294
  # ============================================================
295
 
296
  CSS = """
@@ -319,7 +349,7 @@ with gr.Blocks(title="Z-Image txt2img + img2img") as demo:
319
  status = gr.Markdown("")
320
 
321
  with gr.Accordion("Advanced Settings", open=False):
322
- negative_prompt = gr.Textbox(label="Negative prompt (only used if Guidance > 1)")
323
  seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
324
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
325
 
@@ -327,9 +357,8 @@ with gr.Blocks(title="Z-Image txt2img + img2img") as demo:
327
  height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Height")
328
 
329
  guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=0.0, label="Guidance scale")
330
- num_inference_steps = gr.Slider(1, 50, step=1, value=8, label="Steps")
331
  shift = gr.Slider(1.0, 10.0, step=0.1, value=3.0, label="Time shift")
332
-
333
  max_sequence_length = gr.Slider(64, 512, step=64, value=512, label="Max sequence length")
334
 
335
  strength = gr.Slider(0.0, 1.0, step=0.05, value=0.6, label="Image strength (img2img)")
@@ -354,5 +383,4 @@ with gr.Blocks(title="Z-Image txt2img + img2img") as demo:
354
  )
355
 
356
  if __name__ == "__main__":
357
- # Keep the same launch feel as your first script
358
  demo.queue().launch(ssr_mode=False)
 
1
+ # app.py
2
  # ============================================================
3
  # IMPORTANT: imports order matters for Hugging Face Spaces
4
  # ============================================================
 
23
  import torch
24
  from huggingface_hub import login
25
 
26
+ # ---- Diffusers imports (with safe fallbacks) ----------------
27
+ try:
28
+ from diffusers import ZImagePipeline
29
+ except Exception:
30
+ # Older/newer diffusers sometimes do not export ZImagePipeline at top-level
31
+ from diffusers.pipelines.z_image.pipeline_z_image import ZImagePipeline
32
+
33
+ try:
34
+ from diffusers import AutoPipelineForImage2Image
35
+ except Exception:
36
+ # Rare fallback if top-level export is missing
37
+ from diffusers.pipelines.auto_pipeline import AutoPipelineForImage2Image
38
+
39
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
40
  from transformers import AutoModelForCausalLM, AutoTokenizer
41
 
42
  # ============================================================
 
44
  # ============================================================
45
 
46
  MODEL_PATH = os.environ.get("MODEL_PATH", "telcom/dee-z-image").strip()
47
+
48
+ # Optional knobs
49
+ ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3").strip() # flash_3, flash, sdpa, native
50
  ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true"
51
 
52
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
 
73
  else:
74
  dtype = torch.float32
75
 
 
76
  MAX_IMAGE_SIZE = 1536 if cuda_available else 768
77
 
78
  fallback_msg = ""
 
88
  model_loaded = False
89
  load_error = None
90
 
91
+ def _manual_load_zimage(model_path: str):
 
 
 
92
  """
93
+ Manual loader (matches common Z-Image repo layout with subfolders):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  vae/, transformer/, text_encoder/, tokenizer/
95
+ Works for both remote HF repos and local paths.
96
  """
97
  use_auth_token = HF_TOKEN if HF_TOKEN else True
98
 
99
+ # Load VAE
100
  vae = AutoencoderKL.from_pretrained(
101
+ model_path,
102
  subfolder="vae",
103
  torch_dtype=dtype,
104
  use_auth_token=use_auth_token,
105
  )
106
+
107
+ # Load text encoder + tokenizer
108
  text_encoder = AutoModelForCausalLM.from_pretrained(
109
+ model_path,
110
  subfolder="text_encoder",
111
  torch_dtype=dtype,
112
  use_auth_token=use_auth_token,
113
  ).eval()
114
+
115
  tokenizer = AutoTokenizer.from_pretrained(
116
+ model_path,
117
  subfolder="tokenizer",
118
  use_auth_token=use_auth_token,
119
  )
120
  tokenizer.padding_side = "left"
121
 
122
+ # Load transformer
123
  from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel
124
 
125
  transformer = ZImageTransformer2DModel.from_pretrained(
126
+ model_path,
127
  subfolder="transformer",
128
  torch_dtype=dtype,
129
  use_auth_token=use_auth_token,
130
  )
131
 
132
+ # Build base txt2img pipeline
133
+ p_txt = ZImagePipeline(
134
+ scheduler=None,
135
+ vae=vae,
136
+ text_encoder=text_encoder,
137
+ tokenizer=tokenizer,
138
+ transformer=transformer,
139
+ )
140
+ return p_txt
141
 
142
  try:
143
+ # Try standard from_pretrained first (if repo is fully diffusers-compatible)
144
+ kwargs = {"torch_dtype": dtype, "use_safetensors": True}
145
+ if HF_TOKEN:
146
+ kwargs["token"] = HF_TOKEN
147
+
148
  try:
149
+ pipe_txt2img = ZImagePipeline.from_pretrained(MODEL_PATH, **kwargs)
150
+ except Exception:
151
+ pipe_txt2img = _manual_load_zimage(MODEL_PATH)
 
 
152
 
 
153
  pipe_txt2img = pipe_txt2img.to(device)
 
154
 
155
+ # Optional attention backend (best-effort)
156
  try:
157
  if hasattr(pipe_txt2img, "transformer") and hasattr(pipe_txt2img.transformer, "set_attention_backend"):
158
  pipe_txt2img.transformer.set_attention_backend(ATTENTION_BACKEND)
 
159
  except Exception:
160
  pass
161
 
162
+ # Optional compile (best-effort)
163
  if ENABLE_COMPILE and device.type == "cuda":
164
  try:
165
+ pipe_txt2img.transformer = torch.compile(
166
+ pipe_txt2img.transformer,
167
+ mode="max-autotune-no-cudagraphs",
168
+ fullgraph=False,
169
+ )
170
  except Exception:
171
  pass
172
 
173
+ # Disable progress bars
174
  try:
175
  pipe_txt2img.set_progress_bar_config(disable=True)
176
+ except Exception:
177
+ pass
178
+
179
+ # Build img2img pipeline using AutoPipelineForImage2Image
180
+ # Preferred: from_pipe reuses components without loading a second copy.
181
+ try:
182
+ pipe_img2img = AutoPipelineForImage2Image.from_pipe(pipe_txt2img).to(device)
183
+ except Exception:
184
+ # Fallback: load an image2image pipeline from the repo (may use more VRAM)
185
+ pipe_img2img = AutoPipelineForImage2Image.from_pretrained(MODEL_PATH, **kwargs).to(device)
186
+
187
+ try:
188
  pipe_img2img.set_progress_bar_config(disable=True)
189
  except Exception:
190
  pass
191
 
192
+ model_loaded = True
193
+
194
+ except Exception as e:
195
+ load_error = repr(e)
196
+ model_loaded = False
197
+
198
  # ============================================================
199
+ # Utility helpers
200
  # ============================================================
201
 
202
+ def make_error_image(w: int, h: int) -> Image.Image:
203
+ return Image.new("RGB", (int(w), int(h)), (18, 18, 22))
204
 
205
+ def _prep_init_image(init_image, width: int, height: int):
206
  if init_image is None:
207
  return None
208
  if not isinstance(init_image, Image.Image):
209
  return None
210
+ img = init_image.convert("RGB")
211
+ if img.size != (width, height):
212
+ img = img.resize((width, height), Image.LANCZOS)
213
+ return img
214
 
215
  # ============================================================
216
  # Inference
217
  # ============================================================
218
 
219
  def _infer_impl(
220
+ prompt,
221
+ negative_prompt,
222
+ seed,
223
+ randomize_seed,
224
+ width,
225
+ height,
226
+ guidance_scale,
227
+ num_inference_steps,
228
+ shift,
229
+ max_sequence_length,
230
  init_image,
231
+ strength,
232
  ):
233
  width = int(width)
234
  height = int(height)
235
  seed = int(seed)
236
 
237
  if not model_loaded:
238
+ return make_error_image(width, height), f"Model load failed: {load_error}"
239
 
240
  prompt = (prompt or "").strip()
241
  if not prompt:
242
+ return make_error_image(width, height), "Error: prompt is empty."
243
 
244
  if randomize_seed:
245
  seed = random.randint(0, MAX_SEED)
246
 
247
+ generator = torch.Generator(device=device).manual_seed(seed)
248
  init_image = _prep_init_image(init_image, width, height)
249
 
 
 
 
250
  status = f"Seed: {seed}"
251
  if fallback_msg:
252
  status += f" | {fallback_msg}"
253
 
254
+ # Set scheduler per-run so shift changes take effect
255
  scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
256
+ try:
257
+ pipe_txt2img.scheduler = scheduler
258
+ except Exception:
259
+ pass
260
+ try:
261
+ pipe_img2img.scheduler = scheduler
262
+ except Exception:
263
+ pass
264
 
265
  try:
266
+ gs = float(guidance_scale)
267
+ steps = int(num_inference_steps)
268
+ msl = int(max_sequence_length)
269
+ st = float(strength)
270
+
271
+ # Some pipelines only accept negative_prompt when guidance is used.
272
+ neg = (negative_prompt or "").strip()
273
+ if not neg:
274
+ neg = None
275
+
276
  common_kwargs = dict(
277
  prompt=prompt,
278
+ guidance_scale=gs,
279
+ num_inference_steps=steps,
 
280
  generator=generator,
281
  height=height,
282
  width=width,
283
+ max_sequence_length=msl,
284
  )
285
 
286
+ if neg is not None:
287
+ common_kwargs["negative_prompt"] = neg
288
+
289
  with torch.inference_mode():
290
  if device.type == "cuda":
291
  with torch.autocast("cuda", dtype=dtype):
292
  if init_image is not None:
293
+ out = pipe_img2img(image=init_image, strength=st, **common_kwargs)
 
 
 
 
294
  else:
295
  out = pipe_txt2img(**common_kwargs)
296
  else:
297
  if init_image is not None:
298
+ out = pipe_img2img(image=init_image, strength=st, **common_kwargs)
 
 
 
 
299
  else:
300
  out = pipe_txt2img(**common_kwargs)
301
 
302
+ img = out.images[0]
303
+ return img, status
304
 
305
  except Exception as e:
306
  return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
 
310
  if device.type == "cuda":
311
  torch.cuda.empty_cache()
312
 
313
+ # Decorated entrypoint for Spaces
314
  if SPACES_AVAILABLE:
315
  @spaces.GPU
316
  def infer(*args, **kwargs):
 
320
  return _infer_impl(*args, **kwargs)
321
 
322
  # ============================================================
323
+ # UI (your first style)
324
  # ============================================================
325
 
326
  CSS = """
 
349
  status = gr.Markdown("")
350
 
351
  with gr.Accordion("Advanced Settings", open=False):
352
+ negative_prompt = gr.Textbox(label="Negative prompt (optional)")
353
  seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
354
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
355
 
 
357
  height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Height")
358
 
359
  guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=0.0, label="Guidance scale")
360
+ num_inference_steps = gr.Slider(1, 100, step=1, value=8, label="Steps")
361
  shift = gr.Slider(1.0, 10.0, step=0.1, value=3.0, label="Time shift")
 
362
  max_sequence_length = gr.Slider(64, 512, step=64, value=512, label="Max sequence length")
363
 
364
  strength = gr.Slider(0.0, 1.0, step=0.05, value=0.6, label="Image strength (img2img)")
 
383
  )
384
 
385
  if __name__ == "__main__":
 
386
  demo.queue().launch(ssr_mode=False)