neuapi / api /src /inference /model_manager.py
grimshaw's picture
Upload folder using huggingface_hub
35bb6f4 verified
Raw
History Blame Contribute Delete
13.3 kB
from __future__ import annotations
import asyncio
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from enum import Enum
from typing import AsyncGenerator
import numpy as np
from loguru import logger
from api.src.core.config import settings
from api.src.core.model_config import (
BACKBONE_MODELS,
BackendType,
get_backbone_info,
)
class ModelLoadStatus(str, Enum):
PENDING = "pending"
DOWNLOADING = "downloading"
LOADING = "loading"
READY = "ready"
ERROR = "error"
@dataclass
class ModelLoadingTask:
task_id: str
model_id: str
status: ModelLoadStatus = ModelLoadStatus.PENDING
progress_message: str = ""
error_message: str = ""
started_at: float = 0.0
completed_at: float = 0.0
@dataclass
class LoadedModel:
model_id: str
codec_id: str
tts_instance: object # NeuTTS instance
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
backbone_device: str = "cpu"
codec_device: str = "cpu"
class ModelManager:
_instance: ModelManager | None = None
def __init__(self) -> None:
self._models: dict[str, LoadedModel] = {}
self._loading_tasks: dict[str, ModelLoadingTask] = {}
self._executor = ThreadPoolExecutor(max_workers=settings.max_inference_workers)
@classmethod
def get_instance(cls) -> ModelManager:
if cls._instance is None:
cls._instance = cls()
return cls._instance
@property
def loaded_models(self) -> dict[str, LoadedModel]:
return self._models
@property
def loading_tasks(self) -> dict[str, ModelLoadingTask]:
return self._loading_tasks
def is_loaded(self, model_id: str) -> bool:
return model_id in self._models
def get_task(self, task_id: str) -> ModelLoadingTask | None:
return self._loading_tasks.get(task_id)
async def load_model_async(
self,
model_id: str,
codec_id: str | None = None,
backbone_device: str | None = None,
codec_device: str | None = None,
) -> ModelLoadingTask:
"""Start loading a model in the background. Returns a task for polling."""
# Already loaded -> return READY task immediately
if model_id in self._models:
task = ModelLoadingTask(
task_id=str(uuid.uuid4()),
model_id=model_id,
status=ModelLoadStatus.READY,
progress_message="Already loaded",
started_at=time.time(),
completed_at=time.time(),
)
self._loading_tasks[task.task_id] = task
return task
# Already loading -> return existing task
for task in self._loading_tasks.values():
if task.model_id == model_id and task.status in (
ModelLoadStatus.PENDING,
ModelLoadStatus.DOWNLOADING,
ModelLoadStatus.LOADING,
):
return task
info = get_backbone_info(model_id)
if info is None:
raise ValueError(f"Unknown model: {model_id}. Available: {list(BACKBONE_MODELS.keys())}")
task = ModelLoadingTask(
task_id=str(uuid.uuid4()),
model_id=model_id,
status=ModelLoadStatus.PENDING,
progress_message="Queued",
started_at=time.time(),
)
self._loading_tasks[task.task_id] = task
asyncio.ensure_future(
self._background_load(task, codec_id, backbone_device, codec_device)
)
return task
async def _background_load(
self,
task: ModelLoadingTask,
codec_id: str | None,
backbone_device: str | None,
codec_device: str | None,
) -> None:
"""Background coroutine that loads a model and updates task status."""
try:
task.status = ModelLoadStatus.DOWNLOADING
task.progress_message = "Downloading / checking cache..."
info = get_backbone_info(task.model_id)
if info is None:
raise ValueError(f"Unknown model: {task.model_id}")
codec = codec_id or settings.default_codec
bb_device = backbone_device or settings.resolved_backbone_device
cc_device = codec_device or settings.default_codec_device
# GGUF models only support CPU (llama.cpp limitation)
if info.backend == BackendType.GGUF:
bb_device = "cpu"
logger.info(
f"[Task {task.task_id[:8]}] Loading {task.model_id} "
f"(backbone_device={bb_device}, codec_device={cc_device})"
)
# Schedule status transition after 3s (heuristic for download vs load)
async def _mark_loading() -> None:
await asyncio.sleep(3)
if task.status == ModelLoadStatus.DOWNLOADING:
task.status = ModelLoadStatus.LOADING
task.progress_message = "Initializing model..."
timer_task = asyncio.ensure_future(_mark_loading())
loop = asyncio.get_event_loop()
tts = await loop.run_in_executor(
self._executor,
self._create_tts_instance,
info.repo,
codec,
bb_device,
cc_device,
)
timer_task.cancel()
loaded = LoadedModel(
model_id=task.model_id,
codec_id=codec,
tts_instance=tts,
backbone_device=bb_device,
codec_device=cc_device,
)
self._models[task.model_id] = loaded
task.status = ModelLoadStatus.READY
task.progress_message = "Model ready"
task.completed_at = time.time()
logger.info(f"[Task {task.task_id[:8]}] {task.model_id} loaded successfully")
except Exception as e:
task.status = ModelLoadStatus.ERROR
task.error_message = str(e)
task.progress_message = "Failed"
task.completed_at = time.time()
logger.error(f"[Task {task.task_id[:8]}] Failed to load {task.model_id}: {e}")
async def load_model(
self,
model_id: str,
codec_id: str | None = None,
backbone_device: str | None = None,
codec_device: str | None = None,
) -> LoadedModel:
"""Synchronous load (blocks until done). Used by startup."""
if model_id in self._models:
logger.info(f"Model {model_id} already loaded")
return self._models[model_id]
info = get_backbone_info(model_id)
if info is None:
raise ValueError(f"Unknown model: {model_id}. Available: {list(BACKBONE_MODELS.keys())}")
codec = codec_id or settings.default_codec
bb_device = backbone_device or settings.resolved_backbone_device
cc_device = codec_device or settings.default_codec_device
if info.backend == BackendType.GGUF:
bb_device = "cpu"
logger.info(
f"Loading model {model_id} (repo={info.repo}, codec={codec}, "
f"backbone_device={bb_device}, codec_device={cc_device})"
)
loop = asyncio.get_event_loop()
tts = await loop.run_in_executor(
self._executor,
self._create_tts_instance,
info.repo,
codec,
bb_device,
cc_device,
)
loaded = LoadedModel(
model_id=model_id,
codec_id=codec,
tts_instance=tts,
backbone_device=bb_device,
codec_device=cc_device,
)
self._models[model_id] = loaded
logger.info(f"Model {model_id} loaded successfully")
return loaded
@staticmethod
def _create_tts_instance(
backbone_repo: str,
codec_repo: str,
backbone_device: str,
codec_device: str,
) -> object:
from neutts import NeuTTS
return NeuTTS(
backbone_repo=backbone_repo,
backbone_device=backbone_device,
codec_repo=codec_repo,
codec_device=codec_device,
)
async def unload_model(self, model_id: str) -> None:
if model_id not in self._models:
raise ValueError(f"Model {model_id} is not loaded")
loaded = self._models.pop(model_id)
async with loaded.lock:
del loaded.tts_instance
logger.info(f"Model {model_id} unloaded")
async def switch_device(
self,
model_id: str,
backbone_device: str | None = None,
codec_device: str | None = None,
) -> ModelLoadingTask:
"""Unload model and reload on a different device."""
if model_id not in self._models:
raise ValueError(f"Model {model_id} is not loaded")
loaded = self._models[model_id]
info = get_backbone_info(model_id)
if info and info.backend == BackendType.GGUF:
raise ValueError(
f"Model {model_id} is GGUF (llama.cpp) and only supports CPU. "
"Device switching is not available for GGUF models."
)
codec_id = loaded.codec_id
bb_device = backbone_device or loaded.backbone_device
cc_device = codec_device or loaded.codec_device
logger.info(f"Switching {model_id} device to backbone={bb_device}, codec={cc_device}")
await self.unload_model(model_id)
return await self.load_model_async(
model_id=model_id,
codec_id=codec_id,
backbone_device=bb_device,
codec_device=cc_device,
)
def cleanup_old_tasks(self, max_age_seconds: float = 3600) -> int:
"""Remove completed/errored tasks older than max_age_seconds."""
now = time.time()
to_remove = [
tid
for tid, t in self._loading_tasks.items()
if t.status in (ModelLoadStatus.READY, ModelLoadStatus.ERROR)
and t.completed_at > 0
and (now - t.completed_at) > max_age_seconds
]
for tid in to_remove:
del self._loading_tasks[tid]
return len(to_remove)
async def infer(
self,
model_id: str,
text: str,
ref_codes: object,
ref_text: str,
) -> np.ndarray:
loaded = self._get_loaded(model_id)
async with loaded.lock:
loop = asyncio.get_event_loop()
wav = await loop.run_in_executor(
self._executor,
loaded.tts_instance.infer,
text,
ref_codes,
ref_text,
)
return wav
async def infer_stream(
self,
model_id: str,
text: str,
ref_codes: object,
ref_text: str,
) -> AsyncGenerator[np.ndarray, None]:
loaded = self._get_loaded(model_id)
info = get_backbone_info(model_id)
if info is None or not info.supports_streaming:
raise ValueError(
f"Model {model_id} does not support streaming. "
"Only GGUF models support infer_stream()."
)
queue: asyncio.Queue[np.ndarray | None] = asyncio.Queue()
def _stream_worker() -> None:
try:
for chunk in loaded.tts_instance.infer_stream(text, ref_codes, ref_text):
queue.put_nowait(chunk)
except Exception as e:
logger.error(f"Streaming error for {model_id}: {e}")
finally:
queue.put_nowait(None)
async with loaded.lock:
loop = asyncio.get_event_loop()
loop.run_in_executor(self._executor, _stream_worker)
while True:
chunk = await queue.get()
if chunk is None:
break
yield chunk
async def encode_reference(self, model_id: str, audio_path: str) -> object:
loaded = self._get_loaded(model_id)
async with loaded.lock:
loop = asyncio.get_event_loop()
ref_codes = await loop.run_in_executor(
self._executor,
loaded.tts_instance.encode_reference,
audio_path,
)
return ref_codes
def _get_loaded(self, model_id: str) -> LoadedModel:
loaded = self._models.get(model_id)
if loaded is None:
raise ValueError(
f"Model {model_id} is not loaded. "
f"Loaded models: {list(self._models.keys())}"
)
return loaded
async def startup(self) -> None:
for model_id in settings.default_models_list:
try:
await self.load_model(model_id)
except Exception as e:
logger.error(f"Failed to load default model {model_id}: {e}")
async def shutdown(self) -> None:
model_ids = list(self._models.keys())
for model_id in model_ids:
try:
await self.unload_model(model_id)
except Exception:
pass
self._executor.shutdown(wait=False)