Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -68,6 +68,8 @@ def load_leaderboard():
|
|
| 68 |
result.drop('arch_name', axis=1, inplace=True)
|
| 69 |
result.drop('crop_pct', axis=1, inplace=True)
|
| 70 |
result.drop('interpolation', axis=1, inplace=True)
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Round numerical values
|
| 73 |
result = result.round(2)
|
|
@@ -106,6 +108,7 @@ def filter_leaderboard(df, model_name, sort_by):
|
|
| 106 |
|
| 107 |
return filtered_df
|
| 108 |
|
|
|
|
| 109 |
def create_scatter_plot(df, x_axis, y_axis):
|
| 110 |
fig = px.scatter(
|
| 111 |
df,
|
|
@@ -116,10 +119,15 @@ def create_scatter_plot(df, x_axis, y_axis):
|
|
| 116 |
hover_data=['model'],
|
| 117 |
trendline='ols',
|
| 118 |
trendline_options=dict(log_x=True, log_y=True),
|
|
|
|
|
|
|
| 119 |
title=f'{y_axis} vs {x_axis}'
|
| 120 |
)
|
|
|
|
|
|
|
| 121 |
return fig
|
| 122 |
|
|
|
|
| 123 |
# Load the leaderboard data
|
| 124 |
full_df = load_leaderboard()
|
| 125 |
|
|
@@ -132,14 +140,30 @@ DEFAULT_SORT = "avg_top1"
|
|
| 132 |
DEFAULT_X = "infer_samples_per_sec"
|
| 133 |
DEFAULT_Y = "avg_top1"
|
| 134 |
|
| 135 |
-
def update_leaderboard_and_plot(
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
with gr.Blocks(title="The timm Leaderboard") as app:
|
|
@@ -148,8 +172,11 @@ with gr.Blocks(title="The timm Leaderboard") as app:
|
|
| 148 |
gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
|
| 149 |
|
| 150 |
with gr.Row():
|
| 151 |
-
search_bar = gr.Textbox(lines=1, label="
|
| 152 |
sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
with gr.Row():
|
| 155 |
x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
|
|
@@ -164,13 +191,17 @@ with gr.Blocks(title="The timm Leaderboard") as app:
|
|
| 164 |
|
| 165 |
search_bar.submit(
|
| 166 |
update_leaderboard_and_plot,
|
| 167 |
-
inputs=[search_bar, sort_dropdown, x_axis, y_axis],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
outputs=[leaderboard, plot]
|
| 169 |
)
|
| 170 |
-
|
| 171 |
update_btn.click(
|
| 172 |
update_leaderboard_and_plot,
|
| 173 |
-
inputs=[search_bar, sort_dropdown, x_axis, y_axis],
|
| 174 |
outputs=[leaderboard, plot]
|
| 175 |
)
|
| 176 |
|
|
|
|
| 68 |
result.drop('arch_name', axis=1, inplace=True)
|
| 69 |
result.drop('crop_pct', axis=1, inplace=True)
|
| 70 |
result.drop('interpolation', axis=1, inplace=True)
|
| 71 |
+
|
| 72 |
+
result['highlighted'] = False
|
| 73 |
|
| 74 |
# Round numerical values
|
| 75 |
result = result.round(2)
|
|
|
|
| 108 |
|
| 109 |
return filtered_df
|
| 110 |
|
| 111 |
+
|
| 112 |
def create_scatter_plot(df, x_axis, y_axis):
|
| 113 |
fig = px.scatter(
|
| 114 |
df,
|
|
|
|
| 119 |
hover_data=['model'],
|
| 120 |
trendline='ols',
|
| 121 |
trendline_options=dict(log_x=True, log_y=True),
|
| 122 |
+
color='highlighted',
|
| 123 |
+
color_discrete_map={True: 'red', False: 'blue'},
|
| 124 |
title=f'{y_axis} vs {x_axis}'
|
| 125 |
)
|
| 126 |
+
fig.update_layout(showlegend=False)
|
| 127 |
+
|
| 128 |
return fig
|
| 129 |
|
| 130 |
+
|
| 131 |
# Load the leaderboard data
|
| 132 |
full_df = load_leaderboard()
|
| 133 |
|
|
|
|
| 140 |
DEFAULT_X = "infer_samples_per_sec"
|
| 141 |
DEFAULT_Y = "avg_top1"
|
| 142 |
|
| 143 |
+
def update_leaderboard_and_plot(
|
| 144 |
+
model_name=DEFAULT_SEARCH,
|
| 145 |
+
highlight_name=None,
|
| 146 |
+
sort_by=DEFAULT_SORT,
|
| 147 |
+
x_axis=DEFAULT_X,
|
| 148 |
+
y_axis=DEFAULT_Y,
|
| 149 |
+
):
|
| 150 |
+
filtered_df = filter_leaderboard(full_df, model_name, sort_by)
|
| 151 |
+
|
| 152 |
+
# Apply the highlight filter to the entire dataset so the output will be union (comparison) if the filters are disjoint
|
| 153 |
+
highlight_df = filter_leaderboard(full_df, highlight_name, sort_by) if highlight_name else None
|
| 154 |
+
|
| 155 |
+
# Combine filtered_df and highlight_df, removing duplicates
|
| 156 |
+
if highlight_df is not None:
|
| 157 |
+
combined_df = pd.concat([filtered_df, highlight_df]).drop_duplicates().reset_index(drop=True)
|
| 158 |
+
combined_df = combined_df.sort_values(by=sort_by, ascending=False)
|
| 159 |
+
combined_df['highlighted'] = combined_df['model'].isin(highlight_df['model'])
|
| 160 |
+
else:
|
| 161 |
+
combined_df = filtered_df
|
| 162 |
+
|
| 163 |
+
fig = create_scatter_plot(combined_df, x_axis, y_axis)
|
| 164 |
+
highlighted_df = combined_df.style.apply(lambda x: ['background-color: #ffcccc' if x['highlighted'] else '' for _ in x], axis=1)
|
| 165 |
+
|
| 166 |
+
return highlighted_df, fig
|
| 167 |
|
| 168 |
|
| 169 |
with gr.Blocks(title="The timm Leaderboard") as app:
|
|
|
|
| 172 |
gr.HTML("<p>Search tips:<br>- Use wildcards (* or ?) for pattern matching<br>- Use 're:' prefix for regex search<br>- Otherwise, fuzzy matching will be used</p>")
|
| 173 |
|
| 174 |
with gr.Row():
|
| 175 |
+
search_bar = gr.Textbox(lines=1, label="Model Filter", placeholder="e.g. resnet*, re:^vit, efficientnet", scale=3)
|
| 176 |
sort_dropdown = gr.Dropdown(choices=sort_columns, label="Sort by", value=DEFAULT_SORT, scale=1)
|
| 177 |
+
|
| 178 |
+
with gr.Row():
|
| 179 |
+
highlight_bar = gr.Textbox(lines=1, label="Model Highlight/Compare Filter", placeholder="e.g. convnext*, re:^efficient")
|
| 180 |
|
| 181 |
with gr.Row():
|
| 182 |
x_axis = gr.Dropdown(choices=plot_columns, label="X-axis", value=DEFAULT_X)
|
|
|
|
| 191 |
|
| 192 |
search_bar.submit(
|
| 193 |
update_leaderboard_and_plot,
|
| 194 |
+
inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
|
| 195 |
+
outputs=[leaderboard, plot]
|
| 196 |
+
)
|
| 197 |
+
highlight_bar.submit(
|
| 198 |
+
update_leaderboard_and_plot,
|
| 199 |
+
inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
|
| 200 |
outputs=[leaderboard, plot]
|
| 201 |
)
|
|
|
|
| 202 |
update_btn.click(
|
| 203 |
update_leaderboard_and_plot,
|
| 204 |
+
inputs=[search_bar, highlight_bar, sort_dropdown, x_axis, y_axis],
|
| 205 |
outputs=[leaderboard, plot]
|
| 206 |
)
|
| 207 |
|