import os import uuid import threading import queue from typing import Dict, Optional import torch import soundfile as sf from fastapi import FastAPI, BackgroundTasks, HTTPException from fastapi.responses import FileResponse from diffusers import DiffusionPipeline from huggingface_hub import login # -------------------- App -------------------- app = FastAPI() OUTPUT_DIR = "outputs" os.makedirs(OUTPUT_DIR, exist_ok=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 pipe = None # lazy-loaded # -------------------- HF Login -------------------- login(token=os.environ["HF_TOKEN"]) # -------------------- Pipeline -------------------- def load_pipeline(): global pipe if pipe is None: print("Loading Stable Audio Open pipeline...") pipe = DiffusionPipeline.from_pretrained( "stabilityai/stable-audio-open-1.0", torch_dtype=DTYPE, trust_remote_code=True, revision="main", low_cpu_mem_usage=False, ) pipe.enable_attention_slicing() pipe.to(DEVICE) print("Pipeline loaded on", DEVICE) @app.on_event("startup") def startup(): load_pipeline() start_worker() # -------------------- Job State -------------------- class Job: def __init__(self, prompt: str, duration: float): self.id = uuid.uuid4().hex self.prompt = prompt self.duration = duration self.progress = 0 self.status = "queued" # queued | running | done | error self.filepath: Optional[str] = None self.error: Optional[str] = None jobs: Dict[str, Job] = {} job_queue: queue.Queue[Job] = queue.Queue() # -------------------- Worker Thread -------------------- def worker_loop(): while True: job: Job = job_queue.get() try: job.status = "running" job.progress = 5 load_pipeline() pipe.audio_length_in_s = float(job.duration) with torch.no_grad(): output = pipe( prompt=job.prompt, guidance_scale=7.5, num_inference_steps=150, ) job.progress = 90 audio = output.audios[0] if isinstance(audio, torch.Tensor): audio = audio.detach().cpu().numpy() if audio.ndim == 2 and audio.shape[0] < audio.shape[1]: audio = audio.T audio = audio.astype("float32").clip(-1.0, 1.0) filename = f"{job.id}.wav" path = os.path.join(OUTPUT_DIR, filename) sf.write(path, audio, samplerate=44100) job.filepath = path job.progress = 100 job.status = "done" except Exception as e: job.status = "error" job.error = str(e) finally: job_queue.task_done() def start_worker(): t = threading.Thread(target=worker_loop, daemon=True) t.start() # -------------------- API -------------------- @app.post("/generate") def generate(prompt: str, duration: float = 10.0): job = Job(prompt=prompt, duration=duration) jobs[job.id] = job job_queue.put(job) return { "job_id": job.id, "status": job.status, } @app.get("/status/{job_id}") def status(job_id: str): job = jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") return { "job_id": job.id, "status": job.status, "progress": job.progress, "error": job.error, "ready": job.status == "done", } def cleanup_file(path: str): try: os.remove(path) except Exception: pass @app.get("/download/{job_id}") def download(job_id: str, background_tasks: BackgroundTasks): job = jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") if job.status != "done" or not job.filepath: raise HTTPException(status_code=400, detail="File not ready") background_tasks.add_task(cleanup_file, job.filepath) return FileResponse( path=job.filepath, media_type="audio/wav", filename=os.path.basename(job.filepath), )