de-Rodrigo commited on
Commit
dc8144e
·
1 Parent(s): 8b6c8b4

T-nse Test

Browse files
Files changed (1) hide show
  1. app.py +108 -48
app.py CHANGED
@@ -753,56 +753,56 @@ def run_model(model_name):
753
  reset_button = Button(label="Reset Colors", button_type="primary")
754
  line_source = ColumnDataSource(data={'x': [], 'y': []})
755
 
756
- if (reduction_method == "t-SNE" and N_COMPONENTS == 2) or (reduction_method == "PCA" and N_COMPONENTS == 2):
757
- fig, real_renderers, synthetic_renderers, pretrained_renderers = create_figure(
758
- result["dfs_reduced"],
759
- result["unique_subsets"],
760
- get_color_maps(result["unique_subsets"]),
761
- model_name
762
- )
763
- fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
764
- centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
765
- real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
766
- synthetic_centers = {}
767
- synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist())
768
- for label in synth_labels:
769
- subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label]
770
- if 'x' in subset.columns and 'y' in subset.columns:
771
- synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
772
- callback = CustomJS(args=dict(source=source_table, line_source=line_source,
773
- synthetic_centers=synthetic_centers,
774
- real_centers=real_centers_js,
775
- real_select=real_select),
776
- code="""
777
- var selected = source.selected.indices;
778
- if (selected.length > 0) {
779
- var idx = selected[0];
780
- var data = source.data;
781
- var synth_label = data['Synthetic'][idx];
782
- var real_label = real_select.value;
783
- var syn_coords = synthetic_centers[synth_label];
784
- var real_coords = real_centers[real_label];
785
- line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]};
786
- line_source.change.emit();
787
- } else {
788
- line_source.data = {'x': [], 'y': []};
789
- line_source.change.emit();
790
- }
791
- """)
792
- source_table.selected.js_on_change('indices', callback)
793
- real_select.js_on_change('value', callback)
794
 
795
- reset_callback = CustomJS(args=dict(line_source=line_source),
796
- code="""
797
- line_source.data = {'x': [], 'y': []};
798
- line_source.change.emit();
799
- """)
800
- reset_button.js_on_event("button_click", reset_callback)
801
- layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
802
- else:
803
- layout = column(result["scatter_fig"], column(real_select, reset_button, data_table))
804
 
805
- st.bokeh_chart(layout, use_container_width=True)
806
 
807
  buffer = io.BytesIO()
808
  df_table.to_excel(buffer, index=False)
@@ -1435,6 +1435,66 @@ def run_model(model_name):
1435
  key=f"download_pca_coordinates_{model_name}"
1436
  )
1437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1438
 
1439
  def main():
1440
  config_style()
 
753
  reset_button = Button(label="Reset Colors", button_type="primary")
754
  line_source = ColumnDataSource(data={'x': [], 'y': []})
755
 
756
+ # if (reduction_method == "t-SNE" and N_COMPONENTS == 2) or (reduction_method == "PCA" and N_COMPONENTS == 2):
757
+ # fig, real_renderers, synthetic_renderers, pretrained_renderers = create_figure(
758
+ # result["dfs_reduced"],
759
+ # result["unique_subsets"],
760
+ # get_color_maps(result["unique_subsets"]),
761
+ # model_name
762
+ # )
763
+ # fig.line('x', 'y', source=line_source, line_width=2, line_color='black')
764
+ # centers_real = calculate_cluster_centers(result["dfs_reduced"]["real"], result["unique_subsets"]["real"])
765
+ # real_centers_js = {k: [v[0], v[1]] for k, v in centers_real.items()}
766
+ # synthetic_centers = {}
767
+ # synth_labels = sorted(result["dfs_reduced"]["synthetic"]['label'].unique().tolist())
768
+ # for label in synth_labels:
769
+ # subset = result["dfs_reduced"]["synthetic"][result["dfs_reduced"]["synthetic"]['label'] == label]
770
+ # if 'x' in subset.columns and 'y' in subset.columns:
771
+ # synthetic_centers[label] = [subset['x'].mean(), subset['y'].mean()]
772
+ # callback = CustomJS(args=dict(source=source_table, line_source=line_source,
773
+ # synthetic_centers=synthetic_centers,
774
+ # real_centers=real_centers_js,
775
+ # real_select=real_select),
776
+ # code="""
777
+ # var selected = source.selected.indices;
778
+ # if (selected.length > 0) {
779
+ # var idx = selected[0];
780
+ # var data = source.data;
781
+ # var synth_label = data['Synthetic'][idx];
782
+ # var real_label = real_select.value;
783
+ # var syn_coords = synthetic_centers[synth_label];
784
+ # var real_coords = real_centers[real_label];
785
+ # line_source.data = {'x': [syn_coords[0], real_coords[0]], 'y': [syn_coords[1], real_coords[1]]};
786
+ # line_source.change.emit();
787
+ # } else {
788
+ # line_source.data = {'x': [], 'y': []};
789
+ # line_source.change.emit();
790
+ # }
791
+ # """)
792
+ # source_table.selected.js_on_change('indices', callback)
793
+ # real_select.js_on_change('value', callback)
794
 
