pratik-250620 commited on
Commit
fb5a7e6
·
verified ·
1 Parent(s): 960dff6

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. app.py +163 -51
  2. src/embeddings/audio_embedder.py +13 -5
app.py CHANGED
@@ -5,7 +5,7 @@ Live demonstration of multimodal generation + coherence evaluation.
5
  Enter a scene description and the system produces coherent text, image,
6
  and audio with real-time MSCI scoring.
7
 
8
- Pipeline: HF Inference API (text + planning) + CLIP retrieval (image) + CLAP retrieval (audio)
9
  Planning modes: direct, planner, council (3-way), extended_prompt (3x tokens)
10
  """
11
 
@@ -15,6 +15,7 @@ import json
15
  import logging
16
  import os
17
  import sys
 
18
  import time
19
  from pathlib import Path
20
  from typing import Any, Dict, Optional
@@ -406,6 +407,13 @@ def plan_extended(prompt: str) -> Optional[Any]:
406
  # Generation / retrieval functions
407
  # ---------------------------------------------------------------------------
408
 
 
 
 
 
 
 
 
409
  def gen_text(prompt: str, mode: str) -> dict:
410
  """Generate text and optional plan using HF Inference API."""
411
  # Step 1: Plan (if not direct mode)
@@ -457,6 +465,50 @@ def gen_text(prompt: str, mode: str) -> dict:
457
  }
458
 
459
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
  def retrieve_image(prompt: str) -> dict:
461
  r = load_image_retriever().retrieve(prompt)
462
  return {
@@ -540,6 +592,15 @@ def main():
540
  with st.sidebar:
541
  st.markdown("#### Configuration")
542
 
 
 
 
 
 
 
 
 
 
543
  mode = st.selectbox(
544
  "Planning Mode",
545
  ["direct", "planner", "council", "extended_prompt"],
@@ -567,16 +628,22 @@ def main():
567
  "council": "3 LLM calls merged for richer planning",
568
  "extended_prompt": "Single LLM call with 3x token budget",
569
  }
 
 
 
 
 
 
570
  st.markdown(
571
  f'<div class="sidebar-info">'
572
  f'<b>Text</b> HF Inference API<br>'
573
  f'<b>Planning</b> {mode_desc[mode]}<br>'
574
- f'<b>Image</b> CLIP retrieval (57 images)<br>'
575
- f'<b>Audio</b> CLAP retrieval (104 clips)<br><br>'
576
  f'<b>Metric</b> MSCI = 0.45 &times; s<sub>t,i</sub> + 0.45 &times; s<sub>t,a</sub><br><br>'
577
  f'<b>Models</b><br>'
578
- f'CLIP ViT-B/32 (text-image)<br>'
579
- f'CLAP HTSAT-unfused (text-audio)'
580
  f'</div>', unsafe_allow_html=True)
581
 
582
  # Prompt input
@@ -595,9 +662,13 @@ def main():
595
  mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
596
  mcls = "chip-amber" if mode != "direct" else "chip-purple"
597
  mdot = "chip-dot-amber" if mode != "direct" else "chip-dot-purple"
 
 
 
 
598
  st.markdown(
599
  f'<div class="chip-row">'
600
- f'<span class="chip chip-pink"><span class="chip-dot chip-dot-pink"></span>Generative</span>'
601
  f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
602
  f'<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>CLIP + CLAP</span>'
603
  f'</div>', unsafe_allow_html=True)
@@ -613,7 +684,7 @@ def main():
613
  return
614
 
615
  if go and prompt.strip():
616
- st.session_state["last_result"] = run_pipeline(prompt.strip(), mode)
617
 
618
  if "last_result" in st.session_state:
619
  show_results(st.session_state["last_result"])
@@ -623,8 +694,8 @@ def main():
623
  # Pipeline
624
  # ---------------------------------------------------------------------------
625
 
626
- def run_pipeline(prompt: str, mode: str) -> dict:
627
- R: dict = {"mode": mode}
628
  t_all = time.time()
629
 
630
  # 1) Text + Planning
@@ -647,33 +718,53 @@ def run_pipeline(prompt: str, mode: str) -> dict:
647
  ip = R["text"].get("image_prompt", prompt)
648
  ap = R["text"].get("audio_prompt", prompt)
649
 
650
- # 2) Image retrieval
651
- with st.status("Retrieving image...", expanded=True) as s:
 
652
  t0 = time.time()
653
  try:
654
- R["image"] = retrieve_image(ip)
 
 
 
655
  R["t_img"] = time.time() - t0
656
- f = R["image"].get("failed", False)
657
- lbl = f"Image retrieved (sim={R['image']['similarity']:.3f}, {R['t_img']:.1f}s)"
658
- if f:
659
- lbl += " \u2014 below threshold"
660
- s.update(label=lbl, state="complete" if not f else "error")
 
 
 
 
 
 
661
  except Exception as e:
662
  s.update(label=f"Image failed: {e}", state="error")
663
  R["image"] = None
664
  R["t_img"] = time.time() - t0
665
 
666
- # 3) Audio retrieval
667
- with st.status("Retrieving audio...", expanded=True) as s:
 
668
  t0 = time.time()
669
  try:
670
- R["audio"] = retrieve_audio(ap)
 
 
 
671
  R["t_aud"] = time.time() - t0
672
- f = R["audio"].get("failed", False)
673
- lbl = f"Audio retrieved (sim={R['audio']['similarity']:.3f}, {R['t_aud']:.1f}s)"
674
- if f:
675
- lbl += " \u2014 below threshold"
676
- s.update(label=lbl, state="complete" if not f else "error")
 
 
 
 
 
 
677
  except Exception as e:
678
  s.update(label=f"Audio failed: {e}", state="error")
679
  R["audio"] = None
@@ -743,45 +834,54 @@ def show_results(R: dict):
743
  st.markdown(f'<div class="text-card">{txt}</div>', unsafe_allow_html=True)
744
 
745
  with ci:
746
- st.markdown('<div class="sec-label">Image</div>', unsafe_allow_html=True)
747
  ii = R.get("image")
748
  if ii and ii.get("path"):
749
  ip = Path(ii["path"])
750
- failed = ii.get("failed", False)
751
- sim = ii.get("similarity")
752
 
753
- if failed:
 
754
  st.markdown(
755
- f'<div class="warn-banner"><b>Below threshold</b> '
756
- f'(sim={sim:.3f} &lt; {IMAGE_SIM_THRESHOLD}) '
757
- f'\u2014 best match from index.</div>',
758
  unsafe_allow_html=True)
759
 
760
  if ip.exists():
761
  st.image(str(ip), use_container_width=True)
762
- dom = ii.get("domain", "other")
763
- ic = DOMAIN_ICONS.get(dom, "\U0001f4cd")
764
- st.caption(f"{ic} {dom} \u00b7 sim **{sim:.3f}** \u00b7 {ip.name}")
 
 
 
 
 
765
  else:
766
  st.info("No image.")
767
 
768
  with ca:
769
- st.markdown('<div class="sec-label">Audio</div>', unsafe_allow_html=True)
770
  ai = R.get("audio")
771
  if ai and ai.get("path"):
772
  ap = Path(ai["path"])
773
- sim = ai.get("similarity")
774
- failed = ai.get("failed", False)
775
 
776
- if failed:
 
777
  st.markdown(
778
- f'<div class="warn-banner"><b>Below threshold</b> '
779
- f'(sim={sim:.3f} &lt; {AUDIO_SIM_THRESHOLD}).</div>',
780
  unsafe_allow_html=True)
781
 
782
  if ap.exists():
783
  st.audio(str(ap))
784
- st.caption(f"sim **{sim:.3f}** \u00b7 {ap.name}")
 
 
 
 
 
785
  else:
786
  st.info("No audio.")
787
 
@@ -819,22 +919,34 @@ def show_results(R: dict):
819
  else:
820
  st.write(f"Planning ({mode}) did not produce a valid plan. Fell back to direct mode.")
821
 
822
- with st.expander("Retrieval Details"):
823
  r1, r2 = st.columns(2)
824
  with r1:
825
  ii = R.get("image")
826
- if ii and ii.get("top_5"):
827
- st.markdown("**Image \u2014 Top 5 candidates**")
828
- bars = "".join(sim_bar_html(n, s) for n, s in ii["top_5"])
829
- st.markdown(bars, unsafe_allow_html=True)
 
 
 
 
 
 
830
  else:
831
  st.write("No image data.")
832
  with r2:
833
  ai = R.get("audio")
834
- if ai and ai.get("top_5"):
835
- st.markdown("**Audio \u2014 Top 5 candidates**")
836
- bars = "".join(sim_bar_html(n, s) for n, s in ai["top_5"])
837
- st.markdown(bars, unsafe_allow_html=True)
 
 
 
 
 
 
838
  else:
839
  st.write("No audio data.")
840
 
 
5
  Enter a scene description and the system produces coherent text, image,
6
  and audio with real-time MSCI scoring.
7
 
8
+ Pipeline: HF Inference API (text + planning + image + audio) with CLIP/CLAP retrieval fallback
9
  Planning modes: direct, planner, council (3-way), extended_prompt (3x tokens)
10
  """
