import spaces import gradio as gr import torch import torchaudio import tempfile import warnings import os warnings.filterwarnings("ignore") from sam_audio import SAMAudio, SAMAudioProcessor # Available models MODELS = { "sam-audio-small": "facebook/sam-audio-small", "sam-audio-base": "facebook/sam-audio-base", "sam-audio-large": "facebook/sam-audio-large", "sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv", "sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv", "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv", } DEFAULT_MODEL = "sam-audio-small" EXAMPLES_DIR = "examples" EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4") # Chunk processing settings DEFAULT_CHUNK_DURATION = 30 # seconds per chunk OVERLAP_DURATION = 2 # seconds of overlap between chunks MAX_DURATION_WITHOUT_CHUNKING = 60 # auto-chunk if longer than this # Global model cache device = torch.device("cuda" if torch.cuda.is_available() else "cpu") current_model_name = None model = None processor = None def load_model(model_name): global current_model_name, model, processor model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL]) if current_model_name == model_name and model is not None: return print(f"Loading {model_id}...") model = SAMAudio.from_pretrained(model_id).to(device).eval() processor = SAMAudioProcessor.from_pretrained(model_id) current_model_name = model_name print(f"Model {model_id} loaded on {device}.") load_model(DEFAULT_MODEL) def load_audio(file_path): """Load audio from file (supports both audio and video files).""" waveform, sample_rate = torchaudio.load(file_path) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) return waveform, sample_rate def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration): """Split audio waveform into overlapping chunks.""" chunk_samples = int(chunk_duration * sample_rate) overlap_samples = int(overlap_duration * sample_rate) stride = chunk_samples - overlap_samples chunks = [] total_samples = waveform.shape[1] start = 0 while start < total_samples: end = min(start + chunk_samples, total_samples) chunk = waveform[:, start:end] chunks.append(chunk) if end >= total_samples: break start += stride return chunks def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration): """Merge audio chunks with crossfade on overlapping regions.""" if len(chunks) == 1: chunk = chunks[0] # Ensure 2D tensor if chunk.dim() == 1: chunk = chunk.unsqueeze(0) return chunk overlap_samples = int(overlap_duration * sample_rate) # Ensure all chunks are 2D [channels, samples] processed_chunks = [] for chunk in chunks: if chunk.dim() == 1: chunk = chunk.unsqueeze(0) processed_chunks.append(chunk) result = processed_chunks[0] for i in range(1, len(processed_chunks)): prev_chunk = result next_chunk = processed_chunks[i] # Handle case where chunks are shorter than overlap actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1]) if actual_overlap <= 0: # No overlap possible, just concatenate result = torch.cat([prev_chunk, next_chunk], dim=1) continue # Create fade curves fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device) fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device) # Get overlapping regions prev_overlap = prev_chunk[:, -actual_overlap:] next_overlap = next_chunk[:, :actual_overlap] # Crossfade mix crossfaded = prev_overlap * fade_out + next_overlap * fade_in # Concatenate: non-overlap of prev + crossfaded + non-overlap of next result = torch.cat([ prev_chunk[:, :-actual_overlap], crossfaded, next_chunk[:, actual_overlap:] ], dim=1) return result def save_audio(tensor, sample_rate): with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: torchaudio.save(tmp.name, tensor, sample_rate) return tmp.name @spaces.GPU(duration=300) def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CHUNK_DURATION, progress=gr.Progress()): global model, processor progress(0.05, desc="Checking inputs...") if not file_path: return None, None, "❌ Please upload an audio or video file." if not text_prompt or not text_prompt.strip(): return None, None, "❌ Please enter a text prompt." try: progress(0.1, desc="Loading model...") load_model(model_name) progress(0.15, desc="Loading audio...") waveform, sample_rate = load_audio(file_path) duration = waveform.shape[1] / sample_rate # Decide whether to use chunking use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING if use_chunking: progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...") chunks = split_audio_into_chunks(waveform, sample_rate, chunk_duration, OVERLAP_DURATION) num_chunks = len(chunks) target_chunks = [] residual_chunks = [] for i, chunk in enumerate(chunks): chunk_progress = 0.2 + (i / num_chunks) * 0.6 progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...") # Save chunk to temp file for processor with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: torchaudio.save(tmp.name, chunk, sample_rate) chunk_path = tmp.name try: inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device) with torch.inference_mode(): result = model.separate(inputs, predict_spans=False, reranking_candidates=1) target_chunks.append(result.target[0].cpu()) residual_chunks.append(result.residual[0].cpu()) finally: os.unlink(chunk_path) progress(0.85, desc="Merging chunks...") target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION) residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION) progress(0.95, desc="Saving results...") # merged tensors are already 2D [channels, samples] target_path = save_audio(target_merged, sample_rate) residual_path = save_audio(residual_merged, sample_rate) progress(1.0, desc="Done!") return target_path, residual_path, f"✅ Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)" else: # Process without chunking progress(0.3, desc="Processing audio...") inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device) progress(0.6, desc="Separating sounds...") with torch.inference_mode(): result = model.separate(inputs, predict_spans=False, reranking_candidates=1) progress(0.9, desc="Saving results...") sample_rate = processor.audio_sampling_rate target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate) residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate) progress(1.0, desc="Done!") return target_path, residual_path, f"✅ Isolated '{text_prompt}' using {model_name}" except Exception as e: import traceback traceback.print_exc() return None, None, f"❌ Error: {str(e)}" # Build Interface with gr.Blocks(title="SAM-Audio Demo") as demo: gr.Markdown( """ # 🎵 SAM-Audio: Segment Anything for Audio Isolate specific sounds from audio or video using natural language prompts. """ ) with gr.Row(): with gr.Column(scale=1): model_selector = gr.Dropdown( choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Model" ) with gr.Accordion("⚙️ Advanced Options", open=False): chunk_duration_slider = gr.Slider( minimum=10, maximum=60, value=DEFAULT_CHUNK_DURATION, step=5, label="Chunk Duration (seconds)", info=f"Audio longer than {MAX_DURATION_WITHOUT_CHUNKING}s will be automatically split" ) gr.Markdown("#### Upload Audio") input_audio = gr.Audio(label="Audio File", type="filepath") gr.Markdown("#### Or Upload Video") input_video = gr.Video(label="Video File") text_prompt = gr.Textbox( label="Text Prompt", placeholder="e.g., 'A man speaking', 'Piano', 'Dog barking'" ) run_btn = gr.Button("🎯 Isolate Sound", variant="primary", size="lg") status_output = gr.Markdown("") with gr.Column(scale=1): gr.Markdown("### Results") output_target = gr.Audio(label="Isolated Sound (Target)") output_residual = gr.Audio(label="Background (Residual)") gr.Markdown("---") gr.Markdown("### 🎬 Demo Examples") gr.Markdown("Click to load example video and prompt:") with gr.Row(): if os.path.exists(EXAMPLE_FILE): example_btn1 = gr.Button("🎤 Man Speaking") example_btn2 = gr.Button("🎤 Woman Speaking") example_btn3 = gr.Button("🎵 Background Music") # Main process button def process(model_name, audio_path, video_path, prompt, chunk_duration, progress=gr.Progress()): file_path = video_path if video_path else audio_path return separate_audio(model_name, file_path, prompt, chunk_duration, progress) run_btn.click( fn=process, inputs=[model_selector, input_audio, input_video, text_prompt, chunk_duration_slider], outputs=[output_target, output_residual, status_output] ) # Example buttons - just fill the prompt, user clicks button to process if os.path.exists(EXAMPLE_FILE): example_btn1.click( fn=lambda: (EXAMPLE_FILE, "A man speaking"), outputs=[input_video, text_prompt] ) example_btn2.click( fn=lambda: (EXAMPLE_FILE, "A woman speaking"), outputs=[input_video, text_prompt] ) example_btn3.click( fn=lambda: (EXAMPLE_FILE, "Background music"), outputs=[input_video, text_prompt] ) if __name__ == "__main__": demo.launch()