sam / app.py
peoplepilot
add file upload
e09762f
import spaces
import gradio as gr
import torch
import torchaudio
import tempfile
import warnings
import os
import logging
import sys
import time
from sam_audio import SAMAudio, SAMAudioProcessor
import os, uuid
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import gradio as gr
api = FastAPI()
UPLOAD_DIR = "/tmp/uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
@api.post("/upload_audio")
async def upload_audio(file: UploadFile = File(...)):
# Save uploaded bytes
ext = os.path.splitext(file.filename)[1] or ".wav"
out_name = f"{uuid.uuid4().hex}{ext}"
out_path = os.path.join(UPLOAD_DIR, out_name)
data = await file.read()
with open(out_path, "wb") as f:
f.write(data)
# Serve it back via a URL on this same Space
# We'll add a simple file-serving route:
return JSONResponse({"path": out_path, "url": f"/files/{out_name}"})
from fastapi.staticfiles import StaticFiles
api.mount("/files", StaticFiles(directory=UPLOAD_DIR), name="files")
warnings.filterwarnings("ignore")
logger = logging.getLogger("sam_space")
logger.setLevel(logging.INFO)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s %(message)s"))
logger.handlers.clear()
logger.addHandler(handler)
def log(msg: str):
logger.info(msg)
sys.stdout.flush()
# Available models
MODELS = {
"sam-audio-small": "facebook/sam-audio-small",
"sam-audio-base": "facebook/sam-audio-base",
"sam-audio-large": "facebook/sam-audio-large",
"sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv",
"sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv",
"sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv",
}
DEFAULT_MODEL = "sam-audio-small"
EXAMPLES_DIR = "audio"
EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "PromoterClipMono.wav")
# Chunk processing settings
DEFAULT_CHUNK_DURATION = 5 # seconds per chunk
OVERLAP_DURATION = 1 # seconds of overlap between chunks
MAX_DURATION_WITHOUT_CHUNKING = 10 # auto-chunk if longer than this
# Global model cache
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
current_model_name = None
model = None
processor = None
def load_model(model_name):
log(f"App import complete. device={device} default_model={DEFAULT_MODEL} cwd={os.getcwd()}")
global current_model_name, model, processor
model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
if current_model_name == model_name and model is not None:
return
print(f"Loading {model_id}...")
model = SAMAudio.from_pretrained(model_id).to(device).eval()
processor = SAMAudioProcessor.from_pretrained(model_id)
current_model_name = model_name
print(f"Model {model_id} loaded on {device}.")
load_model(DEFAULT_MODEL)
def load_audio(file_path):
"""Load audio from file (supports both audio and video files)."""
waveform, sample_rate = torchaudio.load(file_path)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
return waveform, sample_rate
def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration):
"""Split audio waveform into overlapping chunks."""
chunk_samples = int(chunk_duration * sample_rate)
overlap_samples = int(overlap_duration * sample_rate)
stride = chunk_samples - overlap_samples
chunks = []
total_samples = waveform.shape[1]
start = 0
while start < total_samples:
end = min(start + chunk_samples, total_samples)
chunk = waveform[:, start:end]
chunks.append(chunk)
if end >= total_samples:
break
start += stride
return chunks
def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
"""Merge audio chunks with crossfade on overlapping regions."""
if len(chunks) == 1:
chunk = chunks[0]
# Ensure 2D tensor
if chunk.dim() == 1:
chunk = chunk.unsqueeze(0)
return chunk
overlap_samples = int(overlap_duration * sample_rate)
# Ensure all chunks are 2D [channels, samples]
processed_chunks = []
for chunk in chunks:
if chunk.dim() == 1:
chunk = chunk.unsqueeze(0)
processed_chunks.append(chunk)
result = processed_chunks[0]
for i in range(1, len(processed_chunks)):
prev_chunk = result
next_chunk = processed_chunks[i]
# Handle case where chunks are shorter than overlap
actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1])
if actual_overlap <= 0:
# No overlap possible, just concatenate
result = torch.cat([prev_chunk, next_chunk], dim=1)
continue
# Create fade curves
fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device)
fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device)
# Get overlapping regions
prev_overlap = prev_chunk[:, -actual_overlap:]
next_overlap = next_chunk[:, :actual_overlap]
# Crossfade mix
crossfaded = prev_overlap * fade_out + next_overlap * fade_in
# Concatenate: non-overlap of prev + crossfaded + non-overlap of next
result = torch.cat([
prev_chunk[:, :-actual_overlap],
crossfaded,
next_chunk[:, actual_overlap:]
], dim=1)
return result
def save_audio(tensor, sample_rate):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, tensor, sample_rate)
return tmp.name
@spaces.GPU(duration=10)
def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CHUNK_DURATION, progress=gr.Progress()):
global model, processor
t0 = time.time()
log(f"[separate_audio] ENTER model={model_name} file_path={file_path} prompt='{(text_prompt or '')[:80]}' chunk_duration={chunk_duration}")
# Validate file existence *and log it*
if isinstance(file_path, str):
exists = os.path.exists(file_path)
size = os.path.getsize(file_path) if exists else -1
log(f"[separate_audio] file exists={exists} size={size}")
else:
log(f"[separate_audio] unexpected file_path type: {type(file_path)}")
progress(0.05, desc="Checking inputs...")
if not file_path:
return None, None, "❌ Please upload an audio file."
if not text_prompt or not text_prompt.strip():
return None, None, "❌ Please enter a text prompt."
try:
progress(0.1, desc="Loading model...")
load_model(model_name)
progress(0.15, desc="Loading audio...")
log(f"[separate_audio] loading audio...")
waveform, sample_rate = load_audio(file_path)
duration = waveform.shape[1] / sample_rate
log(f"[separate_audio] audio loaded sr={sample_rate} duration={duration:.2f}s shape={tuple(waveform.shape)}")
# Decide whether to use chunking
use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING
if use_chunking:
progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...")
chunks = split_audio_into_chunks(waveform, sample_rate, chunk_duration, OVERLAP_DURATION)
num_chunks = len(chunks)
target_chunks = []
residual_chunks = []
for i, chunk in enumerate(chunks):
chunk_progress = 0.2 + (i / num_chunks) * 0.6
progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...")
# Save chunk to temp file for processor
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, chunk, sample_rate)
chunk_path = tmp.name
try:
log(f"[separate_audio] building inputs on device={device} ...")
inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device)
log("[separate_audio] running model.separate() ...")
with torch.inference_mode():
result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
log("[separate_audio] model.separate() done")
target_chunks.append(result.target[0].cpu())
residual_chunks.append(result.residual[0].cpu())
finally:
os.unlink(chunk_path)
progress(0.85, desc="Merging chunks...")
target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION)
residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
progress(0.95, desc="Saving results...")
# merged tensors are already 2D [channels, samples]
target_path = save_audio(target_merged, sample_rate)
residual_path = save_audio(residual_merged, sample_rate)
progress(1.0, desc="Done!")
return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)"
else:
# Process without chunking
progress(0.3, desc="Processing audio...")
log(f"[separate_audio] building inputs on device={device} ...")
inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
progress(0.6, desc="Separating sounds...")
log("[separate_audio] running model.separate() ...")
with torch.inference_mode():
result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
progress(0.9, desc="Saving results...")
log("[separate_audio] model.separate() done")
sample_rate = processor.audio_sampling_rate
target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
progress(1.0, desc="Done!")
return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name}"
except Exception as e:
import traceback
log(f"[separate_audio] EXCEPTION: {e}")
traceback.print_exc()
sys.stdout.flush()
return None, None, f"❌ Error: {str(e)}"
finally:
log(f"[separate_audio] EXIT after {time.time() - t0:.2f}s")
# Build Interface
with gr.Blocks(title="SAM-Audio Test") as demo:
gr.Markdown(
"""
# 🎡 SAM-Audio: Segment Anything for Audio
Isolate specific sounds from audio or video using natural language prompts.
"""
)
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(
choices=list(MODELS.keys()),
value=DEFAULT_MODEL,
label="Model"
)
with gr.Accordion("βš™οΈ Advanced Options", open=False):
chunk_duration_slider = gr.Slider(
minimum=10,
maximum=60,
value=DEFAULT_CHUNK_DURATION,
step=5,
label="Chunk Duration (seconds)",
info=f"Audio longer than {MAX_DURATION_WITHOUT_CHUNKING}s will be automatically split"
)
gr.Markdown("#### Upload Audio")
input_audio = gr.Audio(label="Audio File", type="filepath")
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="e.g., 'guitar', 'voice'"
)
run_btn = gr.Button("🎯 Isolate Sound", variant="primary", size="lg")
status_output = gr.Markdown("")
with gr.Column(scale=1):
gr.Markdown("### Results")
output_target = gr.Audio(label="Isolated Sound (Target)")
output_residual = gr.Audio(label="Background (Residual)")
gr.Markdown("---")
gr.Markdown("### 🎬 Demo Examples")
gr.Markdown("Click to load example audio and prompt:")
with gr.Row():
if os.path.exists(EXAMPLE_FILE):
example_btn1 = gr.Button("🎀 Man Speaking")
example_btn2 = gr.Button("🎀 Woman Speaking")
example_btn3 = gr.Button("🎡 Background Music")
# Main process button
# def process(model_name, audio_path, prompt, chunk_duration, progress=gr.Progress()):
# return separate_audio(model_name, audio_path, prompt, chunk_duration, progress)
def process(model_name, audio_path, prompt, chunk_duration, progress=gr.Progress()):
t0 = time.time()
log(f"[process] called model={model_name} chunk_duration={chunk_duration} prompt_len={len(prompt) if prompt else 0}")
# audio_path can be None or a string filepath depending on gradio
log(f"[process] audio_path type={type(audio_path)} value={audio_path}")
try:
out = separate_audio(model_name, audio_path, prompt, chunk_duration, progress)
log(f"[process] finished in {time.time() - t0:.2f}s")
return out
except Exception as e:
log(f"[process] EXCEPTION: {e}")
raise
run_btn.click(
fn=process,
inputs=[model_selector, input_audio, text_prompt, chunk_duration_slider],
outputs=[output_target, output_residual, status_output]
)
# Example buttons - just fill the prompt, user clicks button to process
if os.path.exists(EXAMPLE_FILE):
example_btn1.click(
fn=lambda: (EXAMPLE_FILE, "Guitar"),
outputs=[input_audio, text_prompt]
)
example_btn2.click(
fn=lambda: (EXAMPLE_FILE, "Voice"),
outputs=[input_audio, text_prompt]
)
if __name__ == "__main__":
demo.launch(show_error=True, share=True)
app = gr.mount_gradio_app(api, demo, path="/")