File size: 35,010 Bytes
cb8a7e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d228c2
 
cb8a7e5
5d228c2
 
0f312db
cb8a7e5
5d228c2
 
 
 
cb8a7e5
5d228c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f79718
5d228c2
 
 
cb8a7e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6af579c
 
 
 
 
 
 
 
 
cb8a7e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b97db2
cb8a7e5
406b7cb
9b97db2
 
cb8a7e5
 
 
 
 
 
6af579c
 
 
 
 
 
 
 
cb8a7e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406b7cb
 
 
 
cb8a7e5
 
 
 
 
 
 
6af579c
 
 
 
 
 
 
 
cb8a7e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406b7cb
 
 
 
cb8a7e5
e2019d7
 
 
 
 
 
 
 
 
 
cb8a7e5
 
 
 
e2019d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6de59a4
 
cb8a7e5
 
 
 
 
 
 
 
6de59a4
 
 
 
 
 
 
 
 
cb8a7e5
 
 
 
 
e2019d7
 
 
 
 
 
 
 
 
cb8a7e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
ο»Ώ"""Page 0 - Graph Generation: Generate Attribution Graphs on Neuronpedia"""
import sys
from pathlib import Path

# Add parent directory to path
parent_dir = Path(__file__).parent.parent.parent
if str(parent_dir) not in sys.path:
    sys.path.insert(0, str(parent_dir))

import streamlit as st
import json
import os
from datetime import datetime

# Try to import PipelineState (optional - won't break if missing)
try:
    from eda.utils.pipeline_state import PipelineState
    PIPELINE_STATE_AVAILABLE = True
except ImportError:
    PIPELINE_STATE_AVAILABLE = False

# Import graph generation functions
try:
    from scripts.neuronpedia_graph_generation import (
        generate_attribution_graph,
        get_graph_stats,
        load_api_key,
        extract_static_metrics_from_json
    )
except ImportError:
    # Fallback if module is not directly importable
    import importlib.util
    script_path = parent_dir / "scripts" / "00_neuronpedia_graph_generation.py"
    spec = importlib.util.spec_from_file_location("neuronpedia_graph_generation", script_path)
    graph_gen = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(graph_gen)
    generate_attribution_graph = graph_gen.generate_attribution_graph
    get_graph_stats = graph_gen.get_graph_stats
    load_api_key = graph_gen.load_api_key
    extract_static_metrics_from_json = graph_gen.extract_static_metrics_from_json

st.set_page_config(page_title="Graph Generation", page_icon="🌐", layout="wide")

st.title("🌐 Attribution Graph Generation")

st.info("""
1. **Generate a new attribution graph on Neuronpedia** to analyze how the model predicts the next token. \n
2. **Analyze the graph** to understand the contribution of each feature.\n
3. **Filter Features by Cumulative Influence Coverage** for downstream analysis.
""")

# ===== SIDEBAR: CONFIGURATION =====

st.sidebar.header("Configuration")

# Neuronpedia API Key
st.sidebar.subheader("Neuronpedia API")

# Try to load from environment/secrets
api_key = load_api_key()

if not api_key:
    st.sidebar.warning("⚠️ Neuronpedia API Key not found")
    st.sidebar.info("""
    Add `NEURONPEDIA_API_KEY=your-key` to HF Secrets
    or enter it below.
    """)
    
    # Allow manual input
    api_key = st.sidebar.text_input(
        "Enter API Key:", 
        type="password", 
        key="neuronpedia_key_input",
        help="Enter your Neuronpedia API key"
    )
    
    if not api_key:
        st.error("""
        **Neuronpedia API Key Required!**
        
        1. Obtain an API key from [Neuronpedia](https://www.neuronpedia.org/)
        2. Enter it in the sidebar, OR
        3. Add to HF Spaces Secrets (Settings β†’ Repository secrets):
           ```
           NEURONPEDIA_API_KEY = your-key-here
           ```
        """)
        st.stop()
    else:
        st.sidebar.success(f"βœ… API Key entered ({len(api_key)} characters)")
