triflix commited on
Commit
7a1eda4
·
verified ·
1 Parent(s): 40c0576

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +56 -50
main.py CHANGED
@@ -1,11 +1,20 @@
1
  import os
2
  import pickle
3
  import pandas as pd
4
- from fastapi import FastAPI, Request, Form
5
- from fastapi.responses import HTMLResponse
6
  from fastapi.templating import Jinja2Templates
 
7
 
8
  app = FastAPI(title="Student Score Predictor")
 
 
 
 
 
 
 
 
9
  templates = Jinja2Templates(directory="templates")
10
 
11
  # Load the trained model at startup
@@ -16,59 +25,56 @@ try:
16
  except Exception as e:
17
  raise RuntimeError(f"Failed to load model: {e}")
18
 
19
- # List of input fields and types
20
  FIELDS = [
21
- ('Age', int),
22
- ('Gender', str),
23
- ('HoursOfStudyPerDay', float),
24
- ('SchoolAttendanceRate', float),
25
- ('TuitionAccess', str),
26
- ('AveragePreviousScores', float),
27
- ('HoursOfSleep', float),
28
- ('BreakfastDaily', str),
29
- ('ScreenTimeHours', float),
30
- ('PhysicalActivityHours', float),
31
- ('PlaysSport', str),
32
- ('MentalHealthScore', float),
33
- ('ParentalEducationLevel', str),
34
- ('HouseholdIncomeLevel', str),
35
- ('StudyEnvironmentRating', float),
36
- ('FriendSupportScore', float),
37
- ('ParticipatesInClubs', str),
38
- ('PartTimeWork', str),
39
  ]
40
 
41
  @app.get('/', response_class=HTMLResponse)
42
- async def get_form(request: Request):
43
- return templates.TemplateResponse('index.html', {
 
44
  'request': request,
45
- 'predicted': None,
46
- 'values': {}
47
  })
48
 
49
- @app.post('/predict', response_class=HTMLResponse)
50
- async def post_predict(request: Request, **form_data):
51
- # Convert form inputs to correct types
52
- record = {}
53
- for name, dtype in FIELDS:
54
- raw = form_data.get(name)
55
- if raw is None:
56
- # missing field
57
- record[name] = None
58
- continue
59
- record[name] = dtype(raw)
60
-
61
- # Prepare DataFrame
62
- df = pd.DataFrame([record])
63
- # Predict
64
  try:
65
- score = model.predict(df)[0]
66
- predicted = round(score, 2)
67
- except Exception as e:
68
- predicted = f"Error: {e}"
69
-
70
- return templates.TemplateResponse('index.html', {
71
- 'request': request,
72
- 'predicted': predicted,
73
- 'values': record
74
- })
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import pickle
3
  import pandas as pd
4
+ from fastapi import FastAPI, Request, Form, HTTPException
5
+ from fastapi.responses import HTMLResponse, JSONResponse
6
  from fastapi.templating import Jinja2Templates
7
+ from fastapi.middleware.cors import CORSMiddleware
8
 
9
  app = FastAPI(title="Student Score Predictor")
10
+ # Allow CORS for frontend JS
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"],
14
+ allow_credentials=True,
15
+ allow_methods=["*"],
16
+ allow_headers=["*"],
17
+ )
18
  templates = Jinja2Templates(directory="templates")
19
 
20
  # Load the trained model at startup
 
25
  except Exception as e:
26
  raise RuntimeError(f"Failed to load model: {e}")
27
 
