sam-audio-webui / app.py
Peter Shi
fix: Fixed the issue in the `merge_chunks_with_crossfade` function handling one-dimensional audio blocks and blocks shorter than the overlap area, and removed redundant dimension expansion operations in `save_audio`.v
f4c6545
import spaces
import gradio as gr
import torch
import torchaudio
import tempfile
import warnings
import os
warnings.filterwarnings("ignore")
from sam_audio import SAMAudio, SAMAudioProcessor
# Available models
MODELS = {
"sam-audio-small": "facebook/sam-audio-small",
"sam-audio-base": "facebook/sam-audio-base",
"sam-audio-large": "facebook/sam-audio-large",
"sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv",
"sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv",
"sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv",
}
DEFAULT_MODEL = "sam-audio-small"
EXAMPLES_DIR = "examples"
EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4")
# Chunk processing settings
DEFAULT_CHUNK_DURATION = 30 # seconds per chunk
OVERLAP_DURATION = 2 # seconds of overlap between chunks
MAX_DURATION_WITHOUT_CHUNKING = 60 # auto-chunk if longer than this
# Global model cache
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
current_model_name = None
model = None
processor = None
def load_model(model_name):
global current_model_name, model, processor
model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
if current_model_name == model_name and model is not None:
return
print(f"Loading {model_id}...")
model = SAMAudio.from_pretrained(model_id).to(device).eval()
processor = SAMAudioProcessor.from_pretrained(model_id)
current_model_name = model_name
print(f"Model {model_id} loaded on {device}.")
load_model(DEFAULT_MODEL)
def load_audio(file_path):
"""Load audio from file (supports both audio and video files)."""
waveform, sample_rate = torchaudio.load(file_path)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
return waveform, sample_rate
def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration):
"""Split audio waveform into overlapping chunks."""
chunk_samples = int(chunk_duration * sample_rate)
overlap_samples = int(overlap_duration * sample_rate)
stride = chunk_samples - overlap_samples
chunks = []
total_samples = waveform.shape[1]
start = 0
while start < total_samples:
end = min(start + chunk_samples, total_samples)
chunk = waveform[:, start:end]
chunks.append(chunk)
if end >= total_samples:
break
start += stride
return chunks
def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration):
"""Merge audio chunks with crossfade on overlapping regions."""
if len(chunks) == 1:
chunk = chunks[0]
# Ensure 2D tensor
if chunk.dim() == 1:
chunk = chunk.unsqueeze(0)
return chunk
overlap_samples = int(overlap_duration * sample_rate)
# Ensure all chunks are 2D [channels, samples]
processed_chunks = []
for chunk in chunks:
if chunk.dim() == 1:
chunk = chunk.unsqueeze(0)
processed_chunks.append(chunk)
result = processed_chunks[0]
for i in range(1, len(processed_chunks)):
prev_chunk = result
next_chunk = processed_chunks[i]
# Handle case where chunks are shorter than overlap
actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1])
if actual_overlap <= 0:
# No overlap possible, just concatenate
result = torch.cat([prev_chunk, next_chunk], dim=1)
continue
# Create fade curves
fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device)
fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device)
# Get overlapping regions
prev_overlap = prev_chunk[:, -actual_overlap:]
next_overlap = next_chunk[:, :actual_overlap]
# Crossfade mix
crossfaded = prev_overlap * fade_out + next_overlap * fade_in
# Concatenate: non-overlap of prev + crossfaded + non-overlap of next
result = torch.cat([
prev_chunk[:, :-actual_overlap],
crossfaded,
next_chunk[:, actual_overlap:]
], dim=1)
return result
def save_audio(tensor, sample_rate):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, tensor, sample_rate)
return tmp.name
@spaces.GPU(duration=300)
def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CHUNK_DURATION, progress=gr.Progress()):
global model, processor
progress(0.05, desc="Checking inputs...")
if not file_path:
return None, None, "❌ Please upload an audio or video file."
if not text_prompt or not text_prompt.strip():
return None, None, "❌ Please enter a text prompt."
try:
progress(0.1, desc="Loading model...")
load_model(model_name)
progress(0.15, desc="Loading audio...")
waveform, sample_rate = load_audio(file_path)
duration = waveform.shape[1] / sample_rate
# Decide whether to use chunking
use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING
if use_chunking:
progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...")
chunks = split_audio_into_chunks(waveform, sample_rate, chunk_duration, OVERLAP_DURATION)
num_chunks = len(chunks)
target_chunks = []
residual_chunks = []
for i, chunk in enumerate(chunks):
chunk_progress = 0.2 + (i / num_chunks) * 0.6
progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...")
# Save chunk to temp file for processor
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, chunk, sample_rate)
chunk_path = tmp.name
try:
inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device)
with torch.inference_mode():
result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
target_chunks.append(result.target[0].cpu())
residual_chunks.append(result.residual[0].cpu())
finally:
os.unlink(chunk_path)
progress(0.85, desc="Merging chunks...")
target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION)
residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION)
progress(0.95, desc="Saving results...")
# merged tensors are already 2D [channels, samples]
target_path = save_audio(target_merged, sample_rate)
residual_path = save_audio(residual_merged, sample_rate)
progress(1.0, desc="Done!")
return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)"
else:
# Process without chunking
progress(0.3, desc="Processing audio...")
inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
progress(0.6, desc="Separating sounds...")
with torch.inference_mode():
result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
progress(0.9, desc="Saving results...")
sample_rate = processor.audio_sampling_rate
target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
progress(1.0, desc="Done!")
return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name}"
except Exception as e:
import traceback
traceback.print_exc()
return None, None, f"❌ Error: {str(e)}"
# Build Interface
with gr.Blocks(title="SAM-Audio Demo") as demo:
gr.Markdown(
"""
# 🎡 SAM-Audio: Segment Anything for Audio
Isolate specific sounds from audio or video using natural language prompts.
"""
)
with gr.Row():
with gr.Column(scale=1):
model_selector = gr.Dropdown(
choices=list(MODELS.keys()),
value=DEFAULT_MODEL,
label="Model"
)
with gr.Accordion("βš™οΈ Advanced Options", open=False):
chunk_duration_slider = gr.Slider(
minimum=10,
maximum=60,
value=DEFAULT_CHUNK_DURATION,
step=5,
label="Chunk Duration (seconds)",
info=f"Audio longer than {MAX_DURATION_WITHOUT_CHUNKING}s will be automatically split"
)
gr.Markdown("#### Upload Audio")
input_audio = gr.Audio(label="Audio File", type="filepath")
gr.Markdown("#### Or Upload Video")
input_video = gr.Video(label="Video File")
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="e.g., 'A man speaking', 'Piano', 'Dog barking'"
)
run_btn = gr.Button("🎯 Isolate Sound", variant="primary", size="lg")
status_output = gr.Markdown("")
with gr.Column(scale=1):
gr.Markdown("### Results")
output_target = gr.Audio(label="Isolated Sound (Target)")
output_residual = gr.Audio(label="Background (Residual)")
gr.Markdown("---")
gr.Markdown("### 🎬 Demo Examples")
gr.Markdown("Click to load example video and prompt:")
with gr.Row():
if os.path.exists(EXAMPLE_FILE):
example_btn1 = gr.Button("🎀 Man Speaking")
example_btn2 = gr.Button("🎀 Woman Speaking")
example_btn3 = gr.Button("🎡 Background Music")
# Main process button
def process(model_name, audio_path, video_path, prompt, chunk_duration, progress=gr.Progress()):
file_path = video_path if video_path else audio_path
return separate_audio(model_name, file_path, prompt, chunk_duration, progress)
run_btn.click(
fn=process,
inputs=[model_selector, input_audio, input_video, text_prompt, chunk_duration_slider],
outputs=[output_target, output_residual, status_output]
)
# Example buttons - just fill the prompt, user clicks button to process
if os.path.exists(EXAMPLE_FILE):
example_btn1.click(
fn=lambda: (EXAMPLE_FILE, "A man speaking"),
outputs=[input_video, text_prompt]
)
example_btn2.click(
fn=lambda: (EXAMPLE_FILE, "A woman speaking"),
outputs=[input_video, text_prompt]
)
example_btn3.click(
fn=lambda: (EXAMPLE_FILE, "Background music"),
outputs=[input_video, text_prompt]
)
if __name__ == "__main__":
demo.launch()