triflix commited on
Commit
60d856a
Β·
verified Β·
1 Parent(s): c439510

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +160 -0
main.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import pandas as pd
4
+ from fastapi import FastAPI, Request, HTTPException
5
+ from fastapi.responses import HTMLResponse, JSONResponse
6
+ from fastapi.templating import Jinja2Templates
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from groq import Groq
9
+
10
+ app = FastAPI(title="Student Score Predictor Chatbot + Groq")
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["*"], allow_credentials=True,
14
+ allow_methods=["*"], allow_headers=["*"],
15
+ )
16
+ templates = Jinja2Templates(directory="templates")
17
+
18
+ # β€”β€”β€” Load model at startup β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
19
+ MODEL_PATH = os.getenv('MODEL_PATH', 'student_performance_model.pkl')
20
+ try:
21
+ with open(MODEL_PATH, 'rb') as f:
22
+ model = pickle.load(f)
23
+ except Exception as e:
24
+ raise RuntimeError(f"Could not load model: {e}")
25
+
26
+ # β€”β€”β€” Load Groq API key β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
27
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY')
28
+ if not GROQ_API_KEY:
29
+ raise RuntimeError("Missing Groq API key. Set env var GROQ_API_KEY.")
30
+
31
+ groq_client = Groq(api_key=GROQ_API_KEY)
32
+
33
+ # β€”β€”β€” Chat‑fields configuration β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
34
+ FIELDS = [
35
+ {'name': 'Age', 'type': 'number',
36
+ 'question': 'What is your age?',
37
+ 'validation': {'min': 5, 'max': 100}},
38
+ {'name': 'Gender', 'type': 'select',
39
+ 'question': 'What is your gender?',
40
+ 'options': ['Male', 'Female', 'Other']},
41
+ {'name': 'HoursOfStudyPerDay', 'type': 'number',
42
+ 'question': 'Hours of study per day?',
43
+ 'validation': {'min': 0, 'max': 24}},
44
+ {'name': 'SchoolAttendanceRate', 'type': 'number',
45
+ 'question': 'School attendance rate (%)?',
46
+ 'validation': {'min': 0, 'max': 100}},
47
+ {'name': 'TuitionAccess', 'type': 'select',
48
+ 'question': 'Access to extra tuition?',
49
+ 'options': ['Yes', 'No']},
50
+ {'name': 'AveragePreviousScores', 'type': 'number',
51
+ 'question': 'Average previous score?',
52
+ 'validation': {'min': 0, 'max': 100}},
53
+ {'name': 'HoursOfSleep', 'type': 'number',
54
+ 'question': 'Hours of sleep per night?',
55
+ 'validation': {'min': 0, 'max': 24}},
56
+ {'name': 'BreakfastDaily', 'type': 'select',
57
+ 'question': 'Do you eat breakfast daily?',
58
+ 'options': ['Yes', 'No']},
59
+ {'name': 'ScreenTimeHours', 'type': 'number',
60
+ 'question': 'Screen time hours per day?',
61
+ 'validation': {'min': 0, 'max': 24}},
62
+ {'name': 'PhysicalActivityHours', 'type': 'number',
63
+ 'question': 'Physical activity hours per day?',
64
+ 'validation': {'min': 0, 'max': 24}},
65
+ {'name': 'PlaysSport', 'type': 'select',
66
+ 'question': 'Do you play sports?',
67
+ 'options': ['Yes', 'No']},
68
+ {'name': 'MentalHealthScore', 'type': 'number',
69
+ 'question': 'Rate your mental health (1–10).',
70
+ 'validation': {'min': 1, 'max': 10}},
71
+ {'name': 'ParentalEducationLevel', 'type': 'select',
72
+ 'question': 'Parental education level?',
73
+ 'options': ['High school', 'Graduate', 'Postgrad']},
74
+ {'name': 'HouseholdIncomeLevel', 'type': 'select',
75
+ 'question': 'Household income level?',
76
+ 'options': ['Low', 'Medium', 'High']},
77
+ {'name': 'StudyEnvironmentRating', 'type': 'number',
78
+ 'question': 'Rate your study environment (1–5).',
79
+ 'validation': {'min': 1, 'max': 5}},
80
+ {'name': 'FriendSupportScore', 'type': 'number',
81
+ 'question': 'Friend support score (1–10).',
82
+ 'validation': {'min': 1, 'max': 10}},
83
+ {'name': 'ParticipatesInClubs', 'type': 'select',
84
+ 'question': 'Do you participate in clubs?',
85
+ 'options': ['Yes', 'No']},
86
+ {'name': 'PartTimeWork', 'type': 'select',
87
+ 'question': 'Do you do part‑time work?',
88
+ 'options': ['Yes', 'No']},
89
+ ]
90
+
91
+ @app.get("/", response_class=HTMLResponse)
92
+ async def chat_ui(request: Request):
93
+ return templates.TemplateResponse("chat.html", {
94
+ "request": request,
95
+ "fields": FIELDS
96
+ })
97
+
98
+ @app.post("/predict_json")
99
+ async def predict_and_advise(payload: dict):
100
+ # β€” validate & cast β€”
101
+ data = {}
102
+ for f in FIELDS:
103
+ key = f["name"]
104
+ if key not in payload:
105
+ raise HTTPException(400, f"Missing field: {key}")
106
+ val = payload[key]
107
+ if f["type"] == "number":
108
+ try:
109
+ val = float(val)
110
+ except:
111
+ raise HTTPException(400, f"{key} must be numeric")
112
+ data[key] = val
113
+
114
+ # β€” range checks β€”
115
+ for f in FIELDS:
116
+ if f["type"] == "number" and "validation" in f:
117
+ mn, mx = f["validation"]["min"], f["validation"]["max"]
118
+ if not (mn <= data[f["name"]] <= mx):
119
+ raise HTTPException(400,
120
+ f"{f['name']} must be between {mn} and {mx}")
121
+
122
+ # β€” predict score β€”
123
+ df = pd.DataFrame([data])
124
+ score = float(model.predict(df)[0])
125
+ data["PredictedScore"] = round(score, 2)
126
+
127
+ # β€” build Groq chat messages β€”
128
+ system_msg = {
129
+ "role": "system",
130
+ "content": (
131
+ "You are an expert academic coach. "
132
+ "Given a student’s profile data and their predicted final exam score, "
133
+ "provide a concise performance analysis and actionable improvement suggestions."
134
+ )
135
+ }
136
+ lines = [f"{k}: {v}" for k, v in data.items() if k != "PredictedScore"]
137
+ user_msg = {
138
+ "role": "user",
139
+ "content": (
140
+ "Here is the student data:\n" +
141
+ "\n".join(lines) +
142
+ f"\nPredicted final exam score: {data['PredictedScore']}\n"
143
+ "What targeted advice can you give them to improve their performance?"
144
+ )
145
+ }
146
+
147
+ # β€” call Groq β€”
148
+ resp = groq_client.chat.completions.create(
149
+ model="llama-3.3-70b-versatile",
150
+ messages=[system_msg, user_msg],
151
+ temperature=0.5,
152
+ max_completion_tokens=512,
153
+ top_p=1.0
154
+ )
155
+ advice = resp.choices[0].message.content
156
+
157
+ return JSONResponse({
158
+ "predicted": data["PredictedScore"],
159
+ "advice": advice
160
+ })