ao / app.py
fomext's picture
Update app.py
0be8d78 verified
Raw
History Blame Contribute Delete
4.31 kB
import os
import uuid
import threading
import queue
from typing import Dict, Optional
import torch
import soundfile as sf
from fastapi import FastAPI, BackgroundTasks, HTTPException
from fastapi.responses import FileResponse
from diffusers import DiffusionPipeline
from huggingface_hub import login
# -------------------- App --------------------
app = FastAPI()
OUTPUT_DIR = "outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
pipe = None # lazy-loaded
# -------------------- HF Login --------------------
login(token=os.environ["HF_TOKEN"])
# -------------------- Pipeline --------------------
def load_pipeline():
global pipe
if pipe is None:
print("Loading Stable Audio Open pipeline...")
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-audio-open-1.0",
torch_dtype=DTYPE,
trust_remote_code=True,
revision="main",
low_cpu_mem_usage=False,
)
pipe.enable_attention_slicing()
pipe.to(DEVICE)
print("Pipeline loaded on", DEVICE)
@app.on_event("startup")
def startup():
load_pipeline()
start_worker()
# -------------------- Job State --------------------
class Job:
def __init__(self, prompt: str, duration: float):
self.id = uuid.uuid4().hex
self.prompt = prompt
self.duration = duration
self.progress = 0
self.status = "queued" # queued | running | done | error
self.filepath: Optional[str] = None
self.error: Optional[str] = None
jobs: Dict[str, Job] = {}
job_queue: queue.Queue[Job] = queue.Queue()
# -------------------- Worker Thread --------------------
def worker_loop():
while True:
job: Job = job_queue.get()
try:
job.status = "running"
job.progress = 5
load_pipeline()
pipe.audio_length_in_s = float(job.duration)
with torch.no_grad():
output = pipe(
prompt=job.prompt,
guidance_scale=7.5,
num_inference_steps=150,
)
job.progress = 90
audio = output.audios[0]
if isinstance(audio, torch.Tensor):
audio = audio.detach().cpu().numpy()
if audio.ndim == 2 and audio.shape[0] < audio.shape[1]:
audio = audio.T
audio = audio.astype("float32").clip(-1.0, 1.0)
filename = f"{job.id}.wav"
path = os.path.join(OUTPUT_DIR, filename)
sf.write(path, audio, samplerate=44100)
job.filepath = path
job.progress = 100
job.status = "done"
except Exception as e:
job.status = "error"
job.error = str(e)
finally:
job_queue.task_done()
def start_worker():
t = threading.Thread(target=worker_loop, daemon=True)
t.start()
# -------------------- API --------------------
@app.post("/generate")
def generate(prompt: str, duration: float = 10.0):
job = Job(prompt=prompt, duration=duration)
jobs[job.id] = job
job_queue.put(job)
return {
"job_id": job.id,
"status": job.status,
}
@app.get("/status/{job_id}")
def status(job_id: str):
job = jobs.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
return {
"job_id": job.id,
"status": job.status,
"progress": job.progress,
"error": job.error,
"ready": job.status == "done",
}
def cleanup_file(path: str):
try:
os.remove(path)
except Exception:
pass
@app.get("/download/{job_id}")
def download(job_id: str, background_tasks: BackgroundTasks):
job = jobs.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job.status != "done" or not job.filepath:
raise HTTPException(status_code=400, detail="File not ready")
background_tasks.add_task(cleanup_file, job.filepath)
return FileResponse(
path=job.filepath,
media_type="audio/wav",
filename=os.path.basename(job.filepath),
)