adejumobi commited on
Commit
6ef275d
·
verified ·
1 Parent(s): 2976edd

added contents to my app

Browse files
Files changed (1) hide show
  1. app.py +478 -0
app.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gdown
2
+ import pickle
3
+
4
+ gdown.download(id="1_CzPJBkTMZ_xPnoHFAmzvxkioMnL99y7", output="all_models.pkl", quiet=False)
5
+ gdown.download(id="1dVQ0gF4tdv_-5yny2FbAXIY2ftzc8s-9", output="all_tests.pkl", quiet=False)
6
+
7
+ import pickle
8
+ import pandas as pd
9
+ import numpy as np
10
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
11
+ from scipy.stats import ttest_rel
12
+ import seaborn as sns
13
+ import matplotlib.pyplot as plt
14
+ import gradio as gr
15
+
16
+ # Load pickles
17
+ with open("all_models.pkl", "rb") as f:
18
+ all_models = pickle.load(f)
19
+ with open("all_tests.pkl", "rb") as f:
20
+ all_tests = pickle.load(f)
21
+
22
+ # Define model groups
23
+ TREE_MODELS = ["RandomForest", "DecisionTree"]
24
+ NON_TREE_MODELS = ["KNN", "SVM", "LogisticRegression"]
25
+ ALL_MODELS = TREE_MODELS + NON_TREE_MODELS
26
+
27
+ # Dataset categorization
28
+ DATASET_CATEGORIES = {
29
+ "Medical & Healthcare": {
30
+ "D1": "Heart Disease (Comprehensive)",
31
+ "D2": "Heart attack possibility",
32
+ "D3": "Heart Disease Dataset",
33
+ "D4": "Liver Disorders",
34
+ "D5": "Diabetes Prediction",
35
+ "D9": "Chronic Kidney Disease",
36
+ "D10": "Breast Cancer Prediction",
37
+ "D11": "Stroke Prediction",
38
+ "D12": "Lung Cancer Prediction",
39
+ "D13": "Hepatitis",
40
+ "D15": "Thyroid Disease",
41
+ "D16": "Heart Failure Prediction",
42
+ "D17": "Parkinson's",
43
+ "D18": "Indian Liver Patient",
44
+ "D19": "COVID-19 Effect on Liver Cancer",
45
+ "D20": "Liver Dataset",
46
+ "D21": "Specht Heart",
47
+ "D22": "Early-stage Diabetes",
48
+ "D23": "Diabetic Retinopathy",
49
+ "D24": "Breast Cancer Coimbra",
50
+ "D25": "Chronic Kidney Disease",
51
+ "D26": "Kidney Stone",
52
+ "D28": "Echocardiogram",
53
+ "D29": "Bladder Cancer Recurrence",
54
+ "D31": "Prostate Cancer",
55
+ "D46": "Real Breast Cancer Data",
56
+ "D47": "Breast Cancer (Royston)",
57
+ "D48": "Lung Cancer Dataset",
58
+ "D52": "Cervical Cancer Risk",
59
+ "D53": "Breast Cancer Wisconsin",
60
+ "D61": "Breast Cancer Prediction",
61
+ "D62": "Thyroid Disease",
62
+ "D68": "Lung Cancer",
63
+ "D69": "Cancer Patients Data",
64
+ "D70": "Labor Relations",
65
+ "D71": "Glioma Grading",
66
+ "D74": "Post-Operative Patient",
67
+ "D80": "Heart Rate Stress Monitoring",
68
+ "D82": "Diabetes 2019",
69
+ "D87": "Personal Heart Disease Indicators",
70
+ "D92": "Heart Disease (Logistic)",
71
+ "D95": "Diabetes Prediction",
72
+ "D97": "Cardiovascular Disease",
73
+ "D98": "Diabetes 130 US Hospitals",
74
+ "D99": "Heart Disease Dataset",
75
+ "D181": "HCV Data",
76
+ "D184": "Cardiotocography",
77
+ "D189": "Mammographic Mass",
78
+ "D199": "Easiest Diabetes",
79
+ "D200": "Monkey-Pox Patients",
80
+ "D54": "Breast Cancer Wisconsin",
81
+ "D63": "Sick-euthyroid",
82
+ "D64": "Ann-test",
83
+ "D65": "Ann-train",
84
+ "D66": "Hypothyroid",
85
+ "D67": "New-thyroid",
86
+ "D72": "Glioma Grading",
87
+ },
88
+
89
+ "Gaming & Sports": {
90
+ "D27": "Chess King-Rook",
91
+ "D36": "Tic-Tac-Toe",
92
+ "D40": "IPL 2022 Matches",
93
+ "D41": "League of Legends",
94
+ "D55": "League of Legends Diamond",
95
+ "D56": "Chess Game Dataset",
96
+ "D57": "Game of Thrones",
97
+ "D73": "Connect-4",
98
+ "D75": "FIFA 2018",
99
+ "D76": "Dota 2 Matches",
100
+ "D77": "IPL Match Analysis",
101
+ "D78": "CS:GO Professional",
102
+ "D79": "IPL 2008-2022",
103
+ "D114": "Video Games",
104
+ "D115": "Video Games Sales",
105
+ "D117": "Sacred Games",
106
+ "D118": "PC Games Sales",
107
+ "D119": "Popular Video Games",
108
+ "D120": "Olympic Games 2021",
109
+ "D121": "Video Games ESRB",
110
+ "D122": "Top Play Store Games",
111
+ "D123": "Steam Games",
112
+ "D124": "PS4 Games",
113
+ "D116": "Video Games Sales",
114
+ },
115
+
116
+ "Education & Students": {
117
+ "D43": "Student Marks",
118
+ "D44": "Student 2nd Year Result",
119
+ "D45": "Student Mat Pass/Fail",
120
+ "D103": "Academic Performance",
121
+ "D104": "Student Academic Analysis",
122
+ "D105": "Student Dropout Prediction",
123
+ "D106": "Electronic Gadgets Impact",
124
+ "D107": "Campus Recruitment",
125
+ "D108": "End-Semester Performance",
126
+ "D109": "Fitbits and Grades",
127
+ "D110": "Student Time Management",
128
+ "D111": "Student Feedback",
129
+ "D112": "Depression & Performance",
130
+ "D113": "University Rankings",
131
+ "D126": "University Ranking CWUR",
132
+ "D127": "University Ranking CWUR 2013-2014",
133
+ "D128": "University Ranking CWUR 2014-2015",
134
+ "D129": "University Ranking CWUR 2015-2016",
135
+ "D130": "University Ranking CWUR 2016-2017",
136
+ "D131": "University Ranking CWUR 2017-2018",
137
+ "D132": "University Ranking CWUR 2018-2019",
138
+ "D133": "University Ranking CWUR 2019-2020",
139
+ "D134": "University Ranking CWUR 2020-2021",
140
+ "D135": "University Ranking CWUR 2021-2022",
141
+ "D136": "University Ranking CWUR 2022-2023",
142
+ "D137": "University Ranking GM 2016",
143
+ "D138": "University Ranking GM 2017",
144
+ "D139": "University Ranking GM 2018",
145
+ "D140": "University Ranking GM 2019",
146
+ "D141": "University Ranking GM 2020",
147
+ "D142": "University Ranking GM 2021",
148
+ "D143": "University Ranking GM 2022",
149
+ "D144": "University Ranking Webometric 2012",
150
+ "D145": "University Ranking Webometric 2013",
151
+ "D146": "University Ranking Webometric 2014",
152
+ "D147": "University Ranking Webometric 2015",
153
+ "D148": "University Ranking Webometric 2016",
154
+ "D149": "University Ranking Webometric 2017",
155
+ "D150": "University Ranking Webometric 2018",
156
+ "D151": "University Ranking Webometric 2019",
157
+ "D152": "University Ranking Webometric 2020",
158
+ "D153": "University Ranking Webometric 2021",
159
+ "D154": "University Ranking Webometric 2022",
160
+ "D155": "University Ranking Webometric 2023",
161
+ "D156": "University Ranking URAP 2018-2019",
162
+ "D157": "University Ranking URAP 2019-2020",
163
+ "D158": "University Ranking URAP 2020-2021",
164
+ "D159": "University Ranking URAP 2021-2022",
165
+ "D160": "University Ranking URAP 2022-2023",
166
+ "D161": "University Ranking THE 2011",
167
+ "D162": "University Ranking THE 2012",
168
+ "D163": "University Ranking THE 2013",
169
+ "D164": "University Ranking THE 2014",
170
+ "D165": "University Ranking THE 2015",
171
+ "D166": "University Ranking THE 2016",
172
+ "D167": "University Ranking THE 2017",
173
+ "D168": "University Ranking THE 2018",
174
+ "D169": "University Ranking THE 2019",
175
+ "D170": "University Ranking THE 2020",
176
+ "D171": "University Ranking THE 2021",
177
+ "D172": "University Ranking THE 2022",
178
+ "D173": "University Ranking THE 2023",
179
+ "D174": "University Ranking QS 2022",
180
+ "D190": "Student Academics Performance"
181
+ },
182
+
183
+ "Banking & Finance": {
184
+ "D6": "Bank Marketing 1",
185
+ "D7": "Bank Marketing 2",
186
+ "D30": "Adult Income",
187
+ "D32": "Telco Customer Churn",
188
+ "D35": "Credit Approval",
189
+ "D50": "Term Deposit Prediction",
190
+ "D96": "Credit Card Fraud",
191
+ "D188": "South German Credit",
192
+ "D193": "Credit Risk Classification",
193
+ "D195": "Credit Score Classification",
194
+ "D196": "Banking Classification"
195
+ },
196
+
197
+ "Science & Engineering": {
198
+ "D8": "Mushroom",
199
+ "D14": "Ionosphere",
200
+ "D33": "EEG Eye State",
201
+ "D37": "Steel Plates Faults",
202
+ "D39": "Fertility",
203
+ "D51": "Darwin",
204
+ "D58": "EEG Emotions",
205
+ "D81": "Predictive Maintenance",
206
+ "D84": "Oranges vs Grapefruit",
207
+ "D90": "Crystal System Li-ion",
208
+ "D183": "Drug Consumption",
209
+ "D49": "Air Pressure System Failures",
210
+ "D93": "Air Pressure System Failures",
211
+ "D185": "Toxicity",
212
+ "D186": "Toxicity",
213
+ },
214
+
215
+ "Social & Lifestyle": {
216
+ "D38": "Online Shoppers",
217
+ "D59": "Red Wine Quality",
218
+ "D60": "White Wine Quality",
219
+ "D88": "Airline Passenger Satisfaction",
220
+ "D94": "Go Emotions Google",
221
+ "D100": "Spotify East Asian",
222
+ "D125": "Suicide Rates",
223
+ "D182": "Obesity Levels",
224
+ "D187": "Blood Transfusion",
225
+ "D191": "Obesity Classification",
226
+ "D192": "Gender Classification",
227
+ "D194": "Happiness Classification",
228
+ "D42": "Airline customer Holiday Booking dataset"
229
+ },
230
+
231
+ "ML Benchmarks & Synthetic": {
232
+ "D34": "Spambase",
233
+ "D85": "Synthetic Binary",
234
+ "D89": "Naive Bayes Data",
235
+ "D175": "Monk's Problems 1",
236
+ "D176": "Monk's Problems 2",
237
+ "D177": "Monk's Problems 3",
238
+ "D178": "Monk's Problems 4",
239
+ "D179": "Monk's Problems 5",
240
+ "D180": "Monk's Problems 6"
241
+ },
242
+
243
+ "Other": {
244
+ "D83": "Paris Housing",
245
+ "D91": "Fake Bills",
246
+ "D197": "Star Classification"
247
+ }
248
+ }
249
+
250
+ def compute_metrics(datasets_list, selected_models, metric_for_comparison):
251
+ """Compute metrics and stats for selected datasets and models"""
252
+
253
+ # Handle "All models" selection
254
+ if "All models" in selected_models:
255
+ selected_models = ALL_MODELS
256
+
257
+ records = []
258
+
259
+ # Compute metrics for each dataset-model combo
260
+ for ds in datasets_list:
261
+ if ds not in all_tests or ds not in all_models:
262
+ continue
263
+
264
+ X_test = all_tests[ds]["X_test"]
265
+ y_test = all_tests[ds]["y_test"]
266
+
267
+ for model_name in selected_models:
268
+ if model_name not in all_models[ds]:
269
+ continue
270
+
271
+ model = all_models[ds][model_name]
272
+ y_pred = model.predict(X_test)
273
+
274
+ records.append({
275
+ "dataset": ds,
276
+ "model": model_name,
277
+ "accuracy": accuracy_score(y_test, y_pred),
278
+ "precision": precision_score(y_test, y_pred, average='weighted', zero_division=0),
279
+ "recall": recall_score(y_test, y_pred, average='weighted', zero_division=0),
280
+ "f1_score": f1_score(y_test, y_pred, average='weighted', zero_division=0)
281
+ })
282
+
283
+ df = pd.DataFrame(records)
284
+
285
+ if df.empty:
286
+ return df, pd.DataFrame(), None
287
+
288
+ # Statistical comparisons
289
+ stat_records = []
290
+ models_list = df['model'].unique().tolist()
291
+
292
+ for i, m1 in enumerate(models_list):
293
+ for m2 in models_list[i+1:]:
294
+ m1_vals = df[df['model'] == m1].set_index('dataset')[metric_for_comparison]
295
+ m2_vals = df[df['model'] == m2].set_index('dataset')[metric_for_comparison]
296
+
297
+ combined = pd.concat([m1_vals, m2_vals], axis=1, keys=['m1', 'm2']).dropna()
298
+
299
+ if len(combined) < 2:
300
+ continue
301
+
302
+ t_stat, p_val = ttest_rel(combined['m1'], combined['m2'])
303
+
304
+ stat_records.append({
305
+ "model1": m1,
306
+ "model2": m2,
307
+ "mean_diff": combined['m1'].mean() - combined['m2'].mean(),
308
+ "t_stat": t_stat,
309
+ "p_value": p_val,
310
+ "significant": "Yes" if p_val < 0.05 else "No"
311
+ })
312
+
313
+ stat_df = pd.DataFrame(stat_records)
314
+
315
+ # Create visualization
316
+ fig = create_heatmap(df, metric_for_comparison)
317
+
318
+ return df, stat_df, fig
319
+
320
+ def create_heatmap(df, metric):
321
+ """Create metric by dataset heatmap"""
322
+
323
+ # Create heatmap of metric by dataset and model
324
+ pivot = df.pivot_table(values=metric, index='dataset', columns='model')
325
+
326
+ fig, ax = plt.subplots(figsize=(12, max(8, len(pivot) * 0.4)))
327
+ sns.heatmap(pivot, annot=True, fmt='.3f', cmap='viridis', ax=ax, cbar_kws={'label': metric.capitalize()})
328
+ ax.set_title(f'{metric.capitalize()} by Dataset and Model', fontsize=14, fontweight='bold')
329
+ ax.set_xlabel('Model', fontsize=12)
330
+ ax.set_ylabel('Dataset', fontsize=12)
331
+
332
+ plt.tight_layout()
333
+ return fig
334
+
335
+ def run_evaluation(selected_datasets, selected_models, metric_comparison):
336
+ """Main evaluation function"""
337
+
338
+ if not selected_datasets:
339
+ empty = gr.update(value=None, visible=False)
340
+ return "Please select datasets", empty, empty, empty, empty
341
+
342
+ if not selected_models:
343
+ selected_models = ["All models"]
344
+
345
+ # Ensure metric_comparison is a list
346
+ if isinstance(metric_comparison, str):
347
+ metric_comparison = [metric_comparison]
348
+
349
+ if not metric_comparison:
350
+ empty = gr.update(value=None, visible=False)
351
+ return "Please select at least one metric", empty, empty, empty, empty
352
+
353
+ # Compute metrics once
354
+ df, _, _ = compute_metrics(selected_datasets, selected_models, metric_comparison[0])
355
+
356
+ if df.empty:
357
+ empty = gr.update(value=None, visible=False)
358
+ return "No results found", empty, empty, empty, empty
359
+
360
+ # Create stats and figures for EACH selected metric
361
+ all_stats_html = ""
362
+ outputs = []
363
+
364
+ for i, metric in enumerate(metric_comparison):
365
+ if i >= 4:
366
+ break
367
+
368
+ _, stat_df, fig = compute_metrics(selected_datasets, selected_models, metric)
369
+
370
+ if not stat_df.empty:
371
+ stats_html = f"""
372
+ <h3>Statistical Tests ({metric})</h3>
373
+ <p>Paired t-tests comparing model performance (* = significant at p < 0.05)</p>
374
+ {stat_df.to_html(index=False, float_format='%.4f')}
375
+ <hr>
376
+ """
377
+ all_stats_html += stats_html
378
+
379
+ outputs.append(gr.update(value=fig, visible=True))
380
+
381
+ # Fill remaining slots with hidden empty plots
382
+ while len(outputs) < 4:
383
+ outputs.append(gr.update(value=None, visible=False))
384
+
385
+ if not all_stats_html:
386
+ all_stats_html = "<p>Not enough data for statistical comparisons</p>"
387
+
388
+ return all_stats_html, outputs[0], outputs[1], outputs[2], outputs[3]
389
+
390
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
391
+ gr.Markdown("""
392
+ # Model Evaluation Platform
393
+ ### Compare model performance across different datasets
394
+ """)
395
+
396
+ selected_datasets = gr.State([])
397
+
398
+ with gr.Row():
399
+ with gr.Column(scale=1):
400
+ gr.Markdown("### Select Datasets")
401
+
402
+ # Get available datasets
403
+ available = list(all_models.keys())
404
+
405
+ # Create dropdowns
406
+ dropdowns = []
407
+ for category, datasets in DATASET_CATEGORIES.items():
408
+ choices = [f"{did}: {name}" for did, name in datasets.items() if did in available]
409
+ if choices:
410
+ dd = gr.Dropdown(
411
+ choices=choices,
412
+ label=f"{category} ({len(choices)})",
413
+ multiselect=True,
414
+ value=[]
415
+ )
416
+ dropdowns.append(dd)
417
+
418
+ with gr.Column(scale=1):
419
+ gr.Markdown("### Evaluation Settings")
420
+
421
+ summary = gr.Markdown("**0 datasets selected**")
422
+
423
+ model_input = gr.Dropdown(
424
+ choices=["All models"] + ALL_MODELS,
425
+ label="Models",
426
+ value=["All models"],
427
+ multiselect=True
428
+ )
429
+
430
+ metric_comparison = gr.Dropdown(
431
+ choices=["accuracy", "precision", "recall", "f1_score"],
432
+ label="Primary Metric",
433
+ value="accuracy",
434
+ multiselect=True
435
+ )
436
+
437
+ run_btn = gr.Button("Run Evaluation", variant="primary", size="lg")
438
+
439
+ def update_selection(*dropdown_values):
440
+ ids = []
441
+ for vals in dropdown_values:
442
+ if vals:
443
+ ids.extend([v.split(":")[0] for v in vals])
444
+ ids = sorted(list(set(ids)))
445
+
446
+ if ids:
447
+ summary_text = f"**✓ {len(ids)} dataset{'s' if len(ids) != 1 else ''} selected:** {', '.join(ids)}"
448
+ else:
449
+ summary_text = "**No datasets selected**"
450
+
451
+ return summary_text, ids
452
+
453
+ for dd in dropdowns:
454
+ dd.change(update_selection, inputs=dropdowns, outputs=[summary, selected_datasets])
455
+
456
+ gr.Markdown("---")
457
+ gr.Markdown("## Evaluation Results")
458
+
459
+ output_stats = gr.HTML(label="Statistical Tests")
460
+ #heatmap_output = gr.Plot(label="Performance Heatmap")
461
+ #heatmap_output = gr.Gallery(label="Performance Heatmaps", columns=2, height="auto")
462
+ with gr.Column():
463
+ heatmap_output_1 = gr.Plot(label="Heatmap 1")
464
+ heatmap_output_2 = gr.Plot(label="Heatmap 2")
465
+ heatmap_output_3 = gr.Plot(label="Heatmap 3")
466
+ heatmap_output_4 = gr.Plot(label="Heatmap 4")
467
+
468
+
469
+ run_btn.click(
470
+ run_evaluation,
471
+ inputs=[selected_datasets, model_input, metric_comparison],
472
+ outputs=[
473
+ output_stats,
474
+ heatmap_output_1, heatmap_output_2, heatmap_output_3, heatmap_output_4]
475
+ )
476
+
477
+ if __name__ == "__main__":
478
+ demo.launch()