multimodalart HF Staff commited on
Commit
a1f5c28
·
verified ·
1 Parent(s): 2436f41

Graft hosted lm_head onto pipe encoder for upsampling (no second model)

Browse files
Files changed (1) hide show
  1. app.py +40 -21
app.py CHANGED
@@ -12,15 +12,20 @@ from typing import List, Literal, Union
12
 
13
  import spaces
14
  import torch
 
15
  import gradio as gr
16
  from pydantic import BaseModel, Field
 
 
 
17
  from diffusers import Ideogram4Pipeline
18
  from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
19
 
20
  # --- New (safety-fixed) checkpoint ---
21
  MODEL_ID = "diffusers-internal-dev/ideogram-4-nf4-v2"
22
- # Generative sibling of the text encoder, used for prompt upsampling.
23
- UPSAMPLER_ID = "Qwen/Qwen3-VL-8B-Instruct"
 
24
 
25
  MAX_SEED = 2**31 - 1
26
 
@@ -35,10 +40,9 @@ MODES = {
35
  pipe = Ideogram4Pipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
36
  pipe.to("cuda")
37
 
38
- # --- Prompt upsampler (Qwen3-VL-8B-Instruct, generative) ---
39
- upsampler = Qwen3VLForConditionalGeneration.from_pretrained(UPSAMPLER_ID, torch_dtype=torch.bfloat16)
40
- upsampler.to("cuda")
41
- upsampler_proc = AutoProcessor.from_pretrained(UPSAMPLER_ID)
42
 
43
  try:
44
  import outlines
@@ -88,33 +92,48 @@ _SEC = _load_sections(os.path.join(_HERE, "v6.txt"))
88
  SYSTEM_PROMPT = _SEC["system"]
89
  USER_TEMPLATE = _SEC.get("user", "User idea: {{original_prompt}}")
90
 
91
- _logits_processor = None # built lazily (compiles the schema -> FSM once)
92
-
93
-
94
- def _get_logits_processor():
95
- global _logits_processor
96
- if _logits_processor is None and OUTLINES_AVAILABLE:
97
- ol_model = outlines.from_transformers(upsampler, upsampler_proc.tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  _logits_processor = outlines.Generator(ol_model, Caption).logits_processor
99
- return _logits_processor
100
 
101
 
102
  def upsample_prompt(prompt: str, width: int, height: int) -> str:
103
  from math import gcd
104
 
 
105
  d = gcd(width, height) or 1
106
  aspect_ratio = f"{width // d}:{height // d}"
107
  user = USER_TEMPLATE.replace("{{aspect_ratio}}", aspect_ratio).replace("{{original_prompt}}", prompt)
108
  messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}]
109
  inputs = upsampler_proc.apply_chat_template(
110
  messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
111
- ).to(upsampler.device)
112
  gen_kwargs = dict(max_new_tokens=1024, do_sample=True, temperature=1.0, use_cache=True)
