File size: 39,482 Bytes
0b73078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62f7eb6
0b73078
62f7eb6
0b73078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f5e74b
9b80da8
4f5e74b
0b73078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b80da8
 
 
 
 
 
 
 
 
 
 
 
 
 
0b73078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b80da8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b73078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b80da8
0b73078
 
9b80da8
 
 
 
 
 
 
 
 
 
 
 
 
 
0b73078
9b80da8
 
 
 
0b73078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
#!/usr/bin/env python3
"""
Root Cause Discovery Trajectory Analysis

Analyzes how agents discover root cause entities:
- T_encounter: When GT entity first appears in tool output
- T_investigate: When agent actively queries GT entity  
- T_assert: When agent asserts GT entity as root cause
- T_exonerate: When agent dismisses GT entity (if ever)
- T_recover: When agent corrects after exoneration

Metrics computed:
- Discovery efficiency (how early GT appears)
- Investigation delay (turns between seeing and investigating)
- Assertion delay (turns to confirm after investigating)
- Recovery rate (% of trials with successful recovery)
"""

import json
import sys
import re
import yaml
from pathlib import Path
from dataclasses import dataclass, field, asdict
from typing import Optional, List, Dict, Any
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from tqdm import tqdm

PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

from analysis_src.utils import find_latest_rollout_file

from analysis_src.model_styles import (
    get_display_name, MIN_FONT_SIZE, SINGLE_COLUMN_WIDTH, DOUBLE_COLUMN_WIDTH, PLOT_PARAMETERS
)

# Improved regex to capture standard K8s resource patterns: namespace/Kind/name
# Captures: (namespace, Kind, name)
K8S_ENTITY_PATTERN = re.compile(r'([\w-]+)/(Deployment|Service|Pod|ReplicaSet|ResourceQuota|StatefulSet|DaemonSet|Job|CronJob|ConfigMap|Secret|Endpoints|Ingress|PersistentVolumeClaim|PersistentVolume|ServiceAccount|Role|RoleBinding|ClusterRole|ClusterRoleBinding|NetworkPolicy|HorizontalPodAutoscaler|Node)/([\w-]+)', re.IGNORECASE)

def extract_k8s_entities(text: str) -> List[str]:
    """Extract all K8s entities matching the standard pattern."""
    matches = K8S_ENTITY_PATTERN.findall(text)
    entities = []
    for m in matches:
        # Normalize to namespace/Kind/name
        entity = f"{m[0]}/{m[1]}/{m[2]}"
        entities.append(entity)
    return entities

# Paths
PROJECT_ROOT = Path(__file__).parent.parent
LEADERBOARD_DIR = PROJECT_ROOT / "ITBench-SRE-Agent" / "ITBench-Trajectories" / "ReAct-Agent-Trajectories"
GT_DIR = PROJECT_ROOT / "ITBench-SRE-Agent" / "ITBench-Lite" / "snapshots" / "sre"
OUTPUT_DIR = PROJECT_ROOT / "ITBench-SRE-Agent" / "ITBench-Trajectories" / "output" / "discovery"

@dataclass
class GroundTruth:
    """Ground truth root cause entity info."""
    scenario: str
    entity_name: str
    entity_kind: str
    group_id: str
    filters: List[str]  # regex patterns to match entity
    aliases: List[str]  # related entity group IDs
    propagation_entities: set = field(default_factory=set)  # All entities involved in propagation
    all_entities: list = field(default_factory=list)  # All entities defined in the scenario
    entity_filters: Dict[str, List[str]] = field(default_factory=dict)  # group_id -> filters mapping for all entities


@dataclass 
class EntityMention:
    """A mention of an entity in the agent's trajectory."""
    turn: int
    mention_type: str  # 'encounter', 'investigate', 'assert', 'exonerate'
    context: str  # 'tool_output', 'tool_args', 'reasoning', 'final_output'
    text_snippet: str
    sentiment: str  # 'positive', 'negative', 'neutral'


