Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| df = pd.read_csv('PLOS_model_results.csv') | |
| df = df.drop('Unnamed: 0', axis=1) | |
| df_metrics = df.pivot_table(index='dataset', columns='model', values=['accuracy', 'precision', 'recall', 'f1_score']) | |
| # Reorganize columns | |
| cols_order = [] | |
| metrics = ['accuracy', 'precision', 'recall', 'f1_score'] | |
| models = ['DecisionTree', 'RandomForest', 'SVM', 'KNN', 'LogisticRegression'] | |
| for metric in metrics: | |
| for model in models: | |
| cols_order.append((metric, model)) | |
| df_metrics = df_metrics[cols_order] | |
| # Rename columns for better readability | |
| new_cols = [] | |
| for metric, model in df_metrics.columns: | |
| model_short = model.replace("DecisionTree", "DT") \ | |
| .replace("RandomForest", "RF") \ | |
| .replace("LogisticRegression", "LR") | |
| metric_name = "F1 Score" if metric == "f1_score" else metric.capitalize() | |
| new_cols.append(f"{model_short} {metric_name}") | |
| df_metrics.columns = new_cols | |
| df_metrics = df_metrics.sort_index(key=lambda x: x.str[1:].astype(int)) | |
| df_compare = pd.read_csv("pairwise_comparison_results.csv") | |
| DATASET_CATEGORIES = { | |
| "Medical & Healthcare": { | |
| "D1": "Heart Disease (Comprehensive)", | |
| "D2": "Heart attack possibility", | |
| "D3": "Heart Disease Dataset", | |
| "D4": "Liver Disorders", | |
| "D5": "Diabetes Prediction", | |
| "D9": "Chronic Kidney Disease", | |
| "D10": "Breast Cancer Prediction", | |
| "D11": "Stroke Prediction", | |
| "D12": "Lung Cancer Prediction", | |
| "D13": "Hepatitis", | |
| "D15": "Thyroid Disease", | |
| "D16": "Heart Failure Prediction", | |
| "D17": "Parkinson's", | |
| "D18": "Indian Liver Patient", | |
| "D19": "COVID-19 Effect on Liver Cancer", | |
| "D20": "Liver Dataset", | |
| "D21": "Specht Heart", | |
| "D22": "Early-stage Diabetes", | |
| "D23": "Diabetic Retinopathy", | |
| "D24": "Breast Cancer Coimbra", | |
| "D25": "Chronic Kidney Disease", | |
| "D26": "Kidney Stone", | |
| "D28": "Echocardiogram", | |
| "D29": "Bladder Cancer Recurrence", | |
| "D31": "Prostate Cancer", | |
| "D46": "Real Breast Cancer Data", | |
| "D47": "Breast Cancer (Royston)", | |
| "D48": "Lung Cancer Dataset", | |
| "D52": "Cervical Cancer Risk", | |
| "D53": "Breast Cancer Wisconsin", | |
| "D61": "Breast Cancer Prediction", | |
| "D62": "Thyroid Disease", | |
| "D68": "Lung Cancer", | |
| "D69": "Cancer Patients Data", | |
| "D70": "Labor Relations", | |
| "D71": "Glioma Grading", | |
| "D74": "Post-Operative Patient", | |
| "D80": "Heart Rate Stress Monitoring", | |
| "D82": "Diabetes 2019", | |
| "D87": "Personal Heart Disease Indicators", | |
| "D92": "Heart Disease (Logistic)", | |
| "D95": "Diabetes Prediction", | |
| "D97": "Cardiovascular Disease", | |
| "D98": "Diabetes 130 US Hospitals", | |
| "D99": "Heart Disease Dataset", | |
| "D181": "HCV Data", | |
| "D184": "Cardiotocography", | |
| "D189": "Mammographic Mass", | |
| "D199": "Easiest Diabetes", | |
| "D200": "Monkey-Pox Patients", | |
| "D54": "Breast Cancer Wisconsin", | |
| "D63": "Sick-euthyroid", | |
| "D64": "Ann-test", | |
| "D65": "Ann-train", | |
| "D66": "Hypothyroid", | |
| "D67": "New-thyroid", | |
| "D72": "Glioma Grading", | |
| }, | |
| "Gaming & Sports": { | |
| "D27": "Chess King-Rook", | |
| "D36": "Tic-Tac-Toe", | |
| "D40": "IPL 2022 Matches", | |
| "D41": "League of Legends", | |
| "D55": "League of Legends Diamond", | |
| "D56": "Chess Game Dataset", | |
| "D57": "Game of Thrones", | |
| "D73": "Connect-4", | |
| "D75": "FIFA 2018", | |
| "D76": "Dota 2 Matches", | |
| "D77": "IPL Match Analysis", | |
| "D78": "CS:GO Professional", | |
| "D79": "IPL 2008-2022", | |
| "D114": "Video Games", | |
| "D115": "Video Games Sales", | |
| "D117": "Sacred Games", | |
| "D118": "PC Games Sales", | |
| "D119": "Popular Video Games", | |
| "D120": "Olympic Games 2021", | |
| "D121": "Video Games ESRB", | |
| "D122": "Top Play Store Games", | |
| "D123": "Steam Games", | |
| "D124": "PS4 Games", | |
| "D116": "Video Games Sales", | |
| }, | |
| "Education & Students": { | |
| "D43": "Student Marks", | |
| "D44": "Student 2nd Year Result", | |
| "D45": "Student Mat Pass/Fail", | |
| "D103": "Academic Performance", | |
| "D104": "Student Academic Analysis", | |
| "D105": "Student Dropout Prediction", | |
| "D106": "Electronic Gadgets Impact", | |
| "D107": "Campus Recruitment", | |
| "D108": "End-Semester Performance", | |
| "D109": "Fitbits and Grades", | |
| "D110": "Student Time Management", | |
| "D111": "Student Feedback", | |
| "D112": "Depression & Performance", | |
| "D113": "University Rankings", | |
| "D126": "University Ranking CWUR", | |
| "D127": "University Ranking CWUR 2013-2014", | |
| "D128": "University Ranking CWUR 2014-2015", | |
| "D129": "University Ranking CWUR 2015-2016", | |
| "D130": "University Ranking CWUR 2016-2017", | |
| "D131": "University Ranking CWUR 2017-2018", | |
| "D132": "University Ranking CWUR 2018-2019", | |
| "D133": "University Ranking CWUR 2019-2020", | |
| "D134": "University Ranking CWUR 2020-2021", | |
| "D135": "University Ranking CWUR 2021-2022", | |
| "D136": "University Ranking CWUR 2022-2023", | |
| "D137": "University Ranking GM 2016", | |
| "D138": "University Ranking GM 2017", | |
| "D139": "University Ranking GM 2018", | |
| "D140": "University Ranking GM 2019", | |
| "D141": "University Ranking GM 2020", | |
| "D142": "University Ranking GM 2021", | |
| "D143": "University Ranking GM 2022", | |
| "D144": "University Ranking Webometric 2012", | |
| "D145": "University Ranking Webometric 2013", | |
| "D146": "University Ranking Webometric 2014", | |
| "D147": "University Ranking Webometric 2015", | |
| "D148": "University Ranking Webometric 2016", | |
| "D149": "University Ranking Webometric 2017", | |
| "D150": "University Ranking Webometric 2018", | |
| "D151": "University Ranking Webometric 2019", | |
| "D152": "University Ranking Webometric 2020", | |
| "D153": "University Ranking Webometric 2021", | |
| "D154": "University Ranking Webometric 2022", | |
| "D155": "University Ranking Webometric 2023", | |
| "D156": "University Ranking URAP 2018-2019", | |
| "D157": "University Ranking URAP 2019-2020", | |
| "D158": "University Ranking URAP 2020-2021", | |
| "D159": "University Ranking URAP 2021-2022", | |
| "D160": "University Ranking URAP 2022-2023", | |
| "D161": "University Ranking THE 2011", | |
| "D162": "University Ranking THE 2012", | |
| "D163": "University Ranking THE 2013", | |
| "D164": "University Ranking THE 2014", | |
| "D165": "University Ranking THE 2015", | |
| "D166": "University Ranking THE 2016", | |
| "D167": "University Ranking THE 2017", | |
| "D168": "University Ranking THE 2018", | |
| "D169": "University Ranking THE 2019", | |
| "D170": "University Ranking THE 2020", | |
| "D171": "University Ranking THE 2021", | |
| "D172": "University Ranking THE 2022", | |
| "D173": "University Ranking THE 2023", | |
| "D174": "University Ranking QS 2022", | |
| "D190": "Student Academics Performance" | |
| }, | |
| "Banking & Finance": { | |
| "D6": "Bank Marketing 1", | |
| "D7": "Bank Marketing 2", | |
| "D30": "Adult Income", | |
| "D32": "Telco Customer Churn", | |
| "D35": "Credit Approval", | |
| "D50": "Term Deposit Prediction", | |
| "D96": "Credit Card Fraud", | |
| "D188": "South German Credit", | |
| "D193": "Credit Risk Classification", | |
| "D195": "Credit Score Classification", | |
| "D196": "Banking Classification" | |
| }, | |
| "Science & Engineering": { | |
| "D8": "Mushroom", | |
| "D14": "Ionosphere", | |
| "D33": "EEG Eye State", | |
| "D37": "Steel Plates Faults", | |
| "D39": "Fertility", | |
| "D51": "Darwin", | |
| "D58": "EEG Emotions", | |
| "D81": "Predictive Maintenance", | |
| "D84": "Oranges vs Grapefruit", | |
| "D90": "Crystal System Li-ion", | |
| "D183": "Drug Consumption", | |
| "D49": "Air Pressure System Failures", | |
| "D93": "Air Pressure System Failures", | |
| "D185": "Toxicity", | |
| "D186": "Toxicity", | |
| }, | |
| "Social & Lifestyle": { | |
| "D38": "Online Shoppers", | |
| "D59": "Red Wine Quality", | |
| "D60": "White Wine Quality", | |
| "D88": "Airline Passenger Satisfaction", | |
| "D94": "Go Emotions Google", | |
| "D100": "Spotify East Asian", | |
| "D125": "Suicide Rates", | |
| "D182": "Obesity Levels", | |
| "D187": "Blood Transfusion", | |
| "D191": "Obesity Classification", | |
| "D192": "Gender Classification", | |
| "D194": "Happiness Classification", | |
| "D42": "Airline customer Holiday Booking dataset" | |
| }, | |
| "ML Benchmarks & Synthetic": { | |
| "D34": "Spambase", | |
| "D85": "Synthetic Binary", | |
| "D89": "Naive Bayes Data", | |
| "D175": "Monk's Problems 1", | |
| "D176": "Monk's Problems 2", | |
| "D177": "Monk's Problems 3", | |
| "D178": "Monk's Problems 4", | |
| "D179": "Monk's Problems 5", | |
| "D180": "Monk's Problems 6" | |
| }, | |
| "Other": { | |
| "D83": "Paris Housing", | |
| "D91": "Fake Bills", | |
| "D197": "Star Classification" | |
| } | |
| } | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| def get_dataset_info(dataset_id): | |
| for category, datasets in DATASET_CATEGORIES.items(): | |
| if dataset_id in datasets: | |
| return category, datasets[dataset_id] | |
| return "Unknown", "Unknown Dataset" | |
| def filter_datasets_by_category(df_metrics, category="All Categories"): | |
| if category == "All Categories": | |
| return list(df_metrics.index) | |
| if category in DATASET_CATEGORIES: | |
| dataset_ids = list(DATASET_CATEGORIES[category].keys()) | |
| return [did for did in dataset_ids if did in df_metrics.index] | |
| return [] | |
| def format_dataset_choices(dataset_ids, filter_category="All Categories"): | |
| formatted = [] | |
| for dataset_id in dataset_ids: | |
| category, name = get_dataset_info(dataset_id) | |
| if filter_category == "All Categories": | |
| formatted.append(f"{dataset_id} - {category} | {name}") | |
| else: | |
| formatted.append(f"{dataset_id} - {name}") | |
| return formatted | |
| def extract_dataset_id(formatted_string): | |
| return formatted_string.split(" - ")[0] | |
| def create_comparison_plot(df_metrics, metric, model1, model2): | |
| col1 = f"{model1} {metric.capitalize()}" | |
| col2 = f"{model2} {metric.capitalize()}" | |
| if col1 not in df_metrics.columns or col2 not in df_metrics.columns: | |
| return None | |
| # Calculate difference | |
| diff = df_metrics[col1] - df_metrics[col2] | |
| # Create hover text with dataset info | |
| hover_texts = [] | |
| for dataset_id in df_metrics.index: | |
| category, name = get_dataset_info(dataset_id) | |
| hover_text = f"<b>{dataset_id}</b><br>{category} | {name}" | |
| hover_texts.append(hover_text) | |
| # Calculate symmetric color scale range centered at 0 | |
| max_abs_diff = max(abs(diff.min()), abs(diff.max())) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=df_metrics[col1], | |
| y=df_metrics[col2], | |
| mode='markers', | |
| marker=dict( | |
| size=8, | |
| color=diff, | |
| colorscale='RdYlGn', | |
| showscale=True, | |
| colorbar=dict(title=f"{model1} - {model2}"), | |
| line=dict(width=0.5, color='white'), | |
| cmin=-max_abs_diff, | |
| cmax=max_abs_diff | |
| ), | |
| text=hover_texts, | |
| hovertemplate='%{text}<br><br>' + | |
| f'{model1}: %{{x:.3f}}<br>' + | |
| f'{model2}: %{{y:.3f}}<br>' + | |
| 'Difference: %{marker.color:.3f}<br>' | |
| )) | |
| # Add diagonal line (equal performance) | |
| min_val = min(df_metrics[col1].min(), df_metrics[col2].min()) | |
| max_val = max(df_metrics[col1].max(), df_metrics[col2].max()) | |
| fig.add_trace(go.Scatter( | |
| x=[min_val, max_val], | |
| y=[min_val, max_val], | |
| mode='lines', | |
| line=dict(dash='dash', color='gray', width=2), | |
| showlegend=False, | |
| hoverinfo='skip' | |
| )) | |
| # Calculate statistics | |
| wins_model1 = (diff > 0).sum() | |
| wins_model2 = (diff < 0).sum() | |
| ties = (diff == 0).sum() | |
| fig.update_layout( | |
| title=f"{model1} vs {model2} - {metric.capitalize()}<br>" + | |
| f"<sub>{model1} wins: {wins_model1} | {model2} wins: {wins_model2} | Ties: {ties}</sub>", | |
| xaxis_title=f"{model1} {metric.capitalize()}", | |
| yaxis_title=f"{model2} {metric.capitalize()}", | |
| height=550, | |
| template='plotly_white' | |
| ) | |
| return fig | |
| def create_dataset_performance(df_metrics, dataset_name): | |
| """Create performance plot for a single dataset""" | |
| if dataset_name not in df_metrics.index: | |
| return None | |
| dataset_row = df_metrics.loc[dataset_name] | |
| metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score'] | |
| models = ['DT', 'RF', 'SVM', 'KNN', 'LR'] | |
| fig = go.Figure() | |
| colors = {'DT': '#8b5cf6', 'RF': '#10b981', 'SVM': '#f59e0b', | |
| 'KNN': '#ef4444', 'LR': '#3b82f6'} | |
| for model in models: | |
| values = [] | |
| for metric in metrics: | |
| col_name = f"{model} {metric}" | |
| if col_name in dataset_row.index: | |
| values.append(dataset_row[col_name]) | |
| else: | |
| values.append(0) | |
| fig.add_trace(go.Bar( | |
| name=model, | |
| x=metrics, | |
| y=values, | |
| marker_color=colors.get(model, '#666'), | |
| text=[f'{round(v, 3)}' for v in values], | |
| textposition='outside' | |
| )) | |
| # Get dataset info for title | |
| category, name = get_dataset_info(dataset_name) | |
| fig.update_layout( | |
| title=f"Performance Metrics for {dataset_name}<br><sub>{category} | {name}</sub>", | |
| xaxis_title="Metric", | |
| yaxis_title="Score", | |
| barmode='group', | |
| height=500, | |
| template='plotly_white', | |
| yaxis=dict(range=[0, 1.1]) | |
| ) | |
| return fig | |
| def create_category_performance(df_metrics, category): | |
| """Create performance plot for an entire category of datasets""" | |
| if category == "All Categories": | |
| dataset_ids = list(df_metrics.index) | |
| title_suffix = "All Datasets" | |
| else: | |
| dataset_ids = filter_datasets_by_category(df_metrics, category) | |
| title_suffix = category | |
| if not dataset_ids: | |
| return None | |
| df_filtered = df_metrics.loc[dataset_ids] | |
| models = ['DT', 'RF', 'SVM', 'KNN', 'LR'] | |
| metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score'] | |
| colors = {'DT': '#8b5cf6', 'RF': '#10b981', 'SVM': '#f59e0b', | |
| 'KNN': '#ef4444', 'LR': '#3b82f6'} | |
| fig = go.Figure() | |
| # Calculate average performance for each model-metric combination | |
| for model in models: | |
| values = [] | |
| for metric in metrics: | |
| col_name = f"{model} {metric}" | |
| if col_name in df_filtered.columns: | |
| avg_value = df_filtered[col_name].mean() | |
| values.append(avg_value) | |
| else: | |
| values.append(0) | |
| fig.add_trace(go.Bar( | |
| name=model, | |
| x=metrics, | |
| y=values, | |
| marker_color=colors.get(model, '#666'), | |
| text=[f'{v:.3f}' for v in values], | |
| textposition='outside' | |
| )) | |
| # Calculate summary statistics | |
| summary_text = f"Number of datasets: {len(dataset_ids)}" | |
| fig.update_layout( | |
| title=f"Average Model Performance - {title_suffix}<br>" + | |
| f"<sub>{summary_text}</sub>", | |
| xaxis_title="Metric", | |
| yaxis_title="Average Score", | |
| barmode='group', | |
| height=550, | |
| template='plotly_white', | |
| yaxis=dict(range=[0, 1.1]) | |
| ) | |
| return fig | |
| def build_app(df_metrics, df_compare): | |
| category_choices = ["All Categories"] + list(DATASET_CATEGORIES.keys()) | |
| all_formatted_choices = format_dataset_choices(df_metrics.index.tolist()) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(f""" | |
| # Model Performance Comparison | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("Model Comparison"): | |
| with gr.Row(): | |
| category_comp = gr.Dropdown( | |
| choices=category_choices, | |
| value="All Categories", | |
| label="Dataset Category" | |
| ) | |
| metric_comp = gr.Dropdown( | |
| choices=["accuracy", "precision", "recall", "f1_score"], | |
| value="accuracy", | |
| label="Select Metric" | |
| ) | |
| with gr.Row(): | |
| model1 = gr.Dropdown( | |
| choices=["DT", "RF", "SVM", "KNN", "LR"], | |
| value="RF", | |
| label="Model 1" | |
| ) | |
| model2 = gr.Dropdown( | |
| choices=["DT", "RF", "SVM", "KNN", "LR"], | |
| value="DT", | |
| label="Model 2" | |
| ) | |
| comp_plot = gr.Plot(label="Direct Comparison") | |
| comp_btn = gr.Button("Compare Models", variant="primary", size="lg") | |
| def compare_with_filter(category, metric, m1, m2): | |
| filtered_ids = filter_datasets_by_category(df_metrics, category) | |
| df_filtered = df_metrics.loc[filtered_ids] | |
| return create_comparison_plot(df_filtered, metric, m1, m2) | |
| comp_btn.click( | |
| compare_with_filter, | |
| inputs=[category_comp, metric_comp, model1, model2], | |
| outputs=comp_plot | |
| ) | |
| with gr.Tab("Dataset Performance"): | |
| with gr.Row(): | |
| view_type = gr.Radio( | |
| choices=["Single Dataset", "Category Overview"], | |
| value="Single Dataset", | |
| label="View Type" | |
| ) | |
| with gr.Row(): | |
| category_dataset = gr.Dropdown( | |
| choices=category_choices, | |
| value="All Categories", | |
| label="Dataset Category" | |
| ) | |
| dataset_selector = gr.Dropdown( | |
| choices=all_formatted_choices, | |
| value=all_formatted_choices[0], | |
| label="Select Dataset", | |
| visible=True | |
| ) | |
| dataset_plot = gr.Plot(label="Dataset Performance") | |
| dataset_btn = gr.Button("Show Performance", variant="primary", size="lg") | |
| # Update visibility based on view type | |
| def update_view_controls(view_type): | |
| if view_type == "Single Dataset": | |
| return gr.Dropdown(visible=True) | |
| else: | |
| return gr.Dropdown(visible=False) | |
| view_type.change( | |
| update_view_controls, | |
| inputs=view_type, | |
| outputs=dataset_selector | |
| ) | |
| def update_dataset_choices(category): | |
| filtered_ids = filter_datasets_by_category(df_metrics, category) | |
| formatted = format_dataset_choices(filtered_ids, category) | |
| return gr.Dropdown(choices=formatted, value=formatted[0] if formatted else None) | |
| category_dataset.change( | |
| update_dataset_choices, | |
| inputs=category_dataset, | |
| outputs=dataset_selector | |
| ) | |
| def show_performance(view_type, category, formatted_dataset): | |
| if view_type == "Single Dataset": | |
| dataset_id = extract_dataset_id(formatted_dataset) | |
| return create_dataset_performance(df_metrics, dataset_id) | |
| else: | |
| return create_category_performance(df_metrics, category) | |
| dataset_btn.click( | |
| show_performance, | |
| inputs=[view_type, category_dataset, dataset_selector], | |
| outputs=dataset_plot | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| - **DT**: Decision Tree | **RF**: Random Forest | **SVM**: Support Vector Machine | |
| - **KNN**: K-Nearest Neighbors | **LR**: Logistic Regression | |
| """) | |
| return demo | |
| demo = build_app(df_metrics, df_compare) | |
| demo.launch() | |