telcom commited on
Commit
ef8496d
·
verified ·
1 Parent(s): 59bda41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -279
app.py CHANGED
@@ -1,8 +1,6 @@
1
  # app.py
2
  # ============================================================
3
- # Automatic clothing replacement (no paint, no manual mask)
4
- # GroundingDINO -> SAM -> SDXL Inpaint
5
- # Fixes: Spaces requires @spaces.GPU function at startup
6
  # ============================================================
7
 
8
  import os
@@ -10,401 +8,348 @@ import gc
10
  import random
11
  import warnings
12
  import logging
 
 
 
 
 
 
 
 
13
 
14
- import numpy as np
15
  import gradio as gr
 
16
  from PIL import Image
17
 
18
  import torch
19
- from huggingface_hub import login, hf_hub_download
20
-
21
- from diffusers import StableDiffusionXLInpaintPipeline
22
- from groundingdino.util.inference import load_model, predict
23
- from segment_anything import sam_model_registry, SamPredictor
24
-
25
 
26
  # ============================================================
27
- # Spaces import (do not hide the decorated function)
28
  # ============================================================
29
- try:
30
- import spaces
31
- SPACES_AVAILABLE = True
32
- except Exception:
33
- spaces = None
34
- SPACES_AVAILABLE = False
35
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # ============================================================
38
  # Config
39
  # ============================================================
40
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
41
- warnings.filterwarnings("ignore")
42
- logging.getLogger("transformers").setLevel(logging.ERROR)
 
 
43
 
44
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
45
  if HF_TOKEN:
46
  login(token=HF_TOKEN)
47
 
48
- MAX_SEED = np.iinfo(np.int32).max
 
 
49
 
50
- CUDA_OK = torch.cuda.is_available()
51
- DEVICE = "cuda" if CUDA_OK else "cpu"
52
 
53
- if CUDA_OK and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
54
- DTYPE = torch.bfloat16
55
- elif CUDA_OK:
56
- DTYPE = torch.float16
57
- else:
58
- DTYPE = torch.float32
59
 
60
- MAX_IMAGE_SIZE = 1024 if CUDA_OK else 768
 
61
 
62
- DEFAULT_CLOTHING_QUERY = "shirt, t-shirt, jacket, coat, hoodie, sweater, dress, pants, jeans, skirt, clothing"
63
- DEFAULT_BOX_THRESHOLD = 0.35
64
- DEFAULT_TEXT_THRESHOLD = 0.25
 
 
 
65
 
66
- INPAINT_MODEL = os.environ.get(
67
- "INPAINT_MODEL",
68
- "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
69
- ).strip()
70
 
 
 
 
71
 
72
  # ============================================================
73
- # Load models (download from HF Hub)
74
  # ============================================================
 
 
 
75
  model_loaded = False
76
  load_error = None
77
 
78
- dino = None
79
- sam_predictor = None
80
- pipe = None
81
-
82
- def download_and_load_models():
83
- global dino, sam_predictor, pipe
84
-
85
- # ---- GroundingDINO ----
86
- DINO_REPO = "IDEA-Research/GroundingDINO"
87
- dino_cfg_path = hf_hub_download(
88
- repo_id=DINO_REPO,
89
- filename="groundingdino/config/GroundingDINO_SwinT_OGC.py",
90
- token=HF_TOKEN if HF_TOKEN else None,
91
- )
92
- dino_ckpt_path = hf_hub_download(
93
- repo_id=DINO_REPO,
94
- filename="groundingdino_swint_ogc.pth",
95
- token=HF_TOKEN if HF_TOKEN else None,
96
- )
97
- dino = load_model(dino_cfg_path, dino_ckpt_path)
98
-
99
- # ---- SAM ----
100
- SAM_REPO = "facebook/sam-vit-huge"
101
- sam_ckpt_path = hf_hub_download(
102
- repo_id=SAM_REPO,
103
- filename="sam_vit_h_4b8939.pth",
104
- token=HF_TOKEN if HF_TOKEN else None,
105
- )
106
- sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt_path)
107
- sam.to(DEVICE)
108
- sam_predictor = SamPredictor(sam)
109
-
110
- # ---- SDXL Inpaint ----
111
- fp_kwargs = {"torch_dtype": DTYPE, "use_safetensors": True}
112
- if HF_TOKEN:
113
- fp_kwargs["token"] = HF_TOKEN
114
 
