Peter Shi commited on
Commit
cebdac8
Β·
1 Parent(s): c12dff8

Restore audio/video preview with tabs

Browse files
Files changed (1) hide show
  1. app.py +43 -65
app.py CHANGED
@@ -19,10 +19,7 @@ MODELS = {
19
  "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv",
20
  }
21
 
22
- # Default model
23
  DEFAULT_MODEL = "sam-audio-small"
24
-
25
- # Example files
26
  EXAMPLES_DIR = "examples"
27
  EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4")
28
 
@@ -33,25 +30,19 @@ model = None
33
  processor = None
34
 
35
  def load_model(model_name):
36
- """Load or switch model."""
37
  global current_model_name, model, processor
38
-
39
  model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
40
-
41
  if current_model_name == model_name and model is not None:
42
  return
43
-
44
  print(f"Loading {model_id}...")
45
  model = SAMAudio.from_pretrained(model_id).to(device).eval()
46
  processor = SAMAudioProcessor.from_pretrained(model_id)
47
  current_model_name = model_name
48
  print(f"Model {model_id} loaded on {device}.")
49
 
50
- # Load default model at startup
51
  load_model(DEFAULT_MODEL)
52
 
53
  def save_audio(tensor, sample_rate):
54
- """Helper to save torch tensor to a temp file for Gradio output."""
55
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
56
  torchaudio.save(tmp.name, tensor, sample_rate)
57
  return tmp.name
@@ -59,134 +50,121 @@ def save_audio(tensor, sample_rate):
59
  @spaces.GPU(duration=300)
60
  def separate_audio(model_name, file_path, text_prompt):
61
  global model, processor
62
-
63
- # Load selected model if different
64
  load_model(model_name)
65
 
66
  if not file_path:
67
  return None, None, "❌ Please upload an audio or video file."
68
-
69
  if not text_prompt or not text_prompt.strip():
70
- return None, None, "❌ Please enter a text prompt describing the sound to isolate."
71
 
72
  try:
73
- inputs = processor(
74
- audios=[file_path],
75
- descriptions=[text_prompt.strip()]
76
- ).to(device)
77
-
78
  with torch.inference_mode():
79
  result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
80
-
81
  sample_rate = processor.audio_sampling_rate
82
  target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
83
  residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
84
-
85
- return target_path, residual_path, f"βœ… Successfully isolated '{text_prompt}' using {model_name}"
86
-
87
  except Exception as e:
88
  import traceback
89
  traceback.print_exc()
90
  return None, None, f"❌ Error: {str(e)}"
91
 
92
- def process_file(model_name, file, prompt):
93
- if file is None:
94
- return None, None, "❌ Please upload a file."
95
- # Handle both file object and file path
96
- file_path = file.name if hasattr(file, 'name') else file
97
- return separate_audio(model_name, file_path, prompt)
 
 
 
98
 
99
- def process_example(model_name, file_path, prompt):
100
- """Process directly from example - file_path is already a string."""
101
- if not file_path or not os.path.exists(file_path):
102
  return None, None, "❌ Example file not found."
103
- return separate_audio(model_name, file_path, prompt)
104
 
105
- # Build Gradio Interface
 
 
 
106
  with gr.Blocks(title="SAM-Audio Demo") as demo:
107
  gr.Markdown(
108
  """
109
  # 🎡 SAM-Audio: Segment Anything for Audio
110
-
111
- Isolate specific sounds from an audio or video file using natural language prompts.
112
-
113
- **Models:** [facebook/sam-audio](https://huggingface.co/collections/facebook/sam-audio-67608edbf75ad66bf5e8cb3a)
114
  """
115
  )
116
 
117
  with gr.Row():
118
- with gr.Column():
119
  model_selector = gr.Dropdown(
120
  choices=list(MODELS.keys()),
121
  value=DEFAULT_MODEL,
122
- label="Model",
123
- info="Larger = better quality but slower. TV variants for visual prompting."
124
  )
125
 
126
- input_file = gr.File(
127
- label="Upload Audio or Video",
128
- file_types=[".mp3", ".wav", ".flac", ".ogg", ".m4a", ".mp4", ".mkv", ".avi", ".mov", ".webm"],
129
- )
 
130
 
131
  text_prompt = gr.Textbox(
132
- label="Text Prompt",
133
- placeholder="e.g., 'A man speaking', 'Piano melody', 'Dog barking'",
134
- info="Describe the sound you want to isolate."
135
  )
