Peter Shi commited on
Commit
79ced89
Β·
1 Parent(s): e299ffc

Follow official sam-audio example exactly

Browse files
Files changed (1) hide show
  1. app.py +19 -23
app.py CHANGED
@@ -12,11 +12,11 @@ from sam_audio import SAMAudio, SAMAudioProcessor
12
  MODEL_NAME = "facebook/sam-audio-small"
13
 
14
  # Global model and processor
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- print(f"Loading {MODEL_NAME} on {device}...")
17
- model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval()
18
  processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
19
- print("Model loaded successfully.")
 
20
 
21
  def save_audio(tensor, sample_rate):
22
  """Helper to save torch tensor to a temp file for Gradio output."""
@@ -28,7 +28,7 @@ def save_audio(tensor, sample_rate):
28
  torchaudio.save(tmp.name, tensor, sample_rate)
29
  return tmp.name
30
 
31
- @spaces.GPU(duration=180)
32
  def separate_audio(audio_path, text_prompt):
33
  if not audio_path:
34
  return None, None, "❌ Please upload an audio file."
@@ -37,30 +37,28 @@ def separate_audio(audio_path, text_prompt):
37
  text_prompt = "vocals"
38
 
39
  try:
40
- # Process Inputs
41
- inputs = processor(
42
  audios=[audio_path],
43
  descriptions=[text_prompt.strip()]
44
- ).to(device)
45
 
46
- # Inference
47
- with torch.no_grad():
48
- result = model.separate(inputs)
49
 
50
- # Extract Outputs
51
- target_audio = result.target[0]
52
- residual_audio = result.residual[0]
53
-
54
- # Get sampling rate from the processor config
55
- sr = processor.feature_extractor.sampling_rate
56
 
57
  # Save to files
58
- target_path = save_audio(target_audio, sr)
59
- residual_path = save_audio(residual_audio, sr)
60
 
61
  return target_path, residual_path, f"βœ… Successfully separated '{text_prompt}' from the audio."
62
 
63
  except Exception as e:
 
 
64
  return None, None, f"❌ Error: {str(e)}"
65
 
66
  # Build Gradio Interface
@@ -83,7 +81,7 @@ with gr.Blocks(
83
  input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
84
  text_prompt = gr.Textbox(
85
  label="Text Prompt",
86
- placeholder="e.g., 'drums', 'vocals', 'speech', 'piano'",
87
  value="drums",
88
  info="Describe the sound you want to isolate."
89
  )
@@ -104,9 +102,7 @@ with gr.Blocks(
104
  gr.Markdown(
105
  """
106
  ### Tips
107
- - Use prompts like: `drums`, `vocals`, `speech`, `piano`, `guitar`, `bass`, `synth`
108
- - For mixed audio with speech, try: `man speaking`, `woman singing`
109
- - GPU recommended for faster inference
110
  """
111
  )
112
 
 
12
  MODEL_NAME = "facebook/sam-audio-small"
13
 
14
  # Global model and processor
15
+ print(f"Loading {MODEL_NAME}...")
16
+ model = SAMAudio.from_pretrained(MODEL_NAME)
 
17
  processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
18
+ model = model.eval().cuda()
19
+ print("Model loaded on CUDA.")
20
 
21
  def save_audio(tensor, sample_rate):
22
  """Helper to save torch tensor to a temp file for Gradio output."""
 
28
  torchaudio.save(tmp.name, tensor, sample_rate)
29
  return tmp.name
30
 
31
+ @spaces.GPU(duration=300)
32
  def separate_audio(audio_path, text_prompt):
33
  if not audio_path:
34
  return None, None, "❌ Please upload an audio file."
 
37
  text_prompt = "vocals"
38
 
39
  try:
40
+ # Process Inputs (following official example)
41
+ batch = processor(
42
  audios=[audio_path],
43
  descriptions=[text_prompt.strip()]
44
+ ).to("cuda")
45
 
46
+ # Inference using inference_mode (as per official docs)
47
+ with torch.inference_mode():
48
+ result = model.separate(batch, predict_spans=False, reranking_candidates=1)
49
 
50
+ # Get sampling rate
51
+ sample_rate = processor.audio_sampling_rate
 
 
 
 
52
 
53
  # Save to files
54
+ target_path = save_audio(result.target, sample_rate)
55
+ residual_path = save_audio(result.residual, sample_rate)
56
 
57
  return target_path, residual_path, f"βœ… Successfully separated '{text_prompt}' from the audio."
58
 
59
  except Exception as e:
60
+ import traceback
61
+ traceback.print_exc()
62
  return None, None, f"❌ Error: {str(e)}"
63
 
64
  # Build Gradio Interface
 
81
  input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
82
  text_prompt = gr.Textbox(
83
  label="Text Prompt",
84
+ placeholder="e.g., 'drums', 'vocals', 'A man speaking'",
85
  value="drums",
86
  info="Describe the sound you want to isolate."
87
  )
 
102
  gr.Markdown(
103
  """
104
  ### Tips
105
+ - Use prompts like: `drums`, `vocals`, `A man speaking`, `piano`, `guitar`
 
 
106
  """
107
  )
108