import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import spaces
import torch
import json
import re
import time
import tempfile
import io
from typing import Optional, List, Dict, Any
from PIL import Image
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
from diffusers import Flux2KleinPipeline
# ---------- Model IDs ----------
PLANNER_ID = "InterleaveThinker/InterleaveThinker-Planner-8B"
CRITIC_ID = "InterleaveThinker/Critic-SFT-8B"
FLUX_KLEIN_ID = "black-forest-labs/FLUX.2-klein-9B"
# ---------- System Prompts (from the original repo) ----------
NARRATIVE_PROMPT_JSON = """
# Task Planner, Orchestrator, and Prompt Engineer System
You are an expert **Task Planner, Orchestrator, and Prompt Engineer**.
Your goal is to analyze a user's request, generate a structured execution plan, and optimize EVERY step's instruction into a highly effective Text-to-Image (T2I) prompt or Image Editing instruction.
## Input Information
Here are the instructions that were involved in this process:
Original User Instruction (user's request): "{text_input}"
## Execution Plan Instructions
1. **Dynamic Step Count (Image Operations Only)**: Determine the necessary number of steps. Every step in your execution plan MUST represent an actual image generation or image editing action. **DO NOT** create separate steps solely for generating text, captions, or summaries.
2. **Complete & Polished Output**: Always aim for a fully realized final product. For visual or creative tasks, the final step MUST result in a fully colored, detailed, and polished output. Do not stop at a draft, outline, or uncolored sketch unless the user explicitly requests it.
3. **Text Generation & Auxiliary Text Rule**:
- If the user specifically asks to render or draw text *inside* the image, include this requirement within the `instruction` field.
- If the user explicitly asks for a *separate* text response (e.g., a caption, summary, explanation, or knowledge grounding) to accompany the image, generate this text and place it in the `auxiliary_text` field of the corresponding image generation step.
- If the user does not explicitly request any separate text or caption, you MUST set `auxiliary_text` to `null`.
## Optimize Prompt Instructions
1. **Prompt Optimization for All Steps**: Convert the `instruction` of EVERY step into a highly effective prompt in the `prompt` field.
- **Step 1 (Generation)**: Create a highly detailed T2I prompt representing the foundational stage. Focus *only* on the Step 1 instruction. Do NOT hallucinate unmentioned details or future elements.
- **Subsequent Steps (Editing)**: Create clear, actionable image editing instructions (e.g., "add a red hat", "change the background to a cyberpunk city") based on the current step's goal.
2. **CRITICAL**: The `prompt` field MUST contain ONLY the pure text prompt or editing instruction. DO NOT include meta-text, prefixes (such as "Step 1:", "Prompt:", "Edit:"), or conversational filler. It must be directly usable by the generation/editing API.
## Output
The output consists of two parts:
1. A Statement - Analysis process and reasoning;
2. A JSON — Planing each step and rewrite the instruction to prompt suitable for generation/editing.
Here is a output example
Part 1: Planning analysis explaining the execution plan. Part 2: Analysis of how the instructions were translated into visual keywords for the T2I prompt and editing instructions.
{
'execution_plan':
[
{'step_number': 1, 'step_name': 'Short name for the step', 'instruction': 'Detailed instruction for this image generation step.', 'prompt': "The optimized, pure T2I prompt suitable for the image generation model. (No 'Step 1:' prefix)", 'auxiliary_text': 'The required caption, summary, or text explanation. Output null if no separate text is explicitly requested.'},
{'step_number': 2, 'step_name': 'Short name for the step', 'instruction': 'Detailed instruction for this image editing step.', 'prompt': "The optimized, pure instruction suitable for the image editing model. (No 'Step 2:' prefix)", 'auxiliary_text': None}
]
}
"""
Iterative_T2I_PROMPT_QWEN = """
# Generation/Edit Evaluation and Prompt Refinement System
You are an expert image editing evaluator and prompt engineer. Your task is to:
1. Evaluate the edited image and output the result in boolean format (True/False).
2. If you think the edited image is not good enough (False), generate an optimized rewritten prompt that addresses the original shortcomings; if you think it is good enough (True), output the [Original Rewritten Prompt].
## Input Information
You have been presented with two images in sequence:
- Original Image: The input image before editing. (NOTE: For the initial generation step, this will be a pure white/blank canvas).
- Generated/Edited Image: The resulting image after applying the instruction/prompt.
Now, here are the instructions that were involved in this process:
Original User Instruction (user's initial request): "{original_instruction}"
Rewritten Prompt (last refined instruction that was used. **NOTE: If this is empty, you must base your evaluation and refinement entirely on the Original User Instruction**): "{rewritten_prompt}"
## Evaluation Instructions
**Evaluate Previous Step (Strict 2-Part Check)**: Carefully compare the **Before Image** and the **After Image**. You must evaluate based on two strict criteria. If the image fails *either* criteria, the step is a FAILURE.
1. **Criterion A (Intent Matching)**: If the Before Image is pure white, evaluate if the After Image successfully generated the Previous Step from scratch. Otherwise, observe the delta (differences). Did the changes match the key meaning and necessary details of the Previous Step?
2. **Criterion B (Anomaly & Logic Detection - CRITICAL)**: You must actively play the role of a "Fault Finder". Do NOT just check if the requested object exists; you MUST check HOW it exists. Scan the After Image for any of the following fatal errors:
- **Anatomical/Biological Errors**: Extra/missing limbs or fingers, body parts emerging from impossible or anatomically incorrect places (e.g., a hand growing out of a chest, stomach, or a wall), distorted faces.
- **Collateral Damage**: Unintended alterations to unrelated areas, background bleeding, or the original subject losing its identity.
## Prompt Refinement Strategy (if NOT GOOD ENOUGH, False)
When generating a new rewritten prompt, analyze:
1. **What went wrong?**
- Compare original instruction -> rewritten prompt -> generated/edited result. *(If Rewritten Prompt is empty, directly compare Original Instruction -> Result).*
- Identify gaps between intent and execution
- Determine if the issue is clarity, specificity, or contradiction
2. **Refinement Approaches:**
**If this is an Initial Generation task (Before image was blank):**
- **Establish Foundation:** Translate the raw user instruction into a comprehensive Text-to-Image prompt.
- **Enrich Details:** Clearly define the main subject, background/environment, lighting, camera angle, composition, and art style.
- **Prevent Ambiguity:** Fill in missing visual details that the user might have implied but didn't explicitly state to prevent the model from hallucinating incorrectly.
- **Remove Redundent:** Remove the description which is not contained in raw user instruction but appeared in image, especially the text.
**If the rewritten prompt was too vague:**
- Add more specific descriptors (exact colors, positions, sizes)
- Include spatial relationships and context
- Specify interaction with existing elements
**If the rewritten prompt was contradictory:**
- Resolve conflicts between requirements
- Prioritize core intent over secondary details
- Simplify complex multi-part instructions
**If important details were lost:**
- Explicitly state preservation requirements
- Add "maintain [aspect]" or "preserve [feature]" clauses
- Reference specific elements from the original image
**If positioning/scale was wrong:**
- Use more precise spatial descriptors
- Add relative size/scale indicators
- Specify foreground/midground/background placement
**If style/appearance was incorrect:**
- Use more specific visual vocabulary
- Add reference to original image's style elements
- Include material/texture/lighting specifications
**If the edit was over/under-processed:**
- Add modifiers like "subtle", "gentle", "dramatic", "significant"
- Specify degree of change more clearly
- Balance enhancement with naturalness
3. **Leverage All Information:**
- Reference what's visible in the original image
- Learn from what the previous rewritten prompt missed
- Use the edited image as feedback on what went wrong
- Maintain what worked, fix what didn't
## Output
The output consists of three parts:
1. A Statement - Analysis process and reasoning;
2. A Boolean - Judge whether the edited images is good enough;
3. A prompt — either the optimized rewritten prompt or the original rewritten prompt.
Here is a output example:
Detailed explanation of evaluation and new rewritten prompt. If edited image is good enough, explain why it meets requirements. If not good enough, explain specific shortcomings.
{
'previous_step_success': 'boolean (True ONLY IF the Intent Check is successful AND the Anomaly Check finds ZERO errors. If ANY anomaly is detected, this MUST be False.)',
'refine_prompt': '[Improved rewritten prompt that addresses identified issues and enhances clarity, specificity, and preservation requirements] if NOT GOOD ENOUGH (False), [original rewritten prompt] if GOOD ENOUGH (True)'
}
"""
# ---------- Utility functions ----------
def parse_llm_json(response_text):
"""Parse JSON from LLM response wrapped in tags."""
import json_repair
try:
answer_match = re.search(r"(.*?)", response_text, re.DOTALL)
json_text = answer_match.group(1).strip() if answer_match else response_text
return json_repair.loads(json_text)
except Exception:
return None
def construct_messages(text: str, imgs: Optional[List[str]]) -> List[Dict]:
"""Build chat messages with interleaved text and image placeholders."""
user_content: List[Dict[str, Any]] = []
if not imgs:
user_content.append({"type": "text", "text": text})
return [{"role": "user", "content": user_content}]
num_imgs = len(imgs)
parts = re.split(r'(?i)', text)
for i, part in enumerate(parts):
if part:
user_content.append({"type": "text", "text": part})
if i < len(imgs):
if num_imgs >= 5:
user_content.append({
"type": "image",
"image": imgs[i],
"max_pixels": 384 * 384
})
else:
user_content.append({"type": "image", "image": imgs[i]})
return [{"role": "user", "content": user_content}]
def qwen3_vl_predict(model, processor, messages, max_new_tokens=4096):
"""Run inference on a Qwen3VL model and return text output."""
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = inputs.to(model.device)
print(f"[DEBUG qwen3_vl_predict] Input keys: {inputs.keys() if hasattr(inputs, 'keys') else type(inputs)}")
print(f"[DEBUG qwen3_vl_predict] Input IDs shape: {inputs.input_ids.shape if hasattr(inputs, 'input_ids') else 'N/A'}")
# Print last 20 tokens of input to see if generation prompt is correct
input_ids_list = inputs.input_ids[0].tolist()
print(f"[DEBUG qwen3_vl_predict] Last 20 input tokens: {input_ids_list[-20:]}")
print(f"[DEBUG qwen3_vl_predict] Decoded last 20: {processor.tokenizer.decode(input_ids_list[-20:])}")
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7,
top_p=0.8,
top_k=20,
use_cache=True,
)
print(f"[DEBUG qwen3_vl_predict] Generated IDs shape: {generated_ids.shape}")
# Check what the generated token is
gen_tokens = generated_ids[0].tolist()
print(f"[DEBUG qwen3_vl_predict] Last 5 generated tokens: {gen_tokens[-5:]}")
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
print(f"[DEBUG qwen3_vl_predict] Trimmed lengths: {[len(t) for t in generated_ids_trimmed]}")
# Try decoding without skip_special_tokens to see what's there
raw_decode = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
)
print(f"[DEBUG qwen3_vl_predict] Raw decode (no skip): {raw_decode[0][:200] if raw_decode else 'empty'}")
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(f"[DEBUG qwen3_vl_predict] Raw output text length: {[len(t) for t in output_text]}")
return output_text[0].strip() if output_text else ""
def get_qwen_response(model, processor, prompt: str, image_files=None):
"""Helper: build messages and get a Qwen3VL response."""
if image_files is not None:
if isinstance(image_files, str):
image_files = [image_files]
image_inputs = [str(img) for img in image_files]
else:
image_inputs = None
messages = construct_messages(prompt, image_inputs)
return qwen3_vl_predict(model, processor, messages)
# ---------- Model loading at module scope ----------
print("[Startup] Loading models...")
# 4-bit quantization for the Critic model (less critical, saves VRAM)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Load Planner at bf16 (primary model, needs quality)
print(f"[Startup] Loading Planner (bf16): {PLANNER_ID}")
planner_model = Qwen3VLForConditionalGeneration.from_pretrained(
PLANNER_ID,
dtype=torch.bfloat16,
attn_implementation="sdpa",
).to("cuda").eval()
planner_processor = AutoProcessor.from_pretrained(PLANNER_ID)
print("[Startup] Planner loaded.")
# Load Critic at 4-bit (saves VRAM)
print(f"[Startup] Loading Critic (4-bit): {CRITIC_ID}")
critic_model = Qwen3VLForConditionalGeneration.from_pretrained(
CRITIC_ID,
quantization_config=bnb_config,
device_map="cuda",
attn_implementation="sdpa",
).eval()
critic_processor = AutoProcessor.from_pretrained(CRITIC_ID)
print("[Startup] Critic loaded.")
print(f"[Startup] Loading FLUX.2-klein: {FLUX_KLEIN_ID}")
flux_pipe = Flux2KleinPipeline.from_pretrained(
FLUX_KLEIN_ID,
torch_dtype=torch.bfloat16,
).to("cuda")
flux_pipe.set_progress_bar_config(disable=True)
print("[Startup] FLUX.2-klein loaded.")
# Create a white canvas as initial image for first step
WHITE_IMAGE = Image.new("RGB", (1024, 1024), color="white")
print("[Startup] All models loaded successfully.")
# ---------- Inference pipeline ----------
@spaces.GPU(duration=300)
def generate_interleaved(
prompt: str,
max_steps: int = 4,
max_retries: int = 2,
progress=None,
):
"""Generate an interleaved text-image sequence from a visual narrative prompt.
Args:
prompt: A visual narrative prompt describing what to generate.
max_steps: Maximum number of generation steps in the plan.
max_retries: Maximum critic retries per step.
progress: Gradio progress callback (auto-injected by Gradio).
Returns:
A list of interleaved (text, image) results showing the step-by-step
visual narrative generation process.
"""
import json_repair
# Create a no-op progress if none provided
if progress is None:
def progress(*args, **kwargs):
pass
# Step 1: Planning
progress(0.05, desc="Planner: analyzing prompt and creating execution plan...")
user_prompt = prompt.replace(' both visually and textually', '')
global_plan = None
for attempt in range(10):
try:
text_input = NARRATIVE_PROMPT_JSON.replace('{text_input}', user_prompt)
raw_response = get_qwen_response(planner_model, planner_processor, text_input)
print(f"[DEBUG] Planner raw response (attempt {attempt}): {raw_response[:500]}")
global_plan = parse_llm_json(raw_response)
print(f"[DEBUG] Parsed plan: {global_plan}")
if global_plan and 'execution_plan' in global_plan:
for item in global_plan['execution_plan']:
_ = item['step_number']
_ = item['instruction']
_ = item['prompt']
_ = item['auxiliary_text']
break
except Exception as e:
print(f"Planner attempt {attempt} failed: {e}")
time.sleep(1)
if not isinstance(global_plan, dict) or 'execution_plan' not in global_plan:
return [("Planner failed to generate a plan. Please try a different prompt.", None)]
# Limit steps
execution_plan = global_plan['execution_plan'][:max_steps]
total_steps = len(execution_plan)
print(f"[Inference] Plan has {total_steps} steps")
results = []
current_source_images = [WHITE_IMAGE]
for step_idx, plan_item in enumerate(execution_plan):
step_num = plan_item.get('step_number', step_idx + 1)
original_prompt = plan_item.get('prompt') or plan_item.get('auxiliary_text') or plan_item.get('instruction', '')
current_refine_prompt = original_prompt
step_text = plan_item.get('auxiliary_text') or plan_item.get('instruction') or f"Step {step_num}"
progress(
(step_idx + 0.1) / total_steps,
desc=f"Step {step_num}/{total_steps}: generating image..."
)
step_image = None
for attempt in range(max_retries + 1):
# Generate/edit image with FLUX.2-klein
generator = torch.Generator(device="cuda").manual_seed(int(step_idx * 1000 + attempt * 42))
if step_idx == 0 and len(current_source_images) <= 1:
# First step: text-to-image (no source image)
image = flux_pipe(
prompt=current_refine_prompt,
height=1024,
width=1024,
guidance_scale=1.0,
num_inference_steps=4,
generator=generator,
).images[0]
else:
# Editing step: use last generated image as source
source_img = current_source_images[-1]
image = flux_pipe(
image=source_img,
prompt=current_refine_prompt,
height=1024,
width=1024,
guidance_scale=1.0,
num_inference_steps=4,
generator=generator,
).images[0]
step_image = image
if attempt >= max_retries:
print(f"Step {step_num} reached max retries, accepting result.")
break
# Critic evaluation
progress(
(step_idx + 0.5) / total_steps,
desc=f"Step {step_num}/{total_steps}: critic evaluating..."
)
# Save images to temp files for the VLM
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as before_f, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as after_f:
current_source_images[-1].save(before_f.name)
step_image.save(after_f.name)
critic_input = Iterative_T2I_PROMPT_QWEN.replace(
'{original_instruction}', original_prompt
).replace('{rewritten_prompt}', current_refine_prompt)
critic_input = f'\n{critic_input}'
try:
critic_response = get_qwen_response(
critic_model, critic_processor,
prompt=critic_input,
image_files=[before_f.name, after_f.name]
)
critic_result = parse_llm_json(critic_response)
refine_prompt = critic_result.get('refine_prompt', current_refine_prompt)
judge = critic_result.get('previous_step_success', True)
if not isinstance(refine_prompt, str) or not isinstance(judge, bool):
raise ValueError("Invalid critic response")
if judge:
print(f"Step {step_num} accepted by critic on attempt {attempt + 1}")
break
else:
print(f"Step {step_num} rejected by critic, refining prompt...")
current_refine_prompt = refine_prompt
except Exception as e:
print(f"Critic failed: {e}, accepting image as-is")
break
# Clean up temp files
os.unlink(before_f.name)
os.unlink(after_f.name)
# Clean up temp files if not already cleaned
try:
os.unlink(before_f.name)
os.unlink(after_f.name)
except Exception:
pass
current_source_images.append(step_image)
results.append((step_text, step_image))
progress(
(step_idx + 1) / total_steps,
desc=f"Step {step_num}/{total_steps}: complete"
)
return results
# ---------- Gradio UI ----------
import gradio as gr
CSS = """
#col-container { max-width: 1100px; margin: 0 auto; }
.dark .gradio-container { color: var(--body-text-color); }
"""
with gr.Blocks(theme=gr.themes.Citrus(), css=CSS) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# 🤖 InterleaveThinker: Reinforcing Agentic Interleaved Generation")
gr.Markdown(
"Enter a visual narrative prompt and the system will generate an interleaved "
"sequence of text and images using a **Planner** → **FLUX.2-klein** → **Critic** pipeline. "
"[Paper](https://arxiv.org/abs/2606.13679) | "
"[Planner Model](https://huggingface.co/InterleaveThinker/InterleaveThinker-Planner-8B) | "
"[Critic Model](https://huggingface.co/InterleaveThinker/Critic-SFT-8B)"
)
with gr.Row():
prompt_input = gr.Textbox(
label="Visual Narrative Prompt",
show_label=False,
placeholder="Describe a visual narrative, e.g., 'How to draw a cat step by step?'",
container=False,
scale=4,
lines=2,
)
run_btn = gr.Button("Generate", variant="primary", scale=1)
with gr.Row():
max_steps_slider = gr.Slider(
label="Max Steps",
minimum=2,
maximum=6,
step=1,
value=3,
)
max_retries_slider = gr.Slider(
label="Critic Retries per Step",
minimum=0,
maximum=3,
step=1,
value=2,
)
gr.Markdown("### Generated Sequence")
gallery = gr.Gallery(
label="Interleaved Images",
columns=3,
height=500,
show_label=True,
)
output_text = gr.Markdown(label="Step Details")
# Examples
gr.Examples(
examples=[
["How to draw a cat step by step?"],
["Create a step-by-step guide on how to make a simple origami crane."],
["Illustrate the life cycle of a butterfly from egg to adult."],
["Show how to prepare a simple cup of coffee, step by step."],
],
inputs=[prompt_input],
outputs=[gallery, output_text],
fn=generate_interleaved,
cache_examples=True,
cache_mode="lazy",
)
def run_and_format(prompt, max_steps, max_retries):
"""Wrapper to run the pipeline and format output for Gradio."""
results = generate_interleaved(prompt, max_steps, max_retries)
# Gallery expects list of (image, caption) tuples
gallery_items = [(img, f"Step {i+1}") for i, (text, img) in enumerate(results) if img is not None]
# Markdown summary
md_lines = []
for i, (text, img) in enumerate(results):
md_lines.append(f"**Step {i+1}:** {text}")
md_lines.append("")
return gallery_items, "\n".join(md_lines)
run_btn.click(
fn=run_and_format,
inputs=[prompt_input, max_steps_slider, max_retries_slider],
outputs=[gallery, output_text],
api_name="generate",
)
prompt_input.submit(
fn=run_and_format,
inputs=[prompt_input, max_steps_slider, max_retries_slider],
outputs=[gallery, output_text],
api_name="generate_submit",
)
demo.launch(mcp_server=True)