Commit
·
c34e772
1
Parent(s):
8157f53
feat: Add update colours button
Browse files
app.py
CHANGED
|
@@ -232,40 +232,30 @@ DATASETS = [
|
|
| 232 |
]
|
| 233 |
|
| 234 |
|
| 235 |
-
def
|
| 236 |
-
"""
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
|
| 242 |
-
|
| 243 |
-
[language.name for language in ALL_LANGUAGES.values()],
|
| 244 |
-
key=lambda language_name: language_name.lower(),
|
| 245 |
-
)
|
| 246 |
-
danish_models = sorted(
|
| 247 |
-
list({model_id for model_id in results_dfs[DANISH].index}),
|
| 248 |
-
key=lambda model_id: model_id.lower(),
|
| 249 |
-
)
|
| 250 |
|
| 251 |
# Get distinct RGB values for all models
|
| 252 |
all_models = list(
|
| 253 |
{model_id for df in results_dfs.values() for model_id in df.index}
|
| 254 |
)
|
| 255 |
-
colour_mapping
|
| 256 |
|
| 257 |
for i in it.count():
|
| 258 |
min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
|
| 259 |
-
|
| 260 |
-
if i > 0:
|
| 261 |
-
logger.info(
|
| 262 |
-
f"All retries failed. Trying again with min colour distance "
|
| 263 |
-
f"{min_colour_distance}."
|
| 264 |
-
)
|
| 265 |
-
|
| 266 |
retries_left = 10 * len(all_models)
|
| 267 |
for model_id in all_models:
|
| 268 |
-
random.seed(hash(model_id) + i)
|
| 269 |
r, g, b = 0, 0, 0
|
| 270 |
too_bright, similar_to_other_model = True, True
|
| 271 |
while (too_bright or similar_to_other_model) and retries_left > 0:
|
|
@@ -287,6 +277,28 @@ def main() -> None:
|
|
| 287 |
)
|
| 288 |
break
|
| 289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
| 291 |
gr.Markdown(INTRO_MARKDOWN)
|
| 292 |
|
|
@@ -340,6 +352,11 @@ def main() -> None:
|
|
| 340 |
interactive=True,
|
| 341 |
scale=1,
|
| 342 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
with gr.Row():
|
| 344 |
plot = gr.Plot(
|
| 345 |
value=produce_radial_plot(
|
|
@@ -349,7 +366,6 @@ def main() -> None:
|
|
| 349 |
show_scale=show_scale_checkbox.value,
|
| 350 |
plot_width=plot_width_slider.value,
|
| 351 |
plot_height=plot_height_slider.value,
|
| 352 |
-
colour_mapping=colour_mapping,
|
| 353 |
results_dfs=results_dfs,
|
| 354 |
),
|
| 355 |
)
|
|
@@ -371,7 +387,6 @@ def main() -> None:
|
|
| 371 |
update_plot_kwargs = dict(
|
| 372 |
fn=partial(
|
| 373 |
produce_radial_plot,
|
| 374 |
-
colour_mapping=colour_mapping,
|
| 375 |
results_dfs=results_dfs,
|
| 376 |
),
|
| 377 |
inputs=[
|
|
@@ -391,6 +406,11 @@ def main() -> None:
|
|
| 391 |
plot_width_slider.change(**update_plot_kwargs)
|
| 392 |
plot_height_slider.change(**update_plot_kwargs)
|
| 393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
demo.launch()
|
| 395 |
|
| 396 |
|
|
@@ -483,7 +503,6 @@ def produce_radial_plot(
|
|
| 483 |
show_scale: bool,
|
| 484 |
plot_width: int,
|
| 485 |
plot_height: int,
|
| 486 |
-
colour_mapping: dict[str, tuple[int, int, int]],
|
| 487 |
results_dfs: dict[Language, pd.DataFrame] | None,
|
| 488 |
) -> go.Figure:
|
| 489 |
"""Produce a radial plot as a plotly figure.
|
|
@@ -501,8 +520,6 @@ def produce_radial_plot(
|
|
| 501 |
The width of the plot.
|
| 502 |
plot_height:
|
| 503 |
The height of the plot.
|
| 504 |
-
colour_mapping:
|
| 505 |
-
A mapping from model ids to RGB triplets.
|
| 506 |
results_dfs:
|
| 507 |
The results dataframes for each language.
|
| 508 |
|
|
|
|
| 232 |
]
|
| 233 |
|
| 234 |
|
| 235 |
+
def update_colour_mapping(results_dfs: dict[Language, pd.DataFrame]) -> None:
|
| 236 |
+
"""Get a mapping from model ids to RGB triplets.
|
| 237 |
|
| 238 |
+
Args:
|
| 239 |
+
results_dfs:
|
| 240 |
+
The results dataframes for each language.
|
| 241 |
+
"""
|
| 242 |
+
global colour_mapping
|
| 243 |
+
global seed
|
| 244 |
+
seed += 1
|
| 245 |
|
| 246 |
+
gr.Info(f"Updating colour mapping...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
# Get distinct RGB values for all models
|
| 249 |
all_models = list(
|
| 250 |
{model_id for df in results_dfs.values() for model_id in df.index}
|
| 251 |
)
|
| 252 |
+
colour_mapping = dict()
|
| 253 |
|
| 254 |
for i in it.count():
|
| 255 |
min_colour_distance = MIN_COLOUR_DISTANCE_BETWEEN_MODELS - i
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
retries_left = 10 * len(all_models)
|
| 257 |
for model_id in all_models:
|
| 258 |
+
random.seed(hash(model_id) + i + seed)
|
| 259 |
r, g, b = 0, 0, 0
|
| 260 |
too_bright, similar_to_other_model = True, True
|
| 261 |
while (too_bright or similar_to_other_model) and retries_left > 0:
|
|
|
|
| 277 |
)
|
| 278 |
break
|
| 279 |
|
| 280 |
+
|
| 281 |
+
def main() -> None:
|
| 282 |
+
"""Produce a radial plot."""
|
| 283 |
+
|
| 284 |
+
global last_fetch
|
| 285 |
+
results_dfs = fetch_results()
|
| 286 |
+
last_fetch = dt.datetime.now()
|
| 287 |
+
|
| 288 |
+
all_languages = sorted(
|
| 289 |
+
[language.name for language in ALL_LANGUAGES.values()],
|
| 290 |
+
key=lambda language_name: language_name.lower(),
|
| 291 |
+
)
|
| 292 |
+
danish_models = sorted(
|
| 293 |
+
list({model_id for model_id in results_dfs[DANISH].index}),
|
| 294 |
+
key=lambda model_id: model_id.lower(),
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
global colour_mapping
|
| 298 |
+
global seed
|
| 299 |
+
seed = 4242
|
| 300 |
+
update_colour_mapping(results_dfs=results_dfs)
|
| 301 |
+
|
| 302 |
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
| 303 |
gr.Markdown(INTRO_MARKDOWN)
|
| 304 |
|
|
|
|
| 352 |
interactive=True,
|
| 353 |
scale=1,
|
| 354 |
)
|
| 355 |
+
update_colours_button = gr.Button(
|
| 356 |
+
value="Update colours",
|
| 357 |
+
interactive=True,
|
| 358 |
+
scale=1,
|
| 359 |
+
)
|
| 360 |
with gr.Row():
|
| 361 |
plot = gr.Plot(
|
| 362 |
value=produce_radial_plot(
|
|
|
|
| 366 |
show_scale=show_scale_checkbox.value,
|
| 367 |
plot_width=plot_width_slider.value,
|
| 368 |
plot_height=plot_height_slider.value,
|
|
|
|
| 369 |
results_dfs=results_dfs,
|
| 370 |
),
|
| 371 |
)
|
|
|
|
| 387 |
update_plot_kwargs = dict(
|
| 388 |
fn=partial(
|
| 389 |
produce_radial_plot,
|
|
|
|
| 390 |
results_dfs=results_dfs,
|
| 391 |
),
|
| 392 |
inputs=[
|
|
|
|
| 406 |
plot_width_slider.change(**update_plot_kwargs)
|
| 407 |
plot_height_slider.change(**update_plot_kwargs)
|
| 408 |
|
| 409 |
+
# Update colours when the button is clicked
|
| 410 |
+
update_colours_button.click(
|
| 411 |
+
fn=partial(update_colour_mapping, results_dfs=results_dfs),
|
| 412 |
+
).then(**update_plot_kwargs)
|
| 413 |
+
|
| 414 |
demo.launch()
|
| 415 |
|
| 416 |
|
|
|
|
| 503 |
show_scale: bool,
|
| 504 |
plot_width: int,
|
| 505 |
plot_height: int,
|
|
|
|
| 506 |
results_dfs: dict[Language, pd.DataFrame] | None,
|
| 507 |
) -> go.Figure:
|
| 508 |
"""Produce a radial plot as a plotly figure.
|
|
|
|
| 520 |
The width of the plot.
|
| 521 |
plot_height:
|
| 522 |
The height of the plot.
|
|
|
|
|
|
|
| 523 |
results_dfs:
|
| 524 |
The results dataframes for each language.
|
| 525 |
|