File size: 4,818 Bytes
6835659 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """
Audio retrieval generator — retrieves best-matching audio from indexed pool.
Mirrors the image retrieval approach:
- Uses CLAP text encoder to embed query
- Compares against CLAP audio embeddings in the audio index
- Returns best matching real audio file
Falls back to synthetic ambient if no audio index is available.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
from src.embeddings.aligned_embeddings import AlignedEmbedder
from src.embeddings.similarity import cosine_similarity
from src.exceptions import IndexError_
logger = logging.getLogger(__name__)
@dataclass
class AudioRetrievalResult:
"""Result of audio retrieval with metadata for experiment bundles."""
audio_path: str
similarity: float
retrieval_failed: bool
candidates_considered: int
candidates_above_threshold: int
top_5: List[Tuple[str, float]]
class AudioRetrievalGenerator:
"""
Audio retrieval using CLAP shared space.
Query: CLAP text embedding of the prompt
Index: CLAP audio embeddings of real audio files
This ensures st_a (text-audio similarity) is meaningful because
both query and candidates live in CLAP's shared space.
"""
def __init__(
self,
index_path: str = "data/embeddings/audio_index.npz",
min_similarity: float = 0.10,
top_k: int = 5,
):
self.index_path = Path(index_path)
self.min_similarity = min_similarity
self.top_k = top_k
if not self.index_path.exists():
raise IndexError_(
f"Missing audio index at {self.index_path}. "
"Run: python scripts/build_embedding_indexes.py",
index_path=str(self.index_path),
)
data = np.load(self.index_path, allow_pickle=True)
self.ids = data["ids"].tolist()
self.embs = data["embs"].astype("float32")
if len(self.ids) == 0:
raise IndexError_(
"Audio index is empty. "
"Add audio files and run: python scripts/build_embedding_indexes.py",
index_path=str(self.index_path),
)
self.embedder = AlignedEmbedder(target_dim=512)
def retrieve(
self,
query_text: str,
min_similarity: Optional[float] = None,
) -> AudioRetrievalResult:
"""
Retrieve best matching audio for a text query.
Uses CLAP text encoder for the query (same space as indexed audio embeddings).
"""
if min_similarity is None:
min_similarity = self.min_similarity
# Use CLAP text encoder (not CLIP) for audio retrieval
query_emb = self.embedder.embed_text_for_audio(query_text)
scored = []
for audio_path, audio_emb in zip(self.ids, self.embs):
sim = cosine_similarity(query_emb, audio_emb)
scored.append((audio_path, sim))
scored.sort(key=lambda x: x[1], reverse=True)
top_5 = [(Path(p).name, s) for p, s in scored[:5]]
above_threshold = [(p, s) for p, s in scored if s >= min_similarity]
if above_threshold:
best_path, best_sim = above_threshold[0]
logger.debug("Audio retrieval: %s (sim=%.4f, %d/%d above threshold)",
Path(best_path).name, best_sim, len(above_threshold), len(scored))
return AudioRetrievalResult(
audio_path=best_path,
similarity=best_sim,
retrieval_failed=False,
candidates_considered=len(scored),
candidates_above_threshold=len(above_threshold),
top_5=top_5,
)
# Fallback: return best candidate even if below threshold
best_path, best_sim = scored[0]
logger.warning("Audio retrieval below threshold: %s (sim=%.4f < %.2f)",
Path(best_path).name, best_sim, min_similarity)
return AudioRetrievalResult(
audio_path=best_path,
similarity=best_sim,
retrieval_failed=best_sim < min_similarity,
candidates_considered=len(scored),
candidates_above_threshold=0,
top_5=top_5,
)
def retrieve_audio_with_metadata(
prompt: str,
index_path: str = "data/embeddings/audio_index.npz",
min_similarity: float = 0.10,
) -> AudioRetrievalResult:
"""
Retrieve audio for a prompt and return full metadata.
Use this in experiment pipelines where audio quality matters.
"""
generator = AudioRetrievalGenerator(
index_path=index_path,
min_similarity=min_similarity,
)
return generator.retrieve(prompt, min_similarity=min_similarity)
|