(≤70 words)."},
{"role":"user","content": f"{joined_issues}\n\n----\n\n{joined_refined}"}],
temperature=0.2, max_tokens=420
)
final_raw = merged.choices[0].message.content
final_refined = clip77_strict(_extract_tag(final_raw, "refined", original_prompt), 77)
issues_merged = _summarize_issues_lines(_extract_tag(final_raw, "issues", ""), 5)
return {"refined": final_refined, "issues_merged": issues_merged}
finally:
await _maybe_close_async_together(client)
@staticmethod
def merge_vlm_multi_text(vlm_refined_77: str, tags_77: str) -> str:
vlm_tags = _split_tags(vlm_refined_77)
moa_tags = _split_tags(tags_77)
merged = _dedup_keep_order(_order_tags([vlm_tags[0] if vlm_tags else ""], (vlm_tags[1:] + moa_tags)))
merged = [t for t in merged if t]
text = _cleanup_commas(", ".join(merged))
if _count_tokens(text) > 77:
text = clip77_strict(text, 77)
return text
# =========================
# 5) SpecFusion (latent FFT gate)
# =========================
@torch.no_grad()
def frequency_fusion(
x_hi_latent: torch.Tensor,
x_lo_latent: torch.Tensor,
base_c: float = 0.5,
rho_t: float = 0.85,
device=None,
) -> torch.Tensor:
if device is None:
device = x_hi_latent.device
B, C, H, W = x_hi_latent.shape
x_h = x_hi_latent.to(torch.float32).to(device)
x_l = x_lo_latent.to(torch.float32).to(device)
Xh = torch.fft.fftshift(torch.fft.fftn(x_h, dim=(-2, -1)), dim=(-2, -1))
Xl = torch.fft.fftshift(torch.fft.fftn(x_l, dim=(-2, -1)), dim=(-2, -1))
tau_h = int(H * base_c * (1 - rho_t))
tau_w = int(W * base_c * (1 - rho_t))
mask = torch.ones((B, C, H, W), device=device, dtype=torch.float32)
cy, cx = H // 2, W // 2
if tau_h > 0 and tau_w > 0:
mask[..., cy - tau_h : cy + tau_h, cx - tau_w : cx + tau_w] = rho_t
Xf = Xh * mask + Xl * (1 - mask)
x = torch.fft.ifftn(torch.fft.ifftshift(Xf, dim=(-2, -1)), dim=(-2, -1)).real
x = x + torch.randn_like(x) * 0.001
return x.to(dtype=x_hi_latent.dtype)
def _decode_to_pil(latents):
out = decode_image_sdxl(latents, SDXL_i2i)
if isinstance(out, Image.Image):
return out
if hasattr(out, "images"):
return out.images[0]
return out
def _guidance_for_k(k: int) -> float:
if k >= 20: return 12.0
if k >= 10: return 7.5
return 5.2
# =========================
# 6) ONE-variant generator (because UI enforces single selection)
# =========================
async def generate_one_variant(
user_prompt: str,
seed: int,
H: int,
W: int,
total_steps_refine: int,
last_k: int,
guidance: float,
preset: str,
variant_key: str,
out_dir: Optional[Path] = None,
) -> Tuple[Image.Image, str, Dict[str, object]]:
"""
Returns:
img, display_name, meta_dict
"""
meta: Dict[str, object] = {
"user_prompt": user_prompt,
"variant_key": variant_key,
}
def _save(im: Image.Image, display_name: str):
if out_dir is None:
return
out_dir.mkdir(parents=True, exist_ok=True)
safe = re.sub(r"[^a-zA-Z0-9_\\-]+", "_", display_name)[:120]
im.save(out_dir / f"{safe}.png")
# ----------------------------------------------------------
# Variant 1: Base (Original Prompt) [NO Together needed]
# ----------------------------------------------------------
if variant_key == "base_original":
z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG)
meta.update({"note": "SDXL base generation from original prompt."})
_save(base_og, VARIANT_LABELS[variant_key])
return base_og, VARIANT_LABELS[variant_key], meta
# The rest need Together
if not TOGETHER_API_KEY:
raise RuntimeError("TOGETHER_API_KEY not set, but selected variant requires Together.")
critic = CritiCore(preset=preset)
# Common refine params
lk = int(last_k)
strength = float(strength_for_last_k(lk, total_steps_refine))
use_guidance = float(guidance) if float(guidance) > 0 else float(_guidance_for_k(lk))
steps = int(total_steps_refine)
meta.update({"strength": strength, "guidance": use_guidance, "steps": steps, "last_k": lk})
# ----------------------------------------------------------
# Variant 2: Base (MoA Tags)
# ----------------------------------------------------------
if variant_key == "base_multi_llm":
pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True)
z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags)
meta.update({
"pos_tags_77": pos_tags_77,
"neg_tags": neg_tags,
"note": "SDXL base generation from MoA-generated tags."
})
_save(base_enh, VARIANT_LABELS[variant_key])
return base_enh, VARIANT_LABELS[variant_key], meta
# ----------------------------------------------------------
# Variant 3: CritiFusion (MoA+VLM+SpecFusion)
# ----------------------------------------------------------
if variant_key == "CritiFusion":
pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True)
comps = await critic.decompose_components(user_prompt)
z0_enh, base_enh = base_sample_latent(pos_tags_77, seed=seed, H=H, W=W, neg=neg_tags)
vlm_out = await critic.vlm_refine(base_enh, pos_tags_77, comps or [])
vlm_agg_77 = vlm_out.get("refined") or pos_tags_77
refined_on_enh = CritiCore.merge_vlm_multi_text(vlm_agg_77, pos_tags_77)
z_ref = img2img_latent(
refined_on_enh, z0_enh,
strength=strength, guidance=use_guidance, steps=steps,
seed=seed + 2100 + lk,
neg=DEFAULT_NEG
)
fused_lat = frequency_fusion(z_ref, z0_enh, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE)
img_sf = _decode_to_pil(fused_lat)
meta.update({
"pos_tags_77": pos_tags_77,
"neg_tags": neg_tags,
"components": comps,
"vlm_refined_77": vlm_agg_77,
"enhanced_prompt_77": refined_on_enh,
"vlm_issues": vlm_out.get("issues_merged", ""),
"note": "MoA tags + VLM critique prompt + img2img + SpecFusion."
})
_save(img_sf, VARIANT_LABELS[variant_key])
return img_sf, VARIANT_LABELS[variant_key], meta
# ----------------------------------------------------------
# Variant 4: CritiFusion (Original+VLM+SpecFusion)
# ----------------------------------------------------------
if variant_key == "criticore_on_original__specfusion":
pos_tags_77, neg_tags = await critic.make_tags(user_prompt, clip77=True)
comps = await critic.decompose_components(user_prompt)
z0_og, base_og = base_sample_latent(user_prompt, seed=seed, H=H, W=W, neg=DEFAULT_NEG)
vlm_on_og = await critic.vlm_refine(base_og, user_prompt, comps or [])
refined_og_77 = clip77_strict(vlm_on_og.get("refined") or user_prompt, 77)
refined_merge = CritiCore.merge_vlm_multi_text(refined_og_77, pos_tags_77)
z_ref = img2img_latent(
refined_merge, z0_og,
strength=strength, guidance=use_guidance, steps=steps,
seed=seed + 2400 + lk,
neg=DEFAULT_NEG
)
fused_lat = frequency_fusion(z_ref, z0_og, base_c=0.5, rho_t=RHO_T_DEFAULT, device=DEVICE)
img_sf = _decode_to_pil(fused_lat)
meta.update({
"pos_tags_77": pos_tags_77,
"neg_tags": neg_tags,
"components": comps,
"vlm_refined_77": refined_og_77,
"enhanced_prompt_77": refined_merge,
"vlm_issues": vlm_on_og.get("issues_merged", ""),
"note": "Original prompt + VLM critique prompt + img2img + SpecFusion."
})
_save(img_sf, VARIANT_LABELS[variant_key])
return img_sf, VARIANT_LABELS[variant_key], meta
raise ValueError(f"Unknown variant_key: {variant_key}")
# =========================
# 7) UI callbacks
# =========================
def ui_run_once(
user_prompt: str,
seed: int,
H: int,
W: int,
preset: str,
total_steps_refine: int,
last_k: int,
guidance: float,
enabled_variants_display: List[str],
save_outputs: bool,
out_dir: str,
):
t0 = time.time()
try:
if not user_prompt or not user_prompt.strip():
return [], "Empty prompt."
# display -> internal
display_to_internal = {v: k for k, v in VARIANT_LABELS.items()}
chosen_display = (enabled_variants_display or [])[-1:] # enforce single here too
if not chosen_display:
return [], "Please select ONE variant."
chosen_display = chosen_display[0]
variant_key = display_to_internal.get(chosen_display)
if variant_key is None:
return [], f"Unknown selected variant: {chosen_display}"
out_path = Path(out_dir) if (save_outputs and out_dir) else None
img, disp_name, meta = _run_async(generate_one_variant(
user_prompt=user_prompt.strip(),
seed=int(seed),
H=int(H), W=int(W),
total_steps_refine=int(total_steps_refine),
last_k=int(last_k),
guidance=float(guidance),
preset=preset,
variant_key=variant_key,
out_dir=out_path,
))
meta["ui"] = {
"seed": int(seed),
"H": int(H),
"W": int(W),
"preset": preset,
"total_steps_refine": int(total_steps_refine),
"last_k": int(last_k),
"guidance": float(guidance),
"selected_variant": chosen_display,
"save_outputs": bool(save_outputs),
"out_dir": out_dir if save_outputs else None,
}
meta["elapsed_sec"] = round(time.time() - t0, 3)
gallery = [(img, disp_name)]
return gallery, json.dumps(meta, ensure_ascii=False, indent=2)
except Exception:
return [], traceback.format_exc()
@spaces.GPU
def ui_run_once_gpu(*args, **kwargs):
return ui_run_once(*args, **kwargs)
# =========================
# 8) Single-select enforcement for CheckboxGroup
# =========================
def enforce_single_variant(new_list: List[str], prev_list: List[str]):
new_list = new_list or []
prev_list = prev_list or []
new_set = set(new_list)
prev_set = set(prev_list)
added = list(new_set - prev_set)
if added:
# keep the newly added one
chosen = added[-1]
out = [chosen]
else:
# no added; maybe removed or same; if multi exists, keep last item
out = new_list[-1:] if len(new_list) > 1 else new_list
return out, out # update checkbox value + state
# =========================
# 9) Gradio UI
# =========================
with gr.Blocks(title="CritiFusion (SDXL) Demo") as demo:
gr.Markdown(
f"""
# CritiFusion Demo (SDXL)
Keep Enabled Variants pills UI, but only one can be selected.
Device: {DEVICE_STR}, DType: {DTYPE}
Together API: {'✅ set' if TOGETHER_API_KEY else '❌ missing (set TOGETHER_API_KEY)'}
"""
)
gr.Markdown(
"""
"""
)
with gr.Row():
with gr.Column(scale=7):
user_prompt = gr.Textbox(
label="Prompt",
value="A fluffy orange cat lying on a window ledge, front-facing, stylized 3D, soft indoor lighting",
lines=3,
)
with gr.Row():
seed = gr.Number(label="Seed", value=2026, precision=0)
preset = gr.Dropdown(label="Preset", choices=["hq_preference"], value="hq_preference")
with gr.Row():
H = gr.Number(label="H", value=1024, precision=0)
W = gr.Number(label="W", value=1024, precision=0)
with gr.Row():
total_steps_refine = gr.Slider(label="total_steps_refine", minimum=10, maximum=80, step=1, value=50)
last_k = gr.Slider(label="last_k", minimum=1, maximum=50, step=1, value=37)
guidance = gr.Slider(
label="Guidance (0 => fallback rule)",
minimum=0.0, maximum=15.0, step=0.1, value=0.0
)
# --- pills UI, but single-select enforced ---
selected_state = gr.State([VARIANT_LABELS["base_original"]])
enabled_variants = gr.CheckboxGroup(
label="Enabled Variants (select ONE)",
choices=[VARIANT_LABELS[k] for k in VARIANT_LABELS.keys()],
value=[VARIANT_LABELS["base_original"]],
)
# enforce single selection on change
enabled_variants.change(
fn=enforce_single_variant,
inputs=[enabled_variants, selected_state],
outputs=[enabled_variants, selected_state],
)
with gr.Row():
save_outputs = gr.Checkbox(label="Save output to disk", value=False)
out_dir = gr.Textbox(label="Output dir (only if save enabled)", value="./variants_demo_gradio")
run_btn = gr.Button("Run", variant="primary")
with gr.Column(scale=8):
gallery = gr.Gallery(label="Result", columns=1, height=600)
meta_json = gr.Code(label="Meta / Debug (JSON)", language="json")
run_btn.click(
fn=ui_run_once_gpu,
inputs=[user_prompt, seed, H, W, preset, total_steps_refine, last_k, guidance, enabled_variants, save_outputs, out_dir],
outputs=[gallery, meta_json],
api_name=False, # gradio-safe (avoid schema issues)
)
demo.queue().launch(
debug=True,
share=True, # optional; helps if you run outside Spaces
)