Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pickle | |
| import pandas as pd | |
| # Load the trained model | |
| filename = 'knn_model.sav' | |
| loaded_model = pickle.load(open(filename, 'rb')) | |
| # Helper functions for input transformations | |
| def hptension(hp): | |
| return 1 if hp == 'yes' else 0 | |
| def ht_dis(ht): | |
| return 1 if ht == 'yes' else 0 | |
| def gender_select(gen): | |
| return 1 if gen == 'male' else 0 | |
| def age_group_selector(age_grp): | |
| age_map = {'0-16': 0, '17-32': 1, '33-48': 2, '49-64': 3, '64+': 4} | |
| return age_map.get(age_grp, 0) | |
| def smoker_cat(smoke): | |
| smoke_map = {'formerly smoked': 0, 'never smoked': 1, 'smokes': 2, 'Prefer not to say': 3} | |
| return smoke_map.get(smoke, 3) | |
| # Prediction function | |
| def predict_insurance(input_gender, input_age_group, input_hypertension, input_heart_disease, input_avg_glucose_level, input_bmi, input_smoking_status): | |
| # Prepare the input data | |
| series = { | |
| 'gender': [gender_select(input_gender)], | |
| 'age_band': [age_group_selector(input_age_group)], | |
| 'hypertension': [hptension(input_hypertension)], | |
| 'heart_disease': [ht_dis(input_heart_disease)], | |
| 'avg_glucose_level': [input_avg_glucose_level / 272], | |
| 'bmi': [input_bmi / 49], | |
| 'smoking_status': [smoker_cat(input_smoking_status)], | |
| } | |
| vector = pd.DataFrame(series) | |
| # Perform prediction | |
| result = loaded_model.predict(vector) | |
| return "Risk of having stroke is high" if result[0] == 1 else "Risk of having stroke is low" | |
| # CSS to hide footer and markdown elements | |
| css = """ | |
| footer {display:none !important} | |
| .output-markdown{display:none !important} | |
| footer {visibility: hidden} | |
| .gr-button-lg { | |
| z-index: 14; | |
| width: 113px; | |
| height: 30px; | |
| left: 0px; | |
| top: 0px; | |
| padding: 0px; | |
| cursor: pointer !important; | |
| background: none rgb(17, 20, 45) !important; | |
| border: none !important; | |
| text-align: center !important; | |
| font-size: 14px !important; | |
| font-weight: 500 !important; | |
| color: rgb(255, 255, 255) !important; | |
| line-height: 1 !important; | |
| border-radius: 6px !important; | |
| transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; | |
| box-shadow: none !important; | |
| } | |
| .gr-button-lg:hover{ | |
| z-index: 14; | |
| width: 113px; | |
| height: 30px; | |
| left: 0px; | |
| top: 0px; | |
| padding: 0px; | |
| cursor: pointer !important; | |
| background: none rgb(66, 133, 244) !important; | |
| border: none !important; | |
| text-align: center !important; | |
| font-size: 14px !important; | |
| font-weight: 500 !important; | |
| color: rgb(255, 255, 255) !important; | |
| line-height: 1 !important; | |
| border-radius: 6px !important; | |
| transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; | |
| box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important; | |
| } | |
| """ | |
| # Gradio app layout | |
| with gr.Blocks(title="Brain Stroke Prediction | Data Science Dojo", css=css) as demo: | |
| with gr.Row(): | |
| input_gender = gr.Radio(["male", "female"], label="Gender") | |
| input_hypertension = gr.Radio(["yes", "no"], label="Hypertension") | |
| input_heart_disease = gr.Radio(["yes", "no"], label="Heart Disease") | |
| with gr.Row(): | |
| input_age_group = gr.Dropdown(['0-16', '17-32', '33-48', '49-64', '64+'], label='Age Group') | |
| input_smoking_status = gr.Dropdown(['formerly smoked', 'never smoked', 'smokes', 'Prefer not to say'], label='Smoker') | |
| with gr.Row(): | |
| input_avg_glucose_level = gr.Slider(0, 270, label='Average Glucose Level') | |
| with gr.Row(): | |
| input_bmi = gr.Slider(0, 45, label='BMI Range') | |
| with gr.Row(): | |
| stroke = gr.Textbox(label='Chances of Stroke', interactive=False) | |
| btn_ins = gr.Button(value="Submit") | |
| btn_ins.click( | |
| fn=predict_insurance, | |
| inputs=[ | |
| input_gender, | |
| input_age_group, | |
| input_hypertension, | |
| input_heart_disease, | |
| input_avg_glucose_level, | |
| input_bmi, | |
| input_smoking_status, | |
| ], | |
| outputs=[stroke] | |
| ) | |
| demo.launch(debug=True) | |