| import os |
|
|
| import gradio as gr |
| import librosa |
| import numpy as np |
| import torch |
|
|
| from transformers import ClapModel, ClapProcessor |
|
|
|
|
| MODEL_ID = "laion/clap-htsat-fused" |
| TARGET_SR = 48000 |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| processor = ClapProcessor.from_pretrained(MODEL_ID) |
| model = ClapModel.from_pretrained(MODEL_ID).to(device) |
| model.eval() |
|
|
| |
| index_embeddings = None |
| index_metadata = [] |
|
|
|
|
| def load_audio(path, target_sr=TARGET_SR): |
| audio, _ = librosa.load(path, sr=target_sr, mono=True) |
| return audio |
|
|
|
|
| def embed_audio(path): |
| audio = load_audio(path) |
|
|
| inputs = processor( |
| audio=audio, |
| sampling_rate=TARGET_SR, |
| return_tensors="pt", |
| padding=True, |
| ) |
|
|
| inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| output = model.get_audio_features(**inputs) |
|
|
| if hasattr(output, "pooler_output"): |
| embedding = output.pooler_output |
| else: |
| embedding = output |
|
|
| embedding = embedding.detach().cpu().numpy().astype(np.float32)[0] |
|
|
| norm = np.linalg.norm(embedding) |
| if norm == 0: |
| return embedding |
|
|
| return embedding / norm |
|
|
|
|
| def index_audios(files): |
| global index_embeddings, index_metadata |
|
|
| if not files: |
| return [], "Upload at least one audio file." |
|
|
| embeddings = [] |
| metadata = [] |
|
|
| for file_obj in files: |
| path = file_obj.name |
| filename = os.path.basename(path) |
|
|
| emb = embed_audio(path) |
|
|
| embeddings.append(emb) |
| metadata.append( |
| { |
| "filename": filename, |
| "path": path, |
| } |
| ) |
|
|
| index_embeddings = np.vstack(embeddings).astype(np.float32) |
| index_metadata = metadata |
|
|
| rows = [ |
| [item["filename"], index_embeddings.shape[1]] |
| for item in index_metadata |
| ] |
|
|
| return rows, f"Indexed {len(index_metadata)} audio files." |
|
|
|
|
| def search_similar(query_file, top_k): |
| global index_embeddings, index_metadata |
|
|
| if query_file is None: |
| return [["Upload a query audio first.", "", ""]] |
|
|
| if index_embeddings is None or len(index_metadata) == 0: |
| return [["Index audios first.", "", ""]] |
|
|
| query_emb = embed_audio(query_file.name) |
|
|
| |
| scores = index_embeddings @ query_emb |
|
|
| top_k = min(int(top_k), len(scores)) |
| top_indices = np.argsort(scores)[::-1][:top_k] |
|
|
| rows = [] |
|
|
| for idx in top_indices: |
| rows.append( |
| [ |
| index_metadata[idx]["filename"], |
| round(float(scores[idx]), 4), |
| index_metadata[idx]["path"], |
| ] |
| ) |
|
|
| return rows |
|
|
|
|
| def similarity_matrix(): |
| global index_embeddings, index_metadata |
|
|
| if index_embeddings is None or len(index_metadata) == 0: |
| return [["Index audios first."]] |
|
|
| matrix = index_embeddings @ index_embeddings.T |
|
|
| rows = [] |
| filenames = [item["filename"] for item in index_metadata] |
|
|
| for i, filename in enumerate(filenames): |
| row = [filename] |
| row.extend([round(float(v), 4) for v in matrix[i]]) |
| rows.append(row) |
|
|
| headers = ["audio"] + filenames |
|
|
| return gr.Dataframe( |
| value=rows, |
| headers=headers, |
| label="Cosine similarity matrix", |
| ) |
|
|
|
|
| def reset_index(): |
| global index_embeddings, index_metadata |
|
|
| index_embeddings = None |
| index_metadata = [] |
|
|
| return "Index reset." |
|
|
|
|
| with gr.Blocks(title="CLAP Audio Similarity PoC") as demo: |
| gr.Markdown( |
| """ |
| # CLAP Audio Similarity PoC |
| |
| Generate LAION CLAP embeddings, and compare them with cosine similarity. |
| """ |
| ) |
|
|
| with gr.Tab("1. Index audios"): |
| files = gr.File( |
| label="Audio files to index", |
| file_count="multiple", |
| file_types=["audio"], |
| ) |
|
|
| index_btn = gr.Button("Index audios") |
|
|
| index_output = gr.Dataframe( |
| headers=["filename", "embedding_dim"], |
| label="Indexed files", |
| ) |
|
|
| index_status = gr.Textbox(label="Status") |
|
|
| index_btn.click( |
| fn=index_audios, |
| inputs=[files], |
| outputs=[index_output, index_status], |
| ) |
|
|
| with gr.Tab("2. Search similar"): |
| query_file = gr.File( |
| label="Query audio", |
| file_count="single", |
| file_types=["audio"], |
| ) |
|
|
| top_k = gr.Slider( |
| minimum=1, |
| maximum=20, |
| value=10, |
| step=1, |
| label="Top K", |
| ) |
|
|
| search_btn = gr.Button("Search") |
|
|
| search_output = gr.Dataframe( |
| headers=["filename", "score", "path"], |
| label="Similar audios", |
| ) |
|
|
| search_btn.click( |
| fn=search_similar, |
| inputs=[query_file, top_k], |
| outputs=[search_output], |
| ) |
|
|
| with gr.Tab("3. Similarity matrix"): |
| matrix_btn = gr.Button("Generate matrix") |
| matrix_output = gr.Dataframe(label="Cosine similarity matrix") |
|
|
| matrix_btn.click( |
| fn=similarity_matrix, |
| inputs=[], |
| outputs=[matrix_output], |
| ) |
|
|
| with gr.Tab("Reset"): |
| reset_btn = gr.Button("Reset index") |
| reset_output = gr.Textbox(label="Status") |
|
|
| reset_btn.click( |
| fn=reset_index, |
| inputs=[], |
| outputs=[reset_output], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |