from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException from fastapi.responses import FileResponse import torch import os import uuid import subprocess import threading from typing import Dict import zipfile import tempfile app = FastAPI() UPLOAD_DIR = "uploads" OUTPUT_DIR = "outputs" os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" QUALITY_PRESETS = { "low": {"shifts": 0, "overlap": 0.25}, "medium": {"shifts": 1, "overlap": 0.25}, "high": {"shifts": 2, "overlap": 0.5}, } # ----------------------------- # In-memory job store # ----------------------------- jobs: Dict[str, dict] = {} # ----------------------------- # Worker function # ----------------------------- def run_demucs(job_id: str, input_path: str, quality: str): try: jobs[job_id]["status"] = "processing" jobs[job_id]["progress"] = 10 preset = QUALITY_PRESETS[quality] output_path = os.path.join(OUTPUT_DIR, job_id) cmd = [ "python3", "-m", "demucs", "--device", DEVICE, "--shifts", str(preset["shifts"]), "--overlap", str(preset["overlap"]), "--out", output_path, input_path ] subprocess.run(cmd, check=True) base = os.path.splitext(os.path.basename(input_path))[0] stems_dir = os.path.join(output_path, "htdemucs", base) jobs[job_id]["stems"] = { "vocals": f"{stems_dir}/vocals.wav", "drums": f"{stems_dir}/drums.wav", "bass": f"{stems_dir}/bass.wav", "other": f"{stems_dir}/other.wav", } jobs[job_id]["progress"] = 100 jobs[job_id]["status"] = "completed" except Exception as e: jobs[job_id]["status"] = "failed" jobs[job_id]["error"] = str(e) # ----------------------------- # Cleanup helper # ----------------------------- def cleanup_job(job_id: str): job = jobs.get(job_id) if not job: return for path in job.get("stems", {}).values(): if os.path.exists(path): os.remove(path) input_file = job.get("input_path") if input_file and os.path.exists(input_file): os.remove(input_file) output_dir = os.path.join(OUTPUT_DIR, job_id) if os.path.exists(output_dir): subprocess.run(["rm", "-rf", output_dir]) jobs.pop(job_id, None) # ----------------------------- # Create job # ----------------------------- @app.post("/separate") async def separate_audio( file: UploadFile = File(...), quality: str = "medium" ): if quality not in QUALITY_PRESETS: raise HTTPException(400, "quality must be low, medium, or high") job_id = uuid.uuid4().hex safe_name = os.path.basename(file.filename) input_path = os.path.join(UPLOAD_DIR, f"{job_id}_{safe_name}") with open(input_path, "wb") as f: f.write(await file.read()) jobs[job_id] = { "status": "queued", "progress": 0, "quality": quality, "input_path": input_path, "stems": None, } thread = threading.Thread( target=run_demucs, args=(job_id, input_path, quality), daemon=True ) thread.start() return { "job_id": job_id, "status": "queued" } # ----------------------------- # Progress polling # ----------------------------- @app.get("/status/{job_id}") def job_status(job_id: str): job = jobs.get(job_id) if not job: raise HTTPException(404, "Job not found") return { "job_id": job_id, "status": job["status"], "progress": job["progress"], "stems": job.get("stems") } # ----------------------------- # Download stem + auto cleanup # ----------------------------- @app.get("/download/{job_id}/all") def download_all_stems( job_id: str, background_tasks: BackgroundTasks ): job = jobs.get(job_id) if not job or job["status"] != "completed": raise HTTPException(404, "Job not completed") stems = job.get("stems") if not stems: raise HTTPException(404, "No stems found") # create temp zip file tmp = tempfile.NamedTemporaryFile( delete=False, suffix=".zip" ) zip_path = tmp.name tmp.close() with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: for stem_name, stem_path in stems.items(): if os.path.exists(stem_path): zipf.write( stem_path, arcname=f"{stem_name}.wav" ) # cleanup after response is sent background_tasks.add_task(cleanup_job, job_id) background_tasks.add_task(os.remove, zip_path) return FileResponse( zip_path, media_type="application/zip", filename=f"{job_id}_stems.zip" )