cgeorgiaw HF Staff commited on
Commit
2484fe9
Β·
1 Parent(s): bc5d507

adding filter for training set

Browse files
Files changed (2) hide show
  1. about.py +3 -0
  2. app.py +54 -8
about.py CHANGED
@@ -61,6 +61,8 @@ COLUMN_DISPLAY_NAMES = {
61
  'hhi_production_mean': 'HHI Production',
62
  'hhi_reserve_mean': 'HHI Reserve',
63
  'hhi_combined_mean': 'HHI Combined',
 
 
64
  }
65
 
66
  # Metrics that can be shown as percentages (count-based metrics)
@@ -152,6 +154,7 @@ COLUMN_TO_GROUP = get_column_to_group_mapping()
152
  # Compact view columns (most important metrics visible without scrolling)
153
  COMPACT_VIEW_COLUMNS = [
154
  'model_name',
 
155
  'overall_valid_count',
156
  'unique_count',
157
  'novel_count',
 
61
  'hhi_production_mean': 'HHI Production',
62
  'hhi_reserve_mean': 'HHI Reserve',
63
  'hhi_combined_mean': 'HHI Combined',
64
+ # Metadata columns
65
+ 'training_set': 'Training Set',
66
  }
67
 
68
  # Metrics that can be shown as percentages (count-based metrics)
 
154
  # Compact view columns (most important metrics visible without scrolling)
155
  COMPACT_VIEW_COLUMNS = [
156
  'model_name',
157
+ 'training_set',
158
  'overall_valid_count',
159
  'unique_count',
160
  'novel_count',
app.py CHANGED
@@ -45,6 +45,8 @@ def format_dataframe(df, show_percentage=False, selected_groups=None, compact_vi
45
  selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns]
46
  else:
47
  # Build from selected groups
 
 
48
  if 'n_structures' in df.columns:
49
  selected_cols.append('n_structures')
50
 
@@ -73,6 +75,12 @@ def format_dataframe(df, show_percentage=False, selected_groups=None, compact_vi
73
  name = row['model_name']
74
  symbols = []
75
 
 
 
 
 
 
 
76
  # Add relaxed symbol
77
  if 'relaxed' in df.columns and row.get('relaxed', False):
78
  symbols.append('⚑')
@@ -109,16 +117,28 @@ def format_dataframe(df, show_percentage=False, selected_groups=None, compact_vi
109
  if display_df[col].dtype in ['float64', 'float32']:
110
  display_df[col] = display_df[col].round(4)
111
 
 
 
 
 
 
 
 
 
 
 
112
  # Rename columns for display
113
  display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES)
114
 
115
  # Apply color coding based on metric groups
116
- styled_df = apply_color_styling(display_df, selected_cols)
117
 
118
  return styled_df
119
 
120
- def apply_color_styling(display_df, original_cols):
121
  """Apply background colors to dataframe based on metric groups using pandas Styler."""
 
 
122
 
123
  def style_by_group(x):
124
  # Create a DataFrame with the same shape filled with empty strings
@@ -136,12 +156,20 @@ def apply_color_styling(display_df, original_cols):
136
  if color:
137
  styles[display_col] = f'background-color: {color}'
138
 
 
 
 
 
 
 
 
 
139
  return styles
140
 
141
  # Apply the styling function
142
  return display_df.style.apply(style_by_group, axis=None)
143
 
144
- def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction):
145
  """Update the leaderboard based on user selections.
146
 
147
  Uses cached dataframe to avoid re-downloading data on every change.
@@ -149,6 +177,10 @@ def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df
149
  # Use cached dataframe instead of re-downloading
150
  df_to_format = cached_df.copy()
151
 
 
 
 
 
152
  # Convert display name back to raw column name for sorting
153
  if sort_by and sort_by != "None":
154
  # Create reverse mapping from display names to raw column names
@@ -321,6 +353,12 @@ Generative machine learning models hold great promise for accelerating materials
321
  value="Descending",
322
  label="Sort Direction"
323
  )
 
 
 
 
 
 
324
  with gr.Column(scale=2):
325
  selected_groups = gr.CheckboxGroup(
326
  choices=list(METRIC_GROUPS.keys()),
@@ -353,27 +391,32 @@ Generative machine learning models hold great promise for accelerating materials
353
  # Update dataframe when options change (using cached data)
354
  show_percentage.change(
355
  fn=update_leaderboard,
356
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
357
  outputs=leaderboard_table
358
  )
359
  selected_groups.change(
360
  fn=update_leaderboard,
361
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
362
  outputs=leaderboard_table
363
  )
364
  compact_view.change(
365
  fn=update_leaderboard,
366
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
367
  outputs=leaderboard_table
368
  )
369
  sort_by.change(
370
  fn=update_leaderboard,
371
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
372
  outputs=leaderboard_table
373
  )
374
  sort_direction.change(
375
  fn=update_leaderboard,
376
- inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction],
 
 
 
 
 
377
  outputs=leaderboard_table
378
  )
379
 
@@ -382,12 +425,15 @@ Generative machine learning models hold great promise for accelerating materials
382
 
383
  gr.Markdown("""
384
  **Symbol Legend:**
 
385
  - βœ… Model output verified
386
  - ⚑ Structures were already relaxed
387
  - β˜… Contributes to LeMat-Bulk reference dataset (in-distribution)
388
  - β—† Out-of-distribution relative to LeMat-Bulk reference dataset
389
 
390
  Verified submissions mean the results came from a model submission rather than a CIF submission.
 
 
391
  """)
392
 
393
  with gr.TabItem("βœ‰οΈ Submit", elem_id="boundary-benchmark-tab-table"):
 
45
  selected_cols = [col for col in COMPACT_VIEW_COLUMNS if col in df.columns]
46
  else:
47
  # Build from selected groups
48
+ if 'training_set' in df.columns:
49
+ selected_cols.append('training_set')
50
  if 'n_structures' in df.columns:
51
  selected_cols.append('n_structures')
52
 
 
75
  name = row['model_name']
76
  symbols = []
77
 
78
+ # Add paper link emoji
79
+ if 'paper_link' in df.columns:
80
+ paper_val = row.get('paper_link', None)
81
+ if paper_val and isinstance(paper_val, str) and paper_val.strip():
82
+ symbols.append(f'<a href="{paper_val.strip()}" target="_blank">πŸ“„</a>')
83
+
84
  # Add relaxed symbol
85
  if 'relaxed' in df.columns and row.get('relaxed', False):
86
  symbols.append('⚑')
 
117
  if display_df[col].dtype in ['float64', 'float32']:
118
  display_df[col] = display_df[col].round(4)
119
 
120
+ # Separate baseline models to the bottom
121
+ baseline_indices = set()
122
+ if 'notes' in df.columns:
123
+ is_baseline = df['notes'].fillna('').str.contains('baseline', case=False, na=False)
124
+ non_baseline_df = display_df[~is_baseline]
125
+ baseline_df = display_df[is_baseline]
126
+ display_df = pd.concat([non_baseline_df, baseline_df]).reset_index(drop=True)
127
+ # Track baseline row indices in the new dataframe
128
+ baseline_indices = set(range(len(non_baseline_df), len(display_df)))
129
+
130
  # Rename columns for display
131
  display_df = display_df.rename(columns=COLUMN_DISPLAY_NAMES)
132
 
133
  # Apply color coding based on metric groups
134
+ styled_df = apply_color_styling(display_df, selected_cols, baseline_indices)
135
 
136
  return styled_df
137
 
138
+ def apply_color_styling(display_df, original_cols, baseline_indices=None):
139
  """Apply background colors to dataframe based on metric groups using pandas Styler."""
140
+ if baseline_indices is None:
141
+ baseline_indices = set()
142
 
143
  def style_by_group(x):
144
  # Create a DataFrame with the same shape filled with empty strings
 
156
  if color:
157
  styles[display_col] = f'background-color: {color}'
158
 
159
+ # Add thick top border to the first baseline row as a separator
160
+ if baseline_indices:
161
+ first_baseline_idx = min(baseline_indices)
162
+ for col in x.columns:
163
+ current = styles.at[first_baseline_idx, col]
164
+ separator_style = 'border-top: 3px solid #555'
165
+ styles.at[first_baseline_idx, col] = f'{current}; {separator_style}' if current else separator_style
166
+
167
  return styles
168
 
169
  # Apply the styling function
170
  return display_df.style.apply(style_by_group, axis=None)
171
 
172
+ def update_leaderboard(show_percentage, selected_groups, compact_view, cached_df, sort_by, sort_direction, training_set_filter):
173
  """Update the leaderboard based on user selections.
174
 
175
  Uses cached dataframe to avoid re-downloading data on every change.
 
177
  # Use cached dataframe instead of re-downloading
178
  df_to_format = cached_df.copy()
179
 
180
+ # Apply training set filter
181
+ if training_set_filter and training_set_filter != "All" and 'training_set' in df_to_format.columns:
182
+ df_to_format = df_to_format[df_to_format['training_set'] == training_set_filter]
183
+
184
  # Convert display name back to raw column name for sorting
185
  if sort_by and sort_by != "None":
186
  # Create reverse mapping from display names to raw column names
 
353
  value="Descending",
354
  label="Sort Direction"
355
  )
