Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import pandas as pd | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app = FastAPI(title="Student Score Predictor Chatbot") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| templates = Jinja2Templates(directory="templates") | |
| # Load model once at startup | |
| MODEL_PATH = os.getenv('MODEL_PATH', 'student_performance_model.pkl') | |
| try: | |
| with open(MODEL_PATH, 'rb') as f: | |
| model = pickle.load(f) | |
| except Exception as e: | |
| raise RuntimeError(f"Could not load model: {e}") | |
| # Chat‑fields configuration | |
| FIELDS = [ | |
| {'name': 'Age', 'type': 'number', | |
| 'question': 'What is your age?', | |
| 'validation': {'min': 5, 'max': 100}}, | |
| {'name': 'Gender', 'type': 'select', | |
| 'question': 'What is your gender?', | |
| 'options': ['Male', 'Female', 'Other']}, | |
| {'name': 'HoursOfStudyPerDay', 'type': 'number', | |
| 'question': 'How many hours do you study per day?', | |
| 'validation': {'min': 0, 'max': 24}}, | |
| {'name': 'SchoolAttendanceRate', 'type': 'number', | |
| 'question': 'What is your school attendance rate (%)?', | |
| 'validation': {'min': 0, 'max': 100}}, | |
| {'name': 'TuitionAccess', 'type': 'select', | |
| 'question': 'Do you have access to extra tuition?', | |
| 'options': ['Yes', 'No']}, | |
| {'name': 'AveragePreviousScores', 'type': 'number', | |
| 'question': 'What was your average previous score?', | |
| 'validation': {'min': 0, 'max': 100}}, | |
| {'name': 'HoursOfSleep', 'type': 'number', | |
| 'question': 'How many hours of sleep do you get per night?', | |
| 'validation': {'min': 0, 'max': 24}}, | |
| {'name': 'BreakfastDaily', 'type': 'select', | |
| 'question': 'Do you eat breakfast every day?', | |
| 'options': ['Yes', 'No']}, | |
| {'name': 'ScreenTimeHours', 'type': 'number', | |
| 'question': 'How many hours of screen time per day?', | |
| 'validation': {'min': 0, 'max': 24}}, | |
| {'name': 'PhysicalActivityHours', 'type': 'number', | |
| 'question': 'How many hours of physical activity per day?', | |
| 'validation': {'min': 0, 'max': 24}}, | |
| {'name': 'PlaysSport', 'type': 'select', | |
| 'question': 'Do you play any sports?', | |
| 'options': ['Yes', 'No']}, | |
| {'name': 'MentalHealthScore', 'type': 'number', | |
| 'question': 'Rate your mental health on a scale of 1–10.', | |
| 'validation': {'min': 1, 'max': 10}}, | |
| {'name': 'ParentalEducationLevel', 'type': 'select', | |
| 'question': 'What is your parental education level?', | |
| 'options': ['High school', 'Graduate', 'Postgrad']}, | |
| {'name': 'HouseholdIncomeLevel', 'type': 'select', | |
| 'question': 'What is your household income level?', | |
| 'options': ['Low', 'Medium', 'High']}, | |
| {'name': 'StudyEnvironmentRating', 'type': 'number', | |
| 'question': 'Rate your study environment (1–5).', | |
| 'validation': {'min': 1, 'max': 5}}, | |
| {'name': 'FriendSupportScore', 'type': 'number', | |
| 'question': 'Rate the emotional support from friends (1–10).', | |
| 'validation': {'min': 1, 'max': 10}}, | |
| {'name': 'ParticipatesInClubs', 'type': 'select', | |
| 'question': 'Do you participate in any clubs?', | |
| 'options': ['Yes', 'No']}, | |
| {'name': 'PartTimeWork', 'type': 'select', | |
| 'question': 'Do you do any part‑time work?', | |
| 'options': ['Yes', 'No']}, | |
| ] | |
| async def chat_ui(request: Request): | |
| return templates.TemplateResponse("chat.html", { | |
| "request": request, | |
| "fields": FIELDS | |
| }) | |
| async def predict_json(payload: dict): | |
| # collect + cast | |
| data = {} | |
| for f in FIELDS: | |
| key = f["name"] | |
| if key not in payload: | |
| raise HTTPException(400, f"Missing field: {key}") | |
| val = payload[key] | |
| if f["type"] == "number": | |
| try: | |
| val = float(val) | |
| except: | |
| raise HTTPException(400, f"{key} must be numeric") | |
| data[key] = val | |
| # validate ranges | |
| for f in FIELDS: | |
| if f["type"] == "number" and "validation" in f: | |
| v = f["validation"] | |
| if not (v["min"] <= data[f["name"]] <= v["max"]): | |
| raise HTTPException(400, | |
| f"{f['name']} must be between {v['min']} and {v['max']}") | |
| df = pd.DataFrame([data]) | |
| score = model.predict(df)[0] | |
| return JSONResponse({"predicted": round(float(score), 2)}) | |