Spaces:
Sleeping
Sleeping
Commit
·
ff7de4b
1
Parent(s):
ce4ccf6
Improve Data Visualization
Browse files
app.py
CHANGED
|
@@ -4,7 +4,7 @@ import numpy as np
|
|
| 4 |
from bokeh.plotting import figure
|
| 5 |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button, HoverTool, LinearColorMapper, ColorBar, FuncTickFormatter, FixedTicker
|
| 6 |
from bokeh.layouts import column
|
| 7 |
-
from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9
|
| 8 |
from sklearn.decomposition import PCA
|
| 9 |
from sklearn.manifold import TSNE, trustworthiness
|
| 10 |
from sklearn.metrics import pairwise_distances
|
|
@@ -14,6 +14,8 @@ from sklearn.linear_model import LinearRegression
|
|
| 14 |
from scipy.stats import binned_statistic_2d
|
| 15 |
import json
|
| 16 |
import itertools
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
N_COMPONENTS = 3
|
|
@@ -1041,6 +1043,17 @@ def run_model(model_name):
|
|
| 1041 |
|
| 1042 |
# -------------------------------------------------------------------------
|
| 1043 |
# 4. Cálculo de distancias y scatter plot: Distance vs F1 (usando PC1 y PC2 globales)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1044 |
real_labels_new = sorted(df_all["real"]['label'].unique().tolist())
|
| 1045 |
df_distances_new = compute_cluster_distances_synthetic_individual(
|
| 1046 |
df_all["synthetic"],
|
|
@@ -1144,11 +1157,15 @@ def run_model(model_name):
|
|
| 1144 |
selected_feature = st.selectbox("Select heatmap feature:",
|
| 1145 |
options=feature_options, key=f"heatmap_{model_name}")
|
| 1146 |
select_extra_dataset_hm = st.selectbox("Select a dataset:",
|
| 1147 |
-
options=
|
| 1148 |
|
| 1149 |
# Definir el rango de posiciones (x, y)
|
| 1150 |
-
x_min, x_max = df_heatmap['x'].min(), df_heatmap['x'].max()
|
| 1151 |
-
y_min, y_max = df_heatmap['y'].min(), df_heatmap['y'].max()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1152 |
|
| 1153 |
grid_size = 50
|
| 1154 |
x_bins = np.linspace(x_min, x_max, grid_size + 1)
|
|
@@ -1177,7 +1194,15 @@ def run_model(model_name):
|
|
| 1177 |
# Transponer la matriz para alinear correctamente los ejes
|
| 1178 |
heatmap_data = heat_stat.T
|
| 1179 |
|
| 1180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1181 |
|
| 1182 |
heatmap_fig = figure(title=f"Heatmap de '{selected_feature}'",
|
| 1183 |
x_range=(x_min, x_max), y_range=(y_min, y_max),
|
|
@@ -1223,7 +1248,7 @@ def run_model(model_name):
|
|
| 1223 |
'img': df_extra['img'],
|
| 1224 |
'label': df_extra['name']
|
| 1225 |
})
|
| 1226 |
-
extra_renderer = heatmap_fig.circle('x', 'y', size=
|
| 1227 |
|
| 1228 |
hover_tool_points = HoverTool(renderers=[invisible_renderer], tooltips=TOOLTIPS)
|
| 1229 |
heatmap_fig.add_tools(hover_tool_points)
|
|
|
|
| 4 |
from bokeh.plotting import figure
|
| 5 |
from bokeh.models import ColumnDataSource, DataTable, TableColumn, CustomJS, Select, Button, HoverTool, LinearColorMapper, ColorBar, FuncTickFormatter, FixedTicker
|
| 6 |
from bokeh.layouts import column
|
| 7 |
+
from bokeh.palettes import Reds9, Blues9, Oranges9, Purples9, Greys9, BuGn9, Greens9, RdYlGn11, linear_palette
|
| 8 |
from sklearn.decomposition import PCA
|
| 9 |
from sklearn.manifold import TSNE, trustworthiness
|
| 10 |
from sklearn.metrics import pairwise_distances
|
|
|
|
| 14 |
from scipy.stats import binned_statistic_2d
|
| 15 |
import json
|
| 16 |
import itertools
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
import matplotlib.colors as mcolors
|
| 19 |
|
| 20 |
|
| 21 |
N_COMPONENTS = 3
|
|
|
|
| 1043 |
|
| 1044 |
# -------------------------------------------------------------------------
|
| 1045 |
# 4. Cálculo de distancias y scatter plot: Distance vs F1 (usando PC1 y PC2 globales)
|
| 1046 |
+
model_options = ["es-digital-paragraph-degradation-seq", "es-digital-line-degradation-seq", "es-digital-seq", "es-digital-rotation-degradation-seq", "es-digital-zoom-degradation-seq", "es-render-seq"]
|
| 1047 |
+
model_options_with_default = [""]
|
| 1048 |
+
model_options_with_default.extend(model_options)
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
# Genera una paleta de 256 colores basada en RdYlGn11
|
| 1052 |
+
cmap = plt.get_cmap("RdYlGn")
|
| 1053 |
+
red_green_palette = [mcolors.rgb2hex(cmap(i)) for i in np.linspace(0, 1, 256)]
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
|
| 1057 |
real_labels_new = sorted(df_all["real"]['label'].unique().tolist())
|
| 1058 |
df_distances_new = compute_cluster_distances_synthetic_individual(
|
| 1059 |
df_all["synthetic"],
|
|
|
|
| 1157 |
selected_feature = st.selectbox("Select heatmap feature:",
|
| 1158 |
options=feature_options, key=f"heatmap_{model_name}")
|
| 1159 |
select_extra_dataset_hm = st.selectbox("Select a dataset:",
|
| 1160 |
+
options=model_options_with_default, key=f"heatmap_extra_dataset_{model_name}")
|
| 1161 |
|
| 1162 |
# Definir el rango de posiciones (x, y)
|
| 1163 |
+
# x_min, x_max = df_heatmap['x'].min(), df_heatmap['x'].max()
|
| 1164 |
+
# y_min, y_max = df_heatmap['y'].min(), df_heatmap['y'].max()
|
| 1165 |
+
|
| 1166 |
+
x_min, x_max = -4, 4
|
| 1167 |
+
y_min, y_max = -4, 4
|
| 1168 |
+
|
| 1169 |
|
| 1170 |
grid_size = 50
|
| 1171 |
x_bins = np.linspace(x_min, x_max, grid_size + 1)
|
|
|
|
| 1194 |
# Transponer la matriz para alinear correctamente los ejes
|
| 1195 |
heatmap_data = heat_stat.T
|
| 1196 |
|
| 1197 |
+
if selected_feature in model_options:
|
| 1198 |
+
color_mapper = LinearColorMapper(
|
| 1199 |
+
palette=red_green_palette,
|
| 1200 |
+
low=0,
|
| 1201 |
+
high=1,
|
| 1202 |
+
nan_color='rgba(0, 0, 0, 0)'
|
| 1203 |
+
)
|
| 1204 |
+
else:
|
| 1205 |
+
color_mapper = LinearColorMapper(palette="Viridis256", low=np.nanmin(heatmap_data), high=np.nanmax(heatmap_data), nan_color='rgba(0, 0, 0, 0)')
|
| 1206 |
|
| 1207 |
heatmap_fig = figure(title=f"Heatmap de '{selected_feature}'",
|
| 1208 |
x_range=(x_min, x_max), y_range=(y_min, y_max),
|
|
|
|
| 1248 |
'img': df_extra['img'],
|
| 1249 |
'label': df_extra['name']
|
| 1250 |
})
|
| 1251 |
+
extra_renderer = heatmap_fig.circle('x', 'y', size=5, source=source_extra_points, fill_alpha=0, line_alpha=0.5, color="purple")
|
| 1252 |
|
| 1253 |
hover_tool_points = HoverTool(renderers=[invisible_renderer], tooltips=TOOLTIPS)
|
| 1254 |
heatmap_fig.add_tools(hover_tool_points)
|