else:
    st.sidebar.success(f"βœ… API Key loaded ({len(api_key)} characters)")

# Save to session_state for reuse in other pages
if api_key:
    st.session_state['neuronpedia_api_key'] = api_key

# ===== SECTION: GENERATE NEW GRAPH =====

st.header("🌐 Generate New Attribution Graph")

# INPUT PROMPT
st.subheader("1️⃣ Prompt Configuration")

prompt = st.text_area(
    "Prompt to analyze",
    value="The capital of state containing Dallas is",
    height=100,
    help="Enter the prompt to analyze. The model will try to predict the next token."
)

# GRAPH PARAMETERS
st.subheader("Graph Parameters")

with st.expander("Advanced configuration", expanded=False):
    col1, col2 = st.columns(2)
    
    with col1:
        st.write("**Model & Source Set**")
        
        model_id = st.selectbox(
            "Model ID",
            ["gemma-2-2b", "gpt2-small", "gemma-2-9b"],
            help="Model to analyze"
        )
        
        source_set_name = st.text_input(
            "Source Set Name",
            value="clt-hp", #"gemmascope-transcoder-16k",
            help="Name of the SAE source set to use"
        )
        
        max_feature_nodes = st.number_input(
            "Max Feature Nodes",
            min_value=100,
            max_value=10000,
            value=5000,
            step=100,
            help="Maximum number of feature nodes to include"
        )
    
    with col2:
        st.write("**Thresholds**")
        
        node_threshold = st.slider(
            "Node Threshold",
            min_value=0.0,
            max_value=1.0,
            value=0.8,
            step=0.05,
            help="Minimum importance threshold to include a node"
        )
        
        edge_threshold = st.slider(
            "Edge Threshold",
            min_value=0.0,
            max_value=1.0,
            value=0.85,
            step=0.05,
            help="Minimum importance threshold to include an edge"
        )
        
        max_n_logits = st.number_input(
            "Max N Logits",
            min_value=1,
            max_value=50,
            value=10,
            step=1,
            help="Maximum number of logits to consider"
        )
        
        desired_logit_prob = st.slider(
            "Desired Logit Probability",
            min_value=0.5,
            max_value=0.99,
            value=0.95,
            step=0.01,
            help="Desired cumulative probability for logits"
        )

slug = st.text_input(
    "Custom slug (optional)",
    value="",
    help="If empty, will be generated automatically"
)

# GENERATION
st.subheader("Generation")

col1, col2 = st.columns([1, 2])

with col1:
    generate_button = st.button("🌐 Generate Graph", type="primary", use_container_width=True)
with col2:
    save_locally = st.checkbox("Save locally", value=True)

# State
if 'generation_result' not in st.session_state:
    st.session_state.generation_result = None
if 'static_metrics_df' not in st.session_state:
    st.session_state.static_metrics_df = None
if 'extracted_graph_data' not in st.session_state:
    st.session_state.extracted_graph_data = None
if 'extracted_csv_df' not in st.session_state:
    st.session_state.extracted_csv_df = None