795
+ # reset_callback = CustomJS(args=dict(line_source=line_source),
796
+ # code="""
797
+ # line_source.data = {'x': [], 'y': []};
798
+ # line_source.change.emit();
799
+ # """)
800
+ # reset_button.js_on_event("button_click", reset_callback)
801
+ # layout = column(fig, result["scatter_fig"], column(real_select, reset_button, data_table))
802
+ # else:
803
+ # layout = column(result["scatter_fig"], column(real_select, reset_button, data_table))
804
 
805
+ # st.bokeh_chart(layout, use_container_width=True)
806
 
807
  buffer = io.BytesIO()
808
  df_table.to_excel(buffer, index=False)
 
1435
  key=f"download_pca_coordinates_{model_name}"
1436
  )
1437
 
1438
+ elif reduction_method == "t-SNE":
1439
+ st.markdown("## t-SNE - Solo Muestras Reales")
1440
+ # -------------------------------------------------------------------------
1441
+ # 1. t-SNE sobre las muestras reales
1442
+ df_real_only = embeddings["real"].copy()
1443
+
1444
+ reducer_real = TSNE(n_components=2, perplexity=30, random_state=42)
1445
+ reduced_real = reducer_real.fit_transform(df_real_only[embedding_cols].values)
1446
+
1447
+ # Agregar columnas TSNE1, TSNE2
1448
+ df_real_only['TSNE1'] = reduced_real[:, 0]
1449
+ df_real_only['TSNE2'] = reduced_real[:, 1]
1450
+
1451
+ unique_labels_real = sorted(df_real_only['label'].unique().tolist())
1452
+
1453
+ # Mapeo de colores para las muestras reales usando la paleta Reds9
1454
+ num_labels = len(unique_labels_real)
1455
+ if num_labels <= 9:
1456
+ red_palette = Reds9[:num_labels]
1457
+ else:
1458
+ red_palette = (Reds9 * ((num_labels // 9) + 1))[:num_labels]
1459
+ real_color_mapping = {label: red_palette[i] for i, label in enumerate(unique_labels_real)}
1460
+
1461
+ # -------------------------------------------------------------------------
1462
+ # Crear plot interactivo con Bokeh
1463
+ st.subheader("t-SNE - Real: Visualización Interactiva")
1464
+
1465
+ source = ColumnDataSource(df_real_only)
1466
+
1467
+ hover = HoverTool(tooltips=[
1468
+ ("Index", "$index"),
1469
+ ("Label", "@label"),
1470
+ ("TSNE1", "@TSNE1"),
1471
+ ("TSNE2", "@TSNE2")
1472
+ ])
1473
+
1474
+ p = figure(
1475
+ width=800,
1476
+ height=600,
1477
+ title="t-SNE sobre muestras reales",
1478
+ tools=["pan", "wheel_zoom", "box_zoom", "reset", hover]
1479
+ )
1480
+
1481
+ for label in unique_labels_real:
1482
+ subset = df_real_only[df_real_only['label'] == label]
1483
+ p.scatter(
1484
+ x=subset["TSNE1"],
1485
+ y=subset["TSNE2"],
1486
+ size=8,
1487
+ color=real_color_mapping[label],
1488
+ alpha=0.7,
1489
+ legend_label=str(label)
1490
+ )
1491
+
1492
+ p.legend.title = "Label"
1493
+ p.legend.location = "top_right"
1494
+ p.xaxis.axis_label = "t-SNE 1"
1495
+ p.yaxis.axis_label = "t-SNE 2"
1496
+
1497
+ st.bokeh_chart(p, use_container_width=True)
1498
 
1499
  def main():
1500
  config_style()