File size: 3,383 Bytes
e317d56 0f0ce9b e317d56 0f0ce9b e317d56 0f0ce9b e317d56 46cc63a e317d56 46cc63a e317d56 0f0ce9b e317d56 0f0ce9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | import time
from urllib.parse import unquote
from fastapi import APIRouter, HTTPException
from src.api.schemas import ModelInfo, ModelStatusEntry, ModelsStatusResponse, SelectModelRequest
from src.api.services import get_service
from src.api.state import PROJECT_ROOT, get_state
from src.service.model_service import AVAILABLE_MODELS, ModelService, check_model_availability
router = APIRouter(tags=["Model"])
def _switch_model_impl(model_name: str) -> dict[str, str]:
if model_name not in AVAILABLE_MODELS:
raise HTTPException(
status_code=400,
detail=f"Model '{model_name}' not available. Options: {list(AVAILABLE_MODELS.keys())}",
)
available, reason = check_model_availability(model_name, PROJECT_ROOT)
if not available:
raise HTTPException(status_code=400, detail=reason or "Model unavailable")
state = get_state()
prev_service = state["service"]
prev_name = state["model_name"]
new_service = ModelService(model_name, PROJECT_ROOT)
warmup = new_service.predict("warmup")
if warmup.get("error"):
state["service"] = prev_service
state["model_name"] = prev_name
raise HTTPException(status_code=400, detail=str(warmup["error"]))
state["service"] = new_service
state["model_name"] = model_name
return {"message": f"Active model set to '{model_name}'", "model": model_name}
@router.get("/model-info", response_model=ModelInfo)
async def get_model_info():
service = get_service()
info = service.get_model_info()
state = get_state()
uptime = round(time.time() - state["startup_time"], 1) if state["startup_time"] else 0.0
gap = info.get("train_test_gap_pp")
return ModelInfo(
name=state["model_name"],
type=info.get("type", "unknown"),
description=info.get("description", ""),
speed=info.get("speed", ""),
accuracy=info.get("accuracy", ""),
uptime_s=uptime,
predictions_served=state.get("predictions_served", 0),
display_banner=info.get("display_banner"),
train_test_gap_pp=float(gap) if gap is not None else None,
recommended_threshold=info.get("recommended_threshold"),
)
@router.get("/models/status", response_model=ModelsStatusResponse)
async def models_status():
state = get_state()
entries: list[ModelStatusEntry] = []
for name, cfg in AVAILABLE_MODELS.items():
available, reason = check_model_availability(name, PROJECT_ROOT)
entries.append(
ModelStatusEntry(
name=name,
available=available,
reason=reason,
type=cfg.get("type", "unknown"),
)
)
return ModelsStatusResponse(models=entries, active=state["model_name"] or "")
@router.get("/models")
async def list_models():
state = get_state()
return {"available": list(AVAILABLE_MODELS.keys()), "active": state["model_name"]}
@router.post("/models/select")
async def select_model(body: SelectModelRequest):
"""Switch active model (preferred — avoids URL-encoding issues in model names)."""
return _switch_model_impl(body.model_name.strip())
@router.put("/model/{model_name:path}")
async def switch_model(model_name: str):
"""Legacy path-based switch (decoded path segment)."""
return _switch_model_impl(unquote(model_name).strip())
|