File size: 29,746 Bytes
cb8a7e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 |
"""Utilities for interactive visualization of extracted graphs"""
import pandas as pd
import plotly.express as px
import streamlit as st
def create_scatter_plot_with_filter(graph_data):
"""
Crea uno scatter plot interattivo con filtro per cumulative influence
Args:
graph_data: Dizionario contenente i dati del grafo (nodes, metadata, etc)
"""
if 'nodes' not in graph_data:
st.warning("⚠️ No nodes found in graph data")
return
# Estrai prompt_tokens dalla metadata per mappare ctx_idx -> token
prompt_tokens = graph_data.get('metadata', {}).get('prompt_tokens', [])
# Crea mapping ctx_idx -> token
token_map = {i: token for i, token in enumerate(prompt_tokens)}
# Estrai i nodi con ctx_idx, layer e influence
# Mappa layer 'E' (embeddings) a -1, numeri restano numeri
scatter_data = []
skipped_nodes = [] # Per logging nodi problematici
for node in graph_data['nodes']:
layer_val = node.get('layer', '')
try:
# Mappa embedding layer a -1
if str(layer_val).upper() == 'E':
layer_numeric = -1
else:
# Prova a convertire a int
layer_numeric = int(layer_val)
# Gestisci influence: usa valore minimo se mancante o zero
influence_val = node.get('influence', 0)
if influence_val is None or influence_val == 0:
influence_val = 0.001 # Valore minimo per visibilità
# Ottieni ctx_idx e mappa al token
ctx_idx_val = node.get('ctx_idx', 0)
token_str = token_map.get(ctx_idx_val, f"ctx_{ctx_idx_val}")
# Estrai feature_index dal node_id SOLO per nodi SAE
# Formato SAE: "layer_featureIndex_sequence" → es. "24_79427_7"
# Altri tipi (MLP error, embeddings, logits) usano formati diversi
node_id = node.get('node_id', '')
node_type = node.get('feature_type', '')
feature_idx = None
if node_type == 'cross layer transcoder':
# Solo per nodi SAE: estrai feature_idx da node_id
if node_id and '_' in node_id:
parts = node_id.split('_')
if len(parts) >= 2:
try:
# Il secondo elemento è il feature_index
feature_idx = int(parts[1])
except (ValueError, IndexError):
pass
# Se il parsing fallisce per un nodo SAE, skippa!
if feature_idx is None:
skipped_nodes.append(f"layer={layer_val}, node_id={node_id}, type=SAE")
continue # Salta nodi SAE malformati
else:
# Per nodi non-SAE (embeddings, logits, MLP error, ecc.):
# usa -1 come placeholder - NON estrarre da node_id!
feature_idx = -1
scatter_data.append({
'layer': layer_numeric,
'ctx_idx': ctx_idx_val,
'token': token_str,
'id': node_id,
'influence': influence_val,
'feature': feature_idx # Ora contiene l'indice corretto o -1 per non-features!
})
except (ValueError, TypeError):
# Salta nodi con layer non valido
continue
# Log nodi skippati se ce ne sono
if skipped_nodes:
st.warning(f"⚠️ {len(skipped_nodes)} feature nodes with malformed node_id were skipped")
with st.expander("Skipped nodes details"):
for node_info in skipped_nodes[:10]: # Mostra solo i primi 10
st.text(node_info)
if len(skipped_nodes) > 10:
st.text(f"... and {len(skipped_nodes) - 10} more nodes")
if not scatter_data:
st.warning("⚠️ No valid nodes found for plotting")
return
scatter_df = pd.DataFrame(scatter_data)
# Pulisci NaN e valori invalidi
scatter_df['influence'] = scatter_df['influence'].fillna(0.001)
scatter_df['influence'] = scatter_df['influence'].replace(0, 0.001)
# === BINNING PER EVITARE SOVRAPPOSIZIONI (stile Neuronpedia) ===
# Per ogni combinazione (ctx_idx, layer), distribuiamo i nodi su sub-colonne
import numpy as np
bin_width = 0.3 # Larghezza della sub-colonna
scatter_df['sub_column'] = 0
for (ctx, layer), group in scatter_df.groupby(['ctx_idx', 'layer']):
n_nodes = len(group)
if n_nodes > 1:
# Calcola quante sub-colonne servono (max 5 per evitare troppa dispersione)
n_bins = min(5, int(np.ceil(np.sqrt(n_nodes))))
# Assegna ogni nodo a una sub-colonna
for i, idx in enumerate(group.index):
sub_col = (i % n_bins) - (n_bins - 1) / 2 # Centra attorno a 0
scatter_df.at[idx, 'sub_column'] = sub_col * bin_width
# Applica offset per creare sub-colonne
scatter_df['ctx_idx_display'] = scatter_df['ctx_idx'] + scatter_df['sub_column']
# === FILTRO PER CUMULATIVE INFLUENCE ===
st.markdown("### 3️⃣ Filter Features by Cumulative Influence Coverage")
# Calcola il massimo valore di influence presente nei dati
max_influence = scatter_df['influence'].max()
# Mostra il node_threshold usato durante la generazione (se disponibile)
node_threshold_used = graph_data.get('metadata', {}).get('node_threshold', None)
if node_threshold_used is not None:
st.info(f"""
**The `influence` field is the cumulative coverage (0-{max_influence:.2f})** calculated by circuit tracer pruning. When nodes are sorted by descending influence, a node with `influence=0.65` means that
**up to that node** covers 65% of the total influence.
""")
else:
st.info(f"""
**The `influence` field is the cumulative coverage (0-{max_influence:.2f})** calculated by circuit tracer pruning.
When nodes are sorted by descending influence, a node with `influence=0.65` means that
**up to that node** covers 65% of the total influence.
""")
cumulative_threshold = st.slider(
"Cumulative Influence Threshold",
min_value=0.0,
max_value=float(max_influence),
value=float(max_influence),
step=0.01,
key="cumulative_slider_main",
help=f"Keep only nodes with influence ≤ threshold. Range: 0.0 - {max_influence:.2f} (max in data)"
)
# Checkbox per filtrare reconstruction error nodes
filter_error_nodes = st.checkbox(
"Exclude Reconstruction Error Nodes (feature = -1)",
value=False,
key="filter_error_checkbox",
help="Reconstruction error nodes represent the part of the model not explained by SAE features"
)
# Filtra usando direttamente il campo influence dal JSON
num_total = len(scatter_df)
# Identifica reconstruction error nodes (feature = -1) - KPI verrà calcolato dopo
is_error_node = scatter_df['feature'] == -1
n_error_total = is_error_node.sum()
pct_error_nodes = (n_error_total / num_total * 100) if num_total > 0 else 0
# Identifica embeddings e logits da mantenere sempre
is_embedding = scatter_df['layer'] == -1 # Layer 'E' mappato a -1
# Logits hanno layer massimo (es. layer 27 per gemma-2-2b con 26 layer + 1)
max_layer = scatter_df['layer'].max()
is_logit = scatter_df['layer'] == max_layer
# Applica filtri combinati: influence threshold + error nodes (se checkbox attivo)
if cumulative_threshold < 1.0:
mask_influence = scatter_df['influence'] <= cumulative_threshold
mask_keep = mask_influence | is_embedding | is_logit
else:
mask_keep = pd.Series([True] * len(scatter_df), index=scatter_df.index)
# Applica filtro error nodes se checkbox attivo
if filter_error_nodes:
# Escludi error nodes (feature = -1), ma mantieni embeddings/logits
mask_not_error = (scatter_df['feature'] != -1) | is_embedding | is_logit
mask_keep = mask_keep & mask_not_error
scatter_filtered = scatter_df[mask_keep].copy()
# Soglia di influence effettiva (max influence tra i nodi filtrati, escludendo embeddings/logits)
feature_nodes_filtered = scatter_filtered[~((scatter_filtered['layer'] == -1) | (scatter_filtered['layer'] == max_layer))]
if len(feature_nodes_filtered) > 0:
threshold_influence = feature_nodes_filtered['influence'].max()
else:
threshold_influence = 0.0
num_selected = len(scatter_filtered)
# Conta embeddings, features e error nodes nel dataset filtrato (prima di rimuovere logit)
is_embedding_filtered = scatter_filtered['layer'] == -1
max_layer_filtered = scatter_filtered['layer'].max()
is_logit_filtered = scatter_filtered['layer'] == max_layer_filtered
is_error_filtered = scatter_filtered['feature'] == -1
n_embeddings = len(scatter_filtered[is_embedding_filtered])
n_error_nodes = len(scatter_filtered[is_error_filtered & ~is_embedding_filtered & ~is_logit_filtered])
n_features = len(scatter_filtered[~(is_embedding_filtered | is_logit_filtered | is_error_filtered)])
n_logits_excluded = len(scatter_filtered[is_logit_filtered])
n_error_excluded = n_error_total - n_error_nodes if filter_error_nodes else 0
# Mostra statistiche filtro
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Nodes", num_total)
with col2:
st.metric("Selected Nodes", num_selected)
with col3:
pct = (num_selected / num_total * 100) if num_total > 0 else 0
st.metric("% Nodes", f"{pct:.1f}%")
with col4:
st.metric("Influence Threshold", f"{threshold_influence:.6f}")
# Usa il dataframe filtrato per il plot
scatter_df = scatter_filtered
# Ricalcola le sub-colonne per il dataset filtrato
scatter_df = scatter_df.copy()
scatter_df['sub_column'] = 0
for (ctx, layer), group in scatter_df.groupby(['ctx_idx', 'layer']):
n_nodes = len(group)
if n_nodes > 1:
n_bins = min(5, int(np.ceil(np.sqrt(n_nodes))))
for i, idx in enumerate(group.index):
sub_col = (i % n_bins) - (n_bins - 1) / 2
scatter_df.at[idx, 'sub_column'] = sub_col * bin_width
scatter_df['ctx_idx_display'] = scatter_df['ctx_idx'] + scatter_df['sub_column']
# Calcola node_influence (marginal influence) per il raggio dei cerchi/quadrati
# Se non presente nel JSON (vecchi grafi), calcoliamo al volo
if 'node_influence' not in scatter_df.columns:
# Calcola marginal influence come differenza tra cumulative consecutive
df_sorted_by_cumul = scatter_df.sort_values('influence').reset_index(drop=True)
df_sorted_by_cumul['node_influence'] = df_sorted_by_cumul['influence'].diff()
df_sorted_by_cumul.loc[0, 'node_influence'] = df_sorted_by_cumul.loc[0, 'influence']
# Remap al dataframe originale
node_id_to_marginal = dict(zip(df_sorted_by_cumul['id'], df_sorted_by_cumul['node_influence']))
scatter_df['node_influence'] = scatter_df['id'].map(node_id_to_marginal).fillna(scatter_df['influence'])
# CALCOLA KPI ERROR NODES (ora che node_influence è disponibile)
# Usa scatter_df (dataset completo prima della rimozione logit) per i KPI globali
is_error_in_complete = scatter_df['feature'] == -1
total_node_influence = scatter_df['node_influence'].sum()
error_node_influence = scatter_df[is_error_in_complete]['node_influence'].sum()
pct_error_influence = (error_node_influence / total_node_influence * 100) if total_node_influence > 0 else 0
# Mostra KPI reconstruction error nodes (prima del plot)
col1, col2 = st.columns(2)
with col1:
st.metric(
"% Error Nodes",
f"{pct_error_nodes:.1f}%",
help=f"{n_error_total} out of {num_total} total nodes are reconstruction error (feature=-1)"
)
with col2:
st.metric(
"% Node Influence (Error)",
f"{pct_error_influence:.1f}%",
help=f"Reconstruction error nodes contribute {pct_error_influence:.1f}% of total node_influence"
)
# Messaggio info con breakdown
info_parts = [f"{n_embeddings} embeddings", f"{n_features} features"]
if n_error_nodes > 0:
info_parts.append(f"{n_error_nodes} error nodes")
excluded_parts = [f"{n_logits_excluded} logits"]
if n_error_excluded > 0:
excluded_parts.append(f"{n_error_excluded} error nodes")
st.info(f"📊 Displaying {n_embeddings + n_features + n_error_nodes} nodes: {', '.join(info_parts)} ({', '.join(excluded_parts)} excluded)")
# Identifica i 2 gruppi: embeddings e features (escludi logits)
is_embedding_group = scatter_df['layer'] == -1
max_layer = scatter_df['layer'].max()
is_logit_group = scatter_df['layer'] == max_layer
is_feature_group = ~(is_embedding_group | is_logit_group)
# RIMUOVI I LOGIT dal dataset
scatter_df = scatter_df[~is_logit_group].copy()
# Ricalcola le maschere dopo il filtro
is_embedding_group = scatter_df['layer'] == -1
is_feature_group = scatter_df['layer'] != -1
# Aggiungi colonna per il tipo di nodo (solo 2 tipi ora)
scatter_df['node_type'] = 'feature'
scatter_df.loc[is_embedding_group, 'node_type'] = 'embedding'
# Calcola influence_log normalizzato per gruppo con formula più aggressiva
# Ogni gruppo ha la sua scala basata sul max del gruppo
scatter_df['influence_log'] = 0.0
for group_name, group_mask in [('embedding', is_embedding_group),
('feature', is_feature_group)]:
if group_mask.sum() > 0:
group_data = scatter_df[group_mask]['node_influence'].abs()
# Normalizza rispetto al max del gruppo
max_in_group = group_data.max()
if max_in_group > 0:
normalized = group_data / max_in_group
# Formula più aggressiva: usa power 3 per estremizzare le differenze
# normalized^3 rende i valori bassi molto più piccoli e i valori alti più grandi
# Moltiplica per 1000 per avere un buon range di grandezza
scatter_df.loc[group_mask, 'influence_log'] = (normalized ** 3) * 1000 + 10
else:
scatter_df.loc[group_mask, 'influence_log'] = 10 # Valore minimo default
# Crea scatter plot con simboli diversi per gruppo (solo embeddings e features)
symbol_map = {
'embedding': 'square',
'feature': 'circle'
}
fig = px.scatter(
scatter_df,
x='ctx_idx_display', # Usa posizione con offset
y='layer',
size='influence_log', # Usa scala aggressiva (power 3) normalizzata per gruppo
symbol='node_type', # Simbolo diverso per tipo
symbol_map=symbol_map,
color='node_type', # Colore diverso per tipo
color_discrete_map={
'embedding': '#4CAF50', # Verde per embeddings
'feature': '#808080' # Grigio per features
},
labels={
'id': 'Node ID',
'ctx_idx_display': 'Context Position',
'ctx_idx': 'ctx_idx',
'layer': 'Layer',
'influence': 'Cumulative Influence',
'node_influence': 'Node Influence',
'node_type': 'Node Type',
'token': 'Token',
'feature': 'Feature'
},
title='Features by Layer and Position (size: node_influence^3 normalized per group)',
hover_data={
'ctx_idx': True,
'token': True,
'layer': True,
'node_type': True,
'id': True,
'feature': True,
'node_influence': ':.6f', # Influenza marginale (grandezza simbolo)
'influence': ':.4f', # Cumulative influence (filtro slider)
'ctx_idx_display': False, # Nascondi la posizione modificata
'influence_log': False # Nascondi il valore logaritmico
}
)
# Personalizza il layout con alta trasparenza e outline marcato
# Applica a tutte le tracce (embeddings, features, logits)
max_influence_log = scatter_df['influence_log'].max()
fig.update_traces(
marker=dict(
sizemode='area',
sizeref=2.*max_influence_log/(50.**2) if max_influence_log > 0 else 1,
sizemin=2, # Dimensione minima
opacity=0.3, # Trasparenza medio-alta
line=dict(width=1.5, color='white') # Contorno bianco per distinguere
)
)
# Crea tick labels personalizzate per l'asse x (ctx_idx: token)
unique_ctx = sorted(scatter_df['ctx_idx'].unique())
tick_labels = [f"{ctx}: {token_map.get(ctx, '')}" for ctx in unique_ctx]
fig.update_layout(
template='plotly_white',
height=600,
showlegend=True, # Mostra legenda per i 3 gruppi
legend=dict(
title="Node Type",
orientation="v",
yanchor="top",
y=0.99,
xanchor="left",
x=0.99,
bgcolor="rgba(255,255,255,0.8)"
),
xaxis=dict(
gridcolor='lightgray',
tickmode='array',
tickvals=unique_ctx,
ticktext=tick_labels,
tickangle=-45
),
yaxis=dict(gridcolor='lightgray')
)
st.plotly_chart(fig, use_container_width=True)
# Mostra statistiche per gruppo
with st.expander("📊 Statistics by Group (Size Normalization)", expanded=False):
col1, col2 = st.columns(2)
with col1:
st.markdown("**🟩 Embeddings (green squares)**")
emb_data = scatter_df[scatter_df['node_type'] == 'embedding']
if len(emb_data) > 0:
st.metric("Nodes", len(emb_data))
st.metric("Max node_influence", f"{emb_data['node_influence'].max():.6f}")
st.metric("Mean node_influence", f"{emb_data['node_influence'].mean():.6f}")
st.metric("Min node_influence", f"{emb_data['node_influence'].min():.6f}")
else:
st.info("No embeddings in filtered dataset")
with col2:
st.markdown("**⚪ Features (gray circles)**")
feat_data = scatter_df[scatter_df['node_type'] == 'feature']
if len(feat_data) > 0:
st.metric("Nodes", len(feat_data))
st.metric("Max node_influence", f"{feat_data['node_influence'].max():.6f}")
st.metric("Mean node_influence", f"{feat_data['node_influence'].mean():.6f}")
st.metric("Min node_influence", f"{feat_data['node_influence'].min():.6f}")
else:
st.info("No features in filtered dataset")
st.info("""
💡 **Size formula**: `size = (normalized_node_influence)³ × 1000 + 10`
Size is normalized **per group** and uses **power 3** to emphasize differences:
- A node with 50% of max → size = 0.5³ = 12.5% (much smaller)
- A node with 80% of max → size = 0.8³ = 51.2%
- A node with 100% of max → size = 1.0³ = 100%
The 2 groups (embeddings and features) have independent scales.
Note: in the JSON the "influence" field is the pre-pruning cumulative, so estimating node_influence as the difference between consecutive cumulatives is only a normalized proxy (to be renormalized on the current set), because the graph may already be topologically pruned and the selection does not coincide with a contiguous prefix of sorted nodes.
""")
# === GRAFICO PARETO: NODE INFLUENCE (solo features, no embeddings/logits) ===
with st.expander("📈 Pareto Analysis Node Influence (Features only)", expanded=False):
try:
# Filtra solo features (scatter_df ha già rimosso i logit e ha node_type)
features_only = scatter_df[scatter_df['node_type'] == 'feature'].copy()
if len(features_only) == 0:
st.warning("⚠️ No features found in filtered dataset")
return
# Ordina per node_influence decrescente
sorted_df = features_only.sort_values('node_influence', ascending=False).reset_index(drop=True)
# Calcola rank e percentile
sorted_df['rank'] = range(1, len(sorted_df) + 1)
sorted_df['rank_pct'] = sorted_df['rank'] / len(sorted_df) * 100
# Calcola node_influence cumulativa (somma progressiva)
total_node_inf = sorted_df['node_influence'].sum()
if total_node_inf == 0:
st.warning("⚠️ Total Node influence is 0")
return
sorted_df['cumulative_node_influence'] = sorted_df['node_influence'].cumsum()
sorted_df['cumulative_node_influence_pct'] = sorted_df['cumulative_node_influence'] / total_node_inf * 100
# Crea grafico Pareto con doppio asse Y
import plotly.graph_objects as go
from plotly.subplots import make_subplots
# Crea subplot con asse Y secondario
fig_pareto = make_subplots(specs=[[{"secondary_y": True}]])
# Barra: node_influence individuale (limita a primi 100 nodi per leggibilità)
display_limit = min(100, len(sorted_df))
fig_pareto.add_trace(
go.Bar(
x=sorted_df['rank'][:display_limit],
y=sorted_df['node_influence'][:display_limit],
name='Node Influence',
marker=dict(color='#2196F3', opacity=0.6),
hovertemplate='<b>Rank: %{x}</b><br>Node Influence: %{y:.6f}<extra></extra>'
),
secondary_y=False
)
# Linea: cumulativa % (usa tutti i nodi)
fig_pareto.add_trace(
go.Scatter(
x=sorted_df['rank_pct'],
y=sorted_df['cumulative_node_influence_pct'],
mode='lines+markers',
name='Cumulative %',
line=dict(color='#FF5722', width=3),
marker=dict(size=4),
hovertemplate='<b>Top %{x:.1f}% features</b><br>Cumulative: %{y:.1f}%<extra></extra>'
),
secondary_y=True
)
# Linee di riferimento Pareto (80%, 90%, 95%)
for pct, label in [(80, '80%'), (90, '90%'), (95, '95%')]:
fig_pareto.add_hline(
y=pct,
line_dash="dash",
line_color="gray",
opacity=0.5,
secondary_y=True
)
fig_pareto.add_annotation(
x=100,
y=pct,
text=label,
showarrow=False,
xanchor='left',
yref='y2'
)
# Trova il "knee" (punto dove la cumulativa raggiunge 80%)
knee_idx = (sorted_df['cumulative_node_influence_pct'] >= 80).idxmax()
knee_rank_pct = sorted_df.loc[knee_idx, 'rank_pct']
knee_cumul = sorted_df.loc[knee_idx, 'cumulative_node_influence_pct']
fig_pareto.add_trace(
go.Scatter(
x=[knee_rank_pct],
y=[knee_cumul],
mode='markers',
name='Knee (80%)',
marker=dict(size=15, color='#4CAF50', symbol='diamond', line=dict(width=2, color='white')),
hovertemplate=f'<b>Knee Point</b><br>Top {knee_rank_pct:.1f}% features<br>Cumulativa: {knee_cumul:.1f}%<extra></extra>',
showlegend=True
),
secondary_y=True
)
# Layout
fig_pareto.update_xaxes(title_text="Rank % Features (by descending node_influence)")
fig_pareto.update_yaxes(title_text="Node Influence (individual)", secondary_y=False)
fig_pareto.update_yaxes(title_text="Cumulative % Node Influence", secondary_y=True, range=[0, 105])
fig_pareto.update_layout(
height=500,
showlegend=True,
template='plotly_white',
legend=dict(x=0.02, y=0.98, xanchor='left', yanchor='top'),
title="Pareto Chart: Node Influence of Features"
)
st.plotly_chart(fig_pareto, use_container_width=True)
# Statistiche chiave Pareto
st.markdown("#### 📊 Pareto Statistics (Node Influence)")
col1, col2, col3, col4 = st.columns(4)
# Trova percentili chiave
top_10_idx = max(0, int(len(sorted_df) * 0.1))
top_20_idx = max(0, int(len(sorted_df) * 0.2))
top_50_idx = max(0, int(len(sorted_df) * 0.5))
top_10_pct = sorted_df['cumulative_node_influence_pct'].iloc[top_10_idx] if top_10_idx < len(sorted_df) else 0
top_20_pct = sorted_df['cumulative_node_influence_pct'].iloc[top_20_idx] if top_20_idx < len(sorted_df) else 0
top_50_pct = sorted_df['cumulative_node_influence_pct'].iloc[top_50_idx] if top_50_idx < len(sorted_df) else 0
with col1:
st.metric("Top 10% features", f"{top_10_pct:.1f}% node_influence",
help=f"The top {int(len(sorted_df)*0.1)} most influential features cover {top_10_pct:.1f}% of total influence")
with col2:
st.metric("Top 20% features", f"{top_20_pct:.1f}% node_influence",
help=f"The top {int(len(sorted_df)*0.2)} most influential features cover {top_20_pct:.1f}% of total influence")
with col3:
st.metric("Top 50% features", f"{top_50_pct:.1f}% node_influence",
help=f"The top {int(len(sorted_df)*0.5)} most influential features cover {top_50_pct:.1f}% of total influence")
with col4:
# Gini coefficient
gini = 1 - 2 * np.trapz(sorted_df['cumulative_node_influence_pct'] / 100, sorted_df['rank_pct'] / 100)
st.metric("Gini Coefficient", f"{gini:.3f}", help="0 = equal distribution, 1 = highly concentrated")
# Info sul knee point e suggerimento threshold
# sorted_df[knee_idx] ci dà la riga del knee point
knee_cumul_threshold = sorted_df.loc[knee_idx, 'influence'] if 'influence' in sorted_df.columns else scatter_df['influence'].max()
st.success(f"""
🎯 **Knee Point (80%)**: The first **{knee_rank_pct:.1f}%** of features ({int(len(sorted_df) * knee_rank_pct / 100)} nodes)
cover **80%** of total node_influence.
💡 **Threshold Suggestion**: To focus on features up to the knee point (80%),
use `cumulative_threshold ≈ {knee_cumul_threshold:.4f}` in the slider above.
""")
# Histogram distribuzione node_influence (opzionale, in expander)
with st.expander("📊 Node Influence Distribution Histogram", expanded=False):
fig_hist = px.histogram(
sorted_df,
x='node_influence',
nbins=50,
title='Node Influence Distribution (Features)',
labels={'node_influence': 'Node Influence', 'count': 'Frequency'},
color_discrete_sequence=['#2196F3']
)
fig_hist.update_layout(
height=350,
template='plotly_white',
showlegend=False
)
fig_hist.update_traces(marker=dict(opacity=0.7))
st.plotly_chart(fig_hist, use_container_width=True)
# Statistiche distribuzione
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Mean", f"{sorted_df['node_influence'].mean():.6f}")
with col2:
st.metric("Median", f"{sorted_df['node_influence'].median():.6f}")
with col3:
st.metric("Std Dev", f"{sorted_df['node_influence'].std():.6f}")
with col4:
st.metric("Max", f"{sorted_df['node_influence'].max():.6f}")
except Exception as e:
st.error(f"❌ Error creating distribution chart: {str(e)}")
import traceback
st.code(traceback.format_exc())
# Ritorna le feature filtrate (solo SAE features, no embeddings/logits/errors)
# Utile per export
sae_features_only = scatter_filtered[
~(is_embedding_filtered | is_logit_filtered | is_error_filtered)
].copy()
return sae_features_only
|