ethnmcl commited on
Commit
fe44b0a
·
verified ·
1 Parent(s): 1cc7225

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -4
app.py CHANGED
@@ -1,20 +1,34 @@
 
 
 
 
1
  import os, json, joblib, numpy as np, pandas as pd, threading
2
  from huggingface_hub import snapshot_download
3
  import xgboost as xgb
4
  from pathlib import Path
5
 
 
 
 
6
  HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/models/hf")
7
  HF_REPO_ID = os.getenv("HF_REPO_ID", "ethnmcl/test-score-predictor-xgb")
8
- HF_TOKEN = os.getenv("HF_TOKEN", None) # set as Space secret for private repos
9
 
10
- _loaded_lock = threading.Lock()
 
 
11
  _loaded = False
 
12
  _pre = None
13
  _weights = None
14
  _schema = None
15
  _model = None
16
 
 
 
 
17
  def repo_snapshot(repo_id: str = None) -> str:
 
18
  repo_id = repo_id or HF_REPO_ID
19
  local_dir = snapshot_download(
20
  repo_id=repo_id,
@@ -26,6 +40,7 @@ def repo_snapshot(repo_id: str = None) -> str:
26
  return local_dir
27
 
28
  def load_model():
 
29
  global _loaded, _pre, _weights, _schema, _model
30
  if _loaded:
31
  return
@@ -46,7 +61,7 @@ def _transform(records):
46
  df = pd.DataFrame(records, columns=num + cat)
47
  Xt = _pre.transform(df)
48
  Xt = Xt.astype(float, copy=False)
49
- Xt[:, :len(num)] *= _weights # post-transform numeric weighting
50
  return Xt
51
 
52
  def predict_one(record: dict) -> float:
@@ -54,7 +69,7 @@ def predict_one(record: dict) -> float:
54
  load_model()
55
  Xt = _transform([record])
56
  pred = float(_model.predict(Xt)[0])
57
- return max(50.0, min(100.0, pred)) # optional clamp to match dataset range
58
 
59
  def predict_batch(records: list) -> np.ndarray:
60
  if not _loaded:
@@ -63,3 +78,56 @@ def predict_batch(records: list) -> np.ndarray:
63
  preds = _model.predict(Xt)
64
  return np.clip(preds, 50.0, 100.0)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel, Field, validator
4
+ from typing import List, Dict, Any
5
  import os, json, joblib, numpy as np, pandas as pd, threading
6
  from huggingface_hub import snapshot_download
7
  import xgboost as xgb
8
  from pathlib import Path
9
 
10
+ # -------------------------------
11
+ # Hugging Face repo config
12
+ # -------------------------------
13
  HF_CACHE_DIR = os.getenv("HF_CACHE_DIR", "/models/hf")
14
  HF_REPO_ID = os.getenv("HF_REPO_ID", "ethnmcl/test-score-predictor-xgb")
15
+ HF_TOKEN = os.getenv("HF_TOKEN", None) # only needed if repo is private
16
 
17
+ # -------------------------------
18
+ # Global state
19
+ # -------------------------------
20
  _loaded = False
21
+ _loaded_lock = threading.Lock()
22
  _pre = None
23
  _weights = None
24
  _schema = None
25
  _model = None
26
 
27
+ # -------------------------------
28
+ # Loader functions
29
+ # -------------------------------
30
  def repo_snapshot(repo_id: str = None) -> str:
31
+ """Download model repo snapshot (if not cached)."""
32
  repo_id = repo_id or HF_REPO_ID
33
  local_dir = snapshot_download(
34
  repo_id=repo_id,
 
40
  return local_dir
41
 
42
  def load_model():
43
+ """Load preprocessor, weights, schema, and XGB model into memory."""
44
  global _loaded, _pre, _weights, _schema, _model
45
  if _loaded:
46
  return
 
61
  df = pd.DataFrame(records, columns=num + cat)
62
  Xt = _pre.transform(df)
63
  Xt = Xt.astype(float, copy=False)
64
+ Xt[:, :len(num)] *= _weights
65
  return Xt
66
 
67
  def predict_one(record: dict) -> float:
 
69
  load_model()
70
  Xt = _transform([record])
71
  pred = float(_model.predict(Xt)[0])
72
+ return max(50.0, min(100.0, pred)) # clamp to dataset range
73
 
74
  def predict_batch(records: list) -> np.ndarray:
75
  if not _loaded:
 
78
  preds = _model.predict(Xt)
79
  return np.clip(preds, 50.0, 100.0)
80
 
81
+ # -------------------------------
82
+ # FastAPI app
83
+ # -------------------------------
84
+ app = FastAPI(title="Test Score Predictor API", version="1.0.0")
85
+
86
+ app.add_middleware(
87
+ CORSMiddleware,
88
+ allow_origins=["*"], allow_credentials=True,
89
+ allow_methods=["*"], allow_headers=["*"],
90
+ )
91
+
92
+ @app.on_event("startup")
93
+ def _startup():
94
+ repo_snapshot(HF_REPO_ID)
95
+ load_model()
96
+
97
+ # -------------------------------
98
+ # Request schemas
99
+ # -------------------------------
100
+ class Record(BaseModel):
101
+ Subject: str = Field(..., examples=["Mathematics"])
102
+ Current_Grade: int = Field(..., ge=60, le=98)
103
+ Max_Test_Percentage: int = Field(..., ge=65, le=100)
104
+ Days_Preparing: int = Field(..., ge=1, le=14)
105
+ Hours_Studied: int = Field(..., ge=2, le=50)
106
+ Study_Session_Average: float = Field(..., ge=0.1, le=10.0)
107
+ Avg_Previous_Tests: int = Field(..., ge=55, le=95)
108
+ Test_Difficulty: str = Field(..., examples=["Easy (20)", "Medium (30)", "Hard (50)"])
109
+
110
+ @validator("Study_Session_Average", always=True)
111
+ def recompute_session_avg(cls, v, values):
112
+ if "Hours_Studied" in values and "Days_Preparing" in values:
113
+ return round(values["Hours_Studied"] / values["Days_Preparing"], 1)
114
+ return v
115
+
116
+ class PredictRequest(BaseModel):
117
+ data: List[Record]
118
+
119
+ # -------------------------------
120
+ # Routes
121
+ # -------------------------------
122
+ @app.get("/health")
123
+ def health() -> Dict[str, Any]:
124
+ return {"status": "ok", "repo": HF_REPO_ID}
125
+
126
+ @app.post("/predict")
127
+ def predict(req: Record) -> Dict[str, Any]:
128
+ return {"predicted_score": predict_one(req.dict())}
129
+
130
+ @app.post("/predict-batch")
131
+ def predict_many(req: PredictRequest) -> Dict[str, Any]:
132
+ recs = [r.dict() for r in req.data]
133
+ return {"predicted_scores": predict_batch(recs).tolist(), "count": len(recs)}