@dataclass
class TrajectoryAnalysis:
    """Analysis results for a single trial."""
    model: str
    scenario: str
    trial: int
    total_turns: int
    gt_entity: str
    
    # Key timestamps (turn numbers, None if not found)
    t_encounter: Optional[int] = None
    t_investigate: Optional[int] = None  
    t_assert: Optional[int] = None
    t_exonerate: Optional[int] = None
    t_recover: Optional[int] = None
    
    # Final outcome (from judge scores if available)
    final_success: bool = False  # Did the final answer include GT?
    root_cause_f1: Optional[float] = None
    
    # Pipeline stage reached (for funnel analysis)
    # 0=none, 1=encounter, 2=investigate, 3=assert, 4=success
    max_stage_reached: int = 0
    
    # All mentions for detailed analysis
    mentions: List[EntityMention] = field(default_factory=list)
    
    # Exploration metrics
    total_entities_available: int = 0
    unique_entities_encountered: int = 0
    unique_entities_investigated: int = 0
    exploration_ratio: float = 0.0  # investigated / available
    
    # Coverage metrics
    on_chain_investigated: int = 0
    off_chain_investigated: int = 0  # Detoured
    propagation_coverage: float = 0.0  # % of chain entities investigated
    detour_rate: float = 0.0  # off_chain / total_investigated
    
    # Computed metrics
    discovery_efficiency: Optional[float] = None  # t_encounter / total_turns
    investigation_delay: Optional[int] = None  # t_investigate - t_encounter
    assertion_delay: Optional[int] = None  # t_assert - t_investigate
    had_recovery: bool = False


def check_entity_match(text: str, entity_info: Dict) -> bool:
    """Check if text matches an arbitrary entity definition."""
    text_lower = text.lower()
    entity_name = entity_info.get('id', '').lower()
    
    # Check direct name match
    if entity_name and entity_name in text_lower:
        return True
    
    # Check filters
    filters = entity_info.get('filter', [])
    for pattern in filters:
        search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '')
        if search_term and search_term.lower() in text_lower:
            return True
            
    return False


def load_ground_truth(scenario: str) -> Optional[GroundTruth]:
    """Load and parse ground truth YAML for a scenario.

    Searches for ground_truth.yaml in GT_DIR/v0.2-*/scenario/ground_truth.yaml
    """
    # Find the version directory (e.g., v0.2-something)
    version_dirs = [d for d in GT_DIR.iterdir() if d.is_dir() and d.name.startswith("v0.2-")]

    for version_dir in version_dirs:
        gt_path = version_dir / scenario / "ground_truth.yaml"
        if gt_path.exists():
            with open(gt_path) as f:
                gt_data = yaml.safe_load(f)
            break
    else:
        return None
    
    # Find the root cause group
    root_cause_group = None
    all_groups = gt_data.get('groups', [])
    
    for group in all_groups:
        if group.get('root_cause', False):
            root_cause_group = group
            break
    
    if not root_cause_group:
        return None
    
    # Get fault entity info
    fault_list = gt_data.get('fault', [])
    fault_info = fault_list[0] if fault_list else {}
    entity_info = fault_info.get('entity', {})
    
    # Collect all aliases
    aliases = []
    for alias_group in gt_data.get('aliases', []):
        if root_cause_group['id'] in alias_group:
            aliases.extend(alias_group)
            
    # Collect all entities in propagation chain
    propagation_entities = set()
    for prop in gt_data.get('propagations', []):
        if 'source' in prop:
            propagation_entities.add(prop['source'])
        if 'target' in prop:
            propagation_entities.add(prop['target'])
    
    # Add root cause itself if not already there (it should be as source)
    propagation_entities.add(root_cause_group['id'])
    
    # Build entity_filters mapping: group_id -> list of filter patterns
    entity_filters = {}
    for group in all_groups:
        group_id = group.get('id', '')
        filters = group.get('filter', [])
        # Also use the group id itself and 'name' field as filters
        name = group.get('name', '')
        all_filters = list(filters) if filters else []
        if group_id:
            all_filters.append(group_id)
        if name and name != group_id:
            all_filters.append(name)
        entity_filters[group_id] = all_filters
    
    gt_obj = GroundTruth(
        scenario=scenario,
        entity_name=entity_info.get('name', root_cause_group['id']),
        entity_kind=root_cause_group.get('kind', 'Unknown'),
        group_id=root_cause_group['id'],
        filters=root_cause_group.get('filter', []),
        aliases=aliases,
        propagation_entities=propagation_entities,
        entity_filters=entity_filters
    )
    
    # Attach all entities for exploration analysis
    gt_obj.all_entities = all_groups
    return gt_obj


def entity_matches(text: str, gt: GroundTruth) -> bool:
    """Check if text mentions the ground truth entity."""
    text_lower = text.lower()
    
    # Check direct name match
    if gt.entity_name.lower() in text_lower:
        return True
    
    # Check group_id match
    if gt.group_id.lower().replace('-', ' ') in text_lower.replace('-', ' '):
        return True
    if gt.group_id.lower().replace('-', '') in text_lower.replace('-', ''):
        return True
    
    # Check filter patterns
    for pattern in gt.filters:
        # Convert filter pattern to regex-friendly form
        search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '')
        if search_term.lower() in text_lower:
            return True
    
    # Check aliases
    for alias in gt.aliases:
        alias_term = alias.replace('-', ' ').lower()
        if alias_term in text_lower.replace('-', ' '):
            return True
    
    return False