if generate_button:
    if not prompt.strip():
        st.error("Enter a valid prompt!")
        st.stop()
    
    progress_bar = st.progress(0)
    status_text = st.empty()
    
    try:
        status_text.text("Preparing...")
        progress_bar.progress(10)
        
        status_text.text("Sending request to Neuronpedia...")
        progress_bar.progress(30)
        
        result = generate_attribution_graph(
            prompt=prompt,
            api_key=api_key,
            model_id=model_id,
            source_set_name=source_set_name,
            slug=slug if slug.strip() else None,
            max_n_logits=max_n_logits,
            desired_logit_prob=desired_logit_prob,
            node_threshold=node_threshold,
            edge_threshold=edge_threshold,
            max_feature_nodes=max_feature_nodes,
            save_locally=save_locally,
            verbose=False
        )
        
        progress_bar.progress(100)
        status_text.empty()
        progress_bar.empty()
        
        # Add generation parameters to result for later use
        if result['success']:
            result['source_set_name'] = source_set_name
            result['node_threshold'] = node_threshold
            result['desired_logit_prob'] = desired_logit_prob
        
        # Rename file to new format if saved locally (BEFORE saving to session_state)
        if result['success'] and result.get('local_path') and save_locally and PIPELINE_STATE_AVAILABLE:
            old_path = Path(result['local_path'])
            if old_path.exists():
                # Generate new filename with st1_ prefix
                new_filename = PipelineState.generate_filename(
                    step=1,
                    file_type='graph',
                    prompt=prompt
                )
                new_path = old_path.parent / new_filename
                
                # Rename file
                old_path.rename(new_path)
                
                # Update result with absolute path (AFTER rename)
                result['local_path'] = str(new_path.resolve())
                result['renamed_to_new_format'] = True
        
        # Save result to session_state (with updated path)
        st.session_state.generation_result = result
        
        # Save Graph JSON to pipeline session_state for auto-loading in next steps
        if result['success'] and result.get('local_path'):
            try:
                with open(result['local_path'], 'r', encoding='utf-8') as f:
                    graph_data = json.load(f)
                
                st.session_state['pipeline_graph_json'] = {
                    'data': graph_data,
                    'filename': Path(result['local_path']).name,
                    'timestamp': datetime.now().isoformat()
                }
            except Exception as e:
                # Don't break the flow if saving to pipeline state fails
                pass
        
        # Build Neuronpedia URL
        if result['success']:
            neuronpedia_url = (
                f"https://www.neuronpedia.org/{result.get('model_id', 'gemma-2-2b')}/graph"
                f"?sourceSet={result.get('source_set_name', 'clt-hp')}"
                f"&slug={result.get('slug', '')}"
                f"&pruningThreshold={result.get('node_threshold', 0.8)}"
                f"&densityThreshold={result.get('desired_logit_prob', 0.95)}"
            )
            
            # Get the filename for display
            if result.get('local_path'):
                filename = Path(result['local_path']).name
                st.success(f"βœ… Graph generated successfully: `{filename}`\n\n" f"[**Open Graph on Neuronpedia**]({neuronpedia_url})")
                
                # Auto-download the generated graph JSON
                try:
                    import streamlit.components.v1 as components
                    import base64
                    
                    with open(result['local_path'], 'r', encoding='utf-8') as f:
                        graph_json_content = f.read()
                    
                    # Encode to base64 for JavaScript
                    b64 = base64.b64encode(graph_json_content.encode()).decode()
                    
                    # Auto-download with JavaScript
                    html = f"""
                    <script>
                    function downloadFile() {{
                        const link = document.createElement('a');
                        link.href = 'data:application/json;base64,{b64}';
                        link.download = '{filename}';
                        document.body.appendChild(link);
                        link.click();
                        document.body.removeChild(link);
                    }}
                    // Trigger download after a short delay
                    setTimeout(downloadFile, 100);
                    </script>
                    """
                    components.html(html, height=50)
                    
                except Exception as e:
                    st.warning(f"⚠️ Could not prepare auto-download: {e}")
        else:
            # Generation failed
            error_msg = result.get('error', 'Unknown error')
            st.error(f"❌ Graph generation failed!\n\nError: {error_msg}")
            
            # Show details if available
            if result.get('details'):
                with st.expander("Error details"):
                    st.code(result['details'])
    
    
    except Exception as e:
        progress_bar.empty()
        status_text.empty()
        st.error(f"Unexpected error: {str(e)}")
        with st.expander("Details"):
            import traceback
            st.code(traceback.format_exc())

st.markdown("---")

# ===== SECTION: ANALYZE GRAPH =====

st.subheader("2️⃣ Analyze Graph")