11
 
 
15
  import logging
16
  import os
17
  import sys
18
+ import tempfile
19
  import time
20
  from pathlib import Path
21
  from typing import Any, Dict, Optional
 
407
  # Generation / retrieval functions
408
  # ---------------------------------------------------------------------------
409
 
410
+ # HF Inference API model IDs
411
+ IMAGE_GEN_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
412
+ AUDIO_GEN_MODELS = [
413
+ "cvssp/audioldm2",
414
+ "facebook/musicgen-small",
415
+ ]
416
+
417
  def gen_text(prompt: str, mode: str) -> dict:
418
  """Generate text and optional plan using HF Inference API."""
419
  # Step 1: Plan (if not direct mode)
 
465
  }
466
 
467
 
468
+ def generate_image(prompt: str) -> dict:
469
+ """Generate image via HF Inference API (SDXL), fallback to retrieval."""
470
+ client = get_inference_client()
471
+ try:
472
+ image = client.text_to_image(prompt, model=IMAGE_GEN_MODEL)
473
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
474
+ image.save(tmp.name)
475
+ return {
476
+ "path": tmp.name, "backend": "generative",
477
+ "model": "SDXL", "failed": False,
478
+ }
479
+ except Exception as e:
480
+ logger.warning("Image generation failed: %s — falling back to retrieval", e)
481
+ return retrieve_image(prompt)
482
+
483
+
484
+ def generate_audio(prompt: str) -> dict:
485
+ """Generate audio via HF Inference API, fallback to retrieval."""
486
+ client = get_inference_client()
487
+ for model_id in AUDIO_GEN_MODELS:
488
+ try:
489
+ audio_bytes = client.text_to_audio(prompt, model=model_id)
490
+ suffix = ".flac" if "musicgen" in model_id else ".wav"
491
+ tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False, dir="/tmp")
492
+ if isinstance(audio_bytes, bytes):
493
+ tmp.write(audio_bytes)
494
+ tmp.flush()
495
+ else:
496
+ # Some API versions return object with .read() or similar
497
+ tmp.write(bytes(audio_bytes))
498
+ tmp.flush()
499
+ model_name = model_id.split("/")[-1]
500
+ return {
501
+ "path": tmp.name, "backend": "generative",
502
+ "model": model_name, "failed": False,
503
+ }
504
+ except Exception as e:
505
+ logger.warning("Audio gen with %s failed: %s", model_id, e)
506
+ continue
507
+ # All generative models failed — fall back to retrieval
508
+ logger.warning("All audio generation models failed — falling back to retrieval")
509
+ return retrieve_audio(prompt)
510
+
511
+
512
  def retrieve_image(prompt: str) -> dict:
