stem-separator / backend /task_queue.py
sourav-das's picture
Upload folder using huggingface_hub
7dfae77 verified
import asyncio
from dataclasses import dataclass
from backend.separator import StemSeparatorService
from backend import file_manager
@dataclass
class JobProgress:
state: str = "queued"
progress: float = 0.0
message: str = "Waiting in queue..."
stems: dict[str, str] | None = None
error: str | None = None
# Shared state
jobs: dict[str, JobProgress] = {}
_queue: asyncio.Queue | None = None
def get_queue() -> asyncio.Queue:
global _queue
if _queue is None:
_queue = asyncio.Queue(maxsize=5)
return _queue
def get_job_progress(job_id: str) -> JobProgress | None:
return jobs.get(job_id)
async def enqueue_job(job_id: str, stems: list[str], output_format: str) -> bool:
"""Enqueue a separation job. Returns False if queue is full."""
q = get_queue()
if q.full():
return False
jobs[job_id] = JobProgress()
await q.put((job_id, stems, output_format))
return True
async def worker_loop():
"""Single worker that processes separation jobs sequentially."""
separator = StemSeparatorService()
q = get_queue()
while True:
job_id, stems, output_format = await q.get()
try:
progress = jobs.get(job_id)
if progress is None:
progress = JobProgress()
jobs[job_id] = progress
def update_progress(state: str, pct: float):
progress.state = state
progress.progress = pct
messages = {
"loading_model": "Loading BS-RoFormer model...",
"separating": "Separating stems...",
"finalizing": "Finalizing output files...",
"done": "Separation complete!",
}
progress.message = messages.get(state, f"{state}...")
input_file = file_manager.get_input_file(job_id)
if input_file is None:
progress.state = "error"
progress.error = "Input file not found"
progress.message = "Error: input file not found"
continue
output_dir = str(file_manager.get_output_dir(job_id))
# Run separation in a thread to avoid blocking the event loop
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
separator.separate,
str(input_file),
output_dir,
stems,
output_format,
update_progress,
)
progress.state = "done"
progress.progress = 1.0
progress.message = "Separation complete!"
progress.stems = result
except Exception as e:
progress = jobs.get(job_id, JobProgress())
progress.state = "error"
progress.error = str(e)
progress.message = f"Error: {e}"
jobs[job_id] = progress
finally:
q.task_done()