triflix's picture
Update main.py
5b6cb96 verified
import os
import pickle
import pandas as pd
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title="Student Score Predictor Chatbot")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
templates = Jinja2Templates(directory="templates")
# Load model once at startup
MODEL_PATH = os.getenv('MODEL_PATH', 'student_performance_model.pkl')
try:
with open(MODEL_PATH, 'rb') as f:
model = pickle.load(f)
except Exception as e:
raise RuntimeError(f"Could not load model: {e}")
# Chat‑fields configuration
FIELDS = [
{'name': 'Age', 'type': 'number',
'question': 'What is your age?',
'validation': {'min': 5, 'max': 100}},
{'name': 'Gender', 'type': 'select',
'question': 'What is your gender?',
'options': ['Male', 'Female', 'Other']},
{'name': 'HoursOfStudyPerDay', 'type': 'number',
'question': 'How many hours do you study per day?',
'validation': {'min': 0, 'max': 24}},
{'name': 'SchoolAttendanceRate', 'type': 'number',
'question': 'What is your school attendance rate (%)?',
'validation': {'min': 0, 'max': 100}},
{'name': 'TuitionAccess', 'type': 'select',
'question': 'Do you have access to extra tuition?',
'options': ['Yes', 'No']},
{'name': 'AveragePreviousScores', 'type': 'number',
'question': 'What was your average previous score?',
'validation': {'min': 0, 'max': 100}},
{'name': 'HoursOfSleep', 'type': 'number',
'question': 'How many hours of sleep do you get per night?',
'validation': {'min': 0, 'max': 24}},
{'name': 'BreakfastDaily', 'type': 'select',
'question': 'Do you eat breakfast every day?',
'options': ['Yes', 'No']},
{'name': 'ScreenTimeHours', 'type': 'number',
'question': 'How many hours of screen time per day?',
'validation': {'min': 0, 'max': 24}},
{'name': 'PhysicalActivityHours', 'type': 'number',
'question': 'How many hours of physical activity per day?',
'validation': {'min': 0, 'max': 24}},
{'name': 'PlaysSport', 'type': 'select',
'question': 'Do you play any sports?',
'options': ['Yes', 'No']},
{'name': 'MentalHealthScore', 'type': 'number',
'question': 'Rate your mental health on a scale of 1–10.',
'validation': {'min': 1, 'max': 10}},
{'name': 'ParentalEducationLevel', 'type': 'select',
'question': 'What is your parental education level?',
'options': ['High school', 'Graduate', 'Postgrad']},
{'name': 'HouseholdIncomeLevel', 'type': 'select',
'question': 'What is your household income level?',
'options': ['Low', 'Medium', 'High']},
{'name': 'StudyEnvironmentRating', 'type': 'number',
'question': 'Rate your study environment (1–5).',
'validation': {'min': 1, 'max': 5}},
{'name': 'FriendSupportScore', 'type': 'number',
'question': 'Rate the emotional support from friends (1–10).',
'validation': {'min': 1, 'max': 10}},
{'name': 'ParticipatesInClubs', 'type': 'select',
'question': 'Do you participate in any clubs?',
'options': ['Yes', 'No']},
{'name': 'PartTimeWork', 'type': 'select',
'question': 'Do you do any part‑time work?',
'options': ['Yes', 'No']},
]
@app.get("/", response_class=HTMLResponse)
async def chat_ui(request: Request):
return templates.TemplateResponse("chat.html", {
"request": request,
"fields": FIELDS
})
@app.post("/predict_json")
async def predict_json(payload: dict):
# collect + cast
data = {}
for f in FIELDS:
key = f["name"]
if key not in payload:
raise HTTPException(400, f"Missing field: {key}")
val = payload[key]
if f["type"] == "number":
try:
val = float(val)
except:
raise HTTPException(400, f"{key} must be numeric")
data[key] = val
# validate ranges
for f in FIELDS:
if f["type"] == "number" and "validation" in f:
v = f["validation"]
if not (v["min"] <= data[f["name"]] <= v["max"]):
raise HTTPException(400,
f"{f['name']} must be between {v['min']} and {v['max']}")
df = pd.DataFrame([data])
score = model.predict(df)[0]
return JSONResponse({"predicted": round(float(score), 2)})