telcom commited on
Commit
21668ae
·
verified ·
1 Parent(s): 56ac11d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -153
app.py CHANGED
@@ -1,6 +1,7 @@
1
  # app.py
2
  # ============================================================
3
- # IMPORTANT: imports order matters for Hugging Face Spaces
 
4
  # ============================================================
5
 
6
  import os
@@ -8,7 +9,6 @@ import gc
8
  import random
9
  import warnings
10
  import logging
11
- import inspect
12
 
13
  # ---- Spaces GPU decorator (must be imported early) ----------
14
  try:
@@ -24,31 +24,17 @@ from PIL import Image
24
  import torch
25
  from huggingface_hub import login
26
 
27
- # ============================================================
28
- # Try importing Z-Image pipelines (requires diffusers>=0.36.0)
29
- # ============================================================
30
-
31
- ZIMAGE_AVAILABLE = True
32
- ZIMAGE_IMPORT_ERROR = None
33
-
34
- try:
35
- from diffusers import (
36
- ZImagePipeline,
37
- ZImageImg2ImgPipeline,
38
- FlowMatchEulerDiscreteScheduler,
39
- )
40
- except Exception as e:
41
- ZIMAGE_AVAILABLE = False
42
- ZIMAGE_IMPORT_ERROR = repr(e)
43
 
44
  # ============================================================
45
  # Config
46
  # ============================================================
47
 
48
- MODEL_PATH = os.environ.get("MODEL_PATH", "telcom/dee-z-image").strip()
49
-
50
- ATTENTION_BACKEND = os.environ.get("ATTENTION_BACKEND", "flash_3").strip()
51
- ENABLE_COMPILE = os.environ.get("ENABLE_COMPILE", "false").lower() == "true"
 
52
 
53
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
54
  if HF_TOKEN:
@@ -81,71 +67,39 @@ if not cuda_available:
81
  fallback_msg = "GPU unavailable. Running in CPU fallback mode (slow)."
82
 
83
  # ============================================================
84
- # Load pipelines
85
  # ============================================================
86
 
87
- pipe_txt2img = None
88
- pipe_img2img = None
89
  model_loaded = False
90
  load_error = None
91
 
92
- def _set_attention_backend_best_effort(p):
93
- try:
94
- if hasattr(p, "transformer") and hasattr(p.transformer, "set_attention_backend"):
95
- p.transformer.set_attention_backend(ATTENTION_BACKEND)
96
- except Exception:
97
- pass
 
 
 
 
 
 
 
 
 
 
98
 
99
- def _compile_best_effort(p):
100
- if not (ENABLE_COMPILE and device.type == "cuda"):
101
- return
102
  try:
103
- if hasattr(p, "transformer"):
104
- p.transformer = torch.compile(
105
- p.transformer,
106
- mode="max-autotune-no-cudagraphs",
107
- fullgraph=False,
108
- )
109
  except Exception:
110
  pass
111
 
112
- if ZIMAGE_AVAILABLE:
113
- try:
114
- fp_kwargs = {
115
- "torch_dtype": dtype,
116
- "use_safetensors": True,
117
- }
118
- if HF_TOKEN:
119
- fp_kwargs["token"] = HF_TOKEN
120
-
121
- pipe_txt2img = ZImagePipeline.from_pretrained(MODEL_PATH, **fp_kwargs).to(device)
122
- _set_attention_backend_best_effort(pipe_txt2img)
123
- _compile_best_effort(pipe_txt2img)
124
-
125
- try:
126
- pipe_txt2img.set_progress_bar_config(disable=True)
127
- except Exception:
128
- pass
129
-
130
- # Share weights/components with img2img pipeline
131
- pipe_img2img = ZImageImg2ImgPipeline(**pipe_txt2img.components).to(device)
132
- _set_attention_backend_best_effort(pipe_img2img)
133
- try:
134
- pipe_img2img.set_progress_bar_config(disable=True)
135
- except Exception:
136
- pass
137
-
138
- model_loaded = True
139
 
140
- except Exception as e:
141
- load_error = repr(e)
142
- model_loaded = False
143
- else:
144
- load_error = (
145
- "Z-Image pipelines not available in your diffusers install.\n\n"
146
- f"Import error:\n{ZIMAGE_IMPORT_ERROR}\n\n"
147
- "Fix: set requirements.txt to diffusers==0.36.0 (or install Diffusers from source)."
148
- )
149
  model_loaded = False
150
 
151
  # ============================================================
@@ -155,29 +109,33 @@ else:
155
  def make_error_image(w: int, h: int) -> Image.Image:
156
  return Image.new("RGB", (int(w), int(h)), (18, 18, 22))
157
 
