stem-separator / backend /task_queue.py
sayakpaul's picture
sayakpaul HF Staff
Upload backend/task_queue.py with huggingface_hub
b86b18c verified
import asyncio
import os
import shutil
from dataclasses import dataclass
from gradio_client import Client, handle_file
from backend import file_manager
STEMS = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"]
INFERENCE_SPACE_ID = "sayakpaul/stem-separator-inference"
@dataclass
class JobProgress:
state: str = "queued"
progress: float = 0.0
message: str = "Waiting in queue..."
stems: dict[str, str] | None = None
error: str | None = None
# Shared state
jobs: dict[str, JobProgress] = {}
_queue: asyncio.Queue | None = None
def get_queue() -> asyncio.Queue:
global _queue
if _queue is None:
_queue = asyncio.Queue(maxsize=5)
return _queue
def get_job_progress(job_id: str) -> JobProgress | None:
return jobs.get(job_id)
async def enqueue_job(job_id: str, stems: list[str], output_format: str) -> bool:
"""Enqueue a separation job. Returns False if queue is full."""
q = get_queue()
if q.full():
return False
jobs[job_id] = JobProgress()
await q.put((job_id, stems, output_format))
return True
async def worker_loop():
"""Single worker that processes separation jobs sequentially via the inference Space."""
client = Client(INFERENCE_SPACE_ID)
q = get_queue()
while True:
job_id, stems, output_format = await q.get()
try:
progress = jobs.get(job_id)
if progress is None:
progress = JobProgress()
jobs[job_id] = progress
input_file = file_manager.get_input_file(job_id)
if input_file is None:
progress.state = "error"
progress.error = "Input file not found"
progress.message = "Error: input file not found"
continue
output_dir = str(file_manager.get_output_dir(job_id))
progress.state = "separating"
progress.progress = 0.2
progress.message = "Separating stems..."
# Call the remote inference Space in a thread to avoid blocking
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: client.predict(
audio_file=handle_file(str(input_file)),
stems=stems,
output_format=output_format,
api_name="/separate",
),
)
# response is a tuple of 6 elements (one per stem in STEMS order).
# Each is a filepath to a downloaded temp file, or None.
result: dict[str, str] = {}
for stem_name, file_path in zip(STEMS, response):
if file_path is not None:
filename = os.path.basename(file_path)
dest = os.path.join(output_dir, filename)
shutil.copy2(file_path, dest)
result[stem_name] = filename
progress.state = "done"
progress.progress = 1.0
progress.message = "Separation complete!"
progress.stems = result
except Exception as e:
progress = jobs.get(job_id, JobProgress())
progress.state = "error"
progress.error = str(e)
progress.message = f"Error: {e}"
jobs[job_id] = progress
finally:
q.task_done()