techlead_demo / app.py
bqmolina's picture
Update app.py
11493d6 verified
import torch
import torch.nn as nn
import gradio as gr
import numpy as np
# ===== Load Model =====
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear1 = nn.Linear(6, 4)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(4, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.linear1(x))
x = self.linear2(x)
x = self.sigmoid(x)
return x
# Adjust input size/hidden size/output size according to your model
model = MyModel()
model.load_state_dict(torch.load("best_model.pt", map_location=torch.device('cpu')))
model.eval()
# ===== Standardization placeholders (fill in later) =====
GRD_MEAN, GRD_STD = 3.078704, 2.114522
AATT_MEAN, AATT_STD = 0.903264, 0.068064
FATT_MEAN, FATT_STD = 0.897477, 0.089867
# ===== BMI classification function =====
BMI_TABLE = {
'boys': {
# age in years: (severe_max, moderate_max, normal_max, overweight_max)
5: (12.0, 12.9, 18.3, 20.2),
6: (12.0, 12.9, 18.6, 20.8),
7: (12.2, 13.0, 19, 21.6),
8: (12.3, 13.2, 19.7, 22.8),
9: (12.5, 13.4, 20.5, 24.3),
10: (12.7, 13.6, 21.4, 26.1),
11: (13, 14.0, 22.6, 29),
12: (13.3, 14.4, 23.6, 30),
13: (13.7, 14.9, 24.8, 31.7),
14: (14.2, 15.4, 25.9, 33.1),
15: (14.6, 15.9, 27, 34.1),
16: (15.1, 16.5, 27.9, 34.8),
17: (15.3, 16.8, 28.6, 35.2),
18: (15.6, 17.2, 29.2, 35.4)
},
'girls': {
5: (11.8, 12.6, 18.9, 21.2),
6: (11.7, 12.6, 19.2, 22.1),
7: (11.7, 12.6, 19.9, 23.5),
8: (11.8, 12.8, 20.6, 24.8),
9: (12, 13, 21.5, 21.6),
10: (12.3, 13.4, 22.6, 28.4),
11: (12.7, 13.8, 23.7, 30.2),
12: (13.1, 14.3, 25, 31.9),
13: (13.5, 14.8, 26.2, 33.4),
14: (13.9, 15.3, 27.3, 34.7),
15: (14.3, 15.8, 28.2, 35.5),
16: (14.5, 16.1, 28.9, 36.1),
17: (14.6, 16.4, 29.4, 36.3),
18: (14.6, 16.3, 29.5, 36.3)
}
}
def classify_bmi(age, sex, bmi):
table = BMI_TABLE['boys'] if sex.lower() == 'male' else BMI_TABLE['girls']
# Use nearest lower age if exact not found
ages = sorted(table.keys())
key = max([a for a in ages if a <= age], default=ages[0])
sev, mod, norm, over = table[key]
if bmi < sev:
return "Severely Wasted"
elif bmi < mod:
return "Wasted"
elif bmi <= norm:
return "Normal"
elif bmi <= over:
return "Overweight"
else:
return "Obese"
# ===== Prediction function =====
def predict_intervention(sex, grade, aatt, fatt, absl, aend, bbsl, bend, age):
# Predict yearend grade if missing
if aend is None or aend == 0:
aend = 0.9335 * absl + 8.0135
if absl == 0:
aend = 0
# Predict yearend BMI if missing
if bend is None or bend == 0:
bend = 0.6485 * bbsl + 8.0926
# Calculate AIMPR
if grade == 0 or grade == 7: # Kinder or SPED = No Record
aimpr = 0
else:
if aend > absl and aend >= 80: aimpr = 3 # Improved
elif aend == absl: aimpr = 2 # Maintained
else: aimpr = 1 # Worsened
# Calculate BIMPR
baseline_class = classify_bmi(age, sex, bbsl)
yearend_class = classify_bmi(age, sex, bend)
if baseline_class == yearend_class and (baseline_class == "Severely Wasted" or baseline_class == "Wasted"):
bimpr = 0
elif yearend_class == "Wasted":
bimpr = 1
else:
bimpr = 2
# Convert categorical inputs to numeric
sex_num = 0 if sex == "Male" else 1
grd_std = (grade - GRD_MEAN) / GRD_STD
aatt_std = (aatt - AATT_MEAN) / AATT_STD
fatt_std = (fatt - FATT_MEAN) / FATT_STD
# Model input
x = np.array([[sex_num, grd_std, bimpr, aatt_std, fatt_std, aimpr]], dtype=np.float32)
with torch.no_grad():
pred = model(torch.tensor(x)).item()
ivn_pred = "Yes" if pred >= 0.5 else "No"
prob = round(pred, 3)
aend_pred = round(aend, 2)
bend_pred = round(bend, 2)
print((
ivn_pred,
prob,
aimpr,
bimpr,
aend_pred,
bend_pred,
baseline_class,
yearend_class
))
return (
ivn_pred,
prob,
aimpr,
bimpr,
aend_pred,
bend_pred,
baseline_class,
yearend_class
)
# ===== Gradio UI =====
inputs = [
gr.Radio(["Male", "Female"], label="Sex"),
gr.Dropdown([0,1,2,3,4,5,6,7], label="Grade Level (0=Kinder, 1-6=Grades, 7=SPED)"),
gr.Number(label="Class Attendance % (Decimal)", value=0.95),
gr.Number(label="Feeding Attendance % (Decimal)", value=0.95),
gr.Number(label="Baseline Grade (ABSL)", value=85),
gr.Number(label="Yearend Grade (AEND) (0 if no record yet)", value=None),
gr.Number(label="Baseline BMI (BBSL)", value=14.5),
gr.Number(label="Yearend BMI (BEND)(0 if no record yet)", value=None),
gr.Number(label="Age", value=8)
]
outputs = [
gr.Label(label="Predicted Intervention"),
gr.Label(label="Probability"),
gr.Label(label="AIMPR"),
gr.Label(label="BIMPR"),
gr.Label(label="Predicted AEND"),
gr.Label(label="Predicted BEND"),
gr.Label(label="Baseline BMI Classification"),
gr.Label(label="Yearend BMI Classification")
]
examples=[
# 1. Wasted at year-end
["Male", 6, 0.81, 0.98, 80, 81, 12.5, 14, 13],
# 2. Grade < 80 at year-end
["Male", 4, 0.95, 0.9, 70, 0, 13, 15.86, 11],
# 3. Healthy with grade >= 80
["Female", 0, 0.92, 0.94, 0, 0, 11.6, 16.0, 7]
]
gr.Interface(
fn=predict_intervention,
inputs=inputs,
outputs=outputs,
examples = examples,
title="Student Intervention Prediction",
description="Predict if a student needs intervention based on attendance, grades, and BMI."
).launch()