import time from urllib.parse import unquote from fastapi import APIRouter, HTTPException from src.api.schemas import ModelInfo, ModelStatusEntry, ModelsStatusResponse, SelectModelRequest from src.api.services import get_service from src.api.state import PROJECT_ROOT, get_state from src.service.model_service import AVAILABLE_MODELS, ModelService, check_model_availability router = APIRouter(tags=["Model"]) def _switch_model_impl(model_name: str) -> dict[str, str]: if model_name not in AVAILABLE_MODELS: raise HTTPException( status_code=400, detail=f"Model '{model_name}' not available. Options: {list(AVAILABLE_MODELS.keys())}", ) available, reason = check_model_availability(model_name, PROJECT_ROOT) if not available: raise HTTPException(status_code=400, detail=reason or "Model unavailable") state = get_state() prev_service = state["service"] prev_name = state["model_name"] new_service = ModelService(model_name, PROJECT_ROOT) warmup = new_service.predict("warmup") if warmup.get("error"): state["service"] = prev_service state["model_name"] = prev_name raise HTTPException(status_code=400, detail=str(warmup["error"])) state["service"] = new_service state["model_name"] = model_name return {"message": f"Active model set to '{model_name}'", "model": model_name} @router.get("/model-info", response_model=ModelInfo) async def get_model_info(): service = get_service() info = service.get_model_info() state = get_state() uptime = round(time.time() - state["startup_time"], 1) if state["startup_time"] else 0.0 gap = info.get("train_test_gap_pp") return ModelInfo( name=state["model_name"], type=info.get("type", "unknown"), description=info.get("description", ""), speed=info.get("speed", ""), accuracy=info.get("accuracy", ""), uptime_s=uptime, predictions_served=state.get("predictions_served", 0), display_banner=info.get("display_banner"), train_test_gap_pp=float(gap) if gap is not None else None, recommended_threshold=info.get("recommended_threshold"), ) @router.get("/models/status", response_model=ModelsStatusResponse) async def models_status(): state = get_state() entries: list[ModelStatusEntry] = [] for name, cfg in AVAILABLE_MODELS.items(): available, reason = check_model_availability(name, PROJECT_ROOT) entries.append( ModelStatusEntry( name=name, available=available, reason=reason, type=cfg.get("type", "unknown"), ) ) return ModelsStatusResponse(models=entries, active=state["model_name"] or "") @router.get("/models") async def list_models(): state = get_state() return {"available": list(AVAILABLE_MODELS.keys()), "active": state["model_name"]} @router.post("/models/select") async def select_model(body: SelectModelRequest): """Switch active model (preferred — avoids URL-encoding issues in model names).""" return _switch_model_impl(body.model_name.strip()) @router.put("/model/{model_name:path}") async def switch_model(model_name: str): """Legacy path-based switch (decoded path segment).""" return _switch_model_impl(unquote(model_name).strip())