def is_entity_on_chain(entity_str: str, gt: GroundTruth) -> Optional[str]:
    """
    Check if an entity string matches any entity in the fault propagation chain.
    Returns the matched group_id if on-chain, None if off-chain.
    
    entity_str: e.g., "otel-demo/Pod/frontend-abc123" or just "frontend"
    """
    entity_lower = entity_str.lower()
    
    # For each propagation entity, check if entity_str matches its filters
    for group_id in gt.propagation_entities:
        filters = gt.entity_filters.get(group_id, [])
        
        # Check group_id itself
        if group_id.lower() in entity_lower or entity_lower in group_id.lower():
            return group_id
        
        # Check filter patterns
        for pattern in filters:
            # Clean up the regex pattern for simple matching
            search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '').replace('\\', '')
            if search_term and len(search_term) > 2:  # Avoid too short matches
                if search_term.lower() in entity_lower:
                    return group_id
    
    return None


def get_entity_group_match(entity_str: str, gt: GroundTruth) -> Optional[str]:
    """
    Check if an entity string matches any entity group in the scenario.
    Returns the matched group_id if found, None otherwise.
    """
    entity_lower = entity_str.lower()
    
    for group in gt.all_entities:
        group_id = group.get('id', '')
        filters = group.get('filter', [])
        name = group.get('name', '')
        
        # Check group_id
        if group_id and (group_id.lower() in entity_lower or entity_lower in group_id.lower()):
            return group_id
        
        # Check name
        if name and (name.lower() in entity_lower or entity_lower in name.lower()):
            return group_id
        
        # Check filter patterns
        for pattern in filters:
            search_term = pattern.replace('\\b', '').replace('-.*', '').replace('.*', '').replace('\\', '')
            if search_term and len(search_term) > 2:
                if search_term.lower() in entity_lower:
                    return group_id
    
    return None


def classify_sentiment(text: str, gt: GroundTruth) -> str:
    """Classify if mention is positive (asserting), negative (exonerating), or neutral."""
    text_lower = text.lower()
    
    # Find the sentence/context containing the entity
    entity_term = gt.entity_name.lower()
    
    # Positive indicators (asserting as root cause)
    positive_patterns = [
        r'root\s*cause',
        r'is\s+the\s+cause',
        r'caused\s+by',
        r'source\s+of\s+(the\s+)?problem',
        r'culprit',
        r'responsible\s+for',
        r'likely\s+cause',
        r'appears\s+to\s+be\s+the\s+issue',
        r'primary\s+issue',
        r'main\s+issue',
    ]
    
    # Negative indicators (exonerating)
    negative_patterns = [
        r'not\s+the\s+(root\s*)?cause',
        r'ruled\s+out',
        r'is\s+not\s+responsible',
        r'working\s+(correctly|normally|fine)',
        r'healthy',
        r'no\s+issues?\s+(found|detected)',
        r'can\s+be\s+excluded',
        r'unlikely\s+to\s+be',
    ]
    
    for pattern in positive_patterns:
        if re.search(pattern, text_lower):
            return 'positive'
    
    for pattern in negative_patterns:
        if re.search(pattern, text_lower):
            return 'negative'
    
    return 'neutral'


def get_latest_rollout(trial_dir: Path) -> Optional[Path]:
    """Get the latest rollout file from a trial directory."""
    sessions_dir = trial_dir / "sessions"
    if not sessions_dir.exists():
        return None
    
    rollout_files = list(sessions_dir.glob("**/rollout-*.jsonl"))
    if not rollout_files:
        return None
    
    # Sort by modification time, get latest
    return max(rollout_files, key=lambda p: p.stat().st_mtime)


def get_judge_score(trial_dir: Path) -> Optional[float]:
    """Get root_cause_entity_f1 from judge output."""
    judge_path = trial_dir / "judge_output.json"
    if not judge_path.exists():
        return None
    
    try:
        with open(judge_path) as f:
            judge_data = json.load(f)
        return judge_data.get('flat_scores', {}).get('root_cause_entity_f1')
    except:
        return None