158
- def prep_init_image(img: Image.Image, width: int, height: int) -> Image.Image:
159
  if img is None:
160
  return None
161
  if not isinstance(img, Image.Image):
162
  return None
163
- img = img.convert("RGB")
164
- if img.size != (width, height):
165
- img = img.resize((width, height), Image.LANCZOS)
166
- return img
167
 
168
- def _call_pipeline(pipe, kwargs: dict):
169
  """
170
- Robust call: only pass kwargs the pipeline actually accepts.
171
- This avoids crashes if a particular build does not support negative_prompt, etc.
172
  """
173
- try:
174
- sig = inspect.signature(pipe.__call__)
175
- allowed = set(sig.parameters.keys())
176
- filtered = {k: v for k, v in kwargs.items() if k in allowed and v is not None}
177
- return pipe(**filtered)
178
- except Exception:
179
- # Fallback: try raw kwargs (some pipelines use **kwargs internally)
180
- return pipe(**{k: v for k, v in kwargs.items() if v is not None})
 
 
 
 
 
 
 
181
 
182
  # ============================================================
183
  # Inference
@@ -192,10 +150,8 @@ def _infer_impl(
192
  height,
193
  guidance_scale,
194
  num_inference_steps,
195
- shift,
196
- max_sequence_length,
197
  init_image,
198
- strength,
199
  ):
200
  width = int(width)
201
  height = int(height)
@@ -208,6 +164,12 @@ def _infer_impl(
208
  if not prompt:
209
  return make_error_image(width, height), "Error: prompt is empty."
210
 
 
 
 
 
 
 
211
  if randomize_seed:
212
  seed = random.randint(0, MAX_SEED)
213
 
@@ -215,59 +177,49 @@ def _infer_impl(
215
 
216
  status = f"Seed: {seed}"
217
  if fallback_msg:
218
- status += f" | {fallback_msg}"
219
 
220
  gs = float(guidance_scale)
221
  steps = int(num_inference_steps)
222
- msl = int(max_sequence_length)
223
- st = float(strength)
224
 
225
  neg = (negative_prompt or "").strip()
226
  if not neg:
227
  neg = None
228
 
229
- init_image = prep_init_image(init_image, width, height)
 
230
 
231
- # Update scheduler (shift) per run
232
- try:
233
- scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=float(shift))
234
- pipe_txt2img.scheduler = scheduler
235
- pipe_img2img.scheduler = scheduler
236
- except Exception:
237
- pass
238
 
239
  try:
240
- base_kwargs = dict(
241
- prompt=prompt,
242
- height=height,
243
- width=width,
244
- guidance_scale=gs,
245
- num_inference_steps=steps,
246
- generator=generator,
247
- max_sequence_length=msl,
248
- )
249
- # only passed if supported by the pipeline
250
- if neg is not None:
251
- base_kwargs["negative_prompt"] = neg
252
-
253
  with torch.inference_mode():
254
  if device.type == "cuda":
255
  with torch.autocast("cuda", dtype=dtype):
256
- if init_image is not None:
257
- out = _call_pipeline(
258
- pipe_img2img,
259
- {**base_kwargs, "image": init_image, "strength": st},
260
- )
261
- else:
262
- out = _call_pipeline(pipe_txt2img, base_kwargs)
263
- else:
264
- if init_image is not None:
265
- out = _call_pipeline(
266
- pipe_img2img,
267
- {**base_kwargs, "image": init_image, "strength": st},
268
  )
269
- else:
270
- out = _call_pipeline(pipe_txt2img, base_kwargs)
 
 
 
 
 
 
 
 
 
 
271
 
272
  img = out.images[0]
273
  return img, status
@@ -289,17 +241,14 @@ else:
289
  return _infer_impl(*args, **kwargs)
290
 
291
  # ============================================================
292
- # UI (simple black style like your SDXL example)
293
  # ============================================================
294
 
295
  CSS = """
296
- body {
297
- background: #000;
298
- color: #fff;
299
- }
300
  """
301
 
302
- with gr.Blocks(title="Z-Image txt2img + img2img") as demo:
303
  gr.HTML(f"<style>{CSS}</style>")
304
 
305
  if fallback_msg:
@@ -308,29 +257,29 @@ with gr.Blocks(title="Z-Image txt2img + img2img") as demo:
308
  if not model_loaded:
309
  gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
310
 
311
- gr.Markdown("## Z-Image Generator (txt2img + img2img)")
 
312
 
313
- prompt = gr.Textbox(label="Prompt", lines=2)
314
- init_image = gr.Image(label="Initial image (optional)", type="pil")
315
 
316
- run_button = gr.Button("Generate")
 
 
 
 
317
  result = gr.Image(label="Result")
318
  status = gr.Markdown("")
319
 
320
  with gr.Accordion("Advanced Settings", open=False):
321
- negative_prompt = gr.Textbox(label="Negative prompt (optional)")
322
  seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
323
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
324
 
325
  width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Width")
326
  height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Height")
327
 
328
- guidance_scale = gr.Slider(0.0, 10.0, step=0.1, value=0.0, label="Guidance scale")
329
- num_inference_steps = gr.Slider(1, 100, step=1, value=8, label="Steps")
330
- shift = gr.Slider(1.0, 10.0, step=0.1, value=3.0, label="Time shift")
331
- max_sequence_length = gr.Slider(64, 512, step=64, value=512, label="Max sequence length")
332
-
333
- strength = gr.Slider(0.0, 1.0, step=0.05, value=0.6, label="Image strength (img2img)")
334
 
335
  run_button.click(
336
  fn=infer,
@@ -343,10 +292,8 @@ with gr.Blocks(title="Z-Image txt2img + img2img") as demo:
343
  height,
344
  guidance_scale,
345
  num_inference_steps,
346
- shift,
347
- max_sequence_length,
348
  init_image,
349
- strength,
350
  ],
351
  outputs=[result, status],
352
  )
 
