Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -109,7 +109,9 @@ def filter_leaderboard(df, model_name, sort_by):
|
|
| 109 |
return filtered_df
|
| 110 |
|
| 111 |
|
| 112 |
-
def create_scatter_plot(df, x_axis, y_axis):
|
|
|
|
|
|
|
| 113 |
fig = px.scatter(
|
| 114 |
df,
|
| 115 |
x=x_axis,
|
|
@@ -120,10 +122,27 @@ def create_scatter_plot(df, x_axis, y_axis):
|
|
| 120 |
trendline='ols',
|
| 121 |
trendline_options=dict(log_x=True, log_y=True),
|
| 122 |
color='highlighted',
|
| 123 |
-
color_discrete_map={True:
|
| 124 |
title=f'{y_axis} vs {x_axis}'
|
| 125 |
)
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
return fig
|
| 129 |
|
|
@@ -160,7 +179,7 @@ def update_leaderboard_and_plot(
|
|
| 160 |
else:
|
| 161 |
combined_df = filtered_df
|
| 162 |
|
| 163 |
-
fig = create_scatter_plot(combined_df, x_axis, y_axis)
|
| 164 |
display_df = combined_df.drop(columns=['highlighted'])
|
| 165 |
display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(precision=2)
|
| 166 |
return display_df, fig
|
|
|
|
| 109 |
return filtered_df
|
| 110 |
|
| 111 |
|
| 112 |
+
def create_scatter_plot(df, x_axis, y_axis, model_filter, highlight_filter):
|
| 113 |
+
selected_color = 'orange'
|
| 114 |
+
|
| 115 |
fig = px.scatter(
|
| 116 |
df,
|
| 117 |
x=x_axis,
|
|
|
|
| 122 |
trendline='ols',
|
| 123 |
trendline_options=dict(log_x=True, log_y=True),
|
| 124 |
color='highlighted',
|
| 125 |
+
color_discrete_map={True: selected_color, False: 'blue'},
|
| 126 |
title=f'{y_axis} vs {x_axis}'
|
| 127 |
)
|
| 128 |
+
|
| 129 |
+
# Create legend labels
|
| 130 |
+
legend_labels = {}
|
| 131 |
+
if highlight_filter:
|
| 132 |
+
legend_labels[True] = f'{highlight_filter}'
|
| 133 |
+
legend_labels[False] = f'{model_filter or "all models"}'
|
| 134 |
+
else:
|
| 135 |
+
legend_labels[False] = f'{model_filter or "all models"}'
|
| 136 |
+
|
| 137 |
+
# Update legend
|
| 138 |
+
for trace in fig.data:
|
| 139 |
+
if isinstance(trace.marker.color, str): # This is for the scatter traces
|
| 140 |
+
trace.name = legend_labels.get(trace.marker.color == selected_color, '')
|
| 141 |
+
|
| 142 |
+
fig.update_layout(
|
| 143 |
+
showlegend=True,
|
| 144 |
+
legend_title_text='Model Selection'
|
| 145 |
+
)
|
| 146 |
|
| 147 |
return fig
|
| 148 |
|
|
|
|
| 179 |
else:
|
| 180 |
combined_df = filtered_df
|
| 181 |
|
| 182 |
+
fig = create_scatter_plot(combined_df, x_axis, y_axis, model_name, highlight_name)
|
| 183 |
display_df = combined_df.drop(columns=['highlighted'])
|
| 184 |
display_df = display_df.style.apply(lambda x: ['background-color: #FFA500' if combined_df.loc[x.name, 'highlighted'] else '' for _ in x], axis=1).format(precision=2)
|
| 185 |
return display_df, fig
|