File size: 3,300 Bytes
77a6190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd818d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import torch
import torchaudio
import os
import tempfile
from sam_audio import SAMAudio, SAMAudioProcessor

# 1. Configuration
MODEL_NAME = "facebook/sam-audio-small"
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Loading {MODEL_NAME} on {device}...")

# 2. Load Model and Processor
try:
    model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval()
    processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
    print("Model loaded successfully.")
except Exception as e:
    print(f"Error loading model. Did you set HF_TOKEN in secrets? Error: {e}")
    raise e

def save_audio(tensor, sample_rate):
    """Helper to save torch tensor to a temp file for Gradio output."""
    # Tensor shape is expected to be (C, T) or (T). Ensure it's on CPU.
    if tensor.dim() == 1:
        tensor = tensor.unsqueeze(0)
    tensor = tensor.detach().cpu()
    
    with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
        torchaudio.save(tmp.name, tensor, sample_rate)
        return tmp.name

def separate_audio(audio_path, text_prompt):
    if not audio_path:
        return None, None
    
    # 3. Process Inputs
    # The processor handles loading audio and text tokenization
    inputs = processor(
        audios=[audio_path], 
        descriptions=[text_prompt]
    ).to(device)

    # 4. Inference
    with torch.no_grad():
        # The model returns a result object with 'target' and 'residual'
        result = model.separate(inputs)

    # 5. Extract Outputs
    # result.target and result.residual are lists (batch size). We take the first one.
    target_audio = result.target[0]   # The sound you asked for
    residual_audio = result.residual[0] # Everything else
    
    # Get sampling rate from the processor config
    sr = processor.feature_extractor.sampling_rate

    # 6. Save to files
    target_path = save_audio(target_audio, sr)
    residual_path = save_audio(residual_audio, sr)

    return target_path, residual_path

# 7. Build Gradio Interface
with gr.Blocks(title="Meta SAM-Audio Demo") as demo:
    gr.Markdown(
        """
        # 🎵 SAM-Audio: Segment Anything for Audio
        Isolate specific sounds from an audio file using natural language prompts.
        
        **Model:** `facebook/sam-audio-small`
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
            text_prompt = gr.Textbox(
                label="Text Prompt", 
                placeholder="e.g., 'dog barking', 'man speaking', 'typing keyboard'",
                info="Describe the sound you want to isolate."
            )
            run_btn = gr.Button("Separate Audio", variant="primary")
        
        with gr.Column():
            output_target = gr.Audio(label="Isolated Sound (Target)")
            output_residual = gr.Audio(label="Background (Residual)")

    run_btn.click(
        fn=separate_audio,
        inputs=[input_audio, text_prompt],
        outputs=[output_target, output_residual]
    )

    gr.Examples(
        examples=[
            ["path/to/example.wav", "female vocals"],
        ],
        inputs=[input_audio, text_prompt]
    )

# Launch
demo.queue().launch(server_name="0.0.0.0", server_port=7860)