efectos / app.py
Carley1234's picture
Update app.py
f9e7461 verified
import os
import io
import asyncio
import base64
import datetime
import torch
import numpy as np
import scipy.io.wavfile
from fastapi import FastAPI, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline, AutoProcessor, AudioGenForConditionalGeneration
from supabase import create_client, Client
app = FastAPI()
# --- CORS Configuration ---
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Supabase Configuration ---
SUPABASE_URL = os.environ.get("SUPABASE_URL", "https://tladrluezsmmhjbhupgb.supabase.co")
SUPABASE_KEY = os.environ.get("SUPABASE_KEY", "sb_publishable_zb8TGeURLnafHWDffG9DMg_PtFO_kmv")
SERVER_ID = os.environ.get("SERVER_ID", "efectos-worker")
SERVER_URL = os.environ.get("SERVER_URL", "https://carley1234-efectos.hf.space")
SERVICE_TYPE = "effect"
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
# --- Model Loading ---
device = "cpu"
model_id = "facebook/audiogen-medium"
audio_pipe = None
load_error = None
is_processing = False
def load_models():
global audio_pipe, load_error
try:
# Limit CPU threads BEFORE loading to avoid memory/CPU spikes
torch.set_num_threads(1)
print(f"Loading model {model_id} via explicit classes...")
# We load the classes explicitly to avoid 'Unrecognized model' errors in pipeline
processor = AutoProcessor.from_pretrained(model_id)
model = AudioGenForConditionalGeneration.from_pretrained(model_id)
# Then we wrap it in a pipeline for easy generation
audio_pipe = pipeline("text-to-audio", model=model, tokenizer=processor, device=device)
print("Model loaded successfully.")
load_error = None
except Exception as e:
load_error = str(e)
print(f"Error loading model: {e}")
async def update_status(status: str = None):
global is_processing
try:
if status:
is_processing = (status == "busy")
current_status = "busy" if is_processing else "free"
data = {
"id": SERVER_ID,
"url": SERVER_URL,
"status": current_status,
"service_type": SERVICE_TYPE,
"last_heartbeat": datetime.datetime.now(datetime.timezone.utc).isoformat()
}
supabase.table("server_status").upsert(data).execute()
except Exception as e:
print(f"Error updating status: {e}")
async def heartbeat_loop():
while True:
await update_status()
await asyncio.sleep(20)
@app.on_event("startup")
async def startup_event():
# Load models in background to avoid startup timeouts
asyncio.create_task(asyncio.to_thread(load_models))
await update_status("free")
asyncio.create_task(heartbeat_loop())
@app.get("/")
async def root():
return {"message": "VidSpri Effects Worker is running", "status": "ok"}
@app.post("/generate/{job_id}")
async def generate_effect(job_id: str, prompt: str = Form(...), duration: int = Form(3)):
await update_status("busy")
supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute()
try:
if not audio_pipe:
msg = f"Model pipeline not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..."
raise Exception(msg)
# AudioGen: 50 tokens ~ 1 second of audio
max_tokens = min(int(duration) * 50, 250) # Max 5 seconds (250 tokens)
# Run inference in a separate thread to avoid blocking heartbeats
def run_inference():
with torch.no_grad():
torch.set_num_threads(1)
return audio_pipe(
prompt,
generate_kwargs={
"max_new_tokens": max_tokens,
"do_sample": True,
"temperature": 1.0,
"top_k": 250,
"top_p": 0.99,
"guidance_scale": 3.0
}
)
result = await asyncio.to_thread(run_inference)
# Convert to WAV in memory
sampling_rate = result["sampling_rate"]
audio_data = result["audio"]
# Ensure audio_data is a numpy array and has correct type for scipy
if isinstance(audio_data, torch.Tensor):
audio_data = audio_data.cpu().numpy()
# Clean data and ensure CPU numpy array
audio_data = np.nan_to_num(audio_data)
# Remove DC offset to eliminate "click" and constant hum
if audio_data.size > 0:
audio_data = audio_data - np.mean(audio_data)
# 2. Soft-clipping to prevent digital artifacts on saturation
audio_data = np.tanh(audio_data * 1.2)
# Standardize shape
if audio_data.ndim == 3:
audio_data = audio_data[0]
if audio_data.ndim == 2:
audio_data = np.mean(audio_data, axis=0)
audio_data = audio_data.flatten()
# Fade out end of clip (0.2s for effects)
fade_len = int(sampling_rate * 0.2)
if len(audio_data) > fade_len:
fade_window = np.linspace(1.0, 0.0, fade_len)
audio_data[-fade_len:] *= fade_window
# Normalize audio with headroom
max_val = np.abs(audio_data).max()
if max_val > 0:
audio_data = (audio_data / (max_val + 1e-6)) * 0.9
# Convert to 16-bit PCM with safety clamp
audio_data = np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
wav_buf = io.BytesIO()
scipy.io.wavfile.write(wav_buf, rate=sampling_rate, data=audio_data)
wav_buf.seek(0)
audio_base64 = base64.b64encode(wav_buf.read()).decode('utf-8')
supabase.table("processing_queue").update({"status": "completed"}).eq("id", job_id).execute()
await update_status("free")
return {"status": "success", "audio": audio_base64}
except Exception as e:
print(f"Generation error: {e}")
await update_status("free")
supabase.table("processing_queue").update({"status": "failed"}).eq("id", job_id).execute()
raise HTTPException(status_code=500, detail=str(e))