def parse_rollout(rollout_path: Path, gt: GroundTruth) -> TrajectoryAnalysis:
    """Parse a rollout file and extract entity mentions."""
    mentions = []
    turn_num = 0
    total_turns = 0
    
    t_encounter = None
    t_investigate = None
    t_assert = None
    t_exonerate = None
    t_recover = None
    
    # Exploration tracking
    # We want to track unique entities from the SCENARIO that were touched
    # gt.filters contains patterns for the root cause.
    # But we want patterns for ALL entities in the scenario.
    # The GroundTruth class only has root cause info currently. 
    # We need to pass the full list of scenario entities.
    
    # Wait, GroundTruth class needs updating first to hold all scenario entities.
    # Currently it only holds root cause info.
    # Let's assume the caller will update GroundTruth definition or pass a list of entities.
    
    # Actually, let's update this function to work with the updated GroundTruth class
    # which will be updated in the next step.
    
    encountered_entities = set()
    investigated_entities = set()
    
    # Track which entity groups were investigated (on-chain vs off-chain)
    on_chain_groups_investigated = set()
    off_chain_groups_investigated = set()
    all_groups_investigated = set()
    
    with open(rollout_path) as f:
        for line in f:
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                continue
            
            if obj.get('type') == 'turn_context':
                turn_num += 1
                total_turns = turn_num
            
            if obj.get('type') != 'response_item':
                continue
            
            payload = obj.get('payload', {})
            
            # Check tool outputs (encounter)
            if payload.get('type') == 'function_call_output':
                output = str(payload.get('output', ''))
                
                # Check for root cause match
                if entity_matches(output, gt):
                    sentiment = classify_sentiment(output, gt)
                    mentions.append(EntityMention(
                        turn=turn_num,
                        mention_type='encounter',
                        context='tool_output',
                        text_snippet=output[:200],
                        sentiment=sentiment
                    ))
                    if t_encounter is None:
                        t_encounter = turn_num
                
                # Broad exploration check using Regex
                found_entities = extract_k8s_entities(output)
                for entity in found_entities:
                    encountered_entities.add(entity)

            # Check tool arguments (investigate)
            if payload.get('type') == 'function_call':
                args = payload.get('arguments', {})
                if isinstance(args, str):
                    try:
                        args = json.loads(args)
                    except:
                        args = {'raw': args}
                args_str = json.dumps(args)
                
                # Root cause check
                if entity_matches(args_str, gt):
                    mentions.append(EntityMention(
                        turn=turn_num,
                        mention_type='investigate',
                        context='tool_args',
                        text_snippet=args_str[:200],
                        sentiment='neutral'
                    ))
                    if t_investigate is None:
                        t_investigate = turn_num
                
                # Broad exploration check using Regex
                found_entities = extract_k8s_entities(args_str)
                for entity in found_entities:
                    investigated_entities.add(entity)
                    
                    # Classify as on-chain or off-chain
                    on_chain_group = is_entity_on_chain(entity, gt)
                    if on_chain_group:
                        on_chain_groups_investigated.add(on_chain_group)
                        all_groups_investigated.add(on_chain_group)
                    else:
                        # Check if it matches any entity in scenario at all
                        any_group = get_entity_group_match(entity, gt)
                        if any_group:
                            off_chain_groups_investigated.add(any_group)
                            all_groups_investigated.add(any_group)
                
                # Check update_plan for assertions/reasoning
                if payload.get('name') == 'update_plan':
                    explanation = args.get('explanation', '')
                    if entity_matches(explanation, gt):
                        sentiment = classify_sentiment(explanation, gt)
                        mention_type = 'assert' if sentiment == 'positive' else ('exonerate' if sentiment == 'negative' else 'investigate')
                        mentions.append(EntityMention(
                            turn=turn_num,
                            mention_type=mention_type,
                            context='reasoning',
                            text_snippet=explanation[:200],
                            sentiment=sentiment
                        ))
                        
                        if mention_type == 'assert' and t_assert is None:
                            t_assert = turn_num
                        elif mention_type == 'exonerate' and t_exonerate is None:
                            t_exonerate = turn_num
                
                # Check shell commands for final output
                if payload.get('name') == 'shell':
                    cmd = args.get('command', [])
                    cmd_str = ' '.join(cmd) if isinstance(cmd, list) else str(cmd)
                    
                    # Look for output generation with root cause assertions
                    if ('output.json' in cmd_str or 'root_cause' in cmd_str.lower()) and entity_matches(cmd_str, gt):
                        sentiment = classify_sentiment(cmd_str, gt)
                        if sentiment == 'positive' or 'root_cause' in cmd_str.lower():
                            mentions.append(EntityMention(
                                turn=turn_num,
                                mention_type='assert',
                                context='final_output',
                                text_snippet=cmd_str[:300],
                                sentiment='positive'
                            ))
                            if t_assert is None:
                                t_assert = turn_num
    
    # Check for recovery (exoneration followed by assertion)
    had_recovery = False
    if t_exonerate is not None and t_assert is not None and t_exonerate < t_assert:
        had_recovery = True
        t_recover = t_assert
    
    # Compute metrics
    discovery_efficiency = t_encounter / total_turns if t_encounter and total_turns > 0 else None
    investigation_delay = t_investigate - t_encounter if t_investigate and t_encounter else None
    assertion_delay = t_assert - t_investigate if t_assert and t_investigate else None
    
    # Compute max stage reached (without final success - that comes from judge)
    # 0=none, 1=encounter, 2=investigate, 3=assert
    max_stage = 0
    if t_encounter is not None:
        max_stage = 1
    if t_investigate is not None:
        max_stage = 2
    if t_assert is not None:
        max_stage = 3
    
    # Exploration metrics
    # Note: total_entities_available is hard to define with regex approach as we don't know the universe.
    # We will use the number of encountered entities as the denominator for "investigation ratio"
    # or just report the raw counts.
    
    num_encountered = len(encountered_entities)
    num_investigated = len(investigated_entities)
    
    # Ratio: What % of things seen were actually investigated?
    expl_ratio = num_investigated / num_encountered if num_encountered > 0 else 0.0
    
    # Coverage metrics: on-chain (fault propagation) vs off-chain (detoured)
    n_on_chain = len(on_chain_groups_investigated)
    n_off_chain = len(off_chain_groups_investigated)
    total_investigated_groups = len(all_groups_investigated)
    
    # Propagation coverage: what % of the fault propagation chain was investigated?
    n_propagation_entities = len(gt.propagation_entities)
    prop_coverage = n_on_chain / n_propagation_entities if n_propagation_entities > 0 else 0.0
    
    # Detour rate: what % of investigated entities were off-chain (not in fault propagation)?
    det_rate = n_off_chain / total_investigated_groups if total_investigated_groups > 0 else 0.0
    
    return TrajectoryAnalysis(
        model="",  # Set by caller
        scenario="",  # Set by caller
        trial=0,  # Set by caller
        total_turns=total_turns,
        gt_entity=gt.entity_name,
        t_encounter=t_encounter,
        t_investigate=t_investigate,
        t_assert=t_assert,
        t_exonerate=t_exonerate,
        t_recover=t_recover,
        max_stage_reached=max_stage,
        mentions=mentions,
        total_entities_available=num_encountered, # Using encountered as the "available" set
        unique_entities_encountered=num_encountered,
        unique_entities_investigated=num_investigated,
        exploration_ratio=expl_ratio,
        # Coverage metrics (on-chain vs off-chain)
        on_chain_investigated=n_on_chain,
        off_chain_investigated=n_off_chain,
        propagation_coverage=prop_coverage,
        detour_rate=det_rate,
        # Computed metrics
        discovery_efficiency=discovery_efficiency,
        investigation_delay=investigation_delay,
        assertion_delay=assertion_delay,
        had_recovery=had_recovery
    )


