Aynursusuz commited on
Commit
b9a2daa
·
verified ·
1 Parent(s): 85cf8e6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +348 -0
app.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+
4
+ try:
5
+ import spaces
6
+ ZERO_GPU = True
7
+ except ImportError:
8
+ ZERO_GPU = False
9
+ import numpy as np
10
+ from transformers import ASTForAudioClassification, AutoFeatureExtractor
11
+ from pydub import AudioSegment
12
+ import tempfile
13
+ import logging
14
+
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # Model configurations
19
+ MODELS = {
20
+ "fine_tuned": {
21
+ "name": "Vyvo-Research/AST-Music-Classifier-1K",
22
+ "display_name": "AST-Music-Classifier-1K (Fine-tuned)",
23
+ "description": "Music sınıflandırması için özelleştirilmiş model",
24
+ "badge": "Fine-tuned"
25
+ },
26
+ "base": {
27
+ "name": "MIT/ast-finetuned-audioset-10-10-0.4593",
28
+ "display_name": "MIT AST (Base Model)",
29
+ "description": "AudioSet üzerinde eğitilmiş orijinal AST modeli",
30
+ "badge": "Base"
31
+ }
32
+ }
33
+
34
+ DETECTION_THRESHOLD = 0.50
35
+ WINDOW_SIZE = 5.0
36
+ HOP_SIZE = 5.0
37
+
38
+ # Load both models
39
+ logger.info("Loading models...")
40
+ models = {}
41
+ feature_extractors = {}
42
+
43
+ for key, config in MODELS.items():
44
+ logger.info(f"Loading {config['display_name']}...")
45
+ models[key] = ASTForAudioClassification.from_pretrained(config["name"])
46
+ feature_extractors[key] = AutoFeatureExtractor.from_pretrained(config["name"])
47
+ models[key].eval()
48
+
49
+ logger.info("All models loaded")
50
+
51
+
52
+ def load_audio(file_path: str, target_sr: int):
53
+ audio = AudioSegment.from_file(file_path)
54
+ audio = audio.set_channels(1).set_frame_rate(target_sr)
55
+ samples = np.array(audio.get_array_of_samples()).astype(np.float32)
56
+ samples = samples / np.iinfo(audio.array_type).max
57
+ return samples, audio
58
+
59
+
60
+ @torch.no_grad()
61
+ def detect_music_with_model(audio_array, sample_rate, model_key):
62
+ model = models[model_key]
63
+ feature_extractor = feature_extractors[model_key]
64
+
65
+ window_samples = int(WINDOW_SIZE * sample_rate)
66
+ hop_samples = int(HOP_SIZE * sample_rate)
67
+ total_samples = len(audio_array)
68
+
69
+ music_segments = []
70
+ all_predictions = []
71
+ last_was_music = False
72
+ device = next(model.parameters()).device
73
+ use_half = device.type == "cuda"
74
+
75
+ for start in range(0, total_samples, hop_samples):
76
+ end = min(start + window_samples, total_samples)
77
+ segment = audio_array[start:end]
78
+
79
+ if len(segment) < sample_rate:
80
+ continue
81
+
82
+ needs_padding = len(segment) < window_samples
83
+ if needs_padding:
84
+ segment = np.pad(segment, (0, window_samples - len(segment)), mode='constant')
85
+
86
+ inputs = feature_extractor(
87
+ segment,
88
+ sampling_rate=sample_rate,
89
+ return_tensors="pt",
90
+ padding="max_length",
91
+ truncation=True,
92
+ max_length=1024
93
+ )
94
+
95
+ if use_half:
96
+ inputs = {k: v.to(device).half() for k, v in inputs.items()}
97
+ else:
98
+ inputs = {k: v.to(device) for k, v in inputs.items()}
99
+
100
+ outputs = model(**inputs)
101
+ probs = torch.softmax(outputs.logits, dim=-1)
102
+
103
+ pred_idx = torch.argmax(probs[0]).item()
104
+ pred_label = model.config.id2label.get(pred_idx, "")
105
+ pred_score = probs[0][pred_idx].item()
106
+
107
+ is_music = "music" in pred_label.lower()
108
+ is_uncertain = 0.40 <= pred_score <= 0.60
109
+
110
+ start_sec = start / sample_rate
111
+ end_sec = end / sample_rate
112
+
113
+ all_predictions.append({
114
+ "start": start_sec,
115
+ "end": end_sec,
116
+ "label": pred_label,
117
+ "score": pred_score,
118
+ "is_music": is_music
119
+ })
120
+
121
+ if is_uncertain and needs_padding:
122
+ if last_was_music:
123
+ music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score))
124
+ elif is_music and pred_score >= DETECTION_THRESHOLD:
125
+ music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score))
126
+ last_was_music = True
127
+ else:
128
+ last_was_music = False
129
+
130
+ return music_segments, all_predictions
131
+
132
+
133
+ def merge_segments(segments):
134
+ if not segments:
135
+ return []
136
+
137
+ segments = sorted(segments, key=lambda x: x[0])
138
+ merged = [segments[0]]
139
+
140
+ for current in segments[1:]:
141
+ last = merged[-1]
142
+ if current[0] <= last[1]:
143
+ merged[-1] = (last[0], max(last[1], current[1]), max(last[2], current[2]))
144
+ else:
145
+ merged.append(current)
146
+
147
+ return merged
148
+
149
+
150
+ def remove_music_segments(audio, segments):
151
+ if not segments:
152
+ return audio
153
+
154
+ clean_parts = []
155
+ last_end = 0
156
+
157
+ for start_ms, end_ms, _ in segments:
158
+ if start_ms > last_end:
159
+ clean_parts.append(audio[last_end:start_ms])
160
+ last_end = end_ms
161
+
162
+ if last_end < len(audio):
163
+ clean_parts.append(audio[last_end:])
164
+
165
+ if not clean_parts:
166
+ return AudioSegment.silent(duration=0)
167
+
168
+ return sum(clean_parts)
169
+
170
+
171
+ def calculate_metrics(segments, total_duration_ms):
172
+ if not segments:
173
+ return {
174
+ "total_music_ms": 0,
175
+ "segment_count": 0,
176
+ "avg_confidence": 0,
177
+ "coverage_percent": 0
178
+ }
179
+
180
+ total_music_ms = sum(end - start for start, end, _ in segments)
181
+ avg_confidence = sum(score for _, _, score in segments) / len(segments)
182
+ coverage_percent = (total_music_ms / total_duration_ms) * 100 if total_duration_ms > 0 else 0
183
+
184
+ return {
185
+ "total_music_ms": total_music_ms,
186
+ "segment_count": len(segments),
187
+ "avg_confidence": avg_confidence,
188
+ "coverage_percent": coverage_percent
189
+ }
190
+
191
+
192
+ def build_comparison_report(original_dur, ft_segments, base_segments, ft_metrics, base_metrics):
193
+ ft_detected = ft_metrics["total_music_ms"] / 1000
194
+ base_detected = base_metrics["total_music_ms"] / 1000
195
+
196
+ # Calculate improvement percentages
197
+ if base_metrics["avg_confidence"] > 0:
198
+ conf_improvement = ((ft_metrics["avg_confidence"] - base_metrics["avg_confidence"]) / base_metrics["avg_confidence"]) * 100
199
+ else:
200
+ conf_improvement = 100 if ft_metrics["avg_confidence"] > 0 else 0
201
+
202
+ if base_metrics["segment_count"] > 0:
203
+ segment_improvement = ((ft_metrics["segment_count"] - base_metrics["segment_count"]) / base_metrics["segment_count"]) * 100
204
+ else:
205
+ segment_improvement = 100 if ft_metrics["segment_count"] > 0 else 0
206
+
207
+ # Winner determination
208
+ ft_score = 0
209
+ base_score = 0
210
+ if ft_metrics["avg_confidence"] > base_metrics["avg_confidence"]:
211
+ ft_score += 1
212
+ else:
213
+ base_score += 1
214
+ if ft_metrics["segment_count"] >= base_metrics["segment_count"]:
215
+ ft_score += 1
216
+ else:
217
+ base_score += 1
218
+
219
+ if ft_score > base_score:
220
+ winner = "Fine-tuned"
221
+ winner_pct = abs(conf_improvement)
222
+ else:
223
+ winner = "Base"
224
+ winner_pct = abs(conf_improvement)
225
+
226
+ report = f"""
227
+ ## Result: **{winner}** model wins! (+{winner_pct:.1f}% confidence)
228
+
229
+ | Metric | Fine-tuned | Base |
230
+ |--------|-----------|------|
231
+ | Segments | **{ft_metrics['segment_count']}** | {base_metrics['segment_count']} |
232
+ | Duration | **{ft_detected:.1f}s** | {base_detected:.1f}s |
233
+ | Confidence | **{ft_metrics['avg_confidence']:.0%}** | {base_metrics['avg_confidence']:.0%} |
234
+
235
+ ---
236
+ **Fine-tuned segments:**
237
+ """
238
+ if ft_segments:
239
+ for start_ms, end_ms, score in ft_segments:
240
+ report += f"- {start_ms/1000:.1f}s - {end_ms/1000:.1f}s ({score:.0%})\n"
241
+ else:
242
+ report += "No music detected\n"
243
+
244
+ report += "\n**Base segments:**\n"
245
+ if base_segments:
246
+ for start_ms, end_ms, score in base_segments:
247
+ report += f"- {start_ms/1000:.1f}s - {end_ms/1000:.1f}s ({score:.0%})\n"
248
+ else:
249
+ report += "No music detected\n"
250
+
251
+ return report
252
+
253
+
254
+ @spaces.GPU if ZERO_GPU else lambda f: f
255
+ def process_audio_comparison(audio_file, progress=gr.Progress()):
256
+ if audio_file is None:
257
+ return None, None, "Please upload an audio file."
258
+
259
+ try:
260
+ progress(0.05, desc="Preparing models...")
261
+
262
+ # Move models to GPU if available
263
+ if torch.cuda.is_available():
264
+ for key in models:
265
+ models[key].to("cuda").half()
266
+ torch.backends.cudnn.benchmark = True
267
+
268
+ progress(0.1, desc="Loading audio...")
269
+ sample_rate = feature_extractors["fine_tuned"].sampling_rate
270
+ audio_array, audio = load_audio(audio_file, sample_rate)
271
+ original_duration = len(audio) / 1000
272
+ total_duration_ms = len(audio)
273
+
274
+ # Process with Fine-tuned model
275
+ progress(0.2, desc="Analyzing with Fine-tuned Model...")
276
+ ft_segments, ft_predictions = detect_music_with_model(audio_array, sample_rate, "fine_tuned")
277
+ ft_segments = merge_segments(ft_segments)
278
+ ft_metrics = calculate_metrics(ft_segments, total_duration_ms)
279
+
280
+ # Process with Base model
281
+ progress(0.5, desc="Analyzing with Base Model...")
282
+ base_segments, base_predictions = detect_music_with_model(audio_array, sample_rate, "base")
283
+ base_segments = merge_segments(base_segments)
284
+ base_metrics = calculate_metrics(base_segments, total_duration_ms)
285
+
286
+ # Create outputs for both models
287
+ progress(0.8, desc="Generating outputs...")
288
+
289
+ # Fine-tuned model output
290
+ ft_clean_audio = remove_music_segments(audio, ft_segments)
291
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
292
+ ft_clean_audio.export(f.name, format="wav")
293
+ ft_output_path = f.name
294
+
295
+ # Base model output
296
+ base_clean_audio = remove_music_segments(audio, base_segments)
297
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
298
+ base_clean_audio.export(f.name, format="wav")
299
+ base_output_path = f.name
300
+
301
+ progress(0.95, desc="Building report...")
302
+ report = build_comparison_report(
303
+ original_duration, ft_segments, base_segments, ft_metrics, base_metrics
304
+ )
305
+
306
+ progress(1.0, desc="Done")
307
+ return ft_output_path, base_output_path, report
308
+
309
+ except Exception as e:
310
+ logger.exception("Processing failed")
311
+ return None, None, f"Error: {str(e)}"
312
+
313
+
314
+
315
+ with gr.Blocks(title="CleanSpeech - Model Comparison") as demo:
316
+ gr.Markdown("# CleanSpeech - Model Comparison")
317
+
318
+ # Input section
319
+ with gr.Row():
320
+ with gr.Column(scale=2):
321
+ audio_input = gr.Audio(label="Upload Audio File", type="filepath")
322
+ process_btn = gr.Button("Compare Models", variant="primary", size="lg")
323
+
324
+ # Output section - Side by side
325
+ with gr.Row():
326
+ with gr.Column(scale=1):
327
+ ft_audio_output = gr.Audio(label="Fine-tuned Output")
328
+
329
+ with gr.Column(scale=1):
330
+ base_audio_output = gr.Audio(label="Base Model Output")
331
+
332
+ # Comparison report
333
+ comparison_report = gr.Markdown(label="Comparison Report")
334
+
335
+ process_btn.click(
336
+ fn=process_audio_comparison,
337
+ inputs=[audio_input],
338
+ outputs=[ft_audio_output, base_audio_output, comparison_report]
339
+ )
340
+
341
+ # Footer
342
+ gr.Markdown("""
343
+ ---
344
+ **Models:** [Fine-tuned](https://huggingface.co/Vyvo-Research/AST-Music-Classifier-1K) | [Base](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)
345
+ """)
346
+
347
+ demo.queue()
348
+ demo.launch(theme=gr.themes.Soft())