Spaces:
Runtime error
Runtime error
| import os | |
| import io | |
| import asyncio | |
| import base64 | |
| import datetime | |
| import torch | |
| import numpy as np | |
| import scipy.io.wavfile | |
| from fastapi import FastAPI, HTTPException, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import pipeline, AutoProcessor, AudioGenForConditionalGeneration | |
| from supabase import create_client, Client | |
| app = FastAPI() | |
| # --- CORS Configuration --- | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --- Supabase Configuration --- | |
| SUPABASE_URL = os.environ.get("SUPABASE_URL", "https://tladrluezsmmhjbhupgb.supabase.co") | |
| SUPABASE_KEY = os.environ.get("SUPABASE_KEY", "sb_publishable_zb8TGeURLnafHWDffG9DMg_PtFO_kmv") | |
| SERVER_ID = os.environ.get("SERVER_ID", "efectos-worker") | |
| SERVER_URL = os.environ.get("SERVER_URL", "https://carley1234-efectos.hf.space") | |
| SERVICE_TYPE = "effect" | |
| supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
| # --- Model Loading --- | |
| device = "cpu" | |
| model_id = "facebook/audiogen-medium" | |
| audio_pipe = None | |
| load_error = None | |
| is_processing = False | |
| def load_models(): | |
| global audio_pipe, load_error | |
| try: | |
| # Limit CPU threads BEFORE loading to avoid memory/CPU spikes | |
| torch.set_num_threads(1) | |
| print(f"Loading model {model_id} via explicit classes...") | |
| # We load the classes explicitly to avoid 'Unrecognized model' errors in pipeline | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| model = AudioGenForConditionalGeneration.from_pretrained(model_id) | |
| # Then we wrap it in a pipeline for easy generation | |
| audio_pipe = pipeline("text-to-audio", model=model, tokenizer=processor, device=device) | |
| print("Model loaded successfully.") | |
| load_error = None | |
| except Exception as e: | |
| load_error = str(e) | |
| print(f"Error loading model: {e}") | |
| async def update_status(status: str = None): | |
| global is_processing | |
| try: | |
| if status: | |
| is_processing = (status == "busy") | |
| current_status = "busy" if is_processing else "free" | |
| data = { | |
| "id": SERVER_ID, | |
| "url": SERVER_URL, | |
| "status": current_status, | |
| "service_type": SERVICE_TYPE, | |
| "last_heartbeat": datetime.datetime.now(datetime.timezone.utc).isoformat() | |
| } | |
| supabase.table("server_status").upsert(data).execute() | |
| except Exception as e: | |
| print(f"Error updating status: {e}") | |
| async def heartbeat_loop(): | |
| while True: | |
| await update_status() | |
| await asyncio.sleep(20) | |
| async def startup_event(): | |
| # Load models in background to avoid startup timeouts | |
| asyncio.create_task(asyncio.to_thread(load_models)) | |
| await update_status("free") | |
| asyncio.create_task(heartbeat_loop()) | |
| async def root(): | |
| return {"message": "VidSpri Effects Worker is running", "status": "ok"} | |
| async def generate_effect(job_id: str, prompt: str = Form(...), duration: int = Form(3)): | |
| await update_status("busy") | |
| supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute() | |
| try: | |
| if not audio_pipe: | |
| msg = f"Model pipeline not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..." | |
| raise Exception(msg) | |
| # AudioGen: 50 tokens ~ 1 second of audio | |
| max_tokens = min(int(duration) * 50, 250) # Max 5 seconds (250 tokens) | |
| # Run inference in a separate thread to avoid blocking heartbeats | |
| def run_inference(): | |
| with torch.no_grad(): | |
| torch.set_num_threads(1) | |
| return audio_pipe( | |
| prompt, | |
| generate_kwargs={ | |
| "max_new_tokens": max_tokens, | |
| "do_sample": True, | |
| "temperature": 1.0, | |
| "top_k": 250, | |
| "top_p": 0.99, | |
| "guidance_scale": 3.0 | |
| } | |
| ) | |
| result = await asyncio.to_thread(run_inference) | |
| # Convert to WAV in memory | |
| sampling_rate = result["sampling_rate"] | |
| audio_data = result["audio"] | |
| # Ensure audio_data is a numpy array and has correct type for scipy | |
| if isinstance(audio_data, torch.Tensor): | |
| audio_data = audio_data.cpu().numpy() | |
| # Clean data and ensure CPU numpy array | |
| audio_data = np.nan_to_num(audio_data) | |
| # Remove DC offset to eliminate "click" and constant hum | |
| if audio_data.size > 0: | |
| audio_data = audio_data - np.mean(audio_data) | |
| # 2. Soft-clipping to prevent digital artifacts on saturation | |
| audio_data = np.tanh(audio_data * 1.2) | |
| # Standardize shape | |
| if audio_data.ndim == 3: | |
| audio_data = audio_data[0] | |
| if audio_data.ndim == 2: | |
| audio_data = np.mean(audio_data, axis=0) | |
| audio_data = audio_data.flatten() | |
| # Fade out end of clip (0.2s for effects) | |
| fade_len = int(sampling_rate * 0.2) | |
| if len(audio_data) > fade_len: | |
| fade_window = np.linspace(1.0, 0.0, fade_len) | |
| audio_data[-fade_len:] *= fade_window | |
| # Normalize audio with headroom | |
| max_val = np.abs(audio_data).max() | |
| if max_val > 0: | |
| audio_data = (audio_data / (max_val + 1e-6)) * 0.9 | |
| # Convert to 16-bit PCM with safety clamp | |
| audio_data = np.clip(audio_data * 32767, -32768, 32767).astype(np.int16) | |
| wav_buf = io.BytesIO() | |
| scipy.io.wavfile.write(wav_buf, rate=sampling_rate, data=audio_data) | |
| wav_buf.seek(0) | |
| audio_base64 = base64.b64encode(wav_buf.read()).decode('utf-8') | |
| supabase.table("processing_queue").update({"status": "completed"}).eq("id", job_id).execute() | |
| await update_status("free") | |
| return {"status": "success", "audio": audio_base64} | |
| except Exception as e: | |
| print(f"Generation error: {e}") | |
| await update_status("free") | |
| supabase.table("processing_queue").update({"status": "failed"}).eq("id", job_id).execute() | |
| raise HTTPException(status_code=500, detail=str(e)) |