1
  # app.py
2
  # ============================================================
3
+ # SDXL Inpainting (replace clothing area) for Hugging Face Spaces
4
+ # Removes img2img, adds inpainting with mask_image
5
  # ============================================================
6
 
7
  import os
 
9
  import random
10
  import warnings
11
  import logging
 
12
 
13
  # ---- Spaces GPU decorator (must be imported early) ----------
14
  try:
 
24
  import torch
25
  from huggingface_hub import login
26
 
27
+ from diffusers import StableDiffusionXLInpaintPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # ============================================================
30
  # Config
31
  # ============================================================
32
 
33
+ # SDXL inpainting model repo
34
+ INPAINT_MODEL = os.environ.get(
35
+ "INPAINT_MODEL",
36
+ "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
37
+ ).strip()
38
 
39
  HF_TOKEN = os.getenv("HF_TOKEN", "").strip()
40
  if HF_TOKEN:
 
67
  fallback_msg = "GPU unavailable. Running in CPU fallback mode (slow)."
68
 
69
  # ============================================================
70
+ # Load pipeline
71
  # ============================================================
72
 
73
+ pipe_inpaint = None
 
74
  model_loaded = False
75
  load_error = None
76
 
77
+ def _maybe_disable_safety_checker(pipe):
78
+ # Many Spaces prefer no explicit changes here.
79
+ # If your model includes a checker and you want it enabled, do nothing.
80
+ # If you want to disable (not recommended), you can set it to None.
81
+ return pipe
82
+
83
+ try:
84
+ fp_kwargs = {
85
+ "torch_dtype": dtype,
86
+ "use_safetensors": True,
87
+ }
88
+ if HF_TOKEN:
89
+ fp_kwargs["token"] = HF_TOKEN
90
+
91
+ pipe_inpaint = StableDiffusionXLInpaintPipeline.from_pretrained(INPAINT_MODEL, **fp_kwargs).to(device)
92
+ pipe_inpaint = _maybe_disable_safety_checker(pipe_inpaint)
93
 
 
 
 
94
  try:
95
+ pipe_inpaint.set_progress_bar_config(disable=True)
 
 
 
 
 
96
  except Exception:
97
  pass
98
 
99
+ model_loaded = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ except Exception as e:
102
+ load_error = repr(e)
 
 
 
 
 
 
 
103
  model_loaded = False
104
 
105
  # ============================================================
 
109
  def make_error_image(w: int, h: int) -> Image.Image:
110
  return Image.new("RGB", (int(w), int(h)), (18, 18, 22))
111
 
112
+ def _ensure_rgb(img: Image.Image) -> Image.Image:
113
  if img is None:
114
  return None
115
  if not isinstance(img, Image.Image):
116
  return None
117
+ return img.convert("RGB")
 
 
 
118
 
119
+ def _ensure_mask(mask: Image.Image) -> Image.Image:
120
  """
121
+ Expect white where we want to edit, black where we want to keep.
122
+ Convert to single channel L.
123
  """
124
+ if mask is None:
125
+ return None
126
+ if not isinstance(mask, Image.Image):
127
+ return None
128
+ mask = mask.convert("L")
129
+ return mask
130
+
131
+ def _resize_to(img: Image.Image, w: int, h: int, is_mask: bool = False) -> Image.Image:
132
+ if img is None:
133
+ return None
134
+ if img.size == (w, h):
135
+ return img
136
+ if is_mask:
137
+ return img.resize((w, h), Image.NEAREST)
138
+ return img.resize((w, h), Image.LANCZOS)
139
 
