| import os |
| import uuid |
| import threading |
| import queue |
| from typing import Dict, Optional |
|
|
| import torch |
| import soundfile as sf |
| from fastapi import FastAPI, BackgroundTasks, HTTPException |
| from fastapi.responses import FileResponse |
| from diffusers import DiffusionPipeline |
| from huggingface_hub import login |
|
|
| |
|
|
| app = FastAPI() |
|
|
| OUTPUT_DIR = "outputs" |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 |
|
|
| pipe = None |
|
|
| |
|
|
| login(token=os.environ["HF_TOKEN"]) |
|
|
| |
|
|
| def load_pipeline(): |
| global pipe |
| if pipe is None: |
| print("Loading Stable Audio Open pipeline...") |
| pipe = DiffusionPipeline.from_pretrained( |
| "stabilityai/stable-audio-open-1.0", |
| torch_dtype=DTYPE, |
| trust_remote_code=True, |
| revision="main", |
| low_cpu_mem_usage=False, |
| ) |
| pipe.enable_attention_slicing() |
| pipe.to(DEVICE) |
| print("Pipeline loaded on", DEVICE) |
|
|
| @app.on_event("startup") |
| def startup(): |
| load_pipeline() |
| start_worker() |
|
|
| |
|
|
| class Job: |
| def __init__(self, prompt: str, duration: float): |
| self.id = uuid.uuid4().hex |
| self.prompt = prompt |
| self.duration = duration |
| self.progress = 0 |
| self.status = "queued" |
| self.filepath: Optional[str] = None |
| self.error: Optional[str] = None |
|
|
| jobs: Dict[str, Job] = {} |
| job_queue: queue.Queue[Job] = queue.Queue() |
|
|
| |
|
|
| def worker_loop(): |
| while True: |
| job: Job = job_queue.get() |
| try: |
| job.status = "running" |
| job.progress = 5 |
|
|
| load_pipeline() |
|
|
| pipe.audio_length_in_s = float(job.duration) |
|
|
| with torch.no_grad(): |
| output = pipe( |
| prompt=job.prompt, |
| guidance_scale=7.5, |
| num_inference_steps=150, |
| ) |
|
|
| job.progress = 90 |
|
|
| audio = output.audios[0] |
|
|
| if isinstance(audio, torch.Tensor): |
| audio = audio.detach().cpu().numpy() |
|
|
| if audio.ndim == 2 and audio.shape[0] < audio.shape[1]: |
| audio = audio.T |
|
|
| audio = audio.astype("float32").clip(-1.0, 1.0) |
|
|
| filename = f"{job.id}.wav" |
| path = os.path.join(OUTPUT_DIR, filename) |
| sf.write(path, audio, samplerate=44100) |
|
|
| job.filepath = path |
| job.progress = 100 |
| job.status = "done" |
|
|
| except Exception as e: |
| job.status = "error" |
| job.error = str(e) |
|
|
| finally: |
| job_queue.task_done() |
|
|
| def start_worker(): |
| t = threading.Thread(target=worker_loop, daemon=True) |
| t.start() |
|
|
| |
|
|
| @app.post("/generate") |
| def generate(prompt: str, duration: float = 10.0): |
| job = Job(prompt=prompt, duration=duration) |
| jobs[job.id] = job |
| job_queue.put(job) |
|
|
| return { |
| "job_id": job.id, |
| "status": job.status, |
| } |
|
|
| @app.get("/status/{job_id}") |
| def status(job_id: str): |
| job = jobs.get(job_id) |
| if not job: |
| raise HTTPException(status_code=404, detail="Job not found") |
|
|
| return { |
| "job_id": job.id, |
| "status": job.status, |
| "progress": job.progress, |
| "error": job.error, |
| "ready": job.status == "done", |
| } |
|
|
| def cleanup_file(path: str): |
| try: |
| os.remove(path) |
| except Exception: |
| pass |
|
|
| @app.get("/download/{job_id}") |
| def download(job_id: str, background_tasks: BackgroundTasks): |
| job = jobs.get(job_id) |
| if not job: |
| raise HTTPException(status_code=404, detail="Job not found") |
|
|
| if job.status != "done" or not job.filepath: |
| raise HTTPException(status_code=400, detail="File not ready") |
|
|
| background_tasks.add_task(cleanup_file, job.filepath) |
|
|
| return FileResponse( |
| path=job.filepath, |
| media_type="audio/wav", |
| filename=os.path.basename(job.filepath), |
| ) |
|
|