Peter Shi commited on
Commit
2922fa7
Β·
1 Parent(s): b02c18a

Add MP4 and video file support

Browse files
Files changed (1) hide show
  1. app.py +168 -47
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import torchaudio
5
  import tempfile
6
  import warnings
 
7
  warnings.filterwarnings("ignore")
8
 
9
  from sam_audio import SAMAudio, SAMAudioProcessor
@@ -11,96 +12,216 @@ from sam_audio import SAMAudio, SAMAudioProcessor
11
  # Configuration
12
  MODEL_NAME = "facebook/sam-audio-small"
13
 
14
- # Load model and processor (following official HuggingFace example)
15
  print(f"Loading {MODEL_NAME}...")
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval()
18
  processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
19
  print(f"Model loaded on {device}.")
20
 
 
 
 
21
  def save_audio(tensor, sample_rate):
22
  """Helper to save torch tensor to a temp file for Gradio output."""
23
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
24
  torchaudio.save(tmp.name, tensor, sample_rate)
25
  return tmp.name
26
 
 
 
 
 
 
 
 
 
 
27
  @spaces.GPU(duration=300)
28
- def separate_audio(audio_path, text_prompt):
29
- if not audio_path:
30
- return None, None, "❌ Please upload an audio file."
 
 
 
 
 
31
 
32
  if not text_prompt or not text_prompt.strip():
33
- text_prompt = "vocals"
34
 
35
  try:
36
- # Process and separate (following official example)
37
  inputs = processor(
38
- audios=[audio_path],
39
  descriptions=[text_prompt.strip()]
40
  ).to(device)
41
 
42
  with torch.inference_mode():
43
  result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
44
 
45
- # Save results (following official example: result.target[0].unsqueeze(0).cpu())
46
  sample_rate = processor.audio_sampling_rate
47
  target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
48
  residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
49
 
50
- return target_path, residual_path, f"βœ… Successfully separated '{text_prompt}' from the audio."
51
 
52
  except Exception as e:
53
  import traceback
54
  traceback.print_exc()
55
  return None, None, f"❌ Error: {str(e)}"
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Build Gradio Interface
58
  with gr.Blocks(
59
- theme=gr.themes.Soft(),
60
- title="SAM-Audio - Segment Anything for Audio"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ) as demo:
62
- gr.Markdown(
63
- """
64
- # 🎡 SAM-Audio: Segment Anything for Audio
65
-
66
- Isolate specific sounds from an audio file using natural language prompts.
67
-
68
- **Model:** [facebook/sam-audio-small](https://huggingface.co/facebook/sam-audio-small)
69
- """
70
- )
71
 
72
  with gr.Row():
73
- with gr.Column():
74
- input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  text_prompt = gr.Textbox(
76
- label="Text Prompt",
77
- placeholder="e.g., 'A man speaking', 'Piano playing', 'Dog barking'",
78
- value="A man speaking",
79
- info="Describe the sound you want to isolate."
80
  )
81
- run_btn = gr.Button("🎯 Separate Audio", variant="primary", size="lg")
82
-
83
- with gr.Column():
84
- output_target = gr.Audio(label="Isolated Sound (Target)")
85
- output_residual = gr.Audio(label="Background (Residual)")
86
-
87
- info_output = gr.Markdown(value="πŸ“ Upload an audio file and enter a prompt to start.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  run_btn.click(
90
- fn=separate_audio,
91
- inputs=[input_audio, text_prompt],
92
- outputs=[output_target, output_residual, info_output]
93
- )
94
-
95
- gr.Markdown(
96
- """
97
- ### Example Prompts
98
- - "A person coughing"
99
- - "Piano playing a melody"
100
- - "Dog barking"
101
- - "Car engine revving"
102
- - "Raindrops falling"
103
- """
104
  )
105
 
106
  if __name__ == "__main__":
 
4
  import torchaudio
5
  import tempfile
6
  import warnings
7
+ import os
8
  warnings.filterwarnings("ignore")
9
 
10
  from sam_audio import SAMAudio, SAMAudioProcessor
 
12
  # Configuration
13
  MODEL_NAME = "facebook/sam-audio-small"
14
 
15
+ # Load model and processor
16
  print(f"Loading {MODEL_NAME}...")
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval()
19
  processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
