pratik-250620 commited on
Commit
358d3bc
·
verified ·
1 Parent(s): 6da5a84

Upload folder using huggingface_hub

Browse files
app.py CHANGED
@@ -180,29 +180,743 @@ section[data-testid="stSidebar"] > div:first-child { padding-top: 1.2rem; }
180
  # Example prompts
181
  # ---------------------------------------------------------------------------
182
  EXAMPLE_PROMPTS = {
183
- "Nature": [
184
- "A peaceful forest at dawn with birdsong and morning mist",
185
- "A field of golden wheat under a warm summer sunset",
186
- "A dense jungle with exotic birds calling from the canopy",
187
- ],
188
- "Urban": [
189
- "A bustling city street at night with neon lights and traffic",
190
- "A quiet alley in an old town with distant footsteps echoing",
191
- "A cafe terrace on a busy boulevard with clinking glasses",
192
- ],
193
- "Water": [
194
- "Ocean waves crashing on a sandy beach at sunset",
195
- "Rain falling on a pond with ripples spreading across the surface",
196
- "A mountain stream flowing over rocks through a pine forest",
197
- ],
198
- "Mixed": [
199
- "A lighthouse on a cliff during a thunderstorm at night",
200
- "A bonfire on a beach with waves and guitar music at night",
201
- "A train passing through countryside with distant church bells",
202
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  }
204
  DOMAIN_ICONS = {"nature": "\U0001f33f", "urban": "\U0001f3d9\ufe0f", "water": "\U0001f30a", "mixed": "\U0001f310", "other": "\U0001f4cd"}
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  # ---------------------------------------------------------------------------
207
  # Planning prompt template (same as src/planner/prompts/unified.txt)
208
  # ---------------------------------------------------------------------------
@@ -306,22 +1020,128 @@ def get_inference_client():
306
  return InferenceClient(token=token)
307
 
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  # ---------------------------------------------------------------------------
310
  # HF Inference API helpers
311
  # ---------------------------------------------------------------------------
312
 
313
- TEXT_GEN_MODELS = [
 
314
  "mistralai/Mistral-7B-Instruct-v0.3",
 
 
 
 
315
  "HuggingFaceH4/zephyr-7b-beta",
316
  "microsoft/Phi-3-mini-4k-instruct",
317
- "meta-llama/Llama-3.2-3B-Instruct",
318
  ]
 
 
 
 
 
 
 
 
319
 
320
  def _hf_chat(system: str, user: str, max_tokens: int = 500, temperature: float = 0.3) -> str:
321
- """Call HF Inference API chat completion, trying multiple models."""
 
 
 
 
322
  client = get_inference_client()
323
  last_error = None
 
 
324
  for model_id in TEXT_GEN_MODELS:
 
 
 
 
325
  try:
326
  response = client.chat_completion(
327
  model=model_id,
@@ -337,9 +1157,15 @@ def _hf_chat(system: str, user: str, max_tokens: int = 500, temperature: float =
337
  return text
338
  except Exception as e:
339
  last_error = e
340
- logger.warning("Chat model %s failed: %s", model_id, e)
 
 
 
 
341
  continue
342
- raise RuntimeError(f"All text models failed. Last error: {last_error}")
 
 
343
 
344
 
345
  def _parse_plan_json(raw: str) -> Optional[Dict[str, Any]]:
@@ -425,11 +1251,14 @@ def plan_extended(prompt: str) -> Optional[Any]:
425
  # Generation / retrieval functions
426
  # ---------------------------------------------------------------------------
427
 
428
- # HF Inference API model IDs
429
- IMAGE_GEN_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
 
 
 
430
  AUDIO_GEN_MODELS = [
431
- "cvssp/audioldm2",
432
- "facebook/musicgen-small",
433
  ]
434
 
435
  def gen_text(prompt: str, mode: str) -> dict:
@@ -487,25 +1316,44 @@ def gen_text(prompt: str, mode: str) -> dict:
487
 
488
 
489
  def generate_image(prompt: str) -> dict:
490
- """Generate image via HF Inference API (SDXL), fallback to retrieval."""
491
  client = get_inference_client()
492
- try:
493
- image = client.text_to_image(prompt, model=IMAGE_GEN_MODEL)
494
- tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
495
- image.save(tmp.name)
496
- return {
497
- "path": tmp.name, "backend": "generative",
498
- "model": "SDXL", "failed": False,
499
- }
500
- except Exception as e:
501
- logger.warning("Image generation failed: %s — falling back to retrieval", e)
502
- return retrieve_image(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
 
504
 
505
  def generate_audio(prompt: str) -> dict:
506
- """Generate audio via HF Inference API, fallback to retrieval."""
507
  client = get_inference_client()
 
508
  for model_id in AUDIO_GEN_MODELS:
 
 
 
509
  try:
510
  audio_bytes = client.text_to_audio(prompt, model=model_id)
511
  suffix = ".flac" if "musicgen" in model_id else ".wav"
@@ -514,7 +1362,6 @@ def generate_audio(prompt: str) -> dict:
514
  tmp.write(audio_bytes)
515
  tmp.flush()
516
  else:
517
- # Some API versions return object with .read() or similar
518
  tmp.write(bytes(audio_bytes))
519
  tmp.flush()
520
  model_name = model_id.split("/")[-1]
@@ -523,11 +1370,17 @@ def generate_audio(prompt: str) -> dict:
523
  "model": model_name, "failed": False,
524
  }
525
  except Exception as e:
526
- logger.warning("Audio gen with %s failed: %s", model_id, e)
 
 
 
 
527
  continue
528
- # All generative models failed — fall back to retrieval
529
  logger.warning("All audio generation models failed — falling back to retrieval")
530
- return retrieve_audio(prompt)
 
 
 
531
 
532
 
533
  def retrieve_image(prompt: str) -> dict:
@@ -599,31 +1452,36 @@ def main():
599
  layout="wide",
600
  initial_sidebar_state="expanded",
601
  )
602
- st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
603
 
604
- # Hero
605
- st.markdown(
606
- '<div class="hero-wrap">'
607
- '<div class="hero-title">Multimodal Coherence AI</div>'
608
- '<div class="hero-sub">Generate semantically coherent <b>text + image + audio</b> bundles '
609
- 'and evaluate cross-modal alignment with the <b>MSCI</b> metric.</div>'
610
- '</div>', unsafe_allow_html=True)
611
-
612
- # Sidebar
613
  with st.sidebar:
614
  st.markdown("#### Configuration")
615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  backend = st.selectbox(
617
- "Backend",
618
  ["generative", "retrieval"],
619
  format_func=lambda x: {
620
- "generative": "Generative (SDXL + AudioLDM2)",
621
  "retrieval": "Retrieval (CLIP + CLAP index)",
622
  }[x],
623
  )
624
 
625
  mode = st.selectbox(
626
- "Planning Mode",
627
  ["direct", "planner", "council", "extended_prompt"],
628
  format_func=lambda x: {
629
  "direct": "Direct",
@@ -634,13 +1492,25 @@ def main():
634
  )
635
 
636
  st.divider()
637
- st.markdown("#### Examples")
638
- for dname, prompts in EXAMPLE_PROMPTS.items():
639
- icon = DOMAIN_ICONS.get(dname.lower(), "\U0001f4cd")
640
- with st.expander(f"{icon} {dname}"):
641
- for p in prompts:
642
- if st.button(p, key=f"ex_{hash(p)}", use_container_width=True):
643
- st.session_state["prompt_input"] = p
 
 
 
 
 
 
 
 
 
 
 
 
644
 
645
  st.divider()
646
  mode_desc = {
@@ -650,35 +1520,57 @@ def main():
650
  "extended_prompt": "Single LLM call with 3x token budget",
651
  }
652
  if backend == "generative":
653
- img_info = "SDXL via HF API"
654
- aud_info = "AudioLDM2 / MusicGen via HF API"
655
  else:
656
  img_info = "CLIP retrieval (57 images)"
657
  aud_info = "CLAP retrieval (104 clips)"
 
658
  st.markdown(
659
  f'<div class="sidebar-info">'
660
  f'<b>Text</b> HF Inference API<br>'
661
  f'<b>Planning</b> {mode_desc[mode]}<br>'
662
  f'<b>Image</b> {img_info}<br>'
663
- f'<b>Audio</b> {aud_info}<br><br>'
664
  f'<b>Metric</b> MSCI = 0.45 &times; s<sub>t,i</sub> + 0.45 &times; s<sub>t,a</sub><br><br>'
665
  f'<b>Models</b><br>'
666
  f'CLIP ViT-B/32 (coherence eval)<br>'
667
  f'CLAP HTSAT-unfused (coherence eval)'
668
  f'</div>', unsafe_allow_html=True)
669
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
670
  # Prompt input
671
  default_prompt = st.session_state.get("prompt_input", "")
672
  prompt = st.text_area(
673
  "Scene", value=default_prompt, height=80,
674
- placeholder="Describe a scene... e.g., 'A peaceful forest at dawn with birdsong and morning mist'",
675
  label_visibility="collapsed",
676
  )
677
 
678
  # Button + chips
679
  bc1, bc2 = st.columns([1, 3])
680
  with bc1:
681
- go = st.button("Generate Bundle", type="primary", use_container_width=True, disabled=not prompt.strip())
682
  with bc2:
683
  mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
684
  mcls = "chip-amber" if mode != "direct" else "chip-purple"
@@ -687,27 +1579,45 @@ def main():
687
  bchip = '<span class="chip chip-pink"><span class="chip-dot chip-dot-pink"></span>Generative</span>'
688
  else:
689
  bchip = '<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
 
 
 
 
 
 
690
  st.markdown(
691
  f'<div class="chip-row">'
692
  f'{bchip}'
693
  f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
694
  f'<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>CLIP + CLAP</span>'
 
695
  f'</div>', unsafe_allow_html=True)
696
 
697
  # Welcome state
698
  if not go and "last_result" not in st.session_state:
699
- st.markdown(
700
- '<div class="welcome">'
701
- '<div class="welcome-icons">\U0001f3a8 \U0001f5bc\ufe0f \U0001f50a</div>'
702
- '<div class="welcome-text">Enter a scene description and click <b>Generate Bundle</b></div>'
703
- '<div class="welcome-hint">or pick an example from the sidebar</div>'
704
- '</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
705
  return
706
 
707
  if go and prompt.strip():
708
- st.session_state["last_result"] = run_pipeline(prompt.strip(), mode, backend)
 
709
 
710
  if "last_result" in st.session_state:
 
 
711
  show_results(st.session_state["last_result"])
712
 
713
 
@@ -715,16 +1625,29 @@ def main():
715
  # Pipeline
716
  # ---------------------------------------------------------------------------
717
 
718
- def run_pipeline(prompt: str, mode: str, backend: str = "generative") -> dict:
719
- R: dict = {"mode": mode, "backend": backend}
720
  t_all = time.time()
721
 
722
- # 1) Text + Planning
 
 
 
 
 
 
 
 
 
 
 
 
 
723
  plan_label = "Generating text..." if mode == "direct" else f"Planning ({mode}) + generating text..."
724
  with st.status(plan_label, expanded=True) as s:
725
  t0 = time.time()
726
  try:
727
- R["text"] = gen_text(prompt, mode)
728
  R["t_text"] = time.time() - t0
729
  has_plan = R["text"].get("plan") is not None
730
  lbl = f"Text ready ({R['t_text']:.1f}s)"
@@ -733,14 +1656,20 @@ def run_pipeline(prompt: str, mode: str, backend: str = "generative") -> dict:
733
  s.update(label=lbl, state="complete")
734
  except Exception as e:
735
  s.update(label=f"Text failed: {e}", state="error")
736
- R["text"] = {"text": prompt, "image_prompt": prompt, "audio_prompt": prompt}
737
  R["t_text"] = time.time() - t0
738
 
739
- ip = R["text"].get("image_prompt", prompt)
740
- ap = R["text"].get("audio_prompt", prompt)
 
 
 
 
 
 
741
 
742
  # 2) Image
743
- img_label = "Generating image (SDXL)..." if backend == "generative" else "Retrieving image..."
744
  with st.status(img_label, expanded=True) as s:
745
  t0 = time.time()
746
  try:
@@ -791,13 +1720,14 @@ def run_pipeline(prompt: str, mode: str, backend: str = "generative") -> dict:
791
  R["audio"] = None
792
  R["t_aud"] = time.time() - t0
793
 
794
- # 4) Coherence evaluation
795
  with st.status("Evaluating coherence...", expanded=True) as s:
796
  t0 = time.time()
797
  try:
798
  imgp = R.get("image", {}).get("path") if R.get("image") else None
799
  audp = R.get("audio", {}).get("path") if R.get("audio") else None
800
- R["coherence"] = eval_coherence(R["text"]["text"], imgp, audp)
 
801
  R["t_eval"] = time.time() - t0
802
  msci = R["coherence"].get("scores", {}).get("msci")
803
  s.update(label=f"MSCI = {msci:.4f} ({R['t_eval']:.1f}s)", state="complete")
@@ -821,23 +1751,52 @@ def show_results(R: dict):
821
  msci = sc.get("msci")
822
  st_i = sc.get("st_i")
823
  st_a = sc.get("st_a")
 
 
824
 
825
- # Score cards
826
- st.markdown('<div class="sec-label">Coherence Scores</div>', unsafe_allow_html=True)
827
- cards = (
828
- score_card_html("MSCI (Overall)", msci)
829
- + score_card_html("Text \u2192 Image", st_i)
830
- + score_card_html("Text \u2192 Audio", st_a)
831
- + score_card_html("Classification", msci, is_class=True)
832
- )
833
- st.markdown(f'<div class="scores-grid">{cards}</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
834
 
835
  # Timing strip
836
  tt = R.get("t_total", 0)
837
  sep = '<span class="t-sep">|</span>'
 
 
838
  st.markdown(
839
- f'<div class="timing">'
840
  f'<span class="t-total">Total {tt:.1f}s</span>{sep}'
 
841
  f'<span>Text {R.get("t_text", 0):.1f}s</span>{sep}'
842
  f'<span>Image {R.get("t_img", 0):.1f}s</span>{sep}'
843
  f'<span>Audio {R.get("t_aud", 0):.1f}s</span>{sep}'
@@ -846,141 +1805,194 @@ def show_results(R: dict):
846
 
847
  st.markdown("---")
848
 
 
 
 
 
849
  # Three columns: text | image | audio
850
  ct, ci, ca = st.columns([1.15, 1, 0.85])
851
 
852
  with ct:
853
- st.markdown('<div class="sec-label">Generated Text</div>', unsafe_allow_html=True)
854
  txt = R.get("text", {}).get("text", "")
855
  text_err = R.get("text", {}).get("text_error")
856
  if text_err:
857
- st.markdown(
858
- f'<div class="warn-banner"><b>Text gen failed</b> — {text_err}</div>',
859
- unsafe_allow_html=True)
860
- st.markdown(f'<div class="text-card">{txt}</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
861
 
862
  with ci:
863
- st.markdown('<div class="sec-label">Generated Image</div>', unsafe_allow_html=True)
864
  ii = R.get("image")
865
  if ii and ii.get("path"):
866
  ip = Path(ii["path"])
867
  backend = ii.get("backend", "unknown")
868
 
869
- if backend == "retrieval" and ii.get("failed", False):
870
- sim = ii.get("similarity", 0)
871
- st.markdown(
872
- f'<div class="warn-banner"><b>Retrieval fallback</b> '
873
- f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
874
- unsafe_allow_html=True)
 
 
 
 
 
 
875
 
876
  if ip.exists():
877
  st.image(str(ip), use_container_width=True)
878
  model = ii.get("model", "")
879
  if backend == "generative":
880
- st.caption(f"Generated via **{model}**")
 
 
881
  else:
882
  sim = ii.get("similarity", 0)
883
  dom = ii.get("domain", "other")
884
  ic = DOMAIN_ICONS.get(dom, "\U0001f4cd")
885
  st.caption(f"{ic} {dom} \u00b7 sim **{sim:.3f}** \u00b7 Retrieved")
886
  else:
887
- st.info("No image.")
888
 
889
  with ca:
890
- st.markdown('<div class="sec-label">Generated Audio</div>', unsafe_allow_html=True)
891
  ai = R.get("audio")
892
  if ai and ai.get("path"):
893
  ap = Path(ai["path"])
894
  backend = ai.get("backend", "unknown")
895
 
896
- if backend == "retrieval" and ai.get("failed", False):
897
- sim = ai.get("similarity", 0)
898
- st.markdown(
899
- f'<div class="warn-banner"><b>Retrieval fallback</b> '
900
- f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
901
- unsafe_allow_html=True)
 
 
 
 
 
 
902
 
903
  if ap.exists():
904
  st.audio(str(ap))
905
  model = ai.get("model", "")
906
  if backend == "generative":
907
- st.caption(f"Generated via **{model}**")
 
 
908
  else:
909
  sim = ai.get("similarity", 0)
910
  st.caption(f"sim **{sim:.3f}** \u00b7 Retrieved")
911
  else:
912
- st.info("No audio.")
913
 
914
  st.markdown("---")
915
 
916
- # Expandable details
917
- with st.expander("Semantic Plan"):
918
- td = R.get("text", {})
919
- plan = td.get("plan")
920
- if plan:
921
- p1, p2 = st.columns(2)
922
- with p1:
923
- dash = "\u2014"
924
- dot = "\u00b7"
925
- scene = plan.get("scene_summary", dash)
926
- domain = plan.get("domain", dash)
927
- core = plan.get("core_semantics", {})
928
- setting = core.get("setting", dash)
929
- tod = core.get("time_of_day", dash)
930
- weather = core.get("weather", dash)
931
- subjects = ", ".join(core.get("main_subjects", []))
932
- st.markdown(f"**Scene** {scene}")
933
- st.markdown(f"**Domain** {domain}")
934
- st.markdown(f"**Setting** {setting} {dot} **Time** {tod} {dot} **Weather** {weather}")
935
- st.markdown(f"**Subjects** {subjects}")
936
- with p2:
937
- st.markdown("**Image prompt**")
938
- st.code(td.get("image_prompt", ""), language=None)
939
- st.markdown("**Audio prompt**")
940
- st.code(td.get("audio_prompt", ""), language=None)
941
- else:
942
- mode = R.get("mode", "direct")
943
- if mode == "direct":
944
- st.write("Direct mode \u2014 no semantic plan. Prompt used as-is for all modalities.")
945
  else:
946
- st.write(f"Planning ({mode}) did not produce a valid plan. Fell back to direct mode.")
947
-
948
- with st.expander("Generation Details"):
949
- r1, r2 = st.columns(2)
950
- with r1:
951
- ii = R.get("image")
952
- if ii:
953
- backend = ii.get("backend", "unknown")
954
- model = ii.get("model", "")
955
- if backend == "generative":
956
- st.markdown(f"**Image** generated via **{model}**")
957
- st.markdown(f"Prompt: *{R.get('text', {}).get('image_prompt', '')}*")
958
- elif ii.get("top_5"):
959
- st.markdown("**Image** (retrieval fallback)")
960
- bars = "".join(sim_bar_html(n, s) for n, s in ii["top_5"])
961
- st.markdown(bars, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962
  else:
963
- st.write("No image data.")
964
- with r2:
965
- ai = R.get("audio")
966
- if ai:
967
- backend = ai.get("backend", "unknown")
968
- model = ai.get("model", "")
969
- if backend == "generative":
970
- st.markdown(f"**Audio** generated via **{model}**")
971
- st.markdown(f"Prompt: *{R.get('text', {}).get('audio_prompt', '')}*")
972
- elif ai.get("top_5"):
973
- st.markdown("**Audio** (retrieval fallback)")
974
- bars = "".join(sim_bar_html(n, s) for n, s in ai["top_5"])
975
- st.markdown(bars, unsafe_allow_html=True)
976
  else:
977
- st.write("No audio data.")
978
-
979
- with st.expander("Full Coherence Report"):
980
- if coh:
981
- st.json(coh)
982
- else:
983
- st.write("No data.")
984
 
985
 
986
  if __name__ == "__main__":
 
180
  # Example prompts
181
  # ---------------------------------------------------------------------------
182
  EXAMPLE_PROMPTS = {
183
+ "en": {
184
+ "Nature": [
185
+ "A peaceful forest at dawn with birdsong and morning mist",
186
+ "A field of golden wheat under a warm summer sunset",
187
+ "A dense jungle with exotic birds calling from the canopy",
188
+ ],
189
+ "Urban": [
190
+ "A bustling city street at night with neon lights and traffic",
191
+ "A quiet alley in an old town with distant footsteps echoing",
192
+ "A cafe terrace on a busy boulevard with clinking glasses",
193
+ ],
194
+ "Water": [
195
+ "Ocean waves crashing on a sandy beach at sunset",
196
+ "Rain falling on a pond with ripples spreading across the surface",
197
+ "A mountain stream flowing over rocks through a pine forest",
198
+ ],
199
+ "Mixed": [
200
+ "A lighthouse on a cliff during a thunderstorm at night",
201
+ "A bonfire on a beach with waves and guitar music at night",
202
+ "A train passing through countryside with distant church bells",
203
+ ],
204
+ },
205
+ "de": {
206
+ "Natur": [
207
+ "Ein friedlicher Wald bei Sonnenaufgang mit Vogelgesang und Morgennebel",
208
+ "Ein goldenes Weizenfeld unter einem warmen Sommerabend",
209
+ "Ein dichter Dschungel mit exotischen V\u00f6geln im Bl\u00e4tterdach",
210
+ ],
211
+ "Stadt": [
212
+ "Eine belebte Stra\u00dfe bei Nacht mit Neonlichtern und Verkehr",
213
+ "Eine ruhige Gasse in einer Altstadt mit fernen Schritten",
214
+ "Eine Caf\u00e9-Terrasse an einem belebten Boulevard mit klinkenden Gl\u00e4sern",
215
+ ],
216
+ "Wasser": [
217
+ "Meereswellen am Sandstrand bei Sonnenuntergang",
218
+ "Regen f\u00e4llt auf einen Teich mit sich ausbreitenden Wellen",
219
+ "Ein Bergbach flie\u00dft \u00fcber Felsen durch einen Kiefernwald",
220
+ ],
221
+ "Gemischt": [
222
+ "Ein Leuchtturm auf einer Klippe w\u00e4hrend eines Gewitters bei Nacht",
223
+ "Ein Lagerfeuer am Strand mit Wellen und Gitarrenmusik bei Nacht",
224
+ "Ein Zug f\u00e4hrt durch die Landschaft mit fernen Kirchenglocken",
225
+ ],
226
+ },
227
  }
228
  DOMAIN_ICONS = {"nature": "\U0001f33f", "urban": "\U0001f3d9\ufe0f", "water": "\U0001f30a", "mixed": "\U0001f310", "other": "\U0001f4cd"}
229
 
230
+ # ---------------------------------------------------------------------------
231
+ # Kid Mode — example prompts (German, fun themes for children)
232
+ # ---------------------------------------------------------------------------
233
+ KID_EXAMPLE_PROMPTS = {
234
+ "de": {
235
+ "\U0001f47e Abenteuer": [
236
+ "Pikachu in einem magischen Wald bei Sonnenuntergang",
237
+ "Ein Minecraft-Dorf auf einer Insel mitten im Ozean",
238
+ "Ein kleiner Drache fliegt \u00fcber eine Burg bei Nacht",
239
+ "Ein Weltraumabenteuer mit Raketen und bunten Planeten",
240
+ ],
241
+ "\U0001f43e Tiere": [
242
+ "Ein freundlicher Hund rettet ein K\u00e4tzchen im Regen",
243
+ "Dinosaurier spielen Fu\u00dfball auf einer sonnigen Wiese",
244
+ "Ein Einhorn galoppiert \u00fcber einen leuchtenden Regenbogen",
245
+ "Pinguine machen eine Schneeballschlacht am S\u00fcdpol",
246
+ "Ein kleiner Fuchs entdeckt einen geheimen Garten",
247
+ ],
248
+ "\u2728 Fantasie": [
249
+ "Ein Zauberer braut einen glitzernden Trank in einem Schloss",
250
+ "Eine Fee fliegt durch einen Wald voller leuchtender Pilze",
251
+ "Ein verzaubertes Baumhaus in den Wolken mit Regenbogenbr\u00fccke",
252
+ "Ein Roboter und ein Teddy gehen zusammen auf Schatzsuche",
253
+ "Ein magischer Unterwasserpalast mit sprechenden Fischen",
254
+ ],
255
+ "\U0001f602 Lustig": [
256
+ "Eine Katze f\u00e4hrt Skateboard durch eine bunte Stadt",
257
+ "Aliens landen im Schulgarten und spielen Verstecken",
258
+ "Ein Elefant versucht sich auf einem Trampolin",
259
+ "Ein Schneemann isst Eis am Strand im Sommer",
260
+ "Monster unter dem Bett machen eine Pyjamaparty",
261
+ ],
262
+ "\U0001f3ae Spielwelt": [
263
+ "Super Mario springt durch eine Welt aus S\u00fc\u00dfigkeiten",
264
+ "Ein Ritter k\u00e4mpft gegen einen freundlichen Drachen",
265
+ "Eine Unterwasser-Rennstrecke mit U-Booten und Delfinen",
266
+ "Ein Baumhaus-Dorf im Dschungel mit H\u00e4ngebr\u00fccken",
267
+ "Tiere bauen zusammen eine riesige Sandburg am Meer",
268
+ ],
269
+ },
270
+ "en": {
271
+ "\U0001f47e Adventure": [
272
+ "Pikachu in a magical forest at sunset",
273
+ "A Minecraft village on an island in the middle of the ocean",
274
+ "A little dragon flying over a castle at night",
275
+ "A space adventure with rockets and colorful planets",
276
+ ],
277
+ "\U0001f43e Animals": [
278
+ "A friendly dog rescuing a kitten in the rain",
279
+ "Dinosaurs playing football on a sunny meadow",
280
+ "A unicorn galloping over a glowing rainbow",
281
+ "Penguins having a snowball fight at the South Pole",
282
+ "A little fox discovering a secret garden",
283
+ ],
284
+ "\u2728 Fantasy": [
285
+ "A wizard brewing a sparkling potion in a castle",
286
+ "A fairy flying through a forest of glowing mushrooms",
287
+ "An enchanted treehouse in the clouds with a rainbow bridge",
288
+ "A robot and a teddy bear going on a treasure hunt together",
289
+ "A magical underwater palace with talking fish",
290
+ ],
291
+ "\U0001f602 Funny": [
292
+ "A cat riding a skateboard through a colorful city",
293
+ "Aliens landing in the school garden and playing hide and seek",
294
+ "An elephant trying to jump on a trampoline",
295
+ "A snowman eating ice cream at the beach in summer",
296
+ "Monsters under the bed having a pajama party",
297
+ ],
298
+ "\U0001f3ae Game World": [
299
+ "Super Mario jumping through a world made of candy",
300
+ "A knight fighting a friendly dragon",
301
+ "An underwater race track with submarines and dolphins",
302
+ "A treehouse village in the jungle with rope bridges",
303
+ "Animals building a giant sandcastle at the beach",
304
+ ],
305
+ },
306
+ }
307
+
308
+ # ---------------------------------------------------------------------------
309
+ # Kid Mode — CSS theme (bright, bubbly, playful)
310
+ # ---------------------------------------------------------------------------
311
+ KID_CSS = """
312
+ <style>
313
+ /* ============================================================
314
+ KID MODE — Full theme override
315
+ ============================================================ */
316
+
317
+ /* Kill the top gap */
318
+ .block-container { padding-top: 0.5rem !important; }
319
+ header[data-testid="stHeader"] { display: none !important; }
320
+
321
+ /* Force light colorful background on EVERYTHING */
322
+ .stApp, .stApp > div, .main, .main .block-container,
323
+ [data-testid="stAppViewContainer"], [data-testid="stAppViewBlockContainer"],
324
+ section.main, section.main > div {
325
+ background: linear-gradient(170deg, #dbeafe 0%, #fce7f3 35%, #fef3c7 65%, #dcfce7 100%) !important;
326
+ color: #1e293b !important;
327
+ }
328
+ /* Sidebar light theme */
329
+ section[data-testid="stSidebar"], section[data-testid="stSidebar"] > div {
330
+ background: linear-gradient(180deg, #ede9fe 0%, #fce7f3 100%) !important;
331
+ color: #1e293b !important;
332
+ }
333
+ section[data-testid="stSidebar"] label,
334
+ section[data-testid="stSidebar"] .stMarkdown,
335
+ section[data-testid="stSidebar"] span,
336
+ section[data-testid="stSidebar"] p {
337
+ color: #334155 !important;
338
+ }
339
+ /* Force dark text everywhere */
340
+ .stMarkdown, .stMarkdown p, .stMarkdown span, .stMarkdown div,
341
+ .stTextArea textarea, label, .stSelectbox label {
342
+ color: #1e293b !important;
343
+ }
344
+ .stTextArea textarea {
345
+ background: rgba(255,255,255,0.85) !important;
346
+ border: 2px solid #c4b5fd !important;
347
+ border-radius: 18px !important;
348
+ font-size: 1rem !important;
349
+ }
350
+ .stTextArea textarea:focus {
351
+ border-color: #8b5cf6 !important;
352
+ box-shadow: 0 0 0 4px rgba(139,92,246,0.15) !important;
353
+ }
354
+ /* Status containers */
355
+ [data-testid="stStatusWidget"] {
356
+ background: rgba(255,255,255,0.6) !important;
357
+ border-radius: 14px !important;
358
+ }
359
+
360
+ /* Floating background elements */
361
+ .kid-bg {
362
+ position: fixed; top: 0; left: 0; width: 100%; height: 100%;
363
+ pointer-events: none; z-index: 0; overflow: hidden;
364
+ }
365
+ .kid-bg-item {
366
+ position: absolute; opacity: 0.15;
367
+ animation: kid-float linear infinite;
368
+ }
369
+ @keyframes kid-float {
370
+ 0% { transform: translateY(105vh) rotate(0deg) scale(0.8); opacity: 0; }
371
+ 8% { opacity: 0.35; }
372
+ 92% { opacity: 0.35; }
373
+ 100% { transform: translateY(-10vh) rotate(360deg) scale(1.1); opacity: 0; }
374
+ }
375
+ /* Twinkle for stars */
376
+ @keyframes kid-twinkle {
377
+ 0%, 100% { opacity: 0.15; transform: scale(0.8); }
378
+ 50% { opacity: 0.5; transform: scale(1.2); }
379
+ }
380
+ .kid-star-fixed {
381
+ position: absolute; pointer-events: none;
382
+ animation: kid-twinkle ease-in-out infinite;
383
+ }
384
+ /* Clouds */
385
+ .kid-cloud {
386
+ position: absolute; pointer-events: none; opacity: 0.18;
387
+ width: 120px; height: 50px; background: white;
388
+ border-radius: 50px; animation: kid-drift linear infinite;
389
+ }
390
+ .kid-cloud::before {
391
+ content: ''; position: absolute; background: white; border-radius: 50%;
392
+ width: 55px; height: 55px; top: -25px; left: 20px;
393
+ }
394
+ .kid-cloud::after {
395
+ content: ''; position: absolute; background: white; border-radius: 50%;
396
+ width: 40px; height: 40px; top: -18px; left: 55px;
397
+ }
398
+ @keyframes kid-drift {
399
+ 0% { transform: translateX(-150px); }
400
+ 100% { transform: translateX(calc(100vw + 150px)); }
401
+ }
402
+
403
+ /* Hero — big colorful title */
404
+ .kid-hero {
405
+ text-align: center; padding: 0.8rem 0 0.3rem; position: relative; z-index: 1;
406
+ }
407
+ .kid-hero-title {
408
+ font-size: 3.2rem; font-weight: 900; letter-spacing: -0.02em;
409
+ background: linear-gradient(135deg, #ec4899, #f97316, #eab308, #22c55e, #3b82f6, #8b5cf6);
410
+ background-size: 300% 300%;
411
+ -webkit-background-clip: text; -webkit-text-fill-color: transparent;
412
+ animation: kid-gradient 4s ease infinite;
413
+ text-shadow: none;
414
+ }
415
+ @keyframes kid-gradient {
416
+ 0% { background-position: 0% 50%; }
417
+ 50% { background-position: 100% 50%; }
418
+ 100% { background-position: 0% 50%; }
419
+ }
420
+ .kid-hero-sub {
421
+ font-size: 1.15rem; color: #475569; margin-top: 0.2rem; font-weight: 500;
422
+ }
423
+ .kid-hero-sub b { color: #7c3aed; }
424
+
425
+ /* Mascots — bigger, animated, with speech bubbles */
426
+ .kid-mascot-row {
427
+ display: flex; justify-content: center; gap: 2rem; margin: 0.8rem 0 0.5rem;
428
+ position: relative; z-index: 1;
429
+ }
430
+ .kid-mascot {
431
+ display: flex; flex-direction: column; align-items: center;
432
+ padding: 0.8rem 1.2rem 0.5rem; border-radius: 24px;
433
+ background: rgba(255,255,255,0.9);
434
+ border: 3px solid rgba(255,255,255,1);
435
+ box-shadow: 0 8px 30px rgba(0,0,0,0.08), 0 2px 8px rgba(139,92,246,0.1);
436
+ transition: transform 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
437
+ cursor: default; position: relative;
438
+ min-width: 105px;
439
+ }
440
+ .kid-mascot:hover {
441
+ transform: scale(1.12) rotate(-3deg);
442
+ box-shadow: 0 12px 40px rgba(139,92,246,0.25);
443
+ }
444
+ .kid-mascot svg { display: block; margin: 0 auto; }
445
+ .kid-mascot-name {
446
+ font-size: 0.9rem; font-weight: 800; margin-top: 0.15rem;
447
+ letter-spacing: 0.04em;
448
+ }
449
+ .kid-mascot:nth-child(1) .kid-mascot-name { color: #3b82f6; }
450
+ .kid-mascot:nth-child(2) .kid-mascot-name { color: #ec4899; }
451
+ .kid-mascot:nth-child(3) .kid-mascot-name { color: #f97316; }
452
+ /* Continuous gentle bounce */
453
+ .kid-mascot:nth-child(1) { animation: kid-bob 2s ease-in-out infinite; }
454
+ .kid-mascot:nth-child(2) { animation: kid-bob 2s ease-in-out 0.3s infinite; }
455
+ .kid-mascot:nth-child(3) { animation: kid-bob 2s ease-in-out 0.6s infinite; }
456
+ @keyframes kid-bob {
457
+ 0%, 100% { transform: translateY(0); }
458
+ 50% { transform: translateY(-6px); }
459
+ }
460
+ .kid-mascot:hover { animation: none; }
461
+ /* Speech bubble */
462
+ .kid-speech {
463
+ position: absolute; top: -32px; left: 50%; transform: translateX(-50%);
464
+ background: #fef3c7; color: #92400e; font-size: 0.65rem; font-weight: 700;
465
+ padding: 3px 10px; border-radius: 12px; white-space: nowrap;
466
+ box-shadow: 0 2px 8px rgba(0,0,0,0.08);
467
+ opacity: 0; transition: opacity 0.2s;
468
+ }
469
+ .kid-speech::after {
470
+ content: ''; position: absolute; bottom: -5px; left: 50%; margin-left: -5px;
471
+ border-left: 5px solid transparent; border-right: 5px solid transparent;
472
+ border-top: 5px solid #fef3c7;
473
+ }
474
+ .kid-mascot:hover .kid-speech { opacity: 1; }
475
+
476
+ /* Score cards — kid version */
477
+ .kid-scores {
478
+ display: grid; grid-template-columns: repeat(4, 1fr);
479
+ gap: 0.8rem; margin: 0.6rem 0; position: relative; z-index: 1;
480
+ }
481
+ @media (max-width: 768px) { .kid-scores { grid-template-columns: repeat(2, 1fr); } }
482
+ .kid-sc {
483
+ border-radius: 22px; padding: 1.1rem 0.8rem; text-align: center;
484
+ background: rgba(255,255,255,0.85);
485
+ border: 2.5px solid rgba(255,255,255,1);
486
+ box-shadow: 0 6px 24px rgba(0,0,0,0.06);
487
+ position: relative; overflow: hidden;
488
+ animation: kid-pop 0.4s cubic-bezier(0.34, 1.56, 0.64, 1) both;
489
+ }
490
+ .kid-sc:nth-child(1) { animation-delay: 0s; }
491
+ .kid-sc:nth-child(2) { animation-delay: 0.1s; }
492
+ .kid-sc:nth-child(3) { animation-delay: 0.2s; }
493
+ .kid-sc:nth-child(4) { animation-delay: 0.3s; }
494
+ @keyframes kid-pop {
495
+ 0% { transform: scale(0.7); opacity: 0; }
496
+ 100% { transform: scale(1); opacity: 1; }
497
+ }
498
+ .kid-sc::before {
499
+ content: ''; position: absolute; top: 0; left: 0; right: 0; height: 5px;
500
+ border-radius: 22px 22px 0 0;
501
+ }
502
+ .kid-sc-great::before { background: linear-gradient(90deg, #22c55e, #06b6d4); }
503
+ .kid-sc-ok::before { background: linear-gradient(90deg, #f59e0b, #f97316); }
504
+ .kid-sc-low::before { background: linear-gradient(90deg, #ef4444, #ec4899); }
505
+ .kid-sc-main::before { background: linear-gradient(90deg, #8b5cf6, #ec4899, #f97316, #eab308); background-size: 200%; animation: kid-gradient 3s ease infinite; }
506
+ .kid-sc-lbl {
507
+ font-size: 0.72rem; font-weight: 800; color: #64748b;
508
+ text-transform: uppercase; letter-spacing: 0.06em;
509
+ }
510
+ .kid-sc-stars { font-size: 1.8rem; margin: 0.3rem 0; line-height: 1.1; }
511
+ .kid-sc-emoji { font-size: 2.4rem; margin: 0.15rem 0; }
512
+ .kid-sc-val {
513
+ font-size: 0.7rem; color: #94a3b8; font-family: 'JetBrains Mono', monospace;
514
+ }
515
+
516
+ /* Verdict banner */
517
+ .kid-verdict {
518
+ text-align: center; font-size: 1.4rem; font-weight: 800;
519
+ color: #334155; margin: 0.4rem 0 0.6rem;
520
+ animation: kid-pop 0.5s cubic-bezier(0.34, 1.56, 0.64, 1) both;
521
+ }
522
+
523
+ /* Section labels */
524
+ .kid-sec-label {
525
+ font-size: 0.85rem; font-weight: 900; letter-spacing: 0.06em;
526
+ text-transform: uppercase; color: #7c3aed !important;
527
+ padding-bottom: 0.35rem; border-bottom: 3px solid #c4b5fd;
528
+ margin-bottom: 0.6rem;
529
+ }
530
+ .kid-text-card {
531
+ border-radius: 20px; padding: 1.2rem 1.3rem;
532
+ background: rgba(255,255,255,0.8);
533
+ border: 2px solid rgba(255,255,255,1);
534
+ box-shadow: 0 4px 20px rgba(0,0,0,0.05);
535
+ font-size: 0.95rem; line-height: 1.8; color: #334155 !important;
536
+ }
537
+
538
+ .kid-timing {
539
+ display: flex; gap: 0.5rem; flex-wrap: wrap; align-items: center;
540
+ padding: 0.45rem 0.9rem; border-radius: 16px;
541
+ background: rgba(255,255,255,0.6);
542
+ border: 2px solid rgba(255,255,255,0.9);
543
+ font-size: 0.72rem; color: #64748b !important; margin: 0.4rem 0;
544
+ }
545
+ .kid-timing span { color: #64748b !important; }
546
+ .kid-timing .t-total { color: #7c3aed !important; font-weight: 700; }
547
+ .kid-timing .t-sep { color: #cbd5e1 !important; }
548
+
549
+ /* Warn banner */
550
+ .kid-warn {
551
+ border-radius: 16px; padding: 0.8rem 1.1rem; margin-bottom: 0.6rem;
552
+ border-left: 4px solid #f97316; font-size: 0.85rem; color: #9a3412 !important;
553
+ background: rgba(255,237,213,0.7);
554
+ }
555
+
556
+ /* Button override */
557
+ .stButton > button[kind="primary"] {
558
+ background: linear-gradient(135deg, #8b5cf6, #ec4899) !important;
559
+ border: none !important; border-radius: 16px !important;
560
+ font-weight: 800 !important; font-size: 1.05rem !important;
561
+ padding: 0.6rem 1.5rem !important;
562
+ box-shadow: 0 4px 15px rgba(139,92,246,0.3) !important;
563
+ transition: transform 0.2s, box-shadow 0.2s !important;
564
+ }
565
+ .stButton > button[kind="primary"]:hover {
566
+ transform: scale(1.03) !important;
567
+ box-shadow: 0 6px 25px rgba(139,92,246,0.4) !important;
568
+ }
569
+
570
+ /* Divider */
571
+ hr { border-color: rgba(139,92,246,0.15) !important; }
572
+ </style>
573
+ """
574
+
575
+ # ---------------------------------------------------------------------------
576
+ # Kid Mode — mascot HTML, star ratings, emoji feedback
577
+ # ---------------------------------------------------------------------------
578
+
579
+ MASCOT_HTML = """
580
+ <!-- Rich floating background -->
581
+ <div class="kid-bg">
582
+ <!-- Wave 1: floating emoji rising (spread across page) -->
583
+ <div class="kid-bg-item" style="font-size:30px;left:2%;animation-duration:14s;">\u2b50</div>
584
+ <div class="kid-bg-item" style="font-size:24px;left:8%;animation-duration:18s;animation-delay:2s;">\U0001f98b</div>
585
+ <div class="kid-bg-item" style="font-size:26px;left:14%;animation-duration:16s;animation-delay:5s;">\U0001f49c</div>
586
+ <div class="kid-bg-item" style="font-size:20px;left:20%;animation-duration:22s;animation-delay:1s;">\U0001f680</div>
587
+ <div class="kid-bg-item" style="font-size:32px;left:26%;animation-duration:13s;animation-delay:3s;">\u2728</div>
588
+ <div class="kid-bg-item" style="font-size:22px;left:32%;animation-duration:19s;animation-delay:7s;">\U0001f338</div>
589
+ <div class="kid-bg-item" style="font-size:28px;left:38%;animation-duration:15s;animation-delay:4s;">\U0001f31f</div>
590
+ <div class="kid-bg-item" style="font-size:18px;left:44%;animation-duration:20s;animation-delay:0s;">\U0001f984</div>
591
+ <div class="kid-bg-item" style="font-size:26px;left:50%;animation-duration:17s;animation-delay:6s;">\U0001f308</div>
592
+ <div class="kid-bg-item" style="font-size:24px;left:56%;animation-duration:14s;animation-delay:2s;">\U0001f49b</div>
593
+ <div class="kid-bg-item" style="font-size:20px;left:62%;animation-duration:21s;animation-delay:8s;">\U0001f33c</div>
594
+ <div class="kid-bg-item" style="font-size:30px;left:68%;animation-duration:16s;animation-delay:1s;">\u2b50</div>
595
+ <div class="kid-bg-item" style="font-size:22px;left:74%;animation-duration:18s;animation-delay:5s;">\U0001f98b</div>
596
+ <div class="kid-bg-item" style="font-size:28px;left:80%;animation-duration:13s;animation-delay:3s;">\u2728</div>
597
+ <div class="kid-bg-item" style="font-size:24px;left:86%;animation-duration:20s;animation-delay:9s;">\U0001f49a</div>
598
+ <div class="kid-bg-item" style="font-size:18px;left:92%;animation-duration:15s;animation-delay:4s;">\U0001f30d</div>
599
+ <div class="kid-bg-item" style="font-size:26px;left:97%;animation-duration:17s;animation-delay:0s;">\U0001f680</div>
600
+ <!-- Wave 2: offset for constant density -->
601
+ <div class="kid-bg-item" style="font-size:22px;left:5%;animation-duration:19s;animation-delay:10s;">\U0001f33c</div>
602
+ <div class="kid-bg-item" style="font-size:28px;left:15%;animation-duration:15s;animation-delay:11s;">\U0001f49b</div>
603
+ <div class="kid-bg-item" style="font-size:18px;left:25%;animation-duration:21s;animation-delay:9s;">\U0001f984</div>
604
+ <div class="kid-bg-item" style="font-size:26px;left:35%;animation-duration:16s;animation-delay:12s;">\u2b50</div>
605
+ <div class="kid-bg-item" style="font-size:24px;left:45%;animation-duration:18s;animation-delay:8s;">\U0001f98b</div>
606
+ <div class="kid-bg-item" style="font-size:20px;left:55%;animation-duration:14s;animation-delay:13s;">\U0001f308</div>
607
+ <div class="kid-bg-item" style="font-size:30px;left:65%;animation-duration:20s;animation-delay:10s;">\u2728</div>
608
+ <div class="kid-bg-item" style="font-size:22px;left:75%;animation-duration:17s;animation-delay:11s;">\U0001f338</div>
609
+ <div class="kid-bg-item" style="font-size:26px;left:85%;animation-duration:13s;animation-delay:14s;">\U0001f49a</div>
610
+ <div class="kid-bg-item" style="font-size:24px;left:95%;animation-duration:19s;animation-delay:9s;">\U0001f31f</div>
611
+ <!-- Wave 3: more for richness -->
612
+ <div class="kid-bg-item" style="font-size:20px;left:10%;animation-duration:17s;animation-delay:15s;">\U0001f680</div>
613
+ <div class="kid-bg-item" style="font-size:26px;left:30%;animation-duration:14s;animation-delay:16s;">\U0001f338</div>
614
+ <div class="kid-bg-item" style="font-size:22px;left:50%;animation-duration:19s;animation-delay:14s;">\U0001f984</div>
615
+ <div class="kid-bg-item" style="font-size:28px;left:70%;animation-duration:15s;animation-delay:17s;">\U0001f49c</div>
616
+ <div class="kid-bg-item" style="font-size:24px;left:90%;animation-duration:18s;animation-delay:15s;">\U0001f33c</div>
617
+ <!-- Twinkling stars (fixed) -->
618
+ <div class="kid-star-fixed" style="font-size:18px;top:5%;left:8%;animation-duration:2.5s;">\u2b50</div>
619
+ <div class="kid-star-fixed" style="font-size:14px;top:12%;left:30%;animation-duration:3s;animation-delay:0.5s;">\u2b50</div>
620
+ <div class="kid-star-fixed" style="font-size:16px;top:8%;left:55%;animation-duration:2.8s;animation-delay:1s;">\u2b50</div>
621
+ <div class="kid-star-fixed" style="font-size:12px;top:15%;left:80%;animation-duration:3.5s;animation-delay:0.3s;">\u2b50</div>
622
+ <div class="kid-star-fixed" style="font-size:15px;top:35%;left:5%;animation-duration:4s;animation-delay:0.8s;">\u2b50</div>
623
+ <div class="kid-star-fixed" style="font-size:11px;top:50%;left:92%;animation-duration:3.2s;animation-delay:1.5s;">\u2b50</div>
624
+ <div class="kid-star-fixed" style="font-size:17px;top:65%;left:15%;animation-duration:2.6s;animation-delay:0.2s;">\u2b50</div>
625
+ <div class="kid-star-fixed" style="font-size:13px;top:75%;left:70%;animation-duration:3.8s;animation-delay:2s;">\u2b50</div>
626
+ <div class="kid-star-fixed" style="font-size:10px;top:88%;left:45%;animation-duration:3s;animation-delay:1.2s;">\u2b50</div>
627
+ <div class="kid-star-fixed" style="font-size:14px;top:42%;left:88%;animation-duration:2.4s;animation-delay:0.7s;">\u2b50</div>
628
+ <!-- Clouds -->
629
+ <div class="kid-cloud" style="top:3%;animation-duration:40s;"></div>
630
+ <div class="kid-cloud" style="top:20%;animation-duration:55s;animation-delay:12s;width:90px;height:38px;"></div>
631
+ <div class="kid-cloud" style="top:45%;animation-duration:48s;animation-delay:25s;width:100px;height:42px;"></div>
632
+ <div class="kid-cloud" style="top:65%;animation-duration:52s;animation-delay:8s;width:80px;height:34px;"></div>
633
+ <div class="kid-cloud" style="top:85%;animation-duration:44s;animation-delay:20s;"></div>
634
+ </div>
635
+ <!-- Corner characters: cute SVG creatures -->
636
+ <!-- Cat (bottom-left) -->
637
+ <div style="position:fixed;bottom:15px;left:260px;z-index:2;opacity:0.4;pointer-events:none;animation:kid-bob 3s ease-in-out infinite;">
638
+ <svg width="55" height="50" viewBox="0 0 55 50">
639
+ <polygon points="9,16 4,2 17,12" fill="#f97316"/>
640
+ <polygon points="46,16 51,2 39,12" fill="#f97316"/>
641
+ <ellipse cx="27" cy="27" rx="20" ry="16" fill="#fb923c"/>
642
+ <ellipse cx="20" cy="25" rx="2.5" ry="3" fill="#1e293b"/>
643
+ <ellipse cx="34" cy="25" rx="2.5" ry="3" fill="#1e293b"/>
644
+ <circle cx="21" cy="24" r="0.8" fill="white"/>
645
+ <circle cx="35" cy="24" r="0.8" fill="white"/>
646
+ <ellipse cx="27" cy="30" rx="2" ry="1.2" fill="#f472b6"/>
647
+ <path d="M24 32 Q27 35 30 32" stroke="#ea580c" stroke-width="1" fill="none"/>
648
+ <line x1="7" y1="27" x2="0" y2="25" stroke="#fdba74" stroke-width="1.2"/>
649
+ <line x1="7" y1="29" x2="0" y2="30" stroke="#fdba74" stroke-width="1.2"/>
650
+ <line x1="47" y1="27" x2="55" y2="25" stroke="#fdba74" stroke-width="1.2"/>
651
+ <line x1="47" y1="29" x2="55" y2="30" stroke="#fdba74" stroke-width="1.2"/>
652
+ <path d="M13 43 Q7 47 10 50" stroke="#fb923c" stroke-width="3.5" fill="none" stroke-linecap="round"/>
653
+ </svg></div>
654
+ <!-- Dog (bottom-right) -->
655
+ <div style="position:fixed;bottom:15px;right:25px;z-index:2;opacity:0.4;pointer-events:none;animation:kid-bob 3.5s ease-in-out 0.5s infinite;">
656
+ <svg width="55" height="50" viewBox="0 0 55 50">
657
+ <ellipse cx="10" cy="10" rx="9" ry="13" fill="#a16207" transform="rotate(-20,10,10)"/>
658
+ <ellipse cx="45" cy="10" rx="9" ry="13" fill="#a16207" transform="rotate(20,45,10)"/>
659
+ <circle cx="27" cy="25" r="18" fill="#d97706"/>
660
+ <ellipse cx="20" cy="22" rx="2.5" ry="3" fill="#1e293b"/>
661
+ <ellipse cx="34" cy="22" rx="2.5" ry="3" fill="#1e293b"/>
662
+ <circle cx="21" cy="21" r="0.8" fill="white"/>
663
+ <circle cx="35" cy="21" r="0.8" fill="white"/>
664
+ <ellipse cx="27" cy="29" rx="3.5" ry="2.5" fill="#1e293b"/>
665
+ <ellipse cx="27" cy="28" rx="2" ry="1.2" fill="#f472b6"/>
666
+ <path d="M22 33 Q27 38 32 33" stroke="#92400e" stroke-width="1.2" fill="none"/>
667
+ </svg></div>
668
+ <!-- Unicorn (top-right) -->
669
+ <div style="position:fixed;top:75px;right:25px;z-index:2;opacity:0.35;pointer-events:none;animation:kid-bob 4s ease-in-out 1s infinite;">
670
+ <svg width="50" height="55" viewBox="0 0 50 55">
671
+ <polygon points="25,0 22,15 28,15" fill="#fbbf24"/>
672
+ <circle cx="25" cy="25" r="14" fill="white" stroke="#e9d5ff" stroke-width="1"/>
673
+ <ellipse cx="19" cy="23" rx="2.5" ry="3" fill="#1e293b"/>
674
+ <ellipse cx="31" cy="23" rx="2.5" ry="3" fill="#1e293b"/>
675
+ <circle cx="20" cy="22" r="0.8" fill="white"/>
676
+ <circle cx="32" cy="22" r="0.8" fill="white"/>
677
+ <circle cx="14" cy="28" rx="3" fill="#fecdd3" opacity="0.5"/>
678
+ <circle cx="36" cy="28" rx="3" fill="#fecdd3" opacity="0.5"/>
679
+ <path d="M20 30 Q25 34 30 30" stroke="#ec4899" stroke-width="1.2" fill="none"/>
680
+ <path d="M11 16 Q5 10 7 18" stroke="#c4b5fd" stroke-width="2.5" fill="none" stroke-linecap="round"/>
681
+ <path d="M13 14 Q8 6 9 15" stroke="#fbcfe8" stroke-width="2" fill="none" stroke-linecap="round"/>
682
+ <path d="M39 16 Q45 10 43 18" stroke="#bfdbfe" stroke-width="2.5" fill="none" stroke-linecap="round"/>
683
+ <path d="M37 14 Q42 6 41 15" stroke="#fde68a" stroke-width="2" fill="none" stroke-linecap="round"/>
684
+ </svg></div>
685
+ <!-- Rocket (top-left past sidebar) -->
686
+ <div style="position:fixed;top:65px;left:260px;z-index:2;opacity:0.35;pointer-events:none;animation:kid-bob 3.2s ease-in-out 0.8s infinite;">
687
+ <svg width="35" height="55" viewBox="0 0 35 55">
688
+ <ellipse cx="17" cy="22" rx="10" ry="18" fill="#ef4444"/>
689
+ <ellipse cx="17" cy="22" rx="6.5" ry="12" fill="#fca5a5"/>
690
+ <circle cx="17" cy="19" r="4.5" fill="#dbeafe"/>
691
+ <circle cx="17" cy="19" r="2.5" fill="#3b82f6"/>
692
+ <polygon points="17,1 14,10 20,10" fill="#ef4444"/>
693
+ <polygon points="7,34 2,43 12,36" fill="#f97316"/>
694
+ <polygon points="27,34 32,43 22,36" fill="#f97316"/>
695
+ <ellipse cx="17" cy="40" rx="4" ry="3.5" fill="#fbbf24"/>
696
+ <ellipse cx="17" cy="44" rx="2.5" ry="5" fill="#fb923c" opacity="0.7"/>
697
+ <ellipse cx="17" cy="49" rx="1.5" ry="3.5" fill="#fbbf24" opacity="0.4"/>
698
+ </svg></div>
699
+ <!-- SVG Mascots -->
700
+ <div class="kid-mascot-row">
701
+ <div class="kid-mascot">
702
+ <div class="kid-speech">Ich schreibe!</div>
703
+ <svg width="70" height="75" viewBox="0 0 70 75">
704
+ <!-- Textino: cute blue robot -->
705
+ <!-- Antenna -->
706
+ <line x1="35" y1="8" x2="35" y2="0" stroke="#60a5fa" stroke-width="2.5" stroke-linecap="round"/>
707
+ <circle cx="35" cy="0" r="4" fill="#fbbf24"/>
708
+ <!-- Head -->
709
+ <rect x="10" y="8" width="50" height="32" rx="12" fill="#3b82f6"/>
710
+ <!-- Face screen -->
711
+ <rect x="15" y="13" width="40" height="22" rx="8" fill="#dbeafe"/>
712
+ <!-- Eyes -->
713
+ <circle cx="27" cy="23" r="5" fill="white"/>
714
+ <circle cx="43" cy="23" r="5" fill="white"/>
715
+ <circle cx="28" cy="23" r="3" fill="#1e293b"/>
716
+ <circle cx="44" cy="23" r="3" fill="#1e293b"/>
717
+ <!-- Eye shine -->
718
+ <circle cx="29" cy="22" r="1" fill="white"/>
719
+ <circle cx="45" cy="22" r="1" fill="white"/>
720
+ <!-- Smile -->
721
+ <path d="M25 29 Q35 35 45 29" stroke="#3b82f6" stroke-width="2" fill="none" stroke-linecap="round"/>
722
+ <!-- Body -->
723
+ <rect x="18" y="40" width="34" height="22" rx="8" fill="#60a5fa"/>
724
+ <!-- Arms -->
725
+ <rect x="5" y="42" width="13" height="8" rx="4" fill="#93c5fd"/>
726
+ <rect x="52" y="42" width="13" height="8" rx="4" fill="#93c5fd"/>
727
+ <!-- Pencil in right hand -->
728
+ <line x1="65" y1="42" x2="69" y2="32" stroke="#f97316" stroke-width="3" stroke-linecap="round"/>
729
+ <polygon points="69,32 67,28 71,28" fill="#fbbf24"/>
730
+ <!-- Belly button -->
731
+ <circle cx="35" cy="51" r="3" fill="#3b82f6"/>
732
+ <!-- Feet -->
733
+ <rect x="20" y="62" width="12" height="8" rx="4" fill="#3b82f6"/>
734
+ <rect x="38" y="62" width="12" height="8" rx="4" fill="#3b82f6"/>
735
+ </svg>
736
+ <div class="kid-mascot-name">Textino</div>
737
+ </div>
738
+ <div class="kid-mascot">
739
+ <div class="kid-speech">Ich male!</div>
740
+ <svg width="70" height="75" viewBox="0 0 70 75">
741
+ <!-- Pixela: cute pink artist character -->
742
+ <!-- Beret -->
743
+ <ellipse cx="35" cy="10" rx="22" ry="8" fill="#ec4899"/>
744
+ <circle cx="35" cy="5" r="5" fill="#f472b6"/>
745
+ <!-- Head -->
746
+ <circle cx="35" cy="25" r="20" fill="#fda4af"/>
747
+ <!-- Rosy cheeks -->
748
+ <circle cx="22" cy="29" r="5" fill="#fecdd3" opacity="0.7"/>
749
+ <circle cx="48" cy="29" r="5" fill="#fecdd3" opacity="0.7"/>
750
+ <!-- Eyes -->
751
+ <ellipse cx="27" cy="23" rx="4.5" ry="5" fill="white"/>
752
+ <ellipse cx="43" cy="23" rx="4.5" ry="5" fill="white"/>
753
+ <circle cx="28" cy="23" r="3" fill="#1e293b"/>
754
+ <circle cx="44" cy="23" r="3" fill="#1e293b"/>
755
+ <circle cx="29" cy="22" r="1" fill="white"/>
756
+ <circle cx="45" cy="22" r="1" fill="white"/>
757
+ <!-- Cat mouth -->
758
+ <path d="M30 31 L35 34 L40 31" stroke="#e11d48" stroke-width="1.5" fill="none" stroke-linecap="round"/>
759
+ <!-- Body -->
760
+ <rect x="20" y="45" width="30" height="18" rx="10" fill="#fb7185"/>
761
+ <!-- Arms -->
762
+ <rect x="7" y="47" width="13" height="7" rx="3.5" fill="#fda4af"/>
763
+ <rect x="50" y="47" width="13" height="7" rx="3.5" fill="#fda4af"/>
764
+ <!-- Paintbrush in right hand -->
765
+ <line x1="63" y1="47" x2="68" y2="35" stroke="#a16207" stroke-width="2.5" stroke-linecap="round"/>
766
+ <ellipse cx="68" cy="33" rx="4" ry="5" fill="#8b5cf6" transform="rotate(-15,68,33)"/>
767
+ <!-- Paint palette in left hand -->
768
+ <ellipse cx="4" cy="50" rx="8" ry="5" fill="#fde68a" transform="rotate(10,4,50)"/>
769
+ <circle cx="2" cy="48" r="2" fill="#ef4444"/>
770
+ <circle cx="6" cy="47" r="2" fill="#3b82f6"/>
771
+ <circle cx="4" cy="52" r="2" fill="#22c55e"/>
772
+ <!-- Feet -->
773
+ <ellipse cx="28" cy="67" rx="7" ry="5" fill="#ec4899"/>
774
+ <ellipse cx="42" cy="67" rx="7" ry="5" fill="#ec4899"/>
775
+ </svg>
776
+ <div class="kid-mascot-name">Pixela</div>
777
+ </div>
778
+ <div class="kid-mascot">
779
+ <div class="kid-speech">Ich spiele!</div>
780
+ <svg width="70" height="75" viewBox="0 0 70 75">
781
+ <!-- Soundo: cute orange music character -->
782
+ <!-- Headphones band -->
783
+ <path d="M12 25 Q12 5 35 5 Q58 5 58 25" stroke="#f97316" stroke-width="4" fill="none" stroke-linecap="round"/>
784
+ <!-- Headphone pads -->
785
+ <rect x="6" y="20" width="12" height="16" rx="6" fill="#f97316"/>
786
+ <rect x="52" y="20" width="12" height="16" rx="6" fill="#f97316"/>
787
+ <rect x="8" y="22" width="8" height="12" rx="4" fill="#fdba74"/>
788
+ <rect x="54" y="22" width="8" height="12" rx="4" fill="#fdba74"/>
789
+ <!-- Head -->
790
+ <circle cx="35" cy="28" r="18" fill="#fed7aa"/>
791
+ <!-- Eyes - happy closed -->
792
+ <path d="M24 26 Q28 22 32 26" stroke="#1e293b" stroke-width="2.5" fill="none" stroke-linecap="round"/>
793
+ <path d="M38 26 Q42 22 46 26" stroke="#1e293b" stroke-width="2.5" fill="none" stroke-linecap="round"/>
794
+ <!-- Big open smile -->
795
+ <path d="M25 33 Q35 42 45 33" stroke="#ea580c" stroke-width="2" fill="#fef3c7" stroke-linecap="round"/>
796
+ <!-- Body -->
797
+ <rect x="22" y="46" width="26" height="16" rx="8" fill="#fb923c"/>
798
+ <!-- Arms -->
799
+ <rect x="9" y="48" width="13" height="7" rx="3.5" fill="#fdba74"/>
800
+ <rect x="48" y="48" width="13" height="7" rx="3.5" fill="#fdba74"/>
801
+ <!-- Music notes floating -->
802
+ <text x="60" y="15" font-size="14" fill="#8b5cf6" opacity="0.8">\u266a</text>
803
+ <text x="4" y="12" font-size="11" fill="#ec4899" opacity="0.7">\u266b</text>
804
+ <text x="55" y="45" font-size="10" fill="#f97316" opacity="0.6">\u266a</text>
805
+ <!-- Feet -->
806
+ <ellipse cx="29" cy="66" rx="7" ry="5" fill="#f97316"/>
807
+ <ellipse cx="41" cy="66" rx="7" ry="5" fill="#f97316"/>
808
+ </svg>
809
+ <div class="kid-mascot-name">Soundo</div>
810
+ </div>
811
+ </div>
812
+ """
813
+
814
+
815
+ def _kid_stars(v: Optional[float]) -> str:
816
+ """Convert a 0-1 score to 1-5 star rating HTML."""
817
+ if v is None:
818
+ return "\u2b50" * 0
819
+ n = max(1, min(5, round(v * 10))) # 0.1→1 star, 0.5→5 stars
820
+ return "\u2b50" * n + "\u2606" * (5 - n) # filled + empty
821
+
822
+
823
+ def _kid_emoji(v: Optional[float]) -> str:
824
+ """Return emoji face based on coherence score."""
825
+ if v is None:
826
+ return "\U0001f914"
827
+ if v >= 0.45:
828
+ return "\U0001f929" # star-struck
829
+ if v >= 0.35:
830
+ return "\U0001f60a" # happy
831
+ if v >= 0.25:
832
+ return "\U0001f642" # slightly smiling
833
+ return "\U0001f61f" # worried
834
+
835
+
836
+ def _kid_verdict(v: Optional[float], lang: str = "de") -> str:
837
+ """Return kid-friendly verdict text."""
838
+ if v is None:
839
+ return "Hmm..." if lang == "de" else "Hmm..."
840
+ if lang == "de":
841
+ if v >= 0.45:
842
+ return "Super! Alles passt perfekt zusammen! \U0001f389"
843
+ if v >= 0.35:
844
+ return "Gut gemacht! Das passt ziemlich gut! \U0001f44d"
845
+ if v >= 0.25:
846
+ return "Geht so \u2014 ein bisschen passt es! \U0001f914"
847
+ return "Hmm, das passt noch nicht so gut \U0001f61e"
848
+ else:
849
+ if v >= 0.45:
850
+ return "Amazing! Everything fits perfectly together! \U0001f389"
851
+ if v >= 0.35:
852
+ return "Well done! That fits pretty well! \U0001f44d"
853
+ if v >= 0.25:
854
+ return "So-so \u2014 it fits a little bit! \U0001f914"
855
+ return "Hmm, that doesn't quite fit yet \U0001f61e"
856
+
857
+
858
+ def kid_score_card(label: str, value: Optional[float], is_main: bool = False) -> str:
859
+ """Kid-friendly score card with stars and emoji."""
860
+ cls = "kid-sc-main" if is_main else (
861
+ "kid-sc-great" if value and value >= 0.45 else
862
+ "kid-sc-ok" if value and value >= 0.30 else "kid-sc-low"
863
+ )
864
+ stars = _kid_stars(value)
865
+ emoji = _kid_emoji(value) if is_main else ""
866
+ val_str = f"{value:.3f}" if value is not None else "\u2014"
867
+ emoji_html = f'<div class="kid-sc-emoji">{emoji}</div>' if emoji else ""
868
+ return (
869
+ f'<div class="kid-sc {cls} kid-confetti">'
870
+ f'<div class="kid-sc-lbl">{label}</div>'
871
+ f'{emoji_html}'
872
+ f'<div class="kid-sc-stars">{stars}</div>'
873
+ f'<div class="kid-sc-val">{val_str}</div>'
874
+ f'</div>'
875
+ )
876
+
877
+
878
+ # Kid-mode UI labels
879
+ UI_LABELS_KID = {
880
+ "de": {
881
+ "hero_title": "Multimodale KI f\u00fcr Kids",
882
+ "hero_sub": "Beschreibe eine Szene und die KI erzeugt <b>Text + Bild + Audio</b> dazu!",
883
+ "config": "Einstellungen",
884
+ "backend": "Wie soll es erstellt werden?",
885
+ "planning": "Planungsmodus",
886
+ "language": "Sprache",
887
+ "examples": "Ideen zum Ausprobieren",
888
+ "scene_placeholder": "Beschreibe deine Szene hier... z.B. 'Ein Einhorn fliegt \u00fcber einen Regenbogen' \U0001f308",
889
+ "generate_btn": "\u2728 Los geht's!",
890
+ "welcome_text": "Beschreibe eine Szene und klicke auf <b>\u2728 Los geht's!</b>",
891
+ "welcome_hint": "oder w\u00e4hle eine Idee aus der Seitenleiste \U0001f449",
892
+ "scores_label": "\U0001f3af Wie gut passt alles zusammen?",
893
+ "gen_text_label": "\U0001f916 Textino schreibt...",
894
+ "gen_image_label": "\U0001f3a8 Pixela malt...",
895
+ "gen_audio_label": "\U0001f3b5 Soundo spielt...",
896
+ "translated_note": "Aus dem Deutschen \u00fcbersetzt",
897
+ "original_label": "Original (Deutsch)",
898
+ },
899
+ "en": {
900
+ "hero_title": "Multimodal AI for Kids",
901
+ "hero_sub": "Describe a scene and the AI creates <b>text + image + audio</b> for it!",
902
+ "config": "Settings",
903
+ "backend": "How should it be created?",
904
+ "planning": "Planning Mode",
905
+ "language": "Language",
906
+ "examples": "Ideas to Try",
907
+ "scene_placeholder": "Describe your scene here... e.g., 'A unicorn flying over a rainbow' \U0001f308",
908
+ "generate_btn": "\u2728 Let's Go!",
909
+ "welcome_text": "Describe a scene and click <b>\u2728 Let's Go!</b>",
910
+ "welcome_hint": "or pick an idea from the sidebar \U0001f449",
911
+ "scores_label": "\U0001f3af How well does everything fit together?",
912
+ "gen_text_label": "\U0001f916 Textino writes...",
913
+ "gen_image_label": "\U0001f3a8 Pixela paints...",
914
+ "gen_audio_label": "\U0001f3b5 Soundo plays...",
915
+ "translated_note": "Translated from German",
916
+ "original_label": "Original (German)",
917
+ },
918
+ }
919
+
920
  # ---------------------------------------------------------------------------
921
  # Planning prompt template (same as src/planner/prompts/unified.txt)
922
  # ---------------------------------------------------------------------------
 
1020
  return InferenceClient(token=token)
1021
 
1022
 
1023
+ # ---------------------------------------------------------------------------
1024
+ # Translation (German <-> English)
1025
+ # ---------------------------------------------------------------------------
1026
+
1027
+ TRANSLATION_MODELS = {
1028
+ "de-en": "Helsinki-NLP/opus-mt-de-en",
1029
+ "en-de": "Helsinki-NLP/opus-mt-en-de",
1030
+ }
1031
+
1032
+
1033
+ def translate(text: str, direction: str) -> str:
1034
+ """Translate text using HF Inference API. direction: 'de-en' or 'en-de'."""
1035
+ if not text or not text.strip():
1036
+ return text
1037
+ model_id = TRANSLATION_MODELS[direction]
1038
+ client = get_inference_client()
1039
+ try:
1040
+ result = client.translation(text, model=model_id)
1041
+ if isinstance(result, str):
1042
+ return result
1043
+ # huggingface_hub returns a TranslationOutput object
1044
+ return result.translation_text if hasattr(result, "translation_text") else str(result)
1045
+ except Exception as e:
1046
+ logger.warning("Translation (%s) failed: %s — returning original", direction, e)
1047
+ return text
1048
+
1049
+
1050
+ def translate_de_to_en(text: str) -> str:
1051
+ return translate(text, "de-en")
1052
+
1053
+
1054
+ def translate_en_to_de(text: str) -> str:
1055
+ return translate(text, "en-de")
1056
+
1057
+
1058
+ # ---------------------------------------------------------------------------
1059
+ # UI labels (i18n)
1060
+ # ---------------------------------------------------------------------------
1061
+
1062
+ UI_LABELS = {
1063
+ "en": {
1064
+ "hero_title": "Multimodal Coherence AI",
1065
+ "hero_sub": 'Generate semantically coherent <b>text + image + audio</b> bundles '
1066
+ 'and evaluate cross-modal alignment with the <b>MSCI</b> metric.',
1067
+ "config": "Configuration",
1068
+ "backend": "Backend",
1069
+ "planning": "Planning Mode",
1070
+ "language": "Language",
1071
+ "examples": "Examples",
1072
+ "scene_placeholder": "Describe a scene... e.g., 'A peaceful forest at dawn with birdsong and morning mist'",
1073
+ "generate_btn": "Generate Bundle",
1074
+ "welcome_text": 'Enter a scene description and click <b>Generate Bundle</b>',
1075
+ "welcome_hint": "or pick an example from the sidebar",
1076
+ "scores_label": "Coherence Scores",
1077
+ "gen_text_label": "Generated Text",
1078
+ "gen_image_label": "Generated Image",
1079
+ "gen_audio_label": "Generated Audio",
1080
+ "translated_note": "Translated from German",
1081
+ "original_label": "Original (German)",
1082
+ },
1083
+ "de": {
1084
+ "hero_title": "Multimodale Koh\u00e4renz-KI",
1085
+ "hero_sub": 'Erzeuge semantisch koh\u00e4rente <b>Text + Bild + Audio</b> B\u00fcndel '
1086
+ 'und bewerte die modale \u00dcbereinstimmung mit der <b>MSCI</b>-Metrik.',
1087
+ "config": "Einstellungen",
1088
+ "backend": "Verfahren",
1089
+ "planning": "Planungsmodus",
1090
+ "language": "Sprache",
1091
+ "examples": "Beispiele",
1092
+ "scene_placeholder": "Beschreibe eine Szene... z.B. 'Ein friedlicher Wald bei Sonnenaufgang mit Vogelgesang'",
1093
+ "generate_btn": "B\u00fcndel erzeugen",
1094
+ "welcome_text": 'Beschreibe eine Szene und klicke auf <b>B\u00fcndel erzeugen</b>',
1095
+ "welcome_hint": "oder w\u00e4hle ein Beispiel aus der Seitenleiste",
1096
+ "scores_label": "Koh\u00e4renz-Bewertung",
1097
+ "gen_text_label": "Erzeugter Text",
1098
+ "gen_image_label": "Erzeugtes Bild",
1099
+ "gen_audio_label": "Erzeugtes Audio",
1100
+ "translated_note": "Aus dem Deutschen \u00fcbersetzt",
1101
+ "original_label": "Original (Deutsch)",
1102
+ },
1103
+ }
1104
+
1105
+
1106
  # ---------------------------------------------------------------------------
1107
  # HF Inference API helpers
1108
  # ---------------------------------------------------------------------------
1109
 
1110
+ # Primary models (may consume credits via Inference Providers)
1111
+ TEXT_GEN_MODELS_PAID = [
1112
  "mistralai/Mistral-7B-Instruct-v0.3",
1113
+ "meta-llama/Llama-3.2-3B-Instruct",
1114
+ ]
1115
+ # Free serverless models (rate-limited but no credit cost)
1116
+ TEXT_GEN_MODELS_FREE = [
1117
  "HuggingFaceH4/zephyr-7b-beta",
1118
  "microsoft/Phi-3-mini-4k-instruct",
1119
+ "google/gemma-2-2b-it",
1120
  ]
1121
+ # Combined: try free first, then paid
1122
+ TEXT_GEN_MODELS = TEXT_GEN_MODELS_FREE + TEXT_GEN_MODELS_PAID
1123
+
1124
+ def _is_credit_error(e: Exception) -> bool:
1125
+ """Check if an exception is a 402 Payment Required (credits depleted)."""
1126
+ msg = str(e).lower()
1127
+ return "402" in msg or "payment required" in msg or "credit" in msg
1128
+
1129
 
1130
  def _hf_chat(system: str, user: str, max_tokens: int = 500, temperature: float = 0.3) -> str:
1131
+ """Call HF Inference API chat completion, trying multiple models.
1132
+
1133
+ Tries free serverless models first, then paid models.
1134
+ Skips paid models entirely if a 402 credit error is detected.
1135
+ """
1136
  client = get_inference_client()
1137
  last_error = None
1138
+ credits_depleted = False
1139
+
1140
  for model_id in TEXT_GEN_MODELS:
1141
+ # Skip paid models if we already know credits are gone
1142
+ if credits_depleted and model_id in TEXT_GEN_MODELS_PAID:
1143
+ logger.info("Skipping paid model %s (credits depleted)", model_id)
1144
+ continue
1145
  try:
1146
  response = client.chat_completion(
1147
  model=model_id,
 
1157
  return text
1158
  except Exception as e:
1159
  last_error = e
1160
+ if _is_credit_error(e):
1161
+ credits_depleted = True
1162
+ logger.warning("Chat model %s: credits depleted (402)", model_id)
1163
+ else:
1164
+ logger.warning("Chat model %s failed: %s", model_id, e)
1165
  continue
1166
+
1167
+ detail = "Credit balance is depleted." if credits_depleted else f"Last error: {last_error}"
1168
+ raise RuntimeError(f"All text models failed. {detail}")
1169
 
1170
 
1171
  def _parse_plan_json(raw: str) -> Optional[Dict[str, Any]]:
 
1251
  # Generation / retrieval functions
1252
  # ---------------------------------------------------------------------------
1253
 
1254
+ # HF Inference API model IDs — free models first, paid fallback
1255
+ IMAGE_GEN_MODELS = [
1256
+ "black-forest-labs/FLUX.1-schnell", # Free serverless
1257
+ "stabilityai/stable-diffusion-xl-base-1.0", # May need credits
1258
+ ]
1259
  AUDIO_GEN_MODELS = [
1260
+ "facebook/musicgen-small", # Free serverless
1261
+ "cvssp/audioldm2", # May need credits
1262
  ]
1263
 
1264
  def gen_text(prompt: str, mode: str) -> dict:
 
1316
 
1317
 
1318
  def generate_image(prompt: str) -> dict:
1319
+ """Generate image via HF Inference API, trying free models first. Falls back to retrieval."""
1320
  client = get_inference_client()
1321
+ credits_depleted = False
1322
+ for model_id in IMAGE_GEN_MODELS:
1323
+ if credits_depleted and model_id == "stabilityai/stable-diffusion-xl-base-1.0":
1324
+ logger.info("Skipping paid image model (credits depleted)")
1325
+ continue
1326
+ try:
1327
+ image = client.text_to_image(prompt, model=model_id)
1328
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
1329
+ image.save(tmp.name)
1330
+ model_name = model_id.split("/")[-1]
1331
+ return {
1332
+ "path": tmp.name, "backend": "generative",
1333
+ "model": model_name, "failed": False,
1334
+ }
1335
+ except Exception as e:
1336
+ if _is_credit_error(e):
1337
+ credits_depleted = True
1338
+ logger.warning("Image model %s: credits depleted (402)", model_id)
1339
+ else:
1340
+ logger.warning("Image gen with %s failed: %s", model_id, e)
1341
+ continue
1342
+ logger.warning("All image generation models failed — falling back to retrieval")
1343
+ result = retrieve_image(prompt)
1344
+ if credits_depleted:
1345
+ result["credit_error"] = True
1346
+ return result
1347
 
1348
 
1349
  def generate_audio(prompt: str) -> dict:
1350
+ """Generate audio via HF Inference API, trying free models first. Falls back to retrieval."""
1351
  client = get_inference_client()
1352
+ credits_depleted = False
1353
  for model_id in AUDIO_GEN_MODELS:
1354
+ if credits_depleted and model_id == "cvssp/audioldm2":
1355
+ logger.info("Skipping paid audio model (credits depleted)")
1356
+ continue
1357
  try:
1358
  audio_bytes = client.text_to_audio(prompt, model=model_id)
1359
  suffix = ".flac" if "musicgen" in model_id else ".wav"
 
1362
  tmp.write(audio_bytes)
1363
  tmp.flush()
1364
  else:
 
1365
  tmp.write(bytes(audio_bytes))
1366
  tmp.flush()
1367
  model_name = model_id.split("/")[-1]
 
1370
  "model": model_name, "failed": False,
1371
  }
1372
  except Exception as e:
1373
+ if _is_credit_error(e):
1374
+ credits_depleted = True
1375
+ logger.warning("Audio model %s: credits depleted (402)", model_id)
1376
+ else:
1377
+ logger.warning("Audio gen with %s failed: %s", model_id, e)
1378
  continue
 
1379
  logger.warning("All audio generation models failed — falling back to retrieval")
1380
+ result = retrieve_audio(prompt)
1381
+ if credits_depleted:
1382
+ result["credit_error"] = True
1383
+ return result
1384
 
1385
 
1386
  def retrieve_image(prompt: str) -> dict:
 
1452
  layout="wide",
1453
  initial_sidebar_state="expanded",
1454
  )
 
1455
 
1456
+ # Sidebar — settings first (needed for CSS choice)
 
 
 
 
 
 
 
 
1457
  with st.sidebar:
1458
  st.markdown("#### Configuration")
1459
 
1460
+ kid_mode = st.toggle("\U0001f476 Kid Mode", value=False)
1461
+
1462
+ lang = st.selectbox(
1463
+ "Language / Sprache",
1464
+ ["en", "de"],
1465
+ format_func=lambda x: {"en": "English", "de": "Deutsch"}[x],
1466
+ )
1467
+
1468
+ # Select labels based on kid mode and language
1469
+ if kid_mode:
1470
+ L = UI_LABELS_KID.get(lang, UI_LABELS_KID["en"])
1471
+ else:
1472
+ L = UI_LABELS[lang]
1473
+
1474
  backend = st.selectbox(
1475
+ L["backend"],
1476
  ["generative", "retrieval"],
1477
  format_func=lambda x: {
1478
+ "generative": "Generative (FLUX/SDXL + MusicGen)",
1479
  "retrieval": "Retrieval (CLIP + CLAP index)",
1480
  }[x],
1481
  )
1482
 
1483
  mode = st.selectbox(
1484
+ L["planning"],
1485
  ["direct", "planner", "council", "extended_prompt"],
1486
  format_func=lambda x: {
1487
  "direct": "Direct",
 
1492
  )
1493
 
1494
  st.divider()
1495
+ st.markdown(f"#### {L['examples']}")
1496
+
1497
+ # Kid mode uses fun themed prompts; normal mode uses domain prompts
1498
+ if kid_mode:
1499
+ lang_examples = KID_EXAMPLE_PROMPTS.get(lang, KID_EXAMPLE_PROMPTS["en"])
1500
+ for dname, prompts in lang_examples.items():
1501
+ with st.expander(dname): # already has emoji in key
1502
+ for p in prompts:
1503
+ if st.button(p, key=f"ex_{hash(p)}", use_container_width=True):
1504
+ st.session_state["prompt_input"] = p
1505
+ else:
1506
+ lang_examples = EXAMPLE_PROMPTS.get(lang, EXAMPLE_PROMPTS["en"])
1507
+ domain_icons_de = {"natur": "\U0001f33f", "stadt": "\U0001f3d9\ufe0f", "wasser": "\U0001f30a", "gemischt": "\U0001f310"}
1508
+ for dname, prompts in lang_examples.items():
1509
+ icon = DOMAIN_ICONS.get(dname.lower(), domain_icons_de.get(dname.lower(), "\U0001f4cd"))
1510
+ with st.expander(f"{icon} {dname}"):
1511
+ for p in prompts:
1512
+ if st.button(p, key=f"ex_{hash(p)}", use_container_width=True):
1513
+ st.session_state["prompt_input"] = p
1514
 
1515
  st.divider()
1516
  mode_desc = {
 
1520
  "extended_prompt": "Single LLM call with 3x token budget",
1521
  }
1522
  if backend == "generative":
1523
+ img_info = "FLUX.1-schnell / SDXL via HF API"
1524
+ aud_info = "MusicGen / AudioLDM2 via HF API"
1525
  else:
1526
  img_info = "CLIP retrieval (57 images)"
1527
  aud_info = "CLAP retrieval (104 clips)"
1528
+ trans_info = "<br><b>Translation</b> opus-mt-de-en / en-de" if lang == "de" else ""
1529
  st.markdown(
1530
  f'<div class="sidebar-info">'
1531
  f'<b>Text</b> HF Inference API<br>'
1532
  f'<b>Planning</b> {mode_desc[mode]}<br>'
1533
  f'<b>Image</b> {img_info}<br>'
1534
+ f'<b>Audio</b> {aud_info}{trans_info}<br><br>'
1535
  f'<b>Metric</b> MSCI = 0.45 &times; s<sub>t,i</sub> + 0.45 &times; s<sub>t,a</sub><br><br>'
1536
  f'<b>Models</b><br>'
1537
  f'CLIP ViT-B/32 (coherence eval)<br>'
1538
  f'CLAP HTSAT-unfused (coherence eval)'
1539
  f'</div>', unsafe_allow_html=True)
1540
 
1541
+ # Apply CSS based on mode
1542
+ if kid_mode:
1543
+ st.markdown(KID_CSS, unsafe_allow_html=True) # kid theme (includes all needed overrides)
1544
+ else:
1545
+ st.markdown(CUSTOM_CSS, unsafe_allow_html=True) # professional dark theme
1546
+
1547
+ # Hero
1548
+ if kid_mode:
1549
+ st.markdown(
1550
+ f'<div class="kid-hero">'
1551
+ f'<div class="kid-hero-title">{L["hero_title"]}</div>'
1552
+ f'<div class="kid-hero-sub">{L["hero_sub"]}</div>'
1553
+ f'</div>', unsafe_allow_html=True)
1554
+ st.markdown(MASCOT_HTML, unsafe_allow_html=True)
1555
+ else:
1556
+ st.markdown(
1557
+ f'<div class="hero-wrap">'
1558
+ f'<div class="hero-title">{L["hero_title"]}</div>'
1559
+ f'<div class="hero-sub">{L["hero_sub"]}</div>'
1560
+ f'</div>', unsafe_allow_html=True)
1561
+
1562
  # Prompt input
1563
  default_prompt = st.session_state.get("prompt_input", "")
1564
  prompt = st.text_area(
1565
  "Scene", value=default_prompt, height=80,
1566
+ placeholder=L["scene_placeholder"],
1567
  label_visibility="collapsed",
1568
  )
1569
 
1570
  # Button + chips
1571
  bc1, bc2 = st.columns([1, 3])
1572
  with bc1:
1573
+ go = st.button(L["generate_btn"], type="primary", use_container_width=True, disabled=not prompt.strip())
1574
  with bc2:
1575
  mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
1576
  mcls = "chip-amber" if mode != "direct" else "chip-purple"
 
1579
  bchip = '<span class="chip chip-pink"><span class="chip-dot chip-dot-pink"></span>Generative</span>'
1580
  else:
1581
  bchip = '<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
1582
+ lang_chip = ""
1583
+ if lang == "de":
1584
+ lang_chip = '<span class="chip chip-amber"><span class="chip-dot chip-dot-amber"></span>DE \u2192 EN</span>'
1585
+ kid_chip = ""
1586
+ if kid_mode:
1587
+ kid_chip = '<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>\U0001f476 Kid</span>'
1588
  st.markdown(
1589
  f'<div class="chip-row">'
1590
  f'{bchip}'
1591
  f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
1592
  f'<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>CLIP + CLAP</span>'
1593
+ f'{lang_chip}{kid_chip}'
1594
  f'</div>', unsafe_allow_html=True)
1595
 
1596
  # Welcome state
1597
  if not go and "last_result" not in st.session_state:
1598
+ if kid_mode:
1599
+ st.markdown(
1600
+ f'<div class="welcome" style="background:rgba(255,255,255,0.5);border-radius:24px;padding:3rem 2rem;">'
1601
+ f'<div class="welcome-icons">\U0001f916\u2728\U0001f3a8\u2728\U0001f3b5</div>'
1602
+ f'<div class="welcome-text" style="color:#334155;">{L["welcome_text"]}</div>'
1603
+ f'<div class="welcome-hint" style="color:#64748b;">{L["welcome_hint"]}</div>'
1604
+ f'</div>', unsafe_allow_html=True)
1605
+ else:
1606
+ st.markdown(
1607
+ f'<div class="welcome">'
1608
+ f'<div class="welcome-icons">\U0001f3a8 \U0001f5bc\ufe0f \U0001f50a</div>'
1609
+ f'<div class="welcome-text">{L["welcome_text"]}</div>'
1610
+ f'<div class="welcome-hint">{L["welcome_hint"]}</div>'
1611
+ f'</div>', unsafe_allow_html=True)
1612
  return
1613
 
1614
  if go and prompt.strip():
1615
+ st.session_state["last_result"] = run_pipeline(prompt.strip(), mode, backend, lang)
1616
+ st.session_state["last_result"]["kid_mode"] = kid_mode
1617
 
1618
  if "last_result" in st.session_state:
1619
+ # Update kid_mode in case user toggled it after generation
1620
+ st.session_state["last_result"]["kid_mode"] = kid_mode
1621
  show_results(st.session_state["last_result"])
1622
 
1623
 
 
1625
  # Pipeline
1626
  # ---------------------------------------------------------------------------
1627
 
1628
+ def run_pipeline(prompt: str, mode: str, backend: str = "generative", lang: str = "en") -> dict:
1629
+ R: dict = {"mode": mode, "backend": backend, "lang": lang, "original_prompt": prompt}
1630
  t_all = time.time()
1631
 
1632
+ # 0) Translate German → English if needed
1633
+ en_prompt = prompt
1634
+ if lang == "de":
1635
+ with st.status("\u00dcbersetze ins Englische...", expanded=True) as s:
1636
+ t0 = time.time()
1637
+ en_prompt = translate_de_to_en(prompt)
1638
+ t_trans = time.time() - t0
1639
+ R["t_translate"] = t_trans
1640
+ R["en_prompt"] = en_prompt
1641
+ s.update(label=f"Translated ({t_trans:.1f}s): {en_prompt[:80]}...", state="complete")
1642
+ else:
1643
+ R["en_prompt"] = prompt
1644
+
1645
+ # 1) Text + Planning (always in English for CLIP/CLAP)
1646
  plan_label = "Generating text..." if mode == "direct" else f"Planning ({mode}) + generating text..."
1647
  with st.status(plan_label, expanded=True) as s:
1648
  t0 = time.time()
1649
  try:
1650
+ R["text"] = gen_text(en_prompt, mode)
1651
  R["t_text"] = time.time() - t0
1652
  has_plan = R["text"].get("plan") is not None
1653
  lbl = f"Text ready ({R['t_text']:.1f}s)"
 
1656
  s.update(label=lbl, state="complete")
1657
  except Exception as e:
1658
  s.update(label=f"Text failed: {e}", state="error")
1659
+ R["text"] = {"text": en_prompt, "image_prompt": en_prompt, "audio_prompt": en_prompt}
1660
  R["t_text"] = time.time() - t0
1661
 
1662
+ # Translate generated text back to German for display
1663
+ if lang == "de":
1664
+ en_text = R["text"].get("text", "")
1665
+ R["text"]["text_en"] = en_text
1666
+ R["text"]["text"] = translate_en_to_de(en_text)
1667
+
1668
+ ip = R["text"].get("image_prompt", en_prompt)
1669
+ ap = R["text"].get("audio_prompt", en_prompt)
1670
 
1671
  # 2) Image
1672
+ img_label = "Generating image..." if backend == "generative" else "Retrieving image..."
1673
  with st.status(img_label, expanded=True) as s:
1674
  t0 = time.time()
1675
  try:
 
1720
  R["audio"] = None
1721
  R["t_aud"] = time.time() - t0
1722
 
1723
+ # 4) Coherence evaluation (always use English text for CLIP/CLAP)
1724
  with st.status("Evaluating coherence...", expanded=True) as s:
1725
  t0 = time.time()
1726
  try:
1727
  imgp = R.get("image", {}).get("path") if R.get("image") else None
1728
  audp = R.get("audio", {}).get("path") if R.get("audio") else None
1729
+ eval_text = R["text"].get("text_en", R["text"]["text"]) # English for CLIP/CLAP
1730
+ R["coherence"] = eval_coherence(eval_text, imgp, audp)
1731
  R["t_eval"] = time.time() - t0
1732
  msci = R["coherence"].get("scores", {}).get("msci")
1733
  s.update(label=f"MSCI = {msci:.4f} ({R['t_eval']:.1f}s)", state="complete")
 
1751
  msci = sc.get("msci")
1752
  st_i = sc.get("st_i")
1753
  st_a = sc.get("st_a")
1754
+ lang = R.get("lang", "en")
1755
+ kid_mode = R.get("kid_mode", False)
1756
 
1757
+ if kid_mode:
1758
+ L = UI_LABELS_KID.get(lang, UI_LABELS_KID["en"])
1759
+ else:
1760
+ L = UI_LABELS.get(lang, UI_LABELS["en"])
1761
+
1762
+ # Warn banner CSS class
1763
+ warn_cls = "kid-warn" if kid_mode else "warn-banner"
1764
+
1765
+ # --- Score cards ---
1766
+ if kid_mode:
1767
+ st.markdown(f'<div class="kid-sec-label">{L["scores_label"]}</div>', unsafe_allow_html=True)
1768
+ # Kid verdict banner
1769
+ verdict = _kid_verdict(msci, lang)
1770
+ st.markdown(f'<div class="kid-verdict">{verdict}</div>', unsafe_allow_html=True)
1771
+ # Balloons for high coherence!
1772
+ if msci is not None and msci >= 0.40:
1773
+ st.balloons()
1774
+ cards = (
1775
+ kid_score_card("\U0001f3af Gesamt" if lang == "de" else "\U0001f3af Overall", msci, is_main=True)
1776
+ + kid_score_card("\U0001f5bc\ufe0f Text \u2192 Bild" if lang == "de" else "\U0001f5bc\ufe0f Text \u2192 Image", st_i)
1777
+ + kid_score_card("\U0001f50a Text \u2192 Ton" if lang == "de" else "\U0001f50a Text \u2192 Audio", st_a)
1778
+ + kid_score_card("\U0001f31f Sterne" if lang == "de" else "\U0001f31f Stars", msci)
1779
+ )
1780
+ st.markdown(f'<div class="kid-scores">{cards}</div>', unsafe_allow_html=True)
1781
+ else:
1782
+ st.markdown(f'<div class="sec-label">{L["scores_label"]}</div>', unsafe_allow_html=True)
1783
+ cards = (
1784
+ score_card_html("MSCI (Overall)", msci)
1785
+ + score_card_html("Text \u2192 Image", st_i)
1786
+ + score_card_html("Text \u2192 Audio", st_a)
1787
+ + score_card_html("Classification", msci, is_class=True)
1788
+ )
1789
+ st.markdown(f'<div class="scores-grid">{cards}</div>', unsafe_allow_html=True)
1790
 
1791
  # Timing strip
1792
  tt = R.get("t_total", 0)
1793
  sep = '<span class="t-sep">|</span>'
1794
+ trans_timing = f'{sep}<span>Translate {R.get("t_translate", 0):.1f}s</span>' if lang == "de" else ""
1795
+ timing_cls = "kid-timing" if kid_mode else "timing"
1796
  st.markdown(
1797
+ f'<div class="{timing_cls}">'
1798
  f'<span class="t-total">Total {tt:.1f}s</span>{sep}'
1799
+ f'{trans_timing}'
1800
  f'<span>Text {R.get("t_text", 0):.1f}s</span>{sep}'
1801
  f'<span>Image {R.get("t_img", 0):.1f}s</span>{sep}'
1802
  f'<span>Audio {R.get("t_aud", 0):.1f}s</span>{sep}'
 
1805
 
1806
  st.markdown("---")
1807
 
1808
+ # CSS class helpers for kid/normal mode
1809
+ sec_cls = "kid-sec-label" if kid_mode else "sec-label"
1810
+ text_cls = "kid-text-card" if kid_mode else "text-card"
1811
+
1812
  # Three columns: text | image | audio
1813
  ct, ci, ca = st.columns([1.15, 1, 0.85])
1814
 
1815
  with ct:
1816
+ st.markdown(f'<div class="{sec_cls}">{L["gen_text_label"]}</div>', unsafe_allow_html=True)
1817
  txt = R.get("text", {}).get("text", "")
1818
  text_err = R.get("text", {}).get("text_error")
1819
  if text_err:
1820
+ if "credit" in text_err.lower() or "402" in text_err:
1821
+ st.markdown(
1822
+ f'<div class="{warn_cls}"><b>Text gen failed</b> — '
1823
+ f'HF credits depleted. Add credits at huggingface.co/settings/billing '
1824
+ f'or wait for free-tier reset.</div>',
1825
+ unsafe_allow_html=True)
1826
+ else:
1827
+ st.markdown(
1828
+ f'<div class="{warn_cls}"><b>Text gen failed</b> — {text_err}</div>',
1829
+ unsafe_allow_html=True)
1830
+ st.markdown(f'<div class="{text_cls}">{txt}</div>', unsafe_allow_html=True)
1831
+ # Show English original when in German mode
1832
+ if lang == "de":
1833
+ text_en = R.get("text", {}).get("text_en", "")
1834
+ if text_en and text_en != txt:
1835
+ with st.expander("English (original)"):
1836
+ st.markdown(f'<div class="{text_cls}" style="opacity:0.7">{text_en}</div>',
1837
+ unsafe_allow_html=True)
1838
 
1839
  with ci:
1840
+ st.markdown(f'<div class="{sec_cls}">{L["gen_image_label"]}</div>', unsafe_allow_html=True)
1841
  ii = R.get("image")
1842
  if ii and ii.get("path"):
1843
  ip = Path(ii["path"])
1844
  backend = ii.get("backend", "unknown")
1845
 
1846
+ if backend == "retrieval" and R.get("backend") == "generative":
1847
+ if ii.get("credit_error"):
1848
+ st.markdown(
1849
+ f'<div class="{warn_cls}"><b>HF credits depleted</b> \u2014 '
1850
+ f'using retrieval fallback.</div>',
1851
+ unsafe_allow_html=True)
1852
+ else:
1853
+ sim = ii.get("similarity", 0)
1854
+ st.markdown(
1855
+ f'<div class="{warn_cls}"><b>Retrieval fallback</b> '
1856
+ f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
1857
+ unsafe_allow_html=True)
1858
 
1859
  if ip.exists():
1860
  st.image(str(ip), use_container_width=True)
1861
  model = ii.get("model", "")
1862
  if backend == "generative":
1863
+ cap = f"\U0001f3a8 Pixela hat gemalt mit **{model}**" if kid_mode and lang == "de" else (
1864
+ f"\U0001f3a8 Pixela painted with **{model}**" if kid_mode else f"Generated via **{model}**")
1865
+ st.caption(cap)
1866
  else:
1867
  sim = ii.get("similarity", 0)
1868
  dom = ii.get("domain", "other")
1869
  ic = DOMAIN_ICONS.get(dom, "\U0001f4cd")
1870
  st.caption(f"{ic} {dom} \u00b7 sim **{sim:.3f}** \u00b7 Retrieved")
1871
  else:
1872
+ st.info("No image." if not kid_mode else "\U0001f3a8 Kein Bild." if lang == "de" else "\U0001f3a8 No image.")
1873
 
1874
  with ca:
1875
+ st.markdown(f'<div class="{sec_cls}">{L["gen_audio_label"]}</div>', unsafe_allow_html=True)
1876
  ai = R.get("audio")
1877
  if ai and ai.get("path"):
1878
  ap = Path(ai["path"])
1879
  backend = ai.get("backend", "unknown")
1880
 
1881
+ if backend == "retrieval" and R.get("backend") == "generative":
1882
+ if ai.get("credit_error"):
1883
+ st.markdown(
1884
+ f'<div class="{warn_cls}"><b>HF credits depleted</b> \u2014 '
1885
+ f'using retrieval fallback.</div>',
1886
+ unsafe_allow_html=True)
1887
+ else:
1888
+ sim = ai.get("similarity", 0)
1889
+ st.markdown(
1890
+ f'<div class="{warn_cls}"><b>Retrieval fallback</b> '
1891
+ f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
1892
+ unsafe_allow_html=True)
1893
 
1894
  if ap.exists():
1895
  st.audio(str(ap))
1896
  model = ai.get("model", "")
1897
  if backend == "generative":
1898
+ cap = f"\U0001f3b5 Soundo spielt mit **{model}**" if kid_mode and lang == "de" else (
1899
+ f"\U0001f3b5 Soundo plays with **{model}**" if kid_mode else f"Generated via **{model}**")
1900
+ st.caption(cap)
1901
  else:
1902
  sim = ai.get("similarity", 0)
1903
  st.caption(f"sim **{sim:.3f}** \u00b7 Retrieved")
1904
  else:
1905
+ st.info("No audio." if not kid_mode else "\U0001f3b5 Kein Audio." if lang == "de" else "\U0001f3b5 No audio.")
1906
 
1907
  st.markdown("---")
1908
 
1909
+ # Expandable details (hidden in kid mode to keep it simple)
1910
+ if not kid_mode:
1911
+ with st.expander("Semantic Plan"):
1912
+ td = R.get("text", {})
1913
+ plan = td.get("plan")
1914
+ if plan:
1915
+ p1, p2 = st.columns(2)
1916
+ with p1:
1917
+ dash = "\u2014"
1918
+ dot = "\u00b7"
1919
+ scene = plan.get("scene_summary", dash)
1920
+ domain = plan.get("domain", dash)
1921
+ core = plan.get("core_semantics", {})
1922
+ setting = core.get("setting", dash)
1923
+ tod = core.get("time_of_day", dash)
1924
+ weather = core.get("weather", dash)
1925
+ subjects = ", ".join(core.get("main_subjects", []))
1926
+ st.markdown(f"**Scene** {scene}")
1927
+ st.markdown(f"**Domain** {domain}")
1928
+ st.markdown(f"**Setting** {setting} {dot} **Time** {tod} {dot} **Weather** {weather}")
1929
+ st.markdown(f"**Subjects** {subjects}")
1930
+ with p2:
1931
+ st.markdown("**Image prompt**")
1932
+ st.code(td.get("image_prompt", ""), language=None)
1933
+ st.markdown("**Audio prompt**")
1934
+ st.code(td.get("audio_prompt", ""), language=None)
 
 
 
1935
  else:
1936
+ mode = R.get("mode", "direct")
1937
+ if mode == "direct":
1938
+ st.write("Direct mode \u2014 no semantic plan. Prompt used as-is for all modalities.")
1939
+ else:
1940
+ st.write(f"Planning ({mode}) did not produce a valid plan. Fell back to direct mode.")
1941
+
1942
+ with st.expander("Generation Details"):
1943
+ r1, r2 = st.columns(2)
1944
+ with r1:
1945
+ ii = R.get("image")
1946
+ if ii:
1947
+ backend = ii.get("backend", "unknown")
1948
+ model = ii.get("model", "")
1949
+ if backend == "generative":
1950
+ st.markdown(f"**Image** generated via **{model}**")
1951
+ st.markdown(f"Prompt: *{R.get('text', {}).get('image_prompt', '')}*")
1952
+ elif ii.get("top_5"):
1953
+ st.markdown("**Image** (retrieval fallback)")
1954
+ bars = "".join(sim_bar_html(n, s) for n, s in ii["top_5"])
1955
+ st.markdown(bars, unsafe_allow_html=True)
1956
+ else:
1957
+ st.write("No image data.")
1958
+ with r2:
1959
+ ai = R.get("audio")
1960
+ if ai:
1961
+ backend = ai.get("backend", "unknown")
1962
+ model = ai.get("model", "")
1963
+ if backend == "generative":
1964
+ st.markdown(f"**Audio** generated via **{model}**")
1965
+ st.markdown(f"Prompt: *{R.get('text', {}).get('audio_prompt', '')}*")
1966
+ elif ai.get("top_5"):
1967
+ st.markdown("**Audio** (retrieval fallback)")
1968
+ bars = "".join(sim_bar_html(n, s) for n, s in ai["top_5"])
1969
+ st.markdown(bars, unsafe_allow_html=True)
1970
+ else:
1971
+ st.write("No audio data.")
1972
+
1973
+ with st.expander("Full Coherence Report"):
1974
+ if coh:
1975
+ st.json(coh)
1976
  else:
1977
+ st.write("No data.")
1978
+ else:
1979
+ # Kid mode: simple "how it works" expander instead of technical details
1980
+ label_how = "\U0001f914 Wie funktioniert das?" if lang == "de" else "\U0001f914 How does it work?"
1981
+ with st.expander(label_how):
1982
+ if lang == "de":
1983
+ st.markdown(
1984
+ "1. **Textino** \U0001f916 liest deine Beschreibung und schreibt eine Geschichte\n"
1985
+ "2. **Pixela** \U0001f3a8 malt ein Bild, das zur Geschichte passt\n"
1986
+ "3. **Soundo** \U0001f3b5 erzeugt Ger\u00e4usche und Musik dazu\n"
1987
+ "4. Dann pr\u00fcfen wir, ob alles gut zusammenpasst! \u2b50"
1988
+ )
 
1989
  else:
1990
+ st.markdown(
1991
+ "1. **Textino** \U0001f916 reads your description and writes a story\n"
1992
+ "2. **Pixela** \U0001f3a8 paints a picture that matches the story\n"
1993
+ "3. **Soundo** \U0001f3b5 creates sounds and music for it\n"
1994
+ "4. Then we check if everything fits together! \u2b50"
1995
+ )
 
1996
 
1997
 
1998
  if __name__ == "__main__":
src/coherence/calibration.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Distribution Normalization for cMSCI.
3
+
4
+ Scores from different embedding spaces (CLIP vs CLAP) and different
5
+ pairwise channels (st_i, st_a, gram_volume) have different natural
6
+ distributions. Z-score normalization makes them comparable.
7
+
8
+ The ReferenceDistribution class fits mean/std from existing experiment
9
+ data and normalizes new scores to z-scores or percentile ranks.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import json
15
+ import logging
16
+ from pathlib import Path
17
+ from typing import Dict, List, Optional
18
+
19
+ import numpy as np
20
+ from scipy import stats as sp_stats
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class ReferenceDistribution:
26
+ """
27
+ Stores mean/std for a single score channel and provides normalization.
28
+
29
+ Usage:
30
+ ref = ReferenceDistribution()
31
+ ref.fit(list_of_scores)
32
+ z = ref.normalize(new_score) # z-score
33
+ p = ref.percentile(new_score) # percentile rank [0, 1]
34
+ """
35
+
36
+ def __init__(self, name: str = ""):
37
+ self.name = name
38
+ self.mean: float = 0.0
39
+ self.std: float = 1.0
40
+ self.n: int = 0
41
+ self._sorted_values: Optional[np.ndarray] = None
42
+
43
+ def fit(self, scores: List[float]) -> None:
44
+ """Fit the distribution from a list of observed scores."""
45
+ arr = np.array(scores, dtype=np.float64)
46
+ self.n = len(arr)
47
+ self.mean = float(np.mean(arr))
48
+ self.std = float(np.std(arr, ddof=1)) if self.n > 1 else 1.0
49
+ if self.std < 1e-10:
50
+ self.std = 1.0
51
+ self._sorted_values = np.sort(arr)
52
+
53
+ def normalize(self, score: float) -> float:
54
+ """Z-score normalization: (score - mean) / std."""
55
+ return float((score - self.mean) / self.std)
56
+
57
+ def percentile(self, score: float) -> float:
58
+ """
59
+ Percentile rank of score within the reference distribution.
60
+
61
+ Returns a value in [0, 1] where 0.5 = median of reference.
62
+ """
63
+ if self._sorted_values is None or len(self._sorted_values) == 0:
64
+ return 0.5
65
+ rank = np.searchsorted(self._sorted_values, score, side="right")
66
+ return float(rank / len(self._sorted_values))
67
+
68
+ def to_dict(self) -> Dict:
69
+ return {
70
+ "name": self.name,
71
+ "mean": self.mean,
72
+ "std": self.std,
73
+ "n": self.n,
74
+ }
75
+
76
+ @classmethod
77
+ def from_dict(cls, d: Dict) -> "ReferenceDistribution":
78
+ obj = cls(name=d.get("name", ""))
79
+ obj.mean = d["mean"]
80
+ obj.std = d["std"]
81
+ obj.n = d.get("n", 0)
82
+ return obj
83
+
84
+ def save(self, path: str) -> None:
85
+ with open(path, "w") as f:
86
+ json.dump(self.to_dict(), f, indent=2)
87
+
88
+ @classmethod
89
+ def load(cls, path: str) -> "ReferenceDistribution":
90
+ with open(path) as f:
91
+ return cls.from_dict(json.load(f))
92
+
93
+
94
+ class CalibrationStore:
95
+ """
96
+ Collection of ReferenceDistributions for all score channels.
97
+
98
+ Provides save/load for the full calibration state.
99
+ """
100
+
101
+ def __init__(self):
102
+ self.distributions: Dict[str, ReferenceDistribution] = {}
103
+
104
+ def add(self, name: str, scores: List[float]) -> ReferenceDistribution:
105
+ ref = ReferenceDistribution(name=name)
106
+ ref.fit(scores)
107
+ self.distributions[name] = ref
108
+ logger.info(
109
+ "Calibration[%s]: mean=%.4f, std=%.4f, n=%d",
110
+ name, ref.mean, ref.std, ref.n,
111
+ )
112
+ return ref
113
+
114
+ def normalize(self, name: str, score: float) -> float:
115
+ if name not in self.distributions:
116
+ return score
117
+ return self.distributions[name].normalize(score)
118
+
119
+ def percentile(self, name: str, score: float) -> float:
120
+ if name not in self.distributions:
121
+ return 0.5
122
+ return self.distributions[name].percentile(score)
123
+
124
+ def save(self, path: str) -> None:
125
+ data = {name: ref.to_dict() for name, ref in self.distributions.items()}
126
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
127
+ with open(path, "w") as f:
128
+ json.dump(data, f, indent=2)
129
+ logger.info("Calibration saved to %s", path)
130
+
131
+ @classmethod
132
+ def load(cls, path: str) -> "CalibrationStore":
133
+ store = cls()
134
+ with open(path) as f:
135
+ data = json.load(f)
136
+ for name, d in data.items():
137
+ store.distributions[name] = ReferenceDistribution.from_dict(d)
138
+ logger.info("Calibration loaded from %s (%d channels)", path, len(store.distributions))
139
+ return store
140
+
141
+
142
+ def has_channel(store: CalibrationStore, name: str) -> bool:
143
+ """Check if a calibration channel exists in the store."""
144
+ return name in store.distributions
145
+
146
+
147
+ def extend_calibration_with_exmcr(
148
+ store: CalibrationStore,
149
+ gram_coh_ia_scores: List[float],
150
+ gram_coh_tia_scores: Optional[List[float]] = None,
151
+ ) -> CalibrationStore:
152
+ """
153
+ Extend calibration store with ExMCR-derived channels.
154
+
155
+ Args:
156
+ store: Existing CalibrationStore to extend.
157
+ gram_coh_ia_scores: Gram coherence of (image_clip, ExMCR(audio_clap)) pairs.
158
+ gram_coh_tia_scores: Optional 3-way gram coherence of (text, image, ExMCR(audio)).
159
+
160
+ Returns:
161
+ Extended CalibrationStore (same object, modified in place).
162
+ """
163
+ if gram_coh_ia_scores:
164
+ store.add("gram_coh_ia_exmcr", gram_coh_ia_scores)
165
+ if gram_coh_tia_scores:
166
+ store.add("gram_coh_tia", gram_coh_tia_scores)
167
+ return store
168
+
169
+
170
+ def extend_calibration_with_uncertainty(
171
+ store: CalibrationStore,
172
+ uncertainty_ti_scores: List[float],
173
+ uncertainty_ta_scores: Optional[List[float]] = None,
174
+ ) -> CalibrationStore:
175
+ """
176
+ Extend calibration store with ProbVLM uncertainty channels.
177
+
178
+ Args:
179
+ store: Existing CalibrationStore to extend.
180
+ uncertainty_ti_scores: Per-sample mean uncertainty for text-image (CLIP adapter).
181
+ uncertainty_ta_scores: Per-sample mean uncertainty for text-audio (CLAP adapter).
182
+
183
+ Returns:
184
+ Extended CalibrationStore (same object, modified in place).
185
+ """
186
+ if uncertainty_ti_scores:
187
+ store.add("uncertainty_ti", uncertainty_ti_scores)
188
+ if uncertainty_ta_scores:
189
+ store.add("uncertainty_ta", uncertainty_ta_scores)
190
+ # Combined uncertainty channel
191
+ if uncertainty_ti_scores and uncertainty_ta_scores:
192
+ combined = [
193
+ (ti + ta) / 2.0
194
+ for ti, ta in zip(uncertainty_ti_scores, uncertainty_ta_scores)
195
+ ]
196
+ store.add("uncertainty_mean", combined)
197
+ return store
198
+
199
+
200
+ def build_reference_distributions(
201
+ rq1_results_path: str,
202
+ ) -> CalibrationStore:
203
+ """
204
+ Build reference distributions from existing RQ1 baseline results.
205
+
206
+ Extracts st_i, st_a, and msci scores from baseline condition only
207
+ (matched image + audio), fitting a distribution for each channel.
208
+
209
+ Args:
210
+ rq1_results_path: Path to rq1_results.json
211
+
212
+ Returns:
213
+ CalibrationStore with fitted distributions for st_i, st_a, msci
214
+ """
215
+ with open(rq1_results_path) as f:
216
+ data = json.load(f)
217
+
218
+ st_i_scores = []
219
+ st_a_scores = []
220
+ msci_scores = []
221
+
222
+ for r in data["results"]:
223
+ if r.get("condition") != "baseline":
224
+ continue
225
+ if r.get("st_i") is not None:
226
+ st_i_scores.append(r["st_i"])
227
+ if r.get("st_a") is not None:
228
+ st_a_scores.append(r["st_a"])
229
+ if r.get("msci") is not None:
230
+ msci_scores.append(r["msci"])
231
+
232
+ store = CalibrationStore()
233
+ if st_i_scores:
234
+ store.add("st_i", st_i_scores)
235
+ if st_a_scores:
236
+ store.add("st_a", st_a_scores)
237
+ if msci_scores:
238
+ store.add("msci", msci_scores)
239
+
240
+ # GRAM coherence distributions (1 - gram_volume) for gram calibration mode
241
+ # gram_volume = sqrt(1 - cos^2), so gram_coherence = 1 - sqrt(1 - cos^2)
242
+ if st_i_scores:
243
+ gram_coh_ti = [1.0 - np.sqrt(max(0, 1 - s**2)) for s in st_i_scores]
244
+ store.add("gram_coh_ti", gram_coh_ti)
245
+ if st_a_scores:
246
+ gram_coh_ta = [1.0 - np.sqrt(max(0, 1 - s**2)) for s in st_a_scores]
247
+ store.add("gram_coh_ta", gram_coh_ta)
248
+
249
+ return store
src/coherence/cmsci_engine.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Calibrated Multimodal Semantic Coherence Index (cMSCI) Engine.
3
+
4
+ Replaces fixed weighted averaging (MSCI) with a principled pipeline:
5
+ 1. Gramian Volume: geometric coherence of embedding vectors
6
+ 2. Distribution Normalization: z-score calibration per channel
7
+ 3. Contrastive Margin: comparison against hard negatives
8
+ 4. Cross-Space Alignment: Ex-MCR projects CLAP→CLIP for 3-way GRAM
9
+ 5. Probabilistic Uncertainty: MC sampling for confidence intervals
10
+
11
+ The CalibratedCoherenceEngine runs alongside CoherenceEngine (not replacing
12
+ it) and returns both legacy MSCI and new cMSCI scores for comparison.
13
+
14
+ Variant progression:
15
+ A: MSCI (baseline, weighted cosine average)
16
+ B: GRAM-only (geometric, no calibration)
17
+ C: GRAM + z-norm (normalized geometric)
18
+ D: GRAM + z-norm + contrastive (calibrated geometric)
19
+ E: GRAM + z-norm + contrastive + Ex-MCR (3-way calibrated)
20
+ F: Full cMSCI (probabilistic + calibrated + 3-way)
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import logging
26
+ from pathlib import Path
27
+ from typing import Any, Dict, List, Optional
28
+
29
+ import numpy as np
30
+
31
+ from src.coherence.gram_volume import (
32
+ gram_volume_2d,
33
+ gram_volume_3d,
34
+ gram_volume_nd,
35
+ normalized_gram_coherence,
36
+ )
37
+ from src.config.settings import (
38
+ CMSCI_MARGIN_ALPHA,
39
+ CMSCI_CHANNEL_WEIGHT_TI,
40
+ CMSCI_CALIBRATION_MODE,
41
+ CMSCI_W_3D,
42
+ CMSCI_GAMMA,
43
+ )
44
+ from src.embeddings.aligned_embeddings import AlignedEmbedder
45
+ from src.embeddings.similarity import cosine_similarity
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ class CalibratedCoherenceEngine:
51
+ """
52
+ Uncertainty-aware, geometrically-grounded tri-modal coherence engine.
53
+
54
+ Computes cMSCI alongside legacy MSCI for comparison.
55
+
56
+ Usage:
57
+ engine = CalibratedCoherenceEngine()
58
+ result = engine.evaluate("A beach at sunset", "beach.jpg", "waves.wav")
59
+ print(result["cmsci"]) # Calibrated score
60
+ print(result["msci"]) # Legacy score (for comparison)
61
+ print(result["variant_scores"]) # Scores for each variant A-F
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ target_dim: int = 512,
67
+ calibration_path: Optional[str] = None,
68
+ exmcr_weights_path: Optional[str] = None,
69
+ bridge_path: Optional[str] = None,
70
+ prob_clip_adapter_path: Optional[str] = None,
71
+ prob_clap_adapter_path: Optional[str] = None,
72
+ negative_bank_enabled: bool = True,
73
+ ):
74
+ self.embedder = AlignedEmbedder(target_dim=target_dim)
75
+
76
+ # Calibration store (Phase 2)
77
+ self._calibration = None
78
+ if calibration_path and Path(calibration_path).exists():
79
+ from src.coherence.calibration import CalibrationStore
80
+ self._calibration = CalibrationStore.load(calibration_path)
81
+ logger.info("Calibration loaded from %s", calibration_path)
82
+
83
+ # Negative bank (Phase 2)
84
+ self._negative_bank = None
85
+ if negative_bank_enabled:
86
+ try:
87
+ from src.coherence.negative_bank import NegativeBank
88
+ self._negative_bank = NegativeBank()
89
+ except Exception as e:
90
+ logger.warning("Negative bank disabled: %s", e)
91
+
92
+ # Ex-MCR projector (Phase 3 — projects CLAP into CLIP space)
93
+ self._exmcr = None
94
+ if exmcr_weights_path:
95
+ from src.embeddings.space_alignment import ExMCRProjector
96
+ self._exmcr = ExMCRProjector(weights_path=exmcr_weights_path)
97
+ if self._exmcr.is_identity:
98
+ logger.info("Ex-MCR in identity mode (no weights)")
99
+ else:
100
+ logger.info("Ex-MCR projector active")
101
+
102
+ # Cross-Space Bridge (projects CLIP image + CLAP audio → shared 256-d)
103
+ self._bridge = None
104
+ if bridge_path and Path(bridge_path).exists():
105
+ from src.embeddings.cross_space_bridge import CrossSpaceBridge
106
+ self._bridge = CrossSpaceBridge.load(bridge_path)
107
+ logger.info("CrossSpaceBridge loaded from %s", bridge_path)
108
+
109
+ # Probabilistic adapters (Phase 4)
110
+ self._prob_clip = None
111
+ self._prob_clap = None
112
+ if prob_clip_adapter_path and Path(prob_clip_adapter_path).exists():
113
+ from src.embeddings.probabilistic_adapter import ProbabilisticAdapter
114
+ self._prob_clip = ProbabilisticAdapter.load(prob_clip_adapter_path)
115
+ logger.info("CLIP probabilistic adapter loaded")
116
+ if prob_clap_adapter_path and Path(prob_clap_adapter_path).exists():
117
+ from src.embeddings.probabilistic_adapter import ProbabilisticAdapter
118
+ self._prob_clap = ProbabilisticAdapter.load(prob_clap_adapter_path)
119
+ logger.info("CLAP probabilistic adapter loaded")
120
+
121
+ def evaluate(
122
+ self,
123
+ text: str,
124
+ image_path: Optional[str] = None,
125
+ audio_path: Optional[str] = None,
126
+ domain: str = "",
127
+ n_mc_samples: int = 100,
128
+ ) -> Dict[str, Any]:
129
+ """
130
+ Evaluate multimodal coherence with full cMSCI pipeline.
131
+
132
+ Returns both legacy MSCI and cMSCI scores along with all
133
+ intermediate computations for ablation analysis.
134
+
135
+ Args:
136
+ text: Text prompt.
137
+ image_path: Path to image file.
138
+ audio_path: Path to audio file.
139
+ domain: Domain hint for negative bank (e.g., "nature").
140
+ n_mc_samples: Number of MC samples for uncertainty.
141
+
142
+ Returns:
143
+ Dict with keys:
144
+ msci: Legacy MSCI score (weighted cosine average)
145
+ cmsci: Calibrated cMSCI score
146
+ scores: Raw pairwise scores (st_i, st_a, si_a)
147
+ gram: Gramian volume scores
148
+ calibration: Z-normalized scores
149
+ contrastive: Contrastive margin results
150
+ uncertainty: MC sampling uncertainty (if adapters loaded)
151
+ variant_scores: Scores for each variant A-F
152
+ """
153
+ # ── Embed ──────────────────────────────────────────────
154
+ emb_text_clip = self.embedder.embed_text(text)
155
+ emb_text_clap = self.embedder.embed_text_for_audio(text) if audio_path else None
156
+ emb_image = self.embedder.embed_image(image_path) if image_path else None
157
+ emb_audio = self.embedder.embed_audio(audio_path) if audio_path else None
158
+
159
+ # ── Legacy MSCI (Variant A) ────────────────────────────
160
+ st_i = None
161
+ st_a = None
162
+ si_a = None
163
+
164
+ if emb_text_clip is not None and emb_image is not None:
165
+ st_i = float(round(cosine_similarity(emb_text_clip, emb_image), 4))
166
+ if emb_text_clap is not None and emb_audio is not None:
167
+ st_a = float(round(cosine_similarity(emb_text_clap, emb_audio), 4))
168
+
169
+ available = {}
170
+ if st_i is not None:
171
+ available["st_i"] = st_i
172
+ if st_a is not None:
173
+ available["st_a"] = st_a
174
+
175
+ weights = {"st_i": 0.45, "st_a": 0.45, "si_a": 0.10}
176
+ if len(available) >= 2:
177
+ total_w = sum(weights[k] for k in available if k in weights)
178
+ msci = sum(available[k] * weights[k] for k in available if k in weights) / max(total_w, 1e-6)
179
+ elif len(available) == 1:
180
+ msci = list(available.values())[0]
181
+ else:
182
+ msci = None
183
+
184
+ variant_a = msci
185
+
186
+ # ── Gramian Volume (Variant B) ─────────────────────────
187
+ gram_ti = None
188
+ gram_ta = None
189
+ gram_tia = None
190
+ gram_coherence_2way = None
191
+
192
+ if emb_text_clip is not None and emb_image is not None:
193
+ gram_ti = gram_volume_2d(emb_text_clip, emb_image)
194
+
195
+ if emb_text_clap is not None and emb_audio is not None:
196
+ gram_ta = gram_volume_2d(emb_text_clap, emb_audio)
197
+
198
+ # 2-way GRAM coherence (average of text-image and text-audio gram coherences)
199
+ gram_coherences = []
200
+ if gram_ti is not None:
201
+ gram_coherences.append(normalized_gram_coherence(gram_ti))
202
+ if gram_ta is not None:
203
+ gram_coherences.append(normalized_gram_coherence(gram_ta))
204
+
205
+ if gram_coherences:
206
+ gram_coherence_2way = float(np.mean(gram_coherences))
207
+
208
+ variant_b = gram_coherence_2way
209
+
210
+ # ── Z-Score Normalization (Variant C) ──────────────────
211
+ z_st_i = None
212
+ z_st_a = None
213
+ z_gram_ti = None
214
+ z_gram_ta = None
215
+ variant_c = variant_b # default to B if no calibration
216
+
217
+ # Channel weight from settings (optimized via LOO-CV)
218
+ w_ti = CMSCI_CHANNEL_WEIGHT_TI
219
+ cal_mode = CMSCI_CALIBRATION_MODE
220
+
221
+ if self._calibration is not None:
222
+ if st_i is not None:
223
+ z_st_i = self._calibration.normalize("st_i", st_i)
224
+ if st_a is not None:
225
+ z_st_a = self._calibration.normalize("st_a", st_a)
226
+
227
+ # GRAM coherence z-scores (for gram calibration mode)
228
+ if gram_ti is not None:
229
+ gram_coh_ti = normalized_gram_coherence(gram_ti)
230
+ z_gram_ti = self._calibration.normalize("gram_coh_ti", gram_coh_ti)
231
+ if gram_ta is not None:
232
+ gram_coh_ta = normalized_gram_coherence(gram_ta)
233
+ z_gram_ta = self._calibration.normalize("gram_coh_ta", gram_coh_ta)
234
+
235
+ # Select calibration mode: cosine z-scores or gram coherence z-scores
236
+ if cal_mode == "gram" and z_gram_ti is not None and z_gram_ta is not None:
237
+ z_mean = w_ti * z_gram_ti + (1.0 - w_ti) * z_gram_ta
238
+ else:
239
+ # Cosine mode (original behavior) with weighted channels
240
+ z_coherences = []
241
+ z_weights = []
242
+ if z_st_i is not None:
243
+ z_coherences.append(z_st_i)
244
+ z_weights.append(w_ti)
245
+ if z_st_a is not None:
246
+ z_coherences.append(z_st_a)
247
+ z_weights.append(1.0 - w_ti)
248
+
249
+ if z_coherences:
250
+ total_w = sum(z_weights)
251
+ z_mean = sum(z * wt for z, wt in zip(z_coherences, z_weights)) / total_w
252
+ else:
253
+ z_mean = None
254
+
255
+ if z_mean is not None:
256
+ # Map z-scores back to [0,1] via sigmoid for interpretability
257
+ variant_c = float(1.0 / (1.0 + np.exp(-z_mean)))
258
+
259
+ # ── Contrastive Margin (Variant D) ─────────────────────
260
+ contrastive_result = None
261
+ variant_d = variant_c # default to C if no negatives
262
+ margin_alpha = CMSCI_MARGIN_ALPHA
263
+
264
+ if self._negative_bank is not None and gram_coherence_2way is not None:
265
+ matched_volume = float(np.mean([v for v in [gram_ti, gram_ta] if v is not None]))
266
+ contrastive_result = self._negative_bank.compute_contrastive_margin(
267
+ matched_volume=matched_volume,
268
+ text_clip_emb=emb_text_clip,
269
+ image_emb=emb_image,
270
+ text_clap_emb=emb_text_clap,
271
+ audio_emb=emb_audio,
272
+ domain=domain,
273
+ k=5,
274
+ )
275
+
276
+ if contrastive_result["n_negatives"] > 0:
277
+ # cMSCI_D = sigmoid(z_mean + alpha * margin)
278
+ # alpha amplifies the contrastive signal at the sigmoid operating point
279
+ margin = contrastive_result["margin"]
280
+
281
+ # Use the same calibration mode and weighting as Variant C
282
+ if cal_mode == "gram" and z_gram_ti is not None and z_gram_ta is not None:
283
+ z_mean_d = w_ti * z_gram_ti + (1.0 - w_ti) * z_gram_ta
284
+ else:
285
+ z_coherences_d = []
286
+ z_weights_d = []
287
+ if z_st_i is not None:
288
+ z_coherences_d.append(z_st_i)
289
+ z_weights_d.append(w_ti)
290
+ elif st_i is not None:
291
+ z_coherences_d.append(st_i)
292
+ z_weights_d.append(w_ti)
293
+ if z_st_a is not None:
294
+ z_coherences_d.append(z_st_a)
295
+ z_weights_d.append(1.0 - w_ti)
296
+ elif st_a is not None:
297
+ z_coherences_d.append(st_a)
298
+ z_weights_d.append(1.0 - w_ti)
299
+
300
+ if z_coherences_d:
301
+ total_wd = sum(z_weights_d)
302
+ z_mean_d = sum(z * wt for z, wt in zip(z_coherences_d, z_weights_d)) / total_wd
303
+ else:
304
+ z_mean_d = None
305
+
306
+ if z_mean_d is not None:
307
+ variant_d = float(1.0 / (1.0 + np.exp(-(z_mean_d + margin_alpha * margin))))
308
+ else:
309
+ variant_d = variant_c
310
+
311
+ # ── Cross-Space Complementarity — Variant E ──────────
312
+ # COMPLEMENTARITY: E = sigmoid(z_2d + w_3d * z_compl + alpha * margin)
313
+ # ExMCR projects CLAP audio → CLIP space, enabling measurement of
314
+ # image-audio complementarity (Gramian dispersion in unified space).
315
+ # High complementarity = image and audio contribute unique perspectives.
316
+ # Low complementarity = redundant cross-modal information.
317
+ # z_compl = z_normalize(gram_volume_ia) — positive z = more complementary.
318
+ # w_3d=0 recovers D exactly (safety guarantee).
319
+ audio_projected = None
320
+ variant_e = variant_d # default to D if no projector
321
+ z_compl = None # z-normalized complementarity (exported for optimizer)
322
+ gram_ia_volume = None # raw image-audio Gramian volume
323
+ w_3d = CMSCI_W_3D
324
+
325
+ # Reconstruct D's pre-margin z-score (z_2d) for composition
326
+ z_2d = None
327
+ margin = 0.0
328
+ if contrastive_result is not None and contrastive_result["n_negatives"] > 0:
329
+ margin = contrastive_result["margin"]
330
+ if cal_mode == "gram" and z_gram_ti is not None and z_gram_ta is not None:
331
+ z_2d = w_ti * z_gram_ti + (1.0 - w_ti) * z_gram_ta
332
+ elif z_st_i is not None and z_st_a is not None:
333
+ z_2d = w_ti * z_st_i + (1.0 - w_ti) * z_st_a
334
+
335
+ # Project audio into CLIP space via ExMCR and compute complementarity
336
+ if self._exmcr is not None and not self._exmcr.is_identity:
337
+ if emb_audio is not None:
338
+ audio_projected = self._exmcr.project_audio(emb_audio)
339
+ if emb_image is not None:
340
+ si_a = float(round(cosine_similarity(emb_image, audio_projected), 4))
341
+ # Image-audio Gramian volume = dispersion = complementarity
342
+ gram_ia_volume = gram_volume_2d(emb_image, audio_projected)
343
+ if emb_text_clip is not None and emb_image is not None and audio_projected is not None:
344
+ gram_tia = gram_volume_3d(emb_text_clip, emb_image, audio_projected)
345
+
346
+ # Z-normalize complementarity (volume, NOT coherence)
347
+ # z_compl = -z_gram_ia_coherence (flipped: high volume = high complementarity)
348
+ if gram_ia_volume is not None and self._calibration is not None:
349
+ gram_ia_coherence = normalized_gram_coherence(gram_ia_volume)
350
+ z_gram_ia_coh = self._calibration.normalize("gram_coh_ia_exmcr", gram_ia_coherence)
351
+ z_compl = -z_gram_ia_coh # flip: positive = more complementary
352
+
353
+ # Compose: E = sigmoid(z_2d + w_3d * z_compl + alpha * margin)
354
+ if z_2d is not None:
355
+ logit_e = z_2d + margin_alpha * margin
356
+ if z_compl is not None:
357
+ logit_e += w_3d * z_compl
358
+ variant_e = float(1.0 / (1.0 + np.exp(-logit_e)))
359
+
360
+ # ── Probabilistic Adaptive Weighting (Variant F) ──────
361
+ # ProbVLM drives per-sample channel weights instead of fixed w_ti.
362
+ # adaptive_w = (1/u_ti) / (1/u_ti + 1/u_ta) — trust more confident channel
363
+ # w_ti_final = (1 - gamma) * base_w + gamma * adaptive_w
364
+ # gamma=0 → w_ti_final = base_w → recovers E exactly (safety guarantee)
365
+ # MC sampling remains metadata only (confidence intervals, not scoring).
366
+ uncertainty_result = None
367
+ variant_f = variant_e # default to E
368
+ u_ti = None # per-channel uncertainty (exported for optimizer)
369
+ u_ta = None
370
+ adaptive_w_ti = None
371
+ gamma = CMSCI_GAMMA
372
+
373
+ if self._prob_clip is not None or self._prob_clap is not None:
374
+ mc_volumes = []
375
+
376
+ # Per-channel uncertainty from ProbVLM adapters
377
+ if self._prob_clip is not None and emb_text_clip is not None and emb_image is not None:
378
+ u_text_clip = self._prob_clip.uncertainty(emb_text_clip)
379
+ u_image_clip = self._prob_clip.uncertainty(emb_image)
380
+ u_ti = float(np.mean([u_text_clip, u_image_clip]))
381
+
382
+ # MC samples for confidence interval metadata
383
+ text_samples = self._prob_clip.sample(emb_text_clip, n_mc_samples)
384
+ image_samples = self._prob_clip.sample(emb_image, n_mc_samples)
385
+ for t_s, i_s in zip(text_samples, image_samples):
386
+ mc_volumes.append(gram_volume_2d(t_s, i_s))
387
+
388
+ if self._prob_clap is not None and emb_text_clap is not None and emb_audio is not None:
389
+ u_text_clap = self._prob_clap.uncertainty(emb_text_clap)
390
+ u_audio_clap = self._prob_clap.uncertainty(emb_audio)
391
+ u_ta = float(np.mean([u_text_clap, u_audio_clap]))
392
+
393
+ text_samples = self._prob_clap.sample(emb_text_clap, n_mc_samples)
394
+ audio_samples = self._prob_clap.sample(emb_audio, n_mc_samples)
395
+ for t_s, a_s in zip(text_samples, audio_samples):
396
+ mc_volumes.append(gram_volume_2d(t_s, a_s))
397
+
398
+ # Compute adaptive channel weight from uncertainty
399
+ if u_ti is not None and u_ta is not None and u_ti > 0 and u_ta > 0 and gamma > 0:
400
+ inv_ti = 1.0 / u_ti
401
+ inv_ta = 1.0 / u_ta
402
+ adaptive_w = inv_ti / (inv_ti + inv_ta)
403
+ w_ti_final = (1.0 - gamma) * w_ti + gamma * adaptive_w
404
+ adaptive_w_ti = float(w_ti_final)
405
+
406
+ # Recompute z_2d with adaptive weights
407
+ if cal_mode == "gram" and z_gram_ti is not None and z_gram_ta is not None:
408
+ z_2d_adaptive = w_ti_final * z_gram_ti + (1.0 - w_ti_final) * z_gram_ta
409
+ elif z_st_i is not None and z_st_a is not None:
410
+ z_2d_adaptive = w_ti_final * z_st_i + (1.0 - w_ti_final) * z_st_a
411
+ else:
412
+ z_2d_adaptive = None
413
+
414
+ if z_2d_adaptive is not None:
415
+ logit_f = z_2d_adaptive + margin_alpha * margin
416
+ if z_compl is not None:
417
+ logit_f += w_3d * z_compl
418
+ variant_f = float(1.0 / (1.0 + np.exp(-logit_f)))
419
+
420
+ # MC sampling for confidence intervals (metadata, NOT scoring)
421
+ if mc_volumes:
422
+ mc_coherences = [normalized_gram_coherence(v) for v in mc_volumes]
423
+ mc_mean = float(np.mean(mc_coherences))
424
+ mc_std = float(np.std(mc_coherences))
425
+ mc_ci_lower = float(np.percentile(mc_coherences, 2.5))
426
+ mc_ci_upper = float(np.percentile(mc_coherences, 97.5))
427
+ else:
428
+ mc_mean = mc_std = mc_ci_lower = mc_ci_upper = None
429
+
430
+ uncertainty_result = {
431
+ "mc_mean": round(mc_mean, 4) if mc_mean is not None else None,
432
+ "mc_std": round(mc_std, 4) if mc_std is not None else None,
433
+ "mc_ci_lower": round(mc_ci_lower, 4) if mc_ci_lower is not None else None,
434
+ "mc_ci_upper": round(mc_ci_upper, 4) if mc_ci_upper is not None else None,
435
+ "u_ti": round(u_ti, 6) if u_ti is not None else None,
436
+ "u_ta": round(u_ta, 6) if u_ta is not None else None,
437
+ "adaptive_w_ti": round(adaptive_w_ti, 4) if adaptive_w_ti is not None else None,
438
+ "gamma": gamma,
439
+ "n_samples": n_mc_samples,
440
+ }
441
+
442
+ # ── Assemble cMSCI ─────────────────────────────────────
443
+ # cMSCI is the highest available variant
444
+ cmsci = variant_f
445
+ active_variant = "F"
446
+
447
+ if variant_f == variant_e:
448
+ active_variant = "E" if variant_e != variant_d else "D"
449
+ if variant_e == variant_d:
450
+ active_variant = "D" if variant_d != variant_c else "C"
451
+ if variant_d == variant_c:
452
+ active_variant = "C" if variant_c != variant_b else "B"
453
+ if variant_c == variant_b:
454
+ active_variant = "B" if variant_b is not None else "A"
455
+
456
+ # Final cMSCI: use the most sophisticated available variant
457
+ if cmsci is None:
458
+ cmsci = msci # fallback to legacy
459
+ active_variant = "A"
460
+
461
+ logger.info(
462
+ "cMSCI = %.4f (variant %s) | MSCI = %s",
463
+ cmsci if cmsci is not None else 0.0,
464
+ active_variant,
465
+ msci,
466
+ )
467
+
468
+ return {
469
+ "cmsci": round(cmsci, 4) if cmsci is not None else None,
470
+ "msci": round(msci, 4) if msci is not None else None,
471
+ "active_variant": active_variant,
472
+ "scores": {
473
+ "st_i": st_i,
474
+ "st_a": st_a,
475
+ "si_a": si_a,
476
+ },
477
+ "gram": {
478
+ "text_image": round(gram_ti, 4) if gram_ti is not None else None,
479
+ "text_audio": round(gram_ta, 4) if gram_ta is not None else None,
480
+ "text_image_audio": round(gram_tia, 4) if gram_tia is not None else None,
481
+ "coherence_2way": round(gram_coherence_2way, 4) if gram_coherence_2way is not None else None,
482
+ },
483
+ "calibration": {
484
+ "z_st_i": round(z_st_i, 4) if z_st_i is not None else None,
485
+ "z_st_a": round(z_st_a, 4) if z_st_a is not None else None,
486
+ "z_gram_ti": round(z_gram_ti, 4) if z_gram_ti is not None else None,
487
+ "z_gram_ta": round(z_gram_ta, 4) if z_gram_ta is not None else None,
488
+ "z_compl": round(z_compl, 4) if z_compl is not None else None,
489
+ "gram_ia_volume": round(gram_ia_volume, 4) if gram_ia_volume is not None else None,
490
+ "u_ti": round(u_ti, 6) if u_ti is not None else None,
491
+ "u_ta": round(u_ta, 6) if u_ta is not None else None,
492
+ "adaptive_w_ti": round(adaptive_w_ti, 4) if adaptive_w_ti is not None else None,
493
+ "cal_mode": cal_mode if self._calibration is not None else None,
494
+ "w_ti": w_ti,
495
+ "w_3d": w_3d,
496
+ "gamma": gamma,
497
+ "margin_alpha": CMSCI_MARGIN_ALPHA if contrastive_result else None,
498
+ },
499
+ "contrastive": contrastive_result,
500
+ "uncertainty": uncertainty_result,
501
+ "variant_scores": {
502
+ "A_msci": round(variant_a, 4) if variant_a is not None else None,
503
+ "B_gram": round(variant_b, 4) if variant_b is not None else None,
504
+ "C_gram_znorm": round(variant_c, 4) if variant_c is not None else None,
505
+ "D_gram_znorm_contrastive": round(variant_d, 4) if variant_d is not None else None,
506
+ "E_gram_znorm_contrastive_exmcr": round(variant_e, 4) if variant_e is not None else None,
507
+ "F_full_cmsci": round(variant_f, 4) if variant_f is not None else None,
508
+ },
509
+ }
510
+
511
+ def evaluate_batch(
512
+ self,
513
+ items: List[Dict[str, str]],
514
+ n_mc_samples: int = 100,
515
+ ) -> List[Dict[str, Any]]:
516
+ """
517
+ Evaluate a batch of (text, image_path, audio_path) triples.
518
+
519
+ Args:
520
+ items: List of dicts with keys "text", "image_path", "audio_path", "domain".
521
+ n_mc_samples: MC samples per item.
522
+
523
+ Returns:
524
+ List of result dicts from evaluate().
525
+ """
526
+ results = []
527
+ for item in items:
528
+ result = self.evaluate(
529
+ text=item.get("text", ""),
530
+ image_path=item.get("image_path"),
531
+ audio_path=item.get("audio_path"),
532
+ domain=item.get("domain", ""),
533
+ n_mc_samples=n_mc_samples,
534
+ )
535
+ results.append(result)
536
+ return results
src/coherence/gram_volume.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gramian Volume Scoring for Multimodal Coherence.
3
+
4
+ The Gramian volume measures the geometric dispersion of embedding vectors.
5
+ For n L2-normalized vectors, the Gramian matrix G has G_ij = <vi, vj>.
6
+
7
+ volume = sqrt(det(G))
8
+
9
+ Properties:
10
+ - Identical vectors → det(G) = 0 → volume = 0 (perfect alignment)
11
+ - Mutually orthogonal unit vectors → det(G) = 1 → volume = 1 (max dispersion)
12
+ - Coherence = 1 - volume → [0, 1] where 1 = perfect alignment
13
+
14
+ For 2 unit vectors:
15
+ det(G) = 1 - cos²(θ) = sin²(θ)
16
+ volume = |sin(θ)|
17
+ coherence = 1 - |sin(θ)| ≈ cos(θ) for small angles
18
+
19
+ For 3 unit vectors:
20
+ det(G) = 1 - cos²(a) - cos²(b) - cos²(c) + 2·cos(a)·cos(b)·cos(c)
21
+ where a, b, c are pairwise angles
22
+ This captures the full tri-modal geometric relationship in one number.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import numpy as np
28
+
29
+
30
+ def _normalize(v: np.ndarray, eps: float = 1e-12) -> np.ndarray:
31
+ """L2-normalize a vector."""
32
+ v = v.astype(np.float64).squeeze()
33
+ norm = np.linalg.norm(v) + eps
34
+ return v / norm
35
+
36
+
37
+ def gram_volume_2d(v1: np.ndarray, v2: np.ndarray) -> float:
38
+ """
39
+ Gramian volume for 2 vectors (area of parallelogram).
40
+
41
+ For unit vectors: volume = |sin(θ)| where θ is the angle between them.
42
+ Range: [0, 1] — 0 when identical, 1 when orthogonal.
43
+ """
44
+ v1_n = _normalize(v1)
45
+ v2_n = _normalize(v2)
46
+ cos_sim = np.clip(np.dot(v1_n, v2_n), -1.0, 1.0)
47
+ # det(G) = 1 - cos²(θ)
48
+ det_g = 1.0 - cos_sim ** 2
49
+ return float(np.sqrt(max(det_g, 0.0)))
50
+
51
+
52
+ def gram_volume_3d(
53
+ v1: np.ndarray, v2: np.ndarray, v3: np.ndarray,
54
+ ) -> float:
55
+ """
56
+ Gramian volume for 3 vectors (volume of parallelepiped).
57
+
58
+ For unit vectors with pairwise cosines a, b, c:
59
+ det(G) = 1 - a² - b² - c² + 2abc
60
+
61
+ Range: [0, 1] — 0 when all collinear, 1 when mutually orthogonal.
62
+ """
63
+ v1_n = _normalize(v1)
64
+ v2_n = _normalize(v2)
65
+ v3_n = _normalize(v3)
66
+
67
+ a = np.dot(v1_n, v2_n)
68
+ b = np.dot(v1_n, v3_n)
69
+ c = np.dot(v2_n, v3_n)
70
+
71
+ det_g = 1.0 - a**2 - b**2 - c**2 + 2.0 * a * b * c
72
+ return float(np.sqrt(max(det_g, 0.0)))
73
+
74
+
75
+ def gram_volume_nd(*vectors: np.ndarray) -> float:
76
+ """
77
+ Gramian volume for n vectors (general case).
78
+
79
+ Builds the Gram matrix G_ij = <vi, vj> from L2-normalized vectors
80
+ and returns sqrt(det(G)).
81
+
82
+ Args:
83
+ *vectors: Variable number of numpy arrays (embeddings).
84
+
85
+ Returns:
86
+ Gramian volume in [0, 1] for unit vectors.
87
+ """
88
+ n = len(vectors)
89
+ if n == 0:
90
+ return 0.0
91
+ if n == 1:
92
+ return 0.0
93
+ if n == 2:
94
+ return gram_volume_2d(vectors[0], vectors[1])
95
+ if n == 3:
96
+ return gram_volume_3d(vectors[0], vectors[1], vectors[2])
97
+
98
+ normed = [_normalize(v) for v in vectors]
99
+ G = np.zeros((n, n), dtype=np.float64)
100
+ for i in range(n):
101
+ for j in range(i, n):
102
+ dot = np.dot(normed[i], normed[j])
103
+ G[i, j] = dot
104
+ G[j, i] = dot
105
+
106
+ det_g = np.linalg.det(G)
107
+ return float(np.sqrt(max(det_g, 0.0)))
108
+
109
+
110
+ def normalized_gram_coherence(volume: float, n_vectors: int = 2) -> float:
111
+ """
112
+ Map Gramian volume to coherence score in [0, 1].
113
+
114
+ 1 = perfect alignment (volume = 0, all vectors identical)
115
+ 0 = maximum dispersion (volume = 1, mutually orthogonal)
116
+
117
+ Args:
118
+ volume: Gramian volume (output of gram_volume_* functions).
119
+ n_vectors: Number of vectors used (for documentation; mapping is the same).
120
+
121
+ Returns:
122
+ Coherence score in [0, 1].
123
+ """
124
+ return float(max(0.0, min(1.0, 1.0 - volume)))
src/coherence/negative_bank.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contrastive Negative Bank for cMSCI Calibration.
3
+
4
+ Computes contrastive margins by comparing a matched (text, image, audio)
5
+ triple against hard-negative alternatives from the embedding indexes.
6
+
7
+ A positive contrastive margin means the matched triple has tighter
8
+ geometric coherence than mismatched alternatives — the defining
9
+ property of a well-calibrated metric.
10
+
11
+ Contrastive margin:
12
+ margin = mean(neg_volumes) - matched_volume
13
+ > 0 → matched triple is more coherent than negatives (good)
14
+ ≤ 0 → metric cannot distinguish matched from mismatched (bad)
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import logging
20
+ from pathlib import Path
21
+ from typing import Dict, List, Optional, Tuple
22
+
23
+ import numpy as np
24
+
25
+ from src.coherence.gram_volume import gram_volume_2d, gram_volume_3d, normalized_gram_coherence
26
+ from src.embeddings.similarity import l2_normalize
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class NegativeBank:
32
+ """
33
+ Loads pre-computed embedding indexes and provides hard negatives.
34
+
35
+ Hard negatives are embeddings with high individual similarity to the
36
+ query but from a different domain — the most challenging cases for
37
+ the coherence metric.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ image_index_path: str = "data/embeddings/image_index.npz",
43
+ audio_index_path: str = "data/embeddings/audio_index.npz",
44
+ ):
45
+ self._image_ids: Optional[np.ndarray] = None
46
+ self._image_embs: Optional[np.ndarray] = None
47
+ self._image_domains: Optional[np.ndarray] = None
48
+ self._audio_ids: Optional[np.ndarray] = None
49
+ self._audio_embs: Optional[np.ndarray] = None
50
+ self._audio_domains: Optional[np.ndarray] = None
51
+
52
+ self._load_index(image_index_path, "image")
53
+ self._load_index(audio_index_path, "audio")
54
+
55
+ def _load_index(self, path: str, modality: str) -> None:
56
+ p = Path(path)
57
+ if not p.exists():
58
+ logger.warning("Index not found: %s — %s negatives disabled", path, modality)
59
+ return
60
+
61
+ data = np.load(path, allow_pickle=True)
62
+ ids = data["ids"] if "ids" in data else data.get("paths", np.array([]))
63
+ embs = data["embs"] if "embs" in data else data.get("embeddings", np.array([]))
64
+ domains = data["domains"] if "domains" in data else np.array(["other"] * len(ids))
65
+
66
+ if modality == "image":
67
+ self._image_ids = ids
68
+ self._image_embs = embs.astype(np.float32)
69
+ self._image_domains = domains
70
+ logger.info("Loaded image index: %d entries", len(ids))
71
+ else:
72
+ self._audio_ids = ids
73
+ self._audio_embs = embs.astype(np.float32)
74
+ self._audio_domains = domains
75
+ logger.info("Loaded audio index: %d entries", len(ids))
76
+
77
+ @property
78
+ def has_images(self) -> bool:
79
+ return self._image_embs is not None and len(self._image_embs) > 0
80
+
81
+ @property
82
+ def has_audio(self) -> bool:
83
+ return self._audio_embs is not None and len(self._audio_embs) > 0
84
+
85
+ def get_hard_negative_images(
86
+ self,
87
+ text_emb: np.ndarray,
88
+ exclude_domain: str = "",
89
+ k: int = 5,
90
+ ) -> List[np.ndarray]:
91
+ """
92
+ Get top-k hardest negative images (high text similarity but wrong domain).
93
+
94
+ Args:
95
+ text_emb: CLIP text embedding for the query.
96
+ exclude_domain: Domain to exclude (the correct domain).
97
+ k: Number of negatives to return.
98
+
99
+ Returns:
100
+ List of image embeddings (hard negatives).
101
+ """
102
+ if not self.has_images:
103
+ return []
104
+
105
+ text_n = l2_normalize(text_emb.squeeze())
106
+ sims = self._image_embs @ text_n
107
+
108
+ # Filter by domain: exclude the matched domain
109
+ if exclude_domain:
110
+ mask = np.array([d != exclude_domain for d in self._image_domains])
111
+ else:
112
+ mask = np.ones(len(sims), dtype=bool)
113
+
114
+ sims_masked = np.where(mask, sims, -np.inf)
115
+ top_k_idx = np.argsort(sims_masked)[-k:][::-1]
116
+
117
+ return [self._image_embs[i] for i in top_k_idx if sims_masked[i] > -np.inf]
118
+
119
+ def get_hard_negative_audio(
120
+ self,
121
+ text_emb: np.ndarray,
122
+ exclude_domain: str = "",
123
+ k: int = 5,
124
+ ) -> List[np.ndarray]:
125
+ """
126
+ Get top-k hardest negative audio (high text similarity but wrong domain).
127
+
128
+ Args:
129
+ text_emb: CLAP text embedding for the query.
130
+ exclude_domain: Domain to exclude.
131
+ k: Number of negatives to return.
132
+
133
+ Returns:
134
+ List of audio embeddings (hard negatives).
135
+ """
136
+ if not self.has_audio:
137
+ return []
138
+
139
+ text_n = l2_normalize(text_emb.squeeze())
140
+ sims = self._audio_embs @ text_n
141
+
142
+ if exclude_domain:
143
+ mask = np.array([d != exclude_domain for d in self._audio_domains])
144
+ else:
145
+ mask = np.ones(len(sims), dtype=bool)
146
+
147
+ sims_masked = np.where(mask, sims, -np.inf)
148
+ top_k_idx = np.argsort(sims_masked)[-k:][::-1]
149
+
150
+ return [self._audio_embs[i] for i in top_k_idx if sims_masked[i] > -np.inf]
151
+
152
+ def compute_contrastive_margin(
153
+ self,
154
+ matched_volume: float,
155
+ text_clip_emb: np.ndarray,
156
+ image_emb: np.ndarray,
157
+ text_clap_emb: Optional[np.ndarray] = None,
158
+ audio_emb: Optional[np.ndarray] = None,
159
+ domain: str = "",
160
+ k: int = 5,
161
+ ) -> Dict[str, float]:
162
+ """
163
+ Compute contrastive margin against hard negatives.
164
+
165
+ For each hard negative, computes the gram volume of the negative
166
+ triple and averages. Margin = mean(neg_volumes) - matched_volume.
167
+
168
+ A positive margin means the matched triple is geometrically tighter
169
+ than hard-negative alternatives.
170
+
171
+ Args:
172
+ matched_volume: Gram volume of the matched (text, image, audio) triple.
173
+ text_clip_emb: CLIP text embedding (for finding negative images).
174
+ image_emb: CLIP image embedding of the matched image.
175
+ text_clap_emb: CLAP text embedding (for finding negative audio).
176
+ audio_emb: CLAP audio embedding of the matched audio.
177
+ domain: Domain of the matched prompt (excluded from negatives).
178
+ k: Number of hard negatives per modality.
179
+
180
+ Returns:
181
+ Dict with margin, mean_neg_volume, n_negatives.
182
+ """
183
+ neg_volumes = []
184
+
185
+ # Image negatives: replace matched image with hard negative
186
+ neg_images = self.get_hard_negative_images(text_clip_emb, domain, k)
187
+ for neg_img in neg_images:
188
+ vol = gram_volume_2d(text_clip_emb, neg_img)
189
+ neg_volumes.append(vol)
190
+
191
+ # Audio negatives: replace matched audio with hard negative
192
+ if text_clap_emb is not None:
193
+ neg_audios = self.get_hard_negative_audio(text_clap_emb, domain, k)
194
+ for neg_aud in neg_audios:
195
+ vol = gram_volume_2d(text_clap_emb, neg_aud)
196
+ neg_volumes.append(vol)
197
+
198
+ if not neg_volumes:
199
+ return {
200
+ "margin": 0.0,
201
+ "mean_neg_volume": matched_volume,
202
+ "n_negatives": 0,
203
+ }
204
+
205
+ mean_neg = float(np.mean(neg_volumes))
206
+ margin = mean_neg - matched_volume
207
+
208
+ return {
209
+ "margin": float(margin),
210
+ "mean_neg_volume": mean_neg,
211
+ "n_negatives": len(neg_volumes),
212
+ }
src/config/settings.py CHANGED
@@ -106,3 +106,47 @@ DRIFT_ASYMMETRY_THRESHOLD = 0.15 # |st_i - st_a| gap to flag drift
106
  RERATING_FRACTION = 0.20
107
  KAPPA_ACCEPTABLE_THRESHOLD = 0.70
108
  ALPHA_ACCEPTABLE_THRESHOLD = 0.667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  RERATING_FRACTION = 0.20
107
  KAPPA_ACCEPTABLE_THRESHOLD = 0.70
108
  ALPHA_ACCEPTABLE_THRESHOLD = 0.667
109
+
110
+ # ---------------------------------------------------------------------------
111
+ # cMSCI (Calibrated Multimodal Semantic Coherence Index)
112
+ # ---------------------------------------------------------------------------
113
+
114
+ # Calibration store (fitted from RQ1 baseline data)
115
+ CMSCI_CALIBRATION_PATH = PROJECT_ROOT / "artifacts" / "cmsci_calibration.json"
116
+
117
+ # Ex-MCR cross-space alignment (CLAP → CLIP projection)
118
+ EXMCR_WEIGHTS_PATH = PROJECT_ROOT / "models" / "exmcr" / "ex_clap.pt"
119
+
120
+ # Cross-Space Bridge (CLIP image + CLAP audio → shared 256-d bridge space)
121
+ BRIDGE_WEIGHTS_PATH = PROJECT_ROOT / "models" / "bridge" / "bridge_best.pt"
122
+
123
+ # Probabilistic adapters (ProbVLM-style uncertainty)
124
+ PROB_CLIP_ADAPTER_PATH = PROJECT_ROOT / "models" / "prob_adapters" / "clip_adapter.pt"
125
+ PROB_CLAP_ADAPTER_PATH = PROJECT_ROOT / "models" / "prob_adapters" / "clap_adapter.pt"
126
+
127
+ # Full pipeline optimized parameters (via LOO-CV on RQ3 human ratings)
128
+ # Full-sample rho=0.608 (p=0.0004), LOO-CV rho=0.546 (p=0.0018), overfit gap=0.001
129
+ # Selected in 87% of LOO folds (26/30) — highly stable
130
+ CMSCI_MARGIN_ALPHA = 16 # Margin scaling factor (amplifies contrastive signal)
131
+ CMSCI_CHANNEL_WEIGHT_TI = 0.90 # Text-image channel weight (1 - w for text-audio)
132
+ CMSCI_CALIBRATION_MODE = "gram" # "cosine" (z-norm cosine sims) or "gram" (z-norm gram coherences)
133
+
134
+ # Variant E: ExMCR cross-modal complementarity (w_3d=0 recovers D exactly)
135
+ # ExMCR projects CLAP audio → CLIP space; complementarity = Gramian dispersion
136
+ # High complementarity = image and audio contribute unique perspectives (rewarded)
137
+ CMSCI_W_3D = 0.45 # Weight for z-normalized IA complementarity
138
+ # Variant F: ProbVLM adaptive channel weighting (gamma=0 recovers E exactly)
139
+ CMSCI_GAMMA = 0.10 # Mixing ratio: w_final = (1-gamma)*base_w + gamma*adaptive_w
140
+
141
+ # Contrastive negative bank
142
+ CMSCI_NEGATIVE_K = 5 # Number of hard negatives per modality
143
+ CMSCI_NEGATIVE_BANK_ENABLED = True # Enable/disable contrastive calibration
144
+
145
+ # MC sampling for uncertainty estimation
146
+ CMSCI_MC_SAMPLES = 100 # Number of Monte Carlo samples for Variant F
147
+
148
+ # Probabilistic adapter training
149
+ PROB_ADAPTER_EPOCHS = 100
150
+ PROB_ADAPTER_LR = 1e-4
151
+ PROB_ADAPTER_BATCH_SIZE = 32
152
+ PROB_ADAPTER_PATIENCE = 15
src/embeddings/prob_adapter_trainer.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Loop for ProbVLM-Style Probabilistic Adapters.
3
+
4
+ Trains lightweight post-hoc adapters on top of frozen CLIP/CLAP encoders.
5
+ Each adapter learns to predict uncertainty (Generalized Gaussian parameters)
6
+ for a single embedding space.
7
+
8
+ Two adapters to train:
9
+ 1. CLIP adapter: trained on (image_embedding, text_embedding) pairs
10
+ 2. CLAP adapter: trained on (audio_embedding, text_embedding) pairs
11
+
12
+ Training data:
13
+ - Our 57 images paired with text descriptions (CLIP pairs)
14
+ - Our 104 audio files paired with text descriptions (CLAP pairs)
15
+ - All 30 RQ1 prompts × matched media as additional pairs
16
+
17
+ Loss:
18
+ L = L1(mu, target) + GenGaussLoss(mu, alpha, beta, target)
19
+
20
+ GenGaussLoss:
21
+ -log p(target | mu, alpha, beta) ∝ log(alpha) - log(beta) + (|target - mu| / alpha)^beta
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import logging
27
+ from pathlib import Path
28
+ from typing import Dict, List, Optional, Tuple
29
+
30
+ import numpy as np
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ try:
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ from torch.utils.data import DataLoader, Dataset, random_split
39
+ TORCH_AVAILABLE = True
40
+ except ImportError:
41
+ TORCH_AVAILABLE = False
42
+
43
+ from src.embeddings.probabilistic_adapter import ProbabilisticAdapter
44
+
45
+
46
+ class EmbeddingPairDataset(Dataset):
47
+ """Dataset of (input_embedding, target_embedding) pairs."""
48
+
49
+ def __init__(self, inputs: np.ndarray, targets: np.ndarray):
50
+ if not TORCH_AVAILABLE:
51
+ raise ImportError("PyTorch required")
52
+ assert len(inputs) == len(targets)
53
+ self.inputs = torch.tensor(inputs, dtype=torch.float32)
54
+ self.targets = torch.tensor(targets, dtype=torch.float32)
55
+
56
+ def __len__(self) -> int:
57
+ return len(self.inputs)
58
+
59
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
60
+ return self.inputs[idx], self.targets[idx]
61
+
62
+
63
+ class GenGaussNLL(nn.Module):
64
+ """
65
+ Negative log-likelihood loss for Generalized Gaussian distribution.
66
+
67
+ -log p(x | mu, alpha, beta) = log(2*alpha) + log(Gamma(1/beta)/beta) + (|x - mu| / alpha)^beta
68
+
69
+ Simplified (dropping constant terms):
70
+ L = log(alpha) + (|target - mu| / alpha)^beta
71
+ """
72
+
73
+ def forward(
74
+ self,
75
+ mu: torch.Tensor,
76
+ alpha: torch.Tensor,
77
+ beta: torch.Tensor,
78
+ target: torch.Tensor,
79
+ ) -> torch.Tensor:
80
+ residual = torch.abs(target - mu)
81
+ # Clamp alpha to avoid division by zero
82
+ alpha_c = torch.clamp(alpha, min=1e-6)
83
+ nll = torch.log(alpha_c) + (residual / alpha_c).pow(beta)
84
+ return nll.mean()
85
+
86
+
87
+ def train_prob_adapter(
88
+ input_embeddings: np.ndarray,
89
+ target_embeddings: np.ndarray,
90
+ epochs: int = 100,
91
+ lr: float = 1e-4,
92
+ batch_size: int = 32,
93
+ val_split: float = 0.15,
94
+ patience: int = 15,
95
+ output_path: Optional[str] = None,
96
+ adapter_name: str = "adapter",
97
+ ) -> ProbabilisticAdapter:
98
+ """
99
+ Train a ProbabilisticAdapter on paired embeddings.
100
+
101
+ Args:
102
+ input_embeddings: Source embeddings [N, 512] (e.g. image CLIP or audio CLAP).
103
+ target_embeddings: Target embeddings [N, 512] (e.g. text CLIP or text CLAP).
104
+ epochs: Maximum training epochs.
105
+ lr: Learning rate.
106
+ batch_size: Batch size.
107
+ val_split: Fraction for validation.
108
+ patience: Early stopping patience.
109
+ output_path: If set, save best model here.
110
+ adapter_name: Name for logging.
111
+
112
+ Returns:
113
+ Trained ProbabilisticAdapter.
114
+ """
115
+ if not TORCH_AVAILABLE:
116
+ raise ImportError("PyTorch required for training")
117
+
118
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
119
+
120
+ # Build dataset
121
+ dataset = EmbeddingPairDataset(input_embeddings, target_embeddings)
122
+ n_val = max(1, int(len(dataset) * val_split))
123
+ n_train = len(dataset) - n_val
124
+ train_ds, val_ds = random_split(
125
+ dataset, [n_train, n_val],
126
+ generator=torch.Generator().manual_seed(42),
127
+ )
128
+
129
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=len(train_ds) > batch_size)
130
+ val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
131
+
132
+ # Build model
133
+ input_dim = input_embeddings.shape[1]
134
+ adapter = ProbabilisticAdapter(input_dim=input_dim).to(device)
135
+ optimizer = torch.optim.AdamW(adapter.parameters(), lr=lr, weight_decay=1e-4)
136
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
137
+
138
+ l1_loss = nn.L1Loss()
139
+ gg_loss = GenGaussNLL()
140
+
141
+ best_val_loss = float("inf")
142
+ patience_counter = 0
143
+
144
+ logger.info(
145
+ "Training %s adapter: %d train, %d val, %d epochs, device=%s",
146
+ adapter_name, n_train, n_val, epochs, device,
147
+ )
148
+
149
+ for epoch in range(epochs):
150
+ # Train
151
+ adapter.train()
152
+ train_losses = []
153
+ for inp, tgt in train_loader:
154
+ inp, tgt = inp.to(device), tgt.to(device)
155
+ optimizer.zero_grad()
156
+
157
+ mu, alpha, beta = adapter(inp)
158
+ loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt)
159
+ loss.backward()
160
+ torch.nn.utils.clip_grad_norm_(adapter.parameters(), max_norm=1.0)
161
+ optimizer.step()
162
+ train_losses.append(loss.item())
163
+
164
+ scheduler.step()
165
+
166
+ # Validate
167
+ adapter.eval()
168
+ val_losses = []
169
+ with torch.no_grad():
170
+ for inp, tgt in val_loader:
171
+ inp, tgt = inp.to(device), tgt.to(device)
172
+ mu, alpha, beta = adapter(inp)
173
+ loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt)
174
+ val_losses.append(loss.item())
175
+
176
+ avg_train = np.mean(train_losses)
177
+ avg_val = np.mean(val_losses) if val_losses else float("inf")
178
+
179
+ if (epoch + 1) % 10 == 0 or epoch == 0:
180
+ logger.info(
181
+ " [%s] Epoch %d/%d: train=%.4f, val=%.4f",
182
+ adapter_name, epoch + 1, epochs, avg_train, avg_val,
183
+ )
184
+
185
+ # Early stopping
186
+ if avg_val < best_val_loss:
187
+ best_val_loss = avg_val
188
+ patience_counter = 0
189
+ if output_path:
190
+ adapter.save(output_path)
191
+ else:
192
+ patience_counter += 1
193
+ if patience_counter >= patience:
194
+ logger.info(" [%s] Early stopping at epoch %d", adapter_name, epoch + 1)
195
+ break
196
+
197
+ # Load best if saved
198
+ if output_path and Path(output_path).exists():
199
+ adapter = ProbabilisticAdapter.load(output_path)
200
+ adapter = adapter.to(device)
201
+ else:
202
+ adapter = adapter.cpu()
203
+
204
+ adapter.eval()
205
+ logger.info(" [%s] Training complete. Best val_loss=%.4f", adapter_name, best_val_loss)
206
+ return adapter
207
+
208
+
209
+ def build_training_pairs_from_index(
210
+ embedding_index_path: str,
211
+ text_embedder_fn,
212
+ modality: str = "image",
213
+ ) -> Tuple[np.ndarray, np.ndarray]:
214
+ """
215
+ Build (media_embedding, text_embedding) pairs from an embedding index.
216
+
217
+ For each media file in the index, generates a text description from
218
+ the filename/metadata and embeds it.
219
+
220
+ Args:
221
+ embedding_index_path: Path to image_index.npz or audio_index.npz.
222
+ text_embedder_fn: Function that takes text -> np.ndarray embedding.
223
+ modality: "image" for CLIP text, "audio" for CLAP text.
224
+
225
+ Returns:
226
+ (media_embeddings, text_embeddings) both shape [N, 512].
227
+ """
228
+ data = np.load(embedding_index_path, allow_pickle=True)
229
+ ids = data["ids"] if "ids" in data else data.get("paths", np.array([]))
230
+ embs = data["embs"] if "embs" in data else data.get("embeddings", np.array([]))
231
+ domains = data["domains"] if "domains" in data else np.array(["other"] * len(ids))
232
+
233
+ media_embs = []
234
+ text_embs = []
235
+
236
+ for i, (file_id, domain) in enumerate(zip(ids, domains)):
237
+ # Generate caption from filename
238
+ name = Path(str(file_id)).stem
239
+ # Clean up filename to make a caption
240
+ caption = name.replace("_", " ").replace("-", " ")
241
+ # Remove common prefixes
242
+ for prefix in ["fs ", "wm ", "proc "]:
243
+ if caption.lower().startswith(prefix):
244
+ caption = caption[len(prefix):]
245
+ # Add domain context
246
+ if domain != "other":
247
+ caption = f"{domain}: {caption}"
248
+
249
+ try:
250
+ text_emb = text_embedder_fn(caption)
251
+ media_embs.append(embs[i])
252
+ text_embs.append(text_emb)
253
+ except Exception as e:
254
+ logger.warning("Skipping %s: %s", file_id, e)
255
+
256
+ return np.array(media_embs), np.array(text_embs)
src/embeddings/probabilistic_adapter.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ProbVLM-Style Probabilistic Adapter for Uncertainty Estimation.
3
+
4
+ Converts point embeddings into distributions (Generalized Gaussian)
5
+ following the BayesCap approach from ProbVLM.
6
+
7
+ Each adapter takes a frozen embedding and predicts:
8
+ mu: Shift from the input embedding (residual)
9
+ alpha: Scale parameter (controls spread)
10
+ beta: Shape parameter (controls tail behavior)
11
+
12
+ These define a Generalized Gaussian distribution:
13
+ p(x) ∝ exp(-(|x - mu| / alpha)^beta)
14
+
15
+ MC sampling from this distribution produces N embedding samples,
16
+ which propagate uncertainty through the Gramian volume computation.
17
+
18
+ Architecture: BayesCap_MLP
19
+ input → Linear(d, hidden) → ReLU → Dropout
20
+ → Linear(hidden, hidden) → ReLU → Dropout
21
+ → Three heads: mu_head, alpha_head, beta_head
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import logging
27
+ from pathlib import Path
28
+ from typing import Dict, Optional, Tuple
29
+
30
+ import numpy as np
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ try:
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ TORCH_AVAILABLE = True
39
+ except ImportError:
40
+ TORCH_AVAILABLE = False
41
+
42
+
43
+ def _check_torch():
44
+ if not TORCH_AVAILABLE:
45
+ raise ImportError("PyTorch required for ProbabilisticAdapter")
46
+
47
+
48
+ class ProbabilisticAdapter(nn.Module):
49
+ """
50
+ BayesCap-style adapter that maps point embeddings to distributions.
51
+
52
+ Takes a frozen embedding (from CLIP or CLAP) and predicts
53
+ Generalized Gaussian parameters: (mu, alpha, beta).
54
+
55
+ The adapter is lightweight (~0.5M params) and trains in minutes
56
+ on small datasets.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ input_dim: int = 512,
62
+ hidden_dim: int = 256,
63
+ num_layers: int = 3,
64
+ dropout: float = 0.1,
65
+ ):
66
+ _check_torch()
67
+ super().__init__()
68
+
69
+ self.input_dim = input_dim
70
+
71
+ # Shared backbone
72
+ layers = []
73
+ in_d = input_dim
74
+ for _ in range(num_layers - 1):
75
+ layers.extend([
76
+ nn.Linear(in_d, hidden_dim),
77
+ nn.ReLU(),
78
+ nn.Dropout(dropout),
79
+ ])
80
+ in_d = hidden_dim
81
+ self.backbone = nn.Sequential(*layers)
82
+
83
+ # Three output heads
84
+ self.mu_head = nn.Linear(hidden_dim, input_dim)
85
+ self.alpha_head = nn.Linear(hidden_dim, input_dim)
86
+ self.beta_head = nn.Linear(hidden_dim, input_dim)
87
+
88
+ self.config = {
89
+ "input_dim": input_dim,
90
+ "hidden_dim": hidden_dim,
91
+ "num_layers": num_layers,
92
+ "dropout": dropout,
93
+ }
94
+
95
+ def forward(
96
+ self, embedding: torch.Tensor,
97
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
98
+ """
99
+ Predict distribution parameters from a point embedding.
100
+
101
+ Args:
102
+ embedding: Input embedding [batch, input_dim].
103
+
104
+ Returns:
105
+ mu: Location parameter [batch, input_dim] (embedding + residual)
106
+ alpha: Scale parameter [batch, input_dim] (> 0, via softplus)
107
+ beta: Shape parameter [batch, input_dim] (> 0, via softplus)
108
+ """
109
+ h = self.backbone(embedding)
110
+
111
+ # mu: residual + input (anchored to original embedding)
112
+ mu = embedding + self.mu_head(h)
113
+
114
+ # alpha, beta: positive via softplus
115
+ alpha = F.softplus(self.alpha_head(h)) + 1e-6
116
+ beta = F.softplus(self.beta_head(h)) + 1e-6
117
+
118
+ return mu, alpha, beta
119
+
120
+ def sample(
121
+ self,
122
+ embedding: np.ndarray,
123
+ n_samples: int = 100,
124
+ ) -> np.ndarray:
125
+ """
126
+ Draw Monte Carlo samples from the predicted distribution.
127
+
128
+ Uses the reparameterization trick for Generalized Gaussian:
129
+ x = mu + alpha * sign(u) * |u|^(1/beta)
130
+ where u ~ Uniform(-1, 1)
131
+
132
+ Args:
133
+ embedding: Input embedding, shape (dim,) or (1, dim).
134
+ n_samples: Number of MC samples.
135
+
136
+ Returns:
137
+ Samples array, shape (n_samples, dim).
138
+ """
139
+ _check_torch()
140
+ self.eval()
141
+
142
+ emb = embedding.squeeze()
143
+ if emb.ndim == 1:
144
+ emb = emb[np.newaxis, :]
145
+
146
+ with torch.no_grad():
147
+ x = torch.tensor(emb, dtype=torch.float32)
148
+ mu, alpha, beta = self.forward(x)
149
+
150
+ # Expand for sampling: [1, dim] -> [n_samples, dim]
151
+ mu = mu.expand(n_samples, -1)
152
+ alpha = alpha.expand(n_samples, -1)
153
+ beta = beta.expand(n_samples, -1)
154
+
155
+ # Reparameterized sampling from Generalized Gaussian
156
+ u = torch.rand_like(mu) * 2 - 1 # Uniform(-1, 1)
157
+ sign = torch.sign(u)
158
+ samples = mu + alpha * sign * (torch.abs(u) + 1e-8).pow(1.0 / beta)
159
+
160
+ # L2 normalize samples (stay on unit sphere)
161
+ samples = F.normalize(samples, p=2, dim=-1)
162
+
163
+ return samples.cpu().numpy()
164
+
165
+ def uncertainty(self, embedding: np.ndarray) -> float:
166
+ """
167
+ Compute scalar aleatoric uncertainty for an embedding.
168
+
169
+ Returns the mean predicted alpha (scale parameter) across dimensions.
170
+ High alpha → high uncertainty → wide distribution.
171
+
172
+ Args:
173
+ embedding: Input embedding, shape (dim,) or (1, dim).
174
+
175
+ Returns:
176
+ Scalar uncertainty value (mean alpha).
177
+ """
178
+ _check_torch()
179
+ self.eval()
180
+
181
+ emb = embedding.squeeze()
182
+ if emb.ndim == 1:
183
+ emb = emb[np.newaxis, :]
184
+
185
+ with torch.no_grad():
186
+ x = torch.tensor(emb, dtype=torch.float32)
187
+ _, alpha, _ = self.forward(x)
188
+ return float(alpha.mean().item())
189
+
190
+ def save(self, path: str) -> None:
191
+ """Save adapter weights + config."""
192
+ _check_torch()
193
+ import json
194
+ p = Path(path)
195
+ p.parent.mkdir(parents=True, exist_ok=True)
196
+ torch.save(self.state_dict(), p)
197
+ config_path = p.with_suffix(".json")
198
+ with config_path.open("w") as f:
199
+ json.dump(self.config, f, indent=2)
200
+ logger.info("Saved ProbabilisticAdapter to %s", path)
201
+
202
+ @classmethod
203
+ def load(cls, path: str) -> "ProbabilisticAdapter":
204
+ """Load adapter from saved weights."""
205
+ _check_torch()
206
+ import json
207
+ p = Path(path)
208
+ config_path = p.with_suffix(".json")
209
+ with config_path.open("r") as f:
210
+ config = json.load(f)
211
+ model = cls(**config)
212
+ state_dict = torch.load(p, map_location="cpu", weights_only=True)
213
+ model.load_state_dict(state_dict)
214
+ model.eval()
215
+ logger.info("Loaded ProbabilisticAdapter from %s", path)
216
+ return model
src/embeddings/space_alignment.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ex-MCR Cross-Space Alignment: CLAP Audio → CLIP Space.
3
+
4
+ Ex-MCR (Ex-Modal Contrastive Retrieval) projects CLAP audio embeddings
5
+ INTO CLIP space while keeping CLIP embeddings unchanged. This lets us
6
+ compute meaningful image-audio similarity and full 3-way Gramian volume.
7
+
8
+ Architecture decision: Ex-MCR over C-MCR because:
9
+ - Ex-MCR keeps CLIP embeddings frozen (no recomputation needed)
10
+ - C-MCR projects BOTH spaces into a new space (breaks everything)
11
+
12
+ The projector is a lightweight MLP:
13
+ CLAP 512-d → Linear(512, 512) → ReLU → Linear(512, 512) → L2 norm
14
+
15
+ If Ex-MCR weights are not available, falls back to an untrained identity
16
+ projection (which is equivalent to not using the projector).
17
+
18
+ CLAP compatibility note:
19
+ Our project uses `laion/clap-htsat-unfused`.
20
+ Ex-MCR uses `laion_clap_fullset_fusion` (different model).
21
+ If projections are poor with our CLAP, switch to the fusion model.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import logging
27
+ from pathlib import Path
28
+ from typing import Optional
29
+
30
+ import numpy as np
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ try:
35
+ import torch
36
+ import torch.nn as nn
37
+ import torch.nn.functional as F
38
+ TORCH_AVAILABLE = True
39
+ except ImportError:
40
+ TORCH_AVAILABLE = False
41
+
42
+
43
+ class ExMCRProjector:
44
+ """
45
+ Projects CLAP audio embeddings into CLIP space.
46
+
47
+ Usage:
48
+ proj = ExMCRProjector("models/exmcr/ex_clap.pt")
49
+ audio_in_clip = proj.project_audio(clap_embedding) # now comparable to CLIP
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ weights_path: Optional[str] = None,
55
+ device: str = "cpu",
56
+ ):
57
+ """
58
+ Args:
59
+ weights_path: Path to Ex-MCR CLAP→CLIP projection weights (.pt).
60
+ If None or file doesn't exist, uses identity (passthrough).
61
+ device: Torch device for inference.
62
+ """
63
+ self._model = None
64
+ self._device = device
65
+ self._identity_mode = True
66
+
67
+ if weights_path and Path(weights_path).exists() and TORCH_AVAILABLE:
68
+ self._load_weights(weights_path)
69
+ elif weights_path and not Path(weights_path).exists():
70
+ logger.warning(
71
+ "Ex-MCR weights not found: %s — using identity projection", weights_path
72
+ )
73
+
74
+ def _load_weights(self, path: str) -> None:
75
+ """Load Ex-MCR projection head from saved weights."""
76
+ state_dict = torch.load(path, map_location=self._device, weights_only=True)
77
+
78
+ # Detect architecture from state dict keys
79
+ # Ex-MCR uses: layers.0.weight, layers.0.bias, layers.2.weight, layers.2.bias
80
+ # or: 0.weight, 0.bias, 2.weight, 2.bias
81
+ keys = list(state_dict.keys())
82
+
83
+ # Build matching MLP
84
+ if any("layers" in k for k in keys):
85
+ # Format: layers.0.weight etc.
86
+ in_dim = state_dict["layers.0.weight"].shape[1]
87
+ hidden_dim = state_dict["layers.0.weight"].shape[0]
88
+ out_dim = state_dict["layers.2.weight"].shape[0]
89
+ model = nn.Sequential(
90
+ nn.Linear(in_dim, hidden_dim),
91
+ nn.ReLU(),
92
+ nn.Linear(hidden_dim, out_dim),
93
+ )
94
+ # Rename keys to match sequential
95
+ new_state = {}
96
+ for k, v in state_dict.items():
97
+ new_key = k.replace("layers.", "")
98
+ new_state[new_key] = v
99
+ model.load_state_dict(new_state)
100
+ elif any(k.startswith("0.") for k in keys):
101
+ # Format: 0.weight, 0.bias, 2.weight, 2.bias (Sequential)
102
+ in_dim = state_dict["0.weight"].shape[1]
103
+ hidden_dim = state_dict["0.weight"].shape[0]
104
+ out_dim = state_dict["2.weight"].shape[0]
105
+ model = nn.Sequential(
106
+ nn.Linear(in_dim, hidden_dim),
107
+ nn.ReLU(),
108
+ nn.Linear(hidden_dim, out_dim),
109
+ )
110
+ model.load_state_dict(state_dict)
111
+ else:
112
+ # Generic: try to infer from weight shapes
113
+ weight_keys = [k for k in keys if "weight" in k]
114
+ if len(weight_keys) >= 2:
115
+ first_w = state_dict[weight_keys[0]]
116
+ last_w = state_dict[weight_keys[-1]]
117
+ in_dim = first_w.shape[1]
118
+ hidden_dim = first_w.shape[0]
119
+ out_dim = last_w.shape[0]
120
+ model = nn.Sequential(
121
+ nn.Linear(in_dim, hidden_dim),
122
+ nn.ReLU(),
123
+ nn.Linear(hidden_dim, out_dim),
124
+ )
125
+ model.load_state_dict(state_dict)
126
+ else:
127
+ logger.warning("Unrecognized Ex-MCR weight format — using identity")
128
+ return
129
+
130
+ model.to(self._device)
131
+ model.eval()
132
+ self._model = model
133
+ self._identity_mode = False
134
+ logger.info(
135
+ "Ex-MCR projector loaded: %d → %d → %d (from %s)",
136
+ in_dim, hidden_dim, out_dim, path,
137
+ )
138
+
139
+ @property
140
+ def is_identity(self) -> bool:
141
+ """True if projector is passthrough (no trained weights loaded)."""
142
+ return self._identity_mode
143
+
144
+ def project_audio(self, clap_embedding: np.ndarray) -> np.ndarray:
145
+ """
146
+ Project CLAP audio embedding into CLIP space.
147
+
148
+ Args:
149
+ clap_embedding: CLAP audio embedding, shape (512,) or (N, 512).
150
+
151
+ Returns:
152
+ Projected embedding in CLIP space, L2-normalized.
153
+ """
154
+ if self._identity_mode:
155
+ emb = clap_embedding.squeeze().astype(np.float32)
156
+ norm = np.linalg.norm(emb) + 1e-12
157
+ return emb / norm
158
+
159
+ if not TORCH_AVAILABLE:
160
+ return clap_embedding.squeeze().astype(np.float32)
161
+
162
+ was_1d = clap_embedding.ndim == 1 or (
163
+ clap_embedding.ndim == 2 and clap_embedding.shape[0] == 1
164
+ )
165
+ emb = clap_embedding.squeeze()
166
+ if emb.ndim == 1:
167
+ emb = emb[np.newaxis, :]
168
+
169
+ with torch.no_grad():
170
+ x = torch.tensor(emb, dtype=torch.float32, device=self._device)
171
+ projected = self._model(x)
172
+ projected = F.normalize(projected, p=2, dim=-1)
173
+ result = projected.cpu().numpy()
174
+
175
+ if was_1d:
176
+ return result.squeeze(0)
177
+ return result
178
+
179
+ def project_audio_batch(self, clap_embeddings: np.ndarray) -> np.ndarray:
180
+ """
181
+ Batch projection of CLAP audio embeddings into CLIP space.
182
+
183
+ Args:
184
+ clap_embeddings: Shape (N, 512).
185
+
186
+ Returns:
187
+ Projected embeddings in CLIP space, shape (N, 512), L2-normalized.
188
+ """
189
+ if self._identity_mode:
190
+ norms = np.linalg.norm(clap_embeddings, axis=1, keepdims=True) + 1e-12
191
+ return (clap_embeddings / norms).astype(np.float32)
192
+
193
+ if not TORCH_AVAILABLE:
194
+ norms = np.linalg.norm(clap_embeddings, axis=1, keepdims=True) + 1e-12
195
+ return (clap_embeddings / norms).astype(np.float32)
196
+
197
+ with torch.no_grad():
198
+ x = torch.tensor(clap_embeddings, dtype=torch.float32, device=self._device)
199
+ projected = self._model(x)
200
+ projected = F.normalize(projected, p=2, dim=-1)
201
+ return projected.cpu().numpy()