Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import gradio as gr | |
| import numpy as np | |
| # ===== Load Model ===== | |
| class MyModel(nn.Module): | |
| def __init__(self): | |
| super(MyModel, self).__init__() | |
| self.linear1 = nn.Linear(6, 4) | |
| self.relu = nn.ReLU() | |
| self.linear2 = nn.Linear(4, 1) | |
| self.sigmoid = nn.Sigmoid() | |
| def forward(self, x): | |
| x = self.relu(self.linear1(x)) | |
| x = self.linear2(x) | |
| x = self.sigmoid(x) | |
| return x | |
| # Adjust input size/hidden size/output size according to your model | |
| model = MyModel() | |
| model.load_state_dict(torch.load("best_model.pt", map_location=torch.device('cpu'))) | |
| model.eval() | |
| # ===== Standardization placeholders (fill in later) ===== | |
| GRD_MEAN, GRD_STD = 3.078704, 2.114522 | |
| AATT_MEAN, AATT_STD = 0.903264, 0.068064 | |
| FATT_MEAN, FATT_STD = 0.897477, 0.089867 | |
| # ===== BMI classification function ===== | |
| BMI_TABLE = { | |
| 'boys': { | |
| # age in years: (severe_max, moderate_max, normal_max, overweight_max) | |
| 5: (12.0, 12.9, 18.3, 20.2), | |
| 6: (12.0, 12.9, 18.6, 20.8), | |
| 7: (12.2, 13.0, 19, 21.6), | |
| 8: (12.3, 13.2, 19.7, 22.8), | |
| 9: (12.5, 13.4, 20.5, 24.3), | |
| 10: (12.7, 13.6, 21.4, 26.1), | |
| 11: (13, 14.0, 22.6, 29), | |
| 12: (13.3, 14.4, 23.6, 30), | |
| 13: (13.7, 14.9, 24.8, 31.7), | |
| 14: (14.2, 15.4, 25.9, 33.1), | |
| 15: (14.6, 15.9, 27, 34.1), | |
| 16: (15.1, 16.5, 27.9, 34.8), | |
| 17: (15.3, 16.8, 28.6, 35.2), | |
| 18: (15.6, 17.2, 29.2, 35.4) | |
| }, | |
| 'girls': { | |
| 5: (11.8, 12.6, 18.9, 21.2), | |
| 6: (11.7, 12.6, 19.2, 22.1), | |
| 7: (11.7, 12.6, 19.9, 23.5), | |
| 8: (11.8, 12.8, 20.6, 24.8), | |
| 9: (12, 13, 21.5, 21.6), | |
| 10: (12.3, 13.4, 22.6, 28.4), | |
| 11: (12.7, 13.8, 23.7, 30.2), | |
| 12: (13.1, 14.3, 25, 31.9), | |
| 13: (13.5, 14.8, 26.2, 33.4), | |
| 14: (13.9, 15.3, 27.3, 34.7), | |
| 15: (14.3, 15.8, 28.2, 35.5), | |
| 16: (14.5, 16.1, 28.9, 36.1), | |
| 17: (14.6, 16.4, 29.4, 36.3), | |
| 18: (14.6, 16.3, 29.5, 36.3) | |
| } | |
| } | |
| def classify_bmi(age, sex, bmi): | |
| table = BMI_TABLE['boys'] if sex.lower() == 'male' else BMI_TABLE['girls'] | |
| # Use nearest lower age if exact not found | |
| ages = sorted(table.keys()) | |
| key = max([a for a in ages if a <= age], default=ages[0]) | |
| sev, mod, norm, over = table[key] | |
| if bmi < sev: | |
| return "Severely Wasted" | |
| elif bmi < mod: | |
| return "Wasted" | |
| elif bmi <= norm: | |
| return "Normal" | |
| elif bmi <= over: | |
| return "Overweight" | |
| else: | |
| return "Obese" | |
| # ===== Prediction function ===== | |
| def predict_intervention(sex, grade, aatt, fatt, absl, aend, bbsl, bend, age): | |
| # Predict yearend grade if missing | |
| if aend is None or aend == 0: | |
| aend = 0.9335 * absl + 8.0135 | |
| if absl == 0: | |
| aend = 0 | |
| # Predict yearend BMI if missing | |
| if bend is None or bend == 0: | |
| bend = 0.6485 * bbsl + 8.0926 | |
| # Calculate AIMPR | |
| if grade == 0 or grade == 7: # Kinder or SPED = No Record | |
| aimpr = 0 | |
| else: | |
| if aend > absl and aend >= 80: aimpr = 3 # Improved | |
| elif aend == absl: aimpr = 2 # Maintained | |
| else: aimpr = 1 # Worsened | |
| # Calculate BIMPR | |
| baseline_class = classify_bmi(age, sex, bbsl) | |
| yearend_class = classify_bmi(age, sex, bend) | |
| if baseline_class == yearend_class and (baseline_class == "Severely Wasted" or baseline_class == "Wasted"): | |
| bimpr = 0 | |
| elif yearend_class == "Wasted": | |
| bimpr = 1 | |
| else: | |
| bimpr = 2 | |
| # Convert categorical inputs to numeric | |
| sex_num = 0 if sex == "Male" else 1 | |
| grd_std = (grade - GRD_MEAN) / GRD_STD | |
| aatt_std = (aatt - AATT_MEAN) / AATT_STD | |
| fatt_std = (fatt - FATT_MEAN) / FATT_STD | |
| # Model input | |
| x = np.array([[sex_num, grd_std, bimpr, aatt_std, fatt_std, aimpr]], dtype=np.float32) | |
| with torch.no_grad(): | |
| pred = model(torch.tensor(x)).item() | |
| ivn_pred = "Yes" if pred >= 0.5 else "No" | |
| prob = round(pred, 3) | |
| aend_pred = round(aend, 2) | |
| bend_pred = round(bend, 2) | |
| print(( | |
| ivn_pred, | |
| prob, | |
| aimpr, | |
| bimpr, | |
| aend_pred, | |
| bend_pred, | |
| baseline_class, | |
| yearend_class | |
| )) | |
| return ( | |
| ivn_pred, | |
| prob, | |
| aimpr, | |
| bimpr, | |
| aend_pred, | |
| bend_pred, | |
| baseline_class, | |
| yearend_class | |
| ) | |
| # ===== Gradio UI ===== | |
| inputs = [ | |
| gr.Radio(["Male", "Female"], label="Sex"), | |
| gr.Dropdown([0,1,2,3,4,5,6,7], label="Grade Level (0=Kinder, 1-6=Grades, 7=SPED)"), | |
| gr.Number(label="Class Attendance % (Decimal)", value=0.95), | |
| gr.Number(label="Feeding Attendance % (Decimal)", value=0.95), | |
| gr.Number(label="Baseline Grade (ABSL)", value=85), | |
| gr.Number(label="Yearend Grade (AEND) (0 if no record yet)", value=None), | |
| gr.Number(label="Baseline BMI (BBSL)", value=14.5), | |
| gr.Number(label="Yearend BMI (BEND)(0 if no record yet)", value=None), | |
| gr.Number(label="Age", value=8) | |
| ] | |
| outputs = [ | |
| gr.Label(label="Predicted Intervention"), | |
| gr.Label(label="Probability"), | |
| gr.Label(label="AIMPR"), | |
| gr.Label(label="BIMPR"), | |
| gr.Label(label="Predicted AEND"), | |
| gr.Label(label="Predicted BEND"), | |
| gr.Label(label="Baseline BMI Classification"), | |
| gr.Label(label="Yearend BMI Classification") | |
| ] | |
| examples=[ | |
| # 1. Wasted at year-end | |
| ["Male", 6, 0.81, 0.98, 80, 81, 12.5, 14, 13], | |
| # 2. Grade < 80 at year-end | |
| ["Male", 4, 0.95, 0.9, 70, 0, 13, 15.86, 11], | |
| # 3. Healthy with grade >= 80 | |
| ["Female", 0, 0.92, 0.94, 0, 0, 11.6, 16.0, 7] | |
| ] | |
| gr.Interface( | |
| fn=predict_intervention, | |
| inputs=inputs, | |
| outputs=outputs, | |
| examples = examples, | |
| title="Student Intervention Prediction", | |
| description="Predict if a student needs intervention based on attendance, grades, and BMI." | |
| ).launch() |