import gradio as gr import torch import torchaudio import numpy as np from transformers import ASTForAudioClassification, AutoFeatureExtractor from pydub import AudioSegment import tempfile import logging from datetime import datetime from typing import Tuple, List, Optional import space logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class MusicRemover: def __init__(self, model_name: str = "Vyvo-Research/AST-Music-Classifier-1K"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Initializing on {self.device}") self.model = ASTForAudioClassification.from_pretrained(model_name).to(self.device) self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) self.model.eval() if self.device.type == "cuda": self.model = self.model.half() torch.backends.cudnn.benchmark = True logger.info("Model loaded successfully") def load_audio(self, file_path: str): audio = AudioSegment.from_file(file_path) audio = audio.set_channels(1) sample_rate = self.feature_extractor.sampling_rate audio = audio.set_frame_rate(sample_rate) samples = np.array(audio.get_array_of_samples()).astype(np.float32) samples = samples / np.iinfo(audio.array_type).max return samples, sample_rate, audio @torch.no_grad() @spaces.GPU() def detect_music_segments(self, audio_array, sample_rate, threshold, window_size, hop_size): window_samples = int(window_size * sample_rate) hop_samples = int(hop_size * sample_rate) music_segments = [] total_samples = len(audio_array) total_duration = total_samples / sample_rate logger.info(f"Audio: {total_duration:.1f}s, Window: {window_size}s, Hop: {hop_size}s") logger.info(f"Total samples: {total_samples}, Window samples: {window_samples}, Hop samples: {hop_samples}") segment_count = 0 last_was_music = False for start in range(0, total_samples, hop_samples): end = min(start + window_samples, total_samples) segment = audio_array[start:end] segment_duration = len(segment) / sample_rate # Çok kısa segmentleri atla (1 saniyeden az) if len(segment) < sample_rate: logger.info(f"Skipping final segment (too short): {segment_duration:.2f}s") continue segment_count += 1 start_sec = start / sample_rate end_sec = end / sample_rate # Kısa segmentleri padding ile doldur needs_padding = len(segment) < window_samples if needs_padding: segment = np.pad(segment, (0, window_samples - len(segment)), mode='constant') logger.info(f"Processing segment {segment_count}: {start_sec:.1f}s - {end_sec:.1f}s (padded)") else: logger.info(f"Processing segment {segment_count}: {start_sec:.1f}s - {end_sec:.1f}s") inputs = self.feature_extractor( segment, sampling_rate=sample_rate, return_tensors="pt", padding="max_length", truncation=True, max_length=1024 ) if self.device.type == "cuda": inputs = {k: v.to(self.device).half() for k, v in inputs.items()} else: inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.model(**inputs) probs = torch.softmax(outputs.logits, dim=-1) # Label'ları al labels = self.model.config.id2label # En yüksek skorlu label'ı bul (argmax) pred_idx = torch.argmax(probs[0]).item() pred_label = labels.get(pred_idx, f'idx{pred_idx}') pred_score = probs[0][pred_idx].item() logger.info(f" -> Prediction: {pred_label} ({pred_score:.2%})") # Eğer prediction "music" ise ve confidence yeterli ise müzik olarak işaretle is_music = 'music' in pred_label.lower() # Belirsiz sonuç kontrolü (40-60% arası) is_uncertain = 0.40 <= pred_score <= 0.60 if is_uncertain and needs_padding: # Kısa segment + belirsiz sonuç = önceki sonucu kullan if last_was_music: start_ms = int(start_sec * 1000) end_ms = int(end_sec * 1000) music_segments.append((start_ms, end_ms, pred_score)) logger.info(f" -> MUSIC (uncertain {pred_score:.0%}, using previous)") else: logger.info(f" -> SPEECH (uncertain {pred_score:.0%}, using previous)") elif is_music and pred_score >= threshold: start_ms = int(start_sec * 1000) end_ms = int(end_sec * 1000) music_segments.append((start_ms, end_ms, pred_score)) last_was_music = True logger.info(f" -> MUSIC DETECTED!") else: last_was_music = False if is_music: logger.info(f" -> Low confidence music ({pred_score:.1%} < {threshold:.0%}), treating as speech") logger.info(f"Processed {segment_count} segments, found {len(music_segments)} music segments") return music_segments def merge_overlapping_segments(self, segments): if not segments: return [] segments = sorted(segments, key=lambda x: x[0]) merged = [segments[0]] for current in segments[1:]: last = merged[-1] if current[0] <= last[1]: merged[-1] = ( last[0], max(last[1], current[1]), max(last[2], current[2]) ) else: merged.append(current) return merged def remove_music(self, audio, music_segments): if not music_segments: return audio, [(0, len(audio)/1000)] clean_segments = [] kept_ranges = [] last_end = 0 for start_ms, end_ms, _ in music_segments: if start_ms > last_end: clean_segments.append(audio[last_end:start_ms]) kept_ranges.append((last_end/1000, start_ms/1000)) last_end = end_ms if last_end < len(audio): clean_segments.append(audio[last_end:]) kept_ranges.append((last_end/1000, len(audio)/1000)) if not clean_segments: return AudioSegment.silent(duration=0), [] return sum(clean_segments), kept_ranges def process(self, input_file, output_format="wav", threshold=0.50, window_size=5.0, hop_size=5.0, progress=None): try: if progress: progress(0, desc="Loading audio...") audio_array, sample_rate, audio = self.load_audio(input_file) original_duration = len(audio) / 1000 if progress: progress(0.2, desc="Detecting music...") music_segments = self.detect_music_segments( audio_array, sample_rate, threshold, window_size, hop_size ) if progress: progress(0.6, desc="Processing...") music_segments = self.merge_overlapping_segments(music_segments) if progress: progress(0.8, desc="Removing music...") clean_audio, kept_ranges = self.remove_music(audio, music_segments) clean_duration = len(clean_audio) / 1000 removed_duration = original_duration - clean_duration if progress: progress(0.9, desc="Saving...") format_settings = { "wav": {"format": "wav"}, "mp3": {"format": "mp3", "bitrate": "192k"}, "ogg": {"format": "ogg", "codec": "libvorbis"} } settings = format_settings.get(output_format, format_settings["wav"]) with tempfile.NamedTemporaryFile(delete=False, suffix=f".{output_format}") as tmp_file: clean_audio.export(tmp_file.name, **settings) output_path = tmp_file.name if progress: progress(1.0, desc="Complete!") segments_detail = "" if music_segments: segments_detail = "\n### 🎵 Detected Music Segments:\n| # | Start | End | Confidence |\n|---|-------|-----|------------|\n" for i, (start_ms, end_ms, score) in enumerate(music_segments, 1): confidence = "🟢 High" if score > 0.7 else "🟡 Medium" if score > 0.5 else "🟠 Low" segments_detail += f"| {i} | {start_ms/1000:.1f}s | {end_ms/1000:.1f}s | {score:.0%} {confidence} |\n" else: segments_detail = "\n### ✅ No music detected!\n" report = f""" ## 📊 Processing Report | Metric | Value | |--------|-------| | Original Duration | {original_duration:.2f}s | | Clean Duration | {clean_duration:.2f}s | | Removed Duration | {removed_duration:.2f}s ({(removed_duration/original_duration)*100:.1f}%) | | Music Segments | {len(music_segments)} | | Output Format | {output_format.upper()} | {segments_detail} """ logger.info(f"Complete: {original_duration:.1f}s -> {clean_duration:.1f}s") return output_path, report except Exception as e: logger.error(f"Failed: {str(e)}") return None, f"Error: {str(e)}" logger.info("Starting CleanSpeech AI...") remover = MusicRemover() def process_audio(audio_file, output_format, progress=gr.Progress()): if audio_file is None: return None, "Please upload an audio file." return remover.process(audio_file, output_format, progress=progress) with gr.Blocks(title="CleanSpeech AI") as demo: gr.Markdown(""" # 🎯 CleanSpeech AI ### Remove Background Music from Audio Upload your audio file and automatically detect and remove background music. """) with gr.Row(): with gr.Column(scale=1): audio_input = gr.Audio(label="🎤 Upload Audio", type="filepath") output_format = gr.Dropdown( choices=["wav", "mp3", "ogg"], value="wav", label="📁 Output Format" ) process_btn = gr.Button("🚀 Remove Music", variant="primary", size="lg") with gr.Column(scale=1): audio_output = gr.Audio(label="🔊 Cleaned Audio") report_output = gr.Markdown() process_btn.click( fn=process_audio, inputs=[audio_input, output_format], outputs=[audio_output, report_output] ) demo.queue() demo.launch()