Upload folder using huggingface_hub
Browse files- app.py +163 -51
- src/embeddings/audio_embedder.py +13 -5
app.py
CHANGED
|
@@ -5,7 +5,7 @@ Live demonstration of multimodal generation + coherence evaluation.
|
|
| 5 |
Enter a scene description and the system produces coherent text, image,
|
| 6 |
and audio with real-time MSCI scoring.
|
| 7 |
|
| 8 |
-
Pipeline: HF Inference API (text + planning
|
| 9 |
Planning modes: direct, planner, council (3-way), extended_prompt (3x tokens)
|
| 10 |
"""
|
| 11 |
|
|
@@ -15,6 +15,7 @@ import json
|
|
| 15 |
import logging
|
| 16 |
import os
|
| 17 |
import sys
|
|
|
|
| 18 |
import time
|
| 19 |
from pathlib import Path
|
| 20 |
from typing import Any, Dict, Optional
|
|
@@ -406,6 +407,13 @@ def plan_extended(prompt: str) -> Optional[Any]:
|
|
| 406 |
# Generation / retrieval functions
|
| 407 |
# ---------------------------------------------------------------------------
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
def gen_text(prompt: str, mode: str) -> dict:
|
| 410 |
"""Generate text and optional plan using HF Inference API."""
|
| 411 |
# Step 1: Plan (if not direct mode)
|
|
@@ -457,6 +465,50 @@ def gen_text(prompt: str, mode: str) -> dict:
|
|
| 457 |
}
|
| 458 |
|
| 459 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 460 |
def retrieve_image(prompt: str) -> dict:
|
| 461 |
r = load_image_retriever().retrieve(prompt)
|
| 462 |
return {
|
|
@@ -540,6 +592,15 @@ def main():
|
|
| 540 |
with st.sidebar:
|
| 541 |
st.markdown("#### Configuration")
|
| 542 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
mode = st.selectbox(
|
| 544 |
"Planning Mode",
|
| 545 |
["direct", "planner", "council", "extended_prompt"],
|
|
@@ -567,16 +628,22 @@ def main():
|
|
| 567 |
"council": "3 LLM calls merged for richer planning",
|
| 568 |
"extended_prompt": "Single LLM call with 3x token budget",
|
| 569 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
st.markdown(
|
| 571 |
f'<div class="sidebar-info">'
|
| 572 |
f'<b>Text</b> HF Inference API<br>'
|
| 573 |
f'<b>Planning</b> {mode_desc[mode]}<br>'
|
| 574 |
-
f'<b>Image</b>
|
| 575 |
-
f'<b>Audio</b>
|
| 576 |
f'<b>Metric</b> MSCI = 0.45 × s<sub>t,i</sub> + 0.45 × s<sub>t,a</sub><br><br>'
|
| 577 |
f'<b>Models</b><br>'
|
| 578 |
-
f'CLIP ViT-B/32 (
|
| 579 |
-
f'CLAP HTSAT-unfused (
|
| 580 |
f'</div>', unsafe_allow_html=True)
|
| 581 |
|
| 582 |
# Prompt input
|
|
@@ -595,9 +662,13 @@ def main():
|
|
| 595 |
mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
|
| 596 |
mcls = "chip-amber" if mode != "direct" else "chip-purple"
|
| 597 |
mdot = "chip-dot-amber" if mode != "direct" else "chip-dot-purple"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 598 |
st.markdown(
|
| 599 |
f'<div class="chip-row">'
|
| 600 |
-
f'
|
| 601 |
f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
|
| 602 |
f'<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>CLIP + CLAP</span>'
|
| 603 |
f'</div>', unsafe_allow_html=True)
|
|
@@ -613,7 +684,7 @@ def main():
|
|
| 613 |
return
|
| 614 |
|
| 615 |
if go and prompt.strip():
|
| 616 |
-
st.session_state["last_result"] = run_pipeline(prompt.strip(), mode)
|
| 617 |
|
| 618 |
if "last_result" in st.session_state:
|
| 619 |
show_results(st.session_state["last_result"])
|
|
@@ -623,8 +694,8 @@ def main():
|
|
| 623 |
# Pipeline
|
| 624 |
# ---------------------------------------------------------------------------
|
| 625 |
|
| 626 |
-
def run_pipeline(prompt: str, mode: str) -> dict:
|
| 627 |
-
R: dict = {"mode": mode}
|
| 628 |
t_all = time.time()
|
| 629 |
|
| 630 |
# 1) Text + Planning
|
|
@@ -647,33 +718,53 @@ def run_pipeline(prompt: str, mode: str) -> dict:
|
|
| 647 |
ip = R["text"].get("image_prompt", prompt)
|
| 648 |
ap = R["text"].get("audio_prompt", prompt)
|
| 649 |
|
| 650 |
-
# 2) Image
|
| 651 |
-
|
|
|
|
| 652 |
t0 = time.time()
|
| 653 |
try:
|
| 654 |
-
|
|
|
|
|
|
|
|
|
|
| 655 |
R["t_img"] = time.time() - t0
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
if
|
| 659 |
-
lbl
|
| 660 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 661 |
except Exception as e:
|
| 662 |
s.update(label=f"Image failed: {e}", state="error")
|
| 663 |
R["image"] = None
|
| 664 |
R["t_img"] = time.time() - t0
|
| 665 |
|
| 666 |
-
# 3) Audio
|
| 667 |
-
|
|
|
|
| 668 |
t0 = time.time()
|
| 669 |
try:
|
| 670 |
-
|
|
|
|
|
|
|
|
|
|
| 671 |
R["t_aud"] = time.time() - t0
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
if
|
| 675 |
-
lbl
|
| 676 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
except Exception as e:
|
| 678 |
s.update(label=f"Audio failed: {e}", state="error")
|
| 679 |
R["audio"] = None
|
|
@@ -743,45 +834,54 @@ def show_results(R: dict):
|
|
| 743 |
st.markdown(f'<div class="text-card">{txt}</div>', unsafe_allow_html=True)
|
| 744 |
|
| 745 |
with ci:
|
| 746 |
-
st.markdown('<div class="sec-label">Image</div>', unsafe_allow_html=True)
|
| 747 |
ii = R.get("image")
|
| 748 |
if ii and ii.get("path"):
|
| 749 |
ip = Path(ii["path"])
|
| 750 |
-
|
| 751 |
-
sim = ii.get("similarity")
|
| 752 |
|
| 753 |
-
if failed:
|
|
|
|
| 754 |
st.markdown(
|
| 755 |
-
f'<div class="warn-banner"><b>
|
| 756 |
-
f'(sim={sim:.3f}
|
| 757 |
-
f'\u2014 best match from index.</div>',
|
| 758 |
unsafe_allow_html=True)
|
| 759 |
|
| 760 |
if ip.exists():
|
| 761 |
st.image(str(ip), use_container_width=True)
|
| 762 |
-
|
| 763 |
-
|
| 764 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
else:
|
| 766 |
st.info("No image.")
|
| 767 |
|
| 768 |
with ca:
|
| 769 |
-
st.markdown('<div class="sec-label">Audio</div>', unsafe_allow_html=True)
|
| 770 |
ai = R.get("audio")
|
| 771 |
if ai and ai.get("path"):
|
| 772 |
ap = Path(ai["path"])
|
| 773 |
-
|
| 774 |
-
failed = ai.get("failed", False)
|
| 775 |
|
| 776 |
-
if failed:
|
|
|
|
| 777 |
st.markdown(
|
| 778 |
-
f'<div class="warn-banner"><b>
|
| 779 |
-
f'(sim={sim:.3f}
|
| 780 |
unsafe_allow_html=True)
|
| 781 |
|
| 782 |
if ap.exists():
|
| 783 |
st.audio(str(ap))
|
| 784 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 785 |
else:
|
| 786 |
st.info("No audio.")
|
| 787 |
|
|
@@ -819,22 +919,34 @@ def show_results(R: dict):
|
|
| 819 |
else:
|
| 820 |
st.write(f"Planning ({mode}) did not produce a valid plan. Fell back to direct mode.")
|
| 821 |
|
| 822 |
-
with st.expander("
|
| 823 |
r1, r2 = st.columns(2)
|
| 824 |
with r1:
|
| 825 |
ii = R.get("image")
|
| 826 |
-
if ii
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 830 |
else:
|
| 831 |
st.write("No image data.")
|
| 832 |
with r2:
|
| 833 |
ai = R.get("audio")
|
| 834 |
-
if ai
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 838 |
else:
|
| 839 |
st.write("No audio data.")
|
| 840 |
|
|
|
|
| 5 |
Enter a scene description and the system produces coherent text, image,
|
| 6 |
and audio with real-time MSCI scoring.
|
| 7 |
|
| 8 |
+
Pipeline: HF Inference API (text + planning + image + audio) with CLIP/CLAP retrieval fallback
|
| 9 |
Planning modes: direct, planner, council (3-way), extended_prompt (3x tokens)
|
| 10 |
"""
|
| 11 |
|
|
|
|
| 15 |
import logging
|
| 16 |
import os
|
| 17 |
import sys
|
| 18 |
+
import tempfile
|
| 19 |
import time
|
| 20 |
from pathlib import Path
|
| 21 |
from typing import Any, Dict, Optional
|
|
|
|
| 407 |
# Generation / retrieval functions
|
| 408 |
# ---------------------------------------------------------------------------
|
| 409 |
|
| 410 |
+
# HF Inference API model IDs
|
| 411 |
+
IMAGE_GEN_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 412 |
+
AUDIO_GEN_MODELS = [
|
| 413 |
+
"cvssp/audioldm2",
|
| 414 |
+
"facebook/musicgen-small",
|
| 415 |
+
]
|
| 416 |
+
|
| 417 |
def gen_text(prompt: str, mode: str) -> dict:
|
| 418 |
"""Generate text and optional plan using HF Inference API."""
|
| 419 |
# Step 1: Plan (if not direct mode)
|
|
|
|
| 465 |
}
|
| 466 |
|
| 467 |
|
| 468 |
+
def generate_image(prompt: str) -> dict:
|
| 469 |
+
"""Generate image via HF Inference API (SDXL), fallback to retrieval."""
|
| 470 |
+
client = get_inference_client()
|
| 471 |
+
try:
|
| 472 |
+
image = client.text_to_image(prompt, model=IMAGE_GEN_MODEL)
|
| 473 |
+
tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir="/tmp")
|
| 474 |
+
image.save(tmp.name)
|
| 475 |
+
return {
|
| 476 |
+
"path": tmp.name, "backend": "generative",
|
| 477 |
+
"model": "SDXL", "failed": False,
|
| 478 |
+
}
|
| 479 |
+
except Exception as e:
|
| 480 |
+
logger.warning("Image generation failed: %s — falling back to retrieval", e)
|
| 481 |
+
return retrieve_image(prompt)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def generate_audio(prompt: str) -> dict:
|
| 485 |
+
"""Generate audio via HF Inference API, fallback to retrieval."""
|
| 486 |
+
client = get_inference_client()
|
| 487 |
+
for model_id in AUDIO_GEN_MODELS:
|
| 488 |
+
try:
|
| 489 |
+
audio_bytes = client.text_to_audio(prompt, model=model_id)
|
| 490 |
+
suffix = ".flac" if "musicgen" in model_id else ".wav"
|
| 491 |
+
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False, dir="/tmp")
|
| 492 |
+
if isinstance(audio_bytes, bytes):
|
| 493 |
+
tmp.write(audio_bytes)
|
| 494 |
+
tmp.flush()
|
| 495 |
+
else:
|
| 496 |
+
# Some API versions return object with .read() or similar
|
| 497 |
+
tmp.write(bytes(audio_bytes))
|
| 498 |
+
tmp.flush()
|
| 499 |
+
model_name = model_id.split("/")[-1]
|
| 500 |
+
return {
|
| 501 |
+
"path": tmp.name, "backend": "generative",
|
| 502 |
+
"model": model_name, "failed": False,
|
| 503 |
+
}
|
| 504 |
+
except Exception as e:
|
| 505 |
+
logger.warning("Audio gen with %s failed: %s", model_id, e)
|
| 506 |
+
continue
|
| 507 |
+
# All generative models failed — fall back to retrieval
|
| 508 |
+
logger.warning("All audio generation models failed — falling back to retrieval")
|
| 509 |
+
return retrieve_audio(prompt)
|
| 510 |
+
|
| 511 |
+
|
| 512 |
def retrieve_image(prompt: str) -> dict:
|
| 513 |
r = load_image_retriever().retrieve(prompt)
|
| 514 |
return {
|
|
|
|
| 592 |
with st.sidebar:
|
| 593 |
st.markdown("#### Configuration")
|
| 594 |
|
| 595 |
+
backend = st.selectbox(
|
| 596 |
+
"Backend",
|
| 597 |
+
["generative", "retrieval"],
|
| 598 |
+
format_func=lambda x: {
|
| 599 |
+
"generative": "Generative (SDXL + AudioLDM2)",
|
| 600 |
+
"retrieval": "Retrieval (CLIP + CLAP index)",
|
| 601 |
+
}[x],
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
mode = st.selectbox(
|
| 605 |
"Planning Mode",
|
| 606 |
["direct", "planner", "council", "extended_prompt"],
|
|
|
|
| 628 |
"council": "3 LLM calls merged for richer planning",
|
| 629 |
"extended_prompt": "Single LLM call with 3x token budget",
|
| 630 |
}
|
| 631 |
+
if backend == "generative":
|
| 632 |
+
img_info = "SDXL via HF API"
|
| 633 |
+
aud_info = "AudioLDM2 / MusicGen via HF API"
|
| 634 |
+
else:
|
| 635 |
+
img_info = "CLIP retrieval (57 images)"
|
| 636 |
+
aud_info = "CLAP retrieval (104 clips)"
|
| 637 |
st.markdown(
|
| 638 |
f'<div class="sidebar-info">'
|
| 639 |
f'<b>Text</b> HF Inference API<br>'
|
| 640 |
f'<b>Planning</b> {mode_desc[mode]}<br>'
|
| 641 |
+
f'<b>Image</b> {img_info}<br>'
|
| 642 |
+
f'<b>Audio</b> {aud_info}<br><br>'
|
| 643 |
f'<b>Metric</b> MSCI = 0.45 × s<sub>t,i</sub> + 0.45 × s<sub>t,a</sub><br><br>'
|
| 644 |
f'<b>Models</b><br>'
|
| 645 |
+
f'CLIP ViT-B/32 (coherence eval)<br>'
|
| 646 |
+
f'CLAP HTSAT-unfused (coherence eval)'
|
| 647 |
f'</div>', unsafe_allow_html=True)
|
| 648 |
|
| 649 |
# Prompt input
|
|
|
|
| 662 |
mlbl = {"direct": "Direct", "planner": "Planner", "council": "Council", "extended_prompt": "Extended"}[mode]
|
| 663 |
mcls = "chip-amber" if mode != "direct" else "chip-purple"
|
| 664 |
mdot = "chip-dot-amber" if mode != "direct" else "chip-dot-purple"
|
| 665 |
+
if backend == "generative":
|
| 666 |
+
bchip = '<span class="chip chip-pink"><span class="chip-dot chip-dot-pink"></span>Generative</span>'
|
| 667 |
+
else:
|
| 668 |
+
bchip = '<span class="chip chip-purple"><span class="chip-dot chip-dot-purple"></span>Retrieval</span>'
|
| 669 |
st.markdown(
|
| 670 |
f'<div class="chip-row">'
|
| 671 |
+
f'{bchip}'
|
| 672 |
f'<span class="chip {mcls}"><span class="chip-dot {mdot}"></span>{mlbl}</span>'
|
| 673 |
f'<span class="chip chip-green"><span class="chip-dot chip-dot-green"></span>CLIP + CLAP</span>'
|
| 674 |
f'</div>', unsafe_allow_html=True)
|
|
|
|
| 684 |
return
|
| 685 |
|
| 686 |
if go and prompt.strip():
|
| 687 |
+
st.session_state["last_result"] = run_pipeline(prompt.strip(), mode, backend)
|
| 688 |
|
| 689 |
if "last_result" in st.session_state:
|
| 690 |
show_results(st.session_state["last_result"])
|
|
|
|
| 694 |
# Pipeline
|
| 695 |
# ---------------------------------------------------------------------------
|
| 696 |
|
| 697 |
+
def run_pipeline(prompt: str, mode: str, backend: str = "generative") -> dict:
|
| 698 |
+
R: dict = {"mode": mode, "backend": backend}
|
| 699 |
t_all = time.time()
|
| 700 |
|
| 701 |
# 1) Text + Planning
|
|
|
|
| 718 |
ip = R["text"].get("image_prompt", prompt)
|
| 719 |
ap = R["text"].get("audio_prompt", prompt)
|
| 720 |
|
| 721 |
+
# 2) Image
|
| 722 |
+
img_label = "Generating image (SDXL)..." if backend == "generative" else "Retrieving image..."
|
| 723 |
+
with st.status(img_label, expanded=True) as s:
|
| 724 |
t0 = time.time()
|
| 725 |
try:
|
| 726 |
+
if backend == "generative":
|
| 727 |
+
R["image"] = generate_image(ip)
|
| 728 |
+
else:
|
| 729 |
+
R["image"] = retrieve_image(ip)
|
| 730 |
R["t_img"] = time.time() - t0
|
| 731 |
+
img_backend = R["image"].get("backend", "unknown")
|
| 732 |
+
model = R["image"].get("model", "")
|
| 733 |
+
if img_backend == "generative":
|
| 734 |
+
lbl = f"Image generated via {model} ({R['t_img']:.1f}s)"
|
| 735 |
+
else:
|
| 736 |
+
sim = R["image"].get("similarity", 0)
|
| 737 |
+
failed = R["image"].get("failed", False)
|
| 738 |
+
lbl = f"Image retrieved (sim={sim:.3f}, {R['t_img']:.1f}s)"
|
| 739 |
+
if failed:
|
| 740 |
+
lbl += " \u2014 below threshold"
|
| 741 |
+
s.update(label=lbl, state="complete")
|
| 742 |
except Exception as e:
|
| 743 |
s.update(label=f"Image failed: {e}", state="error")
|
| 744 |
R["image"] = None
|
| 745 |
R["t_img"] = time.time() - t0
|
| 746 |
|
| 747 |
+
# 3) Audio
|
| 748 |
+
aud_label = "Generating audio..." if backend == "generative" else "Retrieving audio..."
|
| 749 |
+
with st.status(aud_label, expanded=True) as s:
|
| 750 |
t0 = time.time()
|
| 751 |
try:
|
| 752 |
+
if backend == "generative":
|
| 753 |
+
R["audio"] = generate_audio(ap)
|
| 754 |
+
else:
|
| 755 |
+
R["audio"] = retrieve_audio(ap)
|
| 756 |
R["t_aud"] = time.time() - t0
|
| 757 |
+
aud_backend = R["audio"].get("backend", "unknown")
|
| 758 |
+
model = R["audio"].get("model", "")
|
| 759 |
+
if aud_backend == "generative":
|
| 760 |
+
lbl = f"Audio generated via {model} ({R['t_aud']:.1f}s)"
|
| 761 |
+
else:
|
| 762 |
+
sim = R["audio"].get("similarity", 0)
|
| 763 |
+
failed = R["audio"].get("failed", False)
|
| 764 |
+
lbl = f"Audio retrieved (sim={sim:.3f}, {R['t_aud']:.1f}s)"
|
| 765 |
+
if failed:
|
| 766 |
+
lbl += " \u2014 below threshold"
|
| 767 |
+
s.update(label=lbl, state="complete")
|
| 768 |
except Exception as e:
|
| 769 |
s.update(label=f"Audio failed: {e}", state="error")
|
| 770 |
R["audio"] = None
|
|
|
|
| 834 |
st.markdown(f'<div class="text-card">{txt}</div>', unsafe_allow_html=True)
|
| 835 |
|
| 836 |
with ci:
|
| 837 |
+
st.markdown('<div class="sec-label">Generated Image</div>', unsafe_allow_html=True)
|
| 838 |
ii = R.get("image")
|
| 839 |
if ii and ii.get("path"):
|
| 840 |
ip = Path(ii["path"])
|
| 841 |
+
backend = ii.get("backend", "unknown")
|
|
|
|
| 842 |
|
| 843 |
+
if backend == "retrieval" and ii.get("failed", False):
|
| 844 |
+
sim = ii.get("similarity", 0)
|
| 845 |
st.markdown(
|
| 846 |
+
f'<div class="warn-banner"><b>Retrieval fallback</b> '
|
| 847 |
+
f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
|
|
|
|
| 848 |
unsafe_allow_html=True)
|
| 849 |
|
| 850 |
if ip.exists():
|
| 851 |
st.image(str(ip), use_container_width=True)
|
| 852 |
+
model = ii.get("model", "")
|
| 853 |
+
if backend == "generative":
|
| 854 |
+
st.caption(f"Generated via **{model}**")
|
| 855 |
+
else:
|
| 856 |
+
sim = ii.get("similarity", 0)
|
| 857 |
+
dom = ii.get("domain", "other")
|
| 858 |
+
ic = DOMAIN_ICONS.get(dom, "\U0001f4cd")
|
| 859 |
+
st.caption(f"{ic} {dom} \u00b7 sim **{sim:.3f}** \u00b7 Retrieved")
|
| 860 |
else:
|
| 861 |
st.info("No image.")
|
| 862 |
|
| 863 |
with ca:
|
| 864 |
+
st.markdown('<div class="sec-label">Generated Audio</div>', unsafe_allow_html=True)
|
| 865 |
ai = R.get("audio")
|
| 866 |
if ai and ai.get("path"):
|
| 867 |
ap = Path(ai["path"])
|
| 868 |
+
backend = ai.get("backend", "unknown")
|
|
|
|
| 869 |
|
| 870 |
+
if backend == "retrieval" and ai.get("failed", False):
|
| 871 |
+
sim = ai.get("similarity", 0)
|
| 872 |
st.markdown(
|
| 873 |
+
f'<div class="warn-banner"><b>Retrieval fallback</b> '
|
| 874 |
+
f'(sim={sim:.3f}) \u2014 generation unavailable.</div>',
|
| 875 |
unsafe_allow_html=True)
|
| 876 |
|
| 877 |
if ap.exists():
|
| 878 |
st.audio(str(ap))
|
| 879 |
+
model = ai.get("model", "")
|
| 880 |
+
if backend == "generative":
|
| 881 |
+
st.caption(f"Generated via **{model}**")
|
| 882 |
+
else:
|
| 883 |
+
sim = ai.get("similarity", 0)
|
| 884 |
+
st.caption(f"sim **{sim:.3f}** \u00b7 Retrieved")
|
| 885 |
else:
|
| 886 |
st.info("No audio.")
|
| 887 |
|
|
|
|
| 919 |
else:
|
| 920 |
st.write(f"Planning ({mode}) did not produce a valid plan. Fell back to direct mode.")
|
| 921 |
|
| 922 |
+
with st.expander("Generation Details"):
|
| 923 |
r1, r2 = st.columns(2)
|
| 924 |
with r1:
|
| 925 |
ii = R.get("image")
|
| 926 |
+
if ii:
|
| 927 |
+
backend = ii.get("backend", "unknown")
|
| 928 |
+
model = ii.get("model", "")
|
| 929 |
+
if backend == "generative":
|
| 930 |
+
st.markdown(f"**Image** generated via **{model}**")
|
| 931 |
+
st.markdown(f"Prompt: *{R.get('text', {}).get('image_prompt', '')}*")
|
| 932 |
+
elif ii.get("top_5"):
|
| 933 |
+
st.markdown("**Image** (retrieval fallback)")
|
| 934 |
+
bars = "".join(sim_bar_html(n, s) for n, s in ii["top_5"])
|
| 935 |
+
st.markdown(bars, unsafe_allow_html=True)
|
| 936 |
else:
|
| 937 |
st.write("No image data.")
|
| 938 |
with r2:
|
| 939 |
ai = R.get("audio")
|
| 940 |
+
if ai:
|
| 941 |
+
backend = ai.get("backend", "unknown")
|
| 942 |
+
model = ai.get("model", "")
|
| 943 |
+
if backend == "generative":
|
| 944 |
+
st.markdown(f"**Audio** generated via **{model}**")
|
| 945 |
+
st.markdown(f"Prompt: *{R.get('text', {}).get('audio_prompt', '')}*")
|
| 946 |
+
elif ai.get("top_5"):
|
| 947 |
+
st.markdown("**Audio** (retrieval fallback)")
|
| 948 |
+
bars = "".join(sim_bar_html(n, s) for n, s in ai["top_5"])
|
| 949 |
+
st.markdown(bars, unsafe_allow_html=True)
|
| 950 |
else:
|
| 951 |
st.write("No audio data.")
|
| 952 |
|
src/embeddings/audio_embedder.py
CHANGED
|
@@ -56,11 +56,19 @@ class AudioEmbedder:
|
|
| 56 |
def embed(self, audio_path: str) -> np.ndarray:
|
| 57 |
waveform, _ = librosa.load(audio_path, sr=self.target_sr, mono=True)
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
outputs = self.model.get_audio_features(**inputs)
|
| 66 |
emb = self._extract_features(outputs, "audio_projection")
|
|
|
|
| 56 |
def embed(self, audio_path: str) -> np.ndarray:
|
| 57 |
waveform, _ = librosa.load(audio_path, sr=self.target_sr, mono=True)
|
| 58 |
|
| 59 |
+
# Use 'audio' (newer transformers) with fallback to 'audios' (older)
|
| 60 |
+
try:
|
| 61 |
+
inputs = self.processor(
|
| 62 |
+
audio=waveform,
|
| 63 |
+
sampling_rate=self.target_sr,
|
| 64 |
+
return_tensors="pt",
|
| 65 |
+
).to(self.device)
|
| 66 |
+
except TypeError:
|
| 67 |
+
inputs = self.processor(
|
| 68 |
+
audios=waveform,
|
| 69 |
+
sampling_rate=self.target_sr,
|
| 70 |
+
return_tensors="pt",
|
| 71 |
+
).to(self.device)
|
| 72 |
|
| 73 |
outputs = self.model.get_audio_features(**inputs)
|
| 74 |
emb = self._extract_features(outputs, "audio_projection")
|