Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| try: | |
| import spaces | |
| ZERO_GPU = True | |
| except ImportError: | |
| ZERO_GPU = False | |
| import numpy as np | |
| from transformers import ASTForAudioClassification, AutoFeatureExtractor | |
| from pydub import AudioSegment | |
| import tempfile | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Model configurations | |
| MODELS = { | |
| "fine_tuned": { | |
| "name": "Vyvo-Research/AST-Music-Classifier-1K", | |
| "display_name": "AST-Music-Classifier-1K (Fine-tuned)", | |
| "description": "Music sınıflandırması için özelleştirilmiş model", | |
| "badge": "Fine-tuned" | |
| }, | |
| "base": { | |
| "name": "MIT/ast-finetuned-audioset-10-10-0.4593", | |
| "display_name": "MIT AST (Base Model)", | |
| "description": "AudioSet üzerinde eğitilmiş orijinal AST modeli", | |
| "badge": "Base" | |
| } | |
| } | |
| DETECTION_THRESHOLD = 0.50 | |
| WINDOW_SIZE = 5.0 | |
| HOP_SIZE = 5.0 | |
| # Load both models | |
| logger.info("Loading models...") | |
| models = {} | |
| feature_extractors = {} | |
| for key, config in MODELS.items(): | |
| logger.info(f"Loading {config['display_name']}...") | |
| models[key] = ASTForAudioClassification.from_pretrained(config["name"]) | |
| feature_extractors[key] = AutoFeatureExtractor.from_pretrained(config["name"]) | |
| models[key].eval() | |
| logger.info("All models loaded") | |
| def load_audio(file_path: str, target_sr: int): | |
| audio = AudioSegment.from_file(file_path) | |
| audio = audio.set_channels(1).set_frame_rate(target_sr) | |
| samples = np.array(audio.get_array_of_samples()).astype(np.float32) | |
| samples = samples / np.iinfo(audio.array_type).max | |
| return samples, audio | |
| def detect_music_with_model(audio_array, sample_rate, model_key): | |
| model = models[model_key] | |
| feature_extractor = feature_extractors[model_key] | |
| window_samples = int(WINDOW_SIZE * sample_rate) | |
| hop_samples = int(HOP_SIZE * sample_rate) | |
| total_samples = len(audio_array) | |
| music_segments = [] | |
| all_predictions = [] | |
| last_was_music = False | |
| device = next(model.parameters()).device | |
| use_half = device.type == "cuda" | |
| for start in range(0, total_samples, hop_samples): | |
| end = min(start + window_samples, total_samples) | |
| segment = audio_array[start:end] | |
| if len(segment) < sample_rate: | |
| continue | |
| needs_padding = len(segment) < window_samples | |
| if needs_padding: | |
| segment = np.pad(segment, (0, window_samples - len(segment)), mode='constant') | |
| inputs = feature_extractor( | |
| segment, | |
| sampling_rate=sample_rate, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=1024 | |
| ) | |
| if use_half: | |
| inputs = {k: v.to(device).half() for k, v in inputs.items()} | |
| else: | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| pred_idx = torch.argmax(probs[0]).item() | |
| pred_label = model.config.id2label.get(pred_idx, "") | |
| pred_score = probs[0][pred_idx].item() | |
| is_music = "music" in pred_label.lower() | |
| is_uncertain = 0.40 <= pred_score <= 0.60 | |
| start_sec = start / sample_rate | |
| end_sec = end / sample_rate | |
| all_predictions.append({ | |
| "start": start_sec, | |
| "end": end_sec, | |
| "label": pred_label, | |
| "score": pred_score, | |
| "is_music": is_music | |
| }) | |
| if is_uncertain and needs_padding: | |
| if last_was_music: | |
| music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score)) | |
| elif is_music and pred_score >= DETECTION_THRESHOLD: | |
| music_segments.append((int(start_sec * 1000), int(end_sec * 1000), pred_score)) | |
| last_was_music = True | |
| else: | |
| last_was_music = False | |
| return music_segments, all_predictions | |
| def merge_segments(segments): | |
| if not segments: | |
| return [] | |
| segments = sorted(segments, key=lambda x: x[0]) | |
| merged = [segments[0]] | |
| for current in segments[1:]: | |
| last = merged[-1] | |
| if current[0] <= last[1]: | |
| merged[-1] = (last[0], max(last[1], current[1]), max(last[2], current[2])) | |
| else: | |
| merged.append(current) | |
| return merged | |
| def remove_music_segments(audio, segments): | |
| if not segments: | |
| return audio | |
| clean_parts = [] | |
| last_end = 0 | |
| for start_ms, end_ms, _ in segments: | |
| if start_ms > last_end: | |
| clean_parts.append(audio[last_end:start_ms]) | |
| last_end = end_ms | |
| if last_end < len(audio): | |
| clean_parts.append(audio[last_end:]) | |
| if not clean_parts: | |
| return AudioSegment.silent(duration=0) | |
| return sum(clean_parts) | |
| def calculate_metrics(segments, total_duration_ms): | |
| if not segments: | |
| return { | |
| "total_music_ms": 0, | |
| "segment_count": 0, | |
| "avg_confidence": 0, | |
| "coverage_percent": 0 | |
| } | |
| total_music_ms = sum(end - start for start, end, _ in segments) | |
| avg_confidence = sum(score for _, _, score in segments) / len(segments) | |
| coverage_percent = (total_music_ms / total_duration_ms) * 100 if total_duration_ms > 0 else 0 | |
| return { | |
| "total_music_ms": total_music_ms, | |
| "segment_count": len(segments), | |
| "avg_confidence": avg_confidence, | |
| "coverage_percent": coverage_percent | |
| } | |
| def build_comparison_report(original_dur, ft_segments, base_segments, ft_metrics, base_metrics): | |
| ft_detected = ft_metrics["total_music_ms"] / 1000 | |
| base_detected = base_metrics["total_music_ms"] / 1000 | |
| # Calculate improvement percentages | |
| if base_metrics["avg_confidence"] > 0: | |
| conf_improvement = ((ft_metrics["avg_confidence"] - base_metrics["avg_confidence"]) / base_metrics["avg_confidence"]) * 100 | |
| else: | |
| conf_improvement = 100 if ft_metrics["avg_confidence"] > 0 else 0 | |
| if base_metrics["segment_count"] > 0: | |
| segment_improvement = ((ft_metrics["segment_count"] - base_metrics["segment_count"]) / base_metrics["segment_count"]) * 100 | |
| else: | |
| segment_improvement = 100 if ft_metrics["segment_count"] > 0 else 0 | |
| # Winner determination | |
| ft_score = 0 | |
| base_score = 0 | |
| if ft_metrics["avg_confidence"] > base_metrics["avg_confidence"]: | |
| ft_score += 1 | |
| else: | |
| base_score += 1 | |
| if ft_metrics["segment_count"] >= base_metrics["segment_count"]: | |
| ft_score += 1 | |
| else: | |
| base_score += 1 | |
| if ft_score > base_score: | |
| winner = "Fine-tuned" | |
| winner_pct = abs(conf_improvement) | |
| else: | |
| winner = "Base" | |
| winner_pct = abs(conf_improvement) | |
| report = f""" | |
| ## Result: **{winner}** model wins! (+{winner_pct:.1f}% confidence) | |
| | Metric | Fine-tuned | Base | | |
| |--------|-----------|------| | |
| | Segments | **{ft_metrics['segment_count']}** | {base_metrics['segment_count']} | | |
| | Duration | **{ft_detected:.1f}s** | {base_detected:.1f}s | | |
| | Confidence | **{ft_metrics['avg_confidence']:.0%}** | {base_metrics['avg_confidence']:.0%} | | |
| --- | |
| **Fine-tuned segments:** | |
| """ | |
| if ft_segments: | |
| for start_ms, end_ms, score in ft_segments: | |
| report += f"- {start_ms/1000:.1f}s - {end_ms/1000:.1f}s ({score:.0%})\n" | |
| else: | |
| report += "No music detected\n" | |
| report += "\n**Base segments:**\n" | |
| if base_segments: | |
| for start_ms, end_ms, score in base_segments: | |
| report += f"- {start_ms/1000:.1f}s - {end_ms/1000:.1f}s ({score:.0%})\n" | |
| else: | |
| report += "No music detected\n" | |
| return report | |
| def process_audio_comparison(audio_file, progress=gr.Progress()): | |
| if audio_file is None: | |
| return None, None, "Please upload an audio file." | |
| try: | |
| progress(0.05, desc="Preparing models...") | |
| # Move models to GPU if available | |
| if torch.cuda.is_available(): | |
| for key in models: | |
| models[key].to("cuda").half() | |
| torch.backends.cudnn.benchmark = True | |
| progress(0.1, desc="Loading audio...") | |
| sample_rate = feature_extractors["fine_tuned"].sampling_rate | |
| audio_array, audio = load_audio(audio_file, sample_rate) | |
| original_duration = len(audio) / 1000 | |
| total_duration_ms = len(audio) | |
| # Process with Fine-tuned model | |
| progress(0.2, desc="Analyzing with Fine-tuned Model...") | |
| ft_segments, ft_predictions = detect_music_with_model(audio_array, sample_rate, "fine_tuned") | |
| ft_segments = merge_segments(ft_segments) | |
| ft_metrics = calculate_metrics(ft_segments, total_duration_ms) | |
| # Process with Base model | |
| progress(0.5, desc="Analyzing with Base Model...") | |
| base_segments, base_predictions = detect_music_with_model(audio_array, sample_rate, "base") | |
| base_segments = merge_segments(base_segments) | |
| base_metrics = calculate_metrics(base_segments, total_duration_ms) | |
| # Create outputs for both models | |
| progress(0.8, desc="Generating outputs...") | |
| # Fine-tuned model output | |
| ft_clean_audio = remove_music_segments(audio, ft_segments) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| ft_clean_audio.export(f.name, format="wav") | |
| ft_output_path = f.name | |
| # Base model output | |
| base_clean_audio = remove_music_segments(audio, base_segments) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| base_clean_audio.export(f.name, format="wav") | |
| base_output_path = f.name | |
| progress(0.95, desc="Building report...") | |
| report = build_comparison_report( | |
| original_duration, ft_segments, base_segments, ft_metrics, base_metrics | |
| ) | |
| progress(1.0, desc="Done") | |
| return ft_output_path, base_output_path, report | |
| except Exception as e: | |
| logger.exception("Processing failed") | |
| return None, None, f"Error: {str(e)}" | |
| with gr.Blocks(title="CleanSpeech - Model Comparison") as demo: | |
| gr.Markdown("# CleanSpeech - Model Comparison") | |
| # Input section | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
| process_btn = gr.Button("Compare Models", variant="primary", size="lg") | |
| # Output section - Side by side | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ft_audio_output = gr.Audio(label="Fine-tuned Output") | |
| with gr.Column(scale=1): | |
| base_audio_output = gr.Audio(label="Base Model Output") | |
| # Comparison report | |
| comparison_report = gr.Markdown(label="Comparison Report") | |
| process_btn.click( | |
| fn=process_audio_comparison, | |
| inputs=[audio_input], | |
| outputs=[ft_audio_output, base_audio_output, comparison_report] | |
| ) | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| **Models:** [Fine-tuned](https://huggingface.co/Vyvo-Research/AST-Music-Classifier-1K) | [Base](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) | |
| """) | |
| demo.queue() | |
| demo.launch(theme=gr.themes.Soft()) | |