# Check if we just generated a graph
just_generated = st.session_state.get('generation_result') and st.session_state.generation_result.get('success')
generated_path = st.session_state.generation_result.get('local_path') if just_generated else None

if just_generated and generated_path:
    # Auto-select the just-generated graph
    from pathlib import Path as PathLib
    
    # Use the absolute path directly - we know it exists
    gen_path = PathLib(generated_path)
    
    # Store the absolute path for use later
    selected_json = str(gen_path)
    
    st.info(f"πŸ“Š **Ready to analyze**: `{gen_path.name}` (just generated)")
    
    # Option to select a different file
    with st.expander("πŸ“ Select a different graph file", expanded=False):
        json_dir = parent_dir / "output" / "graph_data"
        
        # Create directory if it doesn't exist
        if not json_dir.exists():
            try:
                json_dir.mkdir(parents=True, exist_ok=True)
            except Exception:
                pass  # Silently fail in expander
        
        if json_dir.exists():
            json_files = sorted(json_dir.glob("*.json"), key=lambda x: x.stat().st_mtime, reverse=True)
            if json_files:
                json_options = [str(f.relative_to(parent_dir)) for f in json_files]
                # Find index of generated file
                default_idx = 0
                try:
                    default_idx = json_options.index(selected_json)
                except ValueError:
                    pass
                
                selected_json_alt = st.selectbox(
                    "Select JSON file",
                    options=json_options,
                    index=default_idx,
                    key="alt_json_select",
                    help="JSON files sorted by date (most recent first)"
                )
                if st.button("Use this file instead"):
                    selected_json = selected_json_alt
                    st.rerun()
            else:
                st.info("No other graph files found in `output/graph_data/`")
        else:
            st.warning("Directory `output/graph_data/` not accessible")
else:
    # Normal file selection (no graph just generated)
    st.write("""
    Extract static metrics (`node_influence`, `cumulative_influence`, `frac_external_raw`) from an existing graph.
    """)
    
    json_dir = parent_dir / "output" / "graph_data"
    
    # Create directory if it doesn't exist
    if not json_dir.exists():
        try:
            json_dir.mkdir(parents=True, exist_ok=True)
        except Exception as e:
            st.warning(f"⚠️ Could not create directory: {e}")
    
    if json_dir.exists():
        json_files = sorted(json_dir.glob("*.json"), key=lambda x: x.stat().st_mtime, reverse=True)
        
        if json_files:
            # Use relative paths for display
            json_options = [str(f.relative_to(parent_dir)) for f in json_files]
            selected_json = st.selectbox(
                "Select JSON file",
                options=json_options,
                help="JSON files sorted by date (most recent first)"
            )
        else:
            st.warning("No JSON files found in `output/graph_data/`")
            selected_json = None
    else:
        st.warning("Directory `output/graph_data/` not found")
        selected_json = None

