multimodalart HF Staff commited on
Commit
49e5d58
·
verified ·
1 Parent(s): 60f16ec

Add prompt upsampling (Qwen3-VL+Outlines), 3 modes, v2 checkpoint, Citrus theme: app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -31
app.py CHANGED
@@ -8,61 +8,157 @@ _HERE = os.path.dirname(os.path.abspath(__file__))
8
  sys.path.insert(0, os.path.join(_HERE, "diffusers_src", "src"))
9
 
10
  import random
 
11
 
12
  import spaces
13
  import torch
14
  import gradio as gr
 
15
  from diffusers import Ideogram4Pipeline
 
16
 
17
- MODEL_ID = "diffusers-internal-dev/ideogram-4-nf4"
 
 
 
18
 
 
 
 
 
 
 
 
 
 
 
19
  pipe = Ideogram4Pipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
20
  pipe.to("cuda")
21
 
22
- MAX_SEED = 2**31 - 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
 
 
 
24
 
25
- @spaces.GPU(duration=180)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def generate(
27
  prompt: str,
 
 
28
  width: int,
29
  height: int,
30
- num_inference_steps: int,
31
- guidance_scale: float,
32
  seed: int,
33
  randomize_seed: bool,
34
  progress=gr.Progress(track_tqdm=True),
35
  ):
36
  if randomize_seed or seed < 0:
37
  seed = random.randint(0, MAX_SEED)
38
- generator = torch.Generator(device="cuda").manual_seed(int(seed))
39
 
40
- steps = int(num_inference_steps)
41
- kwargs = dict(
42
- prompt=prompt,
 
 
 
 
 
 
 
43
  width=int(width),
44
  height=int(height),
45
- num_inference_steps=steps,
46
  generator=generator,
47
- )
48
- if guidance_scale > 0:
49
- kwargs["guidance_scale"] = float(guidance_scale)
50
- kwargs["guidance_schedule"] = None
51
- else:
52
- # PR default is len 48 (7.0 x45 + 3.0 x3); rebuild it for any step count.
53
- tail = min(3, max(0, steps - 1))
54
- kwargs["guidance_schedule"] = (7.0,) * (steps - tail) + (3.0,) * tail
55
-
56
- image = pipe(**kwargs).images[0]
57
- return image, seed
58
 
59
 
60
- with gr.Blocks(title="Ideogram 4 (NF4) — diffusers preview") as demo:
61
  gr.Markdown(
62
  "## Ideogram 4 (NF4) — diffusers preview\n"
63
  f"Private demo of [`{MODEL_ID}`](https://huggingface.co/{MODEL_ID}) using the "
64
- "[diffusers PR #2](https://github.com/huggingface/diffusers-new-model-addition-ideogram/pull/2) "
65
- "branch, running on ZeroGPU."
 
66
  )
67
 
68
  with gr.Row():
@@ -72,27 +168,38 @@ with gr.Blocks(title="Ideogram 4 (NF4) — diffusers preview") as demo:
72
  value="A photo of a cat holding a sign that says hello world",
73
  lines=3,
74
  )
 
 
 
 
 
75
  run = gr.Button("Generate", variant="primary")
76
  with gr.Accordion("Advanced", open=False):
 
 
 
 
 
 
77
  with gr.Row():
78
  width = gr.Slider(512, 2048, value=1024, step=64, label="Width")
79
  height = gr.Slider(512, 2048, value=1024, step=64, label="Height")
80
- steps = gr.Slider(8, 64, value=48, step=1, label="Inference steps")
81
- guidance = gr.Slider(
82
- 0.0, 15.0, value=0.0, step=0.1,
83
- label="Guidance scale (0 = recommended schedule: 7.0 → 3.0)",
84
- )
85
  with gr.Row():
86
  seed = gr.Number(label="Seed", value=0, precision=0)
87
  randomize = gr.Checkbox(label="Randomize seed", value=True)
88
  with gr.Column():
89
  out_image = gr.Image(label="Output", type="pil")
90
  out_seed = gr.Number(label="Seed used", interactive=False, precision=0)
 
 
 
 
 
91
 
92
  run.click(
93
  generate,
94
- inputs=[prompt, width, height, steps, guidance, seed, randomize],
95
- outputs=[out_image, out_seed],
96
  )
97
 
98
  demo.queue().launch()
 
8
  sys.path.insert(0, os.path.join(_HERE, "diffusers_src", "src"))