def analyze_model(model_dir: Path, gt_cache: Dict[str, GroundTruth]) -> List[TrajectoryAnalysis]:
    """Analyze all trials for a model."""
    results = []
    model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0]

    # Check if directory contains Scenario folders directly, or if we need to go one level deeper
    # (e.g., model_dir/sre/Scenario-1, model_dir/finops/Scenario-1, etc.)
    has_scenarios = any(d.name.startswith("Scenario") for d in model_dir.iterdir() if d.is_dir())

    if not has_scenarios:
        # Look for subdirectories that might contain scenarios (sre, finops, etc.)
        subdirs = [d for d in model_dir.iterdir() if d.is_dir() and not d.name.startswith(".")]
        if len(subdirs) == 1:
            # If there's exactly one subdirectory, use it
            model_dir = subdirs[0]
        elif len(subdirs) > 1:
            # If there are multiple, try to find one with Scenario folders
            for subdir in subdirs:
                if any(d.name.startswith("Scenario") for d in subdir.iterdir() if d.is_dir()):
                    model_dir = subdir
                    break

    scenario_dirs = [d for d in sorted(model_dir.iterdir()) if d.is_dir() and d.name.startswith("Scenario-")]
    for scenario_dir in tqdm(scenario_dirs, desc=f"  {model_name} scenarios"):
        scenario = scenario_dir.name
        gt = gt_cache.get(scenario)
        if gt is None:
            continue

        trial_dirs = [d for d in sorted(scenario_dir.iterdir()) if d.is_dir() and d.name.isdigit()]
        for trial_dir in tqdm(trial_dirs, desc=f"    {scenario} trials"):
            trial_num = int(trial_dir.name)
            rollout_path = find_latest_rollout_file(trial_dir)

            if rollout_path is None:
                continue

            try:
                analysis = parse_rollout(rollout_path, gt)
                analysis.model = model_name
                analysis.scenario = scenario
                analysis.trial = trial_num

                # Get judge score to determine final success
                f1_score = get_judge_score(trial_dir)
                analysis.root_cause_f1 = f1_score
                if f1_score is not None and f1_score > 0:
                    analysis.final_success = True
                    analysis.max_stage_reached = 4  # Success!

                results.append(analysis)
            except Exception as e:
                print(f"Error processing {model_name}/{scenario}/{trial_num}: {e}")

    return results


