import spaces import gradio as gr import torch import torchaudio import tempfile import warnings import os import logging import sys import time from sam_audio import SAMAudio, SAMAudioProcessor import os, uuid from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse import gradio as gr api = FastAPI() UPLOAD_DIR = "/tmp/uploads" os.makedirs(UPLOAD_DIR, exist_ok=True) @api.post("/upload_audio") async def upload_audio(file: UploadFile = File(...)): # Save uploaded bytes ext = os.path.splitext(file.filename)[1] or ".wav" out_name = f"{uuid.uuid4().hex}{ext}" out_path = os.path.join(UPLOAD_DIR, out_name) data = await file.read() with open(out_path, "wb") as f: f.write(data) # Serve it back via a URL on this same Space # We'll add a simple file-serving route: return JSONResponse({"path": out_path, "url": f"/files/{out_name}"}) from fastapi.staticfiles import StaticFiles api.mount("/files", StaticFiles(directory=UPLOAD_DIR), name="files") warnings.filterwarnings("ignore") logger = logging.getLogger("sam_space") logger.setLevel(logging.INFO) handler = logging.StreamHandler(sys.stdout) handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s %(message)s")) logger.handlers.clear() logger.addHandler(handler) def log(msg: str): logger.info(msg) sys.stdout.flush() # Available models MODELS = { "sam-audio-small": "facebook/sam-audio-small", "sam-audio-base": "facebook/sam-audio-base", "sam-audio-large": "facebook/sam-audio-large", "sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv", "sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv", "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv", } DEFAULT_MODEL = "sam-audio-small" EXAMPLES_DIR = "audio" EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "PromoterClipMono.wav") # Chunk processing settings DEFAULT_CHUNK_DURATION = 5 # seconds per chunk OVERLAP_DURATION = 1 # seconds of overlap between chunks MAX_DURATION_WITHOUT_CHUNKING = 10 # auto-chunk if longer than this # Global model cache device = torch.device("cuda" if torch.cuda.is_available() else "cpu") current_model_name = None model = None processor = None def load_model(model_name): log(f"App import complete. device={device} default_model={DEFAULT_MODEL} cwd={os.getcwd()}") global current_model_name, model, processor model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL]) if current_model_name == model_name and model is not None: return print(f"Loading {model_id}...") model = SAMAudio.from_pretrained(model_id).to(device).eval() processor = SAMAudioProcessor.from_pretrained(model_id) current_model_name = model_name print(f"Model {model_id} loaded on {device}.") load_model(DEFAULT_MODEL) def load_audio(file_path): """Load audio from file (supports both audio and video files).""" waveform, sample_rate = torchaudio.load(file_path) # Convert to mono if stereo if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) return waveform, sample_rate def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration): """Split audio waveform into overlapping chunks.""" chunk_samples = int(chunk_duration * sample_rate) overlap_samples = int(overlap_duration * sample_rate) stride = chunk_samples - overlap_samples chunks = [] total_samples = waveform.shape[1] start = 0 while start < total_samples: end = min(start + chunk_samples, total_samples) chunk = waveform[:, start:end] chunks.append(chunk) if end >= total_samples: break start += stride return chunks def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration): """Merge audio chunks with crossfade on overlapping regions.""" if len(chunks) == 1: chunk = chunks[0] # Ensure 2D tensor if chunk.dim() == 1: chunk = chunk.unsqueeze(0) return chunk overlap_samples = int(overlap_duration * sample_rate) # Ensure all chunks are 2D [channels, samples] processed_chunks = [] for chunk in chunks: if chunk.dim() == 1: chunk = chunk.unsqueeze(0) processed_chunks.append(chunk) result = processed_chunks[0] for i in range(1, len(processed_chunks)): prev_chunk = result next_chunk = processed_chunks[i] # Handle case where chunks are shorter than overlap actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1]) if actual_overlap <= 0: # No overlap possible, just concatenate result = torch.cat([prev_chunk, next_chunk], dim=1) continue # Create fade curves fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device) fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device) # Get overlapping regions prev_overlap = prev_chunk[:, -actual_overlap:] next_overlap = next_chunk[:, :actual_overlap] # Crossfade mix crossfaded = prev_overlap * fade_out + next_overlap * fade_in # Concatenate: non-overlap of prev + crossfaded + non-overlap of next result = torch.cat([ prev_chunk[:, :-actual_overlap], crossfaded, next_chunk[:, actual_overlap:] ], dim=1) return result def save_audio(tensor, sample_rate): with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: torchaudio.save(tmp.name, tensor, sample_rate) return tmp.name @spaces.GPU(duration=10) def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CHUNK_DURATION, progress=gr.Progress()): global model, processor t0 = time.time() log(f"[separate_audio] ENTER model={model_name} file_path={file_path} prompt='{(text_prompt or '')[:80]}' chunk_duration={chunk_duration}") # Validate file existence *and log it* if isinstance(file_path, str): exists = os.path.exists(file_path) size = os.path.getsize(file_path) if exists else -1 log(f"[separate_audio] file exists={exists} size={size}") else: log(f"[separate_audio] unexpected file_path type: {type(file_path)}") progress(0.05, desc="Checking inputs...") if not file_path: return None, None, "❌ Please upload an audio file." if not text_prompt or not text_prompt.strip(): return None, None, "❌ Please enter a text prompt." try: progress(0.1, desc="Loading model...") load_model(model_name) progress(0.15, desc="Loading audio...") log(f"[separate_audio] loading audio...") waveform, sample_rate = load_audio(file_path) duration = waveform.shape[1] / sample_rate log(f"[separate_audio] audio loaded sr={sample_rate} duration={duration:.2f}s shape={tuple(waveform.shape)}") # Decide whether to use chunking use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING if use_chunking: progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...") chunks = split_audio_into_chunks(waveform, sample_rate, chunk_duration, OVERLAP_DURATION) num_chunks = len(chunks) target_chunks = [] residual_chunks = [] for i, chunk in enumerate(chunks): chunk_progress = 0.2 + (i / num_chunks) * 0.6 progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...") # Save chunk to temp file for processor with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: torchaudio.save(tmp.name, chunk, sample_rate) chunk_path = tmp.name try: log(f"[separate_audio] building inputs on device={device} ...") inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device) log("[separate_audio] running model.separate() ...") with torch.inference_mode(): result = model.separate(inputs, predict_spans=False, reranking_candidates=1) log("[separate_audio] model.separate() done") target_chunks.append(result.target[0].cpu()) residual_chunks.append(result.residual[0].cpu()) finally: os.unlink(chunk_path) progress(0.85, desc="Merging chunks...") target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION) residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION) progress(0.95, desc="Saving results...") # merged tensors are already 2D [channels, samples] target_path = save_audio(target_merged, sample_rate) residual_path = save_audio(residual_merged, sample_rate) progress(1.0, desc="Done!") return target_path, residual_path, f"✅ Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)" else: # Process without chunking progress(0.3, desc="Processing audio...") log(f"[separate_audio] building inputs on device={device} ...") inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device) progress(0.6, desc="Separating sounds...") log("[separate_audio] running model.separate() ...") with torch.inference_mode(): result = model.separate(inputs, predict_spans=False, reranking_candidates=1) progress(0.9, desc="Saving results...") log("[separate_audio] model.separate() done") sample_rate = processor.audio_sampling_rate target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate) residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate) progress(1.0, desc="Done!") return target_path, residual_path, f"✅ Isolated '{text_prompt}' using {model_name}" except Exception as e: import traceback log(f"[separate_audio] EXCEPTION: {e}") traceback.print_exc() sys.stdout.flush() return None, None, f"❌ Error: {str(e)}" finally: log(f"[separate_audio] EXIT after {time.time() - t0:.2f}s") # Build Interface with gr.Blocks(title="SAM-Audio Test") as demo: gr.Markdown( """ # 🎵 SAM-Audio: Segment Anything for Audio Isolate specific sounds from audio or video using natural language prompts. """ ) with gr.Row(): with gr.Column(scale=1): model_selector = gr.Dropdown( choices=list(MODELS.keys()), value=DEFAULT_MODEL, label="Model" ) with gr.Accordion("⚙️ Advanced Options", open=False): chunk_duration_slider = gr.Slider( minimum=10, maximum=60, value=DEFAULT_CHUNK_DURATION, step=5, label="Chunk Duration (seconds)", info=f"Audio longer than {MAX_DURATION_WITHOUT_CHUNKING}s will be automatically split" ) gr.Markdown("#### Upload Audio") input_audio = gr.Audio(label="Audio File", type="filepath") text_prompt = gr.Textbox( label="Text Prompt", placeholder="e.g., 'guitar', 'voice'" ) run_btn = gr.Button("🎯 Isolate Sound", variant="primary", size="lg") status_output = gr.Markdown("") with gr.Column(scale=1): gr.Markdown("### Results") output_target = gr.Audio(label="Isolated Sound (Target)") output_residual = gr.Audio(label="Background (Residual)") gr.Markdown("---") gr.Markdown("### 🎬 Demo Examples") gr.Markdown("Click to load example audio and prompt:") with gr.Row(): if os.path.exists(EXAMPLE_FILE): example_btn1 = gr.Button("🎤 Man Speaking") example_btn2 = gr.Button("🎤 Woman Speaking") example_btn3 = gr.Button("🎵 Background Music") # Main process button # def process(model_name, audio_path, prompt, chunk_duration, progress=gr.Progress()): # return separate_audio(model_name, audio_path, prompt, chunk_duration, progress) def process(model_name, audio_path, prompt, chunk_duration, progress=gr.Progress()): t0 = time.time() log(f"[process] called model={model_name} chunk_duration={chunk_duration} prompt_len={len(prompt) if prompt else 0}") # audio_path can be None or a string filepath depending on gradio log(f"[process] audio_path type={type(audio_path)} value={audio_path}") try: out = separate_audio(model_name, audio_path, prompt, chunk_duration, progress) log(f"[process] finished in {time.time() - t0:.2f}s") return out except Exception as e: log(f"[process] EXCEPTION: {e}") raise run_btn.click( fn=process, inputs=[model_selector, input_audio, text_prompt, chunk_duration_slider], outputs=[output_target, output_residual, status_output] ) # Example buttons - just fill the prompt, user clicks button to process if os.path.exists(EXAMPLE_FILE): example_btn1.click( fn=lambda: (EXAMPLE_FILE, "Guitar"), outputs=[input_audio, text_prompt] ) example_btn2.click( fn=lambda: (EXAMPLE_FILE, "Voice"), outputs=[input_audio, text_prompt] ) if __name__ == "__main__": demo.launch(show_error=True, share=True) app = gr.mount_gradio_app(api, demo, path="/")