20
  print(f"Model loaded on {device}.")
21
 
22
+ # Supported file extensions
23
+ SUPPORTED_EXTENSIONS = ['.mp3', '.wav', '.flac', '.ogg', '.m4a', '.mp4', '.mkv', '.avi', '.mov', '.webm']
24
+
25
  def save_audio(tensor, sample_rate):
26
  """Helper to save torch tensor to a temp file for Gradio output."""
27
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
28
  torchaudio.save(tmp.name, tensor, sample_rate)
29
  return tmp.name
30
 
31
+ def validate_file(file_path):
32
+ """Check if file extension is supported."""
33
+ if not file_path:
34
+ return False, "No file uploaded"
35
+ ext = os.path.splitext(file_path)[1].lower()
36
+ if ext not in SUPPORTED_EXTENSIONS:
37
+ return False, f"Unsupported format: {ext}. Supported: {', '.join(SUPPORTED_EXTENSIONS)}"
38
+ return True, "OK"
39
+
40
  @spaces.GPU(duration=300)
41
+ def separate_audio(file_path, text_prompt):
42
+ if not file_path:
43
+ return None, None, "❌ Please upload an audio or video file."
44
+
45
+ # Validate file
46
+ valid, msg = validate_file(file_path)
47
+ if not valid:
48
+ return None, None, f"❌ {msg}"
49
 
50
  if not text_prompt or not text_prompt.strip():
51
+ return None, None, "❌ Please enter a text prompt describing the sound to isolate."
52
 
53
  try:
54
+ # SAM-Audio processor accepts both audio and video files directly
55
  inputs = processor(
56
+ audios=[file_path],
57
  descriptions=[text_prompt.strip()]
58
  ).to(device)
59
 
60
  with torch.inference_mode():
61
  result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
62
 
 
63
  sample_rate = processor.audio_sampling_rate
64
  target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
65
  residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
66
 
67
+ return target_path, residual_path, f"βœ… Successfully isolated **'{text_prompt}'**"
68
 
69
  except Exception as e:
70
  import traceback
71
  traceback.print_exc()
72
  return None, None, f"❌ Error: {str(e)}"
73
 
74
+ # Custom CSS for dark theme
75
+ custom_css = """
76
+ .gradio-container {
77
+ background: #0a0a0a !important;
78
+ max-width: 1400px !important;
79
+ }
80
+
81
+ .upload-box {
82
+ border: 2px dashed #444 !important;
83
+ border-radius: 12px !important;
84
+ background: #1a1a1a !important;
85
+ min-height: 200px !important;
86
+ transition: border-color 0.3s !important;
87
+ }
88
+
89
+ .upload-box:hover {
90
+ border-color: #e91e8c !important;
91
+ }
92
+
93
+ .result-card {
94
+ background: #1a1a1a !important;
95
+ border: 1px solid #333 !important;
96
+ border-radius: 12px !important;
97
+ padding: 1rem !important;
98
+ }
99
+
100
+ .primary-btn {
101
+ background: linear-gradient(135deg, #e91e8c, #9c27b0) !important;
102
+ border: none !important;
103
+ border-radius: 24px !important;
104
+ }
105
+
106
+ .sidebar-text {
107
+ color: #888 !important;
108
+ font-size: 0.9rem !important;
109
+ }
110
+
111
+ .step-text {
112
+ color: #ccc !important;
113
+ padding: 0.3rem 0 !important;
114
+ }
115
+
116
+ .pink-text {
117
+ color: #e91e8c !important;
118
+ }
119
+ """
120
+
121
  # Build Gradio Interface
122
  with gr.Blocks(
123
+ title="SAM-Audio - Isolate Sounds",
124
+ theme=gr.themes.Base(
125
+ primary_hue="pink",
126
+ secondary_hue="purple",
127
+ neutral_hue="gray",
128
+ ).set(
129
+ body_background_fill="#0a0a0a",
130
+ body_background_fill_dark="#0a0a0a",
131
+ block_background_fill="#1a1a1a",
132
+ block_background_fill_dark="#1a1a1a",
133
+ input_background_fill="#1a1a1a",
134
+ input_background_fill_dark="#1a1a1a",
135
+ button_primary_background_fill="linear-gradient(135deg, #e91e8c, #9c27b0)",
136
+ button_primary_background_fill_hover="linear-gradient(135deg, #d1187d, #8a22a0)",
137
+ border_color_primary="#333",
138
+ ),
139
+ css=custom_css
140
  ) as demo:
 
 
 
 
 
 
 
 
 
