multimodalart HF Staff commited on
Commit
66efa02
·
verified ·
1 Parent(s): bc8eac2

Preload upsampler + pipeline on GPU via @spaces.GPU startup warmup

Browse files
Files changed (1) hide show
  1. app.py +67 -22
app.py CHANGED
@@ -92,36 +92,56 @@ _SEC = _load_sections(os.path.join(_HERE, "v6.txt"))
92
  SYSTEM_PROMPT = _SEC["system"]
93
  USER_TEMPLATE = _SEC.get("user", "User idea: {{original_prompt}}")
94
 
95
- _enhancer = None # Qwen3VLForConditionalGeneration sharing the pipe's encoder body
96
- _logits_processor = None # Outlines structural constraint (built once)
97
-
98
 
99
  def _build_enhancer():
100
- """Graft the hosted lm_head onto pipe.text_encoder -> a generative model (no second body)."""
101
- global _enhancer, _logits_processor
102
- if _enhancer is not None:
103
- return _enhancer
104
- device = pipe.text_encoder.device
105
  head = load_file(LM_HEAD_PATH)["lm_head.weight"] # [vocab, hidden] bf16
106
  with init_empty_weights():
107
  gen = Qwen3VLForConditionalGeneration(pipe.text_encoder.config)
108
  gen.model = pipe.text_encoder # reuse the loaded (nf4) encoder body — no extra body in VRAM
109
- lm = nn.Linear(head.shape[1], head.shape[0], bias=False).to(device=device, dtype=torch.bfloat16)
110
  with torch.no_grad():
111
- lm.weight.copy_(head.to(device=device, dtype=torch.bfloat16))
112
- gen.lm_head = lm
 
113
  gen.eval()
114
- _enhancer = gen
115
  if OUTLINES_AVAILABLE:
116
  ol_model = outlines.from_transformers(gen, upsampler_proc.tokenizer)
117
- _logits_processor = outlines.Generator(ol_model, Caption).logits_processor
118
- return gen
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
  def upsample_prompt(prompt: str, width: int, height: int) -> str:
122
  from math import gcd
123
 
124
- gen = _build_enhancer()
125
  d = gcd(width, height) or 1
126
  aspect_ratio = f"{width // d}:{height // d}"
127
  user = USER_TEMPLATE.replace("{{aspect_ratio}}", aspect_ratio).replace("{{original_prompt}}", prompt)
@@ -130,16 +150,16 @@ def upsample_prompt(prompt: str, width: int, height: int) -> str:
130
  messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
131
  ).to(gen.device)
132
  gen_kwargs = dict(max_new_tokens=1024, do_sample=True, temperature=1.0, use_cache=True)
133
- if _logits_processor is not None:
134
- _logits_processor.reset()
135
- gen_kwargs["logits_processor"] = [_logits_processor]
136
  out = gen.generate(**inputs, **gen_kwargs)
