Peter Shi commited on
Commit
b02c18a
·
1 Parent(s): 79ced89

Fix: follow official HF example exactly

Browse files
Files changed (1) hide show
  1. app.py +21 -24
app.py CHANGED
@@ -11,19 +11,15 @@ from sam_audio import SAMAudio, SAMAudioProcessor
11
  # Configuration
12
  MODEL_NAME = "facebook/sam-audio-small"
13
 
14
- # Global model and processor
15
  print(f"Loading {MODEL_NAME}...")
16
- model = SAMAudio.from_pretrained(MODEL_NAME)
 
17
  processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
18
- model = model.eval().cuda()
19
- print("Model loaded on CUDA.")
20
 
21
  def save_audio(tensor, sample_rate):
22
  """Helper to save torch tensor to a temp file for Gradio output."""
23
- if tensor.dim() == 1:
24
- tensor = tensor.unsqueeze(0)
25
- tensor = tensor.detach().cpu()
26
-
27
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
28
  torchaudio.save(tmp.name, tensor, sample_rate)
29
  return tmp.name
@@ -37,22 +33,19 @@ def separate_audio(audio_path, text_prompt):
37
  text_prompt = "vocals"
38
 
39
  try:
40
- # Process Inputs (following official example)
41
- batch = processor(
42
  audios=[audio_path],
43
  descriptions=[text_prompt.strip()]
44
- ).to("cuda")
45
-
46
- # Inference using inference_mode (as per official docs)
47
  with torch.inference_mode():
48
- result = model.separate(batch, predict_spans=False, reranking_candidates=1)
49
-
50
- # Get sampling rate
51
  sample_rate = processor.audio_sampling_rate
52
-
53
- # Save to files
54
- target_path = save_audio(result.target, sample_rate)
55
- residual_path = save_audio(result.residual, sample_rate)
56
 
57
  return target_path, residual_path, f"✅ Successfully separated '{text_prompt}' from the audio."
58
 
@@ -81,8 +74,8 @@ with gr.Blocks(
81
  input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
82
  text_prompt = gr.Textbox(
83
  label="Text Prompt",
84
- placeholder="e.g., 'drums', 'vocals', 'A man speaking'",
85
- value="drums",
86
  info="Describe the sound you want to isolate."
87
  )
88
  run_btn = gr.Button("🎯 Separate Audio", variant="primary", size="lg")
@@ -101,8 +94,12 @@ with gr.Blocks(
101
 
102
  gr.Markdown(
103
  """
104
- ### Tips
105
- - Use prompts like: `drums`, `vocals`, `A man speaking`, `piano`, `guitar`
 
 
 
 
106
  """
107
  )
108
 
 
11
  # Configuration
12
  MODEL_NAME = "facebook/sam-audio-small"
13
 
14
+ # Load model and processor (following official HuggingFace example)
15
  print(f"Loading {MODEL_NAME}...")
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval()
18
  processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
19
+ print(f"Model loaded on {device}.")
 
20
 
21
  def save_audio(tensor, sample_rate):
22
  """Helper to save torch tensor to a temp file for Gradio output."""
 
 
 
 
23
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
24
  torchaudio.save(tmp.name, tensor, sample_rate)
25
  return tmp.name
 
33
  text_prompt = "vocals"
34
 
35
  try:
36
+ # Process and separate (following official example)
37
+ inputs = processor(
38
  audios=[audio_path],
39
  descriptions=[text_prompt.strip()]
40
+ ).to(device)
41
+
 
42
  with torch.inference_mode():
43
+ result = model.separate(inputs, predict_spans=False, reranking_candidates=1)
44
+
45
+ # Save results (following official example: result.target[0].unsqueeze(0).cpu())
46
  sample_rate = processor.audio_sampling_rate
47
+ target_path = save_audio(result.target[0].unsqueeze(0).cpu(), sample_rate)
48
+ residual_path = save_audio(result.residual[0].unsqueeze(0).cpu(), sample_rate)
 
 
49
 
50
  return target_path, residual_path, f"✅ Successfully separated '{text_prompt}' from the audio."
51
 
 
74
  input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
75
  text_prompt = gr.Textbox(
76
  label="Text Prompt",
77
+ placeholder="e.g., 'A man speaking', 'Piano playing', 'Dog barking'",
78
+ value="A man speaking",
79
  info="Describe the sound you want to isolate."
80
  )
81
  run_btn = gr.Button("🎯 Separate Audio", variant="primary", size="lg")
 
94
 
95
  gr.Markdown(
96
  """
97
+ ### Example Prompts
98
+ - "A person coughing"
99
+ - "Piano playing a melody"
100
+ - "Dog barking"
101
+ - "Car engine revving"
102
+ - "Raindrops falling"
103
  """
104
  )
105