dc / app.py
fomext's picture
Update app.py
31536a1 verified
Raw
History Blame Contribute Delete
4.92 kB
from fastapi import FastAPI, UploadFile, File, BackgroundTasks, HTTPException
from fastapi.responses import FileResponse
import torch
import os
import uuid
import subprocess
import threading
from typing import Dict
import zipfile
import tempfile
app = FastAPI()
UPLOAD_DIR = "uploads"
OUTPUT_DIR = "outputs"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
QUALITY_PRESETS = {
"low": {"shifts": 0, "overlap": 0.25},
"medium": {"shifts": 1, "overlap": 0.25},
"high": {"shifts": 2, "overlap": 0.5},
}
# -----------------------------
# In-memory job store
# -----------------------------
jobs: Dict[str, dict] = {}
# -----------------------------
# Worker function
# -----------------------------
def run_demucs(job_id: str, input_path: str, quality: str):
try:
jobs[job_id]["status"] = "processing"
jobs[job_id]["progress"] = 10
preset = QUALITY_PRESETS[quality]
output_path = os.path.join(OUTPUT_DIR, job_id)
cmd = [
"python3", "-m", "demucs",
"--device", DEVICE,
"--shifts", str(preset["shifts"]),
"--overlap", str(preset["overlap"]),
"--out", output_path,
input_path
]
subprocess.run(cmd, check=True)
base = os.path.splitext(os.path.basename(input_path))[0]
stems_dir = os.path.join(output_path, "htdemucs", base)
jobs[job_id]["stems"] = {
"vocals": f"{stems_dir}/vocals.wav",
"drums": f"{stems_dir}/drums.wav",
"bass": f"{stems_dir}/bass.wav",
"other": f"{stems_dir}/other.wav",
}
jobs[job_id]["progress"] = 100
jobs[job_id]["status"] = "completed"
except Exception as e:
jobs[job_id]["status"] = "failed"
jobs[job_id]["error"] = str(e)
# -----------------------------
# Cleanup helper
# -----------------------------
def cleanup_job(job_id: str):
job = jobs.get(job_id)
if not job:
return
for path in job.get("stems", {}).values():
if os.path.exists(path):
os.remove(path)
input_file = job.get("input_path")
if input_file and os.path.exists(input_file):
os.remove(input_file)
output_dir = os.path.join(OUTPUT_DIR, job_id)
if os.path.exists(output_dir):
subprocess.run(["rm", "-rf", output_dir])
jobs.pop(job_id, None)
# -----------------------------
# Create job
# -----------------------------
@app.post("/separate")
async def separate_audio(
file: UploadFile = File(...),
quality: str = "medium"
):
if quality not in QUALITY_PRESETS:
raise HTTPException(400, "quality must be low, medium, or high")
job_id = uuid.uuid4().hex
safe_name = os.path.basename(file.filename)
input_path = os.path.join(UPLOAD_DIR, f"{job_id}_{safe_name}")
with open(input_path, "wb") as f:
f.write(await file.read())
jobs[job_id] = {
"status": "queued",
"progress": 0,
"quality": quality,
"input_path": input_path,
"stems": None,
}
thread = threading.Thread(
target=run_demucs,
args=(job_id, input_path, quality),
daemon=True
)
thread.start()
return {
"job_id": job_id,
"status": "queued"
}
# -----------------------------
# Progress polling
# -----------------------------
@app.get("/status/{job_id}")
def job_status(job_id: str):
job = jobs.get(job_id)
if not job:
raise HTTPException(404, "Job not found")
return {
"job_id": job_id,
"status": job["status"],
"progress": job["progress"],
"stems": job.get("stems")
}
# -----------------------------
# Download stem + auto cleanup
# -----------------------------
@app.get("/download/{job_id}/all")
def download_all_stems(
job_id: str,
background_tasks: BackgroundTasks
):
job = jobs.get(job_id)
if not job or job["status"] != "completed":
raise HTTPException(404, "Job not completed")
stems = job.get("stems")
if not stems:
raise HTTPException(404, "No stems found")
# create temp zip file
tmp = tempfile.NamedTemporaryFile(
delete=False,
suffix=".zip"
)
zip_path = tmp.name
tmp.close()
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for stem_name, stem_path in stems.items():
if os.path.exists(stem_path):
zipf.write(
stem_path,
arcname=f"{stem_name}.wav"
)
# cleanup after response is sent
background_tasks.add_task(cleanup_job, job_id)
background_tasks.add_task(os.remove, zip_path)
return FileResponse(
zip_path,
media_type="application/zip",
filename=f"{job_id}_stems.zip"
)