Carley1234 commited on
Commit
14825a2
·
verified ·
1 Parent(s): 23c9257

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -36
app.py CHANGED
@@ -8,7 +8,7 @@ import numpy as np
8
  import scipy.io.wavfile
9
  from fastapi import FastAPI, HTTPException, Form
10
  from fastapi.middleware.cors import CORSMiddleware
11
- from transformers import pipeline
12
  from supabase import create_client, Client
13
 
14
  app = FastAPI()
@@ -34,17 +34,19 @@ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
34
  # --- Model Loading ---
35
  device = "cpu"
36
  model_id = "facebook/audiogen-medium"
37
- audio_pipe = None
38
  load_error = None
39
  is_processing = False
40
 
41
  def load_models():
42
- global audio_pipe, load_error
43
  try:
44
- # Limit CPU threads BEFORE loading to avoid memory/CPU spikes
45
  torch.set_num_threads(1)
46
- print(f"Loading model {model_id} via pipeline...")
47
- audio_pipe = pipeline("text-to-audio", model=model_id, device=device)
 
 
48
 
49
  print("Model loaded successfully.")
50
  load_error = None
@@ -91,39 +93,26 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
91
  supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute()
92
 
93
  try:
94
- if not audio_pipe:
95
- msg = f"Model pipeline not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..."
96
  raise Exception(msg)
97
 
98
- # AudioGen: 50 tokens ~ 1 second of audio
99
- max_tokens = min(int(duration) * 50, 250) # Max 5 seconds (250 tokens)
100
-
101
- # Run inference in a separate thread to avoid blocking heartbeats
102
  def run_inference():
103
- # Force no_grad and limit threads again just in case
104
  with torch.no_grad():
105
  torch.set_num_threads(1)
106
- return audio_pipe(
107
- prompt,
108
- generate_kwargs={
109
- "max_new_tokens": max_tokens,
110
- "do_sample": True,
111
- "temperature": 1.0,
112
- "top_k": 250,
113
- "top_p": 0.99,
114
- "guidance_scale": 3.0
115
- }
116
  )
 
 
117
 
118
- result = await asyncio.to_thread(run_inference)
119
-
120
- # Convert to WAV in memory
121
- sampling_rate = result["sampling_rate"]
122
- audio_data = result["audio"]
123
-
124
- # Ensure audio_data is a numpy array and has correct type for scipy
125
- if isinstance(audio_data, torch.Tensor):
126
- audio_data = audio_data.cpu().numpy()
127
 
128
  # Clean data and ensure CPU numpy array
129
  audio_data = np.nan_to_num(audio_data)
@@ -132,7 +121,7 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
132
  if audio_data.size > 0:
133
  audio_data = audio_data - np.mean(audio_data)
134
 
135
- # 2. Soft-clipping to prevent digital artifacts on saturation
136
  audio_data = np.tanh(audio_data * 1.2)
137
 
138
  # Standardize shape
@@ -144,18 +133,18 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
144
 
145
  audio_data = audio_data.flatten()
146
 
147
- # Fade out end of clip (0.2s for effects)
148
  fade_len = int(sampling_rate * 0.2)
149
  if len(audio_data) > fade_len:
150
  fade_window = np.linspace(1.0, 0.0, fade_len)
151
  audio_data[-fade_len:] *= fade_window
152
 
153
- # Normalize audio with headroom
154
  max_val = np.abs(audio_data).max()
155
  if max_val > 0:
156
  audio_data = (audio_data / (max_val + 1e-6)) * 0.9
157
 
158
- # Convert to 16-bit PCM with safety clamp
159
  audio_data = np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
160
 
161
  wav_buf = io.BytesIO()
 
8
  import scipy.io.wavfile
9
  from fastapi import FastAPI, HTTPException, Form
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from audiocraft.models import AudioGen
12
  from supabase import create_client, Client
13
 
14
  app = FastAPI()
 
34
  # --- Model Loading ---
35
  device = "cpu"
36
  model_id = "facebook/audiogen-medium"
37
+ model = None
38
  load_error = None
39
  is_processing = False
40
 
41
  def load_models():
42
+ global model, load_error
43
  try:
44
+ # Limit CPU threads BEFORE loading to avoid killing the container
45
  torch.set_num_threads(1)
46
+ print(f"Loading model {model_id} via Audiocraft...")
47
+
48
+ # Native Audiocraft loading
49
+ model = AudioGen.get_pretrained(model_id)
50
 
51
  print("Model loaded successfully.")
52
  load_error = None
 
93
  supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute()
94
 
95
  try:
96
+ if model is None:
97
+ msg = f"Model not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..."
98
  raise Exception(msg)
99
 
 
 
 
 
100
  def run_inference():
 
101
  with torch.no_grad():
102
  torch.set_num_threads(1)
103
+ model.set_generation_params(
104
+ duration=min(int(duration), 5),
105
+ use_sampling=True,
106
+ temp=1.0,
107
+ top_k=250,
108
+ top_p=0.99,
109
+ cfg_coef=3.0
 
 
 
110
  )
111
+ wav = model.generate([prompt])
112
+ return wav[0].cpu().numpy()
113
 
114
+ audio_data = await asyncio.to_thread(run_inference)
115
+ sampling_rate = model.sample_rate
 
 
 
 
 
 
 
116
 
117
  # Clean data and ensure CPU numpy array
118
  audio_data = np.nan_to_num(audio_data)
 
121
  if audio_data.size > 0:
122
  audio_data = audio_data - np.mean(audio_data)
123
 
124
+ # Soft-clipping/Limiter
125
  audio_data = np.tanh(audio_data * 1.2)
126
 
127
  # Standardize shape
 
133
 
134
  audio_data = audio_data.flatten()
135
 
136
+ # Fade out (0.2s)
137
  fade_len = int(sampling_rate * 0.2)
138
  if len(audio_data) > fade_len:
139
  fade_window = np.linspace(1.0, 0.0, fade_len)
140
  audio_data[-fade_len:] *= fade_window
141
 
142
+ # Normalize with headroom
143
  max_val = np.abs(audio_data).max()
144
  if max_val > 0:
145
  audio_data = (audio_data / (max_val + 1e-6)) * 0.9
146
 
147
+ # Convert to 16-bit PCM
148
  audio_data = np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
149
 
150
  wav_buf = io.BytesIO()