140
  # ============================================================
141
  # Inference
 
150
  height,
151
  guidance_scale,
152
  num_inference_steps,
 
 
153
  init_image,
154
+ mask_image,
155
  ):
156
  width = int(width)
157
  height = int(height)
 
164
  if not prompt:
165
  return make_error_image(width, height), "Error: prompt is empty."
166
 
167
+ if init_image is None:
168
+ return make_error_image(width, height), "Error: you must provide an input image."
169
+
170
+ if mask_image is None:
171
+ return make_error_image(width, height), "Error: you must provide a mask image (white=edit, black=keep)."
172
+
173
  if randomize_seed:
174
  seed = random.randint(0, MAX_SEED)
175
 
 
177
 
178
  status = f"Seed: {seed}"
179
  if fallback_msg:
180
+ status = status + " | " + fallback_msg
181
 
182
  gs = float(guidance_scale)
183
  steps = int(num_inference_steps)
 
 
184
 
185
  neg = (negative_prompt or "").strip()
186
  if not neg:
187
  neg = None
188
 
189
+ init_image = _ensure_rgb(init_image)
190
+ mask_image = _ensure_mask(mask_image)
191
 
192
+ # resize both to target resolution
193
+ init_image = _resize_to(init_image, width, height, is_mask=False)
194
+ mask_image = _resize_to(mask_image, width, height, is_mask=True)
 
 
 
 
195
 
196
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  with torch.inference_mode():
198
  if device.type == "cuda":
199
  with torch.autocast("cuda", dtype=dtype):
200
+ out = pipe_inpaint(
201
+ prompt=prompt,
202
+ negative_prompt=neg,
203
+ image=init_image,
204
+ mask_image=mask_image,
205
+ width=width,
206
+ height=height,
207
+ guidance_scale=gs,
208
+ num_inference_steps=steps,
209
+ generator=generator,
 
 
210
  )
211
+ else:
212
+ out = pipe_inpaint(
213
+ prompt=prompt,
214
+ negative_prompt=neg,
215
+ image=init_image,
216
+ mask_image=mask_image,
217
+ width=width,
218
+ height=height,
219
+ guidance_scale=gs,
220
+ num_inference_steps=steps,
221
+ generator=generator,
222
+ )
223
 
224
  img = out.images[0]
225
  return img, status
 
241
  return _infer_impl(*args, **kwargs)
242
 
243
  # ============================================================
244
+ # UI
245
  # ============================================================
246
 
247
  CSS = """
248
+ body { background: #000; color: #fff; }
 
 
 
249
  """
250
 
251
+ with gr.Blocks(title="SDXL Inpainting (Clothing Edit)") as demo:
252
  gr.HTML(f"<style>{CSS}</style>")
253
 
254
  if fallback_msg:
 
257
  if not model_loaded:
258
  gr.Markdown(f"⚠️ Model failed to load:\n\n{load_error}")
259
 
260
+ gr.Markdown("## SDXL Inpainting (image + mask)")
261
+ gr.Markdown("Mask rule: **white = edit**, **black = keep**.")
262
 
263
+ prompt = gr.Textbox(label="Prompt (describe the new clothing)", lines=2)
264
+ negative_prompt = gr.Textbox(label="Negative prompt (optional)", lines=2)
265
 
266
+ with gr.Row():
267
+ init_image = gr.Image(label="Input image", type="pil")
268
+ mask_image = gr.Image(label="Mask image (white edits)", type="pil")
269
+
270
+ run_button = gr.Button("Inpaint")
271
  result = gr.Image(label="Result")
272
  status = gr.Markdown("")
273
 
274
  with gr.Accordion("Advanced Settings", open=False):
 
275
  seed = gr.Slider(0, MAX_SEED, step=1, value=0, label="Seed")
276
  randomize_seed = gr.Checkbox(value=True, label="Randomize seed")
277
 
278
  width = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Width")
279
  height = gr.Slider(256, MAX_IMAGE_SIZE, step=64, value=1024, label="Height")
280
 
281
+ guidance_scale = gr.Slider(0.0, 15.0, step=0.1, value=7.0, label="Guidance scale")
282
+ num_inference_steps = gr.Slider(1, 80, step=1, value=30, label="Steps")
 
 
 
 
283
 
284
  run_button.click(
285
  fn=infer,
 
292
  height,
293
  guidance_scale,
294
  num_inference_steps,
 
 
295
  init_image,
296
+ mask_image,
297
  ],
298
  outputs=[result, status],
299
  )