Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import pickle | |
| import joblib | |
| import os | |
| # Load the trained model (try both joblib and pickle in case one fails) | |
| def load_model(): | |
| try: | |
| model = joblib.load('pcos_model.joblib') | |
| print("Model loaded using joblib") | |
| return model | |
| except: | |
| try: | |
| with open('random_forest_model', 'rb') as file: | |
| model = pickle.load(file) | |
| print("Model loaded using pickle from random_forest_model") | |
| return model | |
| except: | |
| try: | |
| with open('random_forest_model.pkl', 'rb') as file: | |
| model = pickle.load(file) | |
| print("Model loaded using pickle from pcos_model.pkl") | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # Fallback to a simple model for demo purposes | |
| from sklearn.ensemble import RandomForestClassifier | |
| print("Creating a fallback model for demonstration") | |
| fallback_model = RandomForestClassifier(n_estimators=100, random_state=42) | |
| # Train with dummy data to initialize | |
| X_dummy = np.random.rand(100, 43) | |
| y_dummy = np.random.choice([0, 1], 100) | |
| fallback_model.fit(X_dummy, y_dummy) | |
| return fallback_model | |
| # Load the model | |
| model = load_model() | |
| # Define the features required for prediction | |
| features = [ | |
| "Age (yrs)", "Weight (Kg)", "Height(Cm)", "BMI", "Blood Group", "Pulse rate(bpm)", | |
| "RR (breaths/min)", "Hb(g/dl)", "Cycle length(days)", "Cycle(R/I)", "Marraige Status (Yrs)", | |
| "Pregnant(Y/N)", "No. of abortions", "Hip(inch)", "Waist(inch)", "Waist:Hip Ratio", | |
| "Weight gain(Y/N)", "hair growth(Y/N)", "Skin darkening (Y/N)", "Hair loss(Y/N)", | |
| "Pimples(Y/N)", "Fast food (Y/N)", "Reg.Exercise(Y/N)", "BP _Systolic (mmHg)", | |
| "BP _Diastolic (mmHg)", "Follicle No. (L)", "Follicle No. (R)", "Avg. F size (L) (mm)", | |
| "Avg. F size (R) (mm)", "Endometrium (mm)", "FSH(mIU/mL)", "LH(mIU/mL)", "FSH/LH", | |
| "Hip:Waist Ratio", "TSH (mIU/L)", "AMH(ng/mL)", "PRL(ng/mL)", "Vit D3 (ng/mL)", | |
| "PRG(ng/mL)", "RBS(mg/dl)", "Weight gain", "I beta-HCG(mIU/mL)", "II beta-HCG(mIU/mL)" | |
| ] | |
| # Create visualizations for the dashboard | |
| def create_visualizations(): | |
| # For demo purposes, we'll use sample data similar to what was in your notebook | |
| # In a real application, you would load the actual dataset | |
| # Sample data for visualization (small dataset for demo) | |
| np.random.seed(42) | |
| n_samples = 100 | |
| # Create sample data | |
| sample_data = { | |
| "Age (yrs)": np.random.normal(25, 5, n_samples), | |
| "PCOS (Y/N)": np.random.choice([0, 1], n_samples, p=[0.6, 0.4]), | |
| "BMI": np.random.normal(25, 5, n_samples), | |
| "Cycle length(days)": np.random.normal(28, 5, n_samples), | |
| "Follicle No. (L)": np.random.normal(12, 5, n_samples), | |
| "Follicle No. (R)": np.random.normal(12, 5, n_samples), | |
| "Endometrium (mm)": np.random.normal(8, 2, n_samples), | |
| "Cycle(R/I)": np.random.choice([2, 4], n_samples), | |
| "Weight (Kg)": np.random.normal(65, 10, n_samples), | |
| "Hb(g/dl)": np.random.normal(12, 1.5, n_samples) | |
| } | |
| # Create a DataFrame | |
| df = pd.DataFrame(sample_data) | |
| # For PCOS cases, adjust the values to show differences | |
| pcos_indices = df["PCOS (Y/N)"] == 1 | |
| df.loc[pcos_indices, "BMI"] += 2 | |
| df.loc[pcos_indices, "Cycle length(days)"] += 5 | |
| df.loc[pcos_indices, "Follicle No. (L)"] += 8 | |
| df.loc[pcos_indices, "Follicle No. (R)"] += 7 | |
| df.loc[pcos_indices, "Cycle(R/I)"] = 4 | |
| # Create visualizations | |
| visualizations = [] | |
| # 1. BMI vs Age scatter plot | |
| fig1, ax1 = plt.subplots(figsize=(8, 6)) | |
| sns.scatterplot(x="Age (yrs)", y="BMI", hue="PCOS (Y/N)", | |
| data=df, palette=["teal", "plum"], ax=ax1) | |
| ax1.set_title("BMI vs Age by PCOS Status") | |
| visualizations.append(fig1) | |
| # 2. Cycle length vs Age scatter plot | |
| fig2, ax2 = plt.subplots(figsize=(8, 6)) | |
| sns.scatterplot(x="Age (yrs)", y="Cycle length(days)", hue="PCOS (Y/N)", | |
| data=df, palette=["teal", "plum"], ax=ax2) | |
| ax2.set_title("Menstrual Cycle Length vs Age by PCOS Status") | |
| visualizations.append(fig2) | |
| # 3. Follicle distribution scatter plot | |
| fig3, ax3 = plt.subplots(figsize=(8, 6)) | |
| sns.scatterplot(x="Follicle No. (L)", y="Follicle No. (R)", hue="PCOS (Y/N)", | |
| data=df, palette=["teal", "plum"], ax=ax3) | |
| ax3.set_title("Follicle Distribution (Left vs Right Ovary)") | |
| visualizations.append(fig3) | |
| # 4. Boxplot for Follicle numbers | |
| fig4, ax4 = plt.subplots(figsize=(10, 6)) | |
| sns.boxplot(x="PCOS (Y/N)", y="Follicle No. (L)", data=df, palette=["teal", "plum"], ax=ax4) | |
| ax4.set_title("Follicle Count (Left Ovary) by PCOS Status") | |
| visualizations.append(fig4) | |
| # 5. Endometrium thickness boxplot | |
| fig5, ax5 = plt.subplots(figsize=(10, 6)) | |
| sns.boxplot(x="PCOS (Y/N)", y="Endometrium (mm)", data=df, palette=["teal", "plum"], ax=ax5) | |
| ax5.set_title("Endometrium Thickness by PCOS Status") | |
| visualizations.append(fig5) | |
| return visualizations | |
| # Helper function to get numerical value for categorical inputs | |
| def get_numerical_value(value, options): | |
| try: | |
| return options.index(value) | |
| except: | |
| return 0 | |
| # Helper function to preprocess inputs | |
| def preprocess_inputs(input_dict): | |
| # Convert checkbox values to 0/1 | |
| for key in input_dict: | |
| if isinstance(input_dict[key], bool): | |
| input_dict[key] = 1 if input_dict[key] else 0 | |
| # Convert blood group to numeric | |
| blood_groups = ["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"] | |
| if "Blood Group" in input_dict and input_dict["Blood Group"] in blood_groups: | |
| input_dict["Blood Group"] = blood_groups.index(input_dict["Blood Group"]) | |
| return input_dict | |
| # Function to process input and make predictions | |
| def predict_pcos(*args): | |
| if model is None: | |
| return "Model not loaded correctly. Please check if model files are available." | |
| try: | |
| # Convert inputs to a dictionary and then DataFrame | |
| input_dict = {feature: value for feature, value in zip(features, args)} | |
| # Preprocess inputs | |
| input_dict = preprocess_inputs(input_dict) | |
| # Convert to DataFrame | |
| input_df = pd.DataFrame([input_dict]) | |
| # Print for debugging | |
| print("Input shape:", input_df.shape) | |
| print("Input data types:", input_df.dtypes) | |
| # Make prediction | |
| try: | |
| prediction = model.predict(input_df)[0] | |
| probability = model.predict_proba(input_df)[0] | |
| result = "Positive for PCOS" if prediction == 1 else "Negative for PCOS" | |
| conf = probability[1] if prediction == 1 else probability[0] | |
| return f"{result} (Confidence: {conf:.2f})" | |
| except AttributeError: | |
| # If model is a numpy array, use a simple threshold-based prediction | |
| # This is a fallback if the loaded model is just coefficients | |
| print("Model is not a classifier object, using fallback prediction") | |
| risk_score = np.mean([ | |
| input_df["BMI"].values[0] / 30, | |
| input_df["Follicle No. (L)"].values[0] / 15, | |
| input_df["Follicle No. (R)"].values[0] / 15, | |
| (1 if input_df["Cycle(R/I)"].values[0] > 3 else 0) | |
| ]) | |
| prediction = 1 if risk_score > 0.6 else 0 | |
| result = "Positive for PCOS" if prediction == 1 else "Negative for PCOS" | |
| return f"{result} (Risk Score: {risk_score:.2f})" | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return f"Error making prediction: {str(e)}" | |
| # Function to display visualizations | |
| def show_visualization(visualization_index): | |
| visualizations = create_visualizations() | |
| if 0 <= visualization_index < len(visualizations): | |
| return visualizations[visualization_index] | |
| return None | |
| # Create the Gradio interface | |
| with gr.Blocks(title="PCOS Detection Tool") as app: | |
| gr.Markdown("# PCOS Detection and Analysis Tool") | |
| gr.Markdown("This application uses machine learning to detect Polycystic Ovary Syndrome (PCOS) based on patient data.") | |
| with gr.Tabs(): | |
| with gr.TabItem("Make Prediction"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Patient Demographics") | |
| age = gr.Slider(18, 50, value=25, label="Age (yrs)") | |
| weight = gr.Slider(40, 120, value=60, label="Weight (Kg)") | |
| height = gr.Slider(140, 190, value=160, label="Height (cm)") | |
| blood_group = gr.Dropdown(["A+", "A-", "B+", "B-", "AB+", "AB-", "O+", "O-"], value="A+", label="Blood Group") | |
| bmi = gr.Slider(15, 40, value=22, label="BMI") | |
| with gr.Column(): | |
| gr.Markdown("### Vital Signs") | |
| pulse = gr.Slider(60, 120, value=80, label="Pulse rate (bpm)") | |
| rr = gr.Slider(12, 25, value=16, label="Respiratory Rate (breaths/min)") | |
| systolic = gr.Slider(90, 180, value=120, label="BP Systolic (mmHg)") | |
| diastolic = gr.Slider(60, 120, value=80, label="BP Diastolic (mmHg)") | |
| hb = gr.Slider(8, 18, value=12, label="Hemoglobin (g/dl)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Menstrual History") | |
| cycle_length = gr.Slider(21, 45, value=28, label="Cycle length (days)") | |
| cycle_regularity = gr.Radio([2, 4], value=2, label="Cycle Regularity (2=Regular, 4=Irregular)") | |
| with gr.Column(): | |
| gr.Markdown("### Physical Measurements") | |
| hip = gr.Slider(30, 60, value=40, label="Hip (inch)") | |
| waist = gr.Slider(20, 50, value=30, label="Waist (inch)") | |
| waist_hip_ratio = gr.Slider(0.6, 1.2, value=0.75, label="Waist:Hip Ratio") | |
| hip_waist_ratio = gr.Slider(1.0, 2.0, value=1.33, label="Hip:Waist Ratio") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Symptoms") | |
| weight_gain = gr.Checkbox(label="Weight gain", value=False) | |
| hair_growth = gr.Checkbox(label="Excessive hair growth", value=False) | |
| skin_darkening = gr.Checkbox(label="Skin darkening", value=False) | |
| hair_loss = gr.Checkbox(label="Hair loss", value=False) | |
| pimples = gr.Checkbox(label="Pimples", value=False) | |
| with gr.Column(): | |
| gr.Markdown("### Lifestyle") | |
| fast_food = gr.Checkbox(label="Fast food consumption", value=False) | |
| regular_exercise = gr.Checkbox(label="Regular exercise", value=False) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Ultrasound Findings") | |
| follicle_l = gr.Slider(0, 30, value=10, label="Follicle No. (Left)") | |
| follicle_r = gr.Slider(0, 30, value=10, label="Follicle No. (Right)") | |
| avg_fsize_l = gr.Slider(0, 25, value=5, label="Avg. Follicle size (Left) (mm)") | |
| avg_fsize_r = gr.Slider(0, 25, value=5, label="Avg. Follicle size (Right) (mm)") | |
| endometrium = gr.Slider(1, 20, value=8, label="Endometrium (mm)") | |
| with gr.Column(): | |
| gr.Markdown("### Hormone Levels") | |
| fsh = gr.Slider(0, 20, value=6, label="FSH (mIU/mL)") | |
| lh = gr.Slider(0, 20, value=7, label="LH (mIU/mL)") | |
| fsh_lh_ratio = gr.Slider(0, 3, value=0.85, label="FSH/LH Ratio") | |
| tsh = gr.Slider(0, 10, value=2.5, label="TSH (mIU/L)") | |
| amh = gr.Slider(0, 10, value=3, label="AMH (ng/mL)") | |
| prl = gr.Slider(0, 30, value=15, label="Prolactin (ng/mL)") | |
| vit_d3 = gr.Slider(0, 100, value=30, label="Vitamin D3 (ng/mL)") | |
| prg = gr.Slider(0, 20, value=5, label="Progesterone (ng/mL)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Other Medical") | |
| married_years = gr.Slider(0, 20, value=0, label="Marriage Status (Years)") | |
| pregnant = gr.Checkbox(label="Currently Pregnant", value=False) | |
| abortions = gr.Slider(0, 5, value=0, label="Number of abortions") | |
| rbs = gr.Slider(70, 200, value=90, label="Random Blood Sugar (mg/dl)") | |
| beta_hcg1 = gr.Slider(0, 100, value=5, label="I beta-HCG (mIU/mL)") | |
| beta_hcg2 = gr.Slider(0, 100, value=5, label="II beta-HCG (mIU/mL)") | |
| predict_btn = gr.Button("Predict PCOS Status") | |
| prediction_output = gr.Textbox(label="Prediction Result") | |
| # Connect inputs to prediction function | |
| input_components = [ | |
| age, weight, height, bmi, blood_group, pulse, rr, hb, cycle_length, | |
| cycle_regularity, married_years, pregnant, abortions, hip, waist, | |
| waist_hip_ratio, weight_gain, hair_growth, skin_darkening, hair_loss, | |
| pimples, fast_food, regular_exercise, systolic, diastolic, follicle_l, | |
| follicle_r, avg_fsize_l, avg_fsize_r, endometrium, fsh, lh, fsh_lh_ratio, | |
| hip_waist_ratio, tsh, amh, prl, vit_d3, prg, rbs, weight_gain, beta_hcg1, beta_hcg2 | |
| ] | |
| predict_btn.click( | |
| predict_pcos, | |
| inputs=input_components, | |
| outputs=prediction_output | |
| ) | |
| with gr.TabItem("Visualizations"): | |
| gr.Markdown("### PCOS Data Analysis Visualizations") | |
| visualization_choice = gr.Radio( | |
| ["BMI vs Age", "Menstrual Cycle Length vs Age", "Follicle Distribution", | |
| "Follicle Count Boxplot", "Endometrium Thickness"], | |
| value="BMI vs Age", | |
| label="Select Visualization" | |
| ) | |
| visualization_output = gr.Plot() | |
| visualization_choice.change( | |
| lambda choice: show_visualization(["BMI vs Age", "Menstrual Cycle Length vs Age", | |
| "Follicle Distribution", "Follicle Count Boxplot", | |
| "Endometrium Thickness"].index(choice)), | |
| inputs=visualization_choice, | |
| outputs=visualization_output | |
| ) | |
| with gr.TabItem("About PCOS"): | |
| gr.Markdown(""" | |
| # Polycystic Ovary Syndrome (PCOS) | |
| Polycystic ovary syndrome (PCOS) is a hormonal disorder common among women of reproductive age. | |
| Women with PCOS may have infrequent or prolonged menstrual periods or excess male hormone (androgen) levels. | |
| ## Common Symptoms | |
| - Irregular periods | |
| - Excess androgen (elevated levels of male hormones) | |
| - Polycystic ovaries | |
| - Weight gain | |
| - Acne | |
| - Excessive hair growth (hirsutism) | |
| - Thinning hair or hair loss | |
| - Infertility | |
| ## Risk Factors | |
| - Having a mother or sister with PCOS | |
| - Insulin resistance | |
| - Obesity | |
| ## Complications | |
| - Infertility | |
| - Gestational diabetes or pregnancy-induced high blood pressure | |
| - Miscarriage or premature birth | |
| - Type 2 diabetes or prediabetes | |
| - Depression, anxiety, and eating disorders | |
| - Sleep apnea | |
| - Endometrial cancer | |
| - Cardiovascular disease | |
| ## Treatment | |
| Treatment focuses on managing your individual concerns, such as infertility, hirsutism, acne or obesity. | |
| Specific treatment might involve lifestyle changes or medication. | |
| """) | |
| with gr.TabItem("Debug Info"): | |
| gr.Markdown("### Model and System Information") | |
| debug_output = gr.Textbox(label="Debug Information", value=f"Model type: {type(model).__name__}") | |
| debug_btn = gr.Button("Check Model Status") | |
| def check_model(): | |
| try: | |
| if model is None: | |
| return "Model not loaded" | |
| model_info = f"Model type: {type(model).__name__}\n" | |
| # Try to get additional info based on model type | |
| if hasattr(model, 'n_estimators'): | |
| model_info += f"Number of estimators: {model.n_estimators}\n" | |
| if hasattr(model, 'feature_importances_'): | |
| top_features = np.argsort(model.feature_importances_)[-5:] | |
| model_info += "Top 5 important features (indices): " + str(top_features.tolist()) + "\n" | |
| # Check if the model has predict and predict_proba methods | |
| has_predict = hasattr(model, 'predict') and callable(getattr(model, 'predict')) | |
| has_proba = hasattr(model, 'predict_proba') and callable(getattr(model, 'predict_proba')) | |
| model_info += f"Has predict method: {has_predict}\n" | |
| model_info += f"Has predict_proba method: {has_proba}\n" | |
| return model_info | |
| except Exception as e: | |
| return f"Error checking model: {str(e)}" | |
| debug_btn.click(check_model, outputs=debug_output) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| app.launch(share=True, debug=True) |