de-Rodrigo commited on
Commit
8f0b327
2 Parent(s): 7b68f93 bfe8480

Merge branch 'main' of https://huggingface.co/spaces/de-Rodrigo/Embeddings

Browse files
Files changed (1) hide show
  1. app.py +132 -83
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from bokeh.plotting import figure
5
  from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button, HoverTool, LinearColorMapper, ColorBar, FuncTickFormatter, FixedTicker
6
  from bokeh.layouts import column
7
- from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE, trustworthiness
10
  from sklearn.metrics import pairwise_distances
@@ -14,6 +14,8 @@ from sklearn.linear_model import LinearRegression
14
  from scipy.stats import binned_statistic_2d
15
  import json
16
  import itertools
 
 
17
 
18
 
19
  N_COMPONENTS = 3
@@ -1041,6 +1043,17 @@ def run_model(model_name):
1041
 
1042
  # -------------------------------------------------------------------------
1043
  # 4. C谩lculo de distancias y scatter plot: Distance vs F1 (usando PC1 y PC2 globales)
 
 
 
 
 
 
 
 
 
 
 
1044
  real_labels_new = sorted(df_all["real"]['label'].unique().tolist())
1045
  df_distances_new = compute_cluster_distances_synthetic_individual(
1046
  df_all["synthetic"],
@@ -1132,103 +1145,139 @@ def run_model(model_name):
1132
  if 'img' not in df_all["real"].columns:
1133
  st.error("La columna 'img' no se encuentra en las muestras reales para hacer el merge con heatmaps.csv.")
1134
  else:
1135
- # Crear columna 'name' a partir del nombre final de la URL de la imagen
1136
  df_all["real"]["name"] = df_all["real"]["img"].apply(
1137
  lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x
1138
  )
1139
- # Realizar merge de las posiciones reales con el CSV de heatmaps
1140
- df_heatmap = pd.merge(df_all["real"], df_heat, on="name", how="inner")
1141
 
1142
- # Extraer las caracter铆sticas disponibles (excluyendo 'name')
1143
  feature_options = [col for col in df_heat.columns if col != "name"]
1144
  selected_feature = st.selectbox("Select heatmap feature:",
1145
  options=feature_options, key=f"heatmap_{model_name}")
1146
  select_extra_dataset_hm = st.selectbox("Select a dataset:",
1147
- options=["-", "es-digital-line-degradation-seq", "es-digital-seq", "es-digital-rotation-degradation-seq", "es-digital-zoom-degradation-seq", "es-render-seq"], key=f"heatmap_extra_dataset_{model_name}")
1148
-
1149
- # Definir el rango de posiciones (x, y)
1150
- x_min, x_max = df_heatmap['x'].min(), df_heatmap['x'].max()
1151
- y_min, y_max = df_heatmap['y'].min(), df_heatmap['y'].max()
1152
 
 
 
 
1153
  grid_size = 50
1154
  x_bins = np.linspace(x_min, x_max, grid_size + 1)
1155
  y_bins = np.linspace(y_min, y_max, grid_size + 1)
1156
 
1157
- cat_mapping = None
1158
- if df_heatmap[selected_feature].dtype == bool or not pd.api.types.is_numeric_dtype(df_heatmap[selected_feature]):
1159
- cat = df_heatmap[selected_feature].astype('category')
1160
- cat_mapping = list(cat.cat.categories)
1161
- df_heatmap[selected_feature] = cat.cat.codes
1162
-
1163
- try:
1164
- heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
1165
- df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
1166
- statistic='mean', bins=[x_bins, y_bins]
1167
- )
1168
- except TypeError:
1169
- cat = df_heatmap[selected_feature].astype('category')
1170
- cat_mapping = list(cat.cat.categories)
1171
- df_heatmap[selected_feature] = cat.cat.codes
1172
- heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
1173
- df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
1174
- statistic='mean', bins=[x_bins, y_bins]
1175
- )
1176
-
1177
- # Transponer la matriz para alinear correctamente los ejes
1178
- heatmap_data = heat_stat.T
1179
-
1180
- color_mapper = LinearColorMapper(palette="Viridis256", low=np.nanmin(heatmap_data), high=np.nanmax(heatmap_data), nan_color='rgba(0, 0, 0, 0)')
1181
-
1182
- heatmap_fig = figure(title=f"Heatmap de '{selected_feature}'",
1183
- x_range=(x_min, x_max), y_range=(y_min, y_max),
1184
- width=600, height=600,
1185
- tools="pan,wheel_zoom,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS)
1186
- heatmap_fig.image(image=[heatmap_data], x=x_min, y=y_min,
1187
- dw=x_max - x_min, dh=y_max - y_min,
1188
- color_mapper=color_mapper)
1189
-
1190
- color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0))
1191
- if cat_mapping is not None:
1192
- ticks = list(range(len(cat_mapping)))
1193
- color_bar.ticker = FixedTicker(ticks=ticks)
1194
- categories_json = json.dumps(cat_mapping)
1195
- color_bar.formatter = FuncTickFormatter(code=f"""
1196
- var categories = {categories_json};
1197
- var index = Math.round(tick);
1198
- if(index >= 0 && index < categories.length) {{
1199
- return categories[index];
1200
- }} else {{
1201
- return "";
1202
- }}
1203
- """)
1204
- heatmap_fig.add_layout(color_bar, 'right')
1205
-
1206
- source_points = ColumnDataSource(data={
1207
- 'x': df_heatmap['x'],
1208
- 'y': df_heatmap['y'],
1209
- 'img': df_heatmap['img'],
1210
- 'label': df_heatmap['name']
1211
- })
1212
- invisible_renderer = heatmap_fig.circle('x', 'y', size=10, source=source_points, fill_alpha=0, line_alpha=0.5)
1213
-
1214
- if select_extra_dataset_hm != "-":
1215
- df_extra = df_all["synthetic"][df_all["synthetic"]["source"] == select_extra_dataset_hm]
1216
- if 'name' not in df_extra.columns:
1217
- df_extra["name"] = df_extra["img"].apply(
1218
- lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x
1219
  )