113
- lp = _get_logits_processor()
114
- if lp is not None:
115
- lp.reset()
116
- gen_kwargs["logits_processor"] = [lp]
117
- out = upsampler.generate(**inputs, **gen_kwargs)
118
  return upsampler_proc.batch_decode(
119
  out[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
120
  )[0].strip()
@@ -158,7 +177,7 @@ with gr.Blocks(theme=gr.themes.Citrus(), title="Ideogram 4 (NF4) — diffusers p
158
  f"Private demo of [`{MODEL_ID}`](https://huggingface.co/{MODEL_ID}) using the "
159
  "[diffusers PR](https://github.com/huggingface/diffusers-new-model-addition-ideogram) branch, on ZeroGPU.\n"
160
  "Toggle **Prompt upsampling** in Advanced to rewrite your idea into Ideogram's native structured caption "
161
- "(Qwen3-VL-8B + Outlines)."
162
  )
163
 
164
  with gr.Row():
 
12
 
13
  import spaces
14
  import torch
15
+ import torch.nn as nn
16
  import gradio as gr
17
  from pydantic import BaseModel, Field
18
+ from accelerate import init_empty_weights
19
+ from huggingface_hub import hf_hub_download
20
+ from safetensors.torch import load_file
21
  from diffusers import Ideogram4Pipeline
22
  from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
23
 
24
  # --- New (safety-fixed) checkpoint ---
25
  MODEL_ID = "diffusers-internal-dev/ideogram-4-nf4-v2"
26
+ # Just the LM head, grafted onto the pipeline's own Qwen3-VL encoder to make it generative.
27
+ LM_HEAD_REPO = "multimodalart/qwen3-vl-8b-instruct-lm-head"
28
+ TOKENIZER_ID = "Qwen/Qwen3-VL-8B-Instruct" # processor/tokenizer only (no weights)
29
 
30
  MAX_SEED = 2**31 - 1
31
 
 
40
  pipe = Ideogram4Pipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
41
  pipe.to("cuda")
42
 
43
+ # --- Upsampler tokenizer + pre-fetched LM head (graft done lazily on GPU) ---
44
+ upsampler_proc = AutoProcessor.from_pretrained(TOKENIZER_ID)
45
+ LM_HEAD_PATH = hf_hub_download(LM_HEAD_REPO, "lm_head.safetensors") # cached at startup
 
46
 
47
  try:
48
  import outlines
 
92
  SYSTEM_PROMPT = _SEC["system"]
93
  USER_TEMPLATE = _SEC.get("user", "User idea: {{original_prompt}}")
94
 
95
+ _enhancer = None # Qwen3VLForConditionalGeneration sharing the pipe's encoder body
96
+ _logits_processor = None # Outlines structural constraint (built once)
97
+
98
+
99
+ def _build_enhancer():
100
+ """Graft the hosted lm_head onto pipe.text_encoder -> a generative model (no second body)."""
101
+ global _enhancer, _logits_processor
102
+ if _enhancer is not None:
103
+ return _enhancer
104
+ device = pipe.text_encoder.device
105
+ head = load_file(LM_HEAD_PATH)["lm_head.weight"] # [vocab, hidden] bf16
106
+ with init_empty_weights():
107
+ gen = Qwen3VLForConditionalGeneration(pipe.text_encoder.config)
108
+ gen.model = pipe.text_encoder # reuse the loaded (nf4) encoder body — no extra body in VRAM
109
+ lm = nn.Linear(head.shape[1], head.shape[0], bias=False).to(device=device, dtype=torch.bfloat16)
110
+ with torch.no_grad():
111
+ lm.weight.copy_(head.to(device=device, dtype=torch.bfloat16))
112
+ gen.lm_head = lm
113
+ gen.eval()
114
+ _enhancer = gen
115
+ if OUTLINES_AVAILABLE:
116
+ ol_model = outlines.from_transformers(gen, upsampler_proc.tokenizer)
117
  _logits_processor = outlines.Generator(ol_model, Caption).logits_processor
118
+ return gen
119
 
120
 
121
  def upsample_prompt(prompt: str, width: int, height: int) -> str:
122
  from math import gcd
123
 
124
+ gen = _build_enhancer()
125
  d = gcd(width, height) or 1
126
  aspect_ratio = f"{width // d}:{height // d}"
127
  user = USER_TEMPLATE.replace("{{aspect_ratio}}", aspect_ratio).replace("{{original_prompt}}", prompt)
128
  messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}]
129
  inputs = upsampler_proc.apply_chat_template(
130
  messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
131
+ ).to(gen.device)
132
  gen_kwargs = dict(max_new_tokens=1024, do_sample=True, temperature=1.0, use_cache=True)
133
+ if _logits_processor is not None:
134
+ _logits_processor.reset()
135
+ gen_kwargs["logits_processor"] = [_logits_processor]
136
+ out = gen.generate(**inputs, **gen_kwargs)
 
137
  return upsampler_proc.batch_decode(
138
  out[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
139
  )[0].strip()
 
177
  f"Private demo of [`{MODEL_ID}`](https://huggingface.co/{MODEL_ID}) using the "
178
  "[diffusers PR](https://github.com/huggingface/diffusers-new-model-addition-ideogram) branch, on ZeroGPU.\n"
179
  "Toggle **Prompt upsampling** in Advanced to rewrite your idea into Ideogram's native structured caption "
180
+ "(the pipeline's own Qwen3-VL encoder + a grafted LM head + Outlines)."
181
  )
182
 
183
  with gr.Row():