def plot_pipeline_funnel(summary_df: pd.DataFrame):
    """
    Figure 1: Stacked bar showing where trials drop off in the pipeline.

    Pipeline stages:
    - Encounter: GT entity appears in tool OUTPUT (passive - agent didn't ask for it)
    - Investigate: GT entity appears in tool ARGUMENTS (active - agent explicitly queried it)
    - Assert: Agent declares GT as root cause
    - Success: Judge confirms correct answer
    """
    # Filter out mistral (no data) and prepare data
    data = summary_df[summary_df['encounter_rate'] > 0].copy()
    data['model_clean'] = data['model'].apply(get_display_name)
    data = data.sort_values('success_rate', ascending=True)

    # Stack: none, encounter_only, investigate_only, assert_only, success
    # Normalize to percentages
    n_trials = data['n_trials']

    none_pct = data['n_stage_0_none'] / n_trials * 100
    enc_pct = data['n_stage_1_encounter_only'] / n_trials * 100
    inv_pct = data['n_stage_2_investigate_only'] / n_trials * 100
    ass_pct = data['n_stage_3_assert_only'] / n_trials * 100
    suc_pct = data['n_stage_4_success'] / n_trials * 100

    n_models = len(data)
    y = np.arange(n_models)
    bar_height = 0.7

    plt.rcParams.update(PLOT_PARAMETERS)

    STAGE_COLORS = {
        'none': '#d73027',       # Red - never encountered GT
        'encounter': '#fc8d59',  # Orange - saw but didn't investigate
        'investigate': '#fee08b', # Yellow - investigated but didn't assert
        'assert': '#d9ef8b',     # Light green - asserted but wrong final answer
        'success': '#1a9850',    # Green - success
        }

    # Create figure sized to fill half column with legend
    fig, ax = plt.subplots(figsize=(DOUBLE_COLUMN_WIDTH, 2.5))

    # Plot stacked bars with GT prefix labels
    ax.barh(y, none_pct, height=bar_height, label='RC never seen', color=STAGE_COLORS['none'],
            edgecolor='white', linewidth=0.3)
    ax.barh(y, enc_pct, height=bar_height, left=none_pct, label='RC seen, not queried',
            color=STAGE_COLORS['encounter'], edgecolor='white', linewidth=0.3)
    ax.barh(y, inv_pct, height=bar_height, left=none_pct + enc_pct, label='RC queried, not asserted',
            color=STAGE_COLORS['investigate'], edgecolor='white', linewidth=0.3)
    ax.barh(y, ass_pct, height=bar_height, left=none_pct + enc_pct + inv_pct, label='RC asserted, not in output',
            color=STAGE_COLORS['assert'], edgecolor='white', linewidth=0.3)
    ax.barh(y, suc_pct, height=bar_height, left=none_pct + enc_pct + inv_pct + ass_pct, label='RC asserted, in output',
            color=STAGE_COLORS['success'], edgecolor='white', linewidth=0.3)

    # Add percentage labels to each stack
    min_pct_threshold = 4  # Only show labels for segments >= 2%
    label_fontsize = MIN_FONT_SIZE - 3

    for i, model_idx in enumerate(y):
        segments = [
            (none_pct.iloc[i], none_pct.iloc[i] / 2),
            (enc_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] / 2),
            (inv_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] + inv_pct.iloc[i] / 2),
            (ass_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] + inv_pct.iloc[i] + ass_pct.iloc[i] / 2),
            (suc_pct.iloc[i], none_pct.iloc[i] + enc_pct.iloc[i] + inv_pct.iloc[i] + ass_pct.iloc[i] + suc_pct.iloc[i] / 2)
        ]

        for pct, x_pos in segments:
            if pct >= min_pct_threshold:
                ax.text(x_pos, model_idx, f'{pct:.0f}%',
                       ha='center', va='center', fontsize=label_fontsize,
                       color='black', weight='bold')

    ax.set_yticks(y)
    ax.set_yticklabels(data['model_clean'], fontsize=MIN_FONT_SIZE)
    ax.set_xlabel('Trials (%)', fontsize=MIN_FONT_SIZE)
    ax.set_xlim(0, 100)
    ax.set_ylim(-0.5, n_models - 0.5)
    ax.tick_params(axis='x', labelsize=MIN_FONT_SIZE)

    # Legend below the plot - 2 columns, positioned below x-axis label
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.18), ncol=2,
              frameon=False, fontsize=MIN_FONT_SIZE, columnspacing=0.8,
              handletextpad=0.3, handlelength=1.0)

    # Tight margins - more bottom space for legend
    fig.subplots_adjust(left=0.28, right=0.99, top=0.99, bottom=0.38)

    plt.title("Root Cause Entity Discovery Funnel")
    plt.show()
    fig.savefig(OUTPUT_DIR / "fig_conversion_funnel.png")
    plt.close(fig)
    print("Saved: fig_conversion_funnel.png")


