Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, validator | |
| from typing import List, Dict, Any | |
| import os, json, joblib, numpy as np, pandas as pd, threading | |
| from huggingface_hub import snapshot_download | |
| import xgboost as xgb | |
| from pathlib import Path | |
| # ------------------------------- | |
| # Hugging Face repo config | |
| # ------------------------------- | |
| HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/models/hf") | |
| HF_REPO_ID = os.getenv("HF_REPO_ID", "ethnmcl/test-score-predictor-xgb") | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) # only needed if repo is private | |
| # ------------------------------- | |
| # Global state | |
| # ------------------------------- | |
| _loaded = False | |
| _loaded_lock = threading.Lock() | |
| _pre = None | |
| _weights = None | |
| _schema = None | |
| _model = None | |
| # ------------------------------- | |
| # Loader functions | |
| # ------------------------------- | |
| def repo_snapshot(repo_id: str = None) -> str: | |
| """Download model repo snapshot (if not cached).""" | |
| repo_id = repo_id or HF_REPO_ID | |
| local_dir = snapshot_download( | |
| repo_id=repo_id, | |
| local_dir=HF_CACHE_DIR, | |
| local_dir_use_symlinks=False, | |
| token=HF_TOKEN, | |
| repo_type="model" | |
| ) | |
| return local_dir | |
| def load_model(): | |
| """Load preprocessor, weights, schema, and XGB model into memory.""" | |
| global _loaded, _pre, _weights, _schema, _model | |
| if _loaded: | |
| return | |
| with _loaded_lock: | |
| if _loaded: | |
| return | |
| base = Path(repo_snapshot(HF_REPO_ID)) | |
| _pre = joblib.load(base / "preprocessor.joblib") | |
| _weights = np.load(base / "weights.npy") | |
| with open(base / "schema.json") as f: | |
| _schema = json.load(f) | |
| _model = xgb.XGBRegressor() | |
| _model.load_model(str(base / "xgb_model.json")) | |
| _loaded = True | |
| def _transform(records): | |
| num = _schema["numeric"]; cat = _schema["categorical"] | |
| df = pd.DataFrame(records, columns=num + cat) | |
| Xt = _pre.transform(df) | |
| Xt = Xt.astype(float, copy=False) | |
| Xt[:, :len(num)] *= _weights | |
| return Xt | |
| def predict_one(record: dict) -> float: | |
| if not _loaded: | |
| load_model() | |
| Xt = _transform([record]) | |
| pred = float(_model.predict(Xt)[0]) | |
| return max(50.0, min(100.0, pred)) # clamp to dataset range | |
| def predict_batch(records: list) -> np.ndarray: | |
| if not _loaded: | |
| load_model() | |
| Xt = _transform(records) | |
| preds = _model.predict(Xt) | |
| return np.clip(preds, 50.0, 100.0) | |
| # ------------------------------- | |
| # FastAPI app | |
| # ------------------------------- | |
| app = FastAPI(title="Test Score Predictor API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| def _startup(): | |
| repo_snapshot(HF_REPO_ID) | |
| load_model() | |
| # ------------------------------- | |
| # Request schemas | |
| # ------------------------------- | |
| class Record(BaseModel): | |
| Subject: str = Field(..., examples=["Mathematics"]) | |
| Current_Grade: int = Field(..., ge=60, le=98) | |
| Max_Test_Percentage: int = Field(..., ge=65, le=100) | |
| Days_Preparing: int = Field(..., ge=1, le=14) | |
| Hours_Studied: int = Field(..., ge=2, le=50) | |
| Study_Session_Average: float = Field(..., ge=0.1, le=10.0) | |
| Avg_Previous_Tests: int = Field(..., ge=55, le=95) | |
| Test_Difficulty: str = Field(..., examples=["Easy (20)", "Medium (30)", "Hard (50)"]) | |
| def recompute_session_avg(cls, v, values): | |
| if "Hours_Studied" in values and "Days_Preparing" in values: | |
| return round(values["Hours_Studied"] / values["Days_Preparing"], 1) | |
| return v | |
| class PredictRequest(BaseModel): | |
| data: List[Record] | |
| # ------------------------------- | |
| # Routes | |
| # ------------------------------- | |
| def health() -> Dict[str, Any]: | |
| return {"status": "ok", "repo": HF_REPO_ID} | |
| def predict(req: Record) -> Dict[str, Any]: | |
| return {"predicted_score": predict_one(req.dict())} | |
| def predict_many(req: PredictRequest) -> Dict[str, Any]: | |
| recs = [r.dict() for r in req.data] | |
| return {"predicted_scores": predict_batch(recs).tolist(), "count": len(recs)} | |