Spaces:
Running
Running
File size: 4,810 Bytes
0dfbd72 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """
Serial generation queue — ensures only one TTS inference runs at a time
to avoid GPU contention.
"""
import asyncio
import traceback
from dataclasses import dataclass
from typing import Coroutine, Literal
# Keep references to fire-and-forget background tasks to prevent GC
_background_tasks: set = set()
@dataclass
class GenerationJob:
"""Queued generation work plus the generation ID it belongs to."""
generation_id: str
coro: Coroutine
# Generation queue — serializes TTS inference to avoid GPU contention
_generation_queue: asyncio.Queue = None # type: ignore # initialized at startup
_generation_worker_task: asyncio.Task | None = None
_queued_generation_ids: set[str] = set()
_running_generation_tasks: dict[str, asyncio.Task] = {}
_cancelled_generation_ids: set[str] = set()
def create_background_task(coro) -> asyncio.Task:
"""Create a background task and prevent it from being garbage collected."""
task = asyncio.create_task(coro)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
return task
async def _generation_worker():
"""Worker that processes generation tasks one at a time."""
while True:
job = await _generation_queue.get()
try:
if job.generation_id in _cancelled_generation_ids:
_cancelled_generation_ids.discard(job.generation_id)
job.coro.close()
continue
task = asyncio.create_task(job.coro)
_running_generation_tasks[job.generation_id] = task
_queued_generation_ids.discard(job.generation_id)
try:
await task
except asyncio.CancelledError:
if not task.cancelled():
raise
except Exception:
traceback.print_exc()
await _force_fail_if_active(
job.generation_id,
"Worker exited without writing terminal status",
)
finally:
_running_generation_tasks.pop(job.generation_id, None)
_queued_generation_ids.discard(job.generation_id)
_generation_queue.task_done()
async def _force_fail_if_active(generation_id: str, error: str) -> None:
"""Best-effort recovery — flip an active row to failed if the worker
bailed before writing a terminal status. Catches the case where the gen
coroutine's own status-write raised (e.g. SQLite lock contention)."""
try:
from ..database import Generation as DBGeneration, get_db
from . import history
db = next(get_db())
try:
gen = db.query(DBGeneration).filter_by(id=generation_id).first()
if gen is None:
return
if (gen.status or "completed") not in ("loading_model", "generating"):
return
await history.update_generation_status(
generation_id=generation_id,
status="failed",
db=db,
error=error,
)
finally:
db.close()
except Exception:
traceback.print_exc()
def enqueue_generation(generation_id: str, coro):
"""Add a generation coroutine to the serial queue."""
if _generation_queue is None:
raise RuntimeError("Generation queue has not been initialized")
_queued_generation_ids.add(generation_id)
_generation_queue.put_nowait(GenerationJob(generation_id=generation_id, coro=coro))
def cancel_generation(generation_id: str) -> Literal["queued", "running"] | None:
"""Cancel a queued or running generation if it is still active."""
running_task = _running_generation_tasks.get(generation_id)
if running_task is not None:
running_task.cancel()
return "running"
if generation_id in _queued_generation_ids:
_queued_generation_ids.discard(generation_id)
_cancelled_generation_ids.add(generation_id)
return "queued"
return None
def init_queue(force: bool = False):
"""Initialize the generation queue and start the worker.
Must be called once during application startup (inside a running event loop).
"""
global _generation_queue, _generation_worker_task
global _queued_generation_ids, _running_generation_tasks, _cancelled_generation_ids
if _generation_worker_task is not None and not _generation_worker_task.done():
if not force:
return
_generation_worker_task.cancel()
for task in list(_running_generation_tasks.values()):
task.cancel()
_generation_queue = asyncio.Queue()
_queued_generation_ids = set()
_running_generation_tasks = {}
_cancelled_generation_ids = set()
_generation_worker_task = create_background_task(_generation_worker())
|