115
- pipe = StableDiffusionXLInpaintPipeline.from_pretrained(INPAINT_MODEL, **fp_kwargs).to(DEVICE)
 
 
116
  try:
117
- pipe.set_progress_bar_config(disable=True)
 
 
 
 
 
118
  except Exception:
119
  pass
120
 
121
- try:
122
- download_and_load_models()
123
- model_loaded = True
124
- except Exception as e:
125
- model_loaded = False
126
- load_error = repr(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # ============================================================
130
  # Helpers
131
  # ============================================================
 
132
  def make_error_image(w: int, h: int) -> Image.Image:
133
  return Image.new("RGB", (int(w), int(h)), (18, 18, 22))
134
 
135
- def fit64(w: int, h: int):
136
- w = max(256, (w // 64) * 64)
137
- h = max(256, (h // 64) * 64)
138
- return w, h
139
-
140
- def resize_rgb(img: Image.Image, w: int, h: int) -> Image.Image:
141
- return img.convert("RGB").resize((w, h), Image.LANCZOS)
142
-
143
- def resize_mask(mask: Image.Image, w: int, h: int) -> Image.Image:
144
- return mask.convert("L").resize((w, h), Image.NEAREST)
145
-
146
- def dilate_mask(mask_np: np.ndarray, radius: int) -> np.ndarray:
147
- if radius <= 0:
148
- return mask_np
149
- import cv2
150
- kernel = np.ones((radius * 2 + 1, radius * 2 + 1), np.uint8)
151
- return cv2.dilate(mask_np, kernel, iterations=1)
152
-
153
- def largest_component(mask_np: np.ndarray) -> np.ndarray:
154
- import cv2
155
- num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_np, connectivity=8)
156
- if num_labels <= 1:
157
- return mask_np
158
- largest = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
159
- out = np.zeros_like(mask_np)
160
- out[labels == largest] = 255
161
- return out
162
-
163
-
164
- def detect_clothing_mask(
165
- image: Image.Image,
166
- clothing_query: str,
167
- box_threshold: float,
168
- text_threshold: float,
169
- dilate_radius: int,
170
- keep_largest: bool,
171
- ):
172
- if image is None:
173
  return None
174
-
175
- img_rgb = image.convert("RGB")
176
- w, h = img_rgb.size
177
- img_np = np.array(img_rgb)
178
-
179
- boxes, _, _ = predict(
180
- model=dino,
181
- image=img_np,
182
- caption=clothing_query,
183
- box_threshold=float(box_threshold),
184
- text_threshold=float(text_threshold),
185
- )
186
-
187
- if boxes is None or len(boxes) == 0:
188
  return None
189
-
190
- boxes_px = []
191
- for b in boxes:
192
- cx, cy, bw, bh = map(float, b)
193
- x1 = int((cx - bw / 2.0) * w)
194
- y1 = int((cy - bh / 2.0) * h)
195
- x2 = int((cx + bw / 2.0) * w)
196
- y2 = int((cy + bh / 2.0) * h)
197
- x1 = max(0, min(w - 1, x1))
198
- y1 = max(0, min(h - 1, y1))
199
- x2 = max(0, min(w - 1, x2))
200
- y2 = max(0, min(h - 1, y2))
201
- if x2 > x1 and y2 > y1:
202
- boxes_px.append([x1, y1, x2, y2])
203
-
204
- if not boxes_px:
205
- return None
206
-
207
- sam_predictor.set_image(img_np)
208
-
209
- full_mask = np.zeros((h, w), dtype=np.uint8)
210
- for box in boxes_px:
211
- box_arr = np.array(box, dtype=np.float32)
212
- masks, _, _ = sam_predictor.predict(box=box_arr, multimask_output=False)
213
- m = masks[0].astype(np.uint8) * 255
214
- full_mask = np.maximum(full_mask, m)
215
-
216
- if keep_largest:
217
- full_mask = largest_component(full_mask)
218
-
219
- full_mask = dilate_mask(full_mask, int(dilate_radius))
220
-
221
- return Image.fromarray(full_mask, mode="L")
222
-
223
 
224
  # ============================================================
225
- # Core inference (no decorator here)
226
  # ============================================================
227
- def infer_core(
228
- image,
229
  prompt,
230
  negative_prompt,
231
- clothing_query,
232
  seed,
233
  randomize_seed,
234
  width,
235
  height,
236
  guidance_scale,
237
  num_inference_steps,
238
- box_threshold,
239
- text_threshold,
240
- dilate_radius,
241
- keep_largest,
242
  ):
243
  width = int(width)
244
  height = int(height)
 
245
 
246
  if not model_loaded:
247
  return make_error_image(width, height), f"Model load failed: {load_error}"
248
 
249
- if image is None:
250
- return make_error_image(width, height), "Error: upload an image."
251
-
252
  prompt = (prompt or "").strip()
253
  if not prompt:
254
  return make_error_image(width, height), "Error: prompt is empty."
255
 
256
- neg = (negative_prompt or "").strip()
257
- if not neg:
258
- neg = None
259
-
260
- clothing_query = (clothing_query or "").strip() or DEFAULT_CLOTHING_QUERY
261
-
262
  if randomize_seed:
263
  seed = random.randint(0, MAX_SEED)
264
- else:
265
- seed = int(seed)
266
-
267
- generator = torch.Generator(device=DEVICE).manual_seed(seed)
268
-
269
- width, height = fit64(width, height)
270
- width = min(width, MAX_IMAGE_SIZE)
271
- height = min(height, MAX_IMAGE_SIZE)
272
-
273
- mask = detect_clothing_mask(
274
- image=image,
275
- clothing_query=clothing_query,
276
- box_threshold=float(box_threshold),
277
- text_threshold=float(text_threshold),
278
- dilate_radius=int(dilate_radius),
279
- keep_largest=bool(keep_largest),
280
- )
281
-
282
- if mask is None:
283
- return image, f"Seed: {seed}. No clothing detected, try lowering thresholds or changing query."
284
 
285
- img_resized = resize_rgb(image, width, height)
286
- mask_resized = resize_mask(mask, width, height)
287
 
288
  status = f"Seed: {seed}"
289
- if not CUDA_OK:
290
- status += " | CPU only (slow)."
291
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  with torch.inference_mode():
294
- if CUDA_OK:
295
- with torch.autocast("cuda", dtype=DTYPE):
296
- out = pipe(
297
- prompt=prompt,
298
- negative_prompt=neg,
299
- image=img_resized,
300
- mask_image=mask_resized,
301
- guidance_scale=float(guidance_scale),
302
- num_inference_steps=int(num_inference_steps),
303
- generator=generator,
304
- )
305
  else:
306
- out = pipe(
307
- prompt=prompt,
308
- negative_prompt=neg,
309
- image=img_resized,
310
- mask_image=mask_resized,
311
- guidance_scale=float(guidance_scale),
312
- num_inference_steps=int(num_inference_steps),
313
- generator=generator,
314
- )
315
-
316
- return out.images[0], status
317
 
318
  except Exception as e:
319
  return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
320
 
321
  finally:
322
  gc.collect()
323
- if CUDA_OK:
324
  torch.cuda.empty_cache()
325
 
326
-
327
- # ============================================================
328
- # IMPORTANT: Always define a @spaces.GPU function if spaces imports
329
- # (Spaces startup checker requires it)
330
- # ============================================================
331
  if SPACES_AVAILABLE:
332
  @spaces.GPU
333
  def infer(*args, **kwargs):
334
- return infer_core(*args, **kwargs)
335
  else:
336
  def infer(*args, **kwargs):
337
- return infer_core(*args, **kwargs)
338
-
339
 
340
  # ============================================================
341
- # UI
342
  # ============================================================
343
- CSS = "body { background: #000; color: #fff; }"
344
 
345
- with gr.Blocks(title="Auto Clothing Replacement") as demo:
 
 
 
 
 
 
 
346
  gr.HTML(f"<style>{CSS}</style>")
347
- gr.Markdown("## Automatic Clothing Replacement (no paint, no manual mask)")
348
- gr.Markdown("Upload a photo, describe the new clothing. Detection and masking is automatic.")
 
349
 
350
  if not model_loaded:
351
  gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
352
 
353
- image = gr.Image(type="pil", label="Input image")
354
 
355
- prompt = gr.Textbox(
356
- label="Prompt (describe new clothing)",
357
- lines=2,
358
- placeholder="e.g., a navy business suit jacket, realistic fabric folds, studio lighting",
359
- )
360
- negative_prompt = gr.Textbox(
361
- label="Negative prompt (optional)",
362
- lines=2,
363
- placeholder="e.g., blurry, deformed, low quality",
364
- )
365
 
366
- run = gr.Button("Replace Clothing")
367
- out_img = gr.Image(label="Result")
368
  status = gr.Markdown("")
369
 
370
- with gr.Accordion("Advanced settings", open=False):
371
- clothing_query = gr.Textbox(label="Detection query", value=DEFAULT_CLOTHING_QUERY)
372
-
373
  seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
374
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
375
 
376
- width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=768 if not CUDA_OK else 1024, label="Width")
377
- height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=768 if not CUDA_OK else 1024, label="Height")
378
-
379
- guidance_scale = gr.Slider(0.0, 15.0, step=0.1, value=7.0, label="Guidance scale")
380
- num_inference_steps = gr.Slider(1, 80, step=1, value=30, label="Steps")
381
 
382
- box_threshold = gr.Slider(0.05, 0.90, step=0.01, value=DEFAULT_BOX_THRESHOLD, label="Box threshold (DINO)")
383
- text_threshold = gr.Slider(0.05, 0.90, step=0.01, value=DEFAULT_TEXT_THRESHOLD, label="Text threshold (DINO)")
 
 
384
 
385
- dilate_radius = gr.Slider(0, 30, step=1, value=8, label="Mask dilation radius")
386
- keep_largest = gr.Checkbox(value=True, label="Keep only largest region")
387
 
388
- run.click(
389
  fn=infer,
390
  inputs=[
391
- image,
392
  prompt,
393
  negative_prompt,
394
- clothing_query,
395
  seed,
396
  randomize_seed,
397
  width,
398
  height,
399
  guidance_scale,
400
  num_inference_steps,
401
- box_threshold,
402
- text_threshold,
403
- dilate_radius,
404
- keep_largest,
405
  ],
406
- outputs=[out_img, status],
407
  )
408
 
409
  if __name__ == "__main__":
410
- demo.queue().launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
1
  # app.py
2
  # ============================================================
3
+ # IMPORTANT: imports order matters for Hugging Face Spaces
 
 
4
  # ============================================================
5
 
6
  import os
 
8
  import random
9
  import warnings
10
  import logging
11
+ import inspect
12
+
13
+ # ---- Spaces GPU decorator (must be imported early) ----------
14
+ try:
15
+ import spaces # noqa: F401
16
+ SPACES_AVAILABLE = True
17
+ except Exception:
18
+ SPACES_AVAILABLE = False
19
 
 
20
  import gradio as gr
21
+ import numpy as np
22
  from PIL import Image
23
 
24
  import torch
25
+ from huggingface_hub import login
 
 
 
 
 
26
 
27
  # ============================================================
28
+ # Try importing Z-Image pipelines (requires diffusers>=0.36.0)
29
  # ============================================================
 
 
 
 
 
 
30
 
31
+ ZIMAGE_AVAILABLE = True
32
+ ZIMAGE_IMPORT_ERROR = None
33
+
34
+ try:
35
+ from diffusers import (
36
+ ZImagePipeline,
37
+ ZImageImg2ImgPipeline,
38
+ FlowMatchEulerDiscreteScheduler,
39
+ )
40
+ except Exception as e:
41
+ ZIMAGE_AVAILABLE = False
42
+ ZIMAGE_IMPORT_ERROR = repr(e)
43
 
44
  # ============================================================
45
  # Config
46
  # ============================================================
47
+
48
+ MODEL_PATH = os.environ.get("MODEL_PATH", "telcom/dee-z-image").strip()
49
+
50
+ ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3").strip()
51
+ ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true"
52
 
53
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
54
  if HF_TOKEN:
55
  login(token=HF_TOKEN)
56
 
57
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
58
+ warnings.filterwarnings("ignore")
59
+ logging.getLogger("transformers").setLevel(logging.ERROR)
60
 
61
+ MAX_SEED = np.iinfo(np.int32).max
 
62
 
63
+ # ============================================================
64
+ # Device & dtype
65
+ # ============================================================
 
 
 
66
 
67
+ cuda_available = torch.cuda.is_available()
68
+ device = torch.device("cuda" if cuda_available else "cpu")
69
 
70
+ if cuda_available and hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported():
71
+ dtype = torch.bfloat16
72
+ elif cuda_available:
73
+ dtype = torch.float16
74
+ else:
75
+ dtype = torch.float32
76
 
77
+ MAX_IMAGE_SIZE = 1536 if cuda_available else 768
 
 
 
78
 
79
+ fallback_msg = ""
80
+ if not cuda_available:
81
+ fallback_msg = "GPU unavailable. Running in CPU fallback mode (slow)."
82
 
83
  # ============================================================
84
+ # Load pipelines
85
  # ============================================================
86
+
87
+ pipe_txt2img = None
88
+ pipe_img2img = None
89
  model_loaded = False
90
  load_error = None
91
 
92
+ def _set_attention_backend_best_effort(p):
93
+ try:
94
+ if hasattr(p, "transformer") and hasattr(p.transformer, "set_attention_backend"):
95
+ p.transformer.set_attention_backend(ATTENTION_BACKEND)
96
+ except Exception:
97
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ def _compile_best_effort(p):
100
+ if not (ENABLE_COMPILE and device.type == "cuda"):
101
+ return
102
  try:
103
+ if hasattr(p, "transformer"):
104
+ p.transformer = torch.compile(
105
+ p.transformer,
106
+ mode="max-autotune-no-cudagraphs",
107
+ fullgraph=False,
108
+ )
109
  except Exception:
110
  pass
111
 
112
+ if ZIMAGE_AVAILABLE:
113
+ try:
114
+ fp_kwargs = {
115
+ "torch_dtype": dtype,
116
+ "use_safetensors": True,
117
+ }
118
+ if HF_TOKEN:
119
+ fp_kwargs["token"] = HF_TOKEN
120
+
121
+ pipe_txt2img = ZImagePipeline.from_pretrained(MODEL_PATH, **fp_kwargs).to(device)
122
+ _set_attention_backend_best_effort(pipe_txt2img)
123
+ _compile_best_effort(pipe_txt2img)
124
+
125
+ try:
126
+ pipe_txt2img.set_progress_bar_config(disable=True)
127
+ except Exception:
128
+ pass
129
+
130
+ # Share weights/components with img2img pipeline
131
+ pipe_img2img = ZImageImg2ImgPipeline(**pipe_txt2img.components).to(device)
132
+ _set_attention_backend_best_effort(pipe_img2img)
133
+ try:
134
+ pipe_img2img.set_progress_bar_config(disable=True)
135
+ except Exception:
136
+ pass
137
+
138
+ model_loaded = True
139
 
140
+ except Exception as e:
141
+ load_error = repr(e)
142
+ model_loaded = False
143
+ else:
144
+ load_error = (
145
+ "Z-Image pipelines not available in your diffusers install.\n\n"
146
+ f"Import error:\n{ZIMAGE_IMPORT_ERROR}\n\n"
147
+ "Fix: set requirements.txt to diffusers==0.36.0 (or install Diffusers from source)."
148
+ )
149
+ model_loaded = False
150
 
151
  # ============================================================
152
  # Helpers
153
  # ============================================================
154
+
155
  def make_error_image(w: int, h: int) -> Image.Image:
156
  return Image.new("RGB", (int(w), int(h)), (18, 18, 22))
157
 
158
+ def prep_init_image(img: Image.Image, width: int, height: int) -> Image.Image:
159
+ if img is None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  return None
161
+ if not isinstance(img, Image.Image):
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  return None
163
+ img = img.convert("RGB")
164
+ if img.size != (width, height):
165
+ img = img.resize((width, height), Image.LANCZOS)
166
+ return img
167
+
168
+ def _call_pipeline(pipe, kwargs: dict):
169
+ """
170
+ Robust call: only pass kwargs the pipeline actually accepts.
171
+ This avoids crashes if a particular build does not support negative_prompt, etc.
172
+ """
173
+ try:
174
+ sig = inspect.signature(pipe.__call__)
175
+ allowed = set(sig.parameters.keys())
176
+ filtered = {k: v for k, v in kwargs.items() if k in allowed and v is not None}
177
+ return pipe(**filtered)
178
+ except Exception:
179
+ # Fallback: try raw kwargs (some pipelines use **kwargs internally)
180
+ return pipe(**{k: v for k, v in kwargs.items() if v is not None})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  # ============================================================
183
+ # Inference
184
  # ============================================================
185
+
186
+ def _infer_impl(
187
  prompt,
188
  negative_prompt,
 
189
  seed,
190
  randomize_seed,
191
  width,
192
  height,
193
  guidance_scale,
194
  num_inference_steps,
195
+ shift,
196
+ max_sequence_length,
197
+ init_image,
198
+ strength,
199
  ):
200
  width = int(width)
201
  height = int(height)
202
+ seed = int(seed)
203
 
204
  if not model_loaded:
205
  return make_error_image(width, height), f"Model load failed: {load_error}"
206
 
 
 
 
207
  prompt = (prompt or "").strip()
208
  if not prompt:
209
  return make_error_image(width, height), "Error: prompt is empty."
210
 
 
 
 
 
 
 
211
  if randomize_seed:
212
  seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
+ generator = torch.Generator(device=device).manual_seed(seed)
 
215
 
216
  status = f"Seed: {seed}"
217
+ if fallback_msg:
218
+ status += f" | {fallback_msg}"
219
 
220
+ gs = float(guidance_scale)
221
+ steps = int(num_inference_steps)
222
+ msl = int(max_sequence_length)
223
+ st = float(strength)
224
+
225
+ neg = (negative_prompt or "").strip()
226
+ if not neg:
227
+ neg = None
228
+
229
+ init_image = prep_init_image(init_image, width, height)
230
+
231
+ # Update scheduler (shift) per run
232
  try:
233
+ scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
234
+ pipe_txt2img.scheduler = scheduler
235
+ pipe_img2img.scheduler = scheduler
236
+ except Exception:
237
+ pass
238
+
239
+ try:
240
+ base_kwargs = dict(
241
+ prompt=prompt,
242
+ height=height,
243
+ width=width,
244
+ guidance_scale=gs,
245
+ num_inference_steps=steps,
246
+ generator=generator,
247
+ max_sequence_length=msl,
248
+ )
249
+ # only passed if supported by the pipeline
250
+ if neg is not None:
251
+ base_kwargs["negative_prompt"] = neg
252
+
253
  with torch.inference_mode():
254
+ if device.type == "cuda":
255
+ with torch.autocast("cuda", dtype=dtype):
256
+ if init_image is not None:
257
+ out = _call_pipeline(
258
+ pipe_img2img,
259
+ {**base_kwargs, "image": init_image, "strength": st},
260
+ )
261
+ else:
262
+ out = _call_pipeline(pipe_txt2img, base_kwargs)
 
 
263
  else:
264
+ if init_image is not None:
265
+ out = _call_pipeline(
266
+ pipe_img2img,
267
+ {**base_kwargs, "image": init_image, "strength": st},
268
+ )
269
+ else:
270
+ out = _call_pipeline(pipe_txt2img, base_kwargs)
271
+
272
+ img = out.images[0]
273
+ return img, status
 
274
 
275
  except Exception as e:
276
  return make_error_image(width, height), f"Error: {type(e).__name__}: {e}"
277
 
278
  finally:
279
  gc.collect()
280
+ if device.type == "cuda":
281
  torch.cuda.empty_cache()
282
 
 
 
 
 
 
283
  if SPACES_AVAILABLE:
284
  @spaces.GPU
285
  def infer(*args, **kwargs):
286
+ return _infer_impl(*args, **kwargs)
287
  else:
288
  def infer(*args, **kwargs):
289
+ return _infer_impl(*args, **kwargs)
 
290
 
291
  # ============================================================
292
+ # UI (simple black style like your SDXL example)
293
  # ============================================================
 
294
 
295
+ CSS = """
296
+ body {
297
+ background: #000;
298
+ color: #fff;
299
+ }
300
+ """
301
+
302
+ with gr.Blocks(title="Z-Image txt2img + img2img") as demo:
303
  gr.HTML(f"<style>{CSS}</style>")
304
+
305
+ if fallback_msg:
306
+ gr.Markdown(f"**{fallback_msg}**")
307
 
308
  if not model_loaded:
309
  gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
310
 
311
+ gr.Markdown("## Z-Image Generator (txt2img + img2img)")
312
 
313
+ prompt = gr.Textbox(label="Prompt", lines=2)
314
+ init_image = gr.Image(label="Initial image (optional)", type="pil")
 
 
 
 
 
 
 
 
315
 
316
+ run_button = gr.Button("Generate")
317
+ result = gr.Image(label="Result")
318
  status = gr.Markdown("")
319
 
320
+ with gr.Accordion("Advanced Settings", open=False):
321
+ negative_prompt = gr.Textbox(label="Negative prompt (optional)")
 
322
  seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
323
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
324
 
325
+ width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Width")
326
+ height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Height")
 
 
 
327
 
328
+ guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=0.0, label="Guidance scale")
329
+ num_inference_steps = gr.Slider(1, 100, step=1, value=8, label="Steps")
330
+ shift = gr.Slider(1.0, 10.0, step=0.1, value=3.0, label="Time shift")
331
+ max_sequence_length = gr.Slider(64, 512, step=64, value=512, label="Max sequence length")
332
 
333
+ strength = gr.Slider(0.0, 1.0, step=0.05, value=0.6, label="Image strength (img2img)")
 
334
 
335
+ run_button.click(
336
  fn=infer,
337
  inputs=[
 
338
  prompt,
339
  negative_prompt,
 
340
  seed,
341
  randomize_seed,
342
  width,
343
  height,
344
  guidance_scale,
345
  num_inference_steps,
346
+ shift,
347
+ max_sequence_length,
348
+ init_image,
349
+ strength,
350
  ],
351
+ outputs=[result, status],
352
  )
353
 
354
  if __name__ == "__main__":
355
+ demo.queue().launch(ssr_mode=False)