from fastapi import FastAPI, BackgroundTasks, HTTPException from fastapi.responses import FileResponse import threading import queue import uuid import os import time import audiocraft from audiocraft.models import MusicGen from audiocraft.data.audio import audio_write print("Audiocraft version:", audiocraft.__version__) app = FastAPI() MODEL_NAME = "facebook/musicgen-small" OUTPUT_DIR = "outputs" os.makedirs(OUTPUT_DIR, exist_ok=True) print("Loading MusicGen model...") model = MusicGen.get_pretrained(MODEL_NAME) model.set_generation_params( duration=10, temperature=1.0, top_k=250, top_p=0.0 ) # ------------------------- # Job system (in-memory) # ------------------------- job_queue = queue.Queue() jobs = {} # job_id -> metadata class JobStatus: QUEUED = "queued" PROCESSING = "processing" COMPLETED = "completed" FAILED = "failed" import signal class GenerationTimeout(Exception): pass def timeout_handler(signum, frame): raise GenerationTimeout() MAX_RUNTIME = 600 # 10 minutes (CPU-safe) def worker(): while True: job_id = job_queue.get() job = jobs.get(job_id) if not job: job_queue.task_done() continue start_time = time.time() jobs[job_id]["status"] = JobStatus.PROCESSING try: wav = model.generate([job["prompt"]])[0] elapsed = time.time() - start_time if elapsed > MAX_RUNTIME: raise TimeoutError("Generation exceeded max runtime") base_path = os.path.join(OUTPUT_DIR, job_id) audio_write( base_path, wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True ) final_path = base_path + ".wav" jobs[job_id].update({ "status": JobStatus.COMPLETED, "file_path": final_path }) except Exception as e: jobs[job_id].update({ "status": JobStatus.FAILED, "error": str(e) }) finally: job_queue.task_done() # Start worker thread threading.Thread(target=worker, daemon=True).start() # ------------------------- # API endpoints # ------------------------- @app.post("/generate") async def generate(prompt: str): job_id = uuid.uuid4().hex jobs[job_id] = { "status": JobStatus.QUEUED, "prompt": prompt, "created_at": time.time(), "file_path": None } job_queue.put(job_id) return { "job_id": job_id, "status_url": f"/status/{job_id}" } @app.get("/status/{job_id}") async def status(job_id: str): job = jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") response = { "job_id": job_id, "status": job["status"], "elapsed_seconds": int(time.time() - job["created_at"]) } if job["status"] == JobStatus.COMPLETED: response["download_url"] = f"/download/{job_id}" if job["status"] == JobStatus.FAILED: response["error"] = job.get("error") return response def delete_file(path: str, delay: int = 10): """Deletes file after response is sent""" time.sleep(delay) if os.path.exists(path): os.remove(path) @app.get("/download/{job_id}") async def download(job_id: str, background_tasks: BackgroundTasks): job = jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") if job["status"] != JobStatus.COMPLETED: raise HTTPException(status_code=400, detail="Job not completed") path = job["file_path"] if not path or not os.path.exists(path): raise HTTPException(status_code=404, detail="File not found") background_tasks.add_task(delete_file, path) return FileResponse( path, media_type="audio/wav", filename=os.path.basename(path) )