multimodalart HF Staff commited on
Commit
6ed28ce
·
verified ·
1 Parent(s): 013c185

Dequant nf4->bf16 in parent/CPU context (persists across ZeroGPU forks) + streamed upsampler progress

Browse files
Files changed (1) hide show
  1. app.py +36 -39
app.py CHANGED
@@ -39,6 +39,11 @@ MODES = {
39
 
40
  # --- Pipeline ---
41
  pipe = Ideogram4Pipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
 
 
 
 
 
42
  pipe.to("cuda")
43
 
44
  # --- Upsampler tokenizer + pre-fetched LM head (graft done lazily on GPU) ---
@@ -123,24 +128,10 @@ except Exception as e: # don't let a graft hiccup block the demo / the bf16 OOM
123
  ENHANCER, LOGITS_PROCESSOR = None, None
124
 
125
 
126
- # --- bf16 path: dequantize both transformers nf4 -> bf16 (kept resident; YOLO OOM test) ---
127
- # diffusers ModelMixin.dequantize() replaces the bnb 4-bit layers with bf16 and drops the quantizer.
128
- # Done lazily on first GPU request (ZeroGPU: no CUDA at import, and bnb dequant needs the weights on GPU).
129
- _BF16_DONE = False
130
-
131
-
132
- def _ensure_bf16_transformers():
133
- global _BF16_DONE
134
- if _BF16_DONE:
135
- return
136
- pipe.transformer.dequantize()
137
- pipe.unconditional_transformer.dequantize()
138
- torch.cuda.empty_cache()
139
- _BF16_DONE = True
140
-
141
-
142
- def upsample_prompt(prompt: str, width: int, height: int) -> str:
143
  from math import gcd
 
 
144
 
145
  gen = ENHANCER
146
  d = gcd(width, height) or 1
@@ -150,14 +141,28 @@ def upsample_prompt(prompt: str, width: int, height: int) -> str:
150
  inputs = upsampler_proc.apply_chat_template(
151
  messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
152
  ).to(gen.device)
153
- gen_kwargs = dict(max_new_tokens=1024, do_sample=True, temperature=1.0, use_cache=True)
 
154
  if LOGITS_PROCESSOR is not None:
155
  LOGITS_PROCESSOR.reset()
156
  gen_kwargs["logits_processor"] = [LOGITS_PROCESSOR]
157
- out = gen.generate(**inputs, **gen_kwargs)
158
- return upsampler_proc.batch_decode(
159
- out[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
160
- )[0].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
 
163
  @spaces.GPU(duration=240, size="xlarge")
@@ -178,8 +183,9 @@ def generate(
178
  if enhance:
179
  if not OUTLINES_AVAILABLE:
180
  gr.Warning("`outlines` is not installed — upsampling without structural constraints.")
181
- final_prompt = upsample_prompt(prompt, int(width), int(height))
182
 
 
183
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
184
  preset = MODES.get(mode, MODES["Default · 20 steps"])
185
  image = pipe(
@@ -198,22 +204,13 @@ def generate(
198
 
199
  @spaces.GPU(size="xlarge")
200
  def _warmup():
201
- """Force the upsampler + pipeline onto GPU and warm their kernels at STARTUP, so request #1
202
- isn't slow. On ZeroGPU, module-level loading is CPU-only; GPU placement + JIT warmup otherwise
203
- happen on the first request."""
204
- _ensure_bf16_transformers()
205
- try:
206
- if ENHANCER is not None:
207
- upsample_prompt("a red apple on a wooden table", 1024, 1024)
208
- print("[warmup] upsampler ready on GPU", flush=True)
209
- except Exception as e:
210
- print(f"[warmup] upsampler warmup skipped: {e!r}", flush=True)
211
- try:
212
- g = torch.Generator(device="cuda").manual_seed(0)
213
- pipe(prompt="a red apple", width=1024, height=1024, generator=g, **MODES["Turbo · 12 steps"])
214
- print("[warmup] pipeline ready on GPU", flush=True)
215
- except Exception as e:
216
- print(f"[warmup] pipeline warmup skipped: {e!r}", flush=True)
217
 
218
 
219
  try:
 
39
 
40
  # --- Pipeline ---
41
  pipe = Ideogram4Pipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
42
+ # Dequantize nf4 -> bf16 in the PARENT (CPU) context, BEFORE ZeroGPU forks/packs the model, so every
43
+ # fork inherits bf16 (a fork-local dequant doesn't persist). bitsandbytes supports CPU 4-bit dequant.
44
+ # This also gives AOTI real bf16 weights to bind to its (weight-less) compiled graph.
45
+ pipe.transformer.dequantize()
46
+ pipe.unconditional_transformer.dequantize()
47
  pipe.to("cuda")
48
 
49
  # --- Upsampler tokenizer + pre-fetched LM head (graft done lazily on GPU) ---
 
128
  ENHANCER, LOGITS_PROCESSOR = None, None
129
 
130
 
131
+ def upsample_prompt(prompt: str, width: int, height: int, progress=None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  from math import gcd
133
+ from threading import Thread
134
+ from transformers import TextIteratorStreamer
135
 
136
  gen = ENHANCER
137
  d = gcd(width, height) or 1
 
141
  inputs = upsampler_proc.apply_chat_template(
142
  messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
143
  ).to(gen.device)
144
+ max_new = 1024
145
+ gen_kwargs = dict(max_new_tokens=max_new, do_sample=True, temperature=1.0, use_cache=True)
146
  if LOGITS_PROCESSOR is not None:
147
  LOGITS_PROCESSOR.reset()
148
  gen_kwargs["logits_processor"] = [LOGITS_PROCESSOR]
149
+
150
+ if progress is None: # warmup path, no UI
151
+ out = gen.generate(**inputs, **gen_kwargs)
152
+ return upsampler_proc.batch_decode(out[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0].strip()
153
+
154
+ # stream tokens so the UI shows the upsampler working
155
+ streamer = TextIteratorStreamer(upsampler_proc.tokenizer, skip_prompt=True, skip_special_tokens=True)
156
+ gen_kwargs["streamer"] = streamer
157
+ thread = Thread(target=gen.generate, kwargs={**inputs, **gen_kwargs})
158
+ thread.start()
159
+ text, n = "", 0
160
+ for chunk in streamer:
161
+ text += chunk
162
+ n += 1
163
+ progress(min(n / max_new, 0.99), desc="✍️ Upsampling prompt…")
164
+ thread.join()
165
+ return text.strip()
166
 
167
 
168
  @spaces.GPU(duration=240, size="xlarge")
 
183
  if enhance:
184
  if not OUTLINES_AVAILABLE:
185
  gr.Warning("`outlines` is not installed — upsampling without structural constraints.")
186
+ final_prompt = upsample_prompt(prompt, int(width), int(height), progress=progress)
187
 
188
+ progress(0.0, desc="🎨 Generating image…")
189
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
190
  preset = MODES.get(mode, MODES["Default · 20 steps"])
191
  image = pipe(
 
204
 
205
  @spaces.GPU(size="xlarge")
206
  def _warmup():
207
+ """Preload the upsampler onto GPU and warm it at STARTUP (graft move + Outlines FSM + first-token JIT).
208
+ NOTE: runtime nf4->bf16 dequant is intentionally NOT done here it does not persist across ZeroGPU
209
+ forks (each request re-forks from the nf4 parent process), so bf16+speed will come from a precompiled
210
+ AOTI artifact instead."""
211
+ if ENHANCER is not None:
212
+ upsample_prompt("a red apple on a wooden table", 1024, 1024)
213
+ print("[warmup] upsampler ready on GPU", flush=True)
 
 
 
 
 
 
 
 
 
214
 
215
 
216
  try: