Peter Shi commited on
Commit
a11009d
Β·
1 Parent(s): cebdac8

Add progress bar, example only fills data, switch to video tab

Browse files
Files changed (1) hide show
  1. app.py +35 -40
app.py CHANGED
@@ -29,11 +29,13 @@ current_model_name = None
29
  model = None
30
  processor = None
31
 
32
- def load_model(model_name):
33
  global current_model_name, model, processor
34
  model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
35
  if current_model_name == model_name and model is not None:
36
  return
 
 
37
  print(f"Loading {model_id}...")
38
  model = SAMAudio.from_pretrained(model_id).to(device).eval()
39
  processor = SAMAudioProcessor.from_pretrained(model_id)
@@ -48,9 +50,10 @@ def save_audio(tensor, sample_rate):
48
  return tmp.name
49
 
50
  @spaces.GPU(duration=300)
51
- def separate_audio(model_name, file_path, text_prompt):
52
  global model, processor
53
- load_model(model_name)
 
54
 
55
  if not file_path:
56
  return None, None, "❌ Please upload an audio or video file."
@@ -58,35 +61,37 @@ def separate_audio(model_name, file_path, text_prompt):
58
  return None, None, "❌ Please enter a text prompt."
59
 
60
  try:
 
 
 
 
61
  inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
 
 
62
  with torch.inference_mode():
63
  result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
 
 
64
  sample_rate = processor.audio_sampling_rate
65
  target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
66
  residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
 
 
67
  return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name}"
68
  except Exception as e:
69
  import traceback
70
  traceback.print_exc()
71
  return None, None, f"❌ Error: {str(e)}"
72
 
73
- def process_audio(model_name, audio_path, prompt):
74
  if not audio_path:
75
  return None, None, "❌ Please upload an audio file."
76
- return separate_audio(model_name, audio_path, prompt)
77
 
78
- def process_video(model_name, video_path, prompt):
79
  if not video_path:
80
  return None, None, "❌ Please upload a video file."
81
- return separate_audio(model_name, video_path, prompt)
82
-
83
- def process_example(model_name, prompt):
84
- if not os.path.exists(EXAMPLE_FILE):
85
- return None, None, "❌ Example file not found."
86
- return separate_audio(model_name, EXAMPLE_FILE, prompt)
87
-
88
- def load_example(prompt):
89
- return EXAMPLE_FILE, prompt
90
 
91
  # Build Interface
92
  with gr.Blocks(title="SAM-Audio Demo") as demo:
@@ -105,10 +110,11 @@ with gr.Blocks(title="SAM-Audio Demo") as demo:
105
  label="Model"
106
  )
107
 
108
- with gr.Tabs():
109
- with gr.TabItem("🎡 Audio"):
 
110
  input_audio = gr.Audio(label="Upload Audio", type="filepath")
111
- with gr.TabItem("🎬 Video"):
112
  input_video = gr.Video(label="Upload Video")
113
 
114
  text_prompt = gr.Textbox(
@@ -116,7 +122,7 @@ with gr.Blocks(title="SAM-Audio Demo") as demo:
116
  placeholder="e.g., 'A man speaking', 'Piano', 'Dog barking'"
117
  )
118
 
119
- run_btn = gr.Button("🎯 Isolate Sound", variant="primary")
120
  status_output = gr.Markdown("")
121
 
122
  with gr.Column(scale=1):
@@ -125,7 +131,8 @@ with gr.Blocks(title="SAM-Audio Demo") as demo:
125
  output_residual = gr.Audio(label="Background (Residual)")
126
 
127
  gr.Markdown("---")
128
- gr.Markdown("### 🎬 Demo Examples (click to auto-process)")
 
129
 
130
  with gr.Row():
131
  if os.path.exists(EXAMPLE_FILE):
@@ -133,40 +140,28 @@ with gr.Blocks(title="SAM-Audio Demo") as demo:
133
  example_btn2 = gr.Button("🎀 Woman Speaking")
134
  example_btn3 = gr.Button("🎡 Background Music")
135
 
136
- # Audio processing
137
  run_btn.click(
138
  fn=lambda m, a, v, p: process_audio(m, a, p) if a else process_video(m, v, p),
139
  inputs=[model_selector, input_audio, input_video, text_prompt],
140
  outputs=[output_target, output_residual, status_output]
141
  )
142
 
143
- # Example buttons
144
  if os.path.exists(EXAMPLE_FILE):
145
  example_btn1.click(
146
- fn=lambda: (EXAMPLE_FILE, "A man speaking"),
147
- outputs=[input_video, text_prompt]
148
- ).then(
149
- fn=lambda m: process_example(m, "A man speaking"),
150
- inputs=[model_selector],
151
- outputs=[output_target, output_residual, status_output]
152
  )
153
 
154
  example_btn2.click(
155
- fn=lambda: (EXAMPLE_FILE, "A woman speaking"),
156
- outputs=[input_video, text_prompt]
157
- ).then(
158
- fn=lambda m: process_example(m, "A woman speaking"),
159
- inputs=[model_selector],
160
- outputs=[output_target, output_residual, status_output]
161
  )
162
 
