Spaces:
Sleeping
Sleeping
Upload predict.py with huggingface_hub
Browse files- predict.py +79 -0
predict.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from model import CardiovascularRNN
|
| 5 |
+
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
| 6 |
+
|
| 7 |
+
class RiskPredictor:
|
| 8 |
+
def __init__(self, model_path='cardiovascular_rnn_model.pth', csv_path='cardiovascular_risk_dataset.csv'):
|
| 9 |
+
self.df = pd.read_csv(csv_path)
|
| 10 |
+
|
| 11 |
+
# Determine feature names and order
|
| 12 |
+
self.feature_names = self.df.drop(['Patient_ID', 'risk_category'], axis=1).columns.tolist()
|
| 13 |
+
|
| 14 |
+
self.le_risk = LabelEncoder()
|
| 15 |
+
self.le_risk.fit(self.df['risk_category'])
|
| 16 |
+
|
| 17 |
+
self.le_smoking = LabelEncoder()
|
| 18 |
+
self.le_smoking.fit(self.df['smoking_status'])
|
| 19 |
+
|
| 20 |
+
self.le_family = LabelEncoder()
|
| 21 |
+
self.le_family.fit(self.df['family_history_heart_disease'])
|
| 22 |
+
|
| 23 |
+
self.scaler = StandardScaler()
|
| 24 |
+
df_proc = self.df.drop(['Patient_ID', 'risk_category'], axis=1)
|
| 25 |
+
df_proc['smoking_status'] = self.le_smoking.transform(df_proc['smoking_status'])
|
| 26 |
+
df_proc['family_history_heart_disease'] = self.le_family.transform(df_proc['family_history_heart_disease'])
|
| 27 |
+
self.scaler.fit(df_proc.values)
|
| 28 |
+
|
| 29 |
+
self.input_size = 1
|
| 30 |
+
self.hidden_size = 64
|
| 31 |
+
self.num_layers = 2
|
| 32 |
+
self.num_classes = len(self.le_risk.classes_)
|
| 33 |
+
|
| 34 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 35 |
+
self.model = CardiovascularRNN(self.input_size, self.hidden_size, self.num_layers, self.num_classes).to(self.device)
|
| 36 |
+
self.model.load_state_dict(torch.load(model_path, map_location=self.device))
|
| 37 |
+
self.model.eval()
|
| 38 |
+
print(f"Model initialized with features: {self.feature_names}")
|
| 39 |
+
|
| 40 |
+
def predict_single(self, data_dict):
|
| 41 |
+
# Convert dictionary to DataFrame with explicit column order
|
| 42 |
+
input_df = pd.DataFrame([data_dict])[self.feature_names]
|
| 43 |
+
|
| 44 |
+
# Transform categorical
|
| 45 |
+
input_df['smoking_status'] = self.le_smoking.transform(input_df['smoking_status'])
|
| 46 |
+
input_df['family_history_heart_disease'] = self.le_family.transform(input_df['family_history_heart_disease'])
|
| 47 |
+
|
| 48 |
+
# Scale
|
| 49 |
+
input_scaled = self.scaler.transform(input_df.values)
|
| 50 |
+
input_tensor = torch.FloatTensor(input_scaled).reshape(1, -1, 1).to(self.device)
|
| 51 |
+
|
| 52 |
+
# Predict
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
output = self.model(input_tensor)
|
| 55 |
+
probs = torch.softmax(output, dim=1)
|
| 56 |
+
_, predicted = torch.max(output, 1)
|
| 57 |
+
predicted_label = self.le_risk.inverse_transform([predicted.item()])[0]
|
| 58 |
+
|
| 59 |
+
print(f"Prediction: {predicted_label} | Probabilities: {probs.cpu().numpy()}")
|
| 60 |
+
return predicted_label
|
| 61 |
+
|
| 62 |
+
def predict():
|
| 63 |
+
predictor = RiskPredictor()
|
| 64 |
+
print("Model loaded successfully.")
|
| 65 |
+
|
| 66 |
+
# Take a sample from the dataset for prediction
|
| 67 |
+
df = pd.read_csv('cardiovascular_risk_dataset.csv')
|
| 68 |
+
sample_row = df.drop(['Patient_ID', 'risk_category'], axis=1).iloc[0]
|
| 69 |
+
sample_dict = sample_row.to_dict()
|
| 70 |
+
true_label = df['risk_category'].iloc[0]
|
| 71 |
+
|
| 72 |
+
predicted_label = predictor.predict_single(sample_dict)
|
| 73 |
+
|
| 74 |
+
print(f"\nSample Data: {sample_dict}")
|
| 75 |
+
print(f"True Risk Category: {true_label}")
|
| 76 |
+
print(f"Predicted Risk Category: {predicted_label}")
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
predict()
|