pratik-250620 commited on
Commit
5f2e51b
·
verified ·
1 Parent(s): 6835659

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +300 -36
app.py CHANGED
@@ -5,17 +5,19 @@ Live demonstration of multimodal generation + coherence evaluation.
5
  Enter a scene description and the system produces coherent text, image,
6
  and audio with real-time MSCI scoring.
7
 
8
- Pipeline: HF Inference API (text) + CLIP retrieval (image) + CLAP retrieval (audio)
 
9
  """
10
 
11
  from __future__ import annotations
12
 
 
13
  import logging
14
  import os
15
  import sys
16
  import time
17
  from pathlib import Path
18
- from typing import Optional
19
 
20
  import streamlit as st
21
 
@@ -71,10 +73,14 @@ html, body, [class*="css"] { font-family: 'Inter', -apple-system, sans-serif; }
71
  font-size: 0.7rem; font-weight: 600; letter-spacing: 0.03em;
72
  }
73
  .chip-purple { background: rgba(129,140,248,0.14); color: #a5b4fc; }
 
74
  .chip-green { background: rgba(52,211,153,0.14); color: #6ee7b7; }
 
75
  .chip-dot { width: 6px; height: 6px; border-radius: 50%; }
76
  .chip-dot-purple { background: #818cf8; }
 
77
  .chip-dot-green { background: #34d399; }
 
78
 
79
  .scores-grid {
80
  display: grid; grid-template-columns: repeat(4, 1fr);
@@ -196,6 +202,82 @@ EXAMPLE_PROMPTS = {
196
  }
197
  DOMAIN_ICONS = {"nature": "\U0001f33f", "urban": "\U0001f3d9\ufe0f", "water": "\U0001f30a", "mixed": "\U0001f310", "other": "\U0001f4cd"}
198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
  # ---------------------------------------------------------------------------
201
  # Cached model loading
@@ -223,12 +305,136 @@ def get_inference_client():
223
  return InferenceClient(token=token)
224
 
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  # ---------------------------------------------------------------------------
227
  # Generation / retrieval functions
228
  # ---------------------------------------------------------------------------
229
 
230
- def gen_text_hf(prompt: str) -> dict:
231
- """Generate descriptive text using HF Inference API."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  system_prompt = (
233
  "You are a concise descriptive writer. "
234
  "Write a literal description of the scene in 3 to 5 natural sentences. "
@@ -236,21 +442,19 @@ def gen_text_hf(prompt: str) -> dict:
236
  "Focus on concrete visual details AND the likely audio ambience."
237
  )
238
  try:
239
- client = get_inference_client()
240
- response = client.chat_completion(
241
- messages=[
242
- {"role": "system", "content": system_prompt},
243
- {"role": "user", "content": f"Describe this scene: {prompt}"},
244
- ],
245
- max_tokens=250,
246
- )
247
- text = response.choices[0].message.content.strip()
248
  if not text:
249
  raise ValueError("Empty response")
250
- return {"text": text, "image_prompt": prompt, "audio_prompt": prompt, "plan": None}
251
  except Exception as e:
252
- logger.warning("HF Inference API failed: %s — using prompt as text", e)
253
- return {"text": prompt, "image_prompt": prompt, "audio_prompt": prompt, "plan": None}
 
 
 
 
 
 
 
254
 
255
 
256
  def retrieve_image(prompt: str) -> dict:
@@ -334,6 +538,20 @@ def main():
334
 
335
  # Sidebar
336
  with st.sidebar:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  st.markdown("#### Examples")
338
  for dname, prompts in EXAMPLE_PROMPTS.items():
339
  icon = DOMAIN_ICONS.get(dname.lower(), "\U0001f4cd")
@@ -343,16 +561,23 @@ def main():
343
  st.session_state["prompt_input"] = p
344
 
345
  st.divider()
 
 
 
 
 
 