9
 
10
  import random
11
+ from typing import List, Literal, Union
12
 
13
  import spaces
14
  import torch
15
  import gradio as gr
16
+ from pydantic import BaseModel, Field
17
  from diffusers import Ideogram4Pipeline
18
+ from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
19
 
20
+ # --- New (safety-fixed) checkpoint ---
21
+ MODEL_ID = "diffusers-internal-dev/ideogram-4-nf4-v2"
22
+ # Generative sibling of the text encoder, used for prompt upsampling.
23
+ UPSAMPLER_ID = "Qwen/Qwen3-VL-8B-Instruct"
24
 
25
+ MAX_SEED = 2**31 - 1
26
+
27
+ # --- Sampler modes (V4 presets, forward step-order: main CFG 7.0 -> polish 3.0) ---
28
+ MODES = {
29
+ "Turbo · 12 steps": dict(num_inference_steps=12, guidance_schedule=(7.0,) * 11 + (3.0,) * 1, mu=0.5, std=1.75),
30
+ "Default · 20 steps": dict(num_inference_steps=20, guidance_schedule=(7.0,) * 18 + (3.0,) * 2, mu=0.0, std=1.75),
31
+ "Quality · 48 steps": dict(num_inference_steps=48, guidance_schedule=(7.0,) * 45 + (3.0,) * 3, mu=0.0, std=1.5),
32
+ }
33
+
34
+ # --- Pipeline ---
35
  pipe = Ideogram4Pipeline.from_pretrained(MODEL_ID, torch_dtype=torch.bfloat16)
36
  pipe.to("cuda")
37
 
38
+ # --- Prompt upsampler (Qwen3-VL-8B-Instruct, generative) ---
39
+ upsampler = Qwen3VLForConditionalGeneration.from_pretrained(UPSAMPLER_ID, torch_dtype=torch.bfloat16)
40
+ upsampler.to("cuda")
41
+ upsampler_proc = AutoProcessor.from_pretrained(UPSAMPLER_ID)
42
+
43
+ try:
44
+ import outlines
45
+ OUTLINES_AVAILABLE = True
46
+ except Exception:
47
+ OUTLINES_AVAILABLE = False
48
+
49
+
50
+ # --- Caption schema (matches Ideogram's native caption / caption_verifier) ---
51
+ class ObjElement(BaseModel):
52
+ type: Literal["obj"]
53
+ desc: str
54
+
55
+
56
+ class TextElement(BaseModel):
57
+ type: Literal["text"]
58
+ text: str
59
+ desc: str
60
+
61
+
62
+ class Composition(BaseModel):
63
+ background: str
64
+ elements: List[Union[ObjElement, TextElement]] = Field(min_length=1)
65
+
66
+
67
+ class Caption(BaseModel):
68
+ high_level_description: str
69
+ compositional_deconstruction: Composition
70
+
71
+
72
+ def _load_sections(path):
73
+ sections, cur, buf = {}, None, []
74
+ for line in open(path, encoding="utf-8").read().splitlines():
75
+ s = line.strip()
76
+ if s.startswith("[") and s.endswith("]") and " " not in s:
77
+ if cur is not None:
78
+ sections[cur] = "\n".join(buf).strip()
79
+ cur, buf = s[1:-1].lower(), []
80
+ else:
81
+ buf.append(line)
82
+ if cur is not None:
83
+ sections[cur] = "\n".join(buf).strip()
84
+ return sections
85
+
86
 
87
+ _SEC = _load_sections(os.path.join(_HERE, "v6.txt"))
88
+ SYSTEM_PROMPT = _SEC["system"]
89
+ USER_TEMPLATE = _SEC.get("user", "User idea: {{original_prompt}}")
90
 
