File size: 9,158 Bytes
f381be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1552b5a
f381be8
 
 
1552b5a
f381be8
 
 
 
d3996f2
f381be8
 
 
 
 
d3996f2
f381be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3996f2
 
f381be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3996f2
f381be8
 
 
 
1552b5a
d3996f2
1552b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d3996f2
1552b5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f381be8
 
 
d3996f2
f381be8
 
 
 
 
d3996f2
 
 
f381be8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
api.main
========
FastAPI application entry-point for the AI Battery Lifecycle Predictor.

Architecture
------------
- **v1 (Classical)**    : Ridge, Lasso, ElasticNet, KNN Γ—3, SVR,
                          Random Forest, XGBoost, LightGBM
- **v2 (Deep)**         : Vanilla LSTM, BiLSTM, GRU, Attention LSTM,
                          BatteryGPT, TFT, iTransformer Γ—3, VAE-LSTM
- **v2.6 (Ensemble)**   : BestEnsemble β€” weighted average of RF + XGB + LGB
                          (weights proportional to RΒ²)

Mounted routes
--------------
- ``/api/*``      REST endpoints  (predict, batch, recommend, models, visualize)
- ``/gradio``     Gradio interactive demo  (optional, requires *gradio* package)
- ``/``           React SPA  (served from ``frontend/dist/``)

Key endpoints
-------------
- ``POST /api/predict``          β€” single-cycle SOH + RUL prediction
- ``POST /api/predict/ensemble`` β€” always uses BestEnsemble (v2.6)
- ``POST /api/predict/batch``    β€” batch prediction from JSON array
- ``GET  /api/models``           β€” list all models with version / RΒ² metadata
- ``GET  /api/models/versions``  β€” group models by generation (v1/v2)
- ``GET  /health``               β€” liveness probe

Run locally
-----------
::

    uvicorn api.main:app --host 0.0.0.0 --port 7860 --reload

Docker
------
::

    docker compose up --build
"""

from __future__ import annotations

import asyncio
from contextlib import asynccontextmanager
from pathlib import Path

from fastapi import BackgroundTasks, FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse

from api.model_registry import registry, registry_v1, registry_v2, registry_v3
from api.schemas import HealthResponse
from src.utils.logger import get_logger

log = get_logger(__name__)

__version__ = "3.0.0"

# ── Static frontend path ────────────────────────────────────────────────────
_HERE = Path(__file__).resolve().parent
_FRONTEND_DIST = _HERE.parent / "frontend" / "dist"


# ── Lifespan ─────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
    """Load models on startup, clean up on shutdown."""
    log.info("Loading model registries …")
    registry_v1.load_all()
    log.info("v1 registry ready β€” %d models loaded", registry_v1.model_count)
    registry_v2.load_all()
    log.info("v2 registry ready β€” %d models loaded", registry_v2.model_count)
    registry_v3.load_all()
    log.info("v3 registry ready β€” %d models loaded", registry_v3.model_count)
    yield
    log.info("Shutting down battery-lifecycle API")


# ── App ──────────────────────────────────────────────────────────────────────
app = FastAPI(
    title="AI Battery Lifecycle Predictor",
    description=(
        "Predict SOH, RUL, and degradation state of Li-ion batteries "
        "using models trained on the NASA PCoE dataset."
    ),
    version=__version__,
    lifespan=lifespan,
    docs_url="/docs",
    redoc_url="/redoc",
)

# CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# ── Health check ─────────────────────────────────────────────────────────────
@app.get("/health", response_model=HealthResponse, tags=["meta"])
async def health():
    return HealthResponse(
        status="ok",
        version=__version__,
        models_loaded=registry_v1.model_count + registry_v2.model_count + registry_v3.model_count,
        device=registry.device,
    )


# ── Version management ───────────────────────────────────────────────────────
_REGISTRIES = {"v1": registry_v1, "v2": registry_v2, "v3": registry_v3}
_version_status: dict[str, str] = {}   # "downloading" | "ready" | "error"


def _artifacts_dir() -> Path:
    return Path(__file__).resolve().parent.parent / "artifacts"


def _version_loaded(version: str) -> bool:
    base = _artifacts_dir() / version / "models" / "classical"
    return any(base.glob("*.joblib")) if base.exists() else False


@app.get("/api/versions", tags=["meta"])
async def list_versions():
    """Return all known versions with loaded / downloading status."""
    return [
        {
            "id": v,
            "display": f"Version {v[1]}",
            "loaded": _version_loaded(v),
            "model_count": _REGISTRIES[v].model_count,
            "status": _version_status.get(v, "ready" if _version_loaded(v) else "not_downloaded"),
        }
        for v in ["v3", "v2", "v1"]
    ]


async def _bg_load_version(version: str) -> None:
    import subprocess, sys as _sys
    try:
        proc = await asyncio.create_subprocess_exec(
            _sys.executable, "scripts/download_models.py", "--version", version,
            stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT,
        )
        await proc.wait()
        if proc.returncode == 0:
            _REGISTRIES[version].load_all()
            _version_status[version] = "ready"
            log.info("Version %s loaded on demand β€” %d models", version,
                     _REGISTRIES[version].model_count)
        else:
            _version_status[version] = "error"
            log.error("download_models.py failed for version %s", version)
    except Exception as exc:
        _version_status[version] = "error"
        log.error("Failed to load version %s: %s", version, exc)


@app.post("/api/versions/{version}/load", tags=["meta"])
async def load_version(version: str, background_tasks: BackgroundTasks):
    """Download + activate a model version from HF Hub (runs in background)."""
    if version not in _REGISTRIES:
        raise HTTPException(status_code=400, detail=f"Unknown version '{version}'")
    if _version_status.get(version) == "downloading":
        return {"status": "downloading", "version": version}
    _version_status[version] = "downloading"
    background_tasks.add_task(_bg_load_version, version)
    return {"status": "downloading", "version": version}


# ── Include routers ──────────────────────────────────────────────────────────
from api.routers.predict import router as predict_router, v1_router
from api.routers.predict_v2 import router as predict_v2_router
from api.routers.predict_v3 import router as predict_v3_router
from api.routers.visualize import router as viz_router
from api.routers.simulate import router as simulate_router

app.include_router(predict_router)    # /api/* (default, uses v2 registry)
app.include_router(v1_router)         # /api/v1/* (legacy v1 models)
app.include_router(predict_v2_router) # /api/v2/* (v2 models)
app.include_router(predict_v3_router) # /api/v3/* (v3 models, best accuracy)
app.include_router(simulate_router)   # /api/v3/simulate (ML-driven simulation)
app.include_router(viz_router)


# ── Mount Gradio ─────────────────────────────────────────────────────────────
try:
    import gradio as gr
    from api.gradio_app import create_gradio_app

    gradio_app = create_gradio_app()
    app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
    log.info("Gradio UI mounted at /gradio")
except ImportError:
    log.warning("Gradio not installed β€” /gradio endpoint unavailable")


# ── Serve React SPA ──────────────────────────────────────────────────────────
if _FRONTEND_DIST.exists() and (_FRONTEND_DIST / "index.html").exists():
    app.mount("/assets", StaticFiles(directory=str(_FRONTEND_DIST / "assets")), name="static-assets")

    @app.get("/{full_path:path}", include_in_schema=False)
    async def spa_catch_all(full_path: str):
        """Serve React SPA for any path not matched by API routes."""
        file_path = _FRONTEND_DIST / full_path
        if file_path.is_file():
            return FileResponse(file_path)
        return FileResponse(_FRONTEND_DIST / "index.html")

    log.info("React SPA served from %s", _FRONTEND_DIST)
else:
    @app.get("/", include_in_schema=False)
    async def root():
        return {
            "message": "AI Battery Lifecycle Predictor API",
            "docs": "/docs",
            "gradio": "/gradio",
            "health": "/health",
        }