Phaedrus33 commited on
Commit
50e652f
·
verified ·
1 Parent(s): 1dd75b6

Upload generate_traces_final.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. generate_traces_final.py +1272 -0
generate_traces_final.py ADDED
@@ -0,0 +1,1272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Final Reasoning Trace Generator - Train Data Only
4
+
5
+ Produces training-ready checkpoint JSON with reasoning traces for SFT/GRPO
6
+ from the labeled training set (train.csv, 2400 Type A questions).
7
+
8
+ Rule-based classification walkthrough (from data analysis we know these rules):
9
+ - Tier 1: C7 (speed), C2 (distance), C5 (handovers), C8 (RB)
10
+ - Tier 2: C1 detection (3 sub-rules + V16 overrides B/P3/P4)
11
+ - Tier 3: C4 interference (ratio filter)
12
+ - Tier 4: C6 collision (signal filters + V16 overrides P1/P2/G/J/P5b)
13
+ - Tier 5: C1/C3 tiebreaker (tilt/RSRP/SINR gate + rescue rules R1-R4)
14
+
15
+ V19 thresholds (calibrated on train set):
16
+ - tilt_high_c1: 28 with SINR gate (>=12 -> C3)
17
+ - rsrp_c1_medium: -90, rsrp_c3_medium: -82
18
+ - P4: -79, R1: collision_ratio >= 0.9
19
+ - R2: strong_neighbors < 0.8, R3: c4_interference >= 3.0
20
+
21
+ Output: outputs/traces_final/traces_final.json (~2400 traces)
22
+
23
+ Usage:
24
+ uv run python generate_traces_final.py
25
+ uv run python generate_traces_final.py --spot-check 5
26
+ uv run python generate_traces_final.py --validate-only outputs/traces_final/traces_final.json
27
+ """
28
+
29
+ import json
30
+ import argparse
31
+ import logging
32
+ from pathlib import Path
33
+ from collections import Counter
34
+ from typing import Dict, List, Optional, Tuple
35
+
36
+ import pandas as pd
37
+
38
+ from telco_utils import (
39
+ classify_question_type,
40
+ classify_type_a,
41
+ extract_type_a_options,
42
+ parse_type_a_question,
43
+ haversine,
44
+ check_c4_non_colocated,
45
+ check_c6_pci_collision,
46
+ get_min_rsrp,
47
+ get_strong_neighbor_count,
48
+ get_type_a_tilt,
49
+ get_type_a_avg_rsrp,
50
+ get_avg_off_axis_angle,
51
+ get_min_sinr_low_tp,
52
+ get_min_neighbor_diff,
53
+ get_pci_collision_ratio,
54
+ get_tp_threshold,
55
+ compute_v16_metrics,
56
+ classify_c1_vs_c3,
57
+ # Type B imports (used by SFT/GRPO/inference for prompt preparation)
58
+ parse_type_b_question,
59
+ parse_config_data,
60
+ detect_inter_freq_ho,
61
+ check_n1_in_config,
62
+ )
63
+
64
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
65
+ logger = logging.getLogger(__name__)
66
+
67
+ DATA_DIR = Path('the-ai-telco-troubleshooting-challenge20251127-8634-8qzscv')
68
+ OUTPUT_DIR = Path('outputs/traces_final')
69
+
70
+ CAUSE_DESCRIPTIONS = {
71
+ 'C1': 'excessive downtilt causing weak far-end coverage',
72
+ 'C2': 'coverage distance exceeding 1 km (over-shooting)',
73
+ 'C3': 'a neighboring cell providing higher throughput',
74
+ 'C4': 'non-colocated co-frequency interference',
75
+ 'C5': 'frequent handovers degrading performance',
76
+ 'C6': 'PCI mod 30 collision causing interference',
77
+ 'C7': 'vehicle speed exceeding 40 km/h',
78
+ 'C8': 'average scheduled RBs below threshold',
79
+ }
80
+
81
+ # Format specs per metric
82
+ METRIC_FORMATS = {
83
+ 'max_speed': '.1f',
84
+ 'max_distance_low_tp': '.2f',
85
+ 'handover_count': '.0f',
86
+ 'avg_rb': '.1f',
87
+ }
88
+
89
+
90
+ def safe_cmp_fmt(val: float, threshold: float, op: str, decimals: int = 1) -> str:
91
+ """Format value with enough precision that the displayed comparison is correct."""
92
+ def _passes(v, t, o):
93
+ if o == '<': return v < t
94
+ if o == '>': return v > t
95
+ if o == '<=': return v <= t
96
+ if o == '>=': return v >= t
97
+ return False
98
+
99
+ raw_passes = _passes(val, threshold, op)
100
+ for d in range(decimals, decimals + 3):
101
+ rounded = round(val, d)
102
+ if _passes(rounded, threshold, op) == raw_passes:
103
+ return f"{val:.{d}f}"
104
+ return f"{val:.{decimals + 2}f}"
105
+
106
+
107
+ def _maybe_correct(lines: list, m: Dict, canonical: str, is_expert: bool) -> str:
108
+ """Post-process trace: if V19 concluded with the wrong cause, append expert correction."""
109
+ trace = '\n'.join(lines)
110
+ if not is_expert:
111
+ return trace
112
+ correct_conclusion = f"The root cause is {CAUSE_DESCRIPTIONS[canonical]}."
113
+ if correct_conclusion in trace:
114
+ return trace
115
+ # Wrong conclusion - append expert correction
116
+ lines.append("")
117
+ lines.append("However, examining additional indicators:")
118
+ _add_expert_reasoning(lines, m, canonical)
119
+ return '\n'.join(lines)
120
+
121
+
122
+ # =============================================================================
123
+ # METRICS COMPUTATION
124
+ # =============================================================================
125
+
126
+ def compute_all_metrics(question: str, drive_test: List[Dict], cells: Dict) -> Dict:
127
+ """Compute all V19 metrics from parsed drive test and engineering params."""
128
+ m = {}
129
+ tp_threshold = get_tp_threshold(question)
130
+ m['tp_threshold'] = tp_threshold
131
+
132
+ # Tier 1
133
+ speeds = [d['speed'] for d in drive_test if d['speed']]
134
+ m['max_speed'] = max(speeds) if speeds else 0.0
135
+
136
+ low_tp_distances = []
137
+ for d in drive_test:
138
+ if d['throughput'] and d['throughput'] < tp_threshold:
139
+ pci = d['serving_pci']
140
+ if pci and pci in cells:
141
+ cell = cells[pci]
142
+ dist = haversine(cell['lon'], cell['lat'], d['lon'], d['lat'])
143
+ low_tp_distances.append(dist)
144
+ m['max_distance_low_tp'] = max(low_tp_distances) if low_tp_distances else 0.0
145
+
146
+ pcis = [d['serving_pci'] for d in drive_test if d['serving_pci']]
147
+ m['handover_count'] = sum(1 for i in range(1, len(pcis)) if pcis[i] != pcis[i-1]) if len(pcis) >= 2 else 0
148
+
149
+ rbs = [d['rb'] for d in drive_test if d['rb']]
150
+ m['avg_rb'] = sum(rbs) / len(rbs) if rbs else 999.0
151
+
152
+ # Tier 2/3
153
+ m['serving_tilt'] = get_type_a_tilt(drive_test, cells)
154
+ m['avg_rsrp'] = get_type_a_avg_rsrp(drive_test)
155
+ m['min_rsrp'] = get_min_rsrp(drive_test)
156
+ m['strong_neighbor_count'] = get_strong_neighbor_count(drive_test)
157
+ m['min_neighbor_diff'] = get_min_neighbor_diff(drive_test)
158
+ m['avg_off_axis'] = get_avg_off_axis_angle(drive_test, cells)
159
+ m['min_sinr_low_tp'] = get_min_sinr_low_tp(drive_test)
160
+
161
+ # avg_sinr for SINR gate in C1/C3 tiebreaker
162
+ sinrs = [d.get('sinr') for d in drive_test if d.get('sinr') is not None]
163
+ m['avg_sinr'] = sum(sinrs) / len(sinrs) if sinrs else None
164
+
165
+ _, c4_interference, _ = check_c4_non_colocated(drive_test, cells)
166
+ m['c4_interference'] = c4_interference
167
+ m['ratio_nbdiff_interf'] = (
168
+ m['min_neighbor_diff'] / max(c4_interference, 1) if c4_interference > 0 else 0.0
169
+ )
170
+
171
+ _, has_collision, c6_detail = check_c6_pci_collision(drive_test)
172
+ m['pci_collision'] = has_collision
173
+ if has_collision:
174
+ import re as _re
175
+ dm = _re.search(r'serving (\d+)%30=(\d+) == neighbor (\d+)', c6_detail)
176
+ if dm:
177
+ spci, smod, npci = dm.group(1), dm.group(2), dm.group(3)
178
+ nmod = int(npci) % 30
179
+ m['pci_collision_detail'] = f"serving PCI {spci} mod 30 = {smod}, neighbor PCI {npci} mod 30 = {nmod}"
180
+ else:
181
+ m['pci_collision_detail'] = c6_detail
182
+ else:
183
+ m['pci_collision_detail'] = c6_detail
184
+ m['pci_collision_ratio'] = get_pci_collision_ratio(drive_test)
185
+
186
+ # V16 override metrics
187
+ v16 = compute_v16_metrics(drive_test, tp_threshold)
188
+ m['post_ho_good_streak'] = v16.get('post_ho_good_streak', 0)
189
+ m['rsrp_recovery'] = v16.get('rsrp_recovery', 0.0)
190
+ m['rsrp_change_during_prob'] = v16.get('rsrp_change_during_prob', 0.0)
191
+ m['rsrp_trend'] = v16.get('rsrp_trend', 0.0)
192
+ m['nb_within_5db_per_row'] = v16.get('nb_within_5db_per_row', 0.0)
193
+
194
+ # C6 filter signals
195
+ m['c6_c1_signal'] = m['serving_tilt'] >= 20
196
+ m['c6_c3_signal'] = m['min_neighbor_diff'] < 3 and m['serving_tilt'] > 12
197
+ m['c6_c3_off_axis_signal'] = m['avg_off_axis'] > 30
198
+
199
+ return m
200
+
201
+
202
+ def format_metrics_block(m: Dict) -> str:
203
+ """Format all computed metrics as a structured text block."""
204
+ lines = [
205
+ "Extracted metrics:",
206
+ f" max_speed = {m['max_speed']:.1f} km/h",
207
+ f" max_distance_low_tp = {m['max_distance_low_tp']:.2f} km",
208
+ f" handover_count = {m['handover_count']}",
209
+ f" avg_rb = {m['avg_rb']:.1f}",
210
+ f" serving_tilt = {m['serving_tilt']:.0f} deg",
211
+ f" avg_rsrp = {m['avg_rsrp']:.3f} dBm",
212
+ f" min_rsrp = {m['min_rsrp']:.2f} dBm",
213
+ f" strong_neighbor_count = {m['strong_neighbor_count']:.2f}",
214
+ f" min_neighbor_diff = {m['min_neighbor_diff']:.1f} dB",
215
+ f" c4_interference = {m['c4_interference']:.2f} dB",
216
+ f" pci_collision = {'yes' if m['pci_collision'] else 'no'}",
217
+ f" pci_collision_ratio = {m['pci_collision_ratio']:.2f}",
218
+ f" avg_off_axis = {m['avg_off_axis']:.1f} deg",
219
+ f" post_ho_good_streak = {m['post_ho_good_streak']}",
220
+ f" rsrp_recovery = {m['rsrp_recovery']:.1f} dB",
221
+ f" rsrp_trend = {m['rsrp_trend']:.2f}",
222
+ f" nb_within_5db_per_row = {m['nb_within_5db_per_row']:.2f}",
223
+ ]
224
+ if m.get('avg_sinr') is not None:
225
+ lines.append(f" avg_sinr = {m['avg_sinr']:.1f} dB")
226
+ else:
227
+ lines.append(" avg_sinr = N/A")
228
+ return '\n'.join(lines)
229
+
230
+
231
+ # =============================================================================
232
+ # TRACE GENERATION: V19 CASCADE WALKER
233
+ # =============================================================================
234
+
235
+ def generate_trace(
236
+ m: Dict,
237
+ result: Dict,
238
+ available_causes: set,
239
+ ground_truth: str,
240
+ is_expert: bool = False,
241
+ ) -> str:
242
+ """Walk the V19 cascade and emit reasoning trace."""
243
+ canonical = ground_truth if is_expert else result.get('canonical', ground_truth)
244
+ evidence = result.get('evidence', {})
245
+ v16_override = evidence.get('v16_override', '')
246
+
247
+ lines = []
248
+ lines.append(f"I need to identify the root cause of throughput dropping below {m['tp_threshold']:.0f} Mbps.")
249
+ lines.append("")
250
+ lines.append(format_metrics_block(m))
251
+ lines.append("")
252
+
253
+ # ===== STEP 1: TIER 1 CHECKS =====
254
+ lines.append("Step 1 - Tier 1 checks:")
255
+ tier1_hit = _walk_tier1(lines, m, canonical, available_causes)
256
+ if tier1_hit:
257
+ return '\n'.join(lines)
258
+
259
+ lines.append("All tier 1 causes ruled out.")
260
+ lines.append("")
261
+
262
+ # ===== STEP 2: C1 DETECTION =====
263
+ lines.append("Step 2 - C1 detection rules:")
264
+ c1_hit = _walk_c1_detection(lines, m, canonical, available_causes, evidence, v16_override)
265
+ if c1_hit:
266
+ return _maybe_correct(lines, m, canonical, is_expert)
267
+ lines.append("")
268
+
269
+ # ===== STEP 3: C4 CHECK =====
270
+ lines.append("Step 3 - C4 interference check:")
271
+ c4_hit = _walk_c4(lines, m, canonical, available_causes, evidence)
272
+ if c4_hit:
273
+ return _maybe_correct(lines, m, canonical, is_expert)
274
+ lines.append("")
275
+
276
+ # ===== STEP 4: C6 COLLISION CHECK =====
277
+ lines.append("Step 4 - C6 collision check:")
278
+ c6_hit = _walk_c6(lines, m, canonical, available_causes, evidence, v16_override)
279
+ if c6_hit:
280
+ return _maybe_correct(lines, m, canonical, is_expert)
281
+ lines.append("")
282
+
283
+ # ===== STEP 5: C1/C3 TIEBREAKER =====
284
+ lines.append("Step 5 - C1/C3 tiebreaker:")
285
+ _walk_c1c3_tiebreaker(lines, m, canonical, available_causes, evidence, v16_override, is_expert)
286
+
287
+ return _maybe_correct(lines, m, canonical, is_expert)
288
+
289
+
290
+ def _walk_tier1(lines: list, m: Dict, canonical: str, available: set) -> bool:
291
+ """Walk tier-1 cascade. Returns True if answer found here."""
292
+ checks = [
293
+ ('C7', 'max_speed', 40, '>', 'km/h'),
294
+ ('C2', 'max_distance_low_tp', 1.0, '>', 'km'),
295
+ ('C5', 'handover_count', 3, '>=', ''),
296
+ ('C8', 'avg_rb', 170, '<', ''),
297
+ ]
298
+ for code, metric, thresh, op, unit in checks:
299
+ if code not in available:
300
+ continue
301
+ val = m[metric]
302
+ fmt = METRIC_FORMATS[metric]
303
+ triggered = (
304
+ (op == '>' and val > thresh) or
305
+ (op == '>=' and val >= thresh) or
306
+ (op == '<' and val < thresh)
307
+ )
308
+ suffix = f" {unit}".rstrip()
309
+ val_str = f"{val:{fmt}}"
310
+ if triggered and code == canonical:
311
+ cmp_op = '>' if op == '>' else '>=' if op == '>=' else '<'
312
+ lines.append(f"{metric} = {val_str}{suffix} {cmp_op} {thresh} -> {code}.")
313
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS[code]}.")
314
+ return True
315
+ elif triggered:
316
+ cmp_op = '>' if op == '>' else '>=' if op == '>=' else '<'
317
+ lines.append(f"{metric} = {val_str}{suffix} {cmp_op} {thresh} -> would suggest {code}.")
318
+ else:
319
+ inv_op = '<=' if op == '>' else '<' if op == '>=' else '>='
320
+ lines.append(f"{metric} = {val_str}{suffix} {inv_op} {thresh} -> not {code}.")
321
+ return False
322
+
323
+
324
+ def _walk_c1_detection(lines: list, m: Dict, canonical: str, available: set,
325
+ evidence: Dict, v16_override: str) -> bool:
326
+ """Walk C1 detection rules. Returns True if answer resolved here."""
327
+ if 'C1' not in available:
328
+ lines.append("C1 not in available options -> skip.")
329
+ return False
330
+
331
+ r1 = m['min_rsrp'] < -90 and not m['pci_collision'] and m['c4_interference'] < 3
332
+ r2 = m['strong_neighbor_count'] < 0.5 and m['serving_tilt'] >= 15
333
+ r3 = m['pci_collision'] and m['strong_neighbor_count'] < 0.5
334
+
335
+ rsrp_s = safe_cmp_fmt(m['min_rsrp'], -90, '<', decimals=2)
336
+ c4_s = safe_cmp_fmt(m['c4_interference'], 3, '<', decimals=2)
337
+ nb_s1 = safe_cmp_fmt(m['strong_neighbor_count'], 0.5, '<', decimals=2)
338
+ nb_s3 = safe_cmp_fmt(m['strong_neighbor_count'], 0.5, '<', decimals=2)
339
+
340
+ lines.append(
341
+ f"Rule 1: min_rsrp = {rsrp_s} {'<' if m['min_rsrp'] < -90 else '>='} -90"
342
+ f", pci_collision = {'yes' if m['pci_collision'] else 'no'}"
343
+ f", c4_interference = {c4_s} {'<' if m['c4_interference'] < 3 else '>='} 3"
344
+ f" -> {'TRIGGERED' if r1 else 'no'}."
345
+ )
346
+ lines.append(
347
+ f"Rule 2: strong_neighbor_count = {nb_s1} {'<' if m['strong_neighbor_count'] < 0.5 else '>='} 0.5"
348
+ f", serving_tilt = {m['serving_tilt']:.0f} {'>=' if m['serving_tilt'] >= 15 else '<'} 15"
349
+ f" -> {'TRIGGERED' if r2 else 'no'}."
350
+ )
351
+ lines.append(
352
+ f"Rule 3: pci_collision = {'yes' if m['pci_collision'] else 'no'}"
353
+ f", strong_neighbor_count = {nb_s3} {'<' if m['strong_neighbor_count'] < 0.5 else '>='} 0.5"
354
+ f" -> {'TRIGGERED' if r3 else 'no'}."
355
+ )
356
+
357
+ if not (r1 or r2 or r3):
358
+ lines.append("No C1 detection rule triggered.")
359
+ return False
360
+
361
+ fired = 'Rule 1' if r1 else 'Rule 2' if r2 else 'Rule 3'
362
+ lines.append(f"C1 detected via {fired}.")
363
+
364
+ # V16 override checks (always show all 3)
365
+ lines.append("V16 override checks:")
366
+
367
+ b_fires = 'C3' in available and m['post_ho_good_streak'] >= 2
368
+ lines.append(
369
+ f" B: post_ho_good_streak = {m['post_ho_good_streak']} {'>=' if m['post_ho_good_streak'] >= 2 else '<'} 2"
370
+ f" -> {'OVERRIDE to C3' if b_fires else 'no'}."
371
+ )
372
+ if b_fires:
373
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C3']}.")
374
+ return True
375
+
376
+ p3_fires = 'C6' in available and m['pci_collision_ratio'] > 0.70
377
+ cr_s = safe_cmp_fmt(m['pci_collision_ratio'], 0.70, '>', decimals=2)
378
+ lines.append(
379
+ f" P3: pci_collision_ratio = {cr_s} {'>' if m['pci_collision_ratio'] > 0.70 else '<='} 0.70"
380
+ f" -> {'OVERRIDE to C6' if p3_fires else 'no'}."
381
+ )
382
+ if p3_fires:
383
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C6']}.")
384
+ return True
385
+
386
+ p4_fires = 'C3' in available and m['avg_rsrp'] > -79 and m['strong_neighbor_count'] > 1.0
387
+ rsrp_p4_s = safe_cmp_fmt(m['avg_rsrp'], -79, '>', decimals=3)
388
+ nb_p4_s = safe_cmp_fmt(m['strong_neighbor_count'], 1.0, '>', decimals=2)
389
+ lines.append(
390
+ f" P4: avg_rsrp = {rsrp_p4_s} {'>' if m['avg_rsrp'] > -79 else '<='} -79"
391
+ f", strong_neighbor_count = {nb_p4_s} {'>' if m['strong_neighbor_count'] > 1.0 else '<='} 1.0"
392
+ f" -> {'OVERRIDE to C3' if p4_fires else 'no'}."
393
+ )
394
+ if p4_fires:
395
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C3']}.")
396
+ return True
397
+
398
+ lines.append("No override. C1 confirmed.")
399
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
400
+ return True
401
+
402
+
403
+ def _walk_c4(lines: list, m: Dict, canonical: str, available: set, evidence: Dict) -> bool:
404
+ """Walk C4 interference check. Returns True if answer resolved here."""
405
+ if 'C4' not in available:
406
+ lines.append("C4 not in available options -> skip.")
407
+ return False
408
+
409
+ c4_s = safe_cmp_fmt(m['c4_interference'], 3, '<', decimals=2)
410
+ if m['c4_interference'] < 3:
411
+ lines.append(f"c4_interference = {c4_s} < 3 dB -> not C4.")
412
+ return False
413
+
414
+ ratio_skip = m['ratio_nbdiff_interf'] < -0.5 and m['c4_interference'] < 12
415
+ lines.append(
416
+ f"c4_interference = {c4_s} >= 3 dB."
417
+ )
418
+ ratio_s = safe_cmp_fmt(m['ratio_nbdiff_interf'], -0.5, '<', decimals=2)
419
+ c4_12_s = safe_cmp_fmt(m['c4_interference'], 12, '<', decimals=2)
420
+ lines.append(
421
+ f"Ratio filter: ratio_nbdiff_interf = {ratio_s}"
422
+ f" {'<' if m['ratio_nbdiff_interf'] < -0.5 else '>='} -0.5"
423
+ f", c4_interference = {c4_12_s}"
424
+ f" {'<' if m['c4_interference'] < 12 else '>='} 12"
425
+ f" -> {'FILTERED (neighbors dominate, skip C4)' if ratio_skip else 'no filter, C4 confirmed'}."
426
+ )
427
+
428
+ if ratio_skip:
429
+ return False
430
+
431
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C4']}.")
432
+ return True
433
+
434
+
435
+ def _walk_c6(lines: list, m: Dict, canonical: str, available: set,
436
+ evidence: Dict, v16_override: str) -> bool:
437
+ """Walk C6 collision check with filtering. Returns True if answer resolved here."""
438
+ if 'C6' not in available:
439
+ lines.append("C6 not in available options -> skip.")
440
+ return False
441
+
442
+ if not m['pci_collision']:
443
+ lines.append("pci_collision = no -> not C6.")
444
+ return False
445
+
446
+ lines.append(f"pci_collision = yes ({m['pci_collision_detail']}).")
447
+
448
+ lines.append("Filter signals:")
449
+ lines.append(
450
+ f" c1_signal: serving_tilt = {m['serving_tilt']:.0f} {'>=' if m['serving_tilt'] >= 20 else '<'} 20"
451
+ f" -> {'yes' if m['c6_c1_signal'] else 'no'}."
452
+ )
453
+ lines.append(
454
+ f" c3_signal: min_neighbor_diff = {m['min_neighbor_diff']:.1f} {'<' if m['min_neighbor_diff'] < 3 else '>='} 3"
455
+ f" AND serving_tilt = {m['serving_tilt']:.0f} {'>' if m['serving_tilt'] > 12 else '<='} 12"
456
+ f" -> {'yes' if m['c6_c3_signal'] else 'no'}."
457
+ )
458
+ lines.append(
459
+ f" c3_off_axis: avg_off_axis = {m['avg_off_axis']:.1f} {'>' if m['avg_off_axis'] > 30 else '<='} 30"
460
+ f" -> {'yes' if m['c6_c3_off_axis_signal'] else 'no'}."
461
+ )
462
+
463
+ no_signal = not m['c6_c1_signal'] and not m['c6_c3_signal'] and not m['c6_c3_off_axis_signal']
464
+
465
+ if no_signal:
466
+ lines.append("No filter signals -> genuine collision path.")
467
+ return _walk_c6_no_signal_path(lines, m, canonical, available)
468
+
469
+ if m['c6_c3_off_axis_signal']:
470
+ rsrp_offaxis_s = safe_cmp_fmt(m['min_rsrp'], -90, '<', decimals=2)
471
+ if m['min_rsrp'] < -90 and 'C1' in available:
472
+ lines.append(f"Off-axis signal + min_rsrp = {rsrp_offaxis_s} < -90 -> downtilt path.")
473
+ return _walk_c6_offaxis_c1_path(lines, m, canonical, available)
474
+ elif 'C3' in available:
475
+ lines.append(f"Off-axis signal + min_rsrp = {rsrp_offaxis_s} >= -90 -> neighbor-better path.")
476
+ return _walk_c6_offaxis_c3_path(lines, m, canonical, available)
477
+
478
+ signals = []
479
+ if m['c6_c1_signal']:
480
+ signals.append('c1_signal (high tilt)')
481
+ if m['c6_c3_signal']:
482
+ signals.append('c3_signal (small neighbor diff)')
483
+ lines.append(f"Filter triggered: {', '.join(signals)} -> collision not primary cause, fall through to C1/C3.")
484
+ return False
485
+
486
+
487
+ def _walk_c6_no_signal_path(lines: list, m: Dict, canonical: str, available: set) -> bool:
488
+ """C6 no-signal path: B override, then P1 collision ratio check."""
489
+ lines.append("V16 override checks:")
490
+
491
+ b_fires = 'C3' in available and m['post_ho_good_streak'] >= 2
492
+ lines.append(
493
+ f" B: post_ho_good_streak = {m['post_ho_good_streak']} {'>=' if m['post_ho_good_streak'] >= 2 else '<'} 2"
494
+ f" -> {'OVERRIDE to C3' if b_fires else 'no'}."
495
+ )
496
+ if b_fires:
497
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C3']}.")
498
+ return True
499
+
500
+ if m['pci_collision_ratio'] >= 1.0:
501
+ lines.append(f"P1: pci_collision_ratio = {m['pci_collision_ratio']:.2f} >= 1.0 -> genuine C6.")
502
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C6']}.")
503
+ return True
504
+ else:
505
+ lines.append(f"P1: pci_collision_ratio = {m['pci_collision_ratio']:.2f} < 1.0 -> not genuine C6.")
506
+ if 'C1' in available and m['serving_tilt'] > 10 and m['rsrp_trend'] > 0.4:
507
+ lines.append(
508
+ f" serving_tilt = {m['serving_tilt']:.0f} > 10, rsrp_trend = {m['rsrp_trend']:.2f} > 0.4"
509
+ f" -> OVERRIDE to C1."
510
+ )
511
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
512
+ return True
513
+ elif 'C3' in available:
514
+ lines.append(f" Default fallback -> C3.")
515
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C3']}.")
516
+ return True
517
+ else:
518
+ lines.append(f" No better option -> keep C6.")
519
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C6']}.")
520
+ return True
521
+
522
+
523
+ def _walk_c6_offaxis_c1_path(lines: list, m: Dict, canonical: str, available: set) -> bool:
524
+ """C6 off-axis + weak RSRP -> C1 with V16 overrides B/P3/P4."""
525
+ lines.append("V16 override checks:")
526
+
527
+ b_fires = 'C3' in available and m['post_ho_good_streak'] >= 2
528
+ lines.append(
529
+ f" B: post_ho_good_streak = {m['post_ho_good_streak']} {'>=' if m['post_ho_good_streak'] >= 2 else '<'} 2"
530
+ f" -> {'OVERRIDE to C3' if b_fires else 'no'}."
531
+ )
532
+ if b_fires:
533
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C3']}.")
534
+ return True
535
+
536
+ p3_fires = 'C6' in available and m['pci_collision_ratio'] > 0.70
537
+ cr_s = safe_cmp_fmt(m['pci_collision_ratio'], 0.70, '>', decimals=2)
538
+ lines.append(
539
+ f" P3: pci_collision_ratio = {cr_s} {'>' if m['pci_collision_ratio'] > 0.70 else '<='} 0.70"
540
+ f" -> {'OVERRIDE to C6' if p3_fires else 'no'}."
541
+ )
542
+ if p3_fires:
543
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C6']}.")
544
+ return True
545
+
546
+ p4_fires = 'C3' in available and m['avg_rsrp'] > -79 and m['strong_neighbor_count'] > 1.0
547
+ rsrp_s = safe_cmp_fmt(m['avg_rsrp'], -79, '>', decimals=3)
548
+ nb_s = safe_cmp_fmt(m['strong_neighbor_count'], 1.0, '>', decimals=2)
549
+ lines.append(
550
+ f" P4: avg_rsrp = {rsrp_s} {'>' if m['avg_rsrp'] > -79 else '<='} -79"
551
+ f", strong_neighbor_count = {nb_s} {'>' if m['strong_neighbor_count'] > 1.0 else '<='} 1.0"
552
+ f" -> {'OVERRIDE to C3' if p4_fires else 'no'}."
553
+ )
554
+ if p4_fires:
555
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C3']}.")
556
+ return True
557
+
558
+ lines.append("No override. C1 confirmed (off-axis downtilt path).")
559
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
560
+ return True
561
+
562
+
563
+ def _walk_c6_offaxis_c3_path(lines: list, m: Dict, canonical: str, available: set) -> bool:
564
+ """C6 off-axis + good RSRP -> C3 with V16 overrides P2/G/J/P5b."""
565
+ lines.append("V16 override checks:")
566
+
567
+ p2_fires = 'C6' in available and m['pci_collision_ratio'] > 0.70
568
+ cr_s = safe_cmp_fmt(m['pci_collision_ratio'], 0.70, '>', decimals=2)
569
+ lines.append(
570
+ f" P2: pci_collision_ratio = {cr_s} {'>' if m['pci_collision_ratio'] > 0.70 else '<='} 0.70"
571
+ f" -> {'OVERRIDE to C6' if p2_fires else 'no'}."
572
+ )
573
+ if p2_fires:
574
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C6']}.")
575
+ return True
576
+
577
+ g_fires = ('C1' in available
578
+ and m['rsrp_change_during_prob'] > 5
579
+ and m['rsrp_trend'] > 0.5
580
+ and m['nb_within_5db_per_row'] < 1.0)
581
+ rc_s = safe_cmp_fmt(m['rsrp_change_during_prob'], 5, '>')
582
+ rt_s = safe_cmp_fmt(m['rsrp_trend'], 0.5, '>', decimals=2)
583
+ nb5_s = safe_cmp_fmt(m['nb_within_5db_per_row'], 1.0, '<', decimals=2)
584
+ lines.append(
585
+ f" G: rsrp_change = {rc_s} {'>' if m['rsrp_change_during_prob'] > 5 else '<='} 5"
586
+ f", rsrp_trend = {rt_s} {'>' if m['rsrp_trend'] > 0.5 else '<='} 0.5"
587
+ f", nb_5db = {nb5_s} {'<' if m['nb_within_5db_per_row'] < 1.0 else '>='} 1.0"
588
+ f" -> {'OVERRIDE to C1' if g_fires else 'no'}."
589
+ )
590
+ if g_fires:
591
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
592
+ return True
593
+
594
+ j_fires = 'C1' in available and m['rsrp_recovery'] > 15
595
+ rr_s = safe_cmp_fmt(m['rsrp_recovery'], 15, '>')
596
+ lines.append(
597
+ f" J: rsrp_recovery = {rr_s} {'>' if m['rsrp_recovery'] > 15 else '<='} 15"
598
+ f" -> {'OVERRIDE to C1' if j_fires else 'no'}."
599
+ )
600
+ if j_fires:
601
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
602
+ return True
603
+
604
+ p5b_fires = 'C1' in available and m['serving_tilt'] > 6 and m['nb_within_5db_per_row'] < 1.0
605
+ nb5b_s = safe_cmp_fmt(m['nb_within_5db_per_row'], 1.0, '<', decimals=2)
606
+ lines.append(
607
+ f" P5b: serving_tilt = {m['serving_tilt']:.0f} {'>' if m['serving_tilt'] > 6 else '<='} 6"
608
+ f", nb_5db = {nb5b_s} {'<' if m['nb_within_5db_per_row'] < 1.0 else '>='} 1.0"
609
+ f" -> {'OVERRIDE to C1' if p5b_fires else 'no'}."
610
+ )
611
+ if p5b_fires:
612
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
613
+ return True
614
+
615
+ lines.append("No override. C3 confirmed (off-axis neighbor-better path).")
616
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C3']}.")
617
+ return True
618
+
619
+
620
+ def _walk_c1c3_tiebreaker(lines: list, m: Dict, canonical: str, available: set,
621
+ evidence: Dict, v16_override: str, is_expert: bool):
622
+ """Walk V19 C1/C3 tiebreaker with SINR gate, updated thresholds, and rescue rules."""
623
+ tilt = m['serving_tilt']
624
+ avg_rsrp = m['avg_rsrp']
625
+ min_nb_diff = m['min_neighbor_diff']
626
+ avg_sinr = m.get('avg_sinr')
627
+
628
+ pred, conf = classify_c1_vs_c3(tilt, avg_rsrp, min_nb_diff, avg_sinr)
629
+
630
+ rsrp_s = safe_cmp_fmt(avg_rsrp, -90, '<', decimals=3)
631
+ rsrp_s2 = safe_cmp_fmt(avg_rsrp, -82, '>', decimals=3)
632
+
633
+ # tilt >= 28 with SINR gate
634
+ if tilt >= 28:
635
+ if avg_sinr is not None and avg_sinr >= 12:
636
+ lines.append(
637
+ f"serving_tilt = {tilt:.0f} >= 28, avg_sinr = {avg_sinr:.1f} >= 12"
638
+ f" -> SINR gate: high confidence C3 (good SINR despite high tilt)."
639
+ )
640
+ else:
641
+ sinr_str = f"{avg_sinr:.1f}" if avg_sinr is not None else "N/A"
642
+ lines.append(
643
+ f"serving_tilt = {tilt:.0f} >= 28, avg_sinr = {sinr_str} < 12"
644
+ f" -> high confidence C1."
645
+ )
646
+ elif tilt < 12:
647
+ lines.append(f"serving_tilt = {tilt:.0f} < 12 -> high confidence C3.")
648
+ elif avg_rsrp < -90:
649
+ lines.append(
650
+ f"serving_tilt = {tilt:.0f} (12-27 range), avg_rsrp = {rsrp_s} < -90"
651
+ f" -> medium confidence C1."
652
+ )
653
+ elif avg_rsrp > -82:
654
+ lines.append(
655
+ f"serving_tilt = {tilt:.0f} (12-27 range), avg_rsrp = {rsrp_s2} > -82"
656
+ f" -> medium confidence C3."
657
+ )
658
+ else:
659
+ lines.append(
660
+ f"serving_tilt = {tilt:.0f} (12-27 range), avg_rsrp = {avg_rsrp:.3f} (-90 to -82)"
661
+ f" -> low confidence {pred}."
662
+ )
663
+
664
+ # Low confidence -> rescue rules
665
+ if conf == 'low':
666
+ lines.append("Low confidence -> applying rescue rules:")
667
+ _walk_rescue(lines, m, canonical, available, is_expert)
668
+ return
669
+
670
+ if conf in ('high', 'medium') and 'C1' in available and 'C3' in available:
671
+ lines.append("V16 override checks:")
672
+ if pred == 'C3':
673
+ resolved = _show_c3_overrides(lines, m, available, canonical)
674
+ else:
675
+ resolved = _show_c1_overrides(lines, m, available, canonical)
676
+
677
+ if not resolved:
678
+ if pred == canonical:
679
+ lines.append(f"No override. {pred} confirmed.")
680
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS[canonical]}.")
681
+ else:
682
+ lines.append(f"Deterministic classifier predicts {pred}, but examining additional indicators:")
683
+ _add_expert_reasoning(lines, m, canonical)
684
+ else:
685
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS[canonical]}.")
686
+
687
+
688
+ def _walk_rescue(lines: list, m: Dict, canonical: str, available: set, is_expert: bool):
689
+ """Walk V19 rescue rules R1-R4 for low-confidence C1/C3 cases."""
690
+ cr = m['pci_collision_ratio']
691
+ nb = m['strong_neighbor_count']
692
+ c4 = m['c4_interference']
693
+
694
+ # R1: collision_ratio >= 0.9 -> C6
695
+ cr_s = safe_cmp_fmt(cr, 0.9, '>=', decimals=2)
696
+ r1_fires = cr >= 0.9 and 'C6' in available
697
+ lines.append(
698
+ f" R1: pci_collision_ratio = {cr_s} {'>=' if cr >= 0.9 else '<'} 0.9"
699
+ f" -> {'C6' if r1_fires else 'no'}."
700
+ )
701
+ if r1_fires:
702
+ if canonical == 'C6' or not is_expert:
703
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C6']}.")
704
+ else:
705
+ lines.append(f"Rescue rule suggests C6, but examining additional indicators:")
706
+ _add_expert_reasoning(lines, m, canonical)
707
+ return
708
+
709
+ # R2: strong_neighbors < 0.8 -> C1
710
+ nb_s = safe_cmp_fmt(nb, 0.8, '<', decimals=2)
711
+ r2_fires = nb < 0.8 and 'C1' in available
712
+ lines.append(
713
+ f" R2: strong_neighbor_count = {nb_s} {'<' if nb < 0.8 else '>='} 0.8"
714
+ f" -> {'C1' if r2_fires else 'no'}."
715
+ )
716
+ if r2_fires:
717
+ if canonical == 'C1' or not is_expert:
718
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
719
+ else:
720
+ lines.append(f"Rescue rule suggests C1, but examining additional indicators:")
721
+ _add_expert_reasoning(lines, m, canonical)
722
+ return
723
+
724
+ # R3: c4_interference >= 3.0 -> C1
725
+ c4_s = safe_cmp_fmt(c4, 3.0, '>=', decimals=2)
726
+ r3_fires = c4 >= 3.0 and 'C1' in available
727
+ lines.append(
728
+ f" R3: c4_interference = {c4_s} {'>=' if c4 >= 3.0 else '<'} 3.0"
729
+ f" -> {'C1' if r3_fires else 'no'}."
730
+ )
731
+ if r3_fires:
732
+ if canonical == 'C1' or not is_expert:
733
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
734
+ else:
735
+ lines.append(f"Rescue rule suggests C1, but examining additional indicators:")
736
+ _add_expert_reasoning(lines, m, canonical)
737
+ return
738
+
739
+ # R4: default -> C3
740
+ lines.append(" R4: default fallback -> C3.")
741
+ if canonical == 'C3' or not is_expert:
742
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C3']}.")
743
+ else:
744
+ lines.append(f"Default rescue suggests C3, but examining additional indicators:")
745
+ _add_expert_reasoning(lines, m, canonical)
746
+
747
+
748
+ def _show_c3_overrides(lines: list, m: Dict, available: set, canonical: str) -> bool:
749
+ """Show V16 overrides for C3 prediction: P2, G, J, P5b."""
750
+ p2_fires = 'C6' in available and m['pci_collision_ratio'] > 0.70
751
+ cr_s = safe_cmp_fmt(m['pci_collision_ratio'], 0.70, '>', decimals=2)
752
+ lines.append(
753
+ f" P2: pci_collision_ratio = {cr_s} {'>' if m['pci_collision_ratio'] > 0.70 else '<='} 0.70"
754
+ f" -> {'OVERRIDE to C6' if p2_fires else 'no'}."
755
+ )
756
+ if p2_fires:
757
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C6']}.")
758
+ return True
759
+
760
+ g_fires = ('C1' in available
761
+ and m['rsrp_change_during_prob'] > 5
762
+ and m['rsrp_trend'] > 0.5
763
+ and m['nb_within_5db_per_row'] < 1.0)
764
+ rc_s = safe_cmp_fmt(m['rsrp_change_during_prob'], 5, '>')
765
+ rt_s = safe_cmp_fmt(m['rsrp_trend'], 0.5, '>', decimals=2)
766
+ nb5_s = safe_cmp_fmt(m['nb_within_5db_per_row'], 1.0, '<', decimals=2)
767
+ lines.append(
768
+ f" G: rsrp_change = {rc_s} {'>' if m['rsrp_change_during_prob'] > 5 else '<='} 5"
769
+ f", rsrp_trend = {rt_s} {'>' if m['rsrp_trend'] > 0.5 else '<='} 0.5"
770
+ f", nb_5db = {nb5_s} {'<' if m['nb_within_5db_per_row'] < 1.0 else '>='} 1.0"
771
+ f" -> {'OVERRIDE to C1' if g_fires else 'no'}."
772
+ )
773
+ if g_fires:
774
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
775
+ return True
776
+
777
+ j_fires = 'C1' in available and m['rsrp_recovery'] > 15
778
+ rr_s = safe_cmp_fmt(m['rsrp_recovery'], 15, '>')
779
+ lines.append(
780
+ f" J: rsrp_recovery = {rr_s} {'>' if m['rsrp_recovery'] > 15 else '<='} 15"
781
+ f" -> {'OVERRIDE to C1' if j_fires else 'no'}."
782
+ )
783
+ if j_fires:
784
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
785
+ return True
786
+
787
+ p5b_fires = 'C1' in available and m['serving_tilt'] > 6 and m['nb_within_5db_per_row'] < 1.0
788
+ nb5b_s = safe_cmp_fmt(m['nb_within_5db_per_row'], 1.0, '<', decimals=2)
789
+ lines.append(
790
+ f" P5b: serving_tilt = {m['serving_tilt']:.0f} {'>' if m['serving_tilt'] > 6 else '<='} 6"
791
+ f", nb_5db = {nb5b_s} {'<' if m['nb_within_5db_per_row'] < 1.0 else '>='} 1.0"
792
+ f" -> {'OVERRIDE to C1' if p5b_fires else 'no'}."
793
+ )
794
+ if p5b_fires:
795
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C1']}.")
796
+ return True
797
+
798
+ return False
799
+
800
+
801
+ def _show_c1_overrides(lines: list, m: Dict, available: set, canonical: str) -> bool:
802
+ """Show V16 overrides for C1 prediction: P3, P4."""
803
+ p3_fires = 'C6' in available and m['pci_collision_ratio'] > 0.70
804
+ cr_s = safe_cmp_fmt(m['pci_collision_ratio'], 0.70, '>', decimals=2)
805
+ lines.append(
806
+ f" P3: pci_collision_ratio = {cr_s} {'>' if m['pci_collision_ratio'] > 0.70 else '<='} 0.70"
807
+ f" -> {'OVERRIDE to C6' if p3_fires else 'no'}."
808
+ )
809
+ if p3_fires:
810
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C6']}.")
811
+ return True
812
+
813
+ p4_fires = 'C3' in available and m['avg_rsrp'] > -79 and m['strong_neighbor_count'] > 1.0
814
+ rsrp_s = safe_cmp_fmt(m['avg_rsrp'], -79, '>', decimals=3)
815
+ nb_s = safe_cmp_fmt(m['strong_neighbor_count'], 1.0, '>', decimals=2)
816
+ lines.append(
817
+ f" P4: avg_rsrp = {rsrp_s} {'>' if m['avg_rsrp'] > -79 else '<='} -79"
818
+ f", strong_neighbor_count = {nb_s} {'>' if m['strong_neighbor_count'] > 1.0 else '<='} 1.0"
819
+ f" -> {'OVERRIDE to C3' if p4_fires else 'no'}."
820
+ )
821
+ if p4_fires:
822
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS['C3']}.")
823
+ return True
824
+
825
+ return False
826
+
827
+
828
+ def _add_expert_reasoning(lines: list, m: Dict, canonical: str):
829
+ """Add expert reasoning for cases where V19 prediction differs from ground truth."""
830
+ tilt = m['serving_tilt']
831
+ tilt_desc = 'high' if tilt >= 20 else 'moderate' if tilt >= 12 else 'low'
832
+ if canonical == 'C1':
833
+ nb = m['strong_neighbor_count']
834
+ nb_diff = m['min_neighbor_diff']
835
+ if nb < 1.0:
836
+ lines.append(
837
+ f"strong_neighbor_count = {nb:.2f} - few strong neighbors during low TP."
838
+ f" min_neighbor_diff = {nb_diff:.1f} dB - neighbors generally weaker."
839
+ f" Pattern of few strong neighbors despite {tilt_desc} tilt ({tilt:.0f} deg) suggests downtilt is"
840
+ f" causing signal weakness at the cell edge."
841
+ )
842
+ else:
843
+ lines.append(
844
+ f"strong_neighbor_count = {nb:.2f} - some neighbors within 6 dB."
845
+ f" min_neighbor_diff = {nb_diff:.1f} dB - but neighbors provide weaker signal overall."
846
+ f" Despite nearby neighbors, the negative neighbor difference shows the serving cell's"
847
+ f" {tilt_desc} tilt ({tilt:.0f} deg) is degrading coverage, not that a neighbor provides"
848
+ f" better throughput."
849
+ )
850
+ elif canonical == 'C3':
851
+ lines.append(
852
+ f"strong_neighbor_count = {m['strong_neighbor_count']:.2f} - multiple neighbors within 6 dB."
853
+ f" min_neighbor_diff = {m['min_neighbor_diff']:.1f} dB - at least one neighbor provides"
854
+ f" comparable or stronger signal. A neighboring cell can provide higher throughput."
855
+ )
856
+ elif canonical == 'C6':
857
+ lines.append(
858
+ f"Although {tilt_desc} tilt ({tilt:.0f} deg) initially suggested downtilt rather than collision,"
859
+ f" pci_collision_ratio = {m['pci_collision_ratio']:.2f} indicates collision is present in"
860
+ f" {m['pci_collision_ratio']*100:.0f}% of drive test rows."
861
+ f" avg_off_axis = {m['avg_off_axis']:.1f} deg - UE in the main beam where collision"
862
+ f" has maximum impact. The persistent PCI mod 30 collision overrides the tilt signal."
863
+ )
864
+ elif canonical == 'C4':
865
+ lines.append(
866
+ f"c4_interference = {m['c4_interference']:.2f} dB shows significant non-colocated"
867
+ f" co-frequency interference. Despite other indicators, the interference level"
868
+ f" is the primary throughput degradation factor."
869
+ )
870
+ else:
871
+ lines.append(
872
+ f"Further analysis of the drive test data confirms the root cause"
873
+ f" is {CAUSE_DESCRIPTIONS.get(canonical, canonical)}."
874
+ )
875
+ lines.append(f"The root cause is {CAUSE_DESCRIPTIONS[canonical]}.")
876
+
877
+
878
+ # =============================================================================
879
+ # TYPE B METRICS (used by SFT/GRPO/inference for prompt preparation)
880
+ # =============================================================================
881
+
882
+ def compute_type_b_metrics(question: str) -> Optional[Dict]:
883
+ """Compute all Type B metrics from a question string.
884
+
885
+ Includes config/signaling parsing: n1_in_config, inter_freq_ho,
886
+ a2_thld, n_configured_neighbors.
887
+ """
888
+ drive_test, signaling = parse_type_b_question(question)
889
+ if not drive_test:
890
+ return None
891
+
892
+ rsrps = [d['rsrp'] for d in drive_test if d['rsrp']]
893
+ sinrs = [d['sinr'] for d in drive_test if d['sinr']]
894
+ throughputs = [d['throughput'] for d in drive_test if d['throughput']]
895
+ cce_fails = [d['cce_fail_rate'] for d in drive_test if d['cce_fail_rate'] is not None]
896
+ blers = [d['initial_bler'] for d in drive_test if d['initial_bler'] is not None]
897
+ rb_slots = [d['rb_slot'] for d in drive_test if d['rb_slot'] is not None]
898
+
899
+ neighbor1_rsrps = [d['neighbor1_rsrp'] for d in drive_test if d.get('neighbor1_rsrp') is not None]
900
+ neighbor2_rsrps = [d['neighbor2_rsrp'] for d in drive_test if d.get('neighbor2_rsrp') is not None]
901
+ neighbor3_rsrps = [d['neighbor3_rsrp'] for d in drive_test if d.get('neighbor3_rsrp') is not None]
902
+
903
+ avg_rsrp = sum(rsrps) / len(rsrps) if rsrps else -90
904
+ avg_sinr = sum(sinrs) / len(sinrs) if sinrs else 10
905
+ avg_cce_fail = sum(cce_fails) / len(cce_fails) if cce_fails else 0
906
+ avg_bler = sum(blers) / len(blers) if blers else 0
907
+ avg_rb = sum(rb_slots) / len(rb_slots) if rb_slots else 200
908
+
909
+ avg_n1_rsrp = sum(neighbor1_rsrps) / len(neighbor1_rsrps) if neighbor1_rsrps else -120
910
+ min_neighbor_diff = avg_rsrp - avg_n1_rsrp
911
+
912
+ std_rsrp = (sum((r - avg_rsrp)**2 for r in rsrps) / len(rsrps))**0.5 if len(rsrps) > 1 else 0
913
+ rsrp_var_norm = std_rsrp / abs(avg_rsrp) if avg_rsrp != 0 else 0
914
+
915
+ pcis = [d['serving_pci'] for d in drive_test if d['serving_pci']]
916
+ actual_handovers = sum(1 for i in range(1, len(pcis)) if pcis[i] != pcis[i-1]) if len(pcis) > 1 else 0
917
+
918
+ ratio_a3_ho = signaling['a3_events'] / max(actual_handovers, 1)
919
+ rrc_reestablish = signaling.get('rrc_reestablish', 0)
920
+
921
+ # Conditional metrics during low-TP rows
922
+ low_tp_rows = [d for d in drive_test if d.get('throughput') is not None and d['throughput'] < 100]
923
+
924
+ def safe_avg(rows, key):
925
+ vals = [d[key] for d in rows if d.get(key) is not None]
926
+ return sum(vals) / len(vals) if vals else None
927
+
928
+ low_tp_avg_mcs = safe_avg(low_tp_rows, 'avg_mcs')
929
+ low_tp_avg_sinr = safe_avg(low_tp_rows, 'sinr')
930
+ low_tp_avg_bler = safe_avg(low_tp_rows, 'initial_bler')
931
+
932
+ phy_healthy_during_low_tp = None
933
+ if low_tp_avg_mcs is not None and low_tp_avg_sinr is not None and low_tp_avg_bler is not None:
934
+ phy_healthy_during_low_tp = (
935
+ low_tp_avg_mcs > 10 and
936
+ low_tp_avg_sinr > 8 and
937
+ low_tp_avg_bler < 15
938
+ )
939
+
940
+ avg_n2_rsrp = sum(neighbor2_rsrps) / len(neighbor2_rsrps) if neighbor2_rsrps else -120
941
+ avg_n3_rsrp = sum(neighbor3_rsrps) / len(neighbor3_rsrps) if neighbor3_rsrps else -120
942
+
943
+ neighbors_within_3dB = 0
944
+ neighbors_within_5dB = 0
945
+ for avg_n in [avg_n1_rsrp, avg_n2_rsrp, avg_n3_rsrp]:
946
+ if avg_n > -115:
947
+ diff = avg_rsrp - avg_n
948
+ if diff < 3:
949
+ neighbors_within_3dB += 1
950
+ if diff < 5:
951
+ neighbors_within_5dB += 1
952
+
953
+ n1_stronger_count = 0
954
+ n1_total = 0
955
+ for d in drive_test:
956
+ if d.get('rsrp') is not None and d.get('neighbor1_rsrp') is not None:
957
+ n1_total += 1
958
+ if d['neighbor1_rsrp'] > d['rsrp']:
959
+ n1_stronger_count += 1
960
+ n1_stronger_pct = (n1_stronger_count / n1_total * 100) if n1_total > 0 else 0
961
+
962
+ # Configuration and signaling table parsing
963
+ config_cells = parse_config_data(question)
964
+ inter_freq_ho = detect_inter_freq_ho(question)
965
+ n1_in_config, serving_pci, n1_pci = check_n1_in_config(question, drive_test, config_cells)
966
+
967
+ # Extract a2_thld and n_configured_neighbors from config
968
+ a2_thld = None
969
+ n_configured_neighbors = 0
970
+ if serving_pci and serving_pci in config_cells:
971
+ cfg = config_cells[serving_pci]
972
+ a2_thld = cfg.get('a2_rsrp_thld')
973
+ n_configured_neighbors = cfg.get('n_configured_neighbors', 0)
974
+
975
+ return {
976
+ 'avg_rsrp': avg_rsrp,
977
+ 'avg_sinr': avg_sinr,
978
+ 'avg_cce_fail': avg_cce_fail,
979
+ 'avg_bler': avg_bler,
980
+ 'avg_rb': avg_rb,
981
+ 'actual_handovers': actual_handovers,
982
+ 'a3_events': signaling['a3_events'],
983
+ 'ratio_a3_ho': ratio_a3_ho,
984
+ 'rrc_reestablish': rrc_reestablish,
985
+ 'rsrp_var_norm': rsrp_var_norm,
986
+ 'min_neighbor_diff': min_neighbor_diff,
987
+ 'low_tp_avg_mcs': low_tp_avg_mcs,
988
+ 'low_tp_avg_sinr': low_tp_avg_sinr,
989
+ 'low_tp_avg_bler': low_tp_avg_bler,
990
+ 'phy_healthy_during_low_tp': phy_healthy_during_low_tp,
991
+ 'neighbors_within_3dB': neighbors_within_3dB,
992
+ 'neighbors_within_5dB': neighbors_within_5dB,
993
+ 'n1_stronger_pct': n1_stronger_pct,
994
+ 'n1_in_config': n1_in_config,
995
+ 'inter_freq_ho': inter_freq_ho,
996
+ 'a2_thld': a2_thld,
997
+ 'n_configured_neighbors': n_configured_neighbors,
998
+ }
999
+
1000
+
1001
+ def format_type_b_metrics_block(m: Dict) -> str:
1002
+ """Format Type B metrics as a structured text block for the user message."""
1003
+ lines = [
1004
+ "Extracted metrics:",
1005
+ f" avg_rsrp = {m['avg_rsrp']:.1f} dBm",
1006
+ f" avg_sinr = {m['avg_sinr']:.1f} dB",
1007
+ f" avg_cce_fail = {m['avg_cce_fail']:.2f}",
1008
+ f" avg_bler = {m['avg_bler']:.1f}%",
1009
+ f" avg_rb = {m['avg_rb']:.0f}",
1010
+ f" actual_handovers = {m['actual_handovers']}",
1011
+ f" a3_events = {m['a3_events']}",
1012
+ f" ratio_a3_ho = {m['ratio_a3_ho']:.2f}",
1013
+ f" rrc_reestablish = {m['rrc_reestablish']}",
1014
+ f" rsrp_var_norm = {m['rsrp_var_norm']:.3f}",
1015
+ f" min_neighbor_diff = {m['min_neighbor_diff']:.1f} dB",
1016
+ ]
1017
+
1018
+ if m['low_tp_avg_mcs'] is not None:
1019
+ lines.append(f" low_tp_avg_mcs = {m['low_tp_avg_mcs']:.1f}")
1020
+ else:
1021
+ lines.append(" low_tp_avg_mcs = N/A")
1022
+
1023
+ if m['low_tp_avg_sinr'] is not None:
1024
+ lines.append(f" low_tp_avg_sinr = {m['low_tp_avg_sinr']:.1f} dB")
1025
+ else:
1026
+ lines.append(" low_tp_avg_sinr = N/A")
1027
+
1028
+ if m['low_tp_avg_bler'] is not None:
1029
+ lines.append(f" low_tp_avg_bler = {m['low_tp_avg_bler']:.1f}%")
1030
+ else:
1031
+ lines.append(" low_tp_avg_bler = N/A")
1032
+
1033
+ if m['phy_healthy_during_low_tp'] is not None:
1034
+ lines.append(f" phy_healthy_during_low_tp = {m['phy_healthy_during_low_tp']}")
1035
+ else:
1036
+ lines.append(" phy_healthy_during_low_tp = N/A")
1037
+
1038
+ lines.extend([
1039
+ f" neighbors_within_3dB = {m['neighbors_within_3dB']}",
1040
+ f" neighbors_within_5dB = {m['neighbors_within_5dB']}",
1041
+ f" n1_stronger_pct = {m['n1_stronger_pct']:.1f}%",
1042
+ ])
1043
+
1044
+ lines.append(f" n1_in_config = {m.get('n1_in_config', 'N/A')}")
1045
+ lines.append(f" inter_freq_ho = {m.get('inter_freq_ho', False)}")
1046
+ if m.get('a2_thld') is not None:
1047
+ lines.append(f" a2_thld = {m['a2_thld']}")
1048
+ else:
1049
+ lines.append(" a2_thld = N/A")
1050
+ lines.append(f" n_configured_neighbors = {m.get('n_configured_neighbors', 0)}")
1051
+
1052
+ return '\n'.join(lines)
1053
+
1054
+
1055
+ # =============================================================================
1056
+ # TRACE GENERATOR
1057
+ # =============================================================================
1058
+
1059
+ def generate_type_a_traces(
1060
+ train_csv: Path,
1061
+ output_dict: Dict,
1062
+ stats: Dict,
1063
+ spot_check: int = 0,
1064
+ ):
1065
+ """Generate Type A traces from train.csv with ground truth labels.
1066
+
1067
+ Args:
1068
+ train_csv: Path to train.csv (must have 'ID', 'question', 'answer' columns)
1069
+ output_dict: dict to add traces to
1070
+ stats: stats dict to update
1071
+ """
1072
+ logger.info(f"Loading training data from {train_csv}")
1073
+ train_df = pd.read_csv(train_csv)
1074
+ logger.info(f"Loaded {len(train_df)} training questions")
1075
+
1076
+ ground_truth_map = dict(zip(train_df['ID'], train_df['answer']))
1077
+ logger.info(f"Ground truth labels: {len(ground_truth_map)}")
1078
+
1079
+ for _, row in train_df.iterrows():
1080
+ qid = row['ID']
1081
+ question = row['question']
1082
+
1083
+ qtype = classify_question_type(question)
1084
+ if qtype != 'type_a_telco':
1085
+ continue
1086
+
1087
+ ground_truth = ground_truth_map.get(qid)
1088
+ if not ground_truth:
1089
+ logger.warning(f"{qid}: no ground truth, skipping")
1090
+ continue
1091
+
1092
+ result = classify_type_a(question)
1093
+ drive_test, cells = parse_type_a_question(question)
1094
+ option_map = extract_type_a_options(question)
1095
+ cause_to_label = {cause: label for label, cause in option_map.items()}
1096
+ available_causes = set(option_map.values())
1097
+
1098
+ answer_label = cause_to_label.get(ground_truth)
1099
+ if not answer_label:
1100
+ logger.warning(f"{qid}: ground truth {ground_truth} not in options, skipping")
1101
+ continue
1102
+
1103
+ m = compute_all_metrics(question, drive_test, cells)
1104
+
1105
+ is_expert = result['confidence'] == 'needs_llm'
1106
+ if not is_expert and result['canonical'] != ground_truth:
1107
+ logger.info(f"{qid}: V19 predicted {result['canonical']} but truth is {ground_truth}")
1108
+ is_expert = True
1109
+
1110
+ source = 'expert' if is_expert else 'deterministic'
1111
+ trace_text = generate_trace(m, result, available_causes, ground_truth, is_expert=is_expert)
1112
+ formatted_trace = f"<think>\n{trace_text}\n</think>"
1113
+
1114
+ output_dict[qid] = {
1115
+ 'question': question,
1116
+ 'expected_answer': answer_label,
1117
+ 'derived_answer': answer_label,
1118
+ 'reasoning_trace': formatted_trace,
1119
+ 'attempts': 1,
1120
+ 'success': True,
1121
+ 'source': source,
1122
+ 'question_type': 'type_a',
1123
+ }
1124
+
1125
+ stats['total'] += 1
1126
+ stats['by_cause'][ground_truth] += 1
1127
+ stats['by_confidence'][result['confidence']] += 1
1128
+ stats['expert' if is_expert else 'deterministic'] += 1
1129
+ stats['correct' if result['canonical'] == ground_truth else 'incorrect'] += 1
1130
+ stats['trace_lengths'].append(len(trace_text))
1131
+
1132
+
1133
+ def _print_summary(stats, spot_check, checkpoint):
1134
+ logger.info("=" * 60)
1135
+ logger.info("GENERATION SUMMARY")
1136
+ logger.info("=" * 60)
1137
+ logger.info(f"Type A traces: {stats['total']}")
1138
+ logger.info(f" Deterministic: {stats['deterministic']}")
1139
+ logger.info(f" Expert: {stats['expert']}")
1140
+ logger.info(f" V19 accuracy: {stats['correct']}/{stats['total']}")
1141
+ logger.info("")
1142
+ logger.info("By cause:")
1143
+ for c in sorted(stats['by_cause']):
1144
+ logger.info(f" {c}: {stats['by_cause'][c]}")
1145
+ logger.info("")
1146
+ logger.info("By confidence:")
1147
+ for c in sorted(stats['by_confidence']):
1148
+ logger.info(f" {c}: {stats['by_confidence'][c]}")
1149
+
1150
+ if stats['trace_lengths']:
1151
+ L = stats['trace_lengths']
1152
+ logger.info(f"Trace length (chars): min={min(L)}, max={max(L)}, mean={sum(L)/len(L):.0f}")
1153
+
1154
+ logger.info(f"\nTotal traces: {len(checkpoint)}")
1155
+
1156
+ if spot_check > 0:
1157
+ logger.info("")
1158
+ logger.info("=" * 60)
1159
+ logger.info(f"SPOT CHECK ({spot_check} samples)")
1160
+ logger.info("=" * 60)
1161
+ items = list(checkpoint.items())
1162
+ det = [(k, v) for k, v in items if v['source'] == 'deterministic']
1163
+ exp = [(k, v) for k, v in items if v['source'] == 'expert']
1164
+
1165
+ shown = 0
1166
+ for label, pool in [("DETERMINISTIC", det[:3]), ("EXPERT", exp[:2])]:
1167
+ for qid, data in pool:
1168
+ if shown >= spot_check:
1169
+ break
1170
+ logger.info(f"\n--- [{label}] {qid} -> {data['expected_answer']} ({data['source']}) ---")
1171
+ trace = data['reasoning_trace']
1172
+ logger.info(trace[:2000])
1173
+ if len(trace) > 2000:
1174
+ logger.info(f"... ({len(trace)} chars total)")
1175
+ shown += 1
1176
+
1177
+
1178
+ # =============================================================================
1179
+ # VALIDATION
1180
+ # =============================================================================
1181
+
1182
+ def validate_checkpoint(checkpoint_path: Path, train_csv: Path = None):
1183
+ """Validate that all traces have correct format and match ground truth."""
1184
+ with open(checkpoint_path) as f:
1185
+ checkpoint = json.load(f)
1186
+
1187
+ issues = []
1188
+ for qid, data in checkpoint.items():
1189
+ trace = data.get('reasoning_trace', '')
1190
+ if not trace.startswith('<think>'):
1191
+ issues.append(f"{qid}: missing <think> tag")
1192
+ if not trace.rstrip().endswith('</think>'):
1193
+ issues.append(f"{qid}: missing </think> closing tag")
1194
+
1195
+ if train_csv and train_csv.exists():
1196
+ train_df = pd.read_csv(train_csv)
1197
+ train_answers = dict(zip(train_df['ID'], train_df['answer']))
1198
+
1199
+ for qid, data in checkpoint.items():
1200
+ if data.get('question_type') == 'type_a':
1201
+ gt = train_answers.get(qid)
1202
+ if gt and data['expected_answer'] != gt:
1203
+ issues.append(f"{qid}: answer={data['expected_answer']} != truth={gt}")
1204
+
1205
+ if issues:
1206
+ logger.error(f"Validation found {len(issues)} issues:")
1207
+ for issue in issues[:20]:
1208
+ logger.error(f" {issue}")
1209
+ return False
1210
+
1211
+ type_a = sum(1 for v in checkpoint.values() if v.get('question_type') == 'type_a')
1212
+ logger.info(f"Validation passed: {len(checkpoint)} traces ({type_a} Type A), format OK")
1213
+ return True
1214
+
1215
+
1216
+ # =============================================================================
1217
+ # MAIN
1218
+ # =============================================================================
1219
+
1220
+ def main():
1221
+ parser = argparse.ArgumentParser(
1222
+ description="Generate reasoning traces from train.csv for SFT/GRPO training",
1223
+ )
1224
+ parser.add_argument('--output', type=str, default=str(OUTPUT_DIR / 'traces_final.json'))
1225
+ parser.add_argument('--spot-check', type=int, default=0)
1226
+ parser.add_argument('--validate-only', type=str, default=None)
1227
+ args = parser.parse_args()
1228
+
1229
+ if args.validate_only:
1230
+ train_csv = DATA_DIR / 'train.csv'
1231
+ validate_checkpoint(Path(args.validate_only), train_csv)
1232
+ return
1233
+
1234
+ output_path = Path(args.output)
1235
+
1236
+ checkpoint = {}
1237
+ stats = {
1238
+ 'total': 0, 'deterministic': 0, 'expert': 0,
1239
+ 'correct': 0, 'incorrect': 0,
1240
+ 'by_cause': Counter(), 'by_confidence': Counter(),
1241
+ 'trace_lengths': [],
1242
+ }
1243
+
1244
+ train_csv = DATA_DIR / 'train.csv'
1245
+ if not train_csv.exists():
1246
+ logger.error(f"Training data not found: {train_csv}")
1247
+ return
1248
+
1249
+ logger.info("=" * 60)
1250
+ logger.info("GENERATING TRACES FROM TRAIN.CSV")
1251
+ logger.info("=" * 60)
1252
+ generate_type_a_traces(
1253
+ train_csv=train_csv,
1254
+ output_dict=checkpoint,
1255
+ stats=stats,
1256
+ )
1257
+ logger.info(f"Generated {len(checkpoint)} traces")
1258
+
1259
+ # Save
1260
+ logger.info(f"Saving {len(checkpoint)} traces to {output_path}")
1261
+ output_path.parent.mkdir(parents=True, exist_ok=True)
1262
+ with open(output_path, 'w') as f:
1263
+ json.dump(checkpoint, f, indent=2)
1264
+
1265
+ _print_summary(stats, args.spot_check, checkpoint)
1266
+
1267
+ # Validate
1268
+ validate_checkpoint(output_path, train_csv=train_csv)
1269
+
1270
+
1271
+ if __name__ == '__main__':
1272
+ main()