ray-006 commited on
Commit
4556e15
·
verified ·
1 Parent(s): 9987fb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -4
app.py CHANGED
@@ -1,7 +1,73 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import os
5
+ import tempfile
6
+ from sam_audio import SAMAudio, SAMAudioProcessor
7
 
8
+ # --- Initialization ---
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ # Load model and processor once when the app starts
12
+ model = SAMAudio.from_pretrained("facebook/sam-audio-large").to(device).eval()
13
+ processor = SAMAudioProcessor.from_pretrained("facebook/sam-audio-large")
14
+
15
+ def separate_audio(audio_path, description, reranking_candidates):
16
+ if audio_path is None or not description:
17
+ return None, None
18
+
19
+ # Process inputs
20
+ inputs = processor(audios=[audio_path], descriptions=[description]).to(device)
21
+
22
+ with torch.inference_mode():
23
+ # Using reranking if candidates > 1
24
+ result = model.separate(
25
+ inputs,
26
+ predict_spans=True,
27
+ reranking_candidates=int(reranking_candidates)
28
+ )
29
+
30
+ # Use temporary files to store the results for Gradio
31
+ target_path = os.path.join(tempfile.gettempdir(), "target.wav")
32
+ residual_path = os.path.join(tempfile.gettempdir(), "residual.wav")
33
+
34
+ # Save target and residual
35
+ torchaudio.save(target_path, result.target[0].unsqueeze(0).cpu(), processor.audio_sampling_rate)
36
+ torchaudio.save(residual_path, result.residual[0].unsqueeze(0).cpu(), processor.audio_sampling_rate)
37
+
38
+ return target_path, residual_path
39
+
40
+ # --- UI Design ---
41
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
42
+ gr.Markdown("# 🎵 SAM-Audio Separation")
43
+ gr.Markdown("Upload an audio file and describe the specific sound you want to isolate (e.g., 'A dog barking' or 'A man speaking').")
44
+
45
+ with gr.Row():
46
+ with gr.Column():
47
+ input_audio = gr.Audio(label="Input Audio", type="filepath")
48
+ description = gr.Textbox(
49
+ label="What do you want to isolate?",
50
+ placeholder="e.g. A person laughing"
51
+ )
52
+ rerank_slider = gr.Slider(
53
+ minimum=1,
54
+ maximum=16,
55
+ value=1,
56
+ step=1,
57
+ label="Reranking Candidates",
58
+ info="Higher values improve quality but increase processing time."
59
+ )
60
+ btn = gr.Button("Separate Sound", variant="primary")
61
+
62
+ with gr.Column():
63
+ output_target = gr.Audio(label="Isolated (Target) Audio")
64
+ output_residual = gr.Audio(label="Residual Audio")
65
+
66
+ btn.click(
67
+ fn=separate_audio,
68
+ inputs=[input_audio, description, rerank_slider],
69
+ outputs=[output_target, output_residual]
70
+ )
71
+
72
+ if __name__ == "__main__":
73
+ demo.launch()