| """ |
| Audio retrieval generator — retrieves best-matching audio from indexed pool. |
| |
| Mirrors the image retrieval approach: |
| - Uses CLAP text encoder to embed query |
| - Compares against CLAP audio embeddings in the audio index |
| - Returns best matching real audio file |
| |
| Falls back to synthetic ambient if no audio index is available. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
|
|
| from src.embeddings.aligned_embeddings import AlignedEmbedder |
| from src.embeddings.similarity import cosine_similarity |
| from src.exceptions import IndexError_ |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class AudioRetrievalResult: |
| """Result of audio retrieval with metadata for experiment bundles.""" |
| audio_path: str |
| similarity: float |
| retrieval_failed: bool |
| candidates_considered: int |
| candidates_above_threshold: int |
| top_5: List[Tuple[str, float]] |
|
|
|
|
| class AudioRetrievalGenerator: |
| """ |
| Audio retrieval using CLAP shared space. |
| |
| Query: CLAP text embedding of the prompt |
| Index: CLAP audio embeddings of real audio files |
| |
| This ensures st_a (text-audio similarity) is meaningful because |
| both query and candidates live in CLAP's shared space. |
| """ |
|
|
| def __init__( |
| self, |
| index_path: str = "data/embeddings/audio_index.npz", |
| min_similarity: float = 0.10, |
| top_k: int = 5, |
| ): |
| self.index_path = Path(index_path) |
| self.min_similarity = min_similarity |
| self.top_k = top_k |
|
|
| if not self.index_path.exists(): |
| raise IndexError_( |
| f"Missing audio index at {self.index_path}. " |
| "Run: python scripts/build_embedding_indexes.py", |
| index_path=str(self.index_path), |
| ) |
|
|
| data = np.load(self.index_path, allow_pickle=True) |
| self.ids = data["ids"].tolist() |
| self.embs = data["embs"].astype("float32") |
|
|
| if len(self.ids) == 0: |
| raise IndexError_( |
| "Audio index is empty. " |
| "Add audio files and run: python scripts/build_embedding_indexes.py", |
| index_path=str(self.index_path), |
| ) |
|
|
| self.embedder = AlignedEmbedder(target_dim=512) |
|
|
| def retrieve( |
| self, |
| query_text: str, |
| min_similarity: Optional[float] = None, |
| ) -> AudioRetrievalResult: |
| """ |
| Retrieve best matching audio for a text query. |
| |
| Uses CLAP text encoder for the query (same space as indexed audio embeddings). |
| """ |
| if min_similarity is None: |
| min_similarity = self.min_similarity |
|
|
| |
| query_emb = self.embedder.embed_text_for_audio(query_text) |
|
|
| scored = [] |
| for audio_path, audio_emb in zip(self.ids, self.embs): |
| sim = cosine_similarity(query_emb, audio_emb) |
| scored.append((audio_path, sim)) |
| scored.sort(key=lambda x: x[1], reverse=True) |
|
|
| top_5 = [(Path(p).name, s) for p, s in scored[:5]] |
|
|
| above_threshold = [(p, s) for p, s in scored if s >= min_similarity] |
|
|
| if above_threshold: |
| best_path, best_sim = above_threshold[0] |
| logger.debug("Audio retrieval: %s (sim=%.4f, %d/%d above threshold)", |
| Path(best_path).name, best_sim, len(above_threshold), len(scored)) |
| return AudioRetrievalResult( |
| audio_path=best_path, |
| similarity=best_sim, |
| retrieval_failed=False, |
| candidates_considered=len(scored), |
| candidates_above_threshold=len(above_threshold), |
| top_5=top_5, |
| ) |
|
|
| |
| best_path, best_sim = scored[0] |
| logger.warning("Audio retrieval below threshold: %s (sim=%.4f < %.2f)", |
| Path(best_path).name, best_sim, min_similarity) |
| return AudioRetrievalResult( |
| audio_path=best_path, |
| similarity=best_sim, |
| retrieval_failed=best_sim < min_similarity, |
| candidates_considered=len(scored), |
| candidates_above_threshold=0, |
| top_5=top_5, |
| ) |
|
|
|
|
| def retrieve_audio_with_metadata( |
| prompt: str, |
| index_path: str = "data/embeddings/audio_index.npz", |
| min_similarity: float = 0.10, |
| ) -> AudioRetrievalResult: |
| """ |
| Retrieve audio for a prompt and return full metadata. |
| |
| Use this in experiment pipelines where audio quality matters. |
| """ |
| generator = AudioRetrievalGenerator( |
| index_path=index_path, |
| min_similarity=min_similarity, |
| ) |
| return generator.retrieve(prompt, min_similarity=min_similarity) |
|
|