Spaces:
Sleeping
Sleeping
File size: 3,402 Bytes
1b36a79 b86b18c 1b36a79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | import asyncio
import os
import shutil
from dataclasses import dataclass
from gradio_client import Client, handle_file
from backend import file_manager
STEMS = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"]
INFERENCE_SPACE_ID = "sayakpaul/stem-separator-inference"
@dataclass
class JobProgress:
state: str = "queued"
progress: float = 0.0
message: str = "Waiting in queue..."
stems: dict[str, str] | None = None
error: str | None = None
# Shared state
jobs: dict[str, JobProgress] = {}
_queue: asyncio.Queue | None = None
def get_queue() -> asyncio.Queue:
global _queue
if _queue is None:
_queue = asyncio.Queue(maxsize=5)
return _queue
def get_job_progress(job_id: str) -> JobProgress | None:
return jobs.get(job_id)
async def enqueue_job(job_id: str, stems: list[str], output_format: str) -> bool:
"""Enqueue a separation job. Returns False if queue is full."""
q = get_queue()
if q.full():
return False
jobs[job_id] = JobProgress()
await q.put((job_id, stems, output_format))
return True
async def worker_loop():
"""Single worker that processes separation jobs sequentially via the inference Space."""
client = Client(INFERENCE_SPACE_ID)
q = get_queue()
while True:
job_id, stems, output_format = await q.get()
try:
progress = jobs.get(job_id)
if progress is None:
progress = JobProgress()
jobs[job_id] = progress
input_file = file_manager.get_input_file(job_id)
if input_file is None:
progress.state = "error"
progress.error = "Input file not found"
progress.message = "Error: input file not found"
continue
output_dir = str(file_manager.get_output_dir(job_id))
progress.state = "separating"
progress.progress = 0.2
progress.message = "Separating stems..."
# Call the remote inference Space in a thread to avoid blocking
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: client.predict(
audio_file=handle_file(str(input_file)),
stems=stems,
output_format=output_format,
api_name="/separate",
),
)
# response is a tuple of 6 elements (one per stem in STEMS order).
# Each is a filepath to a downloaded temp file, or None.
result: dict[str, str] = {}
for stem_name, file_path in zip(STEMS, response):
if file_path is not None:
filename = os.path.basename(file_path)
dest = os.path.join(output_dir, filename)
shutil.copy2(file_path, dest)
result[stem_name] = filename
progress.state = "done"
progress.progress = 1.0
progress.message = "Separation complete!"
progress.stems = result
except Exception as e:
progress = jobs.get(job_id, JobProgress())
progress.state = "error"
progress.error = str(e)
progress.message = f"Error: {e}"
jobs[job_id] = progress
finally:
q.task_done()
|