|
|
import gradio as gr |
|
|
import torch |
|
|
import torchaudio |
|
|
import os |
|
|
import tempfile |
|
|
from sam_audio import SAMAudio, SAMAudioProcessor |
|
|
|
|
|
|
|
|
MODEL_NAME = "facebook/sam-audio-small" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
print(f"Loading {MODEL_NAME} on {device}...") |
|
|
|
|
|
|
|
|
try: |
|
|
model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval() |
|
|
processor = SAMAudioProcessor.from_pretrained(MODEL_NAME) |
|
|
print("Model loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading model. Did you set HF_TOKEN in secrets? Error: {e}") |
|
|
raise e |
|
|
|
|
|
def save_audio(tensor, sample_rate): |
|
|
"""Helper to save torch tensor to a temp file for Gradio output.""" |
|
|
|
|
|
if tensor.dim() == 1: |
|
|
tensor = tensor.unsqueeze(0) |
|
|
tensor = tensor.detach().cpu() |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: |
|
|
torchaudio.save(tmp.name, tensor, sample_rate) |
|
|
return tmp.name |
|
|
|
|
|
def separate_audio(audio_path, text_prompt): |
|
|
if not audio_path: |
|
|
return None, None |
|
|
|
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
audios=[audio_path], |
|
|
descriptions=[text_prompt] |
|
|
).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
result = model.separate(inputs) |
|
|
|
|
|
|
|
|
|
|
|
target_audio = result.target[0] |
|
|
residual_audio = result.residual[0] |
|
|
|
|
|
|
|
|
sr = processor.feature_extractor.sampling_rate |
|
|
|
|
|
|
|
|
target_path = save_audio(target_audio, sr) |
|
|
residual_path = save_audio(residual_audio, sr) |
|
|
|
|
|
return target_path, residual_path |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Meta SAM-Audio Demo") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🎵 SAM-Audio: Segment Anything for Audio |
|
|
Isolate specific sounds from an audio file using natural language prompts. |
|
|
|
|
|
**Model:** `facebook/sam-audio-small` |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_audio = gr.Audio(label="Upload Input Audio", type="filepath") |
|
|
text_prompt = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
placeholder="e.g., 'dog barking', 'man speaking', 'typing keyboard'", |
|
|
info="Describe the sound you want to isolate." |
|
|
) |
|
|
run_btn = gr.Button("Separate Audio", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_target = gr.Audio(label="Isolated Sound (Target)") |
|
|
output_residual = gr.Audio(label="Background (Residual)") |
|
|
|
|
|
run_btn.click( |
|
|
fn=separate_audio, |
|
|
inputs=[input_audio, text_prompt], |
|
|
outputs=[output_target, output_residual] |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["path/to/example.wav", "female vocals"], |
|
|
], |
|
|
inputs=[input_audio, text_prompt] |
|
|
) |
|
|
|
|
|
|
|
|
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |