File size: 6,297 Bytes
b642640
 
 
 
 
 
8f336f4
b642640
90900a9
b642640
f9e7461
b642640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce84bf1
b642640
 
 
 
 
 
8adfead
f9e7461
a818820
b642640
 
a7c8314
f9e7461
a7c8314
f9e7461
a434ec5
f9e7461
14825a2
f9e7461
 
 
 
 
 
426a631
a7c8314
a818820
a7c8314
a818820
a7c8314
 
b642640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7c8314
 
b642640
 
 
 
 
 
 
 
660552c
b642640
 
 
 
f9e7461
 
a818820
b642640
f9e7461
 
 
 
8727b92
 
c8ec810
f9e7461
 
 
 
 
 
 
 
 
 
a0240fe
8727b92
f9e7461
 
 
 
 
 
 
 
 
8f336f4
a7c8314
a0240fe
8f336f4
a7c8314
 
 
 
f9e7461
8adfead
 
a7c8314
 
a0240fe
a7c8314
 
a0240fe
 
a7c8314
 
f9e7461
a7c8314
 
 
 
 
f9e7461
9fae321
 
a7c8314
9fae321
f9e7461
a7c8314
9fae321
b642640
8f336f4
b642640
 
 
 
 
 
 
 
 
8727b92
b642640
 
a818820
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import io
import asyncio
import base64
import datetime
import torch
import numpy as np
import scipy.io.wavfile
from fastapi import FastAPI, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline, AutoProcessor, AudioGenForConditionalGeneration
from supabase import create_client, Client

app = FastAPI()

# --- CORS Configuration ---
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- Supabase Configuration ---
SUPABASE_URL = os.environ.get("SUPABASE_URL", "https://tladrluezsmmhjbhupgb.supabase.co")
SUPABASE_KEY = os.environ.get("SUPABASE_KEY", "sb_publishable_zb8TGeURLnafHWDffG9DMg_PtFO_kmv")
SERVER_ID = os.environ.get("SERVER_ID", "efectos-worker")
SERVER_URL = os.environ.get("SERVER_URL", "https://carley1234-efectos.hf.space")
SERVICE_TYPE = "effect"

supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY)

# --- Model Loading ---
device = "cpu"
model_id = "facebook/audiogen-medium"
audio_pipe = None
load_error = None
is_processing = False

def load_models():
    global audio_pipe, load_error
    try:
        # Limit CPU threads BEFORE loading to avoid memory/CPU spikes
        torch.set_num_threads(1)
        print(f"Loading model {model_id} via explicit classes...")

        # We load the classes explicitly to avoid 'Unrecognized model' errors in pipeline
        processor = AutoProcessor.from_pretrained(model_id)
        model = AudioGenForConditionalGeneration.from_pretrained(model_id)

        # Then we wrap it in a pipeline for easy generation
        audio_pipe = pipeline("text-to-audio", model=model, tokenizer=processor, device=device)

        print("Model loaded successfully.")
        load_error = None
    except Exception as e:
        load_error = str(e)
        print(f"Error loading model: {e}")

async def update_status(status: str = None):
    global is_processing
    try:
        if status:
            is_processing = (status == "busy")
        current_status = "busy" if is_processing else "free"
        data = {
            "id": SERVER_ID,
            "url": SERVER_URL,
            "status": current_status,
            "service_type": SERVICE_TYPE,
            "last_heartbeat": datetime.datetime.now(datetime.timezone.utc).isoformat()
        }
        supabase.table("server_status").upsert(data).execute()
    except Exception as e:
        print(f"Error updating status: {e}")

async def heartbeat_loop():
    while True:
        await update_status()
        await asyncio.sleep(20)

@app.on_event("startup")
async def startup_event():
    # Load models in background to avoid startup timeouts
    asyncio.create_task(asyncio.to_thread(load_models))
    await update_status("free")
    asyncio.create_task(heartbeat_loop())

@app.get("/")
async def root():
    return {"message": "VidSpri Effects Worker is running", "status": "ok"}

@app.post("/generate/{job_id}")
async def generate_effect(job_id: str, prompt: str = Form(...), duration: int = Form(3)):
    await update_status("busy")
    supabase.table("processing_queue").update({"status": "processing"}).eq("id", job_id).execute()

    try:
        if not audio_pipe:
            msg = f"Model pipeline not loaded. Error during startup: {load_error}" if load_error else "Model is still starting up..."
            raise Exception(msg)

        # AudioGen: 50 tokens ~ 1 second of audio
        max_tokens = min(int(duration) * 50, 250) # Max 5 seconds (250 tokens)

        # Run inference in a separate thread to avoid blocking heartbeats
        def run_inference():
            with torch.no_grad():
                torch.set_num_threads(1)
                return audio_pipe(
                    prompt,
                    generate_kwargs={
                        "max_new_tokens": max_tokens,
                        "do_sample": True,
                        "temperature": 1.0,
                        "top_k": 250,
                        "top_p": 0.99,
                        "guidance_scale": 3.0
                    }
                )

        result = await asyncio.to_thread(run_inference)

        # Convert to WAV in memory
        sampling_rate = result["sampling_rate"]
        audio_data = result["audio"]

        # Ensure audio_data is a numpy array and has correct type for scipy
        if isinstance(audio_data, torch.Tensor):
            audio_data = audio_data.cpu().numpy()

        # Clean data and ensure CPU numpy array
        audio_data = np.nan_to_num(audio_data)

        # Remove DC offset to eliminate "click" and constant hum
        if audio_data.size > 0:
            audio_data = audio_data - np.mean(audio_data)

        # 2. Soft-clipping to prevent digital artifacts on saturation
        audio_data = np.tanh(audio_data * 1.2)

        # Standardize shape
        if audio_data.ndim == 3:
            audio_data = audio_data[0]

        if audio_data.ndim == 2:
            audio_data = np.mean(audio_data, axis=0)

        audio_data = audio_data.flatten()

        # Fade out end of clip (0.2s for effects)
        fade_len = int(sampling_rate * 0.2)
        if len(audio_data) > fade_len:
            fade_window = np.linspace(1.0, 0.0, fade_len)
            audio_data[-fade_len:] *= fade_window

        # Normalize audio with headroom
        max_val = np.abs(audio_data).max()
        if max_val > 0:
            audio_data = (audio_data / (max_val + 1e-6)) * 0.9

        # Convert to 16-bit PCM with safety clamp
        audio_data = np.clip(audio_data * 32767, -32768, 32767).astype(np.int16)

        wav_buf = io.BytesIO()
        scipy.io.wavfile.write(wav_buf, rate=sampling_rate, data=audio_data)
        wav_buf.seek(0)

        audio_base64 = base64.b64encode(wav_buf.read()).decode('utf-8')

        supabase.table("processing_queue").update({"status": "completed"}).eq("id", job_id).execute()
        await update_status("free")
        return {"status": "success", "audio": audio_base64}

    except Exception as e:
        print(f"Generation error: {e}")
        await update_status("free")
        supabase.table("processing_queue").update({"status": "failed"}).eq("id", job_id).execute()
        raise HTTPException(status_code=500, detail=str(e))