File size: 3,383 Bytes
e317d56
0f0ce9b
e317d56
 
 
0f0ce9b
e317d56
 
 
 
 
 
 
0f0ce9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e317d56
 
 
 
 
 
46cc63a
e317d56
 
 
 
 
 
 
 
46cc63a
 
 
e317d56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f0ce9b
 
 
 
e317d56
 
0f0ce9b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import time
from urllib.parse import unquote

from fastapi import APIRouter, HTTPException

from src.api.schemas import ModelInfo, ModelStatusEntry, ModelsStatusResponse, SelectModelRequest
from src.api.services import get_service
from src.api.state import PROJECT_ROOT, get_state
from src.service.model_service import AVAILABLE_MODELS, ModelService, check_model_availability

router = APIRouter(tags=["Model"])


def _switch_model_impl(model_name: str) -> dict[str, str]:
    if model_name not in AVAILABLE_MODELS:
        raise HTTPException(
            status_code=400,
            detail=f"Model '{model_name}' not available. Options: {list(AVAILABLE_MODELS.keys())}",
        )

    available, reason = check_model_availability(model_name, PROJECT_ROOT)
    if not available:
        raise HTTPException(status_code=400, detail=reason or "Model unavailable")

    state = get_state()
    prev_service = state["service"]
    prev_name = state["model_name"]

    new_service = ModelService(model_name, PROJECT_ROOT)
    warmup = new_service.predict("warmup")
    if warmup.get("error"):
        state["service"] = prev_service
        state["model_name"] = prev_name
        raise HTTPException(status_code=400, detail=str(warmup["error"]))

    state["service"] = new_service
    state["model_name"] = model_name
    return {"message": f"Active model set to '{model_name}'", "model": model_name}


@router.get("/model-info", response_model=ModelInfo)
async def get_model_info():
    service = get_service()
    info = service.get_model_info()
    state = get_state()
    uptime = round(time.time() - state["startup_time"], 1) if state["startup_time"] else 0.0
    gap = info.get("train_test_gap_pp")
    return ModelInfo(
        name=state["model_name"],
        type=info.get("type", "unknown"),
        description=info.get("description", ""),
        speed=info.get("speed", ""),
        accuracy=info.get("accuracy", ""),
        uptime_s=uptime,
        predictions_served=state.get("predictions_served", 0),
        display_banner=info.get("display_banner"),
        train_test_gap_pp=float(gap) if gap is not None else None,
        recommended_threshold=info.get("recommended_threshold"),
    )


@router.get("/models/status", response_model=ModelsStatusResponse)
async def models_status():
    state = get_state()
    entries: list[ModelStatusEntry] = []
    for name, cfg in AVAILABLE_MODELS.items():
        available, reason = check_model_availability(name, PROJECT_ROOT)
        entries.append(
            ModelStatusEntry(
                name=name,
                available=available,
                reason=reason,
                type=cfg.get("type", "unknown"),
            )
        )
    return ModelsStatusResponse(models=entries, active=state["model_name"] or "")


@router.get("/models")
async def list_models():
    state = get_state()
    return {"available": list(AVAILABLE_MODELS.keys()), "active": state["model_name"]}


@router.post("/models/select")
async def select_model(body: SelectModelRequest):
    """Switch active model (preferred — avoids URL-encoding issues in model names)."""
    return _switch_model_impl(body.model_name.strip())


@router.put("/model/{model_name:path}")
async def switch_model(model_name: str):
    """Legacy path-based switch (decoded path segment)."""
    return _switch_model_impl(unquote(model_name).strip())