File size: 11,216 Bytes
b9a2daa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
import gradio as gr
import torch

try:
    import spaces
    ZERO_GPU = True
except ImportError:
    ZERO_GPU = False
import numpy as np
from transformers import ASTForAudioClassification, AutoFeatureExtractor
from pydub import AudioSegment
import tempfile
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Model configurations
MODELS = {
    "fine_tuned": {
        "name": "Vyvo-Research/AST-Music-Classifier-1K",
        "display_name": "AST-Music-Classifier-1K (Fine-tuned)",
        "description": "Music sınıflandırması için özelleştirilmiş model",
        "badge": "Fine-tuned"
    },
    "base": {
        "name": "MIT/ast-finetuned-audioset-10-10-0.4593",
        "display_name": "MIT AST (Base Model)",
        "description": "AudioSet üzerinde eğitilmiş orijinal AST modeli",
        "badge": "Base"
    }
}

DETECTION_THRESHOLD = 0.50
WINDOW_SIZE = 5.0
HOP_SIZE = 5.0

# Load both models
logger.info("Loading models...")
models = {}
feature_extractors = {}

for key, config in MODELS.items():
    logger.info(f"Loading {config['display_name']}...")
    models[key] = ASTForAudioClassification.from_pretrained(config["name"])
    feature_extractors[key] = AutoFeatureExtractor.from_pretrained(config["name"])
    models[key].eval()

logger.info("All models loaded")


def load_audio(file_path: str, target_sr: int):
    audio = AudioSegment.from_file(file_path)
    audio = audio.set_channels(1).set_frame_rate(target_sr)
    samples = np.array(audio.get_array_of_samples()).astype(np.float32)
    samples = samples / np.iinfo(audio.array_type).max
    return samples, audio


@torch.no_grad()
def detect_music_with_model(audio_array, sample_rate, model_key):
    model = models[model_key]
    feature_extractor = feature_extractors[model_key]

    window_samples = int(WINDOW_SIZE * sample_rate)
    hop_samples = int(HOP_SIZE * sample_rate)
    total_samples = len(audio_array)

    music_segments = []
    all_predictions = []
    last_was_music = False
    device = next(model.parameters()).device
    use_half = device.type == "cuda"

    for start in range(0, total_samples, hop_samples):
        end = min(start + window_samples, total_samples)
        segment = audio_array[start:end]

        if len(segment) < sample_rate:
            continue

        needs_padding = len(segment) < window_samples
        if needs_padding:
            segment = np.pad(segment, (0, window_samples - len(segment)), mode='constant')

        inputs = feature_extractor(
            segment,
            sampling_rate=sample_rate,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=1024
        )

        if use_half:
            inputs = {k: v.to(device).half() for k, v in inputs.items()}
        else:
            inputs = {k: v.to(device) for k, v in inputs.items()}

        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=-1)

        pred_idx = torch.argmax(probs[0]).item()
        pred_label = model.config.id2label.get(pred_idx, "")
        pred_score = probs[0][pred_idx].item()

        is_music = "music" in pred_label.lower()
        is_uncertain = 0.40 <= pred_score <= 0.60

        start_sec = start / sample_rate
        end_sec = end / sample_rate

        all_predictions.append({
            "start": start_sec,
            "end": end_sec,
            "label": pred_label,
            "score": pred_score,
            "is_music": is_music
        })

        if is_uncertain and needs_padding:
            if last_was_music:
                music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score))
        elif is_music and pred_score >= DETECTION_THRESHOLD:
            music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score))
            last_was_music = True
        else:
            last_was_music = False

    return music_segments, all_predictions


def merge_segments(segments):
    if not segments:
        return []

    segments = sorted(segments, key=lambda x: x[0])
    merged = [segments[0]]

    for current in segments[1:]:
        last = merged[-1]
        if current[0] <= last[1]:
            merged[-1] = (last[0], max(last[1], current[1]), max(last[2], current[2]))
        else:
            merged.append(current)

    return merged


def remove_music_segments(audio, segments):
    if not segments:
        return audio

    clean_parts = []
    last_end = 0

    for start_ms, end_ms, _ in segments:
        if start_ms > last_end:
            clean_parts.append(audio[last_end:start_ms])
        last_end = end_ms

    if last_end < len(audio):
        clean_parts.append(audio[last_end:])

    if not clean_parts:
        return AudioSegment.silent(duration=0)

    return sum(clean_parts)


def calculate_metrics(segments, total_duration_ms):
    if not segments:
        return {
            "total_music_ms": 0,
            "segment_count": 0,
            "avg_confidence": 0,
            "coverage_percent": 0
        }

    total_music_ms = sum(end - start for start, end, _ in segments)
    avg_confidence = sum(score for _, _, score in segments) / len(segments)
    coverage_percent = (total_music_ms / total_duration_ms) * 100 if total_duration_ms > 0 else 0

    return {
        "total_music_ms": total_music_ms,
        "segment_count": len(segments),
        "avg_confidence": avg_confidence,
        "coverage_percent": coverage_percent
    }