91
+ _logits_processor = None # built lazily (compiles the schema -> FSM once)
92
+
93
+
94
+ def _get_logits_processor():
95
+ global _logits_processor
96
+ if _logits_processor is None and OUTLINES_AVAILABLE:
97
+ ol_model = outlines.from_transformers(upsampler, upsampler_proc.tokenizer)
98
+ _logits_processor = outlines.Generator(ol_model, Caption).logits_processor
99
+ return _logits_processor
100
+
101
+
102
+ def upsample_prompt(prompt: str, width: int, height: int) -> str:
103
+ from math import gcd
104
+
105
+ d = gcd(width, height) or 1
106
+ aspect_ratio = f"{width // d}:{height // d}"
107
+ user = USER_TEMPLATE.replace("{{aspect_ratio}}", aspect_ratio).replace("{{original_prompt}}", prompt)
108
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user}]
109
+ inputs = upsampler_proc.apply_chat_template(
110
+ messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True
111
+ ).to(upsampler.device)
112
+ gen_kwargs = dict(max_new_tokens=1024, do_sample=True, temperature=1.0, use_cache=True)
113
+ lp = _get_logits_processor()
114
+ if lp is not None:
115
+ lp.reset()
116
+ gen_kwargs["logits_processor"] = [lp]
117
+ out = upsampler.generate(**inputs, **gen_kwargs)
118
+ return upsampler_proc.batch_decode(
119
+ out[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
120
+ )[0].strip()
121
+
122
+
123
+ @spaces.GPU(duration=240)
124
  def generate(
125
  prompt: str,
126
+ mode: str,
127
+ enhance: bool,
128
  width: int,
129
  height: int,
 
 
130
  seed: int,
131
  randomize_seed: bool,
132
  progress=gr.Progress(track_tqdm=True),
133
  ):
134
  if randomize_seed or seed < 0:
135
  seed = random.randint(0, MAX_SEED)
 
136
 
137
+ final_prompt = prompt
138
+ if enhance:
139
+ if not OUTLINES_AVAILABLE:
140
+ gr.Warning("`outlines` is not installed — upsampling without structural constraints.")
141
+ final_prompt = upsample_prompt(prompt, int(width), int(height))
142
+
143
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
144
+ preset = MODES.get(mode, MODES["Default · 20 steps"])
145
+ image = pipe(
146
+ prompt=final_prompt,
147
  width=int(width),
148
  height=int(height),
 
149
  generator=generator,
150
+ **preset,
151
+ ).images[0]
152
+ return image, seed, final_prompt
 
 
 
 
 
 
 
 
153
 
154
 
155
+ with gr.Blocks(theme=gr.themes.Citrus(), title="Ideogram 4 (NF4) — diffusers preview") as demo:
156
  gr.Markdown(
157
  "## Ideogram 4 (NF4) — diffusers preview\n"
158
  f"Private demo of [`{MODEL_ID}`](https://huggingface.co/{MODEL_ID}) using the "
159
+ "[diffusers PR](https://github.com/huggingface/diffusers-new-model-addition-ideogram) branch, on ZeroGPU.\n"
160
+ "Toggle **Prompt upsampling** in Advanced to rewrite your idea into Ideogram's native structured caption "
161
+ "(Qwen3-VL-8B + Outlines)."
162
  )
163
 
164
  with gr.Row():
 
168
  value="A photo of a cat holding a sign that says hello world",
169
  lines=3,
170
  )
171
+ mode = gr.Radio(
172
+ choices=list(MODES.keys()),
173
+ value="Default · 20 steps",
174
+ label="Mode (speed ↔ quality)",
175
+ )
176
  run = gr.Button("Generate", variant="primary")
177
  with gr.Accordion("Advanced", open=False):
178
+ enhance = gr.Checkbox(
179
+ label="Prompt upsampling (Outlines)",
180
+ value=False,
181
+ info="Rewrite the prompt into Ideogram's native JSON caption before generating."
182
+ + ("" if OUTLINES_AVAILABLE else " ⚠ outlines not installed — runs unconstrained."),
183
+ )
184
  with gr.Row():
185
  width = gr.Slider(512, 2048, value=1024, step=64, label="Width")
186
  height = gr.Slider(512, 2048, value=1024, step=64, label="Height")
 
 
 
 
 
187
  with gr.Row():
188
  seed = gr.Number(label="Seed", value=0, precision=0)
189
  randomize = gr.Checkbox(label="Randomize seed", value=True)
190
  with gr.Column():
191
  out_image = gr.Image(label="Output", type="pil")
192
  out_seed = gr.Number(label="Seed used", interactive=False, precision=0)
193
+ out_caption = gr.Textbox(
194
+ label="Caption fed to the model (upsampled when enabled)",
195
+ lines=4,
196
+ show_copy_button=True,
197
+ )
198
 
199
  run.click(
200
  generate,
201
+ inputs=[prompt, mode, enhance, width, height, seed, randomize],
202
+ outputs=[out_image, out_seed, out_caption],
203
  )
204
 
205
  demo.queue().launch()