triflix commited on
Commit
5b6cb96
·
verified ·
1 Parent(s): 4763629

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +88 -52
main.py CHANGED
@@ -1,80 +1,116 @@
1
  import os
2
  import pickle
3
  import pandas as pd
4
- from fastapi import FastAPI, Request, Form, HTTPException
5
  from fastapi.responses import HTMLResponse, JSONResponse
6
  from fastapi.templating import Jinja2Templates
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
9
- app = FastAPI(title="Student Score Predictor")
10
- # Allow CORS for frontend JS
11
  app.add_middleware(
12
  CORSMiddleware,
13
- allow_origins=["*"],
14
- allow_credentials=True,
15
- allow_methods=["*"],
16
- allow_headers=["*"],
17
  )
18
  templates = Jinja2Templates(directory="templates")
19
 
20
- # Load the trained model at startup
21
  MODEL_PATH = os.getenv('MODEL_PATH', 'student_performance_model.pkl')
22
  try:
23
  with open(MODEL_PATH, 'rb') as f:
24
  model = pickle.load(f)
25
  except Exception as e:
26
- raise RuntimeError(f"Failed to load model: {e}")
27
 
28
- # Input fields configuration
29
  FIELDS = [
30
- {'name': 'Age', 'type': 'number', 'question': 'What is your age?', 'validation': {'min': 5, 'max': 100}},
31
- {'name': 'Gender', 'type': 'select', 'question': 'What is your gender?', 'options': ['Male','Female']},
32
- {'name': 'HoursOfStudyPerDay', 'type': 'number', 'question': 'How many hours do you study per day?', 'validation': {'min': 0, 'max': 24}},
33
- {'name': 'SchoolAttendanceRate', 'type': 'number', 'question': 'What is your school attendance rate in %?', 'validation': {'min': 0, 'max': 100}},
34
- {'name': 'TuitionAccess', 'type': 'select', 'question': 'Do you have access to tuition (Yes/No)?', 'options': ['Yes','No']},
35
- {'name': 'AveragePreviousScores', 'type': 'number', 'question': 'What is your average previous score?', 'validation': {'min': 0, 'max': 100}},
36
- {'name': 'HoursOfSleep', 'type': 'number', 'question': 'How many hours of sleep do you get per day?', 'validation': {'min': 0, 'max': 24}},
37
- {'name': 'BreakfastDaily', 'type': 'select', 'question': 'Do you have breakfast daily? (Yes/No)', 'options': ['Yes','No']},
38
- {'name': 'ScreenTimeHours', 'type': 'number', 'question': 'How many hours of screen time per day?', 'validation': {'min': 0, 'max': 24}},
39
- {'name': 'PhysicalActivityHours', 'type': 'number', 'question': 'How many hours of physical activity per day?', 'validation': {'min': 0, 'max': 24}},
40
- {'name': 'PlaysSport', 'type': 'select', 'question': 'Do you play sports? (Yes/No)', 'options': ['Yes','No']},
41
- {'name': 'MentalHealthScore', 'type': 'number', 'question': 'Rate your mental health (1-10)', 'validation': {'min': 1, 'max': 10}},
42
- {'name': 'ParentalEducationLevel', 'type': 'select', 'question': 'Parental education level?', 'options': ['High school','Graduate','Postgrad']},
43
- {'name': 'HouseholdIncomeLevel', 'type': 'select', 'question': 'Household income level?', 'options': ['Low','Medium','High']},
44
- {'name': 'StudyEnvironmentRating', 'type': 'number', 'question': 'Rate your study environment (1-5)', 'validation': {'min': 1, 'max': 5}},
45
- {'name': 'FriendSupportScore', 'type': 'number', 'question': 'Rate friend support (1-10)', 'validation': {'min': 1, 'max': 10}},
46
- {'name': 'ParticipatesInClubs', 'type': 'select', 'question': 'Do you participate in clubs? (Yes/No)', 'options': ['Yes','No']},
47
- {'name': 'PartTimeWork', 'type': 'select', 'question': 'Do you do part-time work? (Yes/No)', 'options': ['Yes','No']},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ]
49
 
50
- @app.get('/', response_class=HTMLResponse)
51
- async def get_chat(request: Request):
52
- # Serve the chat UI, embedding the FIELDS config
53
- return templates.TemplateResponse('chat.html', {
54
- 'request': request,
55
- 'fields': FIELDS
56
  })
57
 
58
- @app.post('/predict_json')
59
  async def predict_json(payload: dict):
60
- # Ensure all fields present
61
- try:
62
- data = {f['name']: payload[f['name']] for f in FIELDS}
63
- except KeyError as e:
64
- raise HTTPException(status_code=400, detail=f"Missing field: {e}")
65
- # Cast types
66
  for f in FIELDS:
67
- name = f['name']
68
- if f['type']=='number':
69
- data[name] = float(data[name])
70
- # Validate
 
 
 
 
 
 
 
 
71
  for f in FIELDS:
72
- if f['type']=='number' and 'validation' in f:
73
- val = data[f['name']]
74
- v = f['validation']
75
- if not (v['min'] <= val <= v['max']):
76
- raise HTTPException(status_code=400, detail=f"{f['name']} must be between {v['min']} and {v['max']}")
77
- # Predict
78
  df = pd.DataFrame([data])
79
  score = model.predict(df)[0]
80
- return JSONResponse({'predicted': round(float(score),2)})
 
1
  import os
2
  import pickle
3
  import pandas as pd
4
+ from fastapi import FastAPI, Request, HTTPException
5
  from fastapi.responses import HTMLResponse, JSONResponse