141
 
142
  with gr.Row():
143
+ # Sidebar
144
+ with gr.Column(scale=1, min_width=250):
145
+ gr.Markdown("## 🎡 Isolate Sounds")
146
+ gr.Markdown("Extract and isolate any sound from audio or video using AI.", elem_classes=["sidebar-text"])
147
+
148
+ gr.Markdown("---")
149
+ gr.Markdown("### How it works")
150
+ gr.Markdown("**1.** Add audio or video", elem_classes=["step-text"])
151
+ gr.Markdown("**2.** Describe the sound", elem_classes=["step-text"])
152
+ gr.Markdown("**3.** Get separated tracks", elem_classes=["step-text"])
153
+
154
+ gr.Markdown("---")
155
+ gr.Markdown("**Model**")
156
+ gr.Markdown("πŸ€– SAM-Audio Small")
157
+
158
+ gr.Markdown("---")
159
+ gr.Markdown("**Supported Formats**")
160
+ gr.Markdown("🎡 MP3, WAV, FLAC, OGG, M4A", elem_classes=["sidebar-text"])
161
+ gr.Markdown("🎬 MP4, MKV, AVI, MOV, WebM", elem_classes=["sidebar-text"])
162
+
163
+ # Main content area
164
+ with gr.Column(scale=4):
165
+ gr.Markdown("### πŸ“€ Upload Audio or Video")
166
+
167
+ # Use File component to accept both audio and video
168
+ input_file = gr.File(
169
+ label="Drop your audio or video file here",
170
+ file_types=SUPPORTED_EXTENSIONS,
171
+ elem_classes=["upload-box"]
172
+ )
173
+
174
+ gr.Markdown("### πŸ’¬ Describe the Sound to Isolate")
175
  text_prompt = gr.Textbox(
176
+ label="",
177
+ placeholder="e.g., 'A man speaking', 'Piano melody', 'Dog barking', 'Background music'",
178
+ lines=1
 
179
  )
180
+
181
+ with gr.Row():
182
+ run_btn = gr.Button(
183
+ "🎯 Isolate Sound",
184
+ variant="primary",
185
+ size="lg",
186
+ elem_classes=["primary-btn"]
187
+ )
188
+
189
+ status_output = gr.Markdown(
190
+ value="*Upload a file and describe what sound you want to isolate.*"
191
+ )
192
+
193
+ gr.Markdown("---")
194
+ gr.Markdown("### 🎧 Results")
195
+
196
+ with gr.Row():
197
+ with gr.Column(elem_classes=["result-card"]):
198
+ gr.Markdown("**🎯 Isolated Sound** (Target)")
199
+ output_target = gr.Audio(label="", show_label=False)
200
+
201
+ with gr.Column(elem_classes=["result-card"]):
202
+ gr.Markdown("**πŸ”‡ Background** (Residual)")
203
+ output_residual = gr.Audio(label="", show_label=False)
204
+
205
+ gr.Markdown("---")
206
+ gr.Markdown("### πŸ’‘ Example Prompts")
207
+ gr.Markdown("Click any example below to use it:")
208
+
209
+ with gr.Row():
210
+ for prompt in ["A man speaking", "A woman singing", "Piano", "Drums", "Guitar", "Dog barking"]:
211
+ gr.Button(prompt, size="sm").click(
212
+ fn=lambda p=prompt: p,
213
+ outputs=[text_prompt]
214
+ )
215
+
216
+ def process_file(file, prompt):
217
+ if file is None:
218
+ return None, None, "❌ Please upload a file."
219
+ return separate_audio(file.name, prompt)
220
 
221
  run_btn.click(
222
+ fn=process_file,
223
+ inputs=[input_file, text_prompt],
224
+ outputs=[output_target, output_residual, status_output]
 
 
 
 
 
 
 
 
 
 
 
225
  )
226
 
227
  if __name__ == "__main__":