136
 
137
  run_btn = gr.Button("🎯 Isolate Sound", variant="primary")
138
- status_output = gr.Markdown(value="")
139
 
140
- with gr.Column():
141
  gr.Markdown("### Results")
142
  output_target = gr.Audio(label="Isolated Sound (Target)")
143
  output_residual = gr.Audio(label="Background (Residual)")
144
 
145
  gr.Markdown("---")
146
- gr.Markdown("### 🎬 Try Demo Examples")
147
- gr.Markdown("Click an example below to auto-fill and process:")
148
 
149
  with gr.Row():
150
  if os.path.exists(EXAMPLE_FILE):
151
  example_btn1 = gr.Button("🎀 Man Speaking")
152
- example_btn2 = gr.Button("🎀 Woman Speaking")
153
  example_btn3 = gr.Button("🎡 Background Music")
154
 
155
- gr.Markdown("---")
156
- gr.Markdown("**Supported formats:** MP3, WAV, FLAC, OGG, M4A, MP4, MKV, AVI, MOV, WebM")
157
-
158
- # Main run button
159
  run_btn.click(
160
- fn=process_file,
161
- inputs=[model_selector, input_file, text_prompt],
162
  outputs=[output_target, output_residual, status_output]
163
  )
164
 
165
- # Example buttons - auto-fill and process
166
  if os.path.exists(EXAMPLE_FILE):
167
  example_btn1.click(
168
  fn=lambda: (EXAMPLE_FILE, "A man speaking"),
169
- outputs=[input_file, text_prompt]
170
  ).then(
171
- fn=lambda m: process_example(m, EXAMPLE_FILE, "A man speaking"),
172
  inputs=[model_selector],
173
  outputs=[output_target, output_residual, status_output]
174
  )
175
 
176
  example_btn2.click(
177
  fn=lambda: (EXAMPLE_FILE, "A woman speaking"),
178
- outputs=[input_file, text_prompt]
179
  ).then(
180
- fn=lambda m: process_example(m, EXAMPLE_FILE, "A woman speaking"),
181
  inputs=[model_selector],
182
  outputs=[output_target, output_residual, status_output]
183
  )
184
 
185
  example_btn3.click(
186
  fn=lambda: (EXAMPLE_FILE, "Background music"),
187
- outputs=[input_file, text_prompt]
188
  ).then(
189
- fn=lambda m: process_example(m, EXAMPLE_FILE, "Background music"),
190
  inputs=[model_selector],
191
  outputs=[output_target, output_residual, status_output]
192
  )
 
19
  "sam-audio-large-tv (Visual)": "facebook/sam-audio-large-tv",
20
  }
21
 
 
22
  DEFAULT_MODEL = "sam-audio-small"
 
 
23
  EXAMPLES_DIR = "examples"
24
  EXAMPLE_FILE = os.path.join(EXAMPLES_DIR, "office.mp4")
25
 
 
30
  processor = None
31
 
32
  def load_model(model_name):
 
33
  global current_model_name, model, processor
 
34
  model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
 
35
  if current_model_name == model_name and model is not None:
36
  return
 
37
  print(f"Loading {model_id}...")
38
  model = SAMAudio.from_pretrained(model_id).to(device).eval()
39
  processor = SAMAudioProcessor.from_pretrained(model_id)
40
  current_model_name = model_name
41
  print(f"Model {model_id} loaded on {device}.")
42
 
 
43
  load_model(DEFAULT_MODEL)
44
 
45
  def save_audio(tensor, sample_rate):
 
46
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
47
  torchaudio.save(tmp.name, tensor, sample_rate)
48
  return tmp.name
 
50
  @spaces.GPU(duration=300)
51
  def separate_audio(model_name, file_path, text_prompt):
52
  global model, processor
 
 
53
  load_model(model_name)
54
 
55
  if not file_path:
56
  return None, None, "❌ Please upload an audio or video file."
 
57
  if not text_prompt or not text_prompt.strip():
58
+ return None, None, "❌ Please enter a text prompt."
59
 
60
  try:
61
+ inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
 
 
 
 
62
  with torch.inference_mode():
63
  result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
 
64
  sample_rate = processor.audio_sampling_rate
65
  target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
66
  residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