137
  return upsampler_proc.batch_decode(
138
  out[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
139
  )[0].strip()
140
 
141
 
142
- @spaces.GPU(duration=180, size="xlarge")
143
  def generate(
144
  prompt: str,
145
  mode: str,
@@ -171,6 +191,28 @@ def generate(
171
  return image, seed, final_prompt
172
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  with gr.Blocks(theme=gr.themes.Citrus(), title="Ideogram 4 (NF4) — diffusers preview") as demo:
175
  gr.Markdown(
176
  "## Ideogram 4 (NF4) — diffusers preview\n"
@@ -195,8 +237,8 @@ with gr.Blocks(theme=gr.themes.Citrus(), title="Ideogram 4 (NF4) — diffusers p
195
  run = gr.Button("Generate", variant="primary")
196
  with gr.Accordion("Advanced", open=False):
197
  enhance = gr.Checkbox(
198
- label="Prompt upsampling",
199
- value=True,
200
  info="Rewrite the prompt into Ideogram's native JSON caption before generating."
201
  + ("" if OUTLINES_AVAILABLE else " ⚠ outlines not installed — runs unconstrained."),
202
  )
@@ -208,14 +250,17 @@ with gr.Blocks(theme=gr.themes.Citrus(), title="Ideogram 4 (NF4) — diffusers p
208
  randomize = gr.Checkbox(label="Randomize seed", value=True)
209
  with gr.Column():
210
  out_image = gr.Image(label="Output", type="pil")
 
211
  out_caption = gr.Textbox(
212
  label="Caption fed to the model (upsampled when enabled)",
 
 
213
  )
214
 
215
  run.click(
216
  generate,
217
  inputs=[prompt, mode, enhance, width, height, seed, randomize],
218
- outputs=[out_image, seed, out_caption],
219
  )
220
 
221
  demo.queue().launch()
 
92
  SYSTEM_PROMPT = _SEC["system"]
93
  USER_TEMPLATE = _SEC.get("user", "User idea: {{original_prompt}}")
94
 
 
 
 
95
 
96
  def _build_enhancer():
97
+ """Graft the hosted lm_head onto pipe.text_encoder -> a generative model (no second body).
98
+ Done ONCE at import time so nothing heavy happens on the first GPU request. Only the new
99
+ bf16 lm_head is `.to('cuda')` (ZeroGPU defers it); the shared nf4 body is already moved by `pipe`."""
 
 
100
  head = load_file(LM_HEAD_PATH)["lm_head.weight"] # [vocab, hidden] bf16
101
  with init_empty_weights():
102
  gen = Qwen3VLForConditionalGeneration(pipe.text_encoder.config)
103
  gen.model = pipe.text_encoder # reuse the loaded (nf4) encoder body — no extra body in VRAM
104
+ lm = nn.Linear(head.shape[1], head.shape[0], bias=False)
105
  with torch.no_grad():
106
+ lm.weight.copy_(head.to(torch.bfloat16))
107
+ gen.lm_head = lm.to(torch.bfloat16)
108
+ gen.lm_head.to("cuda") # ZeroGPU-deferred move of just the head
109
  gen.eval()
110
+ lp = None
111
  if OUTLINES_AVAILABLE:
112
  ol_model = outlines.from_transformers(gen, upsampler_proc.tokenizer)
113
+ lp = outlines.Generator(ol_model, Caption).logits_processor # compiles schema->FSM now
114
+ return gen, lp
115
+
116
+
117
+ # Assemble the generative enhancer + structural constraint at STARTUP (not on first request).
118
+ try:
119
+ ENHANCER, LOGITS_PROCESSOR = _build_enhancer()
120
+ except Exception as e: # don't let a graft hiccup block the demo / the bf16 OOM test
121
+ print(f"[enhancer] graft failed, prompt upsampling disabled: {e!r}")
122
+ ENHANCER, LOGITS_PROCESSOR = None, None
123
+
124
+
125
+ # --- bf16 path: dequantize both transformers nf4 -> bf16 (kept resident; YOLO OOM test) ---
126
+ # diffusers ModelMixin.dequantize() replaces the bnb 4-bit layers with bf16 and drops the quantizer.
127
+ # Done lazily on first GPU request (ZeroGPU: no CUDA at import, and bnb dequant needs the weights on GPU).
128
+ _BF16_DONE = False
129
+
130
+
131
+ def _ensure_bf16_transformers():
132
+ global _BF16_DONE
133
+ if _BF16_DONE:
134
+ return
135
+ pipe.transformer.dequantize()
136
+ pipe.unconditional_transformer.dequantize()
137
+ torch.cuda.empty_cache()
138
+ _BF16_DONE = True
139
 
140
 
141
  def upsample_prompt(prompt: str, width: int, height: int) -> str:
142
  from math import gcd
143
 
144
+ gen = ENHANCER
145
  d = gcd(width, height) or 1
146
  aspect_ratio = f"{width // d}:{height // d}"
147
  user = USER_TEMPLATE.replace("{{aspect_ratio}}", aspect_ratio).replace("{{original_prompt}}", prompt)
 
150
  messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
151
  ).to(gen.device)
152
  gen_kwargs = dict(max_new_tokens=1024, do_sample=True, temperature=1.0, use_cache=True)
153
+ if LOGITS_PROCESSOR is not None:
154
+ LOGITS_PROCESSOR.reset()
155
+ gen_kwargs["logits_processor"] = [LOGITS_PROCESSOR]
156
  out = gen.generate(**inputs, **gen_kwargs)
157
  return upsampler_proc.batch_decode(
158
  out[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
159
  )[0].strip()
160
 
161
 
162
+ @spaces.GPU(duration=240)
163
  def generate(
164
  prompt: str,
165
  mode: str,
 
191
  return image, seed, final_prompt
192
 
193
 
194
+ @spaces.GPU
195
+ def _warmup():
196
+ """Force the upsampler + pipeline onto GPU and warm their kernels at STARTUP, so request #1
197
+ isn't slow. On ZeroGPU, module-level loading is CPU-only; GPU placement + JIT warmup otherwise
198
+ happen on the first request."""
199
+ try:
200
+ if ENHANCER is not None:
201
+ upsample_prompt("a red apple on a wooden table", 1024, 1024)
202
+ print("[warmup] upsampler ready on GPU", flush=True)
203
+ except Exception as e:
204
+ print(f"[warmup] upsampler warmup skipped: {e!r}", flush=True)
205
+ try:
206
+ g = torch.Generator(device="cuda").manual_seed(0)
207
+ pipe(prompt="a red apple", width=1024, height=1024, generator=g, **MODES["Turbo · 12 steps"])
208
+ print("[warmup] pipeline ready on GPU", flush=True)
209
+ except Exception as e:
210
+ print(f"[warmup] pipeline warmup skipped: {e!r}", flush=True)
211
+
212
+
213
+ _warmup()
214
+
215
+
216
  with gr.Blocks(theme=gr.themes.Citrus(), title="Ideogram 4 (NF4) — diffusers preview") as demo:
217
  gr.Markdown(
218
  "## Ideogram 4 (NF4) — diffusers preview\n"
 
237
  run = gr.Button("Generate", variant="primary")
238
  with gr.Accordion("Advanced", open=False):
239
  enhance = gr.Checkbox(
240
+ label="Prompt upsampling (Outlines)",
241
+ value=False,
242
  info="Rewrite the prompt into Ideogram's native JSON caption before generating."
243
  + ("" if OUTLINES_AVAILABLE else " ⚠ outlines not installed — runs unconstrained."),
244
  )
 
250
  randomize = gr.Checkbox(label="Randomize seed", value=True)
251
  with gr.Column():
252
  out_image = gr.Image(label="Output", type="pil")
253
+ out_seed = gr.Number(label="Seed used", interactive=False, precision=0)
254
  out_caption = gr.Textbox(
255
  label="Caption fed to the model (upsampled when enabled)",
256
+ lines=4,
257
+ show_copy_button=True,
258
  )
259
 
260
  run.click(
261
  generate,
262
  inputs=[prompt, mode, enhance, width, height, seed, randomize],
263
+ outputs=[out_image, out_seed, out_caption],
264
  )
265
 
266
  demo.queue().launch()