def extract_all_data():
    # Create output directory
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    
    # Load all ground truths from GT_DIR
    print("\nLoading ground truth data...")
    gt_cache = {}

    # Find version directories (e.g., v0.2-*)
    if GT_DIR.exists():
        version_dirs = [d for d in GT_DIR.iterdir() if d.is_dir() and d.name.startswith("v0.2-")]

        for version_dir in version_dirs:
            scenario_dirs = [d for d in version_dir.iterdir() if d.is_dir() and d.name.startswith("Scenario-")]
            for scenario_dir in tqdm(scenario_dirs, desc="Loading ground truths"):
                gt = load_ground_truth(scenario_dir.name)
                if gt:
                    gt_cache[scenario_dir.name] = gt
    else:
        print(f"Warning: GT_DIR not found at {GT_DIR}")

    print(f"Loaded {len(gt_cache)} ground truth files")

    # Find all agent directories (excluding hidden and backup directories)
    model_dirs = [d for d in LEADERBOARD_DIR.iterdir()
                  if d.is_dir() and not d.name.startswith(".") and not d.name.startswith("backup_")]
    print(f"Found {len(model_dirs)} agent models")
    
    # Analyze each model
    all_results = []
    for model_dir in tqdm(model_dirs, desc="Analyzing models"):
        model_name = model_dir.name.replace("react with code_", "").split("_07ccdb1")[0]
        print(f"\nAnalyzing {model_name}...")

        results = analyze_model(model_dir, gt_cache)
        all_results.extend(results)
        
        # Summary stats
        if results:
            encounters = [r for r in results if r.t_encounter is not None]
            asserts = [r for r in results if r.t_assert is not None]
            recoveries = [r for r in results if r.had_recovery]
            
            print(f"  Trials: {len(results)}")
            print(f"  Encounters: {len(encounters)} ({100*len(encounters)/len(results):.1f}%)")
            print(f"  Assertions: {len(asserts)} ({100*len(asserts)/len(results):.1f}%)")
            print(f"  Recoveries: {len(recoveries)} ({100*len(recoveries)/len(results):.1f}%)")
    
    # Convert to DataFrame
    print("\n" + "=" * 60)
    print("Generating output files...")
    
    # Summary per trial
    trial_data = []
    for r in all_results:
        trial_data.append({
            'model': r.model,
            'scenario': r.scenario,
            'trial': r.trial,
            'total_turns': r.total_turns,
            'gt_entity': r.gt_entity,
            't_encounter': r.t_encounter,
            't_investigate': r.t_investigate,
            't_assert': r.t_assert,
            't_exonerate': r.t_exonerate,
            't_recover': r.t_recover,
            'max_stage_reached': r.max_stage_reached,
            'final_success': r.final_success,
            'root_cause_f1': r.root_cause_f1,
            'discovery_efficiency': r.discovery_efficiency,
            'investigation_delay': r.investigation_delay,
            'assertion_delay': r.assertion_delay,
            'had_recovery': r.had_recovery,
            'n_mentions': len(r.mentions),
            'total_entities_available': r.total_entities_available,
            'unique_entities_encountered': r.unique_entities_encountered,
            'unique_entities_investigated': r.unique_entities_investigated,
            'exploration_ratio': r.exploration_ratio,
            # Coverage metrics (on-chain vs off-chain)
            'on_chain_investigated': r.on_chain_investigated,
            'off_chain_investigated': r.off_chain_investigated,
            'propagation_coverage': r.propagation_coverage,
            'detour_rate': r.detour_rate
        })
    
    trial_df = pd.DataFrame(trial_data)
    trial_df.to_csv(OUTPUT_DIR / "discovery_trials.csv", index=False)
    print(f"Saved: {OUTPUT_DIR / 'discovery_trials.csv'}")
    
    # Summary per model
    model_summary = []
    for model in trial_df['model'].unique():
        model_data = trial_df[trial_df['model'] == model]
        n_total = len(model_data)
        
        # Funnel stages: count trials reaching each stage
        # Stage 0: none, 1: encounter, 2: investigate, 3: assert, 4: success
        stage_counts = model_data['max_stage_reached'].value_counts().to_dict()
        
        # Cumulative: how many reached AT LEAST this stage
        n_encounter = len(model_data[model_data['max_stage_reached'] >= 1])
        n_investigate = len(model_data[model_data['max_stage_reached'] >= 2])
        n_assert = len(model_data[model_data['max_stage_reached'] >= 3])
        n_success = len(model_data[model_data['max_stage_reached'] >= 4])
        
        # Filter to trials where we found something
        with_encounter = model_data[model_data['t_encounter'].notna()]
        with_assert = model_data[model_data['t_assert'].notna()]
        with_recovery = model_data[model_data['had_recovery'] == True]
        with_success = model_data[model_data['final_success'] == True]
        
        model_summary.append({
            'model': model,
            'n_trials': n_total,
            'n_scenarios': model_data['scenario'].nunique(),
            # Funnel rates (cumulative, relative to total trials)
            'encounter_rate': n_encounter / n_total if n_total > 0 else 0,
            'investigate_rate': n_investigate / n_total if n_total > 0 else 0,
            'assertion_rate': n_assert / n_total if n_total > 0 else 0,
            'success_rate': n_success / n_total if n_total > 0 else 0,
            # Conversion rate: given encounter, did model declare it as root cause?
            # This handles multi-root-cause scenarios better
            'conversion_rate': n_success / n_encounter if n_encounter > 0 else 0,
            # Drop-off at each stage (exclusive counts)
            'n_stage_0_none': stage_counts.get(0, 0),
            'n_stage_1_encounter_only': stage_counts.get(1, 0),
            'n_stage_2_investigate_only': stage_counts.get(2, 0),
            'n_stage_3_assert_only': stage_counts.get(3, 0),
            'n_stage_4_success': stage_counts.get(4, 0),
            # Legacy metrics
            'recovery_rate': len(with_recovery) / n_total if n_total > 0 else 0,
            'avg_t_encounter': with_encounter['t_encounter'].mean() if len(with_encounter) > 0 else None,
            'avg_t_assert': with_assert['t_assert'].mean() if len(with_assert) > 0 else None,
            'avg_total_turns': model_data['total_turns'].mean(),
            'avg_discovery_efficiency': with_encounter['discovery_efficiency'].mean() if len(with_encounter) > 0 else None,
            'avg_investigation_delay': with_encounter['investigation_delay'].mean() if len(with_encounter) > 0 else None,
            'avg_assertion_delay': with_assert['assertion_delay'].mean() if len(with_assert) > 0 else None,
            'avg_f1': with_success['root_cause_f1'].mean() if len(with_success) > 0 else None,
            'avg_exploration_ratio': model_data['exploration_ratio'].mean(),
            'avg_entities_investigated': model_data['unique_entities_investigated'].mean(),
            # Coverage metrics (fault propagation coverage)
            'avg_on_chain_investigated': model_data['on_chain_investigated'].mean(),
            'avg_off_chain_investigated': model_data['off_chain_investigated'].mean(),
            'avg_propagation_coverage': model_data['propagation_coverage'].mean(),
            'avg_detour_rate': model_data['detour_rate'].mean()
        })
    
    summary_df = pd.DataFrame(model_summary)
    summary_df.to_csv(OUTPUT_DIR / "discovery_summary.csv", index=False)
    print(f"Saved: {OUTPUT_DIR / 'discovery_summary.csv'}")

    trials_n = len(all_results)

    return summary_df, trial_df, trials_n


