Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -5,17 +5,19 @@ Live demonstration of multimodal generation + coherence evaluation.
|
|
| 5 |
Enter a scene description and the system produces coherent text, image,
|
| 6 |
and audio with real-time MSCI scoring.
|
| 7 |
|
| 8 |
-
Pipeline: HF Inference API (text) + CLIP retrieval (image) + CLAP retrieval (audio)
|
|
|
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
| 12 |
|
|
|
|
| 13 |
import logging
|
| 14 |
import os
|
| 15 |
import sys
|
| 16 |
import time
|
| 17 |
from pathlib import Path
|
| 18 |
-
from typing import Optional
|
| 19 |
|
| 20 |
import streamlit as st
|
| 21 |
|
|
@@ -71,10 +73,14 @@ html, body, [class*="css"] { font-family: 'Inter', -apple-system, sans-serif; }
|
|
| 71 |
font-size: 0.7rem; font-weight: 600; letter-spacing: 0.03em;
|
| 72 |
}
|
| 73 |
.chip-purple { background: rgba(129,140,248,0.14); color: #a5b4fc; }
|
|
|
|
| 74 |
.chip-green { background: rgba(52,211,153,0.14); color: #6ee7b7; }
|
|
|
|
| 75 |
.chip-dot { width: 6px; height: 6px; border-radius: 50%; }
|
| 76 |
.chip-dot-purple { background: #818cf8; }
|
|
|
|
| 77 |
.chip-dot-green { background: #34d399; }
|
|
|
|
| 78 |
|
| 79 |
.scores-grid {
|
| 80 |
display: grid; grid-template-columns: repeat(4, 1fr);
|
|
@@ -196,6 +202,82 @@ EXAMPLE_PROMPTS = {
|
|
| 196 |
}
|
| 197 |
DOMAIN_ICONS = {"nature": "\U0001f33f", "urban": "\U0001f3d9\ufe0f", "water": "\U0001f30a", "mixed": "\U0001f310", "other": "\U0001f4cd"}
|
| 198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
# ---------------------------------------------------------------------------
|
| 201 |
# Cached model loading
|
|
@@ -223,12 +305,136 @@ def get_inference_client():
|
|
| 223 |
return InferenceClient(token=token)
|
| 224 |
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
# ---------------------------------------------------------------------------
|
| 227 |
# Generation / retrieval functions
|
| 228 |
# ---------------------------------------------------------------------------
|
| 229 |
|
| 230 |
-
def
|
| 231 |
-
"""Generate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
system_prompt = (
|
| 233 |
"You are a concise descriptive writer. "
|
| 234 |
"Write a literal description of the scene in 3 to 5 natural sentences. "
|
|
@@ -236,21 +442,19 @@ def gen_text_hf(prompt: str) -> dict:
|
|
| 236 |
"Focus on concrete visual details AND the likely audio ambience."
|
| 237 |
)
|
| 238 |
try:
|
| 239 |
-
|
| 240 |
-
response = client.chat_completion(
|
| 241 |
-
messages=[
|
| 242 |
-
{"role": "system", "content": system_prompt},
|
| 243 |
-
{"role": "user", "content": f"Describe this scene: {prompt}"},
|
| 244 |
-
],
|
| 245 |
-
max_tokens=250,
|
| 246 |
-
)
|
| 247 |
-
text = response.choices[0].message.content.strip()
|
| 248 |
if not text:
|
| 249 |
raise ValueError("Empty response")
|
| 250 |
-
return {"text": text, "image_prompt": prompt, "audio_prompt": prompt, "plan": None}
|
| 251 |
except Exception as e:
|
| 252 |
-
logger.warning("HF
|
| 253 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
|
| 256 |
def retrieve_image(prompt: str) -> dict:
|
|
@@ -334,6 +538,20 @@ def main():
|
|
| 334 |
|
| 335 |
# Sidebar
|
| 336 |
with st.sidebar:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
st.markdown("#### Examples")
|
| 338 |
for dname, prompts in EXAMPLE_PROMPTS.items():
|
| 339 |
icon = DOMAIN_ICONS.get(dname.lower(), "\U0001f4cd")
|
|
@@ -343,16 +561,23 @@ def main():
|
|
| 343 |
st.session_state["prompt_input"] = p
|
| 344 |
|
| 345 |
st.divider()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
st.markdown(
|
| 347 |
-
'<div class="sidebar-info">'
|
| 348 |
-
'<b>Text</b> HF Inference API<br>'
|
| 349 |
-
'<b>
|
| 350 |
-
'<b>
|
| 351 |
-
'<b>
|
| 352 |
-
'<b>
|
| 353 |
-
'
|
| 354 |
-
'
|
| 355 |
-
'
|
|
|
|
| 356 |
|
| 357 |
# Prompt input
|
| 358 |
default_prompt = st.session_state.get("prompt_input", "")
|
|
@@ -367,11 +592,14 @@ def main():
|
|
| 367 |
with bc1:
|
| 368 |
go = st.button("Generate Bundle", type="primary", use_container_width=True, disabled=not prompt.strip())
|
| 369 |
with bc2:
|
|
|
|
|
|
|
|
|
|
| 370 |
st.markdown(
|
| 371 |
-
'<div class="chip-row">'
|
| 372 |
-
'<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
|
| 373 |
-
'<span class="chip
|
| 374 |
-
'</div>', unsafe_allow_html=True)
|
| 375 |
|
| 376 |
# Welcome state
|
| 377 |
if not go and "last_result" not in st.session_state:
|
|
@@ -384,7 +612,7 @@ def main():
|
|
| 384 |
return
|
| 385 |
|
| 386 |
if go and prompt.strip():
|
| 387 |
-
st.session_state["last_result"] = run_pipeline(prompt.strip())
|
| 388 |
|
| 389 |
if "last_result" in st.session_state:
|
| 390 |
show_results(st.session_state["last_result"])
|
|
@@ -394,17 +622,22 @@ def main():
|
|
| 394 |
# Pipeline
|
| 395 |
# ---------------------------------------------------------------------------
|
| 396 |
|
| 397 |
-
def run_pipeline(prompt: str) -> dict:
|
| 398 |
-
R: dict = {}
|
| 399 |
t_all = time.time()
|
| 400 |
|
| 401 |
-
# 1) Text
|
| 402 |
-
|
|
|
|
| 403 |
t0 = time.time()
|
| 404 |
try:
|
| 405 |
-
R["text"] =
|
| 406 |
R["t_text"] = time.time() - t0
|
| 407 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
except Exception as e:
|
| 409 |
s.update(label=f"Text failed: {e}", state="error")
|
| 410 |
R["text"] = {"text": prompt, "image_prompt": prompt, "audio_prompt": prompt}
|
|
@@ -554,6 +787,37 @@ def show_results(R: dict):
|
|
| 554 |
st.markdown("---")
|
| 555 |
|
| 556 |
# Expandable details
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
with st.expander("Retrieval Details"):
|
| 558 |
r1, r2 = st.columns(2)
|
| 559 |
with r1:
|
|
|
|
| 5 |
Enter a scene description and the system produces coherent text, image,
|
| 6 |
and audio with real-time MSCI scoring.
|
| 7 |
|
| 8 |
+
Pipeline: HF Inference API (text + planning) + CLIP retrieval (image) + CLAP retrieval (audio)
|
| 9 |
+
Planning modes: direct, planner, council (3-way), extended_prompt (3x tokens)
|
| 10 |
"""
|
| 11 |
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
+
import json
|
| 15 |
import logging
|
| 16 |
import os
|
| 17 |
import sys
|
| 18 |
import time
|
| 19 |
from pathlib import Path
|
| 20 |
+
from typing import Any, Dict, Optional
|
| 21 |
|
| 22 |
import streamlit as st
|
| 23 |
|
|
|
|
| 73 |
font-size: 0.7rem; font-weight: 600; letter-spacing: 0.03em;
|
| 74 |
}
|
| 75 |
.chip-purple { background: rgba(129,140,248,0.14); color: #a5b4fc; }
|
| 76 |
+
.chip-pink { background: rgba(244,114,182,0.14); color: #f9a8d4; }
|
| 77 |
.chip-green { background: rgba(52,211,153,0.14); color: #6ee7b7; }
|
| 78 |
+
.chip-amber { background: rgba(251,191,36,0.12); color: #fcd34d; }
|
| 79 |
.chip-dot { width: 6px; height: 6px; border-radius: 50%; }
|
| 80 |
.chip-dot-purple { background: #818cf8; }
|
| 81 |
+
.chip-dot-pink { background: #f472b6; }
|
| 82 |
.chip-dot-green { background: #34d399; }
|
| 83 |
+
.chip-dot-amber { background: #fbbf24; }
|
| 84 |
|
| 85 |
.scores-grid {
|
| 86 |
display: grid; grid-template-columns: repeat(4, 1fr);
|
|
|
|
| 202 |
}
|
| 203 |
DOMAIN_ICONS = {"nature": "\U0001f33f", "urban": "\U0001f3d9\ufe0f", "water": "\U0001f30a", "mixed": "\U0001f310", "other": "\U0001f4cd"}
|
| 204 |
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
# Planning prompt template (same as src/planner/prompts/unified.txt)
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
PLAN_PROMPT_TEMPLATE = """You must produce a SINGLE valid JSON object.
|
| 209 |
+
|
| 210 |
+
RULES:
|
| 211 |
+
- Every field MUST exist
|
| 212 |
+
- Fields that represent lists MUST be arrays
|
| 213 |
+
- Strings must never be arrays
|
| 214 |
+
- Use short phrases, not long paragraphs
|
| 215 |
+
- Do NOT include explanations
|
| 216 |
+
- Do NOT include markdown
|
| 217 |
+
- Do NOT truncate
|
| 218 |
+
|
| 219 |
+
Schema:
|
| 220 |
+
{
|
| 221 |
+
"scene_summary": string,
|
| 222 |
+
"domain": string,
|
| 223 |
+
|
| 224 |
+
"core_semantics": {
|
| 225 |
+
"setting": string,
|
| 226 |
+
"time_of_day": string,
|
| 227 |
+
"weather": string,
|
| 228 |
+
"main_subjects": [string],
|
| 229 |
+
"actions": [string]
|
| 230 |
+
},
|
| 231 |
+
|
| 232 |
+
"style_controls": {
|
| 233 |
+
"visual_style": [string],
|
| 234 |
+
"color_palette": [string],
|
| 235 |
+
"lighting": [string],
|
| 236 |
+
"camera": [string],
|
| 237 |
+
"mood_emotion": [string],
|
| 238 |
+
"narrative_tone": [string]
|
| 239 |
+
},
|
| 240 |
+
|
| 241 |
+
"image_constraints": {
|
| 242 |
+
"must_include": [string],
|
| 243 |
+
"must_avoid": [string],
|
| 244 |
+
"objects": [string],
|
| 245 |
+
"environment_details": [string],
|
| 246 |
+
"composition": [string]
|
| 247 |
+
},
|
| 248 |
+
|
| 249 |
+
"audio_constraints": {
|
| 250 |
+
"audio_intent": [string],
|
| 251 |
+
"sound_sources": [string],
|
| 252 |
+
"ambience": [string],
|
| 253 |
+
"tempo": string,
|
| 254 |
+
"must_include": [string],
|
| 255 |
+
"must_avoid": [string]
|
| 256 |
+
},
|
| 257 |
+
|
| 258 |
+
"text_constraints": {
|
| 259 |
+
"must_include": [string],
|
| 260 |
+
"must_avoid": [string],
|
| 261 |
+
"keywords": [string],
|
| 262 |
+
"length": string
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
User request:
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
EXTENDED_PLAN_SYSTEM = """You are an expert multimodal content planner. Create a detailed,
|
| 270 |
+
comprehensive semantic plan for generating coherent multimodal content (text, image, audio).
|
| 271 |
+
|
| 272 |
+
You have an extended budget. Take your time to:
|
| 273 |
+
1. Deeply analyze the user's request
|
| 274 |
+
2. Consider multiple perspectives and interpretations
|
| 275 |
+
3. Ensure semantic consistency across all modalities
|
| 276 |
+
4. Provide rich, detailed specifications
|
| 277 |
+
|
| 278 |
+
Think step by step about what visual elements, sounds, and descriptive text would best represent the scene.
|
| 279 |
+
After your analysis, produce a SINGLE valid JSON object matching the schema."""
|
| 280 |
+
|
| 281 |
|
| 282 |
# ---------------------------------------------------------------------------
|
| 283 |
# Cached model loading
|
|
|
|
| 305 |
return InferenceClient(token=token)
|
| 306 |
|
| 307 |
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
# HF Inference API helpers
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
|
| 312 |
+
def _hf_chat(system: str, user: str, max_tokens: int = 500, temperature: float = 0.3) -> str:
|
| 313 |
+
"""Call HF Inference API chat completion."""
|
| 314 |
+
client = get_inference_client()
|
| 315 |
+
response = client.chat_completion(
|
| 316 |
+
messages=[
|
| 317 |
+
{"role": "system", "content": system},
|
| 318 |
+
{"role": "user", "content": user},
|
| 319 |
+
],
|
| 320 |
+
max_tokens=max_tokens,
|
| 321 |
+
temperature=temperature,
|
| 322 |
+
)
|
| 323 |
+
return response.choices[0].message.content.strip()
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _parse_plan_json(raw: str) -> Optional[Dict[str, Any]]:
|
| 327 |
+
"""Parse a semantic plan JSON from LLM output, with repair."""
|
| 328 |
+
from src.utils.json_repair import try_repair_json
|
| 329 |
+
return try_repair_json(raw)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def _validate_and_build_plan(data: Dict[str, Any]):
|
| 333 |
+
"""Validate and build a SemanticPlan from dict."""
|
| 334 |
+
from src.planner.validation import validate_semantic_plan_dict
|
| 335 |
+
from src.planner.schema import SemanticPlan
|
| 336 |
+
validate_semantic_plan_dict(data)
|
| 337 |
+
return SemanticPlan(**data)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
# ---------------------------------------------------------------------------
|
| 341 |
+
# Planning functions (HF Inference API)
|
| 342 |
+
# ---------------------------------------------------------------------------
|
| 343 |
+
|
| 344 |
+
def plan_single(prompt: str) -> Optional[Any]:
|
| 345 |
+
"""Single planner call via HF API. Returns SemanticPlan or None."""
|
| 346 |
+
system = "You are a multimodal content planner. Output ONLY valid JSON, no explanations."
|
| 347 |
+
user = PLAN_PROMPT_TEMPLATE + prompt
|
| 348 |
+
try:
|
| 349 |
+
raw = _hf_chat(system, user, max_tokens=1200, temperature=0.3)
|
| 350 |
+
data = _parse_plan_json(raw)
|
| 351 |
+
if data:
|
| 352 |
+
return _validate_and_build_plan(data)
|
| 353 |
+
except Exception as e:
|
| 354 |
+
logger.warning("Planner call failed: %s", e)
|
| 355 |
+
return None
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def plan_council(prompt: str) -> Optional[Any]:
|
| 359 |
+
"""Council mode: 3 planner calls merged. Returns SemanticPlan or None."""
|
| 360 |
+
plans = []
|
| 361 |
+
temps = [0.2, 0.4, 0.5] # Slightly different temperatures for diversity
|
| 362 |
+
system = "You are a multimodal content planner. Output ONLY valid JSON, no explanations."
|
| 363 |
+
user = PLAN_PROMPT_TEMPLATE + prompt
|
| 364 |
+
|
| 365 |
+
for temp in temps:
|
| 366 |
+
try:
|
| 367 |
+
raw = _hf_chat(system, user, max_tokens=1200, temperature=temp)
|
| 368 |
+
data = _parse_plan_json(raw)
|
| 369 |
+
if data:
|
| 370 |
+
plan = _validate_and_build_plan(data)
|
| 371 |
+
plans.append(plan)
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.warning("Council call failed (temp=%.1f): %s", temp, e)
|
| 374 |
+
|
| 375 |
+
if not plans:
|
| 376 |
+
return None
|
| 377 |
+
if len(plans) == 1:
|
| 378 |
+
return plans[0]
|
| 379 |
+
|
| 380 |
+
# Merge using existing merge logic
|
| 381 |
+
try:
|
| 382 |
+
from src.planner.merge_logic import merge_council_plans
|
| 383 |
+
while len(plans) < 3:
|
| 384 |
+
plans.append(plans[0]) # Pad if fewer than 3
|
| 385 |
+
merged, _ = merge_council_plans(plans[0], plans[1], plans[2])
|
| 386 |
+
return merged
|
| 387 |
+
except Exception as e:
|
| 388 |
+
logger.warning("Merge failed: %s — using first plan", e)
|
| 389 |
+
return plans[0]
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def plan_extended(prompt: str) -> Optional[Any]:
|
| 393 |
+
"""Extended prompt mode: longer system prompt, more tokens. Returns SemanticPlan or None."""
|
| 394 |
+
user = PLAN_PROMPT_TEMPLATE + prompt
|
| 395 |
+
try:
|
| 396 |
+
raw = _hf_chat(EXTENDED_PLAN_SYSTEM, user, max_tokens=2000, temperature=0.35)
|
| 397 |
+
data = _parse_plan_json(raw)
|
| 398 |
+
if data:
|
| 399 |
+
return _validate_and_build_plan(data)
|
| 400 |
+
except Exception as e:
|
| 401 |
+
logger.warning("Extended planner failed: %s", e)
|
| 402 |
+
return None
|
| 403 |
+
|
| 404 |
+
|
| 405 |
# ---------------------------------------------------------------------------
|
| 406 |
# Generation / retrieval functions
|
| 407 |
# ---------------------------------------------------------------------------
|
| 408 |
|
| 409 |
+
def gen_text(prompt: str, mode: str) -> dict:
|
| 410 |
+
"""Generate text and optional plan using HF Inference API."""
|
| 411 |
+
# Step 1: Plan (if not direct mode)
|
| 412 |
+
plan = None
|
| 413 |
+
image_prompt = prompt
|
| 414 |
+
audio_prompt = prompt
|
| 415 |
+
|
| 416 |
+
if mode == "planner":
|
| 417 |
+
plan = plan_single(prompt)
|
| 418 |
+
elif mode == "council":
|
| 419 |
+
plan = plan_council(prompt)
|
| 420 |
+
elif mode == "extended_prompt":
|
| 421 |
+
plan = plan_extended(prompt)
|
| 422 |
+
|
| 423 |
+
# Extract modality-specific prompts from plan
|
| 424 |
+
if plan is not None:
|
| 425 |
+
try:
|
| 426 |
+
from src.planner.schema_to_text import plan_to_prompts
|
| 427 |
+
prompts = plan_to_prompts(plan)
|
| 428 |
+
image_prompt = prompts["image_prompt"]
|
| 429 |
+
audio_prompt = prompts["audio_prompt"]
|
| 430 |
+
text_input = prompts["text_prompt"]
|
| 431 |
+
except Exception as e:
|
| 432 |
+
logger.warning("plan_to_prompts failed: %s", e)
|
| 433 |
+
text_input = prompt
|
| 434 |
+
else:
|
| 435 |
+
text_input = prompt
|
| 436 |
+
|
| 437 |
+
# Step 2: Generate text via HF API
|
| 438 |
system_prompt = (
|
| 439 |
"You are a concise descriptive writer. "
|
| 440 |
"Write a literal description of the scene in 3 to 5 natural sentences. "
|
|
|
|
| 442 |
"Focus on concrete visual details AND the likely audio ambience."
|
| 443 |
)
|
| 444 |
try:
|
| 445 |
+
text = _hf_chat(system_prompt, f"Describe this scene: {text_input}", max_tokens=250, temperature=0.7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
if not text:
|
| 447 |
raise ValueError("Empty response")
|
|
|
|
| 448 |
except Exception as e:
|
| 449 |
+
logger.warning("HF text gen failed: %s — using prompt", e)
|
| 450 |
+
text = prompt
|
| 451 |
+
|
| 452 |
+
return {
|
| 453 |
+
"text": text,
|
| 454 |
+
"image_prompt": image_prompt,
|
| 455 |
+
"audio_prompt": audio_prompt,
|
| 456 |
+
"plan": plan.model_dump() if plan and hasattr(plan, "model_dump") else None,
|
| 457 |
+
}
|
| 458 |
|
| 459 |
|
| 460 |
def retrieve_image(prompt: str) -> dict:
|
|
|
|
| 538 |
|
| 539 |
# Sidebar
|
| 540 |
with st.sidebar:
|
| 541 |
+
st.markdown("#### Configuration")
|
| 542 |
+
|
| 543 |
+
mode = st.selectbox(
|
| 544 |
+
"Planning Mode",
|
| 545 |
+
["direct", "planner", "council", "extended_prompt"],
|
| 546 |
+
format_func=lambda x: {
|
| 547 |
+
"direct": "Direct",
|
| 548 |
+
"planner": "Planner (single LLM call)",
|
| 549 |
+
"council": "Council (3-way merge)",
|
| 550 |
+
"extended_prompt": "Extended (3x tokens)",
|
| 551 |
+
}[x],
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
st.divider()
|
| 555 |
st.markdown("#### Examples")
|
| 556 |
for dname, prompts in EXAMPLE_PROMPTS.items():
|
| 557 |
icon = DOMAIN_ICONS.get(dname.lower(), "\U0001f4cd")
|
|
|
|
| 561 |
st.session_state["prompt_input"] = p
|
| 562 |
|
| 563 |
st.divider()
|
| 564 |
+
mode_desc = {
|
| 565 |
+
"direct": "Prompt used directly for all modalities",
|
| 566 |
+
"planner": "LLM creates a semantic plan with image/audio prompts",
|
| 567 |
+
"council": "3 LLM calls merged for richer planning",
|
| 568 |
+
"extended_prompt": "Single LLM call with 3x token budget",
|
| 569 |
+
}
|
| 570 |
st.markdown(
|
| 571 |
+
f'<div class="sidebar-info">'
|
| 572 |
+
f'<b>Text</b> HF Inference API<br>'
|
| 573 |
+
f'<b>Planning</b> {mode_desc[mode]}<br>'
|
| 574 |
+
f'<b>Image</b> CLIP retrieval (57 images)<br>'
|
| 575 |
+
f'<b>Audio</b> CLAP retrieval (104 clips)<br><br>'
|
| 576 |
+
f'<b>Metric</b> MSCI = 0.45 × s<sub>t,i</sub> + 0.45 × s<sub>t,a</sub><br><br>'
|
| 577 |
+
f'<b>Models</b><br>'
|
| 578 |
+
f'CLIP ViT-B/32 (text-image)<br>'
|
| 579 |
+
f'CLAP HTSAT-unfused (text-audio)'
|
| 580 |
+
f'</div>', unsafe_allow_html=True)
|
| 581 |
|
| 582 |
# Prompt input
|
| 583 |
default_prompt = st.session_state.get("prompt_input", "")
|
|
|
|
| 592 |
with bc1:
|
| 593 |
go = st.button("Generate Bundle", type="primary", use_container_width=True, disabled=not prompt.strip())
|
| 594 |
with bc2:
|
| 595 |
+
mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
|
| 596 |
+
mcls = "chip-amber" if mode != "direct" else "chip-purple"
|
| 597 |
+
mdot = "chip-dot-amber" if mode != "direct" else "chip-dot-purple"
|
| 598 |
st.markdown(
|
| 599 |
+
f'<div class="chip-row">'
|
| 600 |
+
f'<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
|
| 601 |
+
f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
|
| 602 |
+
f'</div>', unsafe_allow_html=True)
|
| 603 |
|
| 604 |
# Welcome state
|
| 605 |
if not go and "last_result" not in st.session_state:
|
|
|
|
| 612 |
return
|
| 613 |
|
| 614 |
if go and prompt.strip():
|
| 615 |
+
st.session_state["last_result"] = run_pipeline(prompt.strip(), mode)
|
| 616 |
|
| 617 |
if "last_result" in st.session_state:
|
| 618 |
show_results(st.session_state["last_result"])
|
|
|
|
| 622 |
# Pipeline
|
| 623 |
# ---------------------------------------------------------------------------
|
| 624 |
|
| 625 |
+
def run_pipeline(prompt: str, mode: str) -> dict:
|
| 626 |
+
R: dict = {"mode": mode}
|
| 627 |
t_all = time.time()
|
| 628 |
|
| 629 |
+
# 1) Text + Planning
|
| 630 |
+
plan_label = "Generating text..." if mode == "direct" else f"Planning ({mode}) + generating text..."
|
| 631 |
+
with st.status(plan_label, expanded=True) as s:
|
| 632 |
t0 = time.time()
|
| 633 |
try:
|
| 634 |
+
R["text"] = gen_text(prompt, mode)
|
| 635 |
R["t_text"] = time.time() - t0
|
| 636 |
+
has_plan = R["text"].get("plan") is not None
|
| 637 |
+
lbl = f"Text ready ({R['t_text']:.1f}s)"
|
| 638 |
+
if has_plan:
|
| 639 |
+
lbl = f"Plan + text ready ({R['t_text']:.1f}s)"
|
| 640 |
+
s.update(label=lbl, state="complete")
|
| 641 |
except Exception as e:
|
| 642 |
s.update(label=f"Text failed: {e}", state="error")
|
| 643 |
R["text"] = {"text": prompt, "image_prompt": prompt, "audio_prompt": prompt}
|
|
|
|
| 787 |
st.markdown("---")
|
| 788 |
|
| 789 |
# Expandable details
|
| 790 |
+
with st.expander("Semantic Plan"):
|
| 791 |
+
td = R.get("text", {})
|
| 792 |
+
plan = td.get("plan")
|
| 793 |
+
if plan:
|
| 794 |
+
p1, p2 = st.columns(2)
|
| 795 |
+
with p1:
|
| 796 |
+
dash = "\u2014"
|
| 797 |
+
dot = "\u00b7"
|
| 798 |
+
scene = plan.get("scene_summary", dash)
|
| 799 |
+
domain = plan.get("domain", dash)
|
| 800 |
+
core = plan.get("core_semantics", {})
|
| 801 |
+
setting = core.get("setting", dash)
|
| 802 |
+
tod = core.get("time_of_day", dash)
|
| 803 |
+
weather = core.get("weather", dash)
|
| 804 |
+
subjects = ", ".join(core.get("main_subjects", []))
|
| 805 |
+
st.markdown(f"**Scene** {scene}")
|
| 806 |
+
st.markdown(f"**Domain** {domain}")
|
| 807 |
+
st.markdown(f"**Setting** {setting} {dot} **Time** {tod} {dot} **Weather** {weather}")
|
| 808 |
+
st.markdown(f"**Subjects** {subjects}")
|
| 809 |
+
with p2:
|
| 810 |
+
st.markdown("**Image prompt**")
|
| 811 |
+
st.code(td.get("image_prompt", ""), language=None)
|
| 812 |
+
st.markdown("**Audio prompt**")
|
| 813 |
+
st.code(td.get("audio_prompt", ""), language=None)
|
| 814 |
+
else:
|
| 815 |
+
mode = R.get("mode", "direct")
|
| 816 |
+
if mode == "direct":
|
| 817 |
+
st.write("Direct mode \u2014 no semantic plan. Prompt used as-is for all modalities.")
|
| 818 |
+
else:
|
| 819 |
+
st.write(f"Planning ({mode}) did not produce a valid plan. Fell back to direct mode.")
|
| 820 |
+
|
| 821 |
with st.expander("Retrieval Details"):
|
| 822 |
r1, r2 = st.columns(2)
|
| 823 |
with r1:
|