codelion commited on
Commit
3937446
·
verified ·
1 Parent(s): 9672d7f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -14
app.py CHANGED
@@ -539,23 +539,81 @@ def create_embedding_visualization(df: pd.DataFrame, color_by: str = 'is_positiv
539
  if color_by not in plot_df.columns:
540
  color_by = 'is_positive' if 'is_positive' in plot_df.columns else None
541
 
 
 
 
 
 
542
  if color_by and color_by in plot_df.columns:
543
- fig = px.scatter(
544
- plot_df, x='x', y='y',
545
- color=color_by,
546
- hover_data=['sentence' if 'sentence' in plot_df.columns else 'pivot_token'],
547
- title="Embedding Space Visualization (t-SNE)",
548
- template="plotly_dark"
549
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  else:
551
- fig = px.scatter(
552
- plot_df, x='x', y='y',
553
- hover_data=['sentence' if 'sentence' in plot_df.columns else 'pivot_token'],
554
- title="Embedding Space Visualization (t-SNE)",
555
- template="plotly_dark"
556
- )
 
 
 
 
 
 
 
 
 
557
 
558
- fig.update_layout(height=500)
 
 
 
 
 
 
 
559
 
560
  return fig
561
 
 
539
  if color_by not in plot_df.columns:
540
  color_by = 'is_positive' if 'is_positive' in plot_df.columns else None
541
 
542
+ fig = go.Figure()
543
+
544
+ # Determine text field for hover
545
+ text_field = 'sentence' if 'sentence' in plot_df.columns else 'pivot_token'
546
+
547
  if color_by and color_by in plot_df.columns:
548
+ # Group by color column for separate traces
549
+ if color_by == 'is_positive':
550
+ # Special handling for boolean is_positive
551
+ for is_pos in [True, False]:
552
+ mask = plot_df[color_by] == is_pos
553
+ subset = plot_df[mask]
554
+ if len(subset) > 0:
555
+ hover_texts = [str(row.get(text_field, ''))[:100] for _, row in subset.iterrows()]
556
+ fig.add_trace(go.Scatter(
557
+ x=subset['x'].tolist(),
558
+ y=subset['y'].tolist(),
559
+ mode='markers',
560
+ name='Positive' if is_pos else 'Negative',
561
+ marker=dict(
562
+ size=8,
563
+ color='#22c55e' if is_pos else '#ef4444',
564
+ opacity=0.7
565
+ ),
566
+ hovertext=hover_texts,
567
+ hoverinfo='text'
568
+ ))
569
+ else:
570
+ # Categorical coloring
571
+ unique_vals = plot_df[color_by].unique()
572
+ colors = ['#6366f1', '#22c55e', '#ef4444', '#f59e0b', '#8b5cf6',
573
+ '#ec4899', '#14b8a6', '#f97316', '#06b6d4', '#84cc16']
574
+ for i, val in enumerate(unique_vals):
575
+ mask = plot_df[color_by] == val
576
+ subset = plot_df[mask]
577
+ if len(subset) > 0:
578
+ hover_texts = [str(row.get(text_field, ''))[:100] for _, row in subset.iterrows()]
579
+ fig.add_trace(go.Scatter(
580
+ x=subset['x'].tolist(),
581
+ y=subset['y'].tolist(),
582
+ mode='markers',
583
+ name=str(val),
584
+ marker=dict(
585
+ size=8,
586
+ color=colors[i % len(colors)],
587
+ opacity=0.7
588
+ ),
589
+ hovertext=hover_texts,
590
+ hoverinfo='text'
591
+ ))
592
  else:
593
+ # No color grouping
594
+ hover_texts = [str(row.get(text_field, ''))[:100] for _, row in plot_df.iterrows()]
595
+ fig.add_trace(go.Scatter(
596
+ x=plot_df['x'].tolist(),
597
+ y=plot_df['y'].tolist(),
598
+ mode='markers',
599
+ name='Embeddings',
600
+ marker=dict(
601
+ size=8,
602
+ color='#6366f1',
603
+ opacity=0.7
604
+ ),
605
+ hovertext=hover_texts,
606
+ hoverinfo='text'
607
+ ))
608
 
609
+ fig.update_layout(
610
+ title="Embedding Space Visualization (t-SNE)",
611
+ xaxis_title="t-SNE 1",
612
+ yaxis_title="t-SNE 2",
613
+ template="plotly_dark",
614
+ height=500,
615
+ showlegend=True
616
+ )
617
 
618
  return fig
619