Spaces:
Running on Zero
Running on Zero
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from transformers import ASTForAudioClassification, AutoFeatureExtractor | |
| from pydub import AudioSegment | |
| import tempfile | |
| import logging | |
| from datetime import datetime | |
| from typing import Tuple, List, Optional | |
| import space | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class MusicRemover: | |
| def __init__(self, model_name: str = "Vyvo-Research/AST-Music-Classifier-1K"): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Initializing on {self.device}") | |
| self.model = ASTForAudioClassification.from_pretrained(model_name).to(self.device) | |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
| self.model.eval() | |
| if self.device.type == "cuda": | |
| self.model = self.model.half() | |
| torch.backends.cudnn.benchmark = True | |
| logger.info("Model loaded successfully") | |
| def load_audio(self, file_path: str): | |
| audio = AudioSegment.from_file(file_path) | |
| audio = audio.set_channels(1) | |
| sample_rate = self.feature_extractor.sampling_rate | |
| audio = audio.set_frame_rate(sample_rate) | |
| samples = np.array(audio.get_array_of_samples()).astype(np.float32) | |
| samples = samples / np.iinfo(audio.array_type).max | |
| return samples, sample_rate, audio | |
| def detect_music_segments(self, audio_array, sample_rate, threshold, window_size, hop_size): | |
| window_samples = int(window_size * sample_rate) | |
| hop_samples = int(hop_size * sample_rate) | |
| music_segments = [] | |
| total_samples = len(audio_array) | |
| total_duration = total_samples / sample_rate | |
| logger.info(f"Audio: {total_duration:.1f}s, Window: {window_size}s, Hop: {hop_size}s") | |
| logger.info(f"Total samples: {total_samples}, Window samples: {window_samples}, Hop samples: {hop_samples}") | |
| segment_count = 0 | |
| last_was_music = False | |
| for start in range(0, total_samples, hop_samples): | |
| end = min(start + window_samples, total_samples) | |
| segment = audio_array[start:end] | |
| segment_duration = len(segment) / sample_rate | |
| # Çok kısa segmentleri atla (1 saniyeden az) | |
| if len(segment) < sample_rate: | |
| logger.info(f"Skipping final segment (too short): {segment_duration:.2f}s") | |
| continue | |
| segment_count += 1 | |
| start_sec = start / sample_rate | |
| end_sec = end / sample_rate | |
| # Kısa segmentleri padding ile doldur | |
| needs_padding = len(segment) < window_samples | |
| if needs_padding: | |
| segment = np.pad(segment, (0, window_samples - len(segment)), mode='constant') | |
| logger.info(f"Processing segment {segment_count}: {start_sec:.1f}s - {end_sec:.1f}s (padded)") | |
| else: | |
| logger.info(f"Processing segment {segment_count}: {start_sec:.1f}s - {end_sec:.1f}s") | |
| inputs = self.feature_extractor( | |
| segment, | |
| sampling_rate=sample_rate, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=1024 | |
| ) | |
| if self.device.type == "cuda": | |
| inputs = {k: v.to(self.device).half() for k, v in inputs.items()} | |
| else: | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| outputs = self.model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| # Label'ları al | |
| labels = self.model.config.id2label | |
| # En yüksek skorlu label'ı bul (argmax) | |
| pred_idx = torch.argmax(probs[0]).item() | |
| pred_label = labels.get(pred_idx, f'idx{pred_idx}') | |
| pred_score = probs[0][pred_idx].item() | |
| logger.info(f" -> Prediction: {pred_label} ({pred_score:.2%})") | |
| # Eğer prediction "music" ise ve confidence yeterli ise müzik olarak işaretle | |
| is_music = 'music' in pred_label.lower() | |
| # Belirsiz sonuç kontrolü (40-60% arası) | |
| is_uncertain = 0.40 <= pred_score <= 0.60 | |
| if is_uncertain and needs_padding: | |
| # Kısa segment + belirsiz sonuç = önceki sonucu kullan | |
| if last_was_music: | |
| start_ms = int(start_sec * 1000) | |
| end_ms = int(end_sec * 1000) | |
| music_segments.append((start_ms, end_ms, pred_score)) | |
| logger.info(f" -> MUSIC (uncertain {pred_score:.0%}, using previous)") | |
| else: | |
| logger.info(f" -> SPEECH (uncertain {pred_score:.0%}, using previous)") | |
| elif is_music and pred_score >= threshold: | |
| start_ms = int(start_sec * 1000) | |
| end_ms = int(end_sec * 1000) | |
| music_segments.append((start_ms, end_ms, pred_score)) | |
| last_was_music = True | |
| logger.info(f" -> MUSIC DETECTED!") | |
| else: | |
| last_was_music = False | |
| if is_music: | |
| logger.info(f" -> Low confidence music ({pred_score:.1%} < {threshold:.0%}), treating as speech") | |
| logger.info(f"Processed {segment_count} segments, found {len(music_segments)} music segments") | |
| return music_segments | |
| def merge_overlapping_segments(self, segments): | |
| if not segments: | |
| return [] | |
| segments = sorted(segments, key=lambda x: x[0]) | |
| merged = [segments[0]] | |
| for current in segments[1:]: | |
| last = merged[-1] | |
| if current[0] <= last[1]: | |
| merged[-1] = ( | |
| last[0], | |
| max(last[1], current[1]), | |
| max(last[2], current[2]) | |
| ) | |
| else: | |
| merged.append(current) | |
| return merged | |
| def remove_music(self, audio, music_segments): | |
| if not music_segments: | |
| return audio, [(0, len(audio)/1000)] | |
| clean_segments = [] | |
| kept_ranges = [] | |
| last_end = 0 | |
| for start_ms, end_ms, _ in music_segments: | |
| if start_ms > last_end: | |
| clean_segments.append(audio[last_end:start_ms]) | |
| kept_ranges.append((last_end/1000, start_ms/1000)) | |
| last_end = end_ms | |
| if last_end < len(audio): | |
| clean_segments.append(audio[last_end:]) | |
| kept_ranges.append((last_end/1000, len(audio)/1000)) | |
| if not clean_segments: | |
| return AudioSegment.silent(duration=0), [] | |
| return sum(clean_segments), kept_ranges | |
| def process(self, input_file, output_format="wav", threshold=0.50, window_size=5.0, hop_size=5.0, progress=None): | |
| try: | |
| if progress: | |
| progress(0, desc="Loading audio...") | |
| audio_array, sample_rate, audio = self.load_audio(input_file) | |
| original_duration = len(audio) / 1000 | |
| if progress: | |
| progress(0.2, desc="Detecting music...") | |
| music_segments = self.detect_music_segments( | |
| audio_array, sample_rate, threshold, window_size, hop_size | |
| ) | |
| if progress: | |
| progress(0.6, desc="Processing...") | |
| music_segments = self.merge_overlapping_segments(music_segments) | |
| if progress: | |
| progress(0.8, desc="Removing music...") | |
| clean_audio, kept_ranges = self.remove_music(audio, music_segments) | |
| clean_duration = len(clean_audio) / 1000 | |
| removed_duration = original_duration - clean_duration | |
| if progress: | |
| progress(0.9, desc="Saving...") | |
| format_settings = { | |
| "wav": {"format": "wav"}, | |
| "mp3": {"format": "mp3", "bitrate": "192k"}, | |
| "ogg": {"format": "ogg", "codec": "libvorbis"} | |
| } | |
| settings = format_settings.get(output_format, format_settings["wav"]) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=f".{output_format}") as tmp_file: | |
| clean_audio.export(tmp_file.name, **settings) | |
| output_path = tmp_file.name | |
| if progress: | |
| progress(1.0, desc="Complete!") | |
| segments_detail = "" | |
| if music_segments: | |
| segments_detail = "\n### 🎵 Detected Music Segments:\n| # | Start | End | Confidence |\n|---|-------|-----|------------|\n" | |
| for i, (start_ms, end_ms, score) in enumerate(music_segments, 1): | |
| confidence = "🟢 High" if score > 0.7 else "🟡 Medium" if score > 0.5 else "🟠 Low" | |
| segments_detail += f"| {i} | {start_ms/1000:.1f}s | {end_ms/1000:.1f}s | {score:.0%} {confidence} |\n" | |
| else: | |
| segments_detail = "\n### ✅ No music detected!\n" | |
| report = f""" | |
| ## 📊 Processing Report | |
| | Metric | Value | | |
| |--------|-------| | |
| | Original Duration | {original_duration:.2f}s | | |
| | Clean Duration | {clean_duration:.2f}s | | |
| | Removed Duration | {removed_duration:.2f}s ({(removed_duration/original_duration)*100:.1f}%) | | |
| | Music Segments | {len(music_segments)} | | |
| | Output Format | {output_format.upper()} | | |
| {segments_detail} | |
| """ | |
| logger.info(f"Complete: {original_duration:.1f}s -> {clean_duration:.1f}s") | |
| return output_path, report | |
| except Exception as e: | |
| logger.error(f"Failed: {str(e)}") | |
| return None, f"Error: {str(e)}" | |
| logger.info("Starting CleanSpeech AI...") | |
| remover = MusicRemover() | |
| def process_audio(audio_file, output_format, progress=gr.Progress()): | |
| if audio_file is None: | |
| return None, "Please upload an audio file." | |
| return remover.process(audio_file, output_format, progress=progress) | |
| with gr.Blocks(title="CleanSpeech AI") as demo: | |
| gr.Markdown(""" | |
| # 🎯 CleanSpeech AI | |
| ### Remove Background Music from Audio | |
| Upload your audio file and automatically detect and remove background music. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| audio_input = gr.Audio(label="🎤 Upload Audio", type="filepath") | |
| output_format = gr.Dropdown( | |
| choices=["wav", "mp3", "ogg"], | |
| value="wav", | |
| label="📁 Output Format" | |
| ) | |
| process_btn = gr.Button("🚀 Remove Music", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio(label="🔊 Cleaned Audio") | |
| report_output = gr.Markdown() | |
| process_btn.click( | |
| fn=process_audio, | |
| inputs=[audio_input, output_format], | |
| outputs=[audio_output, report_output] | |
| ) | |
| demo.queue() | |
| demo.launch() |