346
  st.markdown(
347
- '<div class="sidebar-info">'
348
- '<b>Text</b> HF Inference API<br>'
349
- '<b>Image</b> CLIP retrieval (57 images)<br>'
350
- '<b>Audio</b> CLAP retrieval (104 clips)<br><br>'
351
- '<b>Metric</b> MSCI = 0.45 &times; s<sub>t,i</sub> + 0.45 &times; s<sub>t,a</sub><br><br>'
352
- '<b>Models</b><br>'
353
- 'CLIP ViT-B/32 (text-image)<br>'
354
- 'CLAP HTSAT-unfused (text-audio)'
355
- '</div>', unsafe_allow_html=True)
 
356
 
357
  # Prompt input
358
  default_prompt = st.session_state.get("prompt_input", "")
@@ -367,11 +592,14 @@ def main():
367
  with bc1:
368
  go = st.button("Generate Bundle", type="primary", use_container_width=True, disabled=not prompt.strip())
369
  with bc2:
 
 
 
370
  st.markdown(
371
- '<div class="chip-row">'
372
- '<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
373
- '<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>CLIP + CLAP</span>'
374
- '</div>', unsafe_allow_html=True)
375
 
376
  # Welcome state
377
  if not go and "last_result" not in st.session_state:
@@ -384,7 +612,7 @@ def main():
384
  return
385
 
386
  if go and prompt.strip():
387
- st.session_state["last_result"] = run_pipeline(prompt.strip())
388
 
389
  if "last_result" in st.session_state:
390
  show_results(st.session_state["last_result"])
@@ -394,17 +622,22 @@ def main():
394
  # Pipeline
395
  # ---------------------------------------------------------------------------
396
 
397
- def run_pipeline(prompt: str) -> dict:
398
- R: dict = {}
399
  t_all = time.time()
400
 
401
- # 1) Text
402
- with st.status("Generating text...", expanded=True) as s:
 
403
  t0 = time.time()
404
  try:
405
- R["text"] = gen_text_hf(prompt)
406
  R["t_text"] = time.time() - t0
407
- s.update(label=f"Text ready ({R['t_text']:.1f}s)", state="complete")
 
 
 
 
408
  except Exception as e:
409
  s.update(label=f"Text failed: {e}", state="error")
410
  R["text"] = {"text": prompt, "image_prompt": prompt, "audio_prompt": prompt}
@@ -554,6 +787,37 @@ def show_results(R: dict):
554
  st.markdown("---")
555
 
556
  # Expandable details
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  with st.expander("Retrieval Details"):
558
  r1, r2 = st.columns(2)
559
  with r1:
 
5
  Enter a scene description and the system produces coherent text, image,
6
  and audio with real-time MSCI scoring.
7
 
