Carley1234 commited on
Commit
c8ec810
·
verified ·
1 Parent(s): a434ec5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -27
app.py CHANGED
@@ -8,11 +8,7 @@ import numpy as np
8
  import scipy.io.wavfile
9
  from fastapi import FastAPI, HTTPException, Form
10
  from fastapi.middleware.cors import CORSMiddleware
11
- try:
12
- from transformers import AutoProcessor, AudioGenForConditionalGeneration
13
- except ImportError:
14
- # Fallback for some transformer versions or environment quirks
15
- from transformers import AutoProcessor, AutoModel as AudioGenForConditionalGeneration
16
  from supabase import create_client, Client
17
 
18
  app = FastAPI()
@@ -38,21 +34,18 @@ supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
38
  # --- Model Loading ---
39
  device = "cpu"
40
  model_id = "facebook/audiogen-medium"
41
- processor = None
42
- model = None
43
  load_error = None
44
  is_processing = False
45
 
46
  def load_models():
47
- global processor, model, load_error
48
  try:
49
  # Limit CPU threads BEFORE loading to avoid killing the container
50
  torch.set_num_threads(1)
51
- print(f"Loading model {model_id}...")
52
- # Use explicit classes for better control on free CPU resources
53
- processor = AutoProcessor.from_pretrained(model_id)
54
- model = AudioGenForConditionalGeneration.from_pretrained(model_id)
55
- model.to(device)
56
 
57
  print("Model loaded successfully.")
58
  load_error = None
@@ -99,7 +92,7 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
99
  supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute()
100
 
101
  try:
102
- if model is None or processor is None:
103
  msg = f"Model not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..."
104
  raise Exception(msg)
105
 
@@ -108,21 +101,22 @@ async def generate_effect(job_id: str, prompt: str = Form(...), duration: int =
108
 
109
  def run_inference():
110
  with torch.no_grad():
111
- # Explicit generation for better control
112
- inputs = processor(text=[prompt], return_tensors="pt")
113
- audio_values = model.generate(
114
- **inputs.to(device),
115
- max_new_tokens=max_tokens,
116
- do_sample=True,
117
- temperature=1.0,
118
- top_k=250,
119
- top_p=0.99,
120
- guidance_scale=3.0
 
121
  )
122
- return audio_values[0].cpu().numpy()
123
 
124
- audio_data = await asyncio.to_thread(run_inference)
125
- sampling_rate = model.config.audio_encoder.sampling_rate
 
126
 
127
  # Ensure audio_data is a numpy array and has correct type for scipy
128
  if isinstance(audio_data, torch.Tensor):
 
8
  import scipy.io.wavfile
9
  from fastapi import FastAPI, HTTPException, Form
10
  from fastapi.middleware.cors import CORSMiddleware
11
+ from transformers import pipeline
 
 
 
 
12
  from supabase import create_client, Client
13
 
14
  app = FastAPI()
 
34
  # --- Model Loading ---
35
  device = "cpu"
36
  model_id = "facebook/audiogen-medium"
37
+ audio_pipe = None
 
38
  load_error = None
39
  is_processing = False
40
 
41
  def load_models():
42
+ global audio_pipe, load_error
43
  try:
44
  # Limit CPU threads BEFORE loading to avoid killing the container
45
  torch.set_num_threads(1)
46
+ print(f"Loading model {model_id} via pipeline...")
47
+ # Using pipeline as it handles processors and models more robustly
48
+ audio_pipe = pipeline("text-to-audio", model=model_id, device=device)
 
 
49
 
50
  print("Model loaded successfully.")
51
  load_error = None
 
92
  supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute()
93
 
94
  try:
95
+ if audio_pipe is None:
96
  msg = f"Model not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..."
97
  raise Exception(msg)
98
 
 
101
 
102
  def run_inference():
103
  with torch.no_grad():
104
+ torch.set_num_threads(1)
105
+ return audio_pipe(
106
+ prompt,
107
+ forward_params={
108
+ "max_new_tokens": max_tokens,
109
+ "do_sample": True,
110
+ "temperature": 1.0,
111
+ "top_k": 250,
112
+ "top_p": 0.99,
113
+ "guidance_scale": 3.0
114
+ }
115
  )
 
116
 
117
+ result = await asyncio.to_thread(run_inference)
118
+ sampling_rate = result["sampling_rate"]
119
+ audio_data = result["audio"]
120
 
121
  # Ensure audio_data is a numpy array and has correct type for scipy
122
  if isinstance(audio_data, torch.Tensor):