Spaces:
Runtime error
Runtime error
Peter Shi
fix: Fixed the issue in the `merge_chunks_with_crossfade` function handling one-dimensional audio blocks and blocks shorter than the overlap area, and removed redundant dimension expansion operations in `save_audio`.v
f4c6545
| import spaces | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import tempfile | |
| import warnings | |
| import os | |
| warnings.filterwarnings("ignore") | |
| from sam_audio import SAMAudio, SAMAudioProcessor | |
| # Available models | |
| MODELS = { | |
| "sam-audio-small": "facebook/sam-audio-small", | |
| "sam-audio-base": "facebook/sam-audio-base", | |
| "sam-audio-large": "facebook/sam-audio-large", | |
| "sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv", | |
| "sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv", | |
| "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv", | |
| } | |
| DEFAULT_MODEL = "sam-audio-small" | |
| EXAMPLES_DIR = "examples" | |
| EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4") | |
| # Chunk processing settings | |
| DEFAULT_CHUNK_DURATION = 30 # seconds per chunk | |
| OVERLAP_DURATION = 2 # seconds of overlap between chunks | |
| MAX_DURATION_WITHOUT_CHUNKING = 60 # auto-chunk if longer than this | |
| # Global model cache | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| current_model_name = None | |
| model = None | |
| processor = None | |
| def load_model(model_name): | |
| global current_model_name, model, processor | |
| model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL]) | |
| if current_model_name == model_name and model is not None: | |
| return | |
| print(f"Loading {model_id}...") | |
| model = SAMAudio.from_pretrained(model_id).to(device).eval() | |
| processor = SAMAudioProcessor.from_pretrained(model_id) | |
| current_model_name = model_name | |
| print(f"Model {model_id} loaded on {device}.") | |
| load_model(DEFAULT_MODEL) | |
| def load_audio(file_path): | |
| """Load audio from file (supports both audio and video files).""" | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| return waveform, sample_rate | |
| def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration): | |
| """Split audio waveform into overlapping chunks.""" | |
| chunk_samples = int(chunk_duration * sample_rate) | |
| overlap_samples = int(overlap_duration * sample_rate) | |
| stride = chunk_samples - overlap_samples | |
| chunks = [] | |
| total_samples = waveform.shape[1] | |
| start = 0 | |
| while start < total_samples: | |
| end = min(start + chunk_samples, total_samples) | |
| chunk = waveform[:, start:end] | |
| chunks.append(chunk) | |
| if end >= total_samples: | |
| break | |
| start += stride | |
| return chunks | |
| def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration): | |
| """Merge audio chunks with crossfade on overlapping regions.""" | |
| if len(chunks) == 1: | |
| chunk = chunks[0] | |
| # Ensure 2D tensor | |
| if chunk.dim() == 1: | |
| chunk = chunk.unsqueeze(0) | |
| return chunk | |
| overlap_samples = int(overlap_duration * sample_rate) | |
| # Ensure all chunks are 2D [channels, samples] | |
| processed_chunks = [] | |
| for chunk in chunks: | |
| if chunk.dim() == 1: | |
| chunk = chunk.unsqueeze(0) | |
| processed_chunks.append(chunk) | |
| result = processed_chunks[0] | |
| for i in range(1, len(processed_chunks)): | |
| prev_chunk = result | |
| next_chunk = processed_chunks[i] | |
| # Handle case where chunks are shorter than overlap | |
| actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1]) | |
| if actual_overlap <= 0: | |
| # No overlap possible, just concatenate | |
| result = torch.cat([prev_chunk, next_chunk], dim=1) | |
| continue | |
| # Create fade curves | |
| fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device) | |
| fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device) | |
| # Get overlapping regions | |
| prev_overlap = prev_chunk[:, -actual_overlap:] | |
| next_overlap = next_chunk[:, :actual_overlap] | |
| # Crossfade mix | |
| crossfaded = prev_overlap * fade_out + next_overlap * fade_in | |
| # Concatenate: non-overlap of prev + crossfaded + non-overlap of next | |
| result = torch.cat([ | |
| prev_chunk[:, :-actual_overlap], | |
| crossfaded, | |
| next_chunk[:, actual_overlap:] | |
| ], dim=1) | |
| return result | |
| def save_audio(tensor, sample_rate): | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| torchaudio.save(tmp.name, tensor, sample_rate) | |
| return tmp.name | |
| def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CHUNK_DURATION, progress=gr.Progress()): | |
| global model, processor | |
| progress(0.05, desc="Checking inputs...") | |
| if not file_path: | |
| return None, None, "β Please upload an audio or video file." | |
| if not text_prompt or not text_prompt.strip(): | |
| return None, None, "β Please enter a text prompt." | |
| try: | |
| progress(0.1, desc="Loading model...") | |
| load_model(model_name) | |
| progress(0.15, desc="Loading audio...") | |
| waveform, sample_rate = load_audio(file_path) | |
| duration = waveform.shape[1] / sample_rate | |
| # Decide whether to use chunking | |
| use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING | |
| if use_chunking: | |
| progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...") | |
| chunks = split_audio_into_chunks(waveform, sample_rate, chunk_duration, OVERLAP_DURATION) | |
| num_chunks = len(chunks) | |
| target_chunks = [] | |
| residual_chunks = [] | |
| for i, chunk in enumerate(chunks): | |
| chunk_progress = 0.2 + (i / num_chunks) * 0.6 | |
| progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...") | |
| # Save chunk to temp file for processor | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| torchaudio.save(tmp.name, chunk, sample_rate) | |
| chunk_path = tmp.name | |
| try: | |
| inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device) | |
| with torch.inference_mode(): | |
| result = model.separate(inputs, predict_spans=False, reranking_candidates=1) | |
| target_chunks.append(result.target[0].cpu()) | |
| residual_chunks.append(result.residual[0].cpu()) | |
| finally: | |
| os.unlink(chunk_path) | |
| progress(0.85, desc="Merging chunks...") | |
| target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION) | |
| residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION) | |
| progress(0.95, desc="Saving results...") | |
| # merged tensors are already 2D [channels, samples] | |
| target_path = save_audio(target_merged, sample_rate) | |
| residual_path = save_audio(residual_merged, sample_rate) | |
| progress(1.0, desc="Done!") | |
| return target_path, residual_path, f"β Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)" | |
| else: | |
| # Process without chunking | |
| progress(0.3, desc="Processing audio...") | |
| inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device) | |
| progress(0.6, desc="Separating sounds...") | |
| with torch.inference_mode(): | |
| result = model.separate(inputs, predict_spans=False, reranking_candidates=1) | |
| progress(0.9, desc="Saving results...") | |
| sample_rate = processor.audio_sampling_rate | |
| target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate) | |
| residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate) | |
| progress(1.0, desc="Done!") | |
| return target_path, residual_path, f"β Isolated '{text_prompt}' using {model_name}" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return None, None, f"β Error: {str(e)}" | |
| # Build Interface | |
| with gr.Blocks(title="SAM-Audio Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # π΅ SAM-Audio: Segment Anything for Audio | |
| Isolate specific sounds from audio or video using natural language prompts. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=DEFAULT_MODEL, | |
| label="Model" | |
| ) | |
| with gr.Accordion("βοΈ Advanced Options", open=False): | |
| chunk_duration_slider = gr.Slider( | |
| minimum=10, | |
| maximum=60, | |
| value=DEFAULT_CHUNK_DURATION, | |
| step=5, | |
| label="Chunk Duration (seconds)", | |
| info=f"Audio longer than {MAX_DURATION_WITHOUT_CHUNKING}s will be automatically split" | |
| ) | |
| gr.Markdown("#### Upload Audio") | |
| input_audio = gr.Audio(label="Audio File", type="filepath") | |
| gr.Markdown("#### Or Upload Video") | |
| input_video = gr.Video(label="Video File") | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="e.g., 'A man speaking', 'Piano', 'Dog barking'" | |
| ) | |
| run_btn = gr.Button("π― Isolate Sound", variant="primary", size="lg") | |
| status_output = gr.Markdown("") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Results") | |
| output_target = gr.Audio(label="Isolated Sound (Target)") | |
| output_residual = gr.Audio(label="Background (Residual)") | |
| gr.Markdown("---") | |
| gr.Markdown("### π¬ Demo Examples") | |
| gr.Markdown("Click to load example video and prompt:") | |
| with gr.Row(): | |
| if os.path.exists(EXAMPLE_FILE): | |
| example_btn1 = gr.Button("π€ Man Speaking") | |
| example_btn2 = gr.Button("π€ Woman Speaking") | |
| example_btn3 = gr.Button("π΅ Background Music") | |
| # Main process button | |
| def process(model_name, audio_path, video_path, prompt, chunk_duration, progress=gr.Progress()): | |
| file_path = video_path if video_path else audio_path | |
| return separate_audio(model_name, file_path, prompt, chunk_duration, progress) | |
| run_btn.click( | |
| fn=process, | |
| inputs=[model_selector, input_audio, input_video, text_prompt, chunk_duration_slider], | |
| outputs=[output_target, output_residual, status_output] | |
| ) | |
| # Example buttons - just fill the prompt, user clicks button to process | |
| if os.path.exists(EXAMPLE_FILE): | |
| example_btn1.click( | |
| fn=lambda: (EXAMPLE_FILE, "A man speaking"), | |
| outputs=[input_video, text_prompt] | |
| ) | |
| example_btn2.click( | |
| fn=lambda: (EXAMPLE_FILE, "A woman speaking"), | |
| outputs=[input_video, text_prompt] | |
| ) | |
| example_btn3.click( | |
| fn=lambda: (EXAMPLE_FILE, "Background music"), | |
| outputs=[input_video, text_prompt] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |