import pandas as pd import numpy as np df = pd.read_csv('PLOS_model_results.csv') df = df.drop('Unnamed: 0', axis=1) df_metrics = df.pivot_table(index='dataset', columns='model', values=['accuracy', 'precision', 'recall', 'f1_score']) # Reorganize columns cols_order = [] metrics = ['accuracy', 'precision', 'recall', 'f1_score'] models = ['DecisionTree', 'RandomForest', 'SVM', 'KNN', 'LogisticRegression'] for metric in metrics: for model in models: cols_order.append((metric, model)) df_metrics = df_metrics[cols_order] # Rename columns for better readability new_cols = [] for metric, model in df_metrics.columns: model_short = model.replace("DecisionTree", "DT") \ .replace("RandomForest", "RF") \ .replace("LogisticRegression", "LR") metric_name = "F1 Score" if metric == "f1_score" else metric.capitalize() new_cols.append(f"{model_short} {metric_name}") df_metrics.columns = new_cols df_metrics = df_metrics.sort_index(key=lambda x: x.str[1:].astype(int)) df_compare = pd.read_csv("pairwise_comparison_results.csv") DATASET_CATEGORIES = { "Medical & Healthcare": { "D1": "Heart Disease (Comprehensive)", "D2": "Heart attack possibility", "D3": "Heart Disease Dataset", "D4": "Liver Disorders", "D5": "Diabetes Prediction", "D9": "Chronic Kidney Disease", "D10": "Breast Cancer Prediction", "D11": "Stroke Prediction", "D12": "Lung Cancer Prediction", "D13": "Hepatitis", "D15": "Thyroid Disease", "D16": "Heart Failure Prediction", "D17": "Parkinson's", "D18": "Indian Liver Patient", "D19": "COVID-19 Effect on Liver Cancer", "D20": "Liver Dataset", "D21": "Specht Heart", "D22": "Early-stage Diabetes", "D23": "Diabetic Retinopathy", "D24": "Breast Cancer Coimbra", "D25": "Chronic Kidney Disease", "D26": "Kidney Stone", "D28": "Echocardiogram", "D29": "Bladder Cancer Recurrence", "D31": "Prostate Cancer", "D46": "Real Breast Cancer Data", "D47": "Breast Cancer (Royston)", "D48": "Lung Cancer Dataset", "D52": "Cervical Cancer Risk", "D53": "Breast Cancer Wisconsin", "D61": "Breast Cancer Prediction", "D62": "Thyroid Disease", "D68": "Lung Cancer", "D69": "Cancer Patients Data", "D70": "Labor Relations", "D71": "Glioma Grading", "D74": "Post-Operative Patient", "D80": "Heart Rate Stress Monitoring", "D82": "Diabetes 2019", "D87": "Personal Heart Disease Indicators", "D92": "Heart Disease (Logistic)", "D95": "Diabetes Prediction", "D97": "Cardiovascular Disease", "D98": "Diabetes 130 US Hospitals", "D99": "Heart Disease Dataset", "D181": "HCV Data", "D184": "Cardiotocography", "D189": "Mammographic Mass", "D199": "Easiest Diabetes", "D200": "Monkey-Pox Patients", "D54": "Breast Cancer Wisconsin", "D63": "Sick-euthyroid", "D64": "Ann-test", "D65": "Ann-train", "D66": "Hypothyroid", "D67": "New-thyroid", "D72": "Glioma Grading", }, "Gaming & Sports": { "D27": "Chess King-Rook", "D36": "Tic-Tac-Toe", "D40": "IPL 2022 Matches", "D41": "League of Legends", "D55": "League of Legends Diamond", "D56": "Chess Game Dataset", "D57": "Game of Thrones", "D73": "Connect-4", "D75": "FIFA 2018", "D76": "Dota 2 Matches", "D77": "IPL Match Analysis", "D78": "CS:GO Professional", "D79": "IPL 2008-2022", "D114": "Video Games", "D115": "Video Games Sales", "D117": "Sacred Games", "D118": "PC Games Sales", "D119": "Popular Video Games", "D120": "Olympic Games 2021", "D121": "Video Games ESRB", "D122": "Top Play Store Games", "D123": "Steam Games", "D124": "PS4 Games", "D116": "Video Games Sales", }, "Education & Students": { "D43": "Student Marks", "D44": "Student 2nd Year Result", "D45": "Student Mat Pass/Fail", "D103": "Academic Performance", "D104": "Student Academic Analysis", "D105": "Student Dropout Prediction", "D106": "Electronic Gadgets Impact", "D107": "Campus Recruitment", "D108": "End-Semester Performance", "D109": "Fitbits and Grades", "D110": "Student Time Management", "D111": "Student Feedback", "D112": "Depression & Performance", "D113": "University Rankings", "D126": "University Ranking CWUR", "D127": "University Ranking CWUR 2013-2014", "D128": "University Ranking CWUR 2014-2015", "D129": "University Ranking CWUR 2015-2016", "D130": "University Ranking CWUR 2016-2017", "D131": "University Ranking CWUR 2017-2018", "D132": "University Ranking CWUR 2018-2019", "D133": "University Ranking CWUR 2019-2020", "D134": "University Ranking CWUR 2020-2021", "D135": "University Ranking CWUR 2021-2022", "D136": "University Ranking CWUR 2022-2023", "D137": "University Ranking GM 2016", "D138": "University Ranking GM 2017", "D139": "University Ranking GM 2018", "D140": "University Ranking GM 2019", "D141": "University Ranking GM 2020", "D142": "University Ranking GM 2021", "D143": "University Ranking GM 2022", "D144": "University Ranking Webometric 2012", "D145": "University Ranking Webometric 2013", "D146": "University Ranking Webometric 2014", "D147": "University Ranking Webometric 2015", "D148": "University Ranking Webometric 2016", "D149": "University Ranking Webometric 2017", "D150": "University Ranking Webometric 2018", "D151": "University Ranking Webometric 2019", "D152": "University Ranking Webometric 2020", "D153": "University Ranking Webometric 2021", "D154": "University Ranking Webometric 2022", "D155": "University Ranking Webometric 2023", "D156": "University Ranking URAP 2018-2019", "D157": "University Ranking URAP 2019-2020", "D158": "University Ranking URAP 2020-2021", "D159": "University Ranking URAP 2021-2022", "D160": "University Ranking URAP 2022-2023", "D161": "University Ranking THE 2011", "D162": "University Ranking THE 2012", "D163": "University Ranking THE 2013", "D164": "University Ranking THE 2014", "D165": "University Ranking THE 2015", "D166": "University Ranking THE 2016", "D167": "University Ranking THE 2017", "D168": "University Ranking THE 2018", "D169": "University Ranking THE 2019", "D170": "University Ranking THE 2020", "D171": "University Ranking THE 2021", "D172": "University Ranking THE 2022", "D173": "University Ranking THE 2023", "D174": "University Ranking QS 2022", "D190": "Student Academics Performance" }, "Banking & Finance": { "D6": "Bank Marketing 1", "D7": "Bank Marketing 2", "D30": "Adult Income", "D32": "Telco Customer Churn", "D35": "Credit Approval", "D50": "Term Deposit Prediction", "D96": "Credit Card Fraud", "D188": "South German Credit", "D193": "Credit Risk Classification", "D195": "Credit Score Classification", "D196": "Banking Classification" }, "Science & Engineering": { "D8": "Mushroom", "D14": "Ionosphere", "D33": "EEG Eye State", "D37": "Steel Plates Faults", "D39": "Fertility", "D51": "Darwin", "D58": "EEG Emotions", "D81": "Predictive Maintenance", "D84": "Oranges vs Grapefruit", "D90": "Crystal System Li-ion", "D183": "Drug Consumption", "D49": "Air Pressure System Failures", "D93": "Air Pressure System Failures", "D185": "Toxicity", "D186": "Toxicity", }, "Social & Lifestyle": { "D38": "Online Shoppers", "D59": "Red Wine Quality", "D60": "White Wine Quality", "D88": "Airline Passenger Satisfaction", "D94": "Go Emotions Google", "D100": "Spotify East Asian", "D125": "Suicide Rates", "D182": "Obesity Levels", "D187": "Blood Transfusion", "D191": "Obesity Classification", "D192": "Gender Classification", "D194": "Happiness Classification", "D42": "Airline customer Holiday Booking dataset" }, "ML Benchmarks & Synthetic": { "D34": "Spambase", "D85": "Synthetic Binary", "D89": "Naive Bayes Data", "D175": "Monk's Problems 1", "D176": "Monk's Problems 2", "D177": "Monk's Problems 3", "D178": "Monk's Problems 4", "D179": "Monk's Problems 5", "D180": "Monk's Problems 6" }, "Other": { "D83": "Paris Housing", "D91": "Fake Bills", "D197": "Star Classification" } } import gradio as gr import pandas as pd import plotly.graph_objects as go import numpy as np def get_dataset_info(dataset_id): for category, datasets in DATASET_CATEGORIES.items(): if dataset_id in datasets: return category, datasets[dataset_id] return "Unknown", "Unknown Dataset" def filter_datasets_by_category(df_metrics, category="All Categories"): if category == "All Categories": return list(df_metrics.index) if category in DATASET_CATEGORIES: dataset_ids = list(DATASET_CATEGORIES[category].keys()) return [did for did in dataset_ids if did in df_metrics.index] return [] def format_dataset_choices(dataset_ids, filter_category="All Categories"): formatted = [] for dataset_id in dataset_ids: category, name = get_dataset_info(dataset_id) if filter_category == "All Categories": formatted.append(f"{dataset_id} - {category} | {name}") else: formatted.append(f"{dataset_id} - {name}") return formatted def extract_dataset_id(formatted_string): return formatted_string.split(" - ")[0] def create_comparison_plot(df_metrics, metric, model1, model2): col1 = f"{model1} {metric.capitalize()}" col2 = f"{model2} {metric.capitalize()}" if col1 not in df_metrics.columns or col2 not in df_metrics.columns: return None # Calculate difference diff = df_metrics[col1] - df_metrics[col2] # Create hover text with dataset info hover_texts = [] for dataset_id in df_metrics.index: category, name = get_dataset_info(dataset_id) hover_text = f"{dataset_id}
{category} | {name}" hover_texts.append(hover_text) # Calculate symmetric color scale range centered at 0 max_abs_diff = max(abs(diff.min()), abs(diff.max())) fig = go.Figure() fig.add_trace(go.Scatter( x=df_metrics[col1], y=df_metrics[col2], mode='markers', marker=dict( size=8, color=diff, colorscale='RdYlGn', showscale=True, colorbar=dict(title=f"{model1} - {model2}"), line=dict(width=0.5, color='white'), cmin=-max_abs_diff, cmax=max_abs_diff ), text=hover_texts, hovertemplate='%{text}