8
+ Pipeline: HF Inference API (text + planning) + CLIP retrieval (image) + CLAP retrieval (audio)
9
+ Planning modes: direct, planner, council (3-way), extended_prompt (3x tokens)
10
  """
11
 
12
  from __future__ import annotations
13
 
14
+ import json
15
  import logging
16
  import os
17
  import sys
18
  import time
19
  from pathlib import Path
20
+ from typing import Any, Dict, Optional
21
 
22
  import streamlit as st
23
 
 
73
  font-size: 0.7rem; font-weight: 600; letter-spacing: 0.03em;
74
  }
75
  .chip-purple { background: rgba(129,140,248,0.14); color: #a5b4fc; }
76
+ .chip-pink { background: rgba(244,114,182,0.14); color: #f9a8d4; }
77
  .chip-green { background: rgba(52,211,153,0.14); color: #6ee7b7; }
78
+ .chip-amber { background: rgba(251,191,36,0.12); color: #fcd34d; }
79
  .chip-dot { width: 6px; height: 6px; border-radius: 50%; }
80
  .chip-dot-purple { background: #818cf8; }
81
+ .chip-dot-pink { background: #f472b6; }
82
  .chip-dot-green { background: #34d399; }
83
+ .chip-dot-amber { background: #fbbf24; }
84
 
85
  .scores-grid {
86
  display: grid; grid-template-columns: repeat(4, 1fr);
 
202
  }
203
  DOMAIN_ICONS = {"nature": "\U0001f33f", "urban": "\U0001f3d9\ufe0f", "water": "\U0001f30a", "mixed": "\U0001f310", "other": "\U0001f4cd"}
204
 
205
+ # ---------------------------------------------------------------------------
206
+ # Planning prompt template (same as src/planner/prompts/unified.txt)
207
+ # ---------------------------------------------------------------------------
208
+ PLAN_PROMPT_TEMPLATE = """You must produce a SINGLE valid JSON object.
209
+
210
+ RULES:
211
+ - Every field MUST exist
212
+ - Fields that represent lists MUST be arrays
213
+ - Strings must never be arrays
214
+ - Use short phrases, not long paragraphs
215
+ - Do NOT include explanations
216
+ - Do NOT include markdown
217
+ - Do NOT truncate
218
+
219
+ Schema:
220
+ {
221
+ "scene_summary": string,
222
+ "domain": string,
223
+
224
+ "core_semantics": {
225
+ "setting": string,
226
+ "time_of_day": string,
227
+ "weather": string,
228
+ "main_subjects": [string],
229
+ "actions": [string]
230
+ },
231
+
232
+ "style_controls": {
233
+ "visual_style": [string],
234
+ "color_palette": [string],
235
+ "lighting": [string],
236
+ "camera": [string],
237
+ "mood_emotion": [string],
238
+ "narrative_tone": [string]
239
+ },
240
+
241
+ "image_constraints": {
242
+ "must_include": [string],
243
+ "must_avoid": [string],
244
+ "objects": [string],
245
+ "environment_details": [string],
246
+ "composition": [string]
247
+ },
248
+
249
+ "audio_constraints": {
250
+ "audio_intent": [string],
251
+ "sound_sources": [string],
252
+ "ambience": [string],
253
+ "tempo": string,
254
+ "must_include": [string],
255
+ "must_avoid": [string]
256
+ },
257
+
258
+ "text_constraints": {
259
+ "must_include": [string],
260
+ "must_avoid": [string],
261
+ "keywords": [string],
262
+ "length": string
263
+ }
264
+ }
265
+
266
+ User request:
267
+ """
268
+
269
+ EXTENDED_PLAN_SYSTEM = """You are an expert multimodal content planner. Create a detailed,
270
+ comprehensive semantic plan for generating coherent multimodal content (text, image, audio).
271
+
272
+ You have an extended budget. Take your time to:
273
+ 1. Deeply analyze the user's request
274
+ 2. Consider multiple perspectives and interpretations
275
+ 3. Ensure semantic consistency across all modalities
276
+ 4. Provide rich, detailed specifications
277
+
278
+ Think step by step about what visual elements, sounds, and descriptive text would best represent the scene.
279
+ After your analysis, produce a SINGLE valid JSON object matching the schema."""
280
+
281
 
282
  # ---------------------------------------------------------------------------
283
  # Cached model loading
 
305
  return InferenceClient(token=token)
306
 
307
 
308
+ # ---------------------------------------------------------------------------
309
+ # HF Inference API helpers
310
+ # ---------------------------------------------------------------------------
311
+
312
+ def _hf_chat(system: str, user: str, max_tokens: int = 500, temperature: float = 0.3) -> str:
313
+ """Call HF Inference API chat completion."""
314
+ client = get_inference_client()
315
+ response = client.chat_completion(
316
+ messages=[
317
+ {"role": "system", "content": system},
318
+ {"role": "user", "content": user},
319
+ ],
320
+ max_tokens=max_tokens,
321
+ temperature=temperature,
322
+ )
323
+ return response.choices[0].message.content.strip()
324
+
325
+
326
+ def _parse_plan_json(raw: str) -> Optional[Dict[str, Any]]:
327
+ """Parse a semantic plan JSON from LLM output, with repair."""
328
+ from src.utils.json_repair import try_repair_json
329
+ return try_repair_json(raw)
330
+
331
+
332
+ def _validate_and_build_plan(data: Dict[str, Any]):
333
+ """Validate and build a SemanticPlan from dict."""
334
+ from src.planner.validation import validate_semantic_plan_dict
335
+ from src.planner.schema import SemanticPlan
336
+ validate_semantic_plan_dict(data)
337
+ return SemanticPlan(**data)
338
+
339
+
340
+ # ---------------------------------------------------------------------------
341
+ # Planning functions (HF Inference API)
342
+ # ---------------------------------------------------------------------------
343
+
344
+ def plan_single(prompt: str) -> Optional[Any]:
345
+ """Single planner call via HF API. Returns SemanticPlan or None."""
346
+ system = "You are a multimodal content planner. Output ONLY valid JSON, no explanations."
347
+ user = PLAN_PROMPT_TEMPLATE + prompt
348
+ try:
349
+ raw = _hf_chat(system, user, max_tokens=1200, temperature=0.3)
350
+ data = _parse_plan_json(raw)
351
+ if data:
352
+ return _validate_and_build_plan(data)
353
+ except Exception as e:
354
+ logger.warning("Planner call failed: %s", e)
355
+ return None
356
+
357
+
358
+ def plan_council(prompt: str) -> Optional[Any]:
359
+ """Council mode: 3 planner calls merged. Returns SemanticPlan or None."""
360
+ plans = []
361
+ temps = [0.2, 0.4, 0.5] # Slightly different temperatures for diversity
362
+ system = "You are a multimodal content planner. Output ONLY valid JSON, no explanations."
363
+ user = PLAN_PROMPT_TEMPLATE + prompt
364
+
365
+ for temp in temps:
366
+ try:
367
+ raw = _hf_chat(system, user, max_tokens=1200, temperature=temp)
368
+ data = _parse_plan_json(raw)
369
+ if data:
370
+ plan = _validate_and_build_plan(data)
371
+ plans.append(plan)
372
+ except Exception as e:
373
+ logger.warning("Council call failed (temp=%.1f): %s", temp, e)
374
+
375
+ if not plans:
376
+ return None
377
+ if len(plans) == 1:
378
+ return plans[0]
379
+
380
+ # Merge using existing merge logic
381
+ try:
382
+ from src.planner.merge_logic import merge_council_plans
383
+ while len(plans) < 3:
384
+ plans.append(plans[0]) # Pad if fewer than 3
385
+ merged, _ = merge_council_plans(plans[0], plans[1], plans[2])
386
+ return merged
387
+ except Exception as e:
388
+ logger.warning("Merge failed: %s — using first plan", e)
389
+ return plans[0]
390
+
391
+
392
+ def plan_extended(prompt: str) -> Optional[Any]:
393
+ """Extended prompt mode: longer system prompt, more tokens. Returns SemanticPlan or None."""
394
+ user = PLAN_PROMPT_TEMPLATE + prompt
395
+ try:
396
+ raw = _hf_chat(EXTENDED_PLAN_SYSTEM, user, max_tokens=2000, temperature=0.35)
397
+ data = _parse_plan_json(raw)
398
+ if data:
399
+ return _validate_and_build_plan(data)
400
+ except Exception as e:
401
+ logger.warning("Extended planner failed: %s", e)
402
+ return None
403
+
404
+
405
  # ---------------------------------------------------------------------------
406
  # Generation / retrieval functions
407
  # ---------------------------------------------------------------------------
408
 
409
+ def gen_text(prompt: str, mode: str) -> dict:
410
+ """Generate text and optional plan using HF Inference API."""
411
+ # Step 1: Plan (if not direct mode)
412
+ plan = None
413
+ image_prompt = prompt
414
+ audio_prompt = prompt
415
+
416
+ if mode == "planner":
417
+ plan = plan_single(prompt)
418
+ elif mode == "council":
419
+ plan = plan_council(prompt)
420
+ elif mode == "extended_prompt":
421
+ plan = plan_extended(prompt)
422
+
423
+ # Extract modality-specific prompts from plan
424
+ if plan is not None:
425
+ try:
426
+ from src.planner.schema_to_text import plan_to_prompts
427
+ prompts = plan_to_prompts(plan)
428
+ image_prompt = prompts["image_prompt"]
429
+ audio_prompt = prompts["audio_prompt"]
430
+ text_input = prompts["text_prompt"]
431
+ except Exception as e:
432
+ logger.warning("plan_to_prompts failed: %s", e)
433
+ text_input = prompt
434
+ else:
435
+ text_input = prompt
436
+
437
+ # Step 2: Generate text via HF API
438
  system_prompt = (
439
  "You are a concise descriptive writer. "
440
  "Write a literal description of the scene in 3 to 5 natural sentences. "
 
442
  "Focus on concrete visual details AND the likely audio ambience."
443
  )
444
  try:
445
+ text = _hf_chat(system_prompt, f"Describe this scene: {text_input}", max_tokens=250, temperature=0.7)
 
 
 
 
 
 
 
 
446
  if not text:
447
  raise ValueError("Empty response")
 
448
  except Exception as e:
449
+ logger.warning("HF text gen failed: %s — using prompt", e)
450
+ text = prompt
451
+
452
+ return {
453
+ "text": text,
454
+ "image_prompt": image_prompt,
455
+ "audio_prompt": audio_prompt,
456
+ "plan": plan.model_dump() if plan and hasattr(plan, "model_dump") else None,
457
+ }
458
 
459
 
460
  def retrieve_image(prompt: str) -> dict:
 
538
 
539
  # Sidebar
540
  with st.sidebar:
541
+ st.markdown("#### Configuration")
542
+
543
+ mode = st.selectbox(
544
+ "Planning Mode",
545
+ ["direct", "planner", "council", "extended_prompt"],
546
+ format_func=lambda x: {
547
+ "direct": "Direct",
548
+ "planner": "Planner (single LLM call)",
549
+ "council": "Council (3-way merge)",
550
+ "extended_prompt": "Extended (3x tokens)",
551
+ }[x],
552
+ )
553
+
554
+ st.divider()
555
  st.markdown("#### Examples")
556
  for dname, prompts in EXAMPLE_PROMPTS.items():
557
  icon = DOMAIN_ICONS.get(dname.lower(), "\U0001f4cd")
 
561
  st.session_state["prompt_input"] = p
562
 
563
  st.divider()
564
+ mode_desc = {
565
+ "direct": "Prompt used directly for all modalities",
566
+ "planner": "LLM creates a semantic plan with image/audio prompts",
567
+ "council": "3 LLM calls merged for richer planning",
568
+ "extended_prompt": "Single LLM call with 3x token budget",
569
+ }
570
  st.markdown(
571
+ f'<div class="sidebar-info">'
572
+ f'<b>Text</b> HF Inference API<br>'
573
+ f'<b>Planning</b> {mode_desc[mode]}<br>'
574
+ f'<b>Image</b> CLIP retrieval (57 images)<br>'
575
+ f'<b>Audio</b> CLAP retrieval (104 clips)<br><br>'
576
+ f'<b>Metric</b> MSCI = 0.45 &times; s<sub>t,i</sub> + 0.45 &times; s<sub>t,a</sub><br><br>'
577
+ f'<b>Models</b><br>'
578
+ f'CLIP ViT-B/32 (text-image)<br>'
579
+ f'CLAP HTSAT-unfused (text-audio)'
580
+ f'</div>', unsafe_allow_html=True)
581
 
582
  # Prompt input
583
  default_prompt = st.session_state.get("prompt_input", "")
 
592
  with bc1:
593
  go = st.button("Generate Bundle", type="primary", use_container_width=True, disabled=not prompt.strip())
594
  with bc2:
595
+ mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
596
+ mcls = "chip-amber" if mode != "direct" else "chip-purple"
597
+ mdot = "chip-dot-amber" if mode != "direct" else "chip-dot-purple"
598
  st.markdown(
599
+ f'<div class="chip-row">'
600
+ f'<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
601
+ f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
602
+ f'</div>', unsafe_allow_html=True)
603
 
604
  # Welcome state
605
  if not go and "last_result" not in st.session_state:
 
612
  return
613
 
614
  if go and prompt.strip():
615
+ st.session_state["last_result"] = run_pipeline(prompt.strip(), mode)
616
 
617
  if "last_result" in st.session_state:
618
  show_results(st.session_state["last_result"])
 
622
  # Pipeline
623
  # ---------------------------------------------------------------------------
624
 
625
+ def run_pipeline(prompt: str, mode: str) -> dict:
626
+ R: dict = {"mode": mode}
627
  t_all = time.time()
628
 
629
+ # 1) Text + Planning
630
+ plan_label = "Generating text..." if mode == "direct" else f"Planning ({mode}) + generating text..."
631
+ with st.status(plan_label, expanded=True) as s:
632
  t0 = time.time()
633
  try:
634
+ R["text"] = gen_text(prompt, mode)
635
  R["t_text"] = time.time() - t0
636
+ has_plan = R["text"].get("plan") is not None
637
+ lbl = f"Text ready ({R['t_text']:.1f}s)"
638
+ if has_plan:
639
+ lbl = f"Plan + text ready ({R['t_text']:.1f}s)"
640
+ s.update(label=lbl, state="complete")
641
  except Exception as e:
642
  s.update(label=f"Text failed: {e}", state="error")
643
  R["text"] = {"text": prompt, "image_prompt": prompt, "audio_prompt": prompt}
 
787
  st.markdown("---")
788
 
789
  # Expandable details
790
+ with st.expander("Semantic Plan"):
791
+ td = R.get("text", {})
792
+ plan = td.get("plan")
793
+ if plan:
794
+ p1, p2 = st.columns(2)
795
+ with p1:
796
+ dash = "\u2014"
797
+ dot = "\u00b7"
798
+ scene = plan.get("scene_summary", dash)
799
+ domain = plan.get("domain", dash)
800
+ core = plan.get("core_semantics", {})
801
+ setting = core.get("setting", dash)
802
+ tod = core.get("time_of_day", dash)
803
+ weather = core.get("weather", dash)
804
+ subjects = ", ".join(core.get("main_subjects", []))
805
+ st.markdown(f"**Scene** {scene}")
806
+ st.markdown(f"**Domain** {domain}")
807
+ st.markdown(f"**Setting** {setting} {dot} **Time** {tod} {dot} **Weather** {weather}")
808
+ st.markdown(f"**Subjects** {subjects}")
809
+ with p2:
810
+ st.markdown("**Image prompt**")
811
+ st.code(td.get("image_prompt", ""), language=None)
812
+ st.markdown("**Audio prompt**")
813
+ st.code(td.get("audio_prompt", ""), language=None)
814
+ else:
815
+ mode = R.get("mode", "direct")
816
+ if mode == "direct":
817
+ st.write("Direct mode \u2014 no semantic plan. Prompt used as-is for all modalities.")
818
+ else:
819
+ st.write(f"Planning ({mode}) did not produce a valid plan. Fell back to direct mode.")
820
+
821
  with st.expander("Retrieval Details"):
822
  r1, r2 = st.columns(2)
823
  with r1: