import os import gradio as gr import librosa import numpy as np import torch from transformers import ClapModel, ClapProcessor MODEL_ID = "laion/clap-htsat-fused" TARGET_SR = 48000 device = "cuda" if torch.cuda.is_available() else "cpu" processor = ClapProcessor.from_pretrained(MODEL_ID) model = ClapModel.from_pretrained(MODEL_ID).to(device) model.eval() # In-memory state index_embeddings = None index_metadata = [] def load_audio(path, target_sr=TARGET_SR): audio, _ = librosa.load(path, sr=target_sr, mono=True) return audio def embed_audio(path): audio = load_audio(path) inputs = processor( audio=audio, sampling_rate=TARGET_SR, return_tensors="pt", padding=True, ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): output = model.get_audio_features(**inputs) if hasattr(output, "pooler_output"): embedding = output.pooler_output else: embedding = output embedding = embedding.detach().cpu().numpy().astype(np.float32)[0] norm = np.linalg.norm(embedding) if norm == 0: return embedding return embedding / norm def index_audios(files): global index_embeddings, index_metadata if not files: return [], "Upload at least one audio file." embeddings = [] metadata = [] for file_obj in files: path = file_obj.name filename = os.path.basename(path) emb = embed_audio(path) embeddings.append(emb) metadata.append( { "filename": filename, "path": path, } ) index_embeddings = np.vstack(embeddings).astype(np.float32) index_metadata = metadata rows = [ [item["filename"], index_embeddings.shape[1]] for item in index_metadata ] return rows, f"Indexed {len(index_metadata)} audio files." def search_similar(query_file, top_k): global index_embeddings, index_metadata if query_file is None: return [["Upload a query audio first.", "", ""]] if index_embeddings is None or len(index_metadata) == 0: return [["Index audios first.", "", ""]] query_emb = embed_audio(query_file.name) # Since all vectors are normalized, this is cosine similarity scores = index_embeddings @ query_emb top_k = min(int(top_k), len(scores)) top_indices = np.argsort(scores)[::-1][:top_k] rows = [] for idx in top_indices: rows.append( [ index_metadata[idx]["filename"], round(float(scores[idx]), 4), index_metadata[idx]["path"], ] ) return rows def similarity_matrix(): global index_embeddings, index_metadata if index_embeddings is None or len(index_metadata) == 0: return [["Index audios first."]] matrix = index_embeddings @ index_embeddings.T rows = [] filenames = [item["filename"] for item in index_metadata] for i, filename in enumerate(filenames): row = [filename] row.extend([round(float(v), 4) for v in matrix[i]]) rows.append(row) headers = ["audio"] + filenames return gr.Dataframe( value=rows, headers=headers, label="Cosine similarity matrix", ) def reset_index(): global index_embeddings, index_metadata index_embeddings = None index_metadata = [] return "Index reset." with gr.Blocks(title="CLAP Audio Similarity PoC") as demo: gr.Markdown( """ # CLAP Audio Similarity PoC Generate LAION CLAP embeddings, and compare them with cosine similarity. """ ) with gr.Tab("1. Index audios"): files = gr.File( label="Audio files to index", file_count="multiple", file_types=["audio"], ) index_btn = gr.Button("Index audios") index_output = gr.Dataframe( headers=["filename", "embedding_dim"], label="Indexed files", ) index_status = gr.Textbox(label="Status") index_btn.click( fn=index_audios, inputs=[files], outputs=[index_output, index_status], ) with gr.Tab("2. Search similar"): query_file = gr.File( label="Query audio", file_count="single", file_types=["audio"], ) top_k = gr.Slider( minimum=1, maximum=20, value=10, step=1, label="Top K", ) search_btn = gr.Button("Search") search_output = gr.Dataframe( headers=["filename", "score", "path"], label="Similar audios", ) search_btn.click( fn=search_similar, inputs=[query_file, top_k], outputs=[search_output], ) with gr.Tab("3. Similarity matrix"): matrix_btn = gr.Button("Generate matrix") matrix_output = gr.Dataframe(label="Cosine similarity matrix") matrix_btn.click( fn=similarity_matrix, inputs=[], outputs=[matrix_output], ) with gr.Tab("Reset"): reset_btn = gr.Button("Reset index") reset_output = gr.Textbox(label="Status") reset_btn.click( fn=reset_index, inputs=[], outputs=[reset_output], ) if __name__ == "__main__": demo.launch()