ethnmcl commited on
Commit
4d509f3
·
verified ·
1 Parent(s): 2fb4c58

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +71 -0
main.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel, Field, validator
4
+ from typing import List, Dict, Any
5
+ import os
6
+
7
+ from inference import load_model, predict_one, predict_batch, repo_snapshot
8
+
9
+ HF_REPO_ID = os.getenv("HF_REPO_ID", "ethnmcl/test-score-predictor-xgb")
10
+
11
+ app = FastAPI(title="Test Score Predictor API",
12
+ version="1.0.0",
13
+ description="FastAPI wrapper for ethnmcl/test-score-predictor-xgb")
14
+
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"], allow_credentials=True,
18
+ allow_methods=["*"], allow_headers=["*"],
19
+ )
20
+
21
+ # Load model at startup (downloads snapshot if not already present)
22
+ @app.on_event("startup")
23
+ def _startup():
24
+ repo_snapshot(HF_REPO_ID) # ensures files exist locally
25
+ load_model() # loads artifacts into process
26
+
27
+
28
+ class Record(BaseModel):
29
+ Subject: str = Field(..., examples=["Mathematics"])
30
+ Current_Grade: int = Field(..., ge=60, le=98)
31
+ Max_Test_Percentage: int = Field(..., ge=65, le=100)
32
+ Days_Preparing: int = Field(..., ge=1, le=14)
33
+ Hours_Studied: int = Field(..., ge=2, le=50)
34
+ Study_Session_Average: float = Field(..., ge=0.1, le=10.0)
35
+ Avg_Previous_Tests: int = Field(..., ge=55, le=95)
36
+ Test_Difficulty: str = Field(..., examples=["Easy (20)", "Medium (30)", "Hard (50)"])
37
+
38
+ @validator("Study_Session_Average", always=True)
39
+ def recompute_session_avg(cls, v, values):
40
+ # Keep dataset contract: Hours / Days, rounded to 1 decimal
41
+ if "Hours_Studied" in values and "Days_Preparing" in values:
42
+ h = values["Hours_Studied"]; d = values["Days_Preparing"]
43
+ return round(h / d, 1)
44
+ return v
45
+
46
+
47
+ class PredictRequest(BaseModel):
48
+ data: List[Record]
49
+
50
+
51
+ @app.get("/health")
52
+ def health() -> Dict[str, Any]:
53
+ return {"status": "ok", "repo": HF_REPO_ID}
54
+
55
+
56
+ @app.get("/model-info")
57
+ def model_info() -> Dict[str, Any]:
58
+ return {"repo": HF_REPO_ID, "files": ["preprocessor.joblib", "weights.npy", "xgb_model.json", "schema.json"]}
59
+
60
+
61
+ @app.post("/predict")
62
+ def predict(req: Record) -> Dict[str, Any]:
63
+ score = predict_one(req.dict())
64
+ return {"predicted_score": float(score)}
65
+
66
+
67
+ @app.post("/predict-batch")
68
+ def predict_many(req: PredictRequest) -> Dict[str, Any]:
69
+ records = [r.dict() for r in req.data]
70
+ scores = predict_batch(records)
71
+ return {"predicted_scores": [float(s) for s in scores], "count": len(scores)}