Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,11 +8,7 @@ import numpy as np
|
|
| 8 |
import scipy.io.wavfile
|
| 9 |
from fastapi import FastAPI, HTTPException, Form
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
-
|
| 12 |
-
from transformers import AutoProcessor, AudioGenForConditionalGeneration
|
| 13 |
-
except ImportError:
|
| 14 |
-
# Fallback for some transformer versions or environment quirks
|
| 15 |
-
from transformers import AutoProcessor, AutoModel as AudioGenForConditionalGeneration
|
| 16 |
from supabase import create_client, Client
|
| 17 |
|
| 18 |
app = FastAPI()
|
|
@@ -38,21 +34,18 @@ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
|
| 38 |
# --- Model Loading ---
|
| 39 |
device = "cpu"
|
| 40 |
model_id = "facebook/audiogen-medium"
|
| 41 |
-
|
| 42 |
-
model = None
|
| 43 |
load_error = None
|
| 44 |
is_processing = False
|
| 45 |
|
| 46 |
def load_models():
|
| 47 |
-
global
|
| 48 |
try:
|
| 49 |
# Limit CPU threads BEFORE loading to avoid killing the container
|
| 50 |
torch.set_num_threads(1)
|
| 51 |
-
print(f"Loading model {model_id}...")
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
-
model = AudioGenForConditionalGeneration.from_pretrained(model_id)
|
| 55 |
-
model.to(device)
|
| 56 |
|
| 57 |
print("Model loaded successfully.")
|
| 58 |
load_error = None
|
|
@@ -99,7 +92,7 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
|
|
| 99 |
supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute()
|
| 100 |
|
| 101 |
try:
|
| 102 |
-
if
|
| 103 |
msg = f"Model not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..."
|
| 104 |
raise Exception(msg)
|
| 105 |
|
|
@@ -108,21 +101,22 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
|
|
| 108 |
|
| 109 |
def run_inference():
|
| 110 |
with torch.no_grad():
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
|
|
|
| 121 |
)
|
| 122 |
-
return audio_values[0].cpu().numpy()
|
| 123 |
|
| 124 |
-
|
| 125 |
-
sampling_rate =
|
|
|
|
| 126 |
|
| 127 |
# Ensure audio_data is a numpy array and has correct type for scipy
|
| 128 |
if isinstance(audio_data, torch.Tensor):
|
|
|
|
| 8 |
import scipy.io.wavfile
|
| 9 |
from fastapi import FastAPI, HTTPException, Form
|
| 10 |
from fastapi.middleware.cors import CORSMiddleware
|
| 11 |
+
from transformers import pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
from supabase import create_client, Client
|
| 13 |
|
| 14 |
app = FastAPI()
|
|
|
|
| 34 |
# --- Model Loading ---
|
| 35 |
device = "cpu"
|
| 36 |
model_id = "facebook/audiogen-medium"
|
| 37 |
+
audio_pipe = None
|
|
|
|
| 38 |
load_error = None
|
| 39 |
is_processing = False
|
| 40 |
|
| 41 |
def load_models():
|
| 42 |
+
global audio_pipe, load_error
|
| 43 |
try:
|
| 44 |
# Limit CPU threads BEFORE loading to avoid killing the container
|
| 45 |
torch.set_num_threads(1)
|
| 46 |
+
print(f"Loading model {model_id} via pipeline...")
|
| 47 |
+
# Using pipeline as it handles processors and models more robustly
|
| 48 |
+
audio_pipe = pipeline("text-to-audio", model=model_id, device=device)
|
|
|
|
|
|
|
| 49 |
|
| 50 |
print("Model loaded successfully.")
|
| 51 |
load_error = None
|
|
|
|
| 92 |
supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute()
|
| 93 |
|
| 94 |
try:
|
| 95 |
+
if audio_pipe is None:
|
| 96 |
msg = f"Model not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..."
|
| 97 |
raise Exception(msg)
|
| 98 |
|
|
|
|
| 101 |
|
| 102 |
def run_inference():
|
| 103 |
with torch.no_grad():
|
| 104 |
+
torch.set_num_threads(1)
|
| 105 |
+
return audio_pipe(
|
| 106 |
+
prompt,
|
| 107 |
+
forward_params={
|
| 108 |
+
"max_new_tokens": max_tokens,
|
| 109 |
+
"do_sample": True,
|
| 110 |
+
"temperature": 1.0,
|
| 111 |
+
"top_k": 250,
|
| 112 |
+
"top_p": 0.99,
|
| 113 |
+
"guidance_scale": 3.0
|
| 114 |
+
}
|
| 115 |
)
|
|
|
|
| 116 |
|
| 117 |
+
result = await asyncio.to_thread(run_inference)
|
| 118 |
+
sampling_rate = result["sampling_rate"]
|
| 119 |
+
audio_data = result["audio"]
|
| 120 |
|
| 121 |
# Ensure audio_data is a numpy array and has correct type for scipy
|
| 122 |
if isinstance(audio_data, torch.Tensor):
|