6
  from fastapi.templating import Jinja2Templates
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
9
+ app = FastAPI(title="Student Score Predictor Chatbot")
 
10
  app.add_middleware(
11
  CORSMiddleware,
12
+ allow_origins=["*"], allow_credentials=True,
13
+ allow_methods=["*"], allow_headers=["*"],
 
 
14
  )
15
  templates = Jinja2Templates(directory="templates")
16
 
17
+ # Load model once at startup
18
  MODEL_PATH = os.getenv('MODEL_PATH', 'student_performance_model.pkl')
19
  try:
20
  with open(MODEL_PATH, 'rb') as f:
21
  model = pickle.load(f)
22
  except Exception as e:
23
+ raise RuntimeError(f"Could not load model: {e}")
24
 
25
+ # Chat‑fields configuration
26
  FIELDS = [
27
+ {'name': 'Age', 'type': 'number',
28
+ 'question': 'What is your age?',
29
+ 'validation': {'min': 5, 'max': 100}},
30
+ {'name': 'Gender', 'type': 'select',
31
+ 'question': 'What is your gender?',
32
+ 'options': ['Male', 'Female', 'Other']},
33
+ {'name': 'HoursOfStudyPerDay', 'type': 'number',
34
+ 'question': 'How many hours do you study per day?',
35
+ 'validation': {'min': 0, 'max': 24}},
36
+ {'name': 'SchoolAttendanceRate', 'type': 'number',
37
+ 'question': 'What is your school attendance rate (%)?',
38
+ 'validation': {'min': 0, 'max': 100}},
39
+ {'name': 'TuitionAccess', 'type': 'select',
40
+ 'question': 'Do you have access to extra tuition?',
41
+ 'options': ['Yes', 'No']},
42
+ {'name': 'AveragePreviousScores', 'type': 'number',
43
+ 'question': 'What was your average previous score?',
44
+ 'validation': {'min': 0, 'max': 100}},
45
+ {'name': 'HoursOfSleep', 'type': 'number',
46
+ 'question': 'How many hours of sleep do you get per night?',
47
+ 'validation': {'min': 0, 'max': 24}},
48
+ {'name': 'BreakfastDaily', 'type': 'select',
49
+ 'question': 'Do you eat breakfast every day?',
50
+ 'options': ['Yes', 'No']},
51
+ {'name': 'ScreenTimeHours', 'type': 'number',
52
+ 'question': 'How many hours of screen time per day?',
53
+ 'validation': {'min': 0, 'max': 24}},
54
+ {'name': 'PhysicalActivityHours', 'type': 'number',
55
+ 'question': 'How many hours of physical activity per day?',
56
+ 'validation': {'min': 0, 'max': 24}},
57
+ {'name': 'PlaysSport', 'type': 'select',
58
+ 'question': 'Do you play any sports?',
59
+ 'options': ['Yes', 'No']},
60
+ {'name': 'MentalHealthScore', 'type': 'number',
61
+ 'question': 'Rate your mental health on a scale of 1–10.',
62
+ 'validation': {'min': 1, 'max': 10}},
63
+ {'name': 'ParentalEducationLevel', 'type': 'select',
64
+ 'question': 'What is your parental education level?',
65
+ 'options': ['High school', 'Graduate', 'Postgrad']},
66
+ {'name': 'HouseholdIncomeLevel', 'type': 'select',
67
+ 'question': 'What is your household income level?',
68
+ 'options': ['Low', 'Medium', 'High']},
69
+ {'name': 'StudyEnvironmentRating', 'type': 'number',
70
+ 'question': 'Rate your study environment (1–5).',
71
+ 'validation': {'min': 1, 'max': 5}},
72
+ {'name': 'FriendSupportScore', 'type': 'number',
73
+ 'question': 'Rate the emotional support from friends (1–10).',
74
+ 'validation': {'min': 1, 'max': 10}},
75
+ {'name': 'ParticipatesInClubs', 'type': 'select',
76
+ 'question': 'Do you participate in any clubs?',
77
+ 'options': ['Yes', 'No']},
78
+ {'name': 'PartTimeWork', 'type': 'select',
79
+ 'question': 'Do you do any part‑time work?',
80
+ 'options': ['Yes', 'No']},
81
  ]
82
 
83
+ @app.get("/", response_class=HTMLResponse)
84
+ async def chat_ui(request: Request):
85
+ return templates.TemplateResponse("chat.html", {
86
+ "request": request,
87
+ "fields": FIELDS
 
88
  })
89
 
90
+ @app.post("/predict_json")
91
  async def predict_json(payload: dict):
92
+ # collect + cast
93
+ data = {}
 
 
 
 
94
  for f in FIELDS:
95
+ key = f["name"]
96
+ if key not in payload:
97
+ raise HTTPException(400, f"Missing field: {key}")
98
+ val = payload[key]
99
+ if f["type"] == "number":
100
+ try:
101
+ val = float(val)
102
+ except:
103
+ raise HTTPException(400, f"{key} must be numeric")
104
+ data[key] = val
105
+
106
+ # validate ranges
107
  for f in FIELDS:
108
+ if f["type"] == "number" and "validation" in f:
109
+ v = f["validation"]
110
+ if not (v["min"] <= data[f["name"]] <= v["max"]):
111
+ raise HTTPException(400,
112
+ f"{f['name']} must be between {v['min']} and {v['max']}")
113
+
114
  df = pd.DataFrame([data])
115
  score = model.predict(df)[0]
116
+ return JSONResponse({"predicted": round(float(score), 2)})