Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,80 +1,116 @@
|
|
| 1 |
import os
|
| 2 |
import pickle
|
| 3 |
import pandas as pd
|
| 4 |
-
from fastapi import FastAPI, Request,
|
| 5 |
from fastapi.responses import HTMLResponse, JSONResponse
|
| 6 |
from fastapi.templating import Jinja2Templates
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
| 9 |
-
app = FastAPI(title="Student Score Predictor")
|
| 10 |
-
# Allow CORS for frontend JS
|
| 11 |
app.add_middleware(
|
| 12 |
CORSMiddleware,
|
| 13 |
-
allow_origins=["*"],
|
| 14 |
-
|
| 15 |
-
allow_methods=["*"],
|
| 16 |
-
allow_headers=["*"],
|
| 17 |
)
|
| 18 |
templates = Jinja2Templates(directory="templates")
|
| 19 |
|
| 20 |
-
# Load
|
| 21 |
MODEL_PATH = os.getenv('MODEL_PATH', 'student_performance_model.pkl')
|
| 22 |
try:
|
| 23 |
with open(MODEL_PATH, 'rb') as f:
|
| 24 |
model = pickle.load(f)
|
| 25 |
except Exception as e:
|
| 26 |
-
raise RuntimeError(f"
|
| 27 |
|
| 28 |
-
#
|
| 29 |
FIELDS = [
|
| 30 |
-
{'name': 'Age', 'type': 'number',
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
{'name': '
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
{'name': '
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
{'name': '
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
{'name': '
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
{'name': '
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
]
|
| 49 |
|
| 50 |
-
@app.get(
|
| 51 |
-
async def
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
'fields': FIELDS
|
| 56 |
})
|
| 57 |
|
| 58 |
-
@app.post(
|
| 59 |
async def predict_json(payload: dict):
|
| 60 |
-
#
|
| 61 |
-
|
| 62 |
-
data = {f['name']: payload[f['name']] for f in FIELDS}
|
| 63 |
-
except KeyError as e:
|
| 64 |
-
raise HTTPException(status_code=400, detail=f"Missing field: {e}")
|
| 65 |
-
# Cast types
|
| 66 |
for f in FIELDS:
|
| 67 |
-
|
| 68 |
-
if
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
for f in FIELDS:
|
| 72 |
-
if f[
|
| 73 |
-
|
| 74 |
-
v = f[
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
df = pd.DataFrame([data])
|
| 79 |
score = model.predict(df)[0]
|
| 80 |
-
return JSONResponse({
|
|
|
|
| 1 |
import os
|
| 2 |
import pickle
|
| 3 |
import pandas as pd
|
| 4 |
+
from fastapi import FastAPI, Request, HTTPException
|
| 5 |
from fastapi.responses import HTMLResponse, JSONResponse
|
| 6 |
from fastapi.templating import Jinja2Templates
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
| 9 |
+
app = FastAPI(title="Student Score Predictor Chatbot")
|
|
|
|
| 10 |
app.add_middleware(
|
| 11 |
CORSMiddleware,
|
| 12 |
+
allow_origins=["*"], allow_credentials=True,
|
| 13 |
+
allow_methods=["*"], allow_headers=["*"],
|
|
|
|
|
|
|
| 14 |
)
|
| 15 |
templates = Jinja2Templates(directory="templates")
|
| 16 |
|
| 17 |
+
# Load model once at startup
|
| 18 |
MODEL_PATH = os.getenv('MODEL_PATH', 'student_performance_model.pkl')
|
| 19 |
try:
|
| 20 |
with open(MODEL_PATH, 'rb') as f:
|
| 21 |
model = pickle.load(f)
|
| 22 |
except Exception as e:
|
| 23 |
+
raise RuntimeError(f"Could not load model: {e}")
|
| 24 |
|
| 25 |
+
# Chat‑fields configuration
|
| 26 |
FIELDS = [
|
| 27 |
+
{'name': 'Age', 'type': 'number',
|
| 28 |
+
'question': 'What is your age?',
|
| 29 |
+
'validation': {'min': 5, 'max': 100}},
|
| 30 |
+
{'name': 'Gender', 'type': 'select',
|
| 31 |
+
'question': 'What is your gender?',
|
| 32 |
+
'options': ['Male', 'Female', 'Other']},
|
| 33 |
+
{'name': 'HoursOfStudyPerDay', 'type': 'number',
|
| 34 |
+
'question': 'How many hours do you study per day?',
|
| 35 |
+
'validation': {'min': 0, 'max': 24}},
|
| 36 |
+
{'name': 'SchoolAttendanceRate', 'type': 'number',
|
| 37 |
+
'question': 'What is your school attendance rate (%)?',
|
| 38 |
+
'validation': {'min': 0, 'max': 100}},
|
| 39 |
+
{'name': 'TuitionAccess', 'type': 'select',
|
| 40 |
+
'question': 'Do you have access to extra tuition?',
|
| 41 |
+
'options': ['Yes', 'No']},
|
| 42 |
+
{'name': 'AveragePreviousScores', 'type': 'number',
|
| 43 |
+
'question': 'What was your average previous score?',
|
| 44 |
+
'validation': {'min': 0, 'max': 100}},
|
| 45 |
+
{'name': 'HoursOfSleep', 'type': 'number',
|
| 46 |
+
'question': 'How many hours of sleep do you get per night?',
|
| 47 |
+
'validation': {'min': 0, 'max': 24}},
|
| 48 |
+
{'name': 'BreakfastDaily', 'type': 'select',
|
| 49 |
+
'question': 'Do you eat breakfast every day?',
|
| 50 |
+
'options': ['Yes', 'No']},
|
| 51 |
+
{'name': 'ScreenTimeHours', 'type': 'number',
|
| 52 |
+
'question': 'How many hours of screen time per day?',
|
| 53 |
+
'validation': {'min': 0, 'max': 24}},
|
| 54 |
+
{'name': 'PhysicalActivityHours', 'type': 'number',
|
| 55 |
+
'question': 'How many hours of physical activity per day?',
|
| 56 |
+
'validation': {'min': 0, 'max': 24}},
|
| 57 |
+
{'name': 'PlaysSport', 'type': 'select',
|
| 58 |
+
'question': 'Do you play any sports?',
|
| 59 |
+
'options': ['Yes', 'No']},
|
| 60 |
+
{'name': 'MentalHealthScore', 'type': 'number',
|
| 61 |
+
'question': 'Rate your mental health on a scale of 1–10.',
|
| 62 |
+
'validation': {'min': 1, 'max': 10}},
|
| 63 |
+
{'name': 'ParentalEducationLevel', 'type': 'select',
|
| 64 |
+
'question': 'What is your parental education level?',
|
| 65 |
+
'options': ['High school', 'Graduate', 'Postgrad']},
|
| 66 |
+
{'name': 'HouseholdIncomeLevel', 'type': 'select',
|
| 67 |
+
'question': 'What is your household income level?',
|
| 68 |
+
'options': ['Low', 'Medium', 'High']},
|
| 69 |
+
{'name': 'StudyEnvironmentRating', 'type': 'number',
|
| 70 |
+
'question': 'Rate your study environment (1–5).',
|
| 71 |
+
'validation': {'min': 1, 'max': 5}},
|
| 72 |
+
{'name': 'FriendSupportScore', 'type': 'number',
|
| 73 |
+
'question': 'Rate the emotional support from friends (1–10).',
|
| 74 |
+
'validation': {'min': 1, 'max': 10}},
|
| 75 |
+
{'name': 'ParticipatesInClubs', 'type': 'select',
|
| 76 |
+
'question': 'Do you participate in any clubs?',
|
| 77 |
+
'options': ['Yes', 'No']},
|
| 78 |
+
{'name': 'PartTimeWork', 'type': 'select',
|
| 79 |
+
'question': 'Do you do any part‑time work?',
|
| 80 |
+
'options': ['Yes', 'No']},
|
| 81 |
]
|
| 82 |
|
| 83 |
+
@app.get("/", response_class=HTMLResponse)
|
| 84 |
+
async def chat_ui(request: Request):
|
| 85 |
+
return templates.TemplateResponse("chat.html", {
|
| 86 |
+
"request": request,
|
| 87 |
+
"fields": FIELDS
|
|
|
|
| 88 |
})
|
| 89 |
|
| 90 |
+
@app.post("/predict_json")
|
| 91 |
async def predict_json(payload: dict):
|
| 92 |
+
# collect + cast
|
| 93 |
+
data = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
for f in FIELDS:
|
| 95 |
+
key = f["name"]
|
| 96 |
+
if key not in payload:
|
| 97 |
+
raise HTTPException(400, f"Missing field: {key}")
|
| 98 |
+
val = payload[key]
|
| 99 |
+
if f["type"] == "number":
|
| 100 |
+
try:
|
| 101 |
+
val = float(val)
|
| 102 |
+
except:
|
| 103 |
+
raise HTTPException(400, f"{key} must be numeric")
|
| 104 |
+
data[key] = val
|
| 105 |
+
|
| 106 |
+
# validate ranges
|
| 107 |
for f in FIELDS:
|
| 108 |
+
if f["type"] == "number" and "validation" in f:
|
| 109 |
+
v = f["validation"]
|
| 110 |
+
if not (v["min"] <= data[f["name"]] <= v["max"]):
|
| 111 |
+
raise HTTPException(400,
|
| 112 |
+
f"{f['name']} must be between {v['min']} and {v['max']}")
|
| 113 |
+
|
| 114 |
df = pd.DataFrame([data])
|
| 115 |
score = model.predict(df)[0]
|
| 116 |
+
return JSONResponse({"predicted": round(float(score), 2)})
|