ankandrew commited on
Commit
1d7167a
·
verified ·
1 Parent(s): a04c1b1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +235 -0
app.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+
8
+ from transformers import ClapModel, ClapProcessor
9
+
10
+
11
+ MODEL_ID = "laion/clap-htsat-fused"
12
+ TARGET_SR = 48000
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ processor = ClapProcessor.from_pretrained(MODEL_ID)
17
+ model = ClapModel.from_pretrained(MODEL_ID).to(device)
18
+ model.eval()
19
+
20
+ # In-memory state
21
+ index_embeddings = None
22
+ index_metadata = []
23
+
24
+
25
+ def load_audio(path, target_sr=TARGET_SR):
26
+ audio, _ = librosa.load(path, sr=target_sr, mono=True)
27
+ return audio
28
+
29
+
30
+ def embed_audio(path):
31
+ audio = load_audio(path)
32
+
33
+ inputs = processor(
34
+ audios=audio,
35
+ sampling_rate=TARGET_SR,
36
+ return_tensors="pt",
37
+ padding=True,
38
+ )
39
+
40
+ inputs = {k: v.to(device) for k, v in inputs.items()}
41
+
42
+ with torch.no_grad():
43
+ embedding = model.get_audio_features(**inputs)
44
+
45
+ embedding = embedding.detach().cpu().numpy().astype(np.float32)[0]
46
+
47
+ # Normalize for cosine similarity
48
+ norm = np.linalg.norm(embedding)
49
+ if norm == 0:
50
+ return embedding
51
+
52
+ return embedding / norm
53
+
54
+
55
+ def index_audios(files):
56
+ global index_embeddings, index_metadata
57
+
58
+ if not files:
59
+ return [], "Upload at least one audio file."
60
+
61
+ embeddings = []
62
+ metadata = []
63
+
64
+ for file_obj in files:
65
+ path = file_obj.name
66
+ filename = os.path.basename(path)
67
+
68
+ emb = embed_audio(path)
69
+
70
+ embeddings.append(emb)
71
+ metadata.append(
72
+ {
73
+ "filename": filename,
74
+ "path": path,
75
+ }
76
+ )
77
+
78
+ index_embeddings = np.vstack(embeddings).astype(np.float32)
79
+ index_metadata = metadata
80
+
81
+ rows = [
82
+ [item["filename"], index_embeddings.shape[1]]
83
+ for item in index_metadata
84
+ ]
85
+
86
+ return rows, f"Indexed {len(index_metadata)} audio files."
87
+
88
+
89
+ def search_similar(query_file, top_k):
90
+ global index_embeddings, index_metadata
91
+
92
+ if query_file is None:
93
+ return [["Upload a query audio first.", "", ""]]
94
+
95
+ if index_embeddings is None or len(index_metadata) == 0:
96
+ return [["Index audios first.", "", ""]]
97
+
98
+ query_emb = embed_audio(query_file.name)
99
+
100
+ # Since all vectors are normalized, this is cosine similarity
101
+ scores = index_embeddings @ query_emb
102
+
103
+ top_k = min(int(top_k), len(scores))
104
+ top_indices = np.argsort(scores)[::-1][:top_k]
105
+
106
+ rows = []
107
+
108
+ for idx in top_indices:
109
+ rows.append(
110
+ [
111
+ index_metadata[idx]["filename"],
112
+ round(float(scores[idx]), 4),
113
+ index_metadata[idx]["path"],
114
+ ]
115
+ )
116
+
117
+ return rows
118
+
119
+
120
+ def similarity_matrix():
121
+ global index_embeddings, index_metadata
122
+
123
+ if index_embeddings is None or len(index_metadata) == 0:
124
+ return [["Index audios first."]]
125
+
126
+ matrix = index_embeddings @ index_embeddings.T
127
+
128
+ rows = []
129
+ filenames = [item["filename"] for item in index_metadata]
130
+
131
+ for i, filename in enumerate(filenames):
132
+ row = [filename]
133
+ row.extend([round(float(v), 4) for v in matrix[i]])
134
+ rows.append(row)
135
+
136
+ headers = ["audio"] + filenames
137
+
138
+ return gr.Dataframe(
139
+ value=rows,
140
+ headers=headers,
141
+ label="Cosine similarity matrix",
142
+ )
143
+
144
+
145
+ def reset_index():
146
+ global index_embeddings, index_metadata
147
+
148
+ index_embeddings = None
149
+ index_metadata = []
150
+
151
+ return "Index reset."
152
+
153
+
154
+ with gr.Blocks(title="CLAP Audio Similarity PoC") as demo:
155
+ gr.Markdown(
156
+ """
157
+ # CLAP Audio Similarity PoC
158
+
159
+ Generate LAION CLAP embeddings, and compare them with cosine similarity.
160
+ """
161
+ )
162
+
163
+ with gr.Tab("1. Index audios"):
164
+ files = gr.File(
165
+ label="Audio files to index",
166
+ file_count="multiple",
167
+ file_types=["audio"],
168
+ )
169
+
170
+ index_btn = gr.Button("Index audios")
171
+
172
+ index_output = gr.Dataframe(
173
+ headers=["filename", "embedding_dim"],
174
+ label="Indexed files",
175
+ )
176
+
177
+ index_status = gr.Textbox(label="Status")
178
+
179
+ index_btn.click(
180
+ fn=index_audios,
181
+ inputs=[files],
182
+ outputs=[index_output, index_status],
183
+ )
184
+
185
+ with gr.Tab("2. Search similar"):
186
+ query_file = gr.File(
187
+ label="Query audio",
188
+ file_count="single",
189
+ file_types=["audio"],
190
+ )
191
+
192
+ top_k = gr.Slider(
193
+ minimum=1,
194
+ maximum=20,
195
+ value=10,
196
+ step=1,
197
+ label="Top K",
198
+ )
199
+
200
+ search_btn = gr.Button("Search")
201
+
202
+ search_output = gr.Dataframe(
203
+ headers=["filename", "score", "path"],
204
+ label="Similar audios",
205
+ )
206
+
207
+ search_btn.click(
208
+ fn=search_similar,
209
+ inputs=[query_file, top_k],
210
+ outputs=[search_output],
211
+ )
212
+
213
+ with gr.Tab("3. Similarity matrix"):
214
+ matrix_btn = gr.Button("Generate matrix")
215
+ matrix_output = gr.Dataframe(label="Cosine similarity matrix")
216
+
217
+ matrix_btn.click(
218
+ fn=similarity_matrix,
219
+ inputs=[],
220
+ outputs=[matrix_output],
221
+ )
222
+
223
+ with gr.Tab("Reset"):
224
+ reset_btn = gr.Button("Reset index")
225
+ reset_output = gr.Textbox(label="Status")
226
+
227
+ reset_btn.click(
228
+ fn=reset_index,
229
+ inputs=[],
230
+ outputs=[reset_output],
231
+ )
232
+
233
+
234
+ if __name__ == "__main__":
235
+ demo.launch()