| from __future__ import annotations |
|
|
| import time |
|
|
| from fastapi import APIRouter, HTTPException |
| from loguru import logger |
|
|
| from api.src.core.model_config import BACKBONE_MODELS, CODEC_MODELS, get_backbone_info |
| from api.src.inference.model_manager import ModelLoadingTask, ModelManager |
| from api.src.structures.schemas import ( |
| LoadedModelInfo, |
| LoadedModelsResponse, |
| LoadModelRequest, |
| ModelLoadTaskResponse, |
| ModelRegistryResponse, |
| RegistryModelInfo, |
| SwitchDeviceRequest, |
| UnloadModelResponse, |
| make_error, |
| ) |
|
|
| router = APIRouter(prefix="/v1/models", tags=["Model Management"]) |
|
|
|
|
| def _task_to_response(task: ModelLoadingTask) -> ModelLoadTaskResponse: |
| elapsed = 0.0 |
| if task.started_at > 0: |
| end = task.completed_at if task.completed_at > 0 else time.time() |
| elapsed = round(end - task.started_at, 2) |
| return ModelLoadTaskResponse( |
| task_id=task.task_id, |
| model_id=task.model_id, |
| status=task.status.value, |
| progress_message=task.progress_message, |
| error_message=task.error_message, |
| elapsed_seconds=elapsed, |
| ) |
|
|
|
|
| @router.post("/load", response_model=ModelLoadTaskResponse) |
| async def load_model(request: LoadModelRequest) -> ModelLoadTaskResponse: |
| """Start loading a model (non-blocking). Returns a task for polling.""" |
| model_manager = ModelManager.get_instance() |
| info = get_backbone_info(request.model_id) |
|
|
| if info is None: |
| raise HTTPException( |
| status_code=400, |
| detail=make_error( |
| f"Unknown model '{request.model_id}'. " |
| f"Available: {list(BACKBONE_MODELS.keys())}" |
| ), |
| ) |
|
|
| try: |
| task = await model_manager.load_model_async( |
| model_id=request.model_id, |
| codec_id=request.codec, |
| backbone_device=request.backbone_device, |
| codec_device=request.codec_device, |
| ) |
| |
| model_manager.cleanup_old_tasks() |
| return _task_to_response(task) |
| except Exception as e: |
| logger.error(f"Failed to start loading model {request.model_id}: {e}") |
| raise HTTPException(status_code=500, detail=make_error(str(e), "server_error", 500)) |
|
|
|
|
| @router.get("/load/{task_id}", response_model=ModelLoadTaskResponse) |
| async def get_load_status(task_id: str) -> ModelLoadTaskResponse: |
| """Poll the status of a model loading task.""" |
| model_manager = ModelManager.get_instance() |
| task = model_manager.get_task(task_id) |
|
|
| if task is None: |
| raise HTTPException( |
| status_code=404, |
| detail=make_error(f"Task '{task_id}' not found"), |
| ) |
|
|
| return _task_to_response(task) |
|
|
|
|
| @router.get("/loaded", response_model=LoadedModelsResponse) |
| async def get_loaded_models() -> LoadedModelsResponse: |
| """List all loaded models with device details.""" |
| model_manager = ModelManager.get_instance() |
| models = [] |
|
|
| for model_id, loaded in model_manager.loaded_models.items(): |
| info = get_backbone_info(model_id) |
| models.append(LoadedModelInfo( |
| model_id=model_id, |
| codec=loaded.codec_id, |
| backbone_device=loaded.backbone_device, |
| codec_device=loaded.codec_device, |
| language=info.language if info else None, |
| backend=info.backend.value if info else None, |
| supports_streaming=info.supports_streaming if info else False, |
| )) |
|
|
| return LoadedModelsResponse(models=models) |
|
|
|
|
| @router.post("/{model_id}/switch-device", response_model=ModelLoadTaskResponse) |
| async def switch_device(model_id: str, request: SwitchDeviceRequest) -> ModelLoadTaskResponse: |
| """Switch a loaded model to a different device (CPU <-> GPU).""" |
| model_manager = ModelManager.get_instance() |
|
|
| try: |
| task = await model_manager.switch_device( |
| model_id=model_id, |
| backbone_device=request.backbone_device, |
| codec_device=request.codec_device, |
| ) |
| return _task_to_response(task) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=make_error(str(e))) |
| except Exception as e: |
| logger.error(f"Failed to switch device for {model_id}: {e}") |
| raise HTTPException(status_code=500, detail=make_error(str(e), "server_error", 500)) |
|
|
|
|
| @router.delete("/{model_id}", response_model=UnloadModelResponse) |
| async def unload_model(model_id: str) -> UnloadModelResponse: |
| """Unload a model from memory.""" |
| model_manager = ModelManager.get_instance() |
|
|
| try: |
| await model_manager.unload_model(model_id) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=make_error(str(e))) |
|
|
| return UnloadModelResponse(model_id=model_id, status="unloaded") |
|
|
|
|
| @router.get("/registry", response_model=ModelRegistryResponse) |
| async def get_registry() -> ModelRegistryResponse: |
| """List all available models (not just loaded ones).""" |
| model_manager = ModelManager.get_instance() |
|
|
| backbones = [ |
| RegistryModelInfo( |
| model_id=info.model_id, |
| repo=info.repo, |
| language=info.language, |
| backend=info.backend.value, |
| supports_streaming=info.supports_streaming, |
| description=info.description, |
| loaded=model_manager.is_loaded(info.model_id), |
| ) |
| for info in BACKBONE_MODELS.values() |
| ] |
|
|
| codecs = [ |
| { |
| "codec_id": c.codec_id, |
| "repo": c.repo, |
| "type": c.codec_type, |
| "description": c.description, |
| } |
| for c in CODEC_MODELS.values() |
| ] |
|
|
| return ModelRegistryResponse(backbones=backbones, codecs=codecs) |
|
|