28
+ # Input fields configuration
29
  FIELDS = [
30
+ {'name': 'Age', 'type': 'number', 'question': 'What is your age?', 'validation': {'min': 5, 'max': 100}},
31
+ {'name': 'Gender', 'type': 'select', 'question': 'What is your gender?', 'options': ['Male','Female']},
32
+ {'name': 'HoursOfStudyPerDay', 'type': 'number', 'question': 'How many hours do you study per day?', 'validation': {'min': 0, 'max': 24}},
33
+ {'name': 'SchoolAttendanceRate', 'type': 'number', 'question': 'What is your school attendance rate in %?', 'validation': {'min': 0, 'max': 100}},
34
+ {'name': 'TuitionAccess', 'type': 'select', 'question': 'Do you have access to tuition (Yes/No)?', 'options': ['Yes','No']},
35
+ {'name': 'AveragePreviousScores', 'type': 'number', 'question': 'What is your average previous score?', 'validation': {'min': 0, 'max': 100}},
36
+ {'name': 'HoursOfSleep', 'type': 'number', 'question': 'How many hours of sleep do you get per day?', 'validation': {'min': 0, 'max': 24}},
37
+ {'name': 'BreakfastDaily', 'type': 'select', 'question': 'Do you have breakfast daily? (Yes/No)', 'options': ['Yes','No']},
38
+ {'name': 'ScreenTimeHours', 'type': 'number', 'question': 'How many hours of screen time per day?', 'validation': {'min': 0, 'max': 24}},
39
+ {'name': 'PhysicalActivityHours', 'type': 'number', 'question': 'How many hours of physical activity per day?', 'validation': {'min': 0, 'max': 24}},
40
+ {'name': 'PlaysSport', 'type': 'select', 'question': 'Do you play sports? (Yes/No)', 'options': ['Yes','No']},
41
+ {'name': 'MentalHealthScore', 'type': 'number', 'question': 'Rate your mental health (1-10)', 'validation': {'min': 1, 'max': 10}},
42
+ {'name': 'ParentalEducationLevel', 'type': 'select', 'question': 'Parental education level?', 'options': ['High school','Graduate','Postgrad']},
43
+ {'name': 'HouseholdIncomeLevel', 'type': 'select', 'question': 'Household income level?', 'options': ['Low','Medium','High']},
44
+ {'name': 'StudyEnvironmentRating', 'type': 'number', 'question': 'Rate your study environment (1-5)', 'validation': {'min': 1, 'max': 5}},
45
+ {'name': 'FriendSupportScore', 'type': 'number', 'question': 'Rate friend support (1-10)', 'validation': {'min': 1, 'max': 10}},
46
+ {'name': 'ParticipatesInClubs', 'type': 'select', 'question': 'Do you participate in clubs? (Yes/No)', 'options': ['Yes','No']},
47
+ {'name': 'PartTimeWork', 'type': 'select', 'question': 'Do you do part-time work? (Yes/No)', 'options': ['Yes','No']},
48
  ]
49
 
50
  @app.get('/', response_class=HTMLResponse)
51
+ async def get_chat(request: Request):
52
+ # Serve the chat UI, embedding the FIELDS config
53
+ return templates.TemplateResponse('chat.html', {
54
  'request': request,
55
+ 'fields': FIELDS
 
56
  })
57
 
58
+ @app.post('/predict_json')
59
+ async def predict_json(payload: dict):
60
+ # Ensure all fields present
 
 
 
 
 
 
 
 
 
 
 
 
61
  try:
62
+ data = {f['name']: payload[f['name']] for f in FIELDS}
63
+ except KeyError as e:
64
+ raise HTTPException(status_code=400, detail=f"Missing field: {e}")
65
+ # Cast types
66
+ for f in FIELDS:
67
+ name = f['name']
68
+ if f['type']=='number':
69
+ data[name] = float(data[name])
70
+ # Validate
71
+ for f in FIELDS:
72
+ if f['type']=='number' and 'validation' in f:
73
+ val = data[f['name']]
74
+ v = f['validation']
75
+ if not (v['min'] <= val <= v['max']):
76
+ raise HTTPException(status_code=400, detail=f"{f['name']} must be between {v['min']} and {v['max']}")
77
+ # Predict
78
+ df = pd.DataFrame([data])
79
+ score = model.predict(df)[0]
80
+ return JSONResponse({'predicted': round(float(score),2)})