' + f'{model1}: %{{x:.3f}}
' + f'{model2}: %{{y:.3f}}
' + 'Difference: %{marker.color:.3f}
' )) # Add diagonal line (equal performance) min_val = min(df_metrics[col1].min(), df_metrics[col2].min()) max_val = max(df_metrics[col1].max(), df_metrics[col2].max()) fig.add_trace(go.Scatter( x=[min_val, max_val], y=[min_val, max_val], mode='lines', line=dict(dash='dash', color='gray', width=2), showlegend=False, hoverinfo='skip' )) # Calculate statistics wins_model1 = (diff > 0).sum() wins_model2 = (diff < 0).sum() ties = (diff == 0).sum() fig.update_layout( title=f"{model1} vs {model2} - {metric.capitalize()}
" + f"{model1} wins: {wins_model1} | {model2} wins: {wins_model2} | Ties: {ties}", xaxis_title=f"{model1} {metric.capitalize()}", yaxis_title=f"{model2} {metric.capitalize()}", height=550, template='plotly_white' ) return fig def create_dataset_performance(df_metrics, dataset_name): """Create performance plot for a single dataset""" if dataset_name not in df_metrics.index: return None dataset_row = df_metrics.loc[dataset_name] metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score'] models = ['DT', 'RF', 'SVM', 'KNN', 'LR'] fig = go.Figure() colors = {'DT': '#8b5cf6', 'RF': '#10b981', 'SVM': '#f59e0b', 'KNN': '#ef4444', 'LR': '#3b82f6'} for model in models: values = [] for metric in metrics: col_name = f"{model} {metric}" if col_name in dataset_row.index: values.append(dataset_row[col_name]) else: values.append(0) fig.add_trace(go.Bar( name=model, x=metrics, y=values, marker_color=colors.get(model, '#666'), text=[f'{round(v, 3)}' for v in values], textposition='outside' )) # Get dataset info for title category, name = get_dataset_info(dataset_name) fig.update_layout( title=f"Performance Metrics for {dataset_name}
{category} | {name}", xaxis_title="Metric", yaxis_title="Score", barmode='group', height=500, template='plotly_white', yaxis=dict(range=[0, 1.1]) ) return fig def create_category_performance(df_metrics, category): """Create performance plot for an entire category of datasets""" if category == "All Categories": dataset_ids = list(df_metrics.index) title_suffix = "All Datasets" else: dataset_ids = filter_datasets_by_category(df_metrics, category) title_suffix = category if not dataset_ids: return None df_filtered = df_metrics.loc[dataset_ids] models = ['DT', 'RF', 'SVM', 'KNN', 'LR'] metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score'] colors = {'DT': '#8b5cf6', 'RF': '#10b981', 'SVM': '#f59e0b', 'KNN': '#ef4444', 'LR': '#3b82f6'} fig = go.Figure() # Calculate average performance for each model-metric combination for model in models: values = [] for metric in metrics: col_name = f"{model} {metric}" if col_name in df_filtered.columns: avg_value = df_filtered[col_name].mean() values.append(avg_value) else: values.append(0) fig.add_trace(go.Bar( name=model, x=metrics, y=values, marker_color=colors.get(model, '#666'), text=[f'{v:.3f}' for v in values], textposition='outside' )) # Calculate summary statistics summary_text = f"Number of datasets: {len(dataset_ids)}" fig.update_layout( title=f"Average Model Performance - {title_suffix}
" + f"{summary_text}", xaxis_title="Metric", yaxis_title="Average Score", barmode='group', height=550, template='plotly_white', yaxis=dict(range=[0, 1.1]) ) return fig def build_app(df_metrics, df_compare): category_choices = ["All Categories"] + list(DATASET_CATEGORIES.keys()) all_formatted_choices = format_dataset_choices(df_metrics.index.tolist()) with gr.Blocks() as demo: gr.Markdown(f""" # Model Performance Comparison """) with gr.Tabs(): with gr.Tab("Model Comparison"): with gr.Row(): category_comp = gr.Dropdown( choices=category_choices, value="All Categories", label="Dataset Category" ) metric_comp = gr.Dropdown( choices=["accuracy", "precision", "recall", "f1_score"], value="accuracy", label="Select Metric" ) with gr.Row(): model1 = gr.Dropdown( choices=["DT", "RF", "SVM", "KNN", "LR"], value="RF", label="Model 1" ) model2 = gr.Dropdown( choices=["DT", "RF", "SVM", "KNN", "LR"], value="DT", label="Model 2" ) comp_plot = gr.Plot(label="Direct Comparison") comp_btn = gr.Button("Compare Models", variant="primary", size="lg") def compare_with_filter(category, metric, m1, m2): filtered_ids = filter_datasets_by_category(df_metrics, category) df_filtered = df_metrics.loc[filtered_ids] return create_comparison_plot(df_filtered, metric, m1, m2) comp_btn.click( compare_with_filter, inputs=[category_comp, metric_comp, model1, model2], outputs=comp_plot ) with gr.Tab("Dataset Performance"): with gr.Row(): view_type = gr.Radio( choices=["Single Dataset", "Category Overview"], value="Single Dataset", label="View Type" ) with gr.Row(): category_dataset = gr.Dropdown( choices=category_choices, value="All Categories", label="Dataset Category" ) dataset_selector = gr.Dropdown( choices=all_formatted_choices, value=all_formatted_choices[0], label="Select Dataset", visible=True ) dataset_plot = gr.Plot(label="Dataset Performance") dataset_btn = gr.Button("Show Performance", variant="primary", size="lg") # Update visibility based on view type def update_view_controls(view_type): if view_type == "Single Dataset": return gr.Dropdown(visible=True) else: return gr.Dropdown(visible=False) view_type.change( update_view_controls, inputs=view_type, outputs=dataset_selector ) def update_dataset_choices(category): filtered_ids = filter_datasets_by_category(df_metrics, category) formatted = format_dataset_choices(filtered_ids, category) return gr.Dropdown(choices=formatted, value=formatted[0] if formatted else None) category_dataset.change( update_dataset_choices, inputs=category_dataset, outputs=dataset_selector ) def show_performance(view_type, category, formatted_dataset): if view_type == "Single Dataset": dataset_id = extract_dataset_id(formatted_dataset) return create_dataset_performance(df_metrics, dataset_id) else: return create_category_performance(df_metrics, category) dataset_btn.click( show_performance, inputs=[view_type, category_dataset, dataset_selector], outputs=dataset_plot ) gr.Markdown(""" --- - **DT**: Decision Tree | **RF**: Random Forest | **SVM**: Support Vector Machine - **KNN**: K-Nearest Neighbors | **LR**: Logistic Regression """) return demo demo = build_app(df_metrics, df_compare) demo.launch()