neuapi / api /src /routers /model_management.py
grimshaw's picture
Upload folder using huggingface_hub
35bb6f4 verified
Raw
History Blame Contribute Delete
5.69 kB
from __future__ import annotations
import time
from fastapi import APIRouter, HTTPException
from loguru import logger
from api.src.core.model_config import BACKBONE_MODELS, CODEC_MODELS, get_backbone_info
from api.src.inference.model_manager import ModelLoadingTask, ModelManager
from api.src.structures.schemas import (
LoadedModelInfo,
LoadedModelsResponse,
LoadModelRequest,
ModelLoadTaskResponse,
ModelRegistryResponse,
RegistryModelInfo,
SwitchDeviceRequest,
UnloadModelResponse,
make_error,
)
router = APIRouter(prefix="/v1/models", tags=["Model Management"])
def _task_to_response(task: ModelLoadingTask) -> ModelLoadTaskResponse:
elapsed = 0.0
if task.started_at > 0:
end = task.completed_at if task.completed_at > 0 else time.time()
elapsed = round(end - task.started_at, 2)
return ModelLoadTaskResponse(
task_id=task.task_id,
model_id=task.model_id,
status=task.status.value,
progress_message=task.progress_message,
error_message=task.error_message,
elapsed_seconds=elapsed,
)
@router.post("/load", response_model=ModelLoadTaskResponse)
async def load_model(request: LoadModelRequest) -> ModelLoadTaskResponse:
"""Start loading a model (non-blocking). Returns a task for polling."""
model_manager = ModelManager.get_instance()
info = get_backbone_info(request.model_id)
if info is None:
raise HTTPException(
status_code=400,
detail=make_error(
f"Unknown model '{request.model_id}'. "
f"Available: {list(BACKBONE_MODELS.keys())}"
),
)
try:
task = await model_manager.load_model_async(
model_id=request.model_id,
codec_id=request.codec,
backbone_device=request.backbone_device,
codec_device=request.codec_device,
)
# Cleanup old tasks opportunistically
model_manager.cleanup_old_tasks()
return _task_to_response(task)
except Exception as e:
logger.error(f"Failed to start loading model {request.model_id}: {e}")
raise HTTPException(status_code=500, detail=make_error(str(e), "server_error", 500))
@router.get("/load/{task_id}", response_model=ModelLoadTaskResponse)
async def get_load_status(task_id: str) -> ModelLoadTaskResponse:
"""Poll the status of a model loading task."""
model_manager = ModelManager.get_instance()
task = model_manager.get_task(task_id)
if task is None:
raise HTTPException(
status_code=404,
detail=make_error(f"Task '{task_id}' not found"),
)
return _task_to_response(task)
@router.get("/loaded", response_model=LoadedModelsResponse)
async def get_loaded_models() -> LoadedModelsResponse:
"""List all loaded models with device details."""
model_manager = ModelManager.get_instance()
models = []
for model_id, loaded in model_manager.loaded_models.items():
info = get_backbone_info(model_id)
models.append(LoadedModelInfo(
model_id=model_id,
codec=loaded.codec_id,
backbone_device=loaded.backbone_device,
codec_device=loaded.codec_device,
language=info.language if info else None,
backend=info.backend.value if info else None,
supports_streaming=info.supports_streaming if info else False,
))
return LoadedModelsResponse(models=models)
@router.post("/{model_id}/switch-device", response_model=ModelLoadTaskResponse)
async def switch_device(model_id: str, request: SwitchDeviceRequest) -> ModelLoadTaskResponse:
"""Switch a loaded model to a different device (CPU <-> GPU)."""
model_manager = ModelManager.get_instance()
try:
task = await model_manager.switch_device(
model_id=model_id,
backbone_device=request.backbone_device,
codec_device=request.codec_device,
)
return _task_to_response(task)
except ValueError as e:
raise HTTPException(status_code=400, detail=make_error(str(e)))
except Exception as e:
logger.error(f"Failed to switch device for {model_id}: {e}")
raise HTTPException(status_code=500, detail=make_error(str(e), "server_error", 500))
@router.delete("/{model_id}", response_model=UnloadModelResponse)
async def unload_model(model_id: str) -> UnloadModelResponse:
"""Unload a model from memory."""
model_manager = ModelManager.get_instance()
try:
await model_manager.unload_model(model_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=make_error(str(e)))
return UnloadModelResponse(model_id=model_id, status="unloaded")
@router.get("/registry", response_model=ModelRegistryResponse)
async def get_registry() -> ModelRegistryResponse:
"""List all available models (not just loaded ones)."""
model_manager = ModelManager.get_instance()
backbones = [
RegistryModelInfo(
model_id=info.model_id,
repo=info.repo,
language=info.language,
backend=info.backend.value,
supports_streaming=info.supports_streaming,
description=info.description,
loaded=model_manager.is_loaded(info.model_id),
)
for info in BACKBONE_MODELS.values()
]
codecs = [
{
"codec_id": c.codec_id,
"repo": c.repo,
"type": c.codec_type,
"description": c.description,
}
for c in CODEC_MODELS.values()
]
return ModelRegistryResponse(backbones=backbones, codecs=codecs)