PLOS_Gradio_app / app.py
Navinor's picture
Update app.py
0d518f7 verified
import pandas as pd
import numpy as np
df = pd.read_csv('PLOS_model_results.csv')
df = df.drop('Unnamed: 0', axis=1)
df_metrics = df.pivot_table(index='dataset', columns='model', values=['accuracy', 'precision', 'recall', 'f1_score'])
# Reorganize columns
cols_order = []
metrics = ['accuracy', 'precision', 'recall', 'f1_score']
models = ['DecisionTree', 'RandomForest', 'SVM', 'KNN', 'LogisticRegression']
for metric in metrics:
for model in models:
cols_order.append((metric, model))
df_metrics = df_metrics[cols_order]
# Rename columns for better readability
new_cols = []
for metric, model in df_metrics.columns:
model_short = model.replace("DecisionTree", "DT") \
.replace("RandomForest", "RF") \
.replace("LogisticRegression", "LR")
metric_name = "F1 Score" if metric == "f1_score" else metric.capitalize()
new_cols.append(f"{model_short} {metric_name}")
df_metrics.columns = new_cols
df_metrics = df_metrics.sort_index(key=lambda x: x.str[1:].astype(int))
df_compare = pd.read_csv("pairwise_comparison_results.csv")
DATASET_CATEGORIES = {
"Medical & Healthcare": {
"D1": "Heart Disease (Comprehensive)",
"D2": "Heart attack possibility",
"D3": "Heart Disease Dataset",
"D4": "Liver Disorders",
"D5": "Diabetes Prediction",
"D9": "Chronic Kidney Disease",
"D10": "Breast Cancer Prediction",
"D11": "Stroke Prediction",
"D12": "Lung Cancer Prediction",
"D13": "Hepatitis",
"D15": "Thyroid Disease",
"D16": "Heart Failure Prediction",
"D17": "Parkinson's",
"D18": "Indian Liver Patient",
"D19": "COVID-19 Effect on Liver Cancer",
"D20": "Liver Dataset",
"D21": "Specht Heart",
"D22": "Early-stage Diabetes",
"D23": "Diabetic Retinopathy",
"D24": "Breast Cancer Coimbra",
"D25": "Chronic Kidney Disease",
"D26": "Kidney Stone",
"D28": "Echocardiogram",
"D29": "Bladder Cancer Recurrence",
"D31": "Prostate Cancer",
"D46": "Real Breast Cancer Data",
"D47": "Breast Cancer (Royston)",
"D48": "Lung Cancer Dataset",
"D52": "Cervical Cancer Risk",
"D53": "Breast Cancer Wisconsin",
"D61": "Breast Cancer Prediction",
"D62": "Thyroid Disease",
"D68": "Lung Cancer",
"D69": "Cancer Patients Data",
"D70": "Labor Relations",
"D71": "Glioma Grading",
"D74": "Post-Operative Patient",
"D80": "Heart Rate Stress Monitoring",
"D82": "Diabetes 2019",
"D87": "Personal Heart Disease Indicators",
"D92": "Heart Disease (Logistic)",
"D95": "Diabetes Prediction",
"D97": "Cardiovascular Disease",
"D98": "Diabetes 130 US Hospitals",
"D99": "Heart Disease Dataset",
"D181": "HCV Data",
"D184": "Cardiotocography",
"D189": "Mammographic Mass",
"D199": "Easiest Diabetes",
"D200": "Monkey-Pox Patients",
"D54": "Breast Cancer Wisconsin",
"D63": "Sick-euthyroid",
"D64": "Ann-test",
"D65": "Ann-train",
"D66": "Hypothyroid",
"D67": "New-thyroid",
"D72": "Glioma Grading",
},
"Gaming & Sports": {
"D27": "Chess King-Rook",
"D36": "Tic-Tac-Toe",
"D40": "IPL 2022 Matches",
"D41": "League of Legends",
"D55": "League of Legends Diamond",
"D56": "Chess Game Dataset",
"D57": "Game of Thrones",
"D73": "Connect-4",
"D75": "FIFA 2018",
"D76": "Dota 2 Matches",
"D77": "IPL Match Analysis",
"D78": "CS:GO Professional",
"D79": "IPL 2008-2022",
"D114": "Video Games",
"D115": "Video Games Sales",
"D117": "Sacred Games",
"D118": "PC Games Sales",
"D119": "Popular Video Games",
"D120": "Olympic Games 2021",
"D121": "Video Games ESRB",
"D122": "Top Play Store Games",
"D123": "Steam Games",
"D124": "PS4 Games",
"D116": "Video Games Sales",
},
"Education & Students": {
"D43": "Student Marks",
"D44": "Student 2nd Year Result",
"D45": "Student Mat Pass/Fail",
"D103": "Academic Performance",
"D104": "Student Academic Analysis",
"D105": "Student Dropout Prediction",
"D106": "Electronic Gadgets Impact",
"D107": "Campus Recruitment",
"D108": "End-Semester Performance",
"D109": "Fitbits and Grades",
"D110": "Student Time Management",
"D111": "Student Feedback",
"D112": "Depression & Performance",
"D113": "University Rankings",
"D126": "University Ranking CWUR",
"D127": "University Ranking CWUR 2013-2014",
"D128": "University Ranking CWUR 2014-2015",
"D129": "University Ranking CWUR 2015-2016",
"D130": "University Ranking CWUR 2016-2017",
"D131": "University Ranking CWUR 2017-2018",
"D132": "University Ranking CWUR 2018-2019",
"D133": "University Ranking CWUR 2019-2020",
"D134": "University Ranking CWUR 2020-2021",
"D135": "University Ranking CWUR 2021-2022",
"D136": "University Ranking CWUR 2022-2023",
"D137": "University Ranking GM 2016",
"D138": "University Ranking GM 2017",
"D139": "University Ranking GM 2018",
"D140": "University Ranking GM 2019",
"D141": "University Ranking GM 2020",
"D142": "University Ranking GM 2021",
"D143": "University Ranking GM 2022",
"D144": "University Ranking Webometric 2012",
"D145": "University Ranking Webometric 2013",
"D146": "University Ranking Webometric 2014",
"D147": "University Ranking Webometric 2015",
"D148": "University Ranking Webometric 2016",
"D149": "University Ranking Webometric 2017",
"D150": "University Ranking Webometric 2018",
"D151": "University Ranking Webometric 2019",
"D152": "University Ranking Webometric 2020",
"D153": "University Ranking Webometric 2021",
"D154": "University Ranking Webometric 2022",
"D155": "University Ranking Webometric 2023",
"D156": "University Ranking URAP 2018-2019",
"D157": "University Ranking URAP 2019-2020",
"D158": "University Ranking URAP 2020-2021",
"D159": "University Ranking URAP 2021-2022",
"D160": "University Ranking URAP 2022-2023",
"D161": "University Ranking THE 2011",
"D162": "University Ranking THE 2012",
"D163": "University Ranking THE 2013",
"D164": "University Ranking THE 2014",
"D165": "University Ranking THE 2015",
"D166": "University Ranking THE 2016",
"D167": "University Ranking THE 2017",
"D168": "University Ranking THE 2018",
"D169": "University Ranking THE 2019",
"D170": "University Ranking THE 2020",
"D171": "University Ranking THE 2021",
"D172": "University Ranking THE 2022",
"D173": "University Ranking THE 2023",
"D174": "University Ranking QS 2022",
"D190": "Student Academics Performance"
},
"Banking & Finance": {
"D6": "Bank Marketing 1",
"D7": "Bank Marketing 2",
"D30": "Adult Income",
"D32": "Telco Customer Churn",
"D35": "Credit Approval",
"D50": "Term Deposit Prediction",
"D96": "Credit Card Fraud",
"D188": "South German Credit",
"D193": "Credit Risk Classification",
"D195": "Credit Score Classification",
"D196": "Banking Classification"
},
"Science & Engineering": {
"D8": "Mushroom",
"D14": "Ionosphere",
"D33": "EEG Eye State",
"D37": "Steel Plates Faults",
"D39": "Fertility",
"D51": "Darwin",
"D58": "EEG Emotions",
"D81": "Predictive Maintenance",
"D84": "Oranges vs Grapefruit",
"D90": "Crystal System Li-ion",
"D183": "Drug Consumption",
"D49": "Air Pressure System Failures",
"D93": "Air Pressure System Failures",
"D185": "Toxicity",
"D186": "Toxicity",
},
"Social & Lifestyle": {
"D38": "Online Shoppers",
"D59": "Red Wine Quality",
"D60": "White Wine Quality",
"D88": "Airline Passenger Satisfaction",
"D94": "Go Emotions Google",
"D100": "Spotify East Asian",
"D125": "Suicide Rates",
"D182": "Obesity Levels",
"D187": "Blood Transfusion",
"D191": "Obesity Classification",
"D192": "Gender Classification",
"D194": "Happiness Classification",
"D42": "Airline customer Holiday Booking dataset"
},
"ML Benchmarks & Synthetic": {
"D34": "Spambase",
"D85": "Synthetic Binary",
"D89": "Naive Bayes Data",
"D175": "Monk's Problems 1",
"D176": "Monk's Problems 2",
"D177": "Monk's Problems 3",
"D178": "Monk's Problems 4",
"D179": "Monk's Problems 5",
"D180": "Monk's Problems 6"
},
"Other": {
"D83": "Paris Housing",
"D91": "Fake Bills",
"D197": "Star Classification"
}
}
import gradio as gr
import pandas as pd
import plotly.graph_objects as go
import numpy as np
def get_dataset_info(dataset_id):
for category, datasets in DATASET_CATEGORIES.items():
if dataset_id in datasets:
return category, datasets[dataset_id]
return "Unknown", "Unknown Dataset"
def filter_datasets_by_category(df_metrics, category="All Categories"):
if category == "All Categories":
return list(df_metrics.index)
if category in DATASET_CATEGORIES:
dataset_ids = list(DATASET_CATEGORIES[category].keys())
return [did for did in dataset_ids if did in df_metrics.index]
return []
def format_dataset_choices(dataset_ids, filter_category="All Categories"):
formatted = []
for dataset_id in dataset_ids:
category, name = get_dataset_info(dataset_id)
if filter_category == "All Categories":
formatted.append(f"{dataset_id} - {category} | {name}")
else:
formatted.append(f"{dataset_id} - {name}")
return formatted
def extract_dataset_id(formatted_string):
return formatted_string.split(" - ")[0]
def create_comparison_plot(df_metrics, metric, model1, model2):
col1 = f"{model1} {metric.capitalize()}"
col2 = f"{model2} {metric.capitalize()}"
if col1 not in df_metrics.columns or col2 not in df_metrics.columns:
return None
# Calculate difference
diff = df_metrics[col1] - df_metrics[col2]
# Create hover text with dataset info
hover_texts = []
for dataset_id in df_metrics.index:
category, name = get_dataset_info(dataset_id)
hover_text = f"<b>{dataset_id}</b><br>{category} | {name}"
hover_texts.append(hover_text)
# Calculate symmetric color scale range centered at 0
max_abs_diff = max(abs(diff.min()), abs(diff.max()))
fig = go.Figure()
fig.add_trace(go.Scatter(
x=df_metrics[col1],
y=df_metrics[col2],
mode='markers',
marker=dict(
size=8,
color=diff,
colorscale='RdYlGn',
showscale=True,
colorbar=dict(title=f"{model1} - {model2}"),
line=dict(width=0.5, color='white'),
cmin=-max_abs_diff,
cmax=max_abs_diff
),
text=hover_texts,
hovertemplate='%{text}<br><br>' +
f'{model1}: %{{x:.3f}}<br>' +
f'{model2}: %{{y:.3f}}<br>' +
'Difference: %{marker.color:.3f}<br>'
))
# Add diagonal line (equal performance)
min_val = min(df_metrics[col1].min(), df_metrics[col2].min())
max_val = max(df_metrics[col1].max(), df_metrics[col2].max())
fig.add_trace(go.Scatter(
x=[min_val, max_val],
y=[min_val, max_val],
mode='lines',
line=dict(dash='dash', color='gray', width=2),
showlegend=False,
hoverinfo='skip'
))
# Calculate statistics
wins_model1 = (diff > 0).sum()
wins_model2 = (diff < 0).sum()
ties = (diff == 0).sum()
fig.update_layout(
title=f"{model1} vs {model2} - {metric.capitalize()}<br>" +
f"<sub>{model1} wins: {wins_model1} | {model2} wins: {wins_model2} | Ties: {ties}</sub>",
xaxis_title=f"{model1} {metric.capitalize()}",
yaxis_title=f"{model2} {metric.capitalize()}",
height=550,
template='plotly_white'
)
return fig
def create_dataset_performance(df_metrics, dataset_name):
"""Create performance plot for a single dataset"""
if dataset_name not in df_metrics.index:
return None
dataset_row = df_metrics.loc[dataset_name]
metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score']
models = ['DT', 'RF', 'SVM', 'KNN', 'LR']
fig = go.Figure()
colors = {'DT': '#8b5cf6', 'RF': '#10b981', 'SVM': '#f59e0b',
'KNN': '#ef4444', 'LR': '#3b82f6'}
for model in models:
values = []
for metric in metrics:
col_name = f"{model} {metric}"
if col_name in dataset_row.index:
values.append(dataset_row[col_name])
else:
values.append(0)
fig.add_trace(go.Bar(
name=model,
x=metrics,
y=values,
marker_color=colors.get(model, '#666'),
text=[f'{round(v, 3)}' for v in values],
textposition='outside'
))
# Get dataset info for title
category, name = get_dataset_info(dataset_name)
fig.update_layout(
title=f"Performance Metrics for {dataset_name}<br><sub>{category} | {name}</sub>",
xaxis_title="Metric",
yaxis_title="Score",
barmode='group',
height=500,
template='plotly_white',
yaxis=dict(range=[0, 1.1])
)
return fig
def create_category_performance(df_metrics, category):
"""Create performance plot for an entire category of datasets"""
if category == "All Categories":
dataset_ids = list(df_metrics.index)
title_suffix = "All Datasets"
else:
dataset_ids = filter_datasets_by_category(df_metrics, category)
title_suffix = category
if not dataset_ids:
return None
df_filtered = df_metrics.loc[dataset_ids]
models = ['DT', 'RF', 'SVM', 'KNN', 'LR']
metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score']
colors = {'DT': '#8b5cf6', 'RF': '#10b981', 'SVM': '#f59e0b',
'KNN': '#ef4444', 'LR': '#3b82f6'}
fig = go.Figure()
# Calculate average performance for each model-metric combination
for model in models:
values = []
for metric in metrics:
col_name = f"{model} {metric}"
if col_name in df_filtered.columns:
avg_value = df_filtered[col_name].mean()
values.append(avg_value)
else:
values.append(0)
fig.add_trace(go.Bar(
name=model,
x=metrics,
y=values,
marker_color=colors.get(model, '#666'),
text=[f'{v:.3f}' for v in values],
textposition='outside'
))
# Calculate summary statistics
summary_text = f"Number of datasets: {len(dataset_ids)}"
fig.update_layout(
title=f"Average Model Performance - {title_suffix}<br>" +
f"<sub>{summary_text}</sub>",
xaxis_title="Metric",
yaxis_title="Average Score",
barmode='group',
height=550,
template='plotly_white',
yaxis=dict(range=[0, 1.1])
)
return fig
def build_app(df_metrics, df_compare):
category_choices = ["All Categories"] + list(DATASET_CATEGORIES.keys())
all_formatted_choices = format_dataset_choices(df_metrics.index.tolist())
with gr.Blocks() as demo:
gr.Markdown(f"""
# Model Performance Comparison
""")
with gr.Tabs():
with gr.Tab("Model Comparison"):
with gr.Row():
category_comp = gr.Dropdown(
choices=category_choices,
value="All Categories",
label="Dataset Category"
)
metric_comp = gr.Dropdown(
choices=["accuracy", "precision", "recall", "f1_score"],
value="accuracy",
label="Select Metric"
)
with gr.Row():
model1 = gr.Dropdown(
choices=["DT", "RF", "SVM", "KNN", "LR"],
value="RF",
label="Model 1"
)
model2 = gr.Dropdown(
choices=["DT", "RF", "SVM", "KNN", "LR"],
value="DT",
label="Model 2"
)
comp_plot = gr.Plot(label="Direct Comparison")
comp_btn = gr.Button("Compare Models", variant="primary", size="lg")
def compare_with_filter(category, metric, m1, m2):
filtered_ids = filter_datasets_by_category(df_metrics, category)
df_filtered = df_metrics.loc[filtered_ids]
return create_comparison_plot(df_filtered, metric, m1, m2)
comp_btn.click(
compare_with_filter,
inputs=[category_comp, metric_comp, model1, model2],
outputs=comp_plot
)
with gr.Tab("Dataset Performance"):
with gr.Row():
view_type = gr.Radio(
choices=["Single Dataset", "Category Overview"],
value="Single Dataset",
label="View Type"
)
with gr.Row():
category_dataset = gr.Dropdown(
choices=category_choices,
value="All Categories",
label="Dataset Category"
)
dataset_selector = gr.Dropdown(
choices=all_formatted_choices,
value=all_formatted_choices[0],
label="Select Dataset",
visible=True
)
dataset_plot = gr.Plot(label="Dataset Performance")
dataset_btn = gr.Button("Show Performance", variant="primary", size="lg")
# Update visibility based on view type
def update_view_controls(view_type):
if view_type == "Single Dataset":
return gr.Dropdown(visible=True)
else:
return gr.Dropdown(visible=False)
view_type.change(
update_view_controls,
inputs=view_type,
outputs=dataset_selector
)
def update_dataset_choices(category):
filtered_ids = filter_datasets_by_category(df_metrics, category)
formatted = format_dataset_choices(filtered_ids, category)
return gr.Dropdown(choices=formatted, value=formatted[0] if formatted else None)
category_dataset.change(
update_dataset_choices,
inputs=category_dataset,
outputs=dataset_selector
)
def show_performance(view_type, category, formatted_dataset):
if view_type == "Single Dataset":
dataset_id = extract_dataset_id(formatted_dataset)
return create_dataset_performance(df_metrics, dataset_id)
else:
return create_category_performance(df_metrics, category)
dataset_btn.click(
show_performance,
inputs=[view_type, category_dataset, dataset_selector],
outputs=dataset_plot
)
gr.Markdown("""
---
- **DT**: Decision Tree | **RF**: Random Forest | **SVM**: Support Vector Machine
- **KNN**: K-Nearest Neighbors | **LR**: Logistic Regression
""")
return demo
demo = build_app(df_metrics, df_compare)
demo.launch()