163
  example_btn3.click(
164
- fn=lambda: (EXAMPLE_FILE, "Background music"),
165
- outputs=[input_video, text_prompt]
166
- ).then(
167
- fn=lambda m: process_example(m, "Background music"),
168
- inputs=[model_selector],
169
- outputs=[output_target, output_residual, status_output]
170
  )
171
 
172
  if __name__ == "__main__":
 
29
  model = None
30
  processor = None
31
 
32
+ def load_model(model_name, progress=None):
33
  global current_model_name, model, processor
34
  model_id = MODELS.get(model_name, MODELS[DEFAULT_MODEL])
35
  if current_model_name == model_name and model is not None:
36
  return
37
+ if progress:
38
+ progress(0.1, desc="Loading model...")
39
  print(f"Loading {model_id}...")
40
  model = SAMAudio.from_pretrained(model_id).to(device).eval()
41
  processor = SAMAudioProcessor.from_pretrained(model_id)
 
50
  return tmp.name
51
 
52
  @spaces.GPU(duration=300)
53
+ def separate_audio(model_name, file_path, text_prompt, progress=gr.Progress()):
54
  global model, processor
55
+
56
+ progress(0.1, desc="Checking inputs...")
57
 
58
  if not file_path:
59
  return None, None, "❌ Please upload an audio or video file."
 
61
  return None, None, "❌ Please enter a text prompt."
62
 
63
  try:
64
+ progress(0.2, desc="Loading model...")
65
+ load_model(model_name)
66
+
67
+ progress(0.4, desc="Processing audio...")
68
  inputs = processor(audios=[file_path], descriptions=[text_prompt.strip()]).to(device)
69
+
70
+ progress(0.6, desc="Separating sounds...")
71
  with torch.inference_mode():
72
  result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
73
+
74
+ progress(0.8, desc="Saving results...")
75
  sample_rate = processor.audio_sampling_rate
76
  target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
77
  residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
78
+
79
+ progress(1.0, desc="Done!")
80
  return target_path, residual_path, f"βœ… Isolated '{text_prompt}' using {model_name}"
81
  except Exception as e:
82
  import traceback
83
  traceback.print_exc()
84
  return None, None, f"❌ Error: {str(e)}"
85
 
86
+ def process_audio(model_name, audio_path, prompt, progress=gr.Progress()):
87
  if not audio_path:
88
  return None, None, "❌ Please upload an audio file."
89
+ return separate_audio(model_name, audio_path, prompt, progress)
90
 
91
+ def process_video(model_name, video_path, prompt, progress=gr.Progress()):
92
  if not video_path:
93
  return None, None, "❌ Please upload a video file."
94
+ return separate_audio(model_name, video_path, prompt, progress)
 
 
 
 
 
 
 
 
95
 
96
  # Build Interface
97
  with gr.Blocks(title="SAM-Audio Demo") as demo:
 
110
  label="Model"
111
  )
112
 
113
+ tabs = gr.Tabs()
114
+ with tabs:
115
+ with gr.TabItem("🎡 Audio", id=0):
116
  input_audio = gr.Audio(label="Upload Audio", type="filepath")
117
+ with gr.TabItem("🎬 Video", id=1):
118
  input_video = gr.Video(label="Upload Video")
119
 
120
  text_prompt = gr.Textbox(
 
122
  placeholder="e.g., 'A man speaking', 'Piano', 'Dog barking'"
123
  )
124
 
125
+ run_btn = gr.Button("🎯 Isolate Sound", variant="primary", size="lg")
126
  status_output = gr.Markdown("")
127
 
128
  with gr.Column(scale=1):
 
131
  output_residual = gr.Audio(label="Background (Residual)")
132
 
133
  gr.Markdown("---")
134
+ gr.Markdown("### 🎬 Demo Examples")
135
+ gr.Markdown("Click to load example, then click 'Isolate Sound' to process:")
136
 
137
  with gr.Row():
138
  if os.path.exists(EXAMPLE_FILE):
 
140
  example_btn2 = gr.Button("🎀 Woman Speaking")
141
  example_btn3 = gr.Button("🎡 Background Music")
142
 
143
+ # Main process button - check which tab has content
144
  run_btn.click(
145
  fn=lambda m, a, v, p: process_audio(m, a, p) if a else process_video(m, v, p),
146
  inputs=[model_selector, input_audio, input_video, text_prompt],
147
  outputs=[output_target, output_residual, status_output]
148
  )
149
 
150
+ # Example buttons - only fill in data, switch to video tab
151
  if os.path.exists(EXAMPLE_FILE):
152
  example_btn1.click(
153
+ fn=lambda: (EXAMPLE_FILE, "A man speaking", gr.Tabs(selected=1)),
154
+ outputs=[input_video, text_prompt, tabs]
 
 
 
 
155
  )
156
 
157
  example_btn2.click(
158
+ fn=lambda: (EXAMPLE_FILE, "A woman speaking", gr.Tabs(selected=1)),
159
+ outputs=[input_video, text_prompt, tabs]
 
 
 
 
160
  )
161
 
162
  example_btn3.click(
163
+ fn=lambda: (EXAMPLE_FILE, "Background music", gr.Tabs(selected=1)),
164
+ outputs=[input_video, text_prompt, tabs]
 
 
 
 
165
  )
166
 
167
  if __name__ == "__main__":