| from fastapi import FastAPI, BackgroundTasks, HTTPException |
| from fastapi.responses import FileResponse |
| import threading |
| import queue |
| import uuid |
| import os |
| import time |
|
|
| import audiocraft |
| from audiocraft.models import MusicGen |
| from audiocraft.data.audio import audio_write |
|
|
| print("Audiocraft version:", audiocraft.__version__) |
|
|
| app = FastAPI() |
|
|
| MODEL_NAME = "facebook/musicgen-small" |
| OUTPUT_DIR = "outputs" |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| print("Loading MusicGen model...") |
| model = MusicGen.get_pretrained(MODEL_NAME) |
| model.set_generation_params( |
| duration=10, |
| temperature=1.0, |
| top_k=250, |
| top_p=0.0 |
| ) |
|
|
| |
| |
| |
|
|
| job_queue = queue.Queue() |
| jobs = {} |
|
|
|
|
| class JobStatus: |
| QUEUED = "queued" |
| PROCESSING = "processing" |
| COMPLETED = "completed" |
| FAILED = "failed" |
|
|
|
|
| import signal |
|
|
| class GenerationTimeout(Exception): |
| pass |
|
|
| def timeout_handler(signum, frame): |
| raise GenerationTimeout() |
|
|
| MAX_RUNTIME = 600 |
|
|
| def worker(): |
| while True: |
| job_id = job_queue.get() |
| job = jobs.get(job_id) |
|
|
| if not job: |
| job_queue.task_done() |
| continue |
|
|
| start_time = time.time() |
| jobs[job_id]["status"] = JobStatus.PROCESSING |
|
|
| try: |
| wav = model.generate([job["prompt"]])[0] |
|
|
| elapsed = time.time() - start_time |
| if elapsed > MAX_RUNTIME: |
| raise TimeoutError("Generation exceeded max runtime") |
|
|
| base_path = os.path.join(OUTPUT_DIR, job_id) |
|
|
| audio_write( |
| base_path, |
| wav.cpu(), |
| model.sample_rate, |
| strategy="loudness", |
| loudness_compressor=True |
| ) |
| |
| final_path = base_path + ".wav" |
| |
| jobs[job_id].update({ |
| "status": JobStatus.COMPLETED, |
| "file_path": final_path |
| }) |
|
|
|
|
| except Exception as e: |
| jobs[job_id].update({ |
| "status": JobStatus.FAILED, |
| "error": str(e) |
| }) |
|
|
| finally: |
| job_queue.task_done() |
|
|
|
|
|
|
| |
| threading.Thread(target=worker, daemon=True).start() |
|
|
| |
| |
| |
|
|
| @app.post("/generate") |
| async def generate(prompt: str): |
| job_id = uuid.uuid4().hex |
|
|
| jobs[job_id] = { |
| "status": JobStatus.QUEUED, |
| "prompt": prompt, |
| "created_at": time.time(), |
| "file_path": None |
| } |
|
|
| job_queue.put(job_id) |
|
|
| return { |
| "job_id": job_id, |
| "status_url": f"/status/{job_id}" |
| } |
|
|
|
|
| @app.get("/status/{job_id}") |
| async def status(job_id: str): |
| job = jobs.get(job_id) |
|
|
| if not job: |
| raise HTTPException(status_code=404, detail="Job not found") |
|
|
| response = { |
| "job_id": job_id, |
| "status": job["status"], |
| "elapsed_seconds": int(time.time() - job["created_at"]) |
| } |
|
|
|
|
| if job["status"] == JobStatus.COMPLETED: |
| response["download_url"] = f"/download/{job_id}" |
|
|
| if job["status"] == JobStatus.FAILED: |
| response["error"] = job.get("error") |
|
|
| return response |
|
|
|
|
| def delete_file(path: str, delay: int = 10): |
| """Deletes file after response is sent""" |
| time.sleep(delay) |
| if os.path.exists(path): |
| os.remove(path) |
|
|
|
|
| @app.get("/download/{job_id}") |
| async def download(job_id: str, background_tasks: BackgroundTasks): |
| job = jobs.get(job_id) |
|
|
| if not job: |
| raise HTTPException(status_code=404, detail="Job not found") |
|
|
| if job["status"] != JobStatus.COMPLETED: |
| raise HTTPException(status_code=400, detail="Job not completed") |
|
|
| path = job["file_path"] |
|
|
| if not path or not os.path.exists(path): |
| raise HTTPException(status_code=404, detail="File not found") |
|
|
| background_tasks.add_task(delete_file, path) |
|
|
| return FileResponse( |
| path, |
| media_type="audio/wav", |
| filename=os.path.basename(path) |
| ) |
|
|