513
  r = load_image_retriever().retrieve(prompt)
514
  return {
 
592
  with st.sidebar:
593
  st.markdown("#### Configuration")
594
 
595
+ backend = st.selectbox(
596
+ "Backend",
597
+ ["generative", "retrieval"],
598
+ format_func=lambda x: {
599
+ "generative": "Generative (SDXL + AudioLDM2)",
600
+ "retrieval": "Retrieval (CLIP + CLAP index)",
601
+ }[x],
602
+ )
603
+
604
  mode = st.selectbox(
605
  "Planning Mode",
606
  ["direct", "planner", "council", "extended_prompt"],
 
628
  "council": "3 LLM calls merged for richer planning",
629
  "extended_prompt": "Single LLM call with 3x token budget",
630
  }
631
+ if backend == "generative":
632
+ img_info = "SDXL via HF API"
633
+ aud_info = "AudioLDM2 / MusicGen via HF API"
634
+ else:
635
+ img_info = "CLIP retrieval (57 images)"
636
+ aud_info = "CLAP retrieval (104 clips)"
637
  st.markdown(
638
  f'<div class="sidebar-info">'
639
  f'<b>Text</b> HF Inference API<br>'
640
  f'<b>Planning</b> {mode_desc[mode]}<br>'
641
+ f'<b>Image</b> {img_info}<br>'
642
+ f'<b>Audio</b> {aud_info}<br><br>'
643
  f'<b>Metric</b> MSCI = 0.45 &times; s<sub>t,i</sub> + 0.45 &times; s<sub>t,a</sub><br><br>'
644
  f'<b>Models</b><br>'
645
+ f'CLIP ViT-B/32 (coherence eval)<br>'
646
+ f'CLAP HTSAT-unfused (coherence eval)'
647
  f'</div>', unsafe_allow_html=True)
648
 
649
  # Prompt input
 
662
  mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
663
  mcls = "chip-amber" if mode != "direct" else "chip-purple"
664
  mdot = "chip-dot-amber" if mode != "direct" else "chip-dot-purple"
665
+ if backend == "generative":
666
+ bchip = '<span class="chip chip-pink"><span class="chip-dot chip-dot-pink"></span>Generative</span>'
667
+ else:
668
+ bchip = '<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
669
  st.markdown(
670
  f'<div class="chip-row">'
671
+ f'{bchip}'
672
  f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
673
  f'<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>CLIP + CLAP</span>'
674
  f'</div>', unsafe_allow_html=True)
 
684
  return
685
 
686
  if go and prompt.strip():
687
+ st.session_state["last_result"] = run_pipeline(prompt.strip(), mode, backend)
688
 
689
  if "last_result" in st.session_state:
690
  show_results(st.session_state["last_result"])
 
694
  # Pipeline
695
  # ---------------------------------------------------------------------------
696
 
697
+ def run_pipeline(prompt: str, mode: str, backend: str = "generative") -> dict:
698
+ R: dict = {"mode": mode, "backend": backend}
699
  t_all = time.time()
700
 
701
  # 1) Text + Planning
 
718
  ip = R["text"].get("image_prompt", prompt)
719
  ap = R["text"].get("audio_prompt", prompt)
720
 
721
+ # 2) Image
722
+ img_label = "Generating image (SDXL)..." if backend == "generative" else "Retrieving image..."
723
+ with st.status(img_label, expanded=True) as s:
724
  t0 = time.time()
725
  try:
726
+ if backend == "generative":
727
+ R["image"] = generate_image(ip)
728
+ else:
729
+ R["image"] = retrieve_image(ip)
730
  R["t_img"] = time.time() - t0
731
+ img_backend = R["image"].get("backend", "unknown")
732
+ model = R["image"].get("model", "")
733
+ if img_backend == "generative":
734
+ lbl = f"Image generated via {model} ({R['t_img']:.1f}s)"
735
+ else:
736
+ sim = R["image"].get("similarity", 0)
737
+ failed = R["image"].get("failed", False)
738
+ lbl = f"Image retrieved (sim={sim:.3f}, {R['t_img']:.1f}s)"
739
+ if failed:
740
+ lbl += " \u2014 below threshold"
741
+ s.update(label=lbl, state="complete")
742
  except Exception as e:
743
  s.update(label=f"Image failed: {e}", state="error")
744
  R["image"] = None
745
  R["t_img"] = time.time() - t0
746
 
747
+ # 3) Audio
748
+ aud_label = "Generating audio..." if backend == "generative" else "Retrieving audio..."
749
+ with st.status(aud_label, expanded=True) as s:
750
  t0 = time.time()
751
  try:
752
+ if backend == "generative":
753
+ R["audio"] = generate_audio(ap)
754
+ else:
755
+ R["audio"] = retrieve_audio(ap)
756
  R["t_aud"] = time.time() - t0
757
+ aud_backend = R["audio"].get("backend", "unknown")
758
+ model = R["audio"].get("model", "")
759
+ if aud_backend == "generative":
760
+ lbl = f"Audio generated via {model} ({R['t_aud']:.1f}s)"
761
+ else:
762
+ sim = R["audio"].get("similarity", 0)
763
+ failed = R["audio"].get("failed", False)
764
+ lbl = f"Audio retrieved (sim={sim:.3f}, {R['t_aud']:.1f}s)"
765
+ if failed:
766
+ lbl += " \u2014 below threshold"
767
+ s.update(label=lbl, state="complete")
768
  except Exception as e:
769
  s.update(label=f"Audio failed: {e}", state="error")
770
  R["audio"] = None
 
834
  st.markdown(f'<div class="text-card">{txt}</div>', unsafe_allow_html=True)
835
 
836
  with ci:
837
+ st.markdown('<div class="sec-label">Generated Image</div>', unsafe_allow_html=True)
838
  ii = R.get("image")
839
  if ii and ii.get("path"):
840
  ip = Path(ii["path"])
841
+ backend = ii.get("backend", "unknown")
 
842
 
843
+ if backend == "retrieval" and ii.get("failed", False):
844
+ sim = ii.get("similarity", 0)
845
  st.markdown(
846
+ f'<div class="warn-banner"><b>Retrieval fallback</b> '
847
+ f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
 
848
  unsafe_allow_html=True)
849
 
850
  if ip.exists():
851
  st.image(str(ip), use_container_width=True)
852
+ model = ii.get("model", "")
853
+ if backend == "generative":
854
+ st.caption(f"Generated via **{model}**")
855
+ else:
856
+ sim = ii.get("similarity", 0)
857
+ dom = ii.get("domain", "other")
858
+ ic = DOMAIN_ICONS.get(dom, "\U0001f4cd")
859
+ st.caption(f"{ic} {dom} \u00b7 sim **{sim:.3f}** \u00b7 Retrieved")
860
  else:
861
  st.info("No image.")
862
 
863
  with ca:
864
+ st.markdown('<div class="sec-label">Generated Audio</div>', unsafe_allow_html=True)
865
  ai = R.get("audio")
866
  if ai and ai.get("path"):
867
  ap = Path(ai["path"])
868
+ backend = ai.get("backend", "unknown")
 
869
 
870
+ if backend == "retrieval" and ai.get("failed", False):
871
+ sim = ai.get("similarity", 0)
872
  st.markdown(
873
+ f'<div class="warn-banner"><b>Retrieval fallback</b> '
874
+ f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
875
  unsafe_allow_html=True)
876
 
877
  if ap.exists():
878
  st.audio(str(ap))
879
+ model = ai.get("model", "")
880
+ if backend == "generative":
881
+ st.caption(f"Generated via **{model}**")
882
+ else:
883
+ sim = ai.get("similarity", 0)
884
+ st.caption(f"sim **{sim:.3f}** \u00b7 Retrieved")
885
  else:
886
  st.info("No audio.")
887
 
 
919
  else:
920
  st.write(f"Planning ({mode}) did not produce a valid plan. Fell back to direct mode.")
921
 
922
+ with st.expander("Generation Details"):
923
  r1, r2 = st.columns(2)
924
  with r1:
925
  ii = R.get("image")
926
+ if ii:
927
+ backend = ii.get("backend", "unknown")
928
+ model = ii.get("model", "")
929
+ if backend == "generative":
930
+ st.markdown(f"**Image** generated via **{model}**")
931
+ st.markdown(f"Prompt: *{R.get('text', {}).get('image_prompt', '')}*")
932
+ elif ii.get("top_5"):
933
+ st.markdown("**Image** (retrieval fallback)")
934
+ bars = "".join(sim_bar_html(n, s) for n, s in ii["top_5"])
935
+ st.markdown(bars, unsafe_allow_html=True)
936
  else:
937
  st.write("No image data.")
938
  with r2:
939
  ai = R.get("audio")
940
+ if ai:
941
+ backend = ai.get("backend", "unknown")
942
+ model = ai.get("model", "")
943
+ if backend == "generative":
944
+ st.markdown(f"**Audio** generated via **{model}**")
945
+ st.markdown(f"Prompt: *{R.get('text', {}).get('audio_prompt', '')}*")
946
+ elif ai.get("top_5"):
947
+ st.markdown("**Audio** (retrieval fallback)")
948
+ bars = "".join(sim_bar_html(n, s) for n, s in ai["top_5"])
949
+ st.markdown(bars, unsafe_allow_html=True)
950
  else:
951
  st.write("No audio data.")
952
 
src/embeddings/audio_embedder.py CHANGED
@@ -56,11 +56,19 @@ class AudioEmbedder:
56
  def embed(self, audio_path: str) -> np.ndarray:
57
  waveform, _ = librosa.load(audio_path, sr=self.target_sr, mono=True)
58
 
59
- inputs = self.processor(
60
- audios=waveform,
61
- sampling_rate=self.target_sr,
62
- return_tensors="pt",
63
- ).to(self.device)
 
 
 
 
 
 
 
 
64
 
65
  outputs = self.model.get_audio_features(**inputs)
66
  emb = self._extract_features(outputs, "audio_projection")
 
56
  def embed(self, audio_path: str) -> np.ndarray:
57
  waveform, _ = librosa.load(audio_path, sr=self.target_sr, mono=True)
58
 
59
+ # Use 'audio' (newer transformers) with fallback to 'audios' (older)
60
+ try:
61
+ inputs = self.processor(
62
+ audio=waveform,
63
+ sampling_rate=self.target_sr,
64
+ return_tensors="pt",
65
+ ).to(self.device)
66
+ except TypeError:
67
+ inputs = self.processor(
68
+ audios=waveform,
69
+ sampling_rate=self.target_sr,
70
+ return_tensors="pt",
71
+ ).to(self.device)
72
 
73
  outputs = self.model.get_audio_features(**inputs)
74
  emb = self._extract_features(outputs, "audio_projection")