CLAP / app.py
ankandrew's picture
Update app.py
fc45298 verified
import os
import gradio as gr
import librosa
import numpy as np
import torch
from transformers import ClapModel, ClapProcessor
MODEL_ID = "laion/clap-htsat-fused"
TARGET_SR = 48000
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = ClapProcessor.from_pretrained(MODEL_ID)
model = ClapModel.from_pretrained(MODEL_ID).to(device)
model.eval()
# In-memory state
index_embeddings = None
index_metadata = []
def load_audio(path, target_sr=TARGET_SR):
audio, _ = librosa.load(path, sr=target_sr, mono=True)
return audio
def embed_audio(path):
audio = load_audio(path)
inputs = processor(
audio=audio,
sampling_rate=TARGET_SR,
return_tensors="pt",
padding=True,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
output = model.get_audio_features(**inputs)
if hasattr(output, "pooler_output"):
embedding = output.pooler_output
else:
embedding = output
embedding = embedding.detach().cpu().numpy().astype(np.float32)[0]
norm = np.linalg.norm(embedding)
if norm == 0:
return embedding
return embedding / norm
def index_audios(files):
global index_embeddings, index_metadata
if not files:
return [], "Upload at least one audio file."
embeddings = []
metadata = []
for file_obj in files:
path = file_obj.name
filename = os.path.basename(path)
emb = embed_audio(path)
embeddings.append(emb)
metadata.append(
{
"filename": filename,
"path": path,
}
)
index_embeddings = np.vstack(embeddings).astype(np.float32)
index_metadata = metadata
rows = [
[item["filename"], index_embeddings.shape[1]]
for item in index_metadata
]
return rows, f"Indexed {len(index_metadata)} audio files."
def search_similar(query_file, top_k):
global index_embeddings, index_metadata
if query_file is None:
return [["Upload a query audio first.", "", ""]]
if index_embeddings is None or len(index_metadata) == 0:
return [["Index audios first.", "", ""]]
query_emb = embed_audio(query_file.name)
# Since all vectors are normalized, this is cosine similarity
scores = index_embeddings @ query_emb
top_k = min(int(top_k), len(scores))
top_indices = np.argsort(scores)[::-1][:top_k]
rows = []
for idx in top_indices:
rows.append(
[
index_metadata[idx]["filename"],
round(float(scores[idx]), 4),
index_metadata[idx]["path"],
]
)
return rows
def similarity_matrix():
global index_embeddings, index_metadata
if index_embeddings is None or len(index_metadata) == 0:
return [["Index audios first."]]
matrix = index_embeddings @ index_embeddings.T
rows = []
filenames = [item["filename"] for item in index_metadata]
for i, filename in enumerate(filenames):
row = [filename]
row.extend([round(float(v), 4) for v in matrix[i]])
rows.append(row)
headers = ["audio"] + filenames
return gr.Dataframe(
value=rows,
headers=headers,
label="Cosine similarity matrix",
)
def reset_index():
global index_embeddings, index_metadata
index_embeddings = None
index_metadata = []
return "Index reset."
with gr.Blocks(title="CLAP Audio Similarity PoC") as demo:
gr.Markdown(
"""
# CLAP Audio Similarity PoC
Generate LAION CLAP embeddings, and compare them with cosine similarity.
"""
)
with gr.Tab("1. Index audios"):
files = gr.File(
label="Audio files to index",
file_count="multiple",
file_types=["audio"],
)
index_btn = gr.Button("Index audios")
index_output = gr.Dataframe(
headers=["filename", "embedding_dim"],
label="Indexed files",
)
index_status = gr.Textbox(label="Status")
index_btn.click(
fn=index_audios,
inputs=[files],
outputs=[index_output, index_status],
)
with gr.Tab("2. Search similar"):
query_file = gr.File(
label="Query audio",
file_count="single",
file_types=["audio"],
)
top_k = gr.Slider(
minimum=1,
maximum=20,
value=10,
step=1,
label="Top K",
)
search_btn = gr.Button("Search")
search_output = gr.Dataframe(
headers=["filename", "score", "path"],
label="Similar audios",
)
search_btn.click(
fn=search_similar,
inputs=[query_file, top_k],
outputs=[search_output],
)
with gr.Tab("3. Similarity matrix"):
matrix_btn = gr.Button("Generate matrix")
matrix_output = gr.Dataframe(label="Cosine similarity matrix")
matrix_btn.click(
fn=similarity_matrix,
inputs=[],
outputs=[matrix_output],
)
with gr.Tab("Reset"):
reset_btn = gr.Button("Reset index")
reset_output = gr.Textbox(label="Status")
reset_btn.click(
fn=reset_index,
inputs=[],
outputs=[reset_output],
)
if __name__ == "__main__":
demo.launch()