def main():
    print("=" * 60)
    print("Root Cause Discovery Trajectory Analysis")
    print("=" * 60)
    
    summary_df, trial_df, trials_n = extract_all_data()
    
    # Print summary table with funnel
    print("\n" + "=" * 80)
    print("Discovery Pipeline Funnel:")
    print("-" * 80)
    print(f"{'Model':<25} {'Trials':>7} {'Encntr':>8} {'Invest':>8} {'Assert':>8} {'Success':>8}")
    print("-" * 80)
    for _, row in summary_df.iterrows():
        print(f"{row['model']:<25} {row['n_trials']:>7} "
              f"{row['encounter_rate']*100:>7.0f}% "
              f"{row['investigate_rate']*100:>7.0f}% "
              f"{row['assertion_rate']*100:>7.0f}% "
              f"{row['success_rate']*100:>7.0f}%")
    
    print("\n" + "=" * 80)
    print("Drop-off Analysis (where trials stopped):")
    print("-" * 80)
    print(f"{'Model':<25} {'None':>7} {'Enc→X':>7} {'Inv→X':>7} {'Ass→X':>7} {'✓':>7}")
    print("-" * 80)
    for _, row in summary_df.iterrows():
        print(f"{row['model']:<25} "
              f"{row['n_stage_0_none']:>7} "
              f"{row['n_stage_1_encounter_only']:>7} "
              f"{row['n_stage_2_investigate_only']:>7} "
              f"{row['n_stage_3_assert_only']:>7} "
              f"{row['n_stage_4_success']:>7}")
    
    print(f"\nTotal trials analyzed: {trials_n}")
    print(f"\nOutput saved to: {OUTPUT_DIR}")


if __name__ == "__main__":
    main()