TemryL commited on
Commit
b34f2b8
·
1 Parent(s): 9f320e8

filter nb_shots

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -70,13 +70,14 @@ def update_table(
70
  columns: list,
71
  phenotypes: list,
72
  metrics: list,
 
73
  type_query: list,
74
  precision_query: str,
75
  size_query: list,
76
  show_deleted: bool,
77
  query: str,
78
  ):
79
- filtered_df = filter_models(hidden_df, type_query, size_query, precision_query, show_deleted)
80
  filtered_df = filter_queries(query, filtered_df)
81
  df = select_columns(filtered_df, columns, phenotypes, metrics)
82
  return df
@@ -124,8 +125,7 @@ def filter_queries(query: str, filtered_df: pd.DataFrame) -> pd.DataFrame:
124
 
125
 
126
  def filter_models(
127
- df: pd.DataFrame, type_query: list, size_query: list, precision_query: list, show_deleted: bool
128
- ) -> pd.DataFrame:
129
  # Show all models
130
  if show_deleted:
131
  filtered_df = df
@@ -135,7 +135,9 @@ def filter_models(
135
  type_emoji = [t[0] for t in type_query]
136
  filtered_df = filtered_df.loc[df[AutoEvalColumn.model_type_symbol.name].isin(type_emoji)]
137
  filtered_df = filtered_df.loc[df[AutoEvalColumn.precision.name].isin(precision_query + ["None"])]
138
-
 
 
139
  numeric_interval = pd.IntervalIndex(sorted([NUMERIC_INTERVALS[s] for s in size_query]))
140
  params_column = pd.to_numeric(df[AutoEvalColumn.params.name], errors="coerce")
141
  mask = params_column.apply(lambda x: any(numeric_interval.contains(x)))
@@ -201,8 +203,8 @@ with demo:
201
  with gr.Column(min_width=320):
202
  filter_nb_shots = gr.CheckboxGroup(
203
  label="Number of shots",
204
- choices=["Zero-shot", "10-shot", "All"],
205
- value=["Zero-shot", "10-shot", "All"],
206
  interactive=True,
207
  elem_id="filter-nb-shots",
208
  )
@@ -272,7 +274,7 @@ with demo:
272
  ],
273
  leaderboard_table,
274
  )
275
- for selector in [shown_phenotypes, shown_metrics, shown_columns, filter_columns_type, filter_columns_precision, filter_columns_size, deleted_models_visibility]:
276
  selector.change(
277
  update_table,
278
  [
@@ -280,6 +282,7 @@ with demo:
280
  shown_columns,
281
  shown_phenotypes,
282
  shown_metrics,
 
283
  filter_columns_type,
284
  filter_columns_precision,
285
  filter_columns_size,
 
70
  columns: list,
71
  phenotypes: list,
72
  metrics: list,
73
+ nb_shots: list,
74
  type_query: list,
75
  precision_query: str,
76
  size_query: list,
77
  show_deleted: bool,
78
  query: str,
79
  ):
80
+ filtered_df = filter_models(hidden_df, type_query, size_query, precision_query, show_deleted, nb_shots)
81
  filtered_df = filter_queries(query, filtered_df)
82
  df = select_columns(filtered_df, columns, phenotypes, metrics)
83
  return df
 
125
 
126
 
127
  def filter_models(
128
+ df: pd.DataFrame, type_query: list, size_query: list, precision_query: list, show_deleted: bool, nb_shots: list) -> pd.DataFrame:
 
129
  # Show all models
130
  if show_deleted:
131
  filtered_df = df
 
135
  type_emoji = [t[0] for t in type_query]
136
  filtered_df = filtered_df.loc[df[AutoEvalColumn.model_type_symbol.name].isin(type_emoji)]
137
  filtered_df = filtered_df.loc[df[AutoEvalColumn.precision.name].isin(precision_query + ["None"])]
138
+ if -1 not in nb_shots:
139
+ filtered_df = filtered_df.loc[df[AutoEvalColumn.nb_shots.name].isin(nb_shots)]
140
+
141
  numeric_interval = pd.IntervalIndex(sorted([NUMERIC_INTERVALS[s] for s in size_query]))
142
  params_column = pd.to_numeric(df[AutoEvalColumn.params.name], errors="coerce")
143
  mask = params_column.apply(lambda x: any(numeric_interval.contains(x)))
 
203
  with gr.Column(min_width=320):
204
  filter_nb_shots = gr.CheckboxGroup(
205
  label="Number of shots",
206
+ choices=[("Zero-shot", 0), ("10-shot", 10), ("All", -1)],
207
+ value=[-1],
208
  interactive=True,
209
  elem_id="filter-nb-shots",
210
  )
 
274
  ],
275
  leaderboard_table,
276
  )
277
+ for selector in [shown_phenotypes, shown_metrics, shown_columns, filter_columns_type, filter_columns_precision, filter_columns_size, deleted_models_visibility, filter_nb_shots]:
278
  selector.change(
279
  update_table,
280
  [
 
282
  shown_columns,
283
  shown_phenotypes,
284
  shown_metrics,
285
+ filter_nb_shots,
286
  filter_columns_type,
287
  filter_columns_precision,
288
  filter_columns_size,