# Show file info and analysis button if we have a selected file
if selected_json:
    # Handle both absolute and relative paths
    file_path = Path(selected_json)
    if not file_path.is_absolute():
        file_path = parent_dir / selected_json
    
    # Check if file exists - if not, try to use session_state data
    file_exists = file_path.exists()
    
    # If file doesn't exist but we have data in session_state (just generated), use that
    use_session_data = False
    if not file_exists and just_generated and 'pipeline_graph_json' in st.session_state:
        graph_data = st.session_state['pipeline_graph_json']['data']
        use_session_data = True
        st.info("πŸ“¦ Using graph data from current session (file system is temporary on HF Spaces)")
    elif not file_exists:
        st.error(f"❌ File not found: `{file_path.name}`")
        st.warning("The file may have been moved or renamed. Please refresh the page or select another file.")
        st.stop()
    
    # Get file stats and metadata
    if use_session_data:
        # Use data from session_state
        file_size = len(json.dumps(graph_data)) / 1024 / 1024  # Approximate size
        file_time = datetime.fromisoformat(st.session_state['pipeline_graph_json']['timestamp'])
        num_nodes = len(graph_data.get('nodes', []))
        num_links = len(graph_data.get('links', []))
        model_id = graph_data.get('metadata', {}).get('model_id', 'N/A')
    else:
        # Use file on disk
        file_size = file_path.stat().st_size / 1024 / 1024
        file_time = datetime.fromtimestamp(file_path.stat().st_mtime)
        
        # Load JSON to extract graph metadata
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                graph_metadata = json.load(f)
            num_nodes = len(graph_metadata.get('nodes', []))
            num_links = len(graph_metadata.get('links', []))
            model_id = graph_metadata.get('metadata', {}).get('model_id', 'N/A')
        except Exception:
            num_nodes = None
            num_links = None
            model_id = None
    
    # Display file info and graph metadata
    col1, col2, col3 = st.columns(3)
    with col1:
        st.metric("Size", f"{file_size:.2f} MB")
    with col2:
        st.metric("Date", file_time.strftime("%Y-%m-%d %H:%M"))
    with col3:
        st.metric("Name", file_path.name[:20] + "...")
    
    if num_nodes is not None and num_links is not None and model_id is not None:
        col4, col5, col6 = st.columns(3)
        with col4:
            st.metric("Nodes", num_nodes)
        with col5:
            st.metric("Links", num_links)
        with col6:
            st.metric("Model", model_id)
    
    # Extract button
    button_label = "πŸ“Š Analyze This Graph" if just_generated else "πŸ“Š Analyze Graph"
    if st.button(button_label, key="extract_existing", type="primary"):
        try:
            with st.spinner("Extracting metrics..."):
                # Use graph_data from session_state if available, otherwise load from file
                if use_session_data:
                    # Already have graph_data from session_state
                    pass
                else:
                    # Load from file
                    json_full_path = str(parent_dir / selected_json)
                    with open(json_full_path, 'r', encoding='utf-8') as f:
                        graph_data = json.load(f)
                
                csv_output_path = str(parent_dir / "output" / "graph_feature_static_metrics.csv")
                df = extract_static_metrics_from_json(
                    graph_data,
                    output_path=csv_output_path,
                    verbose=False
                )
                
                # Save in session_state to persist across reruns
                st.session_state.extracted_graph_data = graph_data
                st.session_state.extracted_csv_df = df
                st.session_state.analysis_performed = True
            
            st.success(f"βœ… CSV generated: `{csv_output_path}`")
            st.info("πŸ“Š Scroll down to see interactive visualizations")
            
        except Exception as e:
            st.error(f"❌ Error: {str(e)}")

st.markdown("---")

# ===== EXTRACTED DATA VISUALIZATION (persists across reruns) =====

if st.session_state.extracted_graph_data is not None and st.session_state.extracted_csv_df is not None:
    graph_data = st.session_state.extracted_graph_data
    df = st.session_state.extracted_csv_df
    
    # Only show if analysis was performed
    if st.session_state.get('analysis_performed', False):
        st.header("Extracted Data Analysis")
        
        # CSV Metrics
        col1, col2, col3, col4, col5 = st.columns(5)
        with col1:
            st.metric("Features", len(df))
        with col2:
            st.metric("Unique Tokens", df['ctx_idx'].nunique())
        with col3:
            st.metric("Mean Activation", f"{df['activation'].mean():.3f}")
        with col4:
            # Use node_influence (marginal influence) for total sum
            st.metric("Sum Node Infl", f"{df['node_influence'].sum():.2f}")
        with col5:
            st.metric("Mean Frac Ext", f"{df['frac_external_raw'].mean():.3f}")
        
        with st.expander("View Complete Dataframe", expanded=False):
            st.dataframe(df, use_container_width=True, height=600)
        
        # Scatter plot: Layer vs Context Position with Influence

        # Prepare data from JSON for scatter plot
        if 'nodes' in graph_data:
            import pandas as pd
            import plotly.express as px
            
            # Extract prompt_tokens from metadata to map ctx_idx -> token
            prompt_tokens = graph_data.get('metadata', {}).get('prompt_tokens', [])
            
            # Scatter plot visualization with filter
            from eda.utils.graph_visualization import create_scatter_plot_with_filter
            filtered_features = create_scatter_plot_with_filter(graph_data)
            
            # Save filtered_features for export section
            if filtered_features is not None and len(filtered_features) > 0:
                st.session_state.filtered_features_export = filtered_features

