ranimeree commited on
Commit
c9e72f1
·
verified ·
1 Parent(s): b32c1c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -40
app.py CHANGED
@@ -1,51 +1,95 @@
1
  import gradio as gr
2
  import pandas as pd
3
- import joblib
 
 
 
4
 
5
- # Load the model
6
- model_path = "model.pkl"
7
- model = joblib.load(model_path)
8
 
9
- # Define prediction function
10
- def predict_stroke(age, avg_glucose_level, bmi, gender, hypertension, heart_disease, ever_married, work_type, residence_type, smoking_status):
11
- # Prepare input data
12
- input_data = pd.DataFrame([{
13
- "age": age,
14
- "avg_glucose_level": avg_glucose_level,
15
- "bmi": bmi,
16
- "gender": gender,
17
- "hypertension": hypertension,
18
- "heart_disease": heart_disease,
19
- "ever_married": ever_married,
20
- "work_type": work_type,
21
- "Residence_type": residence_type,
22
- "smoking_status": smoking_status
23
- }])
24
- # Make predictions
25
- prediction = model.predict(input_data)[0]
26
- probability = model.predict_proba(input_data)[0][1]
27
- return f"Prediction: {'Stroke' if prediction == 1 else 'No Stroke'} (Probability of Stroke: {probability:.2f})"
28
 
29
- # Define Gradio interface
30
- demo = gr.Interface(
31
- fn=predict_stroke,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  inputs=[
33
- gr.Number(label="Age"),
 
 
 
 
 
34
  gr.Number(label="Average Glucose Level"),
35
- gr.Number(label="BMI"),
36
- gr.Radio(["Male", "Female", "Other"], label="Gender"),
37
- gr.Radio([0, 1], label="Hypertension (0 = No, 1 = Yes)"),
38
- gr.Radio([0, 1], label="Heart Disease (0 = No, 1 = Yes)"),
39
- gr.Radio(["Yes", "No"], label="Ever Married"),
40
- gr.Dropdown(["Private", "Self-employed", "Govt_job", "children", "Never_worked"], label="Work Type"),
41
- gr.Radio(["Urban", "Rural"], label="Residence Type"),
42
- gr.Dropdown(["never smoked", "formerly smoked", "smokes"], label="Smoking Status"),
43
  ],
44
- outputs="text",
45
- title="Stroke Prediction",
46
- description="Predict the likelihood of a stroke based on health and lifestyle inputs.",
47
  )
48
 
49
- # Launch the app
50
  if __name__ == "__main__":
51
- demo.launch()
 
1
  import gradio as gr
2
  import pandas as pd
3
+ import numpy as np
4
+ import mlflow
5
+ from sklearn.preprocessing import StandardScaler
6
+ import sklearn
7
 
8
+ print(f"Prediction environment scikit-learn version: {sklearn.__version__}")
 
 
9
 
10
+ # Load model from MLflow artifacts
11
+ model_path = "metadata/mlflow/mlartifacts/0951b451e9554321adaebc8f9f15ac8c/artifacts/train/model/artifacts/sk_model/model.pkl"
12
+ loaded_model = mlflow.sklearn.load_model(model_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def preprocess_input(data_dict):
15
+ """Preprocess input data to match the training format"""
16
+ df = pd.DataFrame([data_dict])
17
+
18
+ # Numeric features
19
+ numeric_features = ['age', 'avg_glucose_level', 'bmi']
20
+
21
+ # Scale numeric features
22
+ scaler = StandardScaler()
23
+ df[numeric_features] = scaler.fit_transform(df[numeric_features])
24
+
25
+ # Create dummy variables for categorical features
26
+ df = pd.get_dummies(df, columns=['gender', 'hypertension', 'heart_disease',
27
+ 'ever_married', 'work_type', 'Residence_type',
28
+ 'smoking_status'])
29
+
30
+ # Ensure all expected columns are present
31
+ expected_columns = [
32
+ 'num__age', 'num__avg_glucose_level', 'num__bmi',
33
+ 'cat__gender_Male', 'cat__gender_Other', 'cat__hypertension_1',
34
+ 'cat__heart_disease_1', 'cat__ever_married_Yes',
35
+ 'cat__work_type_Never_worked', 'cat__work_type_Private',
36
+ 'cat__work_type_Self-employed', 'cat__work_type_children',
37
+ 'cat__Residence_type_Urban', 'cat__smoking_status_formerly smoked',
38
+ 'cat__smoking_status_never smoked', 'cat__smoking_status_smokes'
39
+ ]
40
+
41
+ for col in expected_columns:
42
+ if col not in df.columns:
43
+ df[col] = 0
44
+
45
+ return df[expected_columns]
46
+
47
+ def predict(gender, age, hypertension, ever_married, work_type, heart_disease,
48
+ avg_glucose_level, bmi, smoking_status, Residence_type):
49
+ """Make prediction using the loaded model"""
50
+ # Create input dictionary
51
+ input_data = {
52
+ 'gender': gender,
53
+ 'age': age,
54
+ 'hypertension': 1 if hypertension == 'Yes' else 0,
55
+ 'heart_disease': 1 if heart_disease == 'Yes' else 0,
56
+ 'ever_married': ever_married,
57
+ 'work_type': work_type,
58
+ 'Residence_type': Residence_type,
59
+ 'avg_glucose_level': avg_glucose_level,
60
+ 'bmi': bmi,
61
+ 'smoking_status': smoking_status
62
+ }
63
+
64
+ # Preprocess the input
65
+ processed_input = preprocess_input(input_data)
66
+
67
+ # Use the loaded model
68
+ try:
69
+ prediction = loaded_model.predict_proba(processed_input)[0][1]
70
+ return f"The probability of stroke is {prediction:.2%}"
71
+ except Exception as e:
72
+ return f"Error making prediction: {str(e)}"
73
+
74
+ # Create the Gradio interface
75
+ iface = gr.Interface(
76
+ fn=predict,
77
  inputs=[
78
+ gr.Radio(choices=['Female', 'Male'], label="Gender"),
79
+ gr.Slider(minimum=0, maximum=100, label="Age"),
80
+ gr.Radio(choices=['Yes', 'No'], label="Hypertension"),
81
+ gr.Radio(choices=['Yes', 'No'], label="Ever Married"),
82
+ gr.Radio(choices=['Private', 'Self-employed', 'Govt_job', 'children', 'Never_worked'], label="Work Type"),
83
+ gr.Radio(choices=['Yes', 'No'], label="Heart Disease"),
84
  gr.Number(label="Average Glucose Level"),
85
+ gr.Slider(minimum=10, maximum=50, label="BMI"),
86
+ gr.Radio(choices=['formerly smoked', 'never smoked', 'smokes', 'Unknown'], label="Smoking Status"),
87
+ gr.Radio(choices=['Urban', 'Rural'], label="Residence Type")
 
 
 
 
 
88
  ],
89
+ outputs='text',
90
+ title='Stroke Probability Predictor',
91
+ description='Predicts the probability of having a stroke based on input features.'
92
  )
93
 
 
94
  if __name__ == "__main__":
95
+ iface.launch()