Spaces:
Sleeping
Sleeping
| import asyncio | |
| import os | |
| import shutil | |
| from dataclasses import dataclass | |
| from gradio_client import Client, handle_file | |
| from backend import file_manager | |
| STEMS = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"] | |
| INFERENCE_SPACE_ID = "sayakpaul/stem-separator-inference" | |
| class JobProgress: | |
| state: str = "queued" | |
| progress: float = 0.0 | |
| message: str = "Waiting in queue..." | |
| stems: dict[str, str] | None = None | |
| error: str | None = None | |
| # Shared state | |
| jobs: dict[str, JobProgress] = {} | |
| _queue: asyncio.Queue | None = None | |
| def get_queue() -> asyncio.Queue: | |
| global _queue | |
| if _queue is None: | |
| _queue = asyncio.Queue(maxsize=5) | |
| return _queue | |
| def get_job_progress(job_id: str) -> JobProgress | None: | |
| return jobs.get(job_id) | |
| async def enqueue_job(job_id: str, stems: list[str], output_format: str) -> bool: | |
| """Enqueue a separation job. Returns False if queue is full.""" | |
| q = get_queue() | |
| if q.full(): | |
| return False | |
| jobs[job_id] = JobProgress() | |
| await q.put((job_id, stems, output_format)) | |
| return True | |
| async def worker_loop(): | |
| """Single worker that processes separation jobs sequentially via the inference Space.""" | |
| client = Client(INFERENCE_SPACE_ID) | |
| q = get_queue() | |
| while True: | |
| job_id, stems, output_format = await q.get() | |
| try: | |
| progress = jobs.get(job_id) | |
| if progress is None: | |
| progress = JobProgress() | |
| jobs[job_id] = progress | |
| input_file = file_manager.get_input_file(job_id) | |
| if input_file is None: | |
| progress.state = "error" | |
| progress.error = "Input file not found" | |
| progress.message = "Error: input file not found" | |
| continue | |
| output_dir = str(file_manager.get_output_dir(job_id)) | |
| progress.state = "separating" | |
| progress.progress = 0.2 | |
| progress.message = "Separating stems..." | |
| # Call the remote inference Space in a thread to avoid blocking | |
| loop = asyncio.get_event_loop() | |
| response = await loop.run_in_executor( | |
| None, | |
| lambda: client.predict( | |
| audio_file=handle_file(str(input_file)), | |
| stems=stems, | |
| output_format=output_format, | |
| api_name="/separate", | |
| ), | |
| ) | |
| # response is a tuple of 6 elements (one per stem in STEMS order). | |
| # Each is a filepath to a downloaded temp file, or None. | |
| result: dict[str, str] = {} | |
| for stem_name, file_path in zip(STEMS, response): | |
| if file_path is not None: | |
| filename = os.path.basename(file_path) | |
| dest = os.path.join(output_dir, filename) | |
| shutil.copy2(file_path, dest) | |
| result[stem_name] = filename | |
| progress.state = "done" | |
| progress.progress = 1.0 | |
| progress.message = "Separation complete!" | |
| progress.stems = result | |
| except Exception as e: | |
| progress = jobs.get(job_id, JobProgress()) | |
| progress.state = "error" | |
| progress.error = str(e) | |
| progress.message = f"Error: {e}" | |
| jobs[job_id] = progress | |
| finally: | |
| q.task_done() | |