# ===== SUMMARY CHARTS: COVERAGE AND STRENGTH =====
# Only show if analysis was performed
if st.session_state.get('analysis_performed', False):
    # Data source: prefer extracted data, otherwise last generated graph
    graph_data_for_plots = None
    if st.session_state.get('extracted_graph_data') is not None:
        graph_data_for_plots = st.session_state.extracted_graph_data
    elif st.session_state.get('generation_result') is not None and st.session_state.generation_result.get('success'):
        graph_data_for_plots = st.session_state.generation_result.get('graph_data')

    if graph_data_for_plots is not None and 'nodes' in graph_data_for_plots:
        with st.expander("Summary Charts: Coverage and Strength", expanded=False):
            import pandas as pd
            import plotly.express as px
            import numpy as np

            nodes_df = pd.DataFrame(graph_data_for_plots['nodes'])
            is_feature = nodes_df['node_id'].astype(str).str[0].str.isdigit() & nodes_df['node_id'].astype(str).str.contains('_')
            feat_nodes = nodes_df.loc[is_feature].copy()
            
            if len(feat_nodes) == 0:
                st.warning("No features found in current data.")
            else:
                # Add slider to filter (reuse same logic as create_scatter_plot_with_filter)
                max_influence = feat_nodes['influence'].max()
                
                st.markdown("### Filter Features by Cumulative Influence")
                st.info(f"""
                **Use the slider to filter the charts below** based on cumulative influence coverage (0-{max_influence:.2f}).
                Summary charts will show only features with `influence <= threshold`.
                """)
                
                # Check if main slider already exists (from create_scatter_plot_with_filter)
                # If it exists, use it, otherwise create a new one
                slider_key = "cumulative_slider_summary"
                if "cumulative_slider_main" in st.session_state:
                    # Reuse main slider value
                    cumulative_threshold_summary = st.session_state.cumulative_slider_main
                    st.info(f"Synchronized with main slider: threshold = {cumulative_threshold_summary:.4f}")
                else:
                    # Create separate slider
                    cumulative_threshold_summary = st.slider(
                        "Cumulative Influence Threshold (summary charts)",
                        min_value=0.0,
                        max_value=float(max_influence),
                        value=float(max_influence),
                        step=0.01,
                        key=slider_key,
                        help=f"Keep only features with influence <= threshold. Range: 0.0 - {max_influence:.2f}"
                    )
                
                # Apply filter
                feat_nodes_filtered = feat_nodes[feat_nodes['influence'] <= cumulative_threshold_summary].copy()
                
                if len(feat_nodes_filtered) == 0:
                    st.warning("No features match the current filter. Increase the threshold.")
                else:
                    # Show filter statistics
                    col1, col2, col3 = st.columns(3)
                    with col1:
                        st.metric("Total Features", len(feat_nodes))
                    with col2:
                        st.metric("Filtered Features", len(feat_nodes_filtered))
                    with col3:
                        pct = (len(feat_nodes_filtered) / len(feat_nodes) * 100) if len(feat_nodes) > 0 else 0
                        st.metric("% Kept", f"{pct:.1f}%")
                    
                    st.markdown("---")
                    
                    # Calculate n_ctx and statistics per feature
                    feat_nodes_filtered['feature_key'] = feat_nodes_filtered['node_id'].str.rsplit('_', n=1).str[0]
                    cov = (
                        feat_nodes_filtered.groupby('feature_key')['ctx_idx'].nunique()
                        .rename('n_ctx').reset_index()
                    )
                    per_feat = (
                        feat_nodes_filtered.groupby('feature_key')
                        .agg(mean_influence=('influence','mean'),
                             mean_activation=('activation','mean'))
                        .reset_index()
                    )
                    per_feat_cov = per_feat.merge(cov, on='feature_key', how='left')
                    nodes_with_cov = feat_nodes_filtered.merge(cov, on='feature_key', how='left')

                    # Chart 1: Coverage (Histogram + ECDF)
                    st.subheader("Feature Coverage (n_ctx)")
                    c1, c2 = st.columns(2)
                    with c1:
                        fig_hist = px.histogram(cov, x='n_ctx', color_discrete_sequence=['#4C78A8'])
                        fig_hist.update_layout(title='n_ctx distribution per feature',
                                               xaxis_title='Number of unique ctx_idx',
                                               yaxis_title='Number of features')
                        st.plotly_chart(fig_hist, use_container_width=True)
                    with c2:
                        fig_ecdf = px.ecdf(cov, x='n_ctx', color_discrete_sequence=['#F58518'])
                        fig_ecdf.update_layout(title='n_ctx ECDF',
                                               xaxis_title='Number of unique ctx_idx',
                                               yaxis_title='Cumulative fraction')
                        st.plotly_chart(fig_ecdf, use_container_width=True)

                    # Chart 2: Strength vs Coverage (Activation vs n_ctx and Scatter mean)
                    st.subheader("Strength vs Coverage")
                    c3, c4 = st.columns(2)
                    with c3:
                        fig_violin = px.violin(nodes_with_cov, x='n_ctx', y='activation', box=True, points=False)
                        fig_violin.update_layout(title='Activation per n_ctx',
                                                 xaxis_title='n_ctx (feature)',
                                                 yaxis_title='Activation (node)')
                        st.plotly_chart(fig_violin, use_container_width=True)
                    with c4:
                        fig_scatter = px.scatter(per_feat_cov, x='mean_activation', y='mean_influence',
                                                 color='n_ctx', size='n_ctx', hover_data=['feature_key'],
                                                 color_continuous_scale='Viridis')
                        # Correlations for subtitle
                        if len(per_feat_cov) >= 2:
                            pearson = float(per_feat_cov['mean_activation'].corr(per_feat_cov['mean_influence'], method='pearson'))
                            spearman = float(per_feat_cov['mean_activation'].corr(per_feat_cov['mean_influence'], method='spearman'))
                            fig_scatter.update_layout(title=f'Mean activation vs mean influence<br>(r={pearson:.2f}, rho={spearman:.2f})')
                        else:
                            fig_scatter.update_layout(title='Mean activation vs mean influence')
                        fig_scatter.update_layout(xaxis_title='Mean activation (per feature)',
                                                  yaxis_title='Mean influence (per feature)')
                        st.plotly_chart(fig_scatter, use_container_width=True)
                    
                    # Quick insights
                    with st.expander("Insights from charts", expanded=False):
                        # Calculate key statistics
                        top_n_ctx = cov['n_ctx'].max()
                        n_top = len(cov[cov['n_ctx'] == top_n_ctx])
                        top_features = cov[cov['n_ctx'] == top_n_ctx]['feature_key'].tolist()
                        
                        st.markdown(f"""
                        **Coverage (n_ctx)**:
                        - {len(cov)} unique features in filtered dataset
                        - {n_top} features present in all {top_n_ctx} contexts
                        - Multi-context features ({top_n_ctx}): {', '.join([f'`{f}`' for f in top_features[:5]])}
                        
                        **Strength vs Coverage**:
                        - Activation-influence correlation: **r={pearson:.2f}** (Pearson), **rho={spearman:.2f}** (Spearman)
                        - {"Negative correlation: features with high activation tend to have low influence" if pearson < -0.2 else "Weak or positive correlation between activation and influence"}
                        """)
                        
                        # Group statistics
                        if len(nodes_with_cov) > 0:
                            g1 = nodes_with_cov[nodes_with_cov['n_ctx'] == 1]
                            g_multi = nodes_with_cov[nodes_with_cov['n_ctx'] >= 5]
                            
                            if len(g1) > 0 and len(g_multi) > 0:
                                st.markdown(f"""
                                **Group comparison**:
                                - n_ctx=1: {len(g1)} nodes, mean_activation={g1['activation'].mean():.2f}, mean_influence={g1['influence'].mean():.3f}
                                - n_ctx>=5: {len(g_multi)} nodes, mean_activation={g_multi['activation'].mean():.2f}, mean_influence={g_multi['influence'].mean():.3f}
                                """)