1220
- source_extra_points = ColumnDataSource(data={
1221
- 'x': df_extra['x'],
1222
- 'y': df_extra['y'],
1223
- 'img': df_extra['img'],
1224
- 'label': df_extra['name']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1225
  })
1226
- extra_renderer = heatmap_fig.circle('x', 'y', size=10, source=source_extra_points, fill_alpha=0, line_alpha=0.5, color="red")
1227
-
1228
- hover_tool_points = HoverTool(renderers=[invisible_renderer], tooltips=TOOLTIPS)
1229
- heatmap_fig.add_tools(hover_tool_points)
1230
-
1231
- st.bokeh_chart(heatmap_fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1232
 
1233
  def main():
1234
  config_style()
 
4
  from bokeh.plotting import figure
5
  from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button, HoverTool, LinearColorMapper, ColorBar, FuncTickFormatter, FixedTicker
6
  from bokeh.layouts import column
7
+ from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9, RdYlGn11, linear_palette
8
  from sklearn.decomposition import PCA
9
  from sklearn.manifold import TSNE, trustworthiness
10
  from sklearn.metrics import pairwise_distances
 
14
  from scipy.stats import binned_statistic_2d
15
  import json
16
  import itertools
17
+ import matplotlib.pyplot as plt
18
+ import matplotlib.colors as mcolors
19
 
20
 
21
  N_COMPONENTS = 3
 
1043
 
1044
  # -------------------------------------------------------------------------
1045
  # 4. C谩lculo de distancias y scatter plot: Distance vs F1 (usando PC1 y PC2 globales)
1046
+ model_options = ["es-digital-paragraph-degradation-seq", "es-digital-line-degradation-seq", "es-digital-seq", "es-digital-rotation-degradation-seq", "es-digital-zoom-degradation-seq", "es-render-seq"]
1047
+ model_options_with_default = [""]
1048
+ model_options_with_default.extend(model_options)
1049
+
1050
+
1051
+ # Genera una paleta de 256 colores basada en RdYlGn11
1052
+ cmap = plt.get_cmap("RdYlGn")
1053
+ red_green_palette = [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, 256)]
1054
+
1055
+
1056
+
1057
  real_labels_new = sorted(df_all["real"]['label'].unique().tolist())
