| from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException |
| from fastapi.responses import FileResponse |
| import torch |
| import os |
| import uuid |
| import subprocess |
| import threading |
| from typing import Dict |
|
|
| import zipfile |
| import tempfile |
|
|
|
|
| app = FastAPI() |
|
|
| UPLOAD_DIR = "uploads" |
| OUTPUT_DIR = "outputs" |
|
|
| os.makedirs(UPLOAD_DIR, exist_ok=True) |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| QUALITY_PRESETS = { |
| "low": {"shifts": 0, "overlap": 0.25}, |
| "medium": {"shifts": 1, "overlap": 0.25}, |
| "high": {"shifts": 2, "overlap": 0.5}, |
| } |
|
|
| |
| |
| |
| jobs: Dict[str, dict] = {} |
|
|
| |
| |
| |
| def run_demucs(job_id: str, input_path: str, quality: str): |
| try: |
| jobs[job_id]["status"] = "processing" |
| jobs[job_id]["progress"] = 10 |
|
|
| preset = QUALITY_PRESETS[quality] |
| output_path = os.path.join(OUTPUT_DIR, job_id) |
|
|
| cmd = [ |
| "python3", "-m", "demucs", |
| "--device", DEVICE, |
| "--shifts", str(preset["shifts"]), |
| "--overlap", str(preset["overlap"]), |
| "--out", output_path, |
| input_path |
| ] |
|
|
| subprocess.run(cmd, check=True) |
|
|
| base = os.path.splitext(os.path.basename(input_path))[0] |
| stems_dir = os.path.join(output_path, "htdemucs", base) |
|
|
| jobs[job_id]["stems"] = { |
| "vocals": f"{stems_dir}/vocals.wav", |
| "drums": f"{stems_dir}/drums.wav", |
| "bass": f"{stems_dir}/bass.wav", |
| "other": f"{stems_dir}/other.wav", |
| } |
|
|
| jobs[job_id]["progress"] = 100 |
| jobs[job_id]["status"] = "completed" |
|
|
| except Exception as e: |
| jobs[job_id]["status"] = "failed" |
| jobs[job_id]["error"] = str(e) |
|
|
| |
| |
| |
| def cleanup_job(job_id: str): |
| job = jobs.get(job_id) |
| if not job: |
| return |
|
|
| for path in job.get("stems", {}).values(): |
| if os.path.exists(path): |
| os.remove(path) |
|
|
| input_file = job.get("input_path") |
| if input_file and os.path.exists(input_file): |
| os.remove(input_file) |
|
|
| output_dir = os.path.join(OUTPUT_DIR, job_id) |
| if os.path.exists(output_dir): |
| subprocess.run(["rm", "-rf", output_dir]) |
|
|
| jobs.pop(job_id, None) |
|
|
| |
| |
| |
| @app.post("/separate") |
| async def separate_audio( |
| file: UploadFile = File(...), |
| quality: str = "medium" |
| ): |
| if quality not in QUALITY_PRESETS: |
| raise HTTPException(400, "quality must be low, medium, or high") |
|
|
| job_id = uuid.uuid4().hex |
| |
| safe_name = os.path.basename(file.filename) |
| input_path = os.path.join(UPLOAD_DIR, f"{job_id}_{safe_name}") |
|
|
|
|
| with open(input_path, "wb") as f: |
| f.write(await file.read()) |
|
|
| jobs[job_id] = { |
| "status": "queued", |
| "progress": 0, |
| "quality": quality, |
| "input_path": input_path, |
| "stems": None, |
| } |
|
|
| thread = threading.Thread( |
| target=run_demucs, |
| args=(job_id, input_path, quality), |
| daemon=True |
| ) |
| thread.start() |
|
|
| return { |
| "job_id": job_id, |
| "status": "queued" |
| } |
|
|
| |
| |
| |
| @app.get("/status/{job_id}") |
| def job_status(job_id: str): |
| job = jobs.get(job_id) |
| if not job: |
| raise HTTPException(404, "Job not found") |
|
|
| return { |
| "job_id": job_id, |
| "status": job["status"], |
| "progress": job["progress"], |
| "stems": job.get("stems") |
| } |
|
|
| |
| |
| |
| @app.get("/download/{job_id}/all") |
| def download_all_stems( |
| job_id: str, |
| background_tasks: BackgroundTasks |
| ): |
| job = jobs.get(job_id) |
| if not job or job["status"] != "completed": |
| raise HTTPException(404, "Job not completed") |
|
|
| stems = job.get("stems") |
| if not stems: |
| raise HTTPException(404, "No stems found") |
|
|
| |
| tmp = tempfile.NamedTemporaryFile( |
| delete=False, |
| suffix=".zip" |
| ) |
| zip_path = tmp.name |
| tmp.close() |
|
|
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: |
| for stem_name, stem_path in stems.items(): |
| if os.path.exists(stem_path): |
| zipf.write( |
| stem_path, |
| arcname=f"{stem_name}.wav" |
| ) |
|
|
| |
| background_tasks.add_task(cleanup_job, job_id) |
| background_tasks.add_task(os.remove, zip_path) |
|
|
| return FileResponse( |
| zip_path, |
| media_type="application/zip", |
| filename=f"{job_id}_stems.zip" |
| ) |
|
|
|
|