import asyncio import os import shutil from dataclasses import dataclass from gradio_client import Client, handle_file from backend import file_manager STEMS = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"] INFERENCE_SPACE_ID = "sayakpaul/stem-separator-inference" @dataclass class JobProgress: state: str = "queued" progress: float = 0.0 message: str = "Waiting in queue..." stems: dict[str, str] | None = None error: str | None = None # Shared state jobs: dict[str, JobProgress] = {} _queue: asyncio.Queue | None = None def get_queue() -> asyncio.Queue: global _queue if _queue is None: _queue = asyncio.Queue(maxsize=5) return _queue def get_job_progress(job_id: str) -> JobProgress | None: return jobs.get(job_id) async def enqueue_job(job_id: str, stems: list[str], output_format: str) -> bool: """Enqueue a separation job. Returns False if queue is full.""" q = get_queue() if q.full(): return False jobs[job_id] = JobProgress() await q.put((job_id, stems, output_format)) return True async def worker_loop(): """Single worker that processes separation jobs sequentially via the inference Space.""" client = Client(INFERENCE_SPACE_ID) q = get_queue() while True: job_id, stems, output_format = await q.get() try: progress = jobs.get(job_id) if progress is None: progress = JobProgress() jobs[job_id] = progress input_file = file_manager.get_input_file(job_id) if input_file is None: progress.state = "error" progress.error = "Input file not found" progress.message = "Error: input file not found" continue output_dir = str(file_manager.get_output_dir(job_id)) progress.state = "separating" progress.progress = 0.2 progress.message = "Separating stems..." # Call the remote inference Space in a thread to avoid blocking loop = asyncio.get_event_loop() response = await loop.run_in_executor( None, lambda: client.predict( audio_file=handle_file(str(input_file)), stems=stems, output_format=output_format, api_name="/separate", ), ) # response is a tuple of 6 elements (one per stem in STEMS order). # Each is a filepath to a downloaded temp file, or None. result: dict[str, str] = {} for stem_name, file_path in zip(STEMS, response): if file_path is not None: filename = os.path.basename(file_path) dest = os.path.join(output_dir, filename) shutil.copy2(file_path, dest) result[stem_name] = filename progress.state = "done" progress.progress = 1.0 progress.message = "Separation complete!" progress.stems = result except Exception as e: progress = jobs.get(job_id, JobProgress()) progress.state = "error" progress.error = str(e) progress.message = f"Error: {e}" jobs[job_id] = progress finally: q.task_done()