Peter Shi commited on
Commit
341403e
·
1 Parent(s): 76cf598

Fix: lazy load model inside GPU context for ZeroGPU

Browse files
Files changed (1) hide show
  1. app.py +19 -14
app.py CHANGED
@@ -17,18 +17,20 @@ from sam_audio import SAMAudio, SAMAudioProcessor
17
 
18
  # Configuration
19
  MODEL_NAME = "facebook/sam-audio-small"
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
- print(f"Loading {MODEL_NAME} on {device}...")
23
 
24
- # Load Model and Processor
25
- try:
26
- model = SAMAudio.from_pretrained(MODEL_NAME).to(device).eval()
27
- processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
28
- print("Model loaded successfully.")
29
- except Exception as e:
30
- print(f"Error loading model. Did you set HF_TOKEN in secrets? Error: {e}")
31
- raise e
 
 
 
32
 
33
  def save_audio(tensor, sample_rate):
34
  """Helper to save torch tensor to a temp file for Gradio output."""
@@ -45,19 +47,22 @@ def separate_audio(audio_path, text_prompt):
45
  if not audio_path:
46
  return None, None
47
 
 
 
 
48
  # Process Inputs
49
  inputs = processor(
50
  audios=[audio_path],
51
  descriptions=[text_prompt]
52
- ).to(device)
53
 
54
  # Inference
55
  with torch.no_grad():
56
  result = model.separate(inputs)
57
 
58
  # Extract Outputs
59
- target_audio = result.target[0] # The sound you asked for
60
- residual_audio = result.residual[0] # Everything else
61
 
62
  # Get sampling rate from the processor config
63
  sr = processor.feature_extractor.sampling_rate
@@ -84,7 +89,7 @@ with gr.Blocks(title="SAM-Audio Demo") as demo:
84
  input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
85
  text_prompt = gr.Textbox(
86
  label="Text Prompt",
87
- placeholder="e.g., 'dog barking', 'man speaking', 'typing keyboard'",
88
  info="Describe the sound you want to isolate."
89
  )
90
  run_btn = gr.Button("Separate Audio", variant="primary")
 
17
 
18
  # Configuration
19
  MODEL_NAME = "facebook/sam-audio-small"
 
20
 
21
+ print(f"Loading {MODEL_NAME} processor...")
22
 
23
+ # Load Processor only (model will be loaded on GPU when needed)
24
+ processor = SAMAudioProcessor.from_pretrained(MODEL_NAME)
25
+ model = None # Will be loaded lazily
26
+
27
+ def get_model():
28
+ global model
29
+ if model is None:
30
+ print(f"Loading model to CUDA...")
31
+ model = SAMAudio.from_pretrained(MODEL_NAME).to("cuda").eval()
32
+ print("Model loaded successfully.")
33
+ return model
34
 
35
  def save_audio(tensor, sample_rate):
36
  """Helper to save torch tensor to a temp file for Gradio output."""
 
47
  if not audio_path:
48
  return None, None
49
 
50
+ # Load model inside GPU context
51
+ model = get_model()
52
+
53
  # Process Inputs
54
  inputs = processor(
55
  audios=[audio_path],
56
  descriptions=[text_prompt]
57
+ ).to("cuda")
58
 
59
  # Inference
60
  with torch.no_grad():
61
  result = model.separate(inputs)
62
 
63
  # Extract Outputs
64
+ target_audio = result.target[0]
65
+ residual_audio = result.residual[0]
66
 
67
  # Get sampling rate from the processor config
68
  sr = processor.feature_extractor.sampling_rate
 
89
  input_audio = gr.Audio(label="Upload Input Audio", type="filepath")
90
  text_prompt = gr.Textbox(
91
  label="Text Prompt",
92
+ placeholder="e.g., 'drums', 'vocals', 'speech', 'piano'",
93
  info="Describe the sound you want to isolate."
94
  )
95
  run_btn = gr.Button("Separate Audio", variant="primary")