def build_comparison_report(original_dur, ft_segments, base_segments, ft_metrics, base_metrics):
    ft_detected = ft_metrics["total_music_ms"] / 1000
    base_detected = base_metrics["total_music_ms"] / 1000

    # Calculate improvement percentages
    if base_metrics["avg_confidence"] > 0:
        conf_improvement = ((ft_metrics["avg_confidence"] - base_metrics["avg_confidence"]) / base_metrics["avg_confidence"]) * 100
    else:
        conf_improvement = 100 if ft_metrics["avg_confidence"] > 0 else 0

    if base_metrics["segment_count"] > 0:
        segment_improvement = ((ft_metrics["segment_count"] - base_metrics["segment_count"]) / base_metrics["segment_count"]) * 100
    else:
        segment_improvement = 100 if ft_metrics["segment_count"] > 0 else 0

    # Winner determination
    ft_score = 0
    base_score = 0
    if ft_metrics["avg_confidence"] > base_metrics["avg_confidence"]:
        ft_score += 1
    else:
        base_score += 1
    if ft_metrics["segment_count"] >= base_metrics["segment_count"]:
        ft_score += 1
    else:
        base_score += 1

    if ft_score > base_score:
        winner = "Fine-tuned"
        winner_pct = abs(conf_improvement)
    else:
        winner = "Base"
        winner_pct = abs(conf_improvement)

    report = f"""
## Result: **{winner}** model wins! (+{winner_pct:.1f}% confidence)

| Metric | Fine-tuned | Base |
|--------|-----------|------|
| Segments | **{ft_metrics['segment_count']}** | {base_metrics['segment_count']} |
| Duration | **{ft_detected:.1f}s** | {base_detected:.1f}s |
| Confidence | **{ft_metrics['avg_confidence']:.0%}** | {base_metrics['avg_confidence']:.0%} |

---
**Fine-tuned segments:**
"""
    if ft_segments:
        for start_ms, end_ms, score in ft_segments:
            report += f"- {start_ms/1000:.1f}s - {end_ms/1000:.1f}s ({score:.0%})\n"
    else:
        report += "No music detected\n"

    report += "\n**Base segments:**\n"
    if base_segments:
        for start_ms, end_ms, score in base_segments:
            report += f"- {start_ms/1000:.1f}s - {end_ms/1000:.1f}s ({score:.0%})\n"
    else:
        report += "No music detected\n"

    return report


@spaces.GPU if ZERO_GPU else lambda f: f
def process_audio_comparison(audio_file, progress=gr.Progress()):
    if audio_file is None:
        return None, None, "Please upload an audio file."

    try:
        progress(0.05, desc="Preparing models...")

        # Move models to GPU if available
        if torch.cuda.is_available():
            for key in models:
                models[key].to("cuda").half()
            torch.backends.cudnn.benchmark = True

        progress(0.1, desc="Loading audio...")
        sample_rate = feature_extractors["fine_tuned"].sampling_rate
        audio_array, audio = load_audio(audio_file, sample_rate)
        original_duration = len(audio) / 1000
        total_duration_ms = len(audio)

        # Process with Fine-tuned model
        progress(0.2, desc="Analyzing with Fine-tuned Model...")
        ft_segments, ft_predictions = detect_music_with_model(audio_array, sample_rate, "fine_tuned")
        ft_segments = merge_segments(ft_segments)
        ft_metrics = calculate_metrics(ft_segments, total_duration_ms)

        # Process with Base model
        progress(0.5, desc="Analyzing with Base Model...")
        base_segments, base_predictions = detect_music_with_model(audio_array, sample_rate, "base")
        base_segments = merge_segments(base_segments)
        base_metrics = calculate_metrics(base_segments, total_duration_ms)

        # Create outputs for both models
        progress(0.8, desc="Generating outputs...")

        # Fine-tuned model output
        ft_clean_audio = remove_music_segments(audio, ft_segments)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
            ft_clean_audio.export(f.name, format="wav")
            ft_output_path = f.name

        # Base model output
        base_clean_audio = remove_music_segments(audio, base_segments)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
            base_clean_audio.export(f.name, format="wav")
            base_output_path = f.name

        progress(0.95, desc="Building report...")
        report = build_comparison_report(
            original_duration, ft_segments, base_segments, ft_metrics, base_metrics
        )

        progress(1.0, desc="Done")
        return ft_output_path, base_output_path, report

    except Exception as e:
        logger.exception("Processing failed")
        return None, None, f"Error: {str(e)}"



with gr.Blocks(title="CleanSpeech - Model Comparison") as demo:
    gr.Markdown("# CleanSpeech - Model Comparison")

    # Input section
    with gr.Row():
        with gr.Column(scale=2):
            audio_input = gr.Audio(label="Upload Audio File", type="filepath")
            process_btn = gr.Button("Compare Models", variant="primary", size="lg")

    # Output section - Side by side
    with gr.Row():
        with gr.Column(scale=1):
            ft_audio_output = gr.Audio(label="Fine-tuned Output")

        with gr.Column(scale=1):
            base_audio_output = gr.Audio(label="Base Model Output")

    # Comparison report
    comparison_report = gr.Markdown(label="Comparison Report")

    process_btn.click(
        fn=process_audio_comparison,
        inputs=[audio_input],
        outputs=[ft_audio_output, base_audio_output, comparison_report]
    )

    # Footer
    gr.Markdown("""
---
**Models:** [Fine-tuned](https://huggingface.co/Vyvo-Research/AST-Music-Classifier-1K) | [Base](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)
    """)

demo.queue()
demo.launch(theme=gr.themes.Soft())