67
+ return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name}"
 
 
68
  except Exception as e:
69
  import traceback
70
  traceback.print_exc()
71
  return None, None, f"❌ Error: {str(e)}"
72
 
73
+ def process_audio(model_name, audio_path, prompt):
74
+ if not audio_path:
75
+ return None, None, "❌ Please upload an audio file."
76
+ return separate_audio(model_name, audio_path, prompt)
77
+
78
+ def process_video(model_name, video_path, prompt):
79
+ if not video_path:
80
+ return None, None, "❌ Please upload a video file."
81
+ return separate_audio(model_name, video_path, prompt)
82
 
83
+ def process_example(model_name, prompt):
84
+ if not os.path.exists(EXAMPLE_FILE):
 
85
  return None, None, "❌ Example file not found."
86
+ return separate_audio(model_name, EXAMPLE_FILE, prompt)
87
 
88
+ def load_example(prompt):
89
+ return EXAMPLE_FILE, prompt
90
+
91
+ # Build Interface
92
  with gr.Blocks(title="SAM-Audio Demo") as demo:
93
  gr.Markdown(
94
  """
95
  # 🎡 SAM-Audio: Segment Anything for Audio
96
+ Isolate specific sounds from audio or video using natural language prompts.
 
 
 
97
  """
98
  )
99
 
100
  with gr.Row():
101
+ with gr.Column(scale=1):
102
  model_selector = gr.Dropdown(
103
  choices=list(MODELS.keys()),
104
  value=DEFAULT_MODEL,
105
+ label="Model"
 
106
  )
107
 
108
+ with gr.Tabs():
109
+ with gr.TabItem("🎡 Audio"):
110
+ input_audio = gr.Audio(label="Upload Audio", type="filepath")
111
+ with gr.TabItem("🎬 Video"):
112
+ input_video = gr.Video(label="Upload Video")
113
 
114
  text_prompt = gr.Textbox(
115
+ label="Text Prompt",
116
+ placeholder="e.g., 'A man speaking', 'Piano', 'Dog barking'"
 
117
  )
118
 
119
  run_btn = gr.Button("🎯 Isolate Sound", variant="primary")
120
+ status_output = gr.Markdown("")
121
 
122
+ with gr.Column(scale=1):
123
  gr.Markdown("### Results")
124
  output_target = gr.Audio(label="Isolated Sound (Target)")
125
  output_residual = gr.Audio(label="Background (Residual)")
126
 
127
  gr.Markdown("---")
128
+ gr.Markdown("### 🎬 Demo Examples (click to auto-process)")
 
129
 
130
  with gr.Row():
131
  if os.path.exists(EXAMPLE_FILE):
132
  example_btn1 = gr.Button("🎀 Man Speaking")
133
+ example_btn2 = gr.Button("🎀 Woman Speaking")
134
  example_btn3 = gr.Button("🎡 Background Music")
135
 
136
+ # Audio processing
 
 
 
137
  run_btn.click(
138
+ fn=lambda m, a, v, p: process_audio(m, a, p) if a else process_video(m, v, p),
139
+ inputs=[model_selector, input_audio, input_video, text_prompt],
140
  outputs=[output_target, output_residual, status_output]
141
  )
142
 
143
+ # Example buttons
144
  if os.path.exists(EXAMPLE_FILE):
145
  example_btn1.click(
146
  fn=lambda: (EXAMPLE_FILE, "A man speaking"),
147
+ outputs=[input_video, text_prompt]
148
  ).then(
149
+ fn=lambda m: process_example(m, "A man speaking"),
150
  inputs=[model_selector],
151
  outputs=[output_target, output_residual, status_output]
152
  )
153
 
154
  example_btn2.click(
155
  fn=lambda: (EXAMPLE_FILE, "A woman speaking"),
156
+ outputs=[input_video, text_prompt]
157
  ).then(
158
+ fn=lambda m: process_example(m, "A woman speaking"),
159
  inputs=[model_selector],
160
  outputs=[output_target, output_residual, status_output]
161
  )
162
 
163
  example_btn3.click(
164
  fn=lambda: (EXAMPLE_FILE, "Background music"),
165
+ outputs=[input_video, text_prompt]
166
  ).then(
167
+ fn=lambda m: process_example(m, "Background music"),
168
  inputs=[model_selector],
169
  outputs=[output_target, output_residual, status_output]
170
  )