|
|
import spaces |
|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
import tempfile |
|
|
import warnings |
|
|
import os |
|
|
|
|
|
import logging |
|
|
import sys |
|
|
import time |
|
|
from sam_audio import SAMAudio, SAMAudioProcessor |
|
|
|
|
|
|
|
|
import os, uuid |
|
|
from fastapi import FastAPI, UploadFile, File |
|
|
from fastapi.responses import JSONResponse |
|
|
import gradio as gr |
|
|
|
|
|
api = FastAPI() |
|
|
|
|
|
UPLOAD_DIR = "/tmp/uploads" |
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True) |
|
|
|
|
|
@api.post("/upload_audio") |
|
|
async def upload_audio(file: UploadFile = File(...)): |
|
|
|
|
|
ext = os.path.splitext(file.filename)[1] or ".wav" |
|
|
out_name = f"{uuid.uuid4().hex}{ext}" |
|
|
out_path = os.path.join(UPLOAD_DIR, out_name) |
|
|
|
|
|
data = await file.read() |
|
|
with open(out_path, "wb") as f: |
|
|
f.write(data) |
|
|
|
|
|
|
|
|
|
|
|
return JSONResponse({"path": out_path, "url": f"/files/{out_name}"}) |
|
|
|
|
|
from fastapi.staticfiles import StaticFiles |
|
|
api.mount("/files", StaticFiles(directory=UPLOAD_DIR), name="files") |
|
|
|
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
logger = logging.getLogger("sam_space") |
|
|
logger.setLevel(logging.INFO) |
|
|
handler = logging.StreamHandler(sys.stdout) |
|
|
handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s %(message)s")) |
|
|
logger.handlers.clear() |
|
|
logger.addHandler(handler) |
|
|
|
|
|
def log(msg: str): |
|
|
logger.info(msg) |
|
|
sys.stdout.flush() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODELS = { |
|
|
"sam-audio-small": "facebook/sam-audio-small", |
|
|
"sam-audio-base": "facebook/sam-audio-base", |
|
|
"sam-audio-large": "facebook/sam-audio-large", |
|
|
"sam-audio-small-tv (Visual)": "facebook/sam-audio-small-tv", |
|
|
"sam-audio-base-tv (Visual)": "facebook/sam-audio-base-tv", |
|
|
"sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv", |
|
|
} |
|
|
|
|
|
DEFAULT_MODEL = "sam-audio-small" |
|
|
EXAMPLES_DIR = "audio" |
|
|
EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "PromoterClipMono.wav") |
|
|
|
|
|
|
|
|
DEFAULT_CHUNK_DURATION = 5 |
|
|
OVERLAP_DURATION = 1 |
|
|
MAX_DURATION_WITHOUT_CHUNKING = 10 |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
current_model_name = None |
|
|
model = None |
|
|
processor = None |
|
|
|
|
|
def load_model(model_name): |
|
|
log(f"App import complete. device={device} default_model={DEFAULT_MODEL} cwd={os.getcwd()}") |
|
|
global current_model_name, model, processor |
|
|
model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL]) |
|
|
if current_model_name == model_name and model is not None: |
|
|
return |
|
|
print(f"Loading {model_id}...") |
|
|
model = SAMAudio.from_pretrained(model_id).to(device).eval() |
|
|
processor = SAMAudioProcessor.from_pretrained(model_id) |
|
|
current_model_name = model_name |
|
|
print(f"Model {model_id} loaded on {device}.") |
|
|
|
|
|
load_model(DEFAULT_MODEL) |
|
|
|
|
|
def load_audio(file_path): |
|
|
"""Load audio from file (supports both audio and video files).""" |
|
|
waveform, sample_rate = torchaudio.load(file_path) |
|
|
|
|
|
if waveform.shape[0] > 1: |
|
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
|
return waveform, sample_rate |
|
|
|
|
|
def split_audio_into_chunks(waveform, sample_rate, chunk_duration, overlap_duration): |
|
|
"""Split audio waveform into overlapping chunks.""" |
|
|
chunk_samples = int(chunk_duration * sample_rate) |
|
|
overlap_samples = int(overlap_duration * sample_rate) |
|
|
stride = chunk_samples - overlap_samples |
|
|
|
|
|
chunks = [] |
|
|
total_samples = waveform.shape[1] |
|
|
|
|
|
start = 0 |
|
|
while start < total_samples: |
|
|
end = min(start + chunk_samples, total_samples) |
|
|
chunk = waveform[:, start:end] |
|
|
chunks.append(chunk) |
|
|
|
|
|
if end >= total_samples: |
|
|
break |
|
|
start += stride |
|
|
|
|
|
return chunks |
|
|
|
|
|
def merge_chunks_with_crossfade(chunks, sample_rate, overlap_duration): |
|
|
"""Merge audio chunks with crossfade on overlapping regions.""" |
|
|
if len(chunks) == 1: |
|
|
chunk = chunks[0] |
|
|
|
|
|
if chunk.dim() == 1: |
|
|
chunk = chunk.unsqueeze(0) |
|
|
return chunk |
|
|
|
|
|
overlap_samples = int(overlap_duration * sample_rate) |
|
|
|
|
|
|
|
|
processed_chunks = [] |
|
|
for chunk in chunks: |
|
|
if chunk.dim() == 1: |
|
|
chunk = chunk.unsqueeze(0) |
|
|
processed_chunks.append(chunk) |
|
|
|
|
|
result = processed_chunks[0] |
|
|
|
|
|
for i in range(1, len(processed_chunks)): |
|
|
prev_chunk = result |
|
|
next_chunk = processed_chunks[i] |
|
|
|
|
|
|
|
|
actual_overlap = min(overlap_samples, prev_chunk.shape[1], next_chunk.shape[1]) |
|
|
|
|
|
if actual_overlap <= 0: |
|
|
|
|
|
result = torch.cat([prev_chunk, next_chunk], dim=1) |
|
|
continue |
|
|
|
|
|
|
|
|
fade_out = torch.linspace(1.0, 0.0, actual_overlap).to(prev_chunk.device) |
|
|
fade_in = torch.linspace(0.0, 1.0, actual_overlap).to(next_chunk.device) |
|
|
|
|
|
|
|
|
prev_overlap = prev_chunk[:, -actual_overlap:] |
|
|
next_overlap = next_chunk[:, :actual_overlap] |
|
|
|
|
|
|
|
|
crossfaded = prev_overlap * fade_out + next_overlap * fade_in |
|
|
|
|
|
|
|
|
result = torch.cat([ |
|
|
prev_chunk[:, :-actual_overlap], |
|
|
crossfaded, |
|
|
next_chunk[:, actual_overlap:] |
|
|
], dim=1) |
|
|
|
|
|
return result |
|
|
|
|
|
def save_audio(tensor, sample_rate): |
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
|
torchaudio.save(tmp.name, tensor, sample_rate) |
|
|
return tmp.name |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=10) |
|
|
def separate_audio(model_name, file_path, text_prompt, chunk_duration=DEFAULT_CHUNK_DURATION, progress=gr.Progress()): |
|
|
global model, processor |
|
|
|
|
|
t0 = time.time() |
|
|
log(f"[separate_audio] ENTER model={model_name} file_path={file_path} prompt='{(text_prompt or '')[:80]}' chunk_duration={chunk_duration}") |
|
|
|
|
|
|
|
|
if isinstance(file_path, str): |
|
|
exists = os.path.exists(file_path) |
|
|
size = os.path.getsize(file_path) if exists else -1 |
|
|
log(f"[separate_audio] file exists={exists} size={size}") |
|
|
else: |
|
|
log(f"[separate_audio] unexpected file_path type: {type(file_path)}") |
|
|
|
|
|
progress(0.05, desc="Checking inputs...") |
|
|
|
|
|
if not file_path: |
|
|
return None, None, "❌ Please upload an audio file." |
|
|
if not text_prompt or not text_prompt.strip(): |
|
|
return None, None, "❌ Please enter a text prompt." |
|
|
|
|
|
try: |
|
|
progress(0.1, desc="Loading model...") |
|
|
load_model(model_name) |
|
|
|
|
|
progress(0.15, desc="Loading audio...") |
|
|
|
|
|
log(f"[separate_audio] loading audio...") |
|
|
waveform, sample_rate = load_audio(file_path) |
|
|
duration = waveform.shape[1] / sample_rate |
|
|
log(f"[separate_audio] audio loaded sr={sample_rate} duration={duration:.2f}s shape={tuple(waveform.shape)}") |
|
|
|
|
|
|
|
|
use_chunking = duration > MAX_DURATION_WITHOUT_CHUNKING |
|
|
|
|
|
if use_chunking: |
|
|
progress(0.2, desc=f"Audio is {duration:.1f}s, splitting into chunks...") |
|
|
chunks = split_audio_into_chunks(waveform, sample_rate, chunk_duration, OVERLAP_DURATION) |
|
|
num_chunks = len(chunks) |
|
|
|
|
|
target_chunks = [] |
|
|
residual_chunks = [] |
|
|
|
|
|
for i, chunk in enumerate(chunks): |
|
|
chunk_progress = 0.2 + (i / num_chunks) * 0.6 |
|
|
progress(chunk_progress, desc=f"Processing chunk {i+1}/{num_chunks}...") |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
|
torchaudio.save(tmp.name, chunk, sample_rate) |
|
|
chunk_path = tmp.name |
|
|
|
|
|
try: |
|
|
log(f"[separate_audio] building inputs on device={device} ...") |
|
|
inputs = processor(audios=[chunk_path], descriptions=[text_prompt.strip()]).to(device) |
|
|
|
|
|
log("[separate_audio] running model.separate() ...") |
|
|
with torch.inference_mode(): |
|
|
result = model.separate(inputs, predict_spans=False, reranking_candidates=1) |
|
|
|
|
|
log("[separate_audio] model.separate() done") |
|
|
target_chunks.append(result.target[0].cpu()) |
|
|
residual_chunks.append(result.residual[0].cpu()) |
|
|
finally: |
|
|
os.unlink(chunk_path) |
|
|
|
|
|
progress(0.85, desc="Merging chunks...") |
|
|
target_merged = merge_chunks_with_crossfade(target_chunks, sample_rate, OVERLAP_DURATION) |
|
|
residual_merged = merge_chunks_with_crossfade(residual_chunks, sample_rate, OVERLAP_DURATION) |
|
|
|
|
|
progress(0.95, desc="Saving results...") |
|
|
|
|
|
target_path = save_audio(target_merged, sample_rate) |
|
|
residual_path = save_audio(residual_merged, sample_rate) |
|
|
|
|
|
progress(1.0, desc="Done!") |
|
|
return target_path, residual_path, f"✅ Isolated '{text_prompt}' using {model_name} ({num_chunks} chunks)" |
|
|
else: |
|
|
|
|
|
progress(0.3, desc="Processing audio...") |
|
|
log(f"[separate_audio] building inputs on device={device} ...") |
|
|
inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device) |
|
|
|
|
|
progress(0.6, desc="Separating sounds...") |
|
|
log("[separate_audio] running model.separate() ...") |
|
|
with torch.inference_mode(): |
|
|
result = model.separate(inputs, predict_spans=False, reranking_candidates=1) |
|
|
|
|
|
progress(0.9, desc="Saving results...") |
|
|
log("[separate_audio] model.separate() done") |
|
|
sample_rate = processor.audio_sampling_rate |
|
|
target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate) |
|
|
residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate) |
|
|
|
|
|
progress(1.0, desc="Done!") |
|
|
return target_path, residual_path, f"✅ Isolated '{text_prompt}' using {model_name}" |
|
|
except Exception as e: |
|
|
import traceback |
|
|
log(f"[separate_audio] EXCEPTION: {e}") |
|
|
traceback.print_exc() |
|
|
sys.stdout.flush() |
|
|
return None, None, f"❌ Error: {str(e)}" |
|
|
finally: |
|
|
log(f"[separate_audio] EXIT after {time.time() - t0:.2f}s") |
|
|
|
|
|
|
|
|
with gr.Blocks(title="SAM-Audio Test") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🎵 SAM-Audio: Segment Anything for Audio |
|
|
Isolate specific sounds from audio or video using natural language prompts. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
model_selector = gr.Dropdown( |
|
|
choices=list(MODELS.keys()), |
|
|
value=DEFAULT_MODEL, |
|
|
label="Model" |
|
|
) |
|
|
|
|
|
with gr.Accordion("⚙️ Advanced Options", open=False): |
|
|
chunk_duration_slider = gr.Slider( |
|
|
minimum=10, |
|
|
maximum=60, |
|
|
value=DEFAULT_CHUNK_DURATION, |
|
|
step=5, |
|
|
label="Chunk Duration (seconds)", |
|
|
info=f"Audio longer than {MAX_DURATION_WITHOUT_CHUNKING}s will be automatically split" |
|
|
) |
|
|
|
|
|
gr.Markdown("#### Upload Audio") |
|
|
input_audio = gr.Audio(label="Audio File", type="filepath") |
|
|
|
|
|
text_prompt = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
placeholder="e.g., 'guitar', 'voice'" |
|
|
) |
|
|
|
|
|
run_btn = gr.Button("🎯 Isolate Sound", variant="primary", size="lg") |
|
|
status_output = gr.Markdown("") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### Results") |
|
|
output_target = gr.Audio(label="Isolated Sound (Target)") |
|
|
output_residual = gr.Audio(label="Background (Residual)") |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### 🎬 Demo Examples") |
|
|
gr.Markdown("Click to load example audio and prompt:") |
|
|
|
|
|
with gr.Row(): |
|
|
if os.path.exists(EXAMPLE_FILE): |
|
|
example_btn1 = gr.Button("🎤 Man Speaking") |
|
|
example_btn2 = gr.Button("🎤 Woman Speaking") |
|
|
example_btn3 = gr.Button("🎵 Background Music") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process(model_name, audio_path, prompt, chunk_duration, progress=gr.Progress()): |
|
|
t0 = time.time() |
|
|
log(f"[process] called model={model_name} chunk_duration={chunk_duration} prompt_len={len(prompt) if prompt else 0}") |
|
|
|
|
|
|
|
|
log(f"[process] audio_path type={type(audio_path)} value={audio_path}") |
|
|
|
|
|
try: |
|
|
out = separate_audio(model_name, audio_path, prompt, chunk_duration, progress) |
|
|
log(f"[process] finished in {time.time() - t0:.2f}s") |
|
|
return out |
|
|
except Exception as e: |
|
|
log(f"[process] EXCEPTION: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
fn=process, |
|
|
inputs=[model_selector, input_audio, text_prompt, chunk_duration_slider], |
|
|
outputs=[output_target, output_residual, status_output] |
|
|
) |
|
|
|
|
|
|
|
|
if os.path.exists(EXAMPLE_FILE): |
|
|
example_btn1.click( |
|
|
fn=lambda: (EXAMPLE_FILE, "Guitar"), |
|
|
outputs=[input_audio, text_prompt] |
|
|
) |
|
|
example_btn2.click( |
|
|
fn=lambda: (EXAMPLE_FILE, "Voice"), |
|
|
outputs=[input_audio, text_prompt] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True, share=True) |
|
|
|
|
|
app = gr.mount_gradio_app(api, demo, path="/") |
|
|
|
|
|
|