1058
  df_distances_new = compute_cluster_distances_synthetic_individual(
1059
  df_all["synthetic"],
 
1145
  if 'img' not in df_all["real"].columns:
1146
  st.error("La columna 'img' no se encuentra en las muestras reales para hacer el merge con heatmaps.csv.")
1147
  else:
1148
+ # Crear columna 'name' en las muestras reales (si a煤n no existe)
1149
  df_all["real"]["name"] = df_all["real"]["img"].apply(
1150
  lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x
1151
  )
1152
+ # Merge de las posiciones reales con el CSV de heatmaps (se usa el merge base)
1153
+ df_heatmap_base = pd.merge(df_all["real"], df_heat, on="name", how="inner")
1154
 
1155
+ # Extraer opciones de feature (excluyendo 'name')
1156
  feature_options = [col for col in df_heat.columns if col != "name"]
1157
  selected_feature = st.selectbox("Select heatmap feature:",
1158
  options=feature_options, key=f"heatmap_{model_name}")
1159
  select_extra_dataset_hm = st.selectbox("Select a dataset:",
1160
+ options=model_options_with_default, key=f"heatmap_extra_dataset_{model_name}")
 
 
 
 
1161
 
1162
+ # Definir un rango fijo para los ejes (por ejemplo, de -4 a 4) y rejilla
1163
+ x_min, x_max = -4, 4
1164
+ y_min, y_max = -4, 4
1165
  grid_size = 50
1166
  x_bins = np.linspace(x_min, x_max, grid_size + 1)
1167
  y_bins = np.linspace(y_min, y_max, grid_size + 1)
1168
 
1169
+ # Generar heatmaps para cada combinaci贸n de componentes
1170
+ pairs = list(itertools.combinations(range(N_COMPONENTS), 2))
1171
+ for (i, j) in pairs:
1172
+ x_comp = f'PC{i+1}'
1173
+ y_comp = f'PC{j+1}'
1174
+ st.markdown(f"### Heatmap: {x_comp} vs {y_comp}")
1175
+
1176
+ # Crear un DataFrame de heatmap para la combinaci贸n actual a partir del merge base
1177
+ df_heatmap = df_heatmap_base.copy()
1178
+ df_heatmap["x"] = df_heatmap[x_comp]
1179
+ df_heatmap["y"] = df_heatmap[y_comp]
1180
+
1181
+ # Si la feature seleccionada no es num茅rica, convertir a c贸digos y guardar la correspondencia
1182
+ cat_mapping = None
1183
+ if df_heatmap[selected_feature].dtype == bool or not pd.api.types.is_numeric_dtype(df_heatmap[selected_feature]):
1184
+ cat = df_heatmap[selected_feature].astype('category')
1185
+ cat_mapping = list(cat.cat.categories)
1186
+ df_heatmap[selected_feature] = cat.cat.codes
1187
+
1188
+ # Calcular la estad铆stica binned (por ejemplo, la media) en la rejilla
1189
+ try:
1190
+ heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
1191
+ df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
1192
+ statistic='mean', bins=[x_bins, y_bins]
1193
+ )
1194
+ except TypeError:
1195
+ cat = df_heatmap[selected_feature].astype('category')
1196
+ cat_mapping = list(cat.cat.categories)
1197
+ df_heatmap[selected_feature] = cat.cat.codes
1198
+ heat_stat, x_edges, y_edges, binnumber = binned_statistic_2d(
1199
+ df_heatmap['x'], df_heatmap['y'], df_heatmap[selected_feature],
1200
+ statistic='mean', bins=[x_bins, y_bins]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1201
  )
