from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, validator from typing import List, Dict, Any import os, json, joblib, numpy as np, pandas as pd, threading from huggingface_hub import snapshot_download import xgboost as xgb from pathlib import Path # ------------------------------- # Hugging Face repo config # ------------------------------- HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/models/hf") HF_REPO_ID = os.getenv("HF_REPO_ID", "ethnmcl/test-score-predictor-xgb") HF_TOKEN = os.getenv("HF_TOKEN", None) # only needed if repo is private # ------------------------------- # Global state # ------------------------------- _loaded = False _loaded_lock = threading.Lock() _pre = None _weights = None _schema = None _model = None # ------------------------------- # Loader functions # ------------------------------- def repo_snapshot(repo_id: str = None) -> str: """Download model repo snapshot (if not cached).""" repo_id = repo_id or HF_REPO_ID local_dir = snapshot_download( repo_id=repo_id, local_dir=HF_CACHE_DIR, local_dir_use_symlinks=False, token=HF_TOKEN, repo_type="model" ) return local_dir def load_model(): """Load preprocessor, weights, schema, and XGB model into memory.""" global _loaded, _pre, _weights, _schema, _model if _loaded: return with _loaded_lock: if _loaded: return base = Path(repo_snapshot(HF_REPO_ID)) _pre = joblib.load(base / "preprocessor.joblib") _weights = np.load(base / "weights.npy") with open(base / "schema.json") as f: _schema = json.load(f) _model = xgb.XGBRegressor() _model.load_model(str(base / "xgb_model.json")) _loaded = True def _transform(records): num = _schema["numeric"]; cat = _schema["categorical"] df = pd.DataFrame(records, columns=num + cat) Xt = _pre.transform(df) Xt = Xt.astype(float, copy=False) Xt[:, :len(num)] *= _weights return Xt def predict_one(record: dict) -> float: if not _loaded: load_model() Xt = _transform([record]) pred = float(_model.predict(Xt)[0]) return max(50.0, min(100.0, pred)) # clamp to dataset range def predict_batch(records: list) -> np.ndarray: if not _loaded: load_model() Xt = _transform(records) preds = _model.predict(Xt) return np.clip(preds, 50.0, 100.0) # ------------------------------- # FastAPI app # ------------------------------- app = FastAPI(title="Test Score Predictor API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.on_event("startup") def _startup(): repo_snapshot(HF_REPO_ID) load_model() # ------------------------------- # Request schemas # ------------------------------- class Record(BaseModel): Subject: str = Field(..., examples=["Mathematics"]) Current_Grade: int = Field(..., ge=60, le=98) Max_Test_Percentage: int = Field(..., ge=65, le=100) Days_Preparing: int = Field(..., ge=1, le=14) Hours_Studied: int = Field(..., ge=2, le=50) Study_Session_Average: float = Field(..., ge=0.1, le=10.0) Avg_Previous_Tests: int = Field(..., ge=55, le=95) Test_Difficulty: str = Field(..., examples=["Easy (20)", "Medium (30)", "Hard (50)"]) @validator("Study_Session_Average", always=True) def recompute_session_avg(cls, v, values): if "Hours_Studied" in values and "Days_Preparing" in values: return round(values["Hours_Studied"] / values["Days_Preparing"], 1) return v class PredictRequest(BaseModel): data: List[Record] # ------------------------------- # Routes # ------------------------------- @app.get("/health") def health() -> Dict[str, Any]: return {"status": "ok", "repo": HF_REPO_ID} @app.post("/predict") def predict(req: Record) -> Dict[str, Any]: return {"predicted_score": predict_one(req.dict())} @app.post("/predict-batch") def predict_many(req: PredictRequest) -> Dict[str, Any]: recs = [r.dict() for r in req.data] return {"predicted_scores": predict_batch(recs).tolist(), "count": len(recs)}