File size: 4,274 Bytes
fe44b0a
 
 
 
440feb0
 
 
 
 
fe44b0a
 
 
440feb0
 
fe44b0a
440feb0
fe44b0a
 
 
440feb0
fe44b0a
440feb0
 
 
 
 
fe44b0a
 
 
440feb0
fe44b0a
440feb0
 
 
 
 
 
 
 
 
 
 
fe44b0a
440feb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe44b0a
440feb0
 
 
 
 
 
 
fe44b0a
440feb0
 
 
 
 
 
 
4d509f3
fe44b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, validator
from typing import List, Dict, Any
import os, json, joblib, numpy as np, pandas as pd, threading
from huggingface_hub import snapshot_download
import xgboost as xgb
from pathlib import Path

# -------------------------------
# Hugging Face repo config
# -------------------------------
HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/models/hf")
HF_REPO_ID   = os.getenv("HF_REPO_ID", "ethnmcl/test-score-predictor-xgb")
HF_TOKEN     = os.getenv("HF_TOKEN", None)  # only needed if repo is private

# -------------------------------
# Global state
# -------------------------------
_loaded = False
_loaded_lock = threading.Lock()
_pre = None
_weights = None
_schema = None
_model = None

# -------------------------------
# Loader functions
# -------------------------------
def repo_snapshot(repo_id: str = None) -> str:
    """Download model repo snapshot (if not cached)."""
    repo_id = repo_id or HF_REPO_ID
    local_dir = snapshot_download(
        repo_id=repo_id,
        local_dir=HF_CACHE_DIR,
        local_dir_use_symlinks=False,
        token=HF_TOKEN,
        repo_type="model"
    )
    return local_dir

def load_model():
    """Load preprocessor, weights, schema, and XGB model into memory."""
    global _loaded, _pre, _weights, _schema, _model
    if _loaded:
        return
    with _loaded_lock:
        if _loaded:
            return
        base = Path(repo_snapshot(HF_REPO_ID))
        _pre = joblib.load(base / "preprocessor.joblib")
        _weights = np.load(base / "weights.npy")
        with open(base / "schema.json") as f:
            _schema = json.load(f)
        _model = xgb.XGBRegressor()
        _model.load_model(str(base / "xgb_model.json"))
        _loaded = True

def _transform(records):
    num = _schema["numeric"]; cat = _schema["categorical"]
    df = pd.DataFrame(records, columns=num + cat)
    Xt = _pre.transform(df)
    Xt = Xt.astype(float, copy=False)
    Xt[:, :len(num)] *= _weights
    return Xt

def predict_one(record: dict) -> float:
    if not _loaded:
        load_model()
    Xt = _transform([record])
    pred = float(_model.predict(Xt)[0])
    return max(50.0, min(100.0, pred))  # clamp to dataset range

def predict_batch(records: list) -> np.ndarray:
    if not _loaded:
        load_model()
    Xt = _transform(records)
    preds = _model.predict(Xt)
    return np.clip(preds, 50.0, 100.0)

# -------------------------------
# FastAPI app
# -------------------------------
app = FastAPI(title="Test Score Predictor API", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True,
    allow_methods=["*"], allow_headers=["*"],
)

@app.on_event("startup")
def _startup():
    repo_snapshot(HF_REPO_ID)
    load_model()

# -------------------------------
# Request schemas
# -------------------------------
class Record(BaseModel):
    Subject: str = Field(..., examples=["Mathematics"])
    Current_Grade: int = Field(..., ge=60, le=98)
    Max_Test_Percentage: int = Field(..., ge=65, le=100)
    Days_Preparing: int = Field(..., ge=1, le=14)
    Hours_Studied: int = Field(..., ge=2, le=50)
    Study_Session_Average: float = Field(..., ge=0.1, le=10.0)
    Avg_Previous_Tests: int = Field(..., ge=55, le=95)
    Test_Difficulty: str = Field(..., examples=["Easy (20)", "Medium (30)", "Hard (50)"])

    @validator("Study_Session_Average", always=True)
    def recompute_session_avg(cls, v, values):
        if "Hours_Studied" in values and "Days_Preparing" in values:
            return round(values["Hours_Studied"] / values["Days_Preparing"], 1)
        return v

class PredictRequest(BaseModel):
    data: List[Record]

# -------------------------------
# Routes
# -------------------------------
@app.get("/health")
def health() -> Dict[str, Any]:
    return {"status": "ok", "repo": HF_REPO_ID}

@app.post("/predict")
def predict(req: Record) -> Dict[str, Any]:
    return {"predicted_score": predict_one(req.dict())}

@app.post("/predict-batch")
def predict_many(req: PredictRequest) -> Dict[str, Any]:
    recs = [r.dict() for r in req.data]
    return {"predicted_scores": predict_batch(recs).tolist(), "count": len(recs)}