# ===== EXPORT SELECTED FEATURES =====

if st.session_state.get('analysis_performed', False) and st.session_state.get('filtered_features_export') is not None:
    filtered_features = st.session_state.filtered_features_export
    
    if len(filtered_features) > 0:
        st.markdown("---")
        st.subheader("Export Selected Features")
        
        # Convert dataframe to format [{"layer": X, "index": Y}, ...]
        # Remove duplicates using set of tuples (layer, feature)
        unique_features = {
            (int(row['layer']), int(row['feature']))
            for _, row in filtered_features.iterrows()
        }
        
        # Convert to sorted list of dicts
        features_export = [
            {"layer": layer, "index": feature}
            for layer, feature in sorted(unique_features)
        ]
        
        # Also extract selected node_ids (for subgraph upload)
        node_ids_export = sorted(filtered_features['id'].unique().tolist())
        
        # Create complete export with features AND node_ids
        export_data = {
            "features": features_export,
            "node_ids": node_ids_export,
            "metadata": {
                "n_features": len(features_export),
                "n_nodes": len(node_ids_export),
                "cumulative_threshold": st.session_state.get('cumulative_slider_main', None),
                "exported_at": datetime.now().isoformat()
            }
        }
        
        # Statistics
        col1, col2, col3 = st.columns(3)
        with col1:
            st.metric("Unique Features", len(features_export))
        with col2:
            st.metric("Selected Nodes", len(node_ids_export))
        with col3:
            st.metric("Unique Layers", len({f['layer'] for f in features_export}))
        
        # Save to pipeline session_state for auto-loading in next steps
        st.session_state['pipeline_selected_nodes'] = {
            'data': export_data,
            'filename': f"st1_feat_node_subset_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
            'timestamp': datetime.now().isoformat()
        }
        
        # Download JSON (complete format)
        st.download_button(
            label="πŸ“₯ Download Features+Nodes Subset",
            data=json.dumps(export_data, indent=2, ensure_ascii=False),
            file_name="selected_features_with_nodes.json",
            mime="application/json",
            help="Complete format with features and node_ids (for Node Grouping + Probe Prompts + batch_get_activations.py)",
            use_container_width=True,
            type="primary"
        )
        
        # LEGACY BUTTON (hidden - all tools now support complete format)
        # with col_legacy:
        #     st.download_button(
        #         label="Download Features JSON (legacy)",
        #         data=json.dumps(features_export, indent=2, ensure_ascii=False),
        #         file_name="selected_features.json",
        #         mime="application/json",
        #         help="Legacy format (features only, compatible with batch_get_activations.py)"
        #     )
        
        # Preview
        with st.expander("Preview Complete Export", expanded=False):
            st.json({
                "features": features_export[:5],
                "node_ids": node_ids_export[:10],
                "metadata": export_data["metadata"]
            })

# ===== FOOTER =====

st.sidebar.markdown("---")
st.sidebar.subheader("Info")
st.sidebar.markdown("""
**Attribution Graph**: visualizes how SAE features contribute to predictions.

**Elements**:
- Embedding nodes: input tokens
- Feature nodes: SAE latents
- Logit nodes: predicted tokens
""")

st.sidebar.caption("Powered by Neuronpedia API")