triflix commited on
Commit
2c9c234
·
verified ·
1 Parent(s): 42261ce

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +53 -49
main.py CHANGED
@@ -5,66 +5,70 @@ from fastapi import FastAPI, Request, Form
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.templating import Jinja2Templates
7
 
8
- app = FastAPI()
9
- # Templates folder
10
  templates = Jinja2Templates(directory="templates")
11
 
12
- # Load trained model (ensure student_performance_model.pkl is in /app)
13
  MODEL_PATH = os.getenv('MODEL_PATH', 'student_performance_model.pkl')
14
- with open(MODEL_PATH, 'rb') as f:
15
- model = pickle.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @app.get('/', response_class=HTMLResponse)
18
- async def read_form(request: Request):
19
  return templates.TemplateResponse('index.html', {
20
  'request': request,
21
- 'predicted': None
 
22
  })
23
 
24
  @app.post('/predict', response_class=HTMLResponse)
25
- async def predict(request: Request,
26
- Age: int = Form(...),
27
- Gender: str = Form(...),
28
- HoursOfStudyPerDay: float = Form(...),
29
- SchoolAttendanceRate: float = Form(...),
30
- TuitionAccess: str = Form(...),
31
- AveragePreviousScores: float = Form(...),
32
- HoursOfSleep: float = Form(...),
33
- BreakfastDaily: str = Form(...),
34
- ScreenTimeHours: float = Form(...),
35
- PhysicalActivityHours: float = Form(...),
36
- PlaysSport: str = Form(...),
37
- MentalHealthScore: float = Form(...),
38
- ParentalEducationLevel: str = Form(...),
39
- HouseholdIncomeLevel: str = Form(...),
40
- StudyEnvironmentRating: float = Form(...),
41
- FriendSupportScore: float = Form(...),
42
- ParticipatesInClubs: str = Form(...),
43
- PartTimeWork: str = Form(...)):
44
- # Organize input into DataFrame
45
- data = pd.DataFrame([{
46
- 'Age': Age,
47
- 'Gender': Gender,
48
- 'HoursOfStudyPerDay': HoursOfStudyPerDay,
49
- 'SchoolAttendanceRate': SchoolAttendanceRate,
50
- 'TuitionAccess': TuitionAccess,
51
- 'AveragePreviousScores': AveragePreviousScores,
52
- 'HoursOfSleep': HoursOfSleep,
53
- 'BreakfastDaily': BreakfastDaily,
54
- 'ScreenTimeHours': ScreenTimeHours,
55
- 'PhysicalActivityHours': PhysicalActivityHours,
56
- 'PlaysSport': PlaysSport,
57
- 'MentalHealthScore': MentalHealthScore,
58
- 'ParentalEducationLevel': ParentalEducationLevel,
59
- 'HouseholdIncomeLevel': HouseholdIncomeLevel,
60
- 'StudyEnvironmentRating': StudyEnvironmentRating,
61
- 'FriendSupportScore': FriendSupportScore,
62
- 'ParticipatesInClubs': ParticipatesInClubs,
63
- 'PartTimeWork': PartTimeWork
64
- }])
65
  # Predict
66
- pred = model.predict(data)[0]
 
 
 
 
 
67
  return templates.TemplateResponse('index.html', {
68
  'request': request,
69
- 'predicted': round(pred, 2)
 
70
  })
 
5
  from fastapi.responses import HTMLResponse
6
  from fastapi.templating import Jinja2Templates
7
 
8
+ app = FastAPI(title="Student Score Predictor")
 
9
  templates = Jinja2Templates(directory="templates")
10
 
11
+ # Load the trained model at startup
12
  MODEL_PATH = os.getenv('MODEL_PATH', 'student_performance_model.pkl')
13
+ try:
14
+ with open(MODEL_PATH, 'rb') as f:
15
+ model = pickle.load(f)
16
+ except Exception as e:
17
+ raise RuntimeError(f"Failed to load model: {e}")
18
+
19
+ # List of input fields and types
20
+ FIELDS = [
21
+ ('Age', int),
22
+ ('Gender', str),
23
+ ('HoursOfStudyPerDay', float),
24
+ ('SchoolAttendanceRate', float),
25
+ ('TuitionAccess', str),
26
+ ('AveragePreviousScores', float),
27
+ ('HoursOfSleep', float),
28
+ ('BreakfastDaily', str),
29
+ ('ScreenTimeHours', float),
30
+ ('PhysicalActivityHours', float),
31
+ ('PlaysSport', str),
32
+ ('MentalHealthScore', float),
33
+ ('ParentalEducationLevel', str),
34
+ ('HouseholdIncomeLevel', str),
35
+ ('StudyEnvironmentRating', float),
36
+ ('FriendSupportScore', float),
37
+ ('ParticipatesInClubs', str),
38
+ ('PartTimeWork', str),
39
+ ]
40
 
41
  @app.get('/', response_class=HTMLResponse)
42
+ async def get_form(request: Request):
43
  return templates.TemplateResponse('index.html', {
44
  'request': request,
45
+ 'predicted': None,
46
+ 'values': {}
47
  })
48
 
49
  @app.post('/predict', response_class=HTMLResponse)
50
+ async def post_predict(request: Request, **form_data):
51
+ # Convert form inputs to correct types
52
+ record = {}
53
+ for name, dtype in FIELDS:
54
+ raw = form_data.get(name)
55
+ if raw is None:
56
+ # missing field
57
+ record[name] = None
58
+ continue
59
+ record[name] = dtype(raw)
60
+
61
+ # Prepare DataFrame
62
+ df = pd.DataFrame([record])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # Predict
64
+ try:
65
+ score = model.predict(df)[0]
66
+ predicted = round(score, 2)
67
+ except Exception as e:
68
+ predicted = f"Error: {e}"
69
+
70
  return templates.TemplateResponse('index.html', {
71
  'request': request,
72
+ 'predicted': predicted,
73
+ 'values': record
74
  })