firobeid commited on
Commit
5019076
Β·
verified Β·
1 Parent(s): a0f7a5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +402 -136
app.py CHANGED
@@ -1,10 +1,8 @@
1
  import gradio as gr
 
2
  import plotly.graph_objects as go
3
  import plotly.express as px
4
- import pandas as pd
5
  from plotly.subplots import make_subplots
6
- import numpy as np
7
- import io
8
 
9
  # Default sample data (will be replaced when CSV is uploaded)
10
  default_data = pd.DataFrame({
@@ -59,7 +57,7 @@ default_data = pd.DataFrame({
59
  def load_csv_data(file):
60
  """Load and validate CSV data"""
61
  if file is None:
62
- return default_data, "Using default sample data"
63
 
64
  try:
65
  df = pd.read_csv(file.name)
@@ -70,18 +68,21 @@ def load_csv_data(file):
70
  missing_cols = [col for col in required_cols if col not in df.columns]
71
 
72
  if missing_cols:
73
- return default_data, f"❌ Missing columns: {missing_cols}. Using default data."
74
 
75
  # Clean data
76
  df = df.dropna()
 
 
 
77
 
78
- return df, f"βœ… Successfully loaded {len(df)} records with {df['model'].nunique()} models"
79
 
80
  except Exception as e:
81
  return default_data, f"❌ Error loading CSV: {str(e)}. Using default data."
82
 
83
  def create_model_leaderboard(df, partition_filter='all', topic_filter='OVERALL'):
84
- """Create leaderboard comparing all models"""
85
  filtered_df = df.copy()
86
 
87
  if partition_filter != 'all':
@@ -96,47 +97,65 @@ def create_model_leaderboard(df, partition_filter='all', topic_filter='OVERALL')
96
 
97
  # Calculate overall score (average of key metrics)
98
  leaderboard['Overall_Score'] = leaderboard[['Precision', 'Recall_Power', 'Accuracy']].mean(axis=1)
99
- leaderboard = leaderboard.sort_values('Overall_Score', ascending=False)
100
-
101
- # Create subplot for each metric
102
  fig = make_subplots(
103
- rows=1, cols=len(metrics) + 1,
104
- subplot_titles=metrics + ['Overall Score']
 
105
  )
106
 
107
- colors = px.colors.qualitative.Set3[:len(leaderboard)]
 
 
 
108
 
109
  for i, metric in enumerate(metrics + ['Overall_Score']):
110
- for j, (_, row) in enumerate(leaderboard.iterrows()):
111
- fig.add_trace(
112
- go.Bar(
113
- x=[row['model']],
114
- y=[row[metric]],
115
- name=row['model'] if i == 0 else "",
116
- marker_color=colors[j],
117
- showlegend=True if i == 0 else False,
118
- text=f"{row[metric]:.3f}",
119
- textposition="outside"
 
 
120
  ),
121
- row=1, col=i+1
122
- )
 
 
 
 
 
123
 
124
  fig.update_layout(
125
- title=f"Model Leaderboard - {partition_filter.title()} | {topic_filter}",
126
- height=500,
127
- showlegend=True
 
 
 
 
 
 
128
  )
129
 
130
- # Update y-axes
131
  for i in range(1, len(metrics) + 2):
132
- fig.update_yaxes(range=[0, 1], row=1, col=i)
 
 
133
 
134
  return fig
135
 
136
  def create_topic_comparison(df, models_selected=None, metric='Accuracy', partition_filter='all'):
137
- """Compare selected models across topics"""
138
  if models_selected is None or len(models_selected) == 0:
139
- models_selected = df['model'].unique()[:3] # Default to first 3 models
140
 
141
  # Filter data
142
  filtered_df = df[df['model'].isin(models_selected)].copy()
@@ -150,34 +169,77 @@ def create_topic_comparison(df, models_selected=None, metric='Accuracy', partiti
150
  # Create grouped bar chart
151
  fig = go.Figure()
152
 
153
- colors = px.colors.qualitative.Set3[:len(models_selected)]
 
 
 
 
154
  topics = sorted(topic_performance['topic'].unique())
155
 
156
  for i, model in enumerate(models_selected):
157
  model_data = topic_performance[topic_performance['model'] == model]
 
 
 
 
 
 
158
  fig.add_trace(go.Bar(
159
- name=model,
160
  x=topics,
161
  y=model_data[metric],
162
- marker_color=colors[i],
163
- text=[f"{val:.3f}" for val in model_data[metric]],
164
- textposition='outside'
 
 
 
 
 
 
165
  ))
166
 
167
  fig.update_layout(
168
- title=f"Model Comparison Across Topics ({metric}) - {partition_filter.title()}",
169
- xaxis_title="Topics",
170
- yaxis_title=metric,
 
 
 
171
  barmode='group',
172
- height=500,
173
- xaxis_tickangle=-45,
174
- yaxis=dict(range=[0, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  )
176
 
177
  return fig
178
 
179
  def create_partition_analysis(df, models_selected=None):
180
- """Analyze model performance across partitions"""
181
  if models_selected is None or len(models_selected) == 0:
182
  models_selected = df['model'].unique()[:3]
183
 
@@ -187,97 +249,187 @@ def create_partition_analysis(df, models_selected=None):
187
  metrics = ['FPR', 'Confidence', 'FDR', 'Precision', 'Recall_Power', 'Accuracy', 'G_mean']
188
  partition_performance = filtered_df.groupby(['model', 'partition'])[metrics].mean().reset_index()
189
 
190
- # Create subplots for each metric
191
  fig = make_subplots(
192
- rows=2, cols=4,
193
- subplot_titles=metrics + [''], # Extra empty title for 8th subplot
194
- specs=[[{"colspan": 1}, {"colspan": 1}, {"colspan": 1}, {"colspan": 1}],
195
- [{"colspan": 1}, {"colspan": 1}, {"colspan": 1}, None]] # 7 subplots total
 
 
 
196
  )
197
 
198
- colors = px.colors.qualitative.Set3[:len(models_selected)]
 
 
 
 
199
  partitions = ['train', 'test', 'inference']
200
 
201
  # Plot each metric
202
- for i, metric in enumerate(metrics):
203
- row = 1 if i < 4 else 2
204
- col = (i % 4) + 1
 
205
 
206
  for j, model in enumerate(models_selected):
207
  model_data = partition_performance[partition_performance['model'] == model]
208
- model_data = model_data.sort_values('partition') # Ensure consistent ordering
 
 
 
209
 
210
  fig.add_trace(
211
  go.Bar(
212
- name=model if i == 0 else "",
213
  x=model_data['partition'],
214
  y=model_data[metric],
215
- marker_color=colors[j],
216
- showlegend=True if i == 0 else False,
217
- text=[f"{val:.3f}" for val in model_data[metric]],
218
- textposition='outside'
 
 
 
 
 
 
219
  ),
220
  row=row, col=col
221
  )
222
 
223
  fig.update_layout(
224
- title="Model Performance Across Partitions - All Metrics",
225
- height=800,
226
- barmode='group'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  )
228
 
229
- # Update y-axes for all subplots
230
- for i in range(1, 8): # 7 subplots
231
- row = 1 if i <= 4 else 2
232
- col = i if i <= 4 else i - 4
233
- if i <= 7: # Only update existing subplots
234
- fig.update_yaxes(range=[0, 1], row=row, col=col)
 
 
 
 
235
 
236
  return fig
237
 
238
  def create_performance_summary_table(df):
239
- """Create summary table with key statistics"""
240
- # Calculate summary statistics
241
  summary_stats = []
242
 
243
  for model in df['model'].unique():
244
  model_data = df[df['model'] == model]
245
 
246
  stats = {
247
- 'Model': model,
248
- 'Avg_Accuracy': model_data['Accuracy'].mean(),
249
- 'Avg_Precision': model_data['Precision'].mean(),
250
- 'Avg_Recall': model_data['Recall_Power'].mean(),
251
- 'Avg_G_mean': model_data['G_mean'].mean(),
252
- 'Best_Topic_Accuracy': model_data.loc[model_data['Accuracy'].idxmax(), 'topic'],
253
- 'Best_Topic_Score': model_data['Accuracy'].max(),
254
- 'Worst_Topic_Accuracy': model_data.loc[model_data['Accuracy'].idxmin(), 'topic'],
255
- 'Worst_Topic_Score': model_data['Accuracy'].min(),
256
- 'Performance_Variance': model_data['Accuracy'].var()
257
  }
258
  summary_stats.append(stats)
259
 
260
  summary_df = pd.DataFrame(summary_stats)
261
- summary_df = summary_df.round(4)
262
- summary_df = summary_df.sort_values('Avg_Accuracy', ascending=False)
 
 
 
263
 
264
  return summary_df
265
 
266
- # Create the Gradio interface
267
- with gr.Blocks(title="Course Backbone Project Leaderboard \nMulti-Model Classifier Dashboard", theme=gr.themes.Soft()) as demo:
268
- gr.HTML("<h1 style='text-align: center; color: #2E86AB;'>πŸ† Multi-Model Classifier Dashboard</h1>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  # Data loading section
271
  with gr.Row():
272
- with gr.Column():
273
  csv_file = gr.File(
274
  label="πŸ“ Upload CSV File",
275
- file_types=['.csv']
 
276
  )
 
277
  data_status = gr.Textbox(
278
- label="Data Status",
279
- value="Using default sample data with 2 models",
280
- interactive=False
 
281
  )
282
 
283
  # Store current data
@@ -285,66 +437,111 @@ with gr.Blocks(title="Course Backbone Project Leaderboard \nMulti-Model Classifi
285
 
286
  with gr.Tabs():
287
  with gr.TabItem("πŸ† Model Leaderboard"):
 
288
  with gr.Row():
289
  with gr.Column(scale=1):
290
  partition_filter = gr.Dropdown(
291
  choices=['all', 'inference', 'test', 'train'],
292
  value='all',
293
- label="Filter by Partition"
294
  )
295
  topic_filter = gr.Dropdown(
296
  choices=['all', 'OVERALL'],
297
  value='OVERALL',
298
- label="Filter by Topic"
299
  )
 
 
 
 
 
 
300
 
301
  with gr.Column(scale=3):
302
  leaderboard_chart = gr.Plot()
303
 
304
  with gr.TabItem("πŸ“Š Topic Comparison"):
 
305
  with gr.Row():
306
  with gr.Column(scale=1):
307
  models_selector = gr.CheckboxGroup(
308
  choices=[],
309
- label="Select Models to Compare",
310
  value=[]
311
  )
312
  metric_selector = gr.Dropdown(
313
  choices=['FPR', 'Confidence', 'FDR', 'Precision', 'Recall_Power', 'Accuracy', 'G_mean'],
314
  value='Accuracy',
315
- label="Select Metric"
316
  )
317
  partition_filter_topic = gr.Dropdown(
318
  choices=['all', 'inference', 'test', 'train'],
319
  value='all',
320
- label="Filter by Partition"
321
  )
 
 
 
 
 
 
322
 
323
  with gr.Column(scale=3):
324
  topic_comparison_chart = gr.Plot()
325
 
326
  with gr.TabItem("πŸ”„ Partition Analysis"):
 
327
  with gr.Row():
328
  with gr.Column(scale=1):
329
  models_selector_partition = gr.CheckboxGroup(
330
  choices=[],
331
- label="Select Models to Analyze",
332
  value=[]
333
  )
 
 
 
 
 
 
 
334
 
335
  with gr.Column(scale=3):
336
  partition_analysis_chart = gr.Plot()
337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  with gr.TabItem("πŸ“ˆ Performance Summary"):
 
339
  summary_table = gr.DataFrame(
340
- label="Model Performance Summary",
341
- interactive=False
 
342
  )
343
 
344
  with gr.TabItem("πŸ“‹ Raw Data"):
 
345
  raw_data_table = gr.DataFrame(
346
- label="Complete Dataset",
347
- interactive=True
 
348
  )
349
 
350
  def update_dashboard(file):
@@ -355,18 +552,23 @@ with gr.Blocks(title="Course Backbone Project Leaderboard \nMulti-Model Classifi
355
  model_choices = sorted(df['model'].unique())
356
  topic_choices = ['all'] + sorted(df['topic'].unique())
357
 
 
 
 
358
  # Create initial plots
359
  leaderboard = create_model_leaderboard(df)
360
- topic_comp = create_topic_comparison(df, model_choices[:3])
361
- partition_analysis = create_partition_analysis(df, model_choices[:3])
 
362
  summary = create_performance_summary_table(df)
363
 
364
  return (
365
  df, status,
366
  gr.update(choices=topic_choices, value='OVERALL'),
367
- gr.update(choices=model_choices, value=model_choices[:3]),
368
- gr.update(choices=model_choices, value=model_choices[:3]),
369
- leaderboard, topic_comp, partition_analysis, summary, df
 
370
  )
371
 
372
  # Event handlers
@@ -375,9 +577,9 @@ with gr.Blocks(title="Course Backbone Project Leaderboard \nMulti-Model Classifi
375
  inputs=[csv_file],
376
  outputs=[
377
  current_data, data_status, topic_filter,
378
- models_selector, models_selector_partition,
379
  leaderboard_chart, topic_comparison_chart,
380
- partition_analysis_chart, summary_table, raw_data_table
381
  ]
382
  )
383
 
@@ -399,6 +601,8 @@ with gr.Blocks(title="Course Backbone Project Leaderboard \nMulti-Model Classifi
399
 
400
  # Update topic comparison when models, metric, or partition change
401
  def update_topic_comparison(data, selected_models, metric, partition):
 
 
402
  return create_topic_comparison(data, selected_models, metric, partition)
403
 
404
  models_selector.change(
@@ -421,6 +625,8 @@ with gr.Blocks(title="Course Backbone Project Leaderboard \nMulti-Model Classifi
421
 
422
  # Update partition analysis when models change
423
  def update_partition_analysis(data, selected_models):
 
 
424
  return create_partition_analysis(data, selected_models)
425
 
426
  models_selector_partition.change(
@@ -429,45 +635,105 @@ with gr.Blocks(title="Course Backbone Project Leaderboard \nMulti-Model Classifi
429
  outputs=partition_analysis_chart
430
  )
431
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  # Initialize dashboard with default data
433
  demo.load(
434
  fn=lambda: update_dashboard(None),
435
  outputs=[
436
  current_data, data_status, topic_filter,
437
- models_selector, models_selector_partition,
438
  leaderboard_chart, topic_comparison_chart,
439
- partition_analysis_chart, summary_table, raw_data_table
440
  ]
441
  )
442
 
443
  gr.Markdown("""
444
- ### πŸ’‘ Dashboard Features
445
-
446
- **πŸ“ Data Loading**: Upload your CSV file with classifier results. The app automatically detects all models and creates comparisons.
447
-
448
- **πŸ† Model Leaderboard**:
449
- - Compare all models side-by-side across key metrics
450
- - Filter by partition and topic for specific comparisons
451
- - Overall score calculated from precision, recall, and accuracy
452
-
453
- **πŸ“Š Topic Comparison**:
454
- - Select specific models to compare across all topics
455
- - Choose any metric (FPR, Confidence, FDR, Precision, Recall_Power, Accuracy, G_mean)
456
- - Filter by partition to focus on specific evaluation splits
457
- - Visual comparison across business categories
458
-
459
- **πŸ”„ Partition Analysis**:
460
- - Analyze all metrics across train/test/inference partitions
461
- - Compare multiple models across different evaluation splits
462
- - Monitor generalization capabilities and detect overfitting
463
- - Comprehensive view of all 7 performance metrics
464
-
465
- **πŸ“ˆ Performance Summary**:
466
- - Statistical overview of each model's performance
467
- - Best and worst performing topics for each model
468
- - Performance variance analysis
469
-
470
- **CSV Format**: Your file should have columns: `model`, `partition`, `topic`, `FPR`, `Confidence`, `FDR`, `Precision`, `Recall_Power`, `Accuracy`, `G_mean`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  """)
472
 
473
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import pandas as pd
3
  import plotly.graph_objects as go
4
  import plotly.express as px
 
5
  from plotly.subplots import make_subplots
 
 
6
 
7
  # Default sample data (will be replaced when CSV is uploaded)
8
  default_data = pd.DataFrame({
 
57
  def load_csv_data(file):
58
  """Load and validate CSV data"""
59
  if file is None:
60
+ return default_data, "πŸ“Š Using default sample data (2 models, 48 records)"
61
 
62
  try:
63
  df = pd.read_csv(file.name)
 
68
  missing_cols = [col for col in required_cols if col not in df.columns]
69
 
70
  if missing_cols:
71
+ return default_data, f"❌ Missing columns: {', '.join(missing_cols)}. Using default data."
72
 
73
  # Clean data
74
  df = df.dropna()
75
+
76
+ num_models = df['model'].nunique()
77
+ num_records = len(df)
78
 
79
+ return df, f"βœ… Successfully loaded: {num_records} records | {num_models} models | {df['topic'].nunique()} topics"
80
 
81
  except Exception as e:
82
  return default_data, f"❌ Error loading CSV: {str(e)}. Using default data."
83
 
84
  def create_model_leaderboard(df, partition_filter='all', topic_filter='OVERALL'):
85
+ """Create enhanced leaderboard comparing all models"""
86
  filtered_df = df.copy()
87
 
88
  if partition_filter != 'all':
 
97
 
98
  # Calculate overall score (average of key metrics)
99
  leaderboard['Overall_Score'] = leaderboard[['Precision', 'Recall_Power', 'Accuracy']].mean(axis=1)
100
+
101
+ # Create subplot for each metric with horizontal bars
 
102
  fig = make_subplots(
103
+ rows=len(metrics) + 1, cols=1,
104
+ subplot_titles=['<b>' + m.replace('_', ' ') + '</b>' for m in metrics] + ['<b>Overall Score</b>'],
105
+ vertical_spacing=0.08
106
  )
107
 
108
+ # Generate enough colors for all models
109
+ num_models = len(leaderboard)
110
+ color_palette = px.colors.qualitative.Plotly
111
+ colors = (color_palette * ((num_models // len(color_palette)) + 1))[:num_models]
112
 
113
  for i, metric in enumerate(metrics + ['Overall_Score']):
114
+ sorted_data = leaderboard.sort_values(metric, ascending=True)
115
+
116
+ fig.add_trace(
117
+ go.Bar(
118
+ y=sorted_data['model'],
119
+ x=sorted_data[metric],
120
+ orientation='h',
121
+ marker=dict(
122
+ color=sorted_data[metric],
123
+ colorscale='RdYlGn',
124
+ showscale=False,
125
+ line=dict(color='rgb(50,50,50)', width=1.5)
126
  ),
127
+ text=[f"<b>{val:.4f}</b>" for val in sorted_data[metric]],
128
+ textposition='auto',
129
+ textfont=dict(size=12, color='white', family='Arial Black'),
130
+ hovertemplate='<b>%{y}</b><br>' + metric.replace('_', ' ') + ': <b>%{x:.4f}</b><extra></extra>'
131
+ ),
132
+ row=i+1, col=1
133
+ )
134
 
135
  fig.update_layout(
136
+ title=dict(
137
+ text=f"<b>πŸ† Model Leaderboard</b><br><sub>Partition: {partition_filter.title()} | Topic: {topic_filter}</sub>",
138
+ font=dict(size=22, color='#2c3e50')
139
+ ),
140
+ height=300 * (len(metrics) + 1),
141
+ showlegend=False,
142
+ font=dict(size=12, family='Arial'),
143
+ plot_bgcolor='rgba(245,245,245,0.8)',
144
+ paper_bgcolor='white'
145
  )
146
 
147
+ # Update axes
148
  for i in range(1, len(metrics) + 2):
149
+ fig.update_xaxes(range=[0, 1.05], gridcolor='rgba(200,200,200,0.5)',
150
+ showgrid=True, fixedrange=False, row=i, col=1)
151
+ fig.update_yaxes(tickfont=dict(size=11), row=i, col=1)
152
 
153
  return fig
154
 
155
  def create_topic_comparison(df, models_selected=None, metric='Accuracy', partition_filter='all'):
156
+ """Create enhanced topic comparison chart"""
157
  if models_selected is None or len(models_selected) == 0:
158
+ models_selected = df['model'].unique()[:3]
159
 
160
  # Filter data
161
  filtered_df = df[df['model'].isin(models_selected)].copy()
 
169
  # Create grouped bar chart
170
  fig = go.Figure()
171
 
172
+ # Generate enough colors for all models
173
+ num_models = len(models_selected)
174
+ color_palette = px.colors.qualitative.Bold
175
+ colors = (color_palette * ((num_models // len(color_palette)) + 1))[:num_models]
176
+
177
  topics = sorted(topic_performance['topic'].unique())
178
 
179
  for i, model in enumerate(models_selected):
180
  model_data = topic_performance[topic_performance['model'] == model]
181
+ # Sort by topic order
182
+ model_data = model_data.set_index('topic').reindex(topics).reset_index()
183
+
184
+ # Shortened model name for legend
185
+ model_short = model if len(model) <= 30 else model[:27] + '...'
186
+
187
  fig.add_trace(go.Bar(
188
+ name=model_short,
189
  x=topics,
190
  y=model_data[metric],
191
+ marker=dict(
192
+ color=colors[i],
193
+ line=dict(color='rgb(40,40,40)', width=1.5),
194
+ opacity=0.85
195
+ ),
196
+ text=[f"<b>{val:.3f}</b>" if not pd.isna(val) else 'N/A' for val in model_data[metric]],
197
+ textposition='outside',
198
+ textfont=dict(size=11, color='black'),
199
+ hovertemplate='<b>Topic:</b> %{x}<br><b>Model:</b> ' + model + '<br><b>' + metric + ':</b> %{y:.4f}<extra></extra>'
200
  ))
201
 
202
  fig.update_layout(
203
+ title=dict(
204
+ text=f"<b>πŸ“Š Topic Performance Comparison</b><br><sub>Metric: {metric.replace('_', ' ')} | Partition: {partition_filter.title()}</sub>",
205
+ font=dict(size=20, color='#2c3e50')
206
+ ),
207
+ xaxis_title="<b>Topics</b>",
208
+ yaxis_title=f"<b>{metric.replace('_', ' ')}</b>",
209
  barmode='group',
210
+ height=600,
211
+ xaxis=dict(
212
+ tickangle=-45,
213
+ tickfont=dict(size=11),
214
+ fixedrange=False
215
+ ),
216
+ yaxis=dict(
217
+ range=[0, max(1.1, topic_performance[metric].max() * 1.15)],
218
+ gridcolor='rgba(200,200,200,0.4)',
219
+ fixedrange=False,
220
+ showgrid=True
221
+ ),
222
+ legend=dict(
223
+ orientation="h",
224
+ yanchor="bottom",
225
+ y=1.02,
226
+ xanchor="right",
227
+ x=1,
228
+ font=dict(size=12),
229
+ bgcolor='rgba(255,255,255,0.8)',
230
+ bordercolor='gray',
231
+ borderwidth=1
232
+ ),
233
+ plot_bgcolor='rgba(245,245,245,0.6)',
234
+ paper_bgcolor='white',
235
+ bargap=0.15,
236
+ bargroupgap=0.1
237
  )
238
 
239
  return fig
240
 
241
  def create_partition_analysis(df, models_selected=None):
242
+ """Create enhanced partition analysis with all metrics"""
243
  if models_selected is None or len(models_selected) == 0:
244
  models_selected = df['model'].unique()[:3]
245
 
 
249
  metrics = ['FPR', 'Confidence', 'FDR', 'Precision', 'Recall_Power', 'Accuracy', 'G_mean']
250
  partition_performance = filtered_df.groupby(['model', 'partition'])[metrics].mean().reset_index()
251
 
252
+ # Create subplots in a 3x3 grid
253
  fig = make_subplots(
254
+ rows=3, cols=3,
255
+ subplot_titles=['<b>' + m.replace('_', ' ') + '</b>' for m in metrics] + ['', ''],
256
+ specs=[[{"type": "bar"}, {"type": "bar"}, {"type": "bar"}],
257
+ [{"type": "bar"}, {"type": "bar"}, {"type": "bar"}],
258
+ [{"type": "bar"}, None, None]],
259
+ vertical_spacing=0.12,
260
+ horizontal_spacing=0.1
261
  )
262
 
263
+ # Generate enough colors for all models
264
+ num_models = len(models_selected)
265
+ color_palette = px.colors.qualitative.Bold
266
+ colors = (color_palette * ((num_models // len(color_palette)) + 1))[:num_models]
267
+
268
  partitions = ['train', 'test', 'inference']
269
 
270
  # Plot each metric
271
+ positions = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3), (3,1)]
272
+
273
+ for idx, metric in enumerate(metrics):
274
+ row, col = positions[idx]
275
 
276
  for j, model in enumerate(models_selected):
277
  model_data = partition_performance[partition_performance['model'] == model]
278
+ model_data = model_data.sort_values('partition')
279
+
280
+ # Shortened model name
281
+ model_short = model if len(model) <= 25 else model[:22] + '...'
282
 
283
  fig.add_trace(
284
  go.Bar(
285
+ name=model_short if idx == 0 else "",
286
  x=model_data['partition'],
287
  y=model_data[metric],
288
+ marker=dict(
289
+ color=colors[j],
290
+ line=dict(color='rgb(40,40,40)', width=1.5),
291
+ opacity=0.85
292
+ ),
293
+ showlegend=True if idx == 0 else False,
294
+ text=[f"<b>{val:.3f}</b>" for val in model_data[metric]],
295
+ textposition='outside',
296
+ textfont=dict(size=10, color='black'),
297
+ hovertemplate='<b>Partition:</b> %{x}<br><b>Model:</b> ' + model + '<br><b>' + metric + ':</b> %{y:.4f}<extra></extra>'
298
  ),
299
  row=row, col=col
300
  )
301
 
302
  fig.update_layout(
303
+ title=dict(
304
+ text="<b>πŸ”„ Model Performance Across Partitions</b><br><sub>All Metrics Overview</sub>",
305
+ font=dict(size=20, color='#2c3e50')
306
+ ),
307
+ height=950,
308
+ barmode='group',
309
+ bargap=0.15,
310
+ bargroupgap=0.1,
311
+ legend=dict(
312
+ orientation="h",
313
+ yanchor="bottom",
314
+ y=1.02,
315
+ xanchor="right",
316
+ x=1,
317
+ font=dict(size=11),
318
+ bgcolor='rgba(255,255,255,0.8)',
319
+ bordercolor='gray',
320
+ borderwidth=1
321
+ ),
322
+ plot_bgcolor='rgba(245,245,245,0.6)',
323
+ paper_bgcolor='white'
324
  )
325
 
326
+ # Update axes for all subplots
327
+ for idx in range(len(metrics)):
328
+ row, col = positions[idx]
329
+ fig.update_yaxes(
330
+ range=[0, 1.05],
331
+ gridcolor='rgba(200,200,200,0.4)',
332
+ showgrid=True,
333
+ row=row, col=col, fixedrange=False
334
+ )
335
+ fig.update_xaxes(tickfont=dict(size=10), row=row, col=col)
336
 
337
  return fig
338
 
339
  def create_performance_summary_table(df):
340
+ """Create enhanced summary table with key statistics"""
 
341
  summary_stats = []
342
 
343
  for model in df['model'].unique():
344
  model_data = df[df['model'] == model]
345
 
346
  stats = {
347
+ '🏷️ Model': model,
348
+ 'πŸ“Š Avg Accuracy': f"{model_data['Accuracy'].mean():.4f}",
349
+ '🎯 Avg Precision': f"{model_data['Precision'].mean():.4f}",
350
+ 'πŸ” Avg Recall': f"{model_data['Recall_Power'].mean():.4f}",
351
+ 'πŸ“ˆ Avg G-mean': f"{model_data['G_mean'].mean():.4f}",
352
+ 'βœ… Best Topic': model_data.loc[model_data['Accuracy'].idxmax(), 'topic'],
353
+ '⭐ Best Score': f"{model_data['Accuracy'].max():.4f}",
354
+ '⚠️ Worst Topic': model_data.loc[model_data['Accuracy'].idxmin(), 'topic'],
355
+ 'πŸ“‰ Worst Score': f"{model_data['Accuracy'].min():.4f}",
356
+ 'πŸ“Š Variance': f"{model_data['Accuracy'].var():.6f}"
357
  }
358
  summary_stats.append(stats)
359
 
360
  summary_df = pd.DataFrame(summary_stats)
361
+
362
+ # Sort by average accuracy
363
+ summary_df['_sort_key'] = summary_df['πŸ“Š Avg Accuracy'].astype(float)
364
+ summary_df = summary_df.sort_values('_sort_key', ascending=False)
365
+ summary_df = summary_df.drop('_sort_key', axis=1)
366
 
367
  return summary_df
368
 
369
+ def create_detailed_metrics_heatmap(df, models_selected=None):
370
+ """Create a heatmap showing all metrics for selected models"""
371
+ if models_selected is None or len(models_selected) == 0:
372
+ models_selected = df['model'].unique()[:3]
373
+
374
+ filtered_df = df[df['model'].isin(models_selected)].copy()
375
+
376
+ # Calculate average for each metric by model
377
+ metrics = ['FPR', 'Confidence', 'FDR', 'Precision', 'Recall_Power', 'Accuracy', 'G_mean']
378
+ heatmap_data = filtered_df.groupby('model')[metrics].mean()
379
+
380
+ # Create heatmap
381
+ fig = go.Figure(data=go.Heatmap(
382
+ z=heatmap_data.values,
383
+ x=[m.replace('_', ' ') for m in metrics],
384
+ y=heatmap_data.index,
385
+ colorscale='RdYlGn',
386
+ text=heatmap_data.values.round(4),
387
+ texttemplate='<b>%{text}</b>',
388
+ textfont={"size": 12},
389
+ hovertemplate='<b>Model:</b> %{y}<br><b>Metric:</b> %{x}<br><b>Value:</b> %{z:.4f}<extra></extra>',
390
+ colorbar=dict(title="Score")
391
+ ))
392
+
393
+ fig.update_layout(
394
+ title=dict(
395
+ text="<b>πŸ”₯ Metrics Heatmap</b><br><sub>Average Performance Across All Topics and Partitions</sub>",
396
+ font=dict(size=20, color='#2c3e50')
397
+ ),
398
+ xaxis_title="<b>Metrics</b>",
399
+ yaxis_title="<b>Models</b>",
400
+ height=200 + (len(models_selected) * 60),
401
+ font=dict(size=12),
402
+ plot_bgcolor='white',
403
+ paper_bgcolor='white',
404
+ xaxis=dict(fixedrange=False),
405
+ yaxis=dict(fixedrange=False)
406
+ )
407
+
408
+ return fig
409
+
410
+ # Create the Gradio interface with enhanced styling
411
+ with gr.Blocks(title="Multi-Model Classifier Dashboard", theme=gr.themes.Soft()) as demo:
412
+ gr.HTML("""
413
+ <div style='text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; margin-bottom: 20px;'>
414
+ <h1 style='color: white; margin: 0; font-size: 2.5em;'>πŸ† Multi-Model Classifier Dashboard</h1>
415
+ <p style='color: #f0f0f0; margin: 10px 0 0 0; font-size: 1.1em;'>Comprehensive Performance Analysis & Comparison Tool</p>
416
+ </div>
417
+ """)
418
 
419
  # Data loading section
420
  with gr.Row():
421
+ with gr.Column(scale=2):
422
  csv_file = gr.File(
423
  label="πŸ“ Upload CSV File",
424
+ file_types=['.csv'],
425
+ file_count="single"
426
  )
427
+ with gr.Column(scale=3):
428
  data_status = gr.Textbox(
429
+ label="πŸ“Š Data Status",
430
+ value="Using default sample data (2 models, 48 records)",
431
+ interactive=False,
432
+ lines=2
433
  )
434
 
435
  # Store current data
 
437
 
438
  with gr.Tabs():
439
  with gr.TabItem("πŸ† Model Leaderboard"):
440
+ gr.Markdown("### Compare all models side-by-side across key performance metrics")
441
  with gr.Row():
442
  with gr.Column(scale=1):
443
  partition_filter = gr.Dropdown(
444
  choices=['all', 'inference', 'test', 'train'],
445
  value='all',
446
+ label="πŸ” Filter by Partition"
447
  )
448
  topic_filter = gr.Dropdown(
449
  choices=['all', 'OVERALL'],
450
  value='OVERALL',
451
+ label="🏷️ Filter by Topic"
452
  )
453
+ gr.Markdown("""
454
+ **πŸ“– How to use:**
455
+ - Select partition to view performance on specific data splits
456
+ - Choose topic to focus on particular business domains
457
+ - Bars are color-coded: 🟒 Green = Better, πŸ”΄ Red = Worse
458
+ """)
459
 
460
  with gr.Column(scale=3):
461
  leaderboard_chart = gr.Plot()
462
 
463
  with gr.TabItem("πŸ“Š Topic Comparison"):
464
+ gr.Markdown("### Analyze how selected models perform across different topics")
465
  with gr.Row():
466
  with gr.Column(scale=1):
467
  models_selector = gr.CheckboxGroup(
468
  choices=[],
469
+ label="βœ… Select Models to Compare",
470
  value=[]
471
  )
472
  metric_selector = gr.Dropdown(
473
  choices=['FPR', 'Confidence', 'FDR', 'Precision', 'Recall_Power', 'Accuracy', 'G_mean'],
474
  value='Accuracy',
475
+ label="πŸ“ Select Metric"
476
  )
477
  partition_filter_topic = gr.Dropdown(
478
  choices=['all', 'inference', 'test', 'train'],
479
  value='all',
480
+ label="πŸ” Filter by Partition"
481
  )
482
+ gr.Markdown("""
483
+ **πŸ“– How to use:**
484
+ - Check models you want to compare
485
+ - Choose the metric to analyze
486
+ - Compare strengths/weaknesses across topics
487
+ """)
488
 
489
  with gr.Column(scale=3):
490
  topic_comparison_chart = gr.Plot()
491
 
492
  with gr.TabItem("πŸ”„ Partition Analysis"):
493
+ gr.Markdown("### Examine model performance across train/test/inference splits")
494
  with gr.Row():
495
  with gr.Column(scale=1):
496
  models_selector_partition = gr.CheckboxGroup(
497
  choices=[],
498
+ label="βœ… Select Models to Analyze",
499
  value=[]
500
  )
501
+ gr.Markdown("""
502
+ **πŸ“– How to use:**
503
+ - Select models to analyze
504
+ - View all 7 metrics simultaneously
505
+ - Identify overfitting (train >> test)
506
+ - Check generalization (test vs inference)
507
+ """)
508
 
509
  with gr.Column(scale=3):
510
  partition_analysis_chart = gr.Plot()
511
 
512
+ with gr.TabItem("πŸ”₯ Metrics Heatmap"):
513
+ gr.Markdown("### Visual overview of all metrics for quick comparison")
514
+ with gr.Row():
515
+ with gr.Column(scale=1):
516
+ models_selector_heatmap = gr.CheckboxGroup(
517
+ choices=[],
518
+ label="βœ… Select Models for Heatmap",
519
+ value=[]
520
+ )
521
+ gr.Markdown("""
522
+ **πŸ“– How to use:**
523
+ - Select models to include in heatmap
524
+ - Quickly spot strengths (green) and weaknesses (red)
525
+ - Average across all topics and partitions
526
+ """)
527
+
528
+ with gr.Column(scale=3):
529
+ heatmap_chart = gr.Plot()
530
+
531
  with gr.TabItem("πŸ“ˆ Performance Summary"):
532
+ gr.Markdown("### Statistical overview and key insights for each model")
533
  summary_table = gr.DataFrame(
534
+ label="πŸ“Š Model Performance Summary Table",
535
+ interactive=False,
536
+ wrap=True
537
  )
538
 
539
  with gr.TabItem("πŸ“‹ Raw Data"):
540
+ gr.Markdown("### Complete dataset view - explore all records")
541
  raw_data_table = gr.DataFrame(
542
+ label="πŸ—‚οΈ Complete Dataset",
543
+ interactive=True,
544
+ wrap=True
545
  )
546
 
547
  def update_dashboard(file):
 
552
  model_choices = sorted(df['model'].unique())
553
  topic_choices = ['all'] + sorted(df['topic'].unique())
554
 
555
+ # Select default models (up to 3)
556
+ default_models = model_choices[:min(3, len(model_choices))]
557
+
558
  # Create initial plots
559
  leaderboard = create_model_leaderboard(df)
560
+ topic_comp = create_topic_comparison(df, default_models)
561
+ partition_analysis = create_partition_analysis(df, default_models)
562
+ heatmap = create_detailed_metrics_heatmap(df, default_models)
563
  summary = create_performance_summary_table(df)
564
 
565
  return (
566
  df, status,
567
  gr.update(choices=topic_choices, value='OVERALL'),
568
+ gr.update(choices=model_choices, value=default_models),
569
+ gr.update(choices=model_choices, value=default_models),
570
+ gr.update(choices=model_choices, value=default_models),
571
+ leaderboard, topic_comp, partition_analysis, heatmap, summary, df
572
  )
573
 
574
  # Event handlers
 
577
  inputs=[csv_file],
578
  outputs=[
579
  current_data, data_status, topic_filter,
580
+ models_selector, models_selector_partition, models_selector_heatmap,
581
  leaderboard_chart, topic_comparison_chart,
582
+ partition_analysis_chart, heatmap_chart, summary_table, raw_data_table
583
  ]
584
  )
585
 
 
601
 
602
  # Update topic comparison when models, metric, or partition change
603
  def update_topic_comparison(data, selected_models, metric, partition):
604
+ if not selected_models:
605
+ selected_models = data['model'].unique()[:3]
606
  return create_topic_comparison(data, selected_models, metric, partition)
607
 
608
  models_selector.change(
 
625
 
626
  # Update partition analysis when models change
627
  def update_partition_analysis(data, selected_models):
628
+ if not selected_models:
629
+ selected_models = data['model'].unique()[:3]
630
  return create_partition_analysis(data, selected_models)
631
 
632
  models_selector_partition.change(
 
635
  outputs=partition_analysis_chart
636
  )
637
 
638
+ # Update heatmap when models change
639
+ def update_heatmap(data, selected_models):
640
+ if not selected_models:
641
+ selected_models = data['model'].unique()[:3]
642
+ return create_detailed_metrics_heatmap(data, selected_models)
643
+
644
+ models_selector_heatmap.change(
645
+ fn=update_heatmap,
646
+ inputs=[current_data, models_selector_heatmap],
647
+ outputs=heatmap_chart
648
+ )
649
+
650
  # Initialize dashboard with default data
651
  demo.load(
652
  fn=lambda: update_dashboard(None),
653
  outputs=[
654
  current_data, data_status, topic_filter,
655
+ models_selector, models_selector_partition, models_selector_heatmap,
656
  leaderboard_chart, topic_comparison_chart,
657
+ partition_analysis_chart, heatmap_chart, summary_table, raw_data_table
658
  ]
659
  )
660
 
661
  gr.Markdown("""
662
+ ---
663
+ ### πŸ’‘ Dashboard Features & Usage Guide
664
+
665
+ #### πŸ“ **Data Loading**
666
+ Upload your CSV file with classifier performance results. The dashboard automatically:
667
+ - Detects all models in your dataset
668
+ - Validates data structure and quality
669
+ - Creates comprehensive comparisons across all dimensions
670
+
671
+ #### πŸ† **Model Leaderboard**
672
+ - **Horizontal bar charts** make it easy to compare models at a glance
673
+ - **Color-coded performance**: Green indicates better scores, red indicates lower scores
674
+ - Filter by **partition** (train/test/inference) and **topic** for targeted analysis
675
+ - **Overall Score** calculated from precision, recall, and accuracy averages
676
+
677
+ #### πŸ“Š **Topic Comparison**
678
+ - Select multiple models to compare side-by-side
679
+ - Choose any metric: FPR, Confidence, FDR, Precision, Recall, Accuracy, or G-mean
680
+ - Identify which topics each model excels at or struggles with
681
+ - Filter by partition to see performance on specific data splits
682
+
683
+ #### πŸ”„ **Partition Analysis**
684
+ - View all 7 metrics simultaneously in a compact grid layout
685
+ - Compare train/test/inference performance to detect overfitting
686
+ - Check model generalization capabilities
687
+ - Grouped bars show direct model-to-model comparisons
688
+
689
+ #### πŸ”₯ **Metrics Heatmap**
690
+ - **Visual overview** of all metrics for quick pattern recognition
691
+ - Color intensity shows performance levels at a glance
692
+ - Average performance across all topics and partitions
693
+ - Perfect for executive summaries and presentations
694
+
695
+ #### πŸ“ˆ **Performance Summary**
696
+ - Statistical overview with key performance indicators
697
+ - Best and worst performing topics identified for each model
698
+ - Performance variance shows consistency across topics
699
+ - Sortable table for custom analysis
700
+
701
+ #### πŸ“‹ **CSV Format Requirements**
702
+ Your CSV file must include these columns:
703
+ - `model`: Model name/identifier
704
+ - `partition`: Data split (train/test/inference)
705
+ - `topic`: Business domain or category
706
+ - `FPR`: False Positive Rate
707
+ - `Confidence`: Model confidence scores
708
+ - `FDR`: False Discovery Rate
709
+ - `Precision`: Positive predictive value
710
+ - `Recall_Power`: True positive rate / Sensitivity
711
+ - `Accuracy`: Overall correctness
712
+ - `G_mean`: Geometric mean of sensitivity and specificity
713
+
714
+ ---
715
+
716
+ ### 🎯 **Tips for Best Results**
717
+
718
+ 1. **Compare 2-4 models** at a time for clearest visualizations
719
+ 2. **Start with the Leaderboard** to identify top performers
720
+ 3. **Use Topic Comparison** to find domain-specific strengths
721
+ 4. **Check Partition Analysis** to ensure model generalization
722
+ 5. **Review the Heatmap** for quick executive summaries
723
+
724
+ ### πŸš€ **Advanced Analysis**
725
+
726
+ - **Overfitting Detection**: If train scores >> test scores in Partition Analysis
727
+ - **Generalization Check**: Compare test vs inference performance
728
+ - **Topic Specialization**: Use Topic Comparison to identify niche strengths
729
+ - **Consistency Analysis**: Check variance in Performance Summary
730
+
731
+ ---
732
+
733
+ <div style='text-align: center; padding: 15px; background-color: #f8f9fa; border-radius: 8px; margin-top: 20px;'>
734
+ <p style='margin: 0; color: #6c757d;'><b>Built with ❀️ using Gradio & Plotly</b></p>
735
+ <p style='margin: 5px 0 0 0; font-size: 0.9em; color: #6c757d;'>Interactive ML Model Performance Dashboard</p>
736
+ </div>
737
  """)
738
 
739
  if __name__ == "__main__":