Spaces:
Runtime error
Runtime error
File size: 6,297 Bytes
b642640 8f336f4 b642640 90900a9 b642640 f9e7461 b642640 ce84bf1 b642640 8adfead f9e7461 a818820 b642640 a7c8314 f9e7461 a7c8314 f9e7461 a434ec5 f9e7461 14825a2 f9e7461 426a631 a7c8314 a818820 a7c8314 a818820 a7c8314 b642640 a7c8314 b642640 660552c b642640 f9e7461 a818820 b642640 f9e7461 8727b92 c8ec810 f9e7461 a0240fe 8727b92 f9e7461 8f336f4 a7c8314 a0240fe 8f336f4 a7c8314 f9e7461 8adfead a7c8314 a0240fe a7c8314 a0240fe a7c8314 f9e7461 a7c8314 f9e7461 9fae321 a7c8314 9fae321 f9e7461 a7c8314 9fae321 b642640 8f336f4 b642640 8727b92 b642640 a818820 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | import os
import io
import asyncio
import base64
import datetime
import torch
import numpy as np
import scipy.io.wavfile
from fastapi import FastAPI, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline, AutoProcessor, AudioGenForConditionalGeneration
from supabase import create_client, Client
app = FastAPI()
# --- CORS Configuration ---
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Supabase Configuration ---
SUPABASE_URL = os.environ.get("SUPABASE_URL", "https://tladrluezsmmhjbhupgb.supabase.co")
SUPABASE_KEY = os.environ.get("SUPABASE_KEY", "sb_publishable_zb8TGeURLnafHWDffG9DMg_PtFO_kmv")
SERVER_ID = os.environ.get("SERVER_ID", "efectos-worker")
SERVER_URL = os.environ.get("SERVER_URL", "https://carley1234-efectos.hf.space")
SERVICE_TYPE = "effect"
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)
# --- Model Loading ---
device = "cpu"
model_id = "facebook/audiogen-medium"
audio_pipe = None
load_error = None
is_processing = False
def load_models():
global audio_pipe, load_error
try:
# Limit CPU threads BEFORE loading to avoid memory/CPU spikes
torch.set_num_threads(1)
print(f"Loading model {model_id} via explicit classes...")
# We load the classes explicitly to avoid 'Unrecognized model' errors in pipeline
processor = AutoProcessor.from_pretrained(model_id)
model = AudioGenForConditionalGeneration.from_pretrained(model_id)
# Then we wrap it in a pipeline for easy generation
audio_pipe = pipeline("text-to-audio", model=model, tokenizer=processor, device=device)
print("Model loaded successfully.")
load_error = None
except Exception as e:
load_error = str(e)
print(f"Error loading model: {e}")
async def update_status(status: str = None):
global is_processing
try:
if status:
is_processing = (status == "busy")
current_status = "busy" if is_processing else "free"
data = {
"id": SERVER_ID,
"url": SERVER_URL,
"status": current_status,
"service_type": SERVICE_TYPE,
"last_heartbeat": datetime.datetime.now(datetime.timezone.utc).isoformat()
}
supabase.table("server_status").upsert(data).execute()
except Exception as e:
print(f"Error updating status: {e}")
async def heartbeat_loop():
while True:
await update_status()
await asyncio.sleep(20)
@app.on_event("startup")
async def startup_event():
# Load models in background to avoid startup timeouts
asyncio.create_task(asyncio.to_thread(load_models))
await update_status("free")
asyncio.create_task(heartbeat_loop())
@app.get("/")
async def root():
return {"message": "VidSpri Effects Worker is running", "status": "ok"}
@app.post("/generate/{job_id}")
async def generate_effect(job_id: str, prompt: str = Form(...), duration: int = Form(3)):
await update_status("busy")
supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute()
try:
if not audio_pipe:
msg = f"Model pipeline not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..."
raise Exception(msg)
# AudioGen: 50 tokens ~ 1 second of audio
max_tokens = min(int(duration) * 50, 250) # Max 5 seconds (250 tokens)
# Run inference in a separate thread to avoid blocking heartbeats
def run_inference():
with torch.no_grad():
torch.set_num_threads(1)
return audio_pipe(
prompt,
generate_kwargs={
"max_new_tokens": max_tokens,
"do_sample": True,
"temperature": 1.0,
"top_k": 250,
"top_p": 0.99,
"guidance_scale": 3.0
}
)
result = await asyncio.to_thread(run_inference)
# Convert to WAV in memory
sampling_rate = result["sampling_rate"]
audio_data = result["audio"]
# Ensure audio_data is a numpy array and has correct type for scipy
if isinstance(audio_data, torch.Tensor):
audio_data = audio_data.cpu().numpy()
# Clean data and ensure CPU numpy array
audio_data = np.nan_to_num(audio_data)
# Remove DC offset to eliminate "click" and constant hum
if audio_data.size > 0:
audio_data = audio_data - np.mean(audio_data)
# 2. Soft-clipping to prevent digital artifacts on saturation
audio_data = np.tanh(audio_data * 1.2)
# Standardize shape
if audio_data.ndim == 3:
audio_data = audio_data[0]
if audio_data.ndim == 2:
audio_data = np.mean(audio_data, axis=0)
audio_data = audio_data.flatten()
# Fade out end of clip (0.2s for effects)
fade_len = int(sampling_rate * 0.2)
if len(audio_data) > fade_len:
fade_window = np.linspace(1.0, 0.0, fade_len)
audio_data[-fade_len:] *= fade_window
# Normalize audio with headroom
max_val = np.abs(audio_data).max()
if max_val > 0:
audio_data = (audio_data / (max_val + 1e-6)) * 0.9
# Convert to 16-bit PCM with safety clamp
audio_data = np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)
wav_buf = io.BytesIO()
scipy.io.wavfile.write(wav_buf, rate=sampling_rate, data=audio_data)
wav_buf.seek(0)
audio_base64 = base64.b64encode(wav_buf.read()).decode('utf-8')
supabase.table("processing_queue").update({"status": "completed"}).eq("id", job_id).execute()
await update_status("free")
return {"status": "success", "audio": audio_base64}
except Exception as e:
print(f"Generation error: {e}")
await update_status("free")
supabase.table("processing_queue").update({"status": "failed"}).eq("id", job_id).execute()
raise HTTPException(status_code=500, detail=str(e)) |