Spaces:
Sleeping
Sleeping
Commit
·
71eb50f
1
Parent(s):
b79fb2d
Add Dimensions
Browse files
app.py
CHANGED
|
@@ -12,6 +12,10 @@ import io
|
|
| 12 |
import ot
|
| 13 |
from sklearn.linear_model import LinearRegression
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
TOOLTIPS = """
|
| 16 |
<div>
|
| 17 |
<div>
|
|
@@ -37,10 +41,6 @@ def config_style():
|
|
| 37 |
""", unsafe_allow_html=True)
|
| 38 |
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
|
| 39 |
|
| 40 |
-
# =============================================================================
|
| 41 |
-
# Funciones de carga de datos y procesamiento (sin cambios en su mayoría)
|
| 42 |
-
# =============================================================================
|
| 43 |
-
|
| 44 |
def load_embeddings(model, version):
|
| 45 |
if model == "Donut":
|
| 46 |
df_real = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_secret_all_embeddings.csv")
|
|
@@ -95,8 +95,10 @@ def load_embeddings(model, version):
|
|
| 95 |
return None
|
| 96 |
|
| 97 |
def split_versions(df_combined, reduced):
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
| 100 |
df_real = df_combined[df_combined["version"] == "real"].copy()
|
| 101 |
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
|
| 102 |
unique_real = sorted(df_real['label'].unique().tolist())
|
|
@@ -107,10 +109,14 @@ def split_versions(df_combined, reduced):
|
|
| 107 |
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
|
| 108 |
return df_dict, unique_subsets
|
| 109 |
|
| 110 |
-
|
| 111 |
-
#
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def compute_cluster_distance(synthetic_points, real_points, metric="wasserstein", bins=20):
|
| 116 |
if metric.lower() == "wasserstein":
|
|
@@ -125,13 +131,14 @@ def compute_cluster_distance(synthetic_points, real_points, metric="wasserstein"
|
|
| 125 |
center_real = np.mean(real_points, axis=0)
|
| 126 |
return np.linalg.norm(center_syn - center_real)
|
| 127 |
elif metric.lower() == "kl":
|
|
|
|
| 128 |
all_points = np.vstack([synthetic_points, real_points])
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
H_syn, _
|
| 134 |
-
H_real, _
|
| 135 |
eps = 1e-10
|
| 136 |
P = H_syn + eps
|
| 137 |
Q = H_real + eps
|
|
@@ -147,26 +154,22 @@ def compute_cluster_distances_synthetic_individual(synthetic_df: pd.DataFrame, d
|
|
| 147 |
groups = synthetic_df.groupby(['source', 'label'])
|
| 148 |
for (source, label), group in groups:
|
| 149 |
key = f"{label} ({source})"
|
| 150 |
-
data = group
|
| 151 |
distances[key] = {}
|
| 152 |
for real_label in real_labels:
|
| 153 |
-
real_data = df_real[df_real['label'] == real_label]
|
| 154 |
d = compute_cluster_distance(data, real_data, metric=metric, bins=bins)
|
| 155 |
distances[key][real_label] = d
|
| 156 |
for source, group in synthetic_df.groupby('source'):
|
| 157 |
key = f"Global ({source})"
|
| 158 |
-
data = group
|
| 159 |
distances[key] = {}
|
| 160 |
for real_label in real_labels:
|
| 161 |
-
real_data = df_real[df_real['label'] == real_label]
|
| 162 |
d = compute_cluster_distance(data, real_data, metric=metric, bins=bins)
|
| 163 |
distances[key][real_label] = d
|
| 164 |
return pd.DataFrame(distances).T
|
| 165 |
|
| 166 |
-
# =============================================================================
|
| 167 |
-
# Función para calcular continuidad (mide la preservación de la vecindad original en el embedding)
|
| 168 |
-
# =============================================================================
|
| 169 |
-
|
| 170 |
def compute_continuity(X, X_embedded, n_neighbors=5):
|
| 171 |
n = X.shape[0]
|
| 172 |
D_high = pairwise_distances(X, metric='euclidean')
|
|
@@ -187,10 +190,6 @@ def compute_continuity(X, X_embedded, n_neighbors=5):
|
|
| 187 |
continuity_value = 1 - norm * total
|
| 188 |
return continuity_value
|
| 189 |
|
| 190 |
-
# =============================================================================
|
| 191 |
-
# Funciones de visualización (sin cambios)
|
| 192 |
-
# =============================================================================
|
| 193 |
-
|
| 194 |
def create_table(df_distances):
|
| 195 |
df_table = df_distances.copy()
|
| 196 |
df_table.reset_index(inplace=True)
|
|
@@ -214,6 +213,7 @@ def create_table(df_distances):
|
|
| 214 |
return data_table, df_table, source_table
|
| 215 |
|
| 216 |
def create_figure(dfs, unique_subsets, color_maps, model_name):
|
|
|
|
| 217 |
fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
|
| 218 |
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
|
| 219 |
marker="circle", color_mapping=color_maps["real"],
|
|
@@ -350,38 +350,36 @@ def calculate_cluster_centers(df, labels):
|
|
| 350 |
centers = {}
|
| 351 |
for label in labels:
|
| 352 |
subset = df[df['label'] == label]
|
| 353 |
-
if not subset.empty:
|
| 354 |
centers[label] = (subset['x'].mean(), subset['y'].mean())
|
| 355 |
return centers
|
| 356 |
|
| 357 |
-
# =============================================================================
|
| 358 |
-
# Pipeline central: reducción, cálculo de distancias y regresión global.
|
| 359 |
-
# Se agrega el parámetro distance_metric.
|
| 360 |
-
# Además, si se utiliza t-SNE, se calculan trustworthiness y continuity.
|
| 361 |
-
# =============================================================================
|
| 362 |
-
|
| 363 |
def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE", distance_metric="wasserstein"):
|
| 364 |
if reduction_method == "PCA":
|
| 365 |
-
reducer = PCA(n_components=
|
| 366 |
else:
|
| 367 |
-
reducer = TSNE(n_components=
|
| 368 |
perplexity=tsne_params["perplexity"],
|
| 369 |
learning_rate=tsne_params["learning_rate"])
|
| 370 |
|
| 371 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
-
# Para PCA se captura la explained variance ratio
|
| 374 |
explained_variance = None
|
| 375 |
if reduction_method == "PCA":
|
| 376 |
explained_variance = reducer.explained_variance_ratio_
|
| 377 |
|
| 378 |
-
# Si se usa t-SNE, calculamos trustworthiness y continuity
|
| 379 |
trust = None
|
| 380 |
cont = None
|
| 381 |
if reduction_method == "t-SNE":
|
| 382 |
X = df_combined[embedding_cols].values
|
| 383 |
-
trust = trustworthiness(X, reduced, n_neighbors=
|
| 384 |
-
cont = compute_continuity(X, reduced, n_neighbors=
|
| 385 |
|
| 386 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
| 387 |
|
|
@@ -453,15 +451,11 @@ def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, r
|
|
| 453 |
"dfs_reduced": dfs_reduced,
|
| 454 |
"unique_subsets": unique_subsets,
|
| 455 |
"df_distances": df_distances,
|
| 456 |
-
"explained_variance": explained_variance,
|
| 457 |
-
"trustworthiness": trust,
|
| 458 |
-
"continuity": cont
|
| 459 |
}
|
| 460 |
|
| 461 |
-
# =============================================================================
|
| 462 |
-
# Optimización de parámetros para TSNE (se propaga también la métrica de distancia)
|
| 463 |
-
# =============================================================================
|
| 464 |
-
|
| 465 |
def optimize_tsne_params(df_combined, embedding_cols, df_f1, distance_metric):
|
| 466 |
perplexity_range = np.linspace(30, 50, 10)
|
| 467 |
learning_rate_range = np.linspace(200, 1000, 20)
|
|
@@ -490,11 +484,6 @@ def optimize_tsne_params(df_combined, embedding_cols, df_f1, distance_metric):
|
|
| 490 |
progress_text.text("Optimization completed!")
|
| 491 |
return best_params, best_R2
|
| 492 |
|
| 493 |
-
# =============================================================================
|
| 494 |
-
# Función principal run_model: incluye selector de versión, método de reducción, métrica de distancia,
|
| 495 |
-
# y, si se usa t-SNE, muestra trustworthiness y continuity.
|
| 496 |
-
# =============================================================================
|
| 497 |
-
|
| 498 |
def run_model(model_name):
|
| 499 |
version = st.selectbox("Select Model Version:", options=["vanilla", "finetuned_real"], key=f"version_{model_name}")
|
| 500 |
|
|
@@ -556,8 +545,9 @@ def run_model(model_name):
|
|
| 556 |
|
| 557 |
if reduction_method == "PCA" and result["explained_variance"] is not None:
|
| 558 |
st.subheader("Explained Variance Ratio")
|
|
|
|
| 559 |
variance_df = pd.DataFrame({
|
| 560 |
-
"Component":
|
| 561 |
"Explained Variance": result["explained_variance"]
|
| 562 |
})
|
| 563 |
st.table(variance_df)
|
|
@@ -565,6 +555,7 @@ def run_model(model_name):
|
|
| 565 |
st.subheader("t-SNE Quality Metrics")
|
| 566 |
st.write(f"Trustworthiness: {result['trustworthiness']:.4f}")
|
| 567 |
st.write(f"Continuity: {result['continuity']:.4f}")
|
|
|
|
| 568 |
|
| 569 |
data_table, df_table, source_table = create_table(result["df_distances"])
|
| 570 |
real_subset_names = list(df_table.columns[1:])
|
|
@@ -572,53 +563,58 @@ def run_model(model_name):
|
|
| 572 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
| 573 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
| 574 |
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
|
|
|
|
|
|
|
|
|
| 586 |
synthetic_centers=synthetic_centers,
|
| 587 |
real_centers=real_centers_js,
|
| 588 |
real_select=real_select),
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 601 |
line_source.data = {'x': [], 'y': []};
|
| 602 |
line_source.change.emit();
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
|
|
|
| 607 |
|
| 608 |
-
|
| 609 |
-
code="""
|
| 610 |
-
line_source.data = {'x': [], 'y': []};
|
| 611 |
-
line_source.change.emit();
|
| 612 |
-
""")
|
| 613 |
-
reset_button.js_on_event("button_click", reset_callback)
|
| 614 |
|
| 615 |
buffer = io.BytesIO()
|
| 616 |
df_table.to_excel(buffer, index=False)
|
| 617 |
buffer.seek(0)
|
| 618 |
|
| 619 |
-
layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
|
| 620 |
-
st.bokeh_chart(layout, use_container_width=True)
|
| 621 |
-
|
| 622 |
st.download_button(
|
| 623 |
label="Export Table",
|
| 624 |
data=buffer,
|
|
|
|
| 12 |
import ot
|
| 13 |
from sklearn.linear_model import LinearRegression
|
| 14 |
|
| 15 |
+
# Usaremos 4 componentes para el embedding
|
| 16 |
+
N_COMPONENTS = 100
|
| 17 |
+
TSNE_NEIGHBOURS = 150
|
| 18 |
+
|
| 19 |
TOOLTIPS = """
|
| 20 |
<div>
|
| 21 |
<div>
|
|
|
|
| 41 |
""", unsafe_allow_html=True)
|
| 42 |
st.markdown('<h1 class="main-title">Merit Embeddings 🎒📃🏆</h1>', unsafe_allow_html=True)
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
def load_embeddings(model, version):
|
| 45 |
if model == "Donut":
|
| 46 |
df_real = pd.read_csv(f"data/donut_{version}_de_Rodrigo_merit_secret_all_embeddings.csv")
|
|
|
|
| 95 |
return None
|
| 96 |
|
| 97 |
def split_versions(df_combined, reduced):
|
| 98 |
+
# Si el embedding es 2D se asignan las columnas x e y para visualización.
|
| 99 |
+
if reduced.shape[1] == 2:
|
| 100 |
+
df_combined['x'] = reduced[:, 0]
|
| 101 |
+
df_combined['y'] = reduced[:, 1]
|
| 102 |
df_real = df_combined[df_combined["version"] == "real"].copy()
|
| 103 |
df_synth = df_combined[df_combined["version"] == "synthetic"].copy()
|
| 104 |
unique_real = sorted(df_real['label'].unique().tolist())
|
|
|
|
| 109 |
unique_subsets = {"real": unique_real, "synthetic": unique_synth}
|
| 110 |
return df_dict, unique_subsets
|
| 111 |
|
| 112 |
+
def get_embedding_from_df(df):
|
| 113 |
+
# Retorna el embedding completo (4 dimensiones en este caso) guardado en la columna 'embedding'
|
| 114 |
+
if 'embedding' in df.columns:
|
| 115 |
+
return np.stack(df['embedding'].to_numpy())
|
| 116 |
+
elif 'x' in df.columns and 'y' in df.columns:
|
| 117 |
+
return df[['x', 'y']].values
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError("No se encontró embedding o coordenadas x,y en el DataFrame.")
|
| 120 |
|
| 121 |
def compute_cluster_distance(synthetic_points, real_points, metric="wasserstein", bins=20):
|
| 122 |
if metric.lower() == "wasserstein":
|
|
|
|
| 131 |
center_real = np.mean(real_points, axis=0)
|
| 132 |
return np.linalg.norm(center_syn - center_real)
|
| 133 |
elif metric.lower() == "kl":
|
| 134 |
+
# Para KL usamos histogramas multidimensionales con límites globales en cada dimensión
|
| 135 |
all_points = np.vstack([synthetic_points, real_points])
|
| 136 |
+
edges = [
|
| 137 |
+
np.linspace(np.min(all_points[:, i]), np.max(all_points[:, i]), bins+1)
|
| 138 |
+
for i in range(all_points.shape[1])
|
| 139 |
+
]
|
| 140 |
+
H_syn, _ = np.histogramdd(synthetic_points, bins=edges)
|
| 141 |
+
H_real, _ = np.histogramdd(real_points, bins=edges)
|
| 142 |
eps = 1e-10
|
| 143 |
P = H_syn + eps
|
| 144 |
Q = H_real + eps
|
|
|
|
| 154 |
groups = synthetic_df.groupby(['source', 'label'])
|
| 155 |
for (source, label), group in groups:
|
| 156 |
key = f"{label} ({source})"
|
| 157 |
+
data = get_embedding_from_df(group)
|
| 158 |
distances[key] = {}
|
| 159 |
for real_label in real_labels:
|
| 160 |
+
real_data = get_embedding_from_df(df_real[df_real['label'] == real_label])
|
| 161 |
d = compute_cluster_distance(data, real_data, metric=metric, bins=bins)
|
| 162 |
distances[key][real_label] = d
|
| 163 |
for source, group in synthetic_df.groupby('source'):
|
| 164 |
key = f"Global ({source})"
|
| 165 |
+
data = get_embedding_from_df(group)
|
| 166 |
distances[key] = {}
|
| 167 |
for real_label in real_labels:
|
| 168 |
+
real_data = get_embedding_from_df(df_real[df_real['label'] == real_label])
|
| 169 |
d = compute_cluster_distance(data, real_data, metric=metric, bins=bins)
|
| 170 |
distances[key][real_label] = d
|
| 171 |
return pd.DataFrame(distances).T
|
| 172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
def compute_continuity(X, X_embedded, n_neighbors=5):
|
| 174 |
n = X.shape[0]
|
| 175 |
D_high = pairwise_distances(X, metric='euclidean')
|
|
|
|
| 190 |
continuity_value = 1 - norm * total
|
| 191 |
return continuity_value
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
def create_table(df_distances):
|
| 194 |
df_table = df_distances.copy()
|
| 195 |
df_table.reset_index(inplace=True)
|
|
|
|
| 213 |
return data_table, df_table, source_table
|
| 214 |
|
| 215 |
def create_figure(dfs, unique_subsets, color_maps, model_name):
|
| 216 |
+
# Se crea solo si el embedding es 2D (ya que se usan 'x' y 'y' para visualizar)
|
| 217 |
fig = figure(width=600, height=600, tools="wheel_zoom,pan,reset,save", active_scroll="wheel_zoom", tooltips=TOOLTIPS, title="")
|
| 218 |
real_renderers = add_dataset_to_fig(fig, dfs["real"], unique_subsets["real"],
|
| 219 |
marker="circle", color_mapping=color_maps["real"],
|
|
|
|
| 350 |
centers = {}
|
| 351 |
for label in labels:
|
| 352 |
subset = df[df['label'] == label]
|
| 353 |
+
if not subset.empty and 'x' in subset.columns and 'y' in subset.columns:
|
| 354 |
centers[label] = (subset['x'].mean(), subset['y'].mean())
|
| 355 |
return centers
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
def compute_global_regression(df_combined, embedding_cols, tsne_params, df_f1, reduction_method="t-SNE", distance_metric="wasserstein"):
|
| 358 |
if reduction_method == "PCA":
|
| 359 |
+
reducer = PCA(n_components=N_COMPONENTS)
|
| 360 |
else:
|
| 361 |
+
reducer = TSNE(n_components=3, random_state=42,
|
| 362 |
perplexity=tsne_params["perplexity"],
|
| 363 |
learning_rate=tsne_params["learning_rate"])
|
| 364 |
|
| 365 |
reduced = reducer.fit_transform(df_combined[embedding_cols].values)
|
| 366 |
+
# Guardamos el embedding completo (4 dimensiones para PCA)
|
| 367 |
+
df_combined['embedding'] = list(reduced)
|
| 368 |
+
# Si el embedding es 2D (por t-SNE o PCA con 2 componentes) asignamos x e y para visualización
|
| 369 |
+
if reduced.shape[1] == 2:
|
| 370 |
+
df_combined['x'] = reduced[:, 0]
|
| 371 |
+
df_combined['y'] = reduced[:, 1]
|
| 372 |
|
|
|
|
| 373 |
explained_variance = None
|
| 374 |
if reduction_method == "PCA":
|
| 375 |
explained_variance = reducer.explained_variance_ratio_
|
| 376 |
|
|
|
|
| 377 |
trust = None
|
| 378 |
cont = None
|
| 379 |
if reduction_method == "t-SNE":
|
| 380 |
X = df_combined[embedding_cols].values
|
| 381 |
+
trust = trustworthiness(X, reduced, n_neighbors=TSNE_NEIGHBOURS)
|
| 382 |
+
cont = compute_continuity(X, reduced, n_neighbors=TSNE_NEIGHBOURS)
|
| 383 |
|
| 384 |
dfs_reduced, unique_subsets = split_versions(df_combined, reduced)
|
| 385 |
|
|
|
|
| 451 |
"dfs_reduced": dfs_reduced,
|
| 452 |
"unique_subsets": unique_subsets,
|
| 453 |
"df_distances": df_distances,
|
| 454 |
+
"explained_variance": explained_variance,
|
| 455 |
+
"trustworthiness": trust,
|
| 456 |
+
"continuity": cont
|
| 457 |
}
|
| 458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
def optimize_tsne_params(df_combined, embedding_cols, df_f1, distance_metric):
|
| 460 |
perplexity_range = np.linspace(30, 50, 10)
|
| 461 |
learning_rate_range = np.linspace(200, 1000, 20)
|
|
|
|
| 484 |
progress_text.text("Optimization completed!")
|
| 485 |
return best_params, best_R2
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
def run_model(model_name):
|
| 488 |
version = st.selectbox("Select Model Version:", options=["vanilla", "finetuned_real"], key=f"version_{model_name}")
|
| 489 |
|
|
|
|
| 545 |
|
| 546 |
if reduction_method == "PCA" and result["explained_variance"] is not None:
|
| 547 |
st.subheader("Explained Variance Ratio")
|
| 548 |
+
component_names = [f"PC{i+1}" for i in range(len(result["explained_variance"]))]
|
| 549 |
variance_df = pd.DataFrame({
|
| 550 |
+
"Component": component_names,
|
| 551 |
"Explained Variance": result["explained_variance"]
|
| 552 |
})
|
| 553 |
st.table(variance_df)
|
|
|
|
| 555 |
st.subheader("t-SNE Quality Metrics")
|
| 556 |
st.write(f"Trustworthiness: {result['trustworthiness']:.4f}")
|
| 557 |
st.write(f"Continuity: {result['continuity']:.4f}")
|
| 558 |
+
|
| 559 |
|
| 560 |
data_table, df_table, source_table = create_table(result["df_distances"])
|
| 561 |
real_subset_names = list(df_table.columns[1:])
|
|
|
|
| 563 |
reset_button = Button(label="Reset Colors", button_type="primary")
|
| 564 |
line_source = ColumnDataSource(data={'x': [], 'y': []})
|
| 565 |
|
| 566 |
+
# Si el embedding es 2D se crea el scatter plot de embeddings;
|
| 567 |
+
# dado que con PCA ahora usamos 4 dimensiones, este bloque se omite para PCA
|
| 568 |
+
if (reduction_method == "t-SNE" and N_COMPONENTS == 2) or (reduction_method == "PCA" and N_COMPONENTS == 2):
|
| 569 |
+
fig, real_renderers, synthetic_renderers = create_figure(result["dfs_reduced"], result["unique_subsets"], get_color_maps(result["unique_subsets"]), model_name)
|
| 570 |
+
fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
|
| 571 |
+
centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
|
| 572 |
+
real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
|
| 573 |
+
synthetic_centers = {}
|
| 574 |
+
synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist())
|
| 575 |
+
for label in synth_labels:
|
| 576 |
+
subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label]
|
| 577 |
+
if 'x' in subset.columns and 'y' in subset.columns:
|
| 578 |
+
synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
|
| 579 |
+
callback = CustomJS(args=dict(source=source_table, line_source=line_source,
|
| 580 |
synthetic_centers=synthetic_centers,
|
| 581 |
real_centers=real_centers_js,
|
| 582 |
real_select=real_select),
|
| 583 |
+
code="""
|
| 584 |
+
var selected = source.selected.indices;
|
| 585 |
+
if (selected.length > 0) {
|
| 586 |
+
var idx = selected[0];
|
| 587 |
+
var data = source.data;
|
| 588 |
+
var synth_label = data['Synthetic'][idx];
|
| 589 |
+
var real_label = real_select.value;
|
| 590 |
+
var syn_coords = synthetic_centers[synth_label];
|
| 591 |
+
var real_coords = real_centers[real_label];
|
| 592 |
+
line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]};
|
| 593 |
+
line_source.change.emit();
|
| 594 |
+
} else {
|
| 595 |
+
line_source.data = {'x': [], 'y': []};
|
| 596 |
+
line_source.change.emit();
|
| 597 |
+
}
|
| 598 |
+
""")
|
| 599 |
+
source_table.selected.js_on_change('indices', callback)
|
| 600 |
+
real_select.js_on_change('value', callback)
|
| 601 |
+
|
| 602 |
+
reset_callback = CustomJS(args=dict(line_source=line_source),
|
| 603 |
+
code="""
|
| 604 |
line_source.data = {'x': [], 'y': []};
|
| 605 |
line_source.change.emit();
|
| 606 |
+
""")
|
| 607 |
+
reset_button.js_on_event("button_click", reset_callback)
|
| 608 |
+
layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
|
| 609 |
+
else:
|
| 610 |
+
layout = column(result["scatter_fig"], column(real_select, reset_button, data_table))
|
| 611 |
|
| 612 |
+
st.bokeh_chart(layout, use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 613 |
|
| 614 |
buffer = io.BytesIO()
|
| 615 |
df_table.to_excel(buffer, index=False)
|
| 616 |
buffer.seek(0)
|
| 617 |
|
|
|
|
|
|
|
|
|
|
| 618 |
st.download_button(
|
| 619 |
label="Export Table",
|
| 620 |
data=buffer,
|