File size: 3,300 Bytes
77a6190 4dd818d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
import torch
import torchaudio
import os
import tempfile
from sam_audio import SAMAudio, SAMAudioProcessor
# 1. Configuration
MODEL_NAME = "facebook/sam-audio-small"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading {MODEL_NAME} on {device}...")
# 2. Load Model and Processor
try:
model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval()
processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model. Did you set HF_TOKEN in secrets? Error: {e}")
raise e
def save_audio(tensor, sample_rate):
"""Helper to save torch tensor to a temp file for Gradio output."""
# Tensor shape is expected to be (C, T) or (T). Ensure it's on CPU.
if tensor.dim() == 1:
tensor = tensor.unsqueeze(0)
tensor = tensor.detach().cpu()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
torchaudio.save(tmp.name, tensor, sample_rate)
return tmp.name
def separate_audio(audio_path, text_prompt):
if not audio_path:
return None, None
# 3. Process Inputs
# The processor handles loading audio and text tokenization
inputs = processor(
audios=[audio_path],
descriptions=[text_prompt]
).to(device)
# 4. Inference
with torch.no_grad():
# The model returns a result object with 'target' and 'residual'
result = model.separate(inputs)
# 5. Extract Outputs
# result.target and result.residual are lists (batch size). We take the first one.
target_audio = result.target[0] # The sound you asked for
residual_audio = result.residual[0] # Everything else
# Get sampling rate from the processor config
sr = processor.feature_extractor.sampling_rate
# 6. Save to files
target_path = save_audio(target_audio, sr)
residual_path = save_audio(residual_audio, sr)
return target_path, residual_path
# 7. Build Gradio Interface
with gr.Blocks(title="Meta SAM-Audio Demo") as demo:
gr.Markdown(
"""
# 🎵 SAM-Audio: Segment Anything for Audio
Isolate specific sounds from an audio file using natural language prompts.
**Model:** `facebook/sam-audio-small`
"""
)
with gr.Row():
with gr.Column():
input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="e.g., 'dog barking', 'man speaking', 'typing keyboard'",
info="Describe the sound you want to isolate."
)
run_btn = gr.Button("Separate Audio", variant="primary")
with gr.Column():
output_target = gr.Audio(label="Isolated Sound (Target)")
output_residual = gr.Audio(label="Background (Residual)")
run_btn.click(
fn=separate_audio,
inputs=[input_audio, text_prompt],
outputs=[output_target, output_residual]
)
gr.Examples(
examples=[
["path/to/example.wav", "female vocals"],
],
inputs=[input_audio, text_prompt]
)
# Launch
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |