Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -1,11 +1,20 @@
|
|
| 1 |
import os
|
| 2 |
import pickle
|
| 3 |
import pandas as pd
|
| 4 |
-
from fastapi import FastAPI, Request, Form
|
| 5 |
-
from fastapi.responses import HTMLResponse
|
| 6 |
from fastapi.templating import Jinja2Templates
|
|
|
|
| 7 |
|
| 8 |
app = FastAPI(title="Student Score Predictor")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
templates = Jinja2Templates(directory="templates")
|
| 10 |
|
| 11 |
# Load the trained model at startup
|
|
@@ -16,59 +25,56 @@ try:
|
|
| 16 |
except Exception as e:
|
| 17 |
raise RuntimeError(f"Failed to load model: {e}")
|
| 18 |
|
| 19 |
-
#
|
| 20 |
FIELDS = [
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
]
|
| 40 |
|
| 41 |
@app.get('/', response_class=HTMLResponse)
|
| 42 |
-
async def
|
| 43 |
-
|
|
|
|
| 44 |
'request': request,
|
| 45 |
-
'
|
| 46 |
-
'values': {}
|
| 47 |
})
|
| 48 |
|
| 49 |
-
@app.post('/
|
| 50 |
-
async def
|
| 51 |
-
#
|
| 52 |
-
record = {}
|
| 53 |
-
for name, dtype in FIELDS:
|
| 54 |
-
raw = form_data.get(name)
|
| 55 |
-
if raw is None:
|
| 56 |
-
# missing field
|
| 57 |
-
record[name] = None
|
| 58 |
-
continue
|
| 59 |
-
record[name] = dtype(raw)
|
| 60 |
-
|
| 61 |
-
# Prepare DataFrame
|
| 62 |
-
df = pd.DataFrame([record])
|
| 63 |
-
# Predict
|
| 64 |
try:
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
'
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import pickle
|
| 3 |
import pandas as pd
|
| 4 |
+
from fastapi import FastAPI, Request, Form, HTTPException
|
| 5 |
+
from fastapi.responses import HTMLResponse, JSONResponse
|
| 6 |
from fastapi.templating import Jinja2Templates
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
| 9 |
app = FastAPI(title="Student Score Predictor")
|
| 10 |
+
# Allow CORS for frontend JS
|
| 11 |
+
app.add_middleware(
|
| 12 |
+
CORSMiddleware,
|
| 13 |
+
allow_origins=["*"],
|
| 14 |
+
allow_credentials=True,
|
| 15 |
+
allow_methods=["*"],
|
| 16 |
+
allow_headers=["*"],
|
| 17 |
+
)
|
| 18 |
templates = Jinja2Templates(directory="templates")
|
| 19 |
|
| 20 |
# Load the trained model at startup
|
|
|
|
| 25 |
except Exception as e:
|
| 26 |
raise RuntimeError(f"Failed to load model: {e}")
|
| 27 |
|
| 28 |
+
# Input fields configuration
|
| 29 |
FIELDS = [
|
| 30 |
+
{'name': 'Age', 'type': 'number', 'question': 'What is your age?', 'validation': {'min': 5, 'max': 100}},
|
| 31 |
+
{'name': 'Gender', 'type': 'select', 'question': 'What is your gender?', 'options': ['Male','Female']},
|
| 32 |
+
{'name': 'HoursOfStudyPerDay', 'type': 'number', 'question': 'How many hours do you study per day?', 'validation': {'min': 0, 'max': 24}},
|
| 33 |
+
{'name': 'SchoolAttendanceRate', 'type': 'number', 'question': 'What is your school attendance rate in %?', 'validation': {'min': 0, 'max': 100}},
|
| 34 |
+
{'name': 'TuitionAccess', 'type': 'select', 'question': 'Do you have access to tuition (Yes/No)?', 'options': ['Yes','No']},
|
| 35 |
+
{'name': 'AveragePreviousScores', 'type': 'number', 'question': 'What is your average previous score?', 'validation': {'min': 0, 'max': 100}},
|
| 36 |
+
{'name': 'HoursOfSleep', 'type': 'number', 'question': 'How many hours of sleep do you get per day?', 'validation': {'min': 0, 'max': 24}},
|
| 37 |
+
{'name': 'BreakfastDaily', 'type': 'select', 'question': 'Do you have breakfast daily? (Yes/No)', 'options': ['Yes','No']},
|
| 38 |
+
{'name': 'ScreenTimeHours', 'type': 'number', 'question': 'How many hours of screen time per day?', 'validation': {'min': 0, 'max': 24}},
|
| 39 |
+
{'name': 'PhysicalActivityHours', 'type': 'number', 'question': 'How many hours of physical activity per day?', 'validation': {'min': 0, 'max': 24}},
|
| 40 |
+
{'name': 'PlaysSport', 'type': 'select', 'question': 'Do you play sports? (Yes/No)', 'options': ['Yes','No']},
|
| 41 |
+
{'name': 'MentalHealthScore', 'type': 'number', 'question': 'Rate your mental health (1-10)', 'validation': {'min': 1, 'max': 10}},
|
| 42 |
+
{'name': 'ParentalEducationLevel', 'type': 'select', 'question': 'Parental education level?', 'options': ['High school','Graduate','Postgrad']},
|
| 43 |
+
{'name': 'HouseholdIncomeLevel', 'type': 'select', 'question': 'Household income level?', 'options': ['Low','Medium','High']},
|
| 44 |
+
{'name': 'StudyEnvironmentRating', 'type': 'number', 'question': 'Rate your study environment (1-5)', 'validation': {'min': 1, 'max': 5}},
|
| 45 |
+
{'name': 'FriendSupportScore', 'type': 'number', 'question': 'Rate friend support (1-10)', 'validation': {'min': 1, 'max': 10}},
|
| 46 |
+
{'name': 'ParticipatesInClubs', 'type': 'select', 'question': 'Do you participate in clubs? (Yes/No)', 'options': ['Yes','No']},
|
| 47 |
+
{'name': 'PartTimeWork', 'type': 'select', 'question': 'Do you do part-time work? (Yes/No)', 'options': ['Yes','No']},
|
| 48 |
]
|
| 49 |
|
| 50 |
@app.get('/', response_class=HTMLResponse)
|
| 51 |
+
async def get_chat(request: Request):
|
| 52 |
+
# Serve the chat UI, embedding the FIELDS config
|
| 53 |
+
return templates.TemplateResponse('chat.html', {
|
| 54 |
'request': request,
|
| 55 |
+
'fields': FIELDS
|
|
|
|
| 56 |
})
|
| 57 |
|
| 58 |
+
@app.post('/predict_json')
|
| 59 |
+
async def predict_json(payload: dict):
|
| 60 |
+
# Ensure all fields present
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
try:
|
| 62 |
+
data = {f['name']: payload[f['name']] for f in FIELDS}
|
| 63 |
+
except KeyError as e:
|
| 64 |
+
raise HTTPException(status_code=400, detail=f"Missing field: {e}")
|
| 65 |
+
# Cast types
|
| 66 |
+
for f in FIELDS:
|
| 67 |
+
name = f['name']
|
| 68 |
+
if f['type']=='number':
|
| 69 |
+
data[name] = float(data[name])
|
| 70 |
+
# Validate
|
| 71 |
+
for f in FIELDS:
|
| 72 |
+
if f['type']=='number' and 'validation' in f:
|
| 73 |
+
val = data[f['name']]
|
| 74 |
+
v = f['validation']
|
| 75 |
+
if not (v['min'] <= val <= v['max']):
|
| 76 |
+
raise HTTPException(status_code=400, detail=f"{f['name']} must be between {v['min']} and {v['max']}")
|
| 77 |
+
# Predict
|
| 78 |
+
df = pd.DataFrame([data])
|
| 79 |
+
score = model.predict(df)[0]
|
| 80 |
+
return JSONResponse({'predicted': round(float(score),2)})
|