356
+ training_set_filter = gr.Dropdown(
357
+ choices=["All"] + TRAINING_DATASETS,
358
+ value="All",
359
+ label="Filter by Training Set",
360
+ info="Show only models trained on a specific dataset"
361
+ )
362
  with gr.Column(scale=2):
363
  selected_groups = gr.CheckboxGroup(
364
  choices=list(METRIC_GROUPS.keys()),
 
391
  # Update dataframe when options change (using cached data)
392
  show_percentage.change(
393
  fn=update_leaderboard,
394
+ inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
395
  outputs=leaderboard_table
396
  )
397
  selected_groups.change(
398
  fn=update_leaderboard,
399
+ inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
400
  outputs=leaderboard_table
401
  )
402
  compact_view.change(
403
  fn=update_leaderboard,
404
+ inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
405
  outputs=leaderboard_table
406
  )
407
  sort_by.change(
408
  fn=update_leaderboard,
409
+ inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
410
  outputs=leaderboard_table
411
  )
412
  sort_direction.change(
413
  fn=update_leaderboard,
414
+ inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
415
+ outputs=leaderboard_table
416
+ )
417
+ training_set_filter.change(
418
+ fn=update_leaderboard,
419
+ inputs=[show_percentage, selected_groups, compact_view, cached_df_state, sort_by, sort_direction, training_set_filter],
420
  outputs=leaderboard_table
421
  )
422
 
 
425
 
426
  gr.Markdown("""
427
  **Symbol Legend:**
428
+ - πŸ“„ Paper available (click to view)
429
  - βœ… Model output verified
430
  - ⚑ Structures were already relaxed
431
  - β˜… Contributes to LeMat-Bulk reference dataset (in-distribution)
432
  - β—† Out-of-distribution relative to LeMat-Bulk reference dataset
433
 
434
  Verified submissions mean the results came from a model submission rather than a CIF submission.
435
+
436
+ Models marked as baselines appear below the separator line at the bottom of the table.
437
  """)
438
 
439
  with gr.TabItem("βœ‰οΈ Submit", elem_id="boundary-benchmark-tab-table"):