mgg / app.py
fomext's picture
Update app.py
22e4707 verified
Raw
History Blame Contribute Delete
4.07 kB
from fastapi import FastAPI, BackgroundTasks, HTTPException
from fastapi.responses import FileResponse
import threading
import queue
import uuid
import os
import time
import audiocraft
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
print("Audiocraft version:", audiocraft.__version__)
app = FastAPI()
MODEL_NAME = "facebook/musicgen-small"
OUTPUT_DIR = "outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("Loading MusicGen model...")
model = MusicGen.get_pretrained(MODEL_NAME)
model.set_generation_params(
duration=10,
temperature=1.0,
top_k=250,
top_p=0.0
)
# -------------------------
# Job system (in-memory)
# -------------------------
job_queue = queue.Queue()
jobs = {} # job_id -> metadata
class JobStatus:
QUEUED = "queued"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
import signal
class GenerationTimeout(Exception):
pass
def timeout_handler(signum, frame):
raise GenerationTimeout()
MAX_RUNTIME = 600 # 10 minutes (CPU-safe)
def worker():
while True:
job_id = job_queue.get()
job = jobs.get(job_id)
if not job:
job_queue.task_done()
continue
start_time = time.time()
jobs[job_id]["status"] = JobStatus.PROCESSING
try:
wav = model.generate([job["prompt"]])[0]
elapsed = time.time() - start_time
if elapsed > MAX_RUNTIME:
raise TimeoutError("Generation exceeded max runtime")
base_path = os.path.join(OUTPUT_DIR, job_id)
audio_write(
base_path,
wav.cpu(),
model.sample_rate,
strategy="loudness",
loudness_compressor=True
)
final_path = base_path + ".wav"
jobs[job_id].update({
"status": JobStatus.COMPLETED,
"file_path": final_path
})
except Exception as e:
jobs[job_id].update({
"status": JobStatus.FAILED,
"error": str(e)
})
finally:
job_queue.task_done()
# Start worker thread
threading.Thread(target=worker, daemon=True).start()
# -------------------------
# API endpoints
# -------------------------
@app.post("/generate")
async def generate(prompt: str):
job_id = uuid.uuid4().hex
jobs[job_id] = {
"status": JobStatus.QUEUED,
"prompt": prompt,
"created_at": time.time(),
"file_path": None
}
job_queue.put(job_id)
return {
"job_id": job_id,
"status_url": f"/status/{job_id}"
}
@app.get("/status/{job_id}")
async def status(job_id: str):
job = jobs.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
response = {
"job_id": job_id,
"status": job["status"],
"elapsed_seconds": int(time.time() - job["created_at"])
}
if job["status"] == JobStatus.COMPLETED:
response["download_url"] = f"/download/{job_id}"
if job["status"] == JobStatus.FAILED:
response["error"] = job.get("error")
return response
def delete_file(path: str, delay: int = 10):
"""Deletes file after response is sent"""
time.sleep(delay)
if os.path.exists(path):
os.remove(path)
@app.get("/download/{job_id}")
async def download(job_id: str, background_tasks: BackgroundTasks):
job = jobs.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="Job not found")
if job["status"] != JobStatus.COMPLETED:
raise HTTPException(status_code=400, detail="Job not completed")
path = job["file_path"]
if not path or not os.path.exists(path):
raise HTTPException(status_code=404, detail="File not found")
background_tasks.add_task(delete_file, path)
return FileResponse(
path,
media_type="audio/wav",
filename=os.path.basename(path)
)