Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import tempfile | |
| import warnings | |
| import os | |
| import logging | |
| import sys | |
| import time | |
| from sam_audio import SAMAudio, SAMAudioProcessor | |
| import os, uuid | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.responses import JSONResponse | |
| import gradio as gr | |
| api = FastAPI() | |
| UPLOAD_DIR = "/tmp/uploads" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| async def upload_audio(file: UploadFile = File(...)): | |
| # Save uploaded bytes | |
| ext = os.path.splitext(file.filename)[1] or ".wav" | |
| out_name = f"{uuid.uuid4().hex}{ext}" | |
| out_path = os.path.join(UPLOAD_DIR, out_name) | |
| data = await file.read() | |
| with open(out_path, "wb") as f: | |
| f.write(data) | |
| # Serve it back via a URL on this same Space | |
| # We'll add a simple file-serving route: | |
| return JSONResponse({"path": out_path, "url": f"/files/{out_name}"}) | |
| from fastapi.staticfiles import StaticFiles | |
| api.mount("/files", StaticFiles(directory=UPLOAD_DIR), name="files") | |
| warnings.filterwarnings("ignore") | |
| logger = logging.getLogger("sam_space") | |
| logger.setLevel(logging.INFO) | |
| handler = logging.StreamHandler(sys.stdout) | |
| handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s %(message)s")) | |
| logger.handlers.clear() | |
| logger.addHandler(handler) | |
| def log(msg: str): | |
| logger.info(msg) | |
| sys.stdout.flush() | |
| # Available models | |
| MODELS = { | |
| "sam-audio-small": "facebook/sam-audio-small", | |
| "sam-audio-base": "facebook/sam-audio-base", | |
| "sam-audio-large": "facebook/sam-audio-large", | |
| "sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv", | |
| "sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv", | |
| "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv", | |
| } | |
| DEFAULT_MODEL = "sam-audio-small" | |
| EXAMPLES_DIR = "audio" | |
| EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "PromoterClipMono.wav") | |
| # Chunk processing settings | |
| DEFAULT_CHUNK_DURATION = 5 # seconds per chunk | |
| OVERLAP_DURATION = 1 # seconds of overlap between chunks | |
| MAX_DURATION_WITHOUT_CHUNKING = 10 # auto-chunk if longer than this | |
| # Global model cache | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| current_model_name = None | |
| model = None | |
| processor = None | |
| def load_model(model_name): | |
| log(f"App import complete. device={device} default_model={DEFAULT_MODEL} cwd={os.getcwd()}") | |
| global current_model_name, model, processor | |
| model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL]) | |
| if current_model_name == model_name and model is not None: | |
| return | |
| print(f"Loading {model_id}...") | |
| model = SAMAudio.from_pretrained(model_id).to(device).eval() | |
| processor = SAMAudioProcessor.from_pretrained(model_id) | |
| current_model_name = model_name | |
| print(f"Model {model_id} loaded on {device}.") | |
| load_model(DEFAULT_MODEL) | |
| def load_audio(file_path): | |
| """Load audio from file (supports both audio and video files).""" | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| # Convert to mono if stereo | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| return waveform, sample_rate | |
| def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration): | |
| """Split audio waveform into overlapping chunks.""" | |
| chunk_samples = int(chunk_duration * sample_rate) | |
| overlap_samples = int(overlap_duration * sample_rate) | |
| stride = chunk_samples - overlap_samples | |
| chunks = [] | |
| total_samples = waveform.shape[1] | |
| start = 0 | |
| while start < total_samples: | |
| end = min(start + chunk_samples, total_samples) | |
| chunk = waveform[:, start:end] | |
| chunks.append(chunk) | |
| if end >= total_samples: | |
| break | |
| start += stride | |
| return chunks | |
| def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration): | |
| """Merge audio chunks with crossfade on overlapping regions.""" | |
| if len(chunks) == 1: | |
| chunk = chunks[0] | |
| # Ensure 2D tensor | |
| if chunk.dim() == 1: | |
| chunk = chunk.unsqueeze(0) | |
| return chunk | |
| overlap_samples = int(overlap_duration * sample_rate) | |
| # Ensure all chunks are 2D [channels, samples] | |
| processed_chunks = [] | |
| for chunk in chunks: | |
| if chunk.dim() == 1: | |
| chunk = chunk.unsqueeze(0) | |
| processed_chunks.append(chunk) | |
| result = processed_chunks[0] | |
| for i in range(1, len(processed_chunks)): | |
| prev_chunk = result | |
| next_chunk = processed_chunks[i] | |
| # Handle case where chunks are shorter than overlap | |
| actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1]) | |
| if actual_overlap <= 0: | |
| # No overlap possible, just concatenate | |
| result = torch.cat([prev_chunk, next_chunk], dim=1) | |
| continue | |
| # Create fade curves | |
| fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device) | |
| fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device) | |
| # Get overlapping regions | |
| prev_overlap = prev_chunk[:, -actual_overlap:] | |
| next_overlap = next_chunk[:, :actual_overlap] | |
| # Crossfade mix | |
| crossfaded = prev_overlap * fade_out + next_overlap * fade_in | |
| # Concatenate: non-overlap of prev + crossfaded + non-overlap of next | |
| result = torch.cat([ | |
| prev_chunk[:, :-actual_overlap], | |
| crossfaded, | |
| next_chunk[:, actual_overlap:] | |
| ], dim=1) | |
| return result | |
| def save_audio(tensor, sample_rate): | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| torchaudio.save(tmp.name, tensor, sample_rate) | |
| return tmp.name | |
| def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CHUNK_DURATION, progress=gr.Progress()): | |
| global model, processor | |
| t0 = time.time() | |
| log(f"[separate_audio] ENTER model={model_name} file_path={file_path} prompt='{(text_prompt or '')[:80]}' chunk_duration={chunk_duration}") | |
| # Validate file existence *and log it* | |
| if isinstance(file_path, str): | |
| exists = os.path.exists(file_path) | |
| size = os.path.getsize(file_path) if exists else -1 | |
| log(f"[separate_audio] file exists={exists} size={size}") | |
| else: | |
| log(f"[separate_audio] unexpected file_path type: {type(file_path)}") | |
| progress(0.05, desc="Checking inputs...") | |
| if not file_path: | |
| return None, None, "β Please upload an audio file." | |
| if not text_prompt or not text_prompt.strip(): | |
| return None, None, "β Please enter a text prompt." | |
| try: | |
| progress(0.1, desc="Loading model...") | |
| load_model(model_name) | |
| progress(0.15, desc="Loading audio...") | |
| log(f"[separate_audio] loading audio...") | |
| waveform, sample_rate = load_audio(file_path) | |
| duration = waveform.shape[1] / sample_rate | |
| log(f"[separate_audio] audio loaded sr={sample_rate} duration={duration:.2f}s shape={tuple(waveform.shape)}") | |
| # Decide whether to use chunking | |
| use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING | |
| if use_chunking: | |
| progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...") | |
| chunks = split_audio_into_chunks(waveform, sample_rate, chunk_duration, OVERLAP_DURATION) | |
| num_chunks = len(chunks) | |
| target_chunks = [] | |
| residual_chunks = [] | |
| for i, chunk in enumerate(chunks): | |
| chunk_progress = 0.2 + (i / num_chunks) * 0.6 | |
| progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...") | |
| # Save chunk to temp file for processor | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: | |
| torchaudio.save(tmp.name, chunk, sample_rate) | |
| chunk_path = tmp.name | |
| try: | |
| log(f"[separate_audio] building inputs on device={device} ...") | |
| inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device) | |
| log("[separate_audio] running model.separate() ...") | |
| with torch.inference_mode(): | |
| result = model.separate(inputs, predict_spans=False, reranking_candidates=1) | |
| log("[separate_audio] model.separate() done") | |
| target_chunks.append(result.target[0].cpu()) | |
| residual_chunks.append(result.residual[0].cpu()) | |
| finally: | |
| os.unlink(chunk_path) | |
| progress(0.85, desc="Merging chunks...") | |
| target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION) | |
| residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION) | |
| progress(0.95, desc="Saving results...") | |
| # merged tensors are already 2D [channels, samples] | |
| target_path = save_audio(target_merged, sample_rate) | |
| residual_path = save_audio(residual_merged, sample_rate) | |
| progress(1.0, desc="Done!") | |
| return target_path, residual_path, f"β Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)" | |
| else: | |
| # Process without chunking | |
| progress(0.3, desc="Processing audio...") | |
| log(f"[separate_audio] building inputs on device={device} ...") | |
| inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device) | |
| progress(0.6, desc="Separating sounds...") | |
| log("[separate_audio] running model.separate() ...") | |
| with torch.inference_mode(): | |
| result = model.separate(inputs, predict_spans=False, reranking_candidates=1) | |
| progress(0.9, desc="Saving results...") | |
| log("[separate_audio] model.separate() done") | |
| sample_rate = processor.audio_sampling_rate | |
| target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate) | |
| residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate) | |
| progress(1.0, desc="Done!") | |
| return target_path, residual_path, f"β Isolated '{text_prompt}' using {model_name}" | |
| except Exception as e: | |
| import traceback | |
| log(f"[separate_audio] EXCEPTION: {e}") | |
| traceback.print_exc() | |
| sys.stdout.flush() | |
| return None, None, f"β Error: {str(e)}" | |
| finally: | |
| log(f"[separate_audio] EXIT after {time.time() - t0:.2f}s") | |
| # Build Interface | |
| with gr.Blocks(title="SAM-Audio Test") as demo: | |
| gr.Markdown( | |
| """ | |
| # π΅ SAM-Audio: Segment Anything for Audio | |
| Isolate specific sounds from audio or video using natural language prompts. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| choices=list(MODELS.keys()), | |
| value=DEFAULT_MODEL, | |
| label="Model" | |
| ) | |
| with gr.Accordion("βοΈ Advanced Options", open=False): | |
| chunk_duration_slider = gr.Slider( | |
| minimum=10, | |
| maximum=60, | |
| value=DEFAULT_CHUNK_DURATION, | |
| step=5, | |
| label="Chunk Duration (seconds)", | |
| info=f"Audio longer than {MAX_DURATION_WITHOUT_CHUNKING}s will be automatically split" | |
| ) | |
| gr.Markdown("#### Upload Audio") | |
| input_audio = gr.Audio(label="Audio File", type="filepath") | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt", | |
| placeholder="e.g., 'guitar', 'voice'" | |
| ) | |
| run_btn = gr.Button("π― Isolate Sound", variant="primary", size="lg") | |
| status_output = gr.Markdown("") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Results") | |
| output_target = gr.Audio(label="Isolated Sound (Target)") | |
| output_residual = gr.Audio(label="Background (Residual)") | |
| gr.Markdown("---") | |
| gr.Markdown("### π¬ Demo Examples") | |
| gr.Markdown("Click to load example audio and prompt:") | |
| with gr.Row(): | |
| if os.path.exists(EXAMPLE_FILE): | |
| example_btn1 = gr.Button("π€ Man Speaking") | |
| example_btn2 = gr.Button("π€ Woman Speaking") | |
| example_btn3 = gr.Button("π΅ Background Music") | |
| # Main process button | |
| # def process(model_name, audio_path, prompt, chunk_duration, progress=gr.Progress()): | |
| # return separate_audio(model_name, audio_path, prompt, chunk_duration, progress) | |
| def process(model_name, audio_path, prompt, chunk_duration, progress=gr.Progress()): | |
| t0 = time.time() | |
| log(f"[process] called model={model_name} chunk_duration={chunk_duration} prompt_len={len(prompt) if prompt else 0}") | |
| # audio_path can be None or a string filepath depending on gradio | |
| log(f"[process] audio_path type={type(audio_path)} value={audio_path}") | |
| try: | |
| out = separate_audio(model_name, audio_path, prompt, chunk_duration, progress) | |
| log(f"[process] finished in {time.time() - t0:.2f}s") | |
| return out | |
| except Exception as e: | |
| log(f"[process] EXCEPTION: {e}") | |
| raise | |
| run_btn.click( | |
| fn=process, | |
| inputs=[model_selector, input_audio, text_prompt, chunk_duration_slider], | |
| outputs=[output_target, output_residual, status_output] | |
| ) | |
| # Example buttons - just fill the prompt, user clicks button to process | |
| if os.path.exists(EXAMPLE_FILE): | |
| example_btn1.click( | |
| fn=lambda: (EXAMPLE_FILE, "Guitar"), | |
| outputs=[input_audio, text_prompt] | |
| ) | |
| example_btn2.click( | |
| fn=lambda: (EXAMPLE_FILE, "Voice"), | |
| outputs=[input_audio, text_prompt] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True, share=True) | |
| app = gr.mount_gradio_app(api, demo, path="/") | |