Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,6 +2,8 @@ import gradio as gr
|
|
| 2 |
import json
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
|
|
|
|
|
|
| 5 |
import operator
|
| 6 |
|
| 7 |
pd.options.plotting.backend = "plotly"
|
|
@@ -9,6 +11,12 @@ pd.options.plotting.backend = "plotly"
|
|
| 9 |
|
| 10 |
TITLE = "Diffusion Professions Cluster Explorer"
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
clusters_dicts = dict(
|
| 13 |
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
|
| 14 |
for num_cl in [12, 24, 48]
|
|
@@ -142,7 +150,8 @@ def show_examplars(num_clusters, prof_name, mod_name, cl_id):
|
|
| 142 |
examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
|
| 143 |
"cluster_examplars"
|
| 144 |
][str(cl_id)]
|
| 145 |
-
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
with gr.Blocks(title=TITLE) as demo:
|
|
@@ -249,7 +258,7 @@ with gr.Blocks(title=TITLE) as demo:
|
|
| 249 |
)
|
| 250 |
with gr.Row():
|
| 251 |
examplars_plot = (
|
| 252 |
-
gr.
|
| 253 |
) # TODO: turn this into a plot with the actual images
|
| 254 |
demo.load(
|
| 255 |
show_examplars,
|
|
|
|
| 2 |
import json
|
| 3 |
import numpy as np
|
| 4 |
import pandas as pd
|
| 5 |
+
from datasets import load_from_disk
|
| 6 |
+
from itertools import chain
|
| 7 |
import operator
|
| 8 |
|
| 9 |
pd.options.plotting.backend = "plotly"
|
|
|
|
| 11 |
|
| 12 |
TITLE = "Diffusion Professions Cluster Explorer"
|
| 13 |
|
| 14 |
+
professions = load_from_disk("professions")
|
| 15 |
+
professions_df = professions.to_pandas()
|
| 16 |
+
|
| 17 |
+
def get_image(model, fname):
|
| 18 |
+
return professions.select(professions_df[(professions_df["image_path"]==fname) & (professions_df["model"]==model)].index)["image"][0]
|
| 19 |
+
|
| 20 |
clusters_dicts = dict(
|
| 21 |
(num_cl, json.load(open(f"clusters/professions_to_clusters_{num_cl}.json")))
|
| 22 |
for num_cl in [12, 24, 48]
|
|
|
|
| 150 |
examplars_dict = clusters_dicts[num_clusters][df_models[mod_name]][prof_name][
|
| 151 |
"cluster_examplars"
|
| 152 |
][str(cl_id)]
|
| 153 |
+
l = list(chain(*[examplars_dict[k] for k in examplars_dict]))
|
| 154 |
+
return [get_image(model,fname) for _,model,fname in l]
|
| 155 |
|
| 156 |
|
| 157 |
with gr.Blocks(title=TITLE) as demo:
|
|
|
|
| 258 |
)
|
| 259 |
with gr.Row():
|
| 260 |
examplars_plot = (
|
| 261 |
+
gr.Gallery()
|
| 262 |
) # TODO: turn this into a plot with the actual images
|
| 263 |
demo.load(
|
| 264 |
show_examplars,
|