1202
+ # Transponer la matriz para alinear correctamente los ejes
1203
+ heatmap_data = heat_stat.T
1204
+
1205
+ # Definir el color mapper: si la feature es de un modelo (model_options) usar la paleta rojo-verde con rango 0 a 1
1206
+ if selected_feature in model_options:
1207
+ color_mapper = LinearColorMapper(
1208
+ palette=red_green_palette,
1209
+ low=0,
1210
+ high=1,
1211
+ nan_color='rgba(0, 0, 0, 0)'
1212
+ )
1213
+ else:
1214
+ color_mapper = LinearColorMapper(
1215
+ palette="Viridis256",
1216
+ low=np.nanmin(heatmap_data),
1217
+ high=np.nanmax(heatmap_data),
1218
+ nan_color='rgba(0, 0, 0, 0)'
1219
+ )
1220
+
1221
+ # Crear la figura para el heatmap con la misma escala para x e y
1222
+ heatmap_fig = figure(title=f"Heatmap de '{selected_feature}' ({x_comp} vs {y_comp})",
1223
+ x_range=(x_min, x_max), y_range=(y_min, y_max),
1224
+ width=600, height=600,
1225
+ tools="pan,wheel_zoom,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS,
1226
+ sizing_mode="fixed")
1227
+ heatmap_fig.match_aspect = True
1228
+
1229
+ heatmap_fig.xaxis.axis_label = x_comp
1230
+ heatmap_fig.yaxis.axis_label = y_comp
1231
+ # Dibujar la imagen del heatmap
1232
+ heatmap_fig.image(image=[heatmap_data], x=x_min, y=y_min,
1233
+ dw=x_max - x_min, dh=y_max - y_min,
1234
+ color_mapper=color_mapper)
1235
+
1236
+ # Agregar la barra de color
1237
+ color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0))
1238
+ if cat_mapping is not None:
1239
+ ticks = list(range(len(cat_mapping)))
1240
+ color_bar.ticker = FixedTicker(ticks=ticks)
1241
+ categories_json = json.dumps(cat_mapping)
1242
+ color_bar.formatter = FuncTickFormatter(code=f"""
1243
+ var categories = {categories_json};
1244
+ var index = Math.round(tick);
1245
+ if(index >= 0 && index < categories.length) {{
1246
+ return categories[index];
1247
+ }} else {{
1248
+ return "";
1249
+ }}
1250
+ """)
1251
+ heatmap_fig.add_layout(color_bar, 'right')
1252
+
1253
+ # Agregar renderer invisible para tooltips (usando puntos en cada bin)
1254
+ source_points = ColumnDataSource(data={
1255
+ 'x': df_heatmap['x'],
1256
+ 'y': df_heatmap['y'],
1257
+ 'img': df_heatmap['img'],
1258
+ 'label': df_heatmap['name']
1259
  })
1260
+ invisible_renderer = heatmap_fig.circle('x', 'y', size=10, source=source_points, fill_alpha=0, line_alpha=0.5)
1261
+
1262
+ # Si se selecciona un dataset extra, proyectar sus puntos en la combinaci贸n actual
1263
+ if select_extra_dataset_hm != "-":
1264
+ df_extra = df_all["synthetic"][df_all["synthetic"]["source"] == select_extra_dataset_hm].copy()
1265
+ df_extra["x"] = df_extra[x_comp]
1266
+ df_extra["y"] = df_extra[y_comp]
1267
+ if 'name' not in df_extra.columns:
1268
+ df_extra["name"] = df_extra["img"].apply(lambda x: x.split("/")[-1].replace(".png", "") if isinstance(x, str) else x)
1269
+ source_extra_points = ColumnDataSource(data={
1270
+ 'x': df_extra['x'],
1271
+ 'y': df_extra['y'],
1272
+ 'img': df_extra['img'],
1273
+ 'label': df_extra['name']
1274
+ })
1275
+ extra_renderer = heatmap_fig.circle('x', 'y', size=5, source=source_extra_points, fill_alpha=0, line_alpha=0.5, color="purple")
1276
+
1277
+ hover_tool_points = HoverTool(renderers=[invisible_renderer], tooltips=TOOLTIPS)
1278
+ heatmap_fig.add_tools(hover_tool_points)
1279
+
1280
+ st.bokeh_chart(heatmap_fig)
1281
 
1282
  def main():
1283
  config_style()