File size: 32,899 Bytes
097b6c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
"""
advanced_data_analysis.py
─────────────────────────
Goes far beyond v2. Covers:
  A. Patient-level leakage detection
  B. Report quality scoring (6 dimensions)
  C. Image quality analysis (brightness, contrast, rotation, crop)
  D. Longitudinal / multi-study per patient analysis
  E. Vocabulary & terminology audit
  F. Section completeness audit
  G. Finding complexity distribution
  H. Report deduplication (copy-paste between studies)
  I. CheXBERT-style 14-label frequency (regex approximation)
  J. Outputs: JSON summary + per-report quality CSV

Usage:
  python advanced_data_analysis.py \
      --reports_dir dataset/files \
      --images_root dataset \
      --images_glob "images_*" \
      --out_dir analysis_out
"""

import argparse
import csv
import hashlib
import json
import math
import re
import statistics
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

# ─────────────────────────────────────────────────────────────────────────────
# A. PATIENT-LEVEL LEAKAGE DETECTION
# MIMIC study IDs are s#######; patient IDs are p#######.
# The metadata file (mimic-cxr-2.0.0-metadata.csv) maps them.
# If you don't have it, we fall back to grouping by study-ID prefix similarity.
# ─────────────────────────────────────────────────────────────────────────────

def load_study_to_patient(metadata_csv: Optional[Path]) -> Dict[str, str]:
    """Returns {study_id -> patient_id}. Needs mimic-cxr metadata CSV."""
    if metadata_csv is None or not metadata_csv.exists():
        return {}
    mapping = {}
    with metadata_csv.open(encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            sid = row.get("study_id", "").strip().lstrip("s")
            pid = row.get("subject_id", "").strip().lstrip("p")
            if sid and pid:
                mapping[f"s{sid}"] = f"p{pid}"
    print(f"[leakage] Loaded {len(mapping)} study→patient mappings.")
    return mapping


def check_patient_leakage(
    train_ids: List[str], val_ids: List[str],
    study_to_patient: Dict[str, str],
) -> Dict:
    """
    Detect patient IDs that appear in BOTH train and val splits.
    This is a critical data-leakage issue β€” the model sees the same
    patient's anatomy in training and evaluation.
    """
    if not study_to_patient:
        return {"leakage_check": "skipped β€” no metadata CSV provided"}

    train_patients = {study_to_patient.get(s, s) for s in train_ids}
    val_patients   = {study_to_patient.get(s, s) for s in val_ids}
    leaked = train_patients & val_patients

    result = {
        "train_patients": len(train_patients),
        "val_patients":   len(val_patients),
        "leaked_patients": len(leaked),
        "leak_rate_pct":  round(100 * len(leaked) / len(val_patients), 2) if val_patients else 0,
        "sample_leaked_ids": sorted(leaked)[:10],
    }
    if leaked:
        print(f"[leakage] ⚠️  {len(leaked)} patients appear in BOTH train and val!")
        print(f"  Leak rate: {result['leak_rate_pct']}% of val patients")
        print(f"  Fix: use split_by_patient() instead of split_by_study()")
    else:
        print(f"[leakage] βœ“ No patient leakage detected.")
    return result


def split_by_patient(
    samples: List[Dict], study_to_patient: Dict[str, str],
    val_ratio: float = 0.02, seed: int = 42,
) -> Tuple[List[Dict], List[Dict]]:
    """
    Correct split: group all studies by patient, then split PATIENTS.
    This guarantees no patient's anatomy appears in both train and val.
    """
    import random
    rng = random.Random(seed)
    patient_to_studies: Dict[str, List[Dict]] = defaultdict(list)
    for s in samples:
        pid = study_to_patient.get(s["study_id"], s["study_id"])
        patient_to_studies[pid].append(s)

    patients = sorted(patient_to_studies.keys())
    rng.shuffle(patients)
    n_val = max(1, int(len(patients) * val_ratio))
    val_patients = set(patients[:n_val])

    train_all = [s for pid, ss in patient_to_studies.items()
                   for s in ss if pid not in val_patients]
    val_all   = [s for pid, ss in patient_to_studies.items()
                   for s in ss if pid in val_patients]

    print(f"[split_by_patient] Train: {len(train_all)} studies from {len(patients)-n_val} patients")
    print(f"[split_by_patient] Val:   {len(val_all)} studies from {n_val} patients")
    return train_all, val_all


# ─────────────────────────────────────────────────────────────────────────────
# B. REPORT QUALITY SCORING (6 dimensions, per report)
# ─────────────────────────────────────────────────────────────────────────────

NEGATION_RE = re.compile(
    r"\b(no|without|absent|not|free of|no evidence|no definite)\b", re.IGNORECASE)

def _is_negated(text: str, start: int, window: int = 8) -> bool:
    words = text[:start].split()
    ctx = " ".join(words[-window:])
    if NEGATION_RE.search(ctx): return True
    s = max(0, text.rfind(".", 0, start) + 1)
    return bool(re.search(r"\bno\b", text[s:start]))

def repetition_rate(text: str, n: int = 5) -> float:
    tokens = text.lower().split()
    if len(tokens) < n: return 0.0
    ngrams = [tuple(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
    c = Counter(ngrams)
    return sum(v-1 for v in c.values() if v > 1) / len(ngrams)

def score_report(report_text: str) -> Dict:
    """
    Returns a quality score dict with 6 dimensions (each 0–1, higher = better).
    
    1. completeness    β€” has both Findings AND Impression sections
    2. specificity     β€” uses specific anatomic/clinical terms vs vague language
    3. reasoning       β€” contains causal/reasoning language ("consistent with",
                         "may represent", "concerning for", "suggestive of")
    4. no_repetition   β€” 1 - repetition_rate (penalises copy-paste)
    5. length_ok       β€” length in the healthy range (150–800 chars)
    6. no_deid_artifacts β€” absence of "___ " placeholders
    """
    t = report_text.strip()
    tl = t.lower()

    # 1. completeness
    has_findings   = bool(re.search(r"\bfindings?\b", tl))
    has_impression = bool(re.search(r"\bimpression\b", tl))
    completeness   = (0.5 * has_findings + 0.5 * has_impression)

    # 2. specificity β€” presence of specific anatomic/clinical vocabulary
    specific_terms = [
        r"\b(cardiomediastinal|costophrenic|hemidiaphragm|parenchymal)\b",
        r"\b(consolidation|atelectasis|pneumothorax|effusion|edema)\b",
        r"\b(basilar|bibasilar|perihilar|retrocardiac|paratracheal)\b",
        r"\b(SVC|IVC|PICC|carina|clavicle|sternotomy)\b",
    ]
    hits = sum(bool(re.search(p, tl)) for p in specific_terms)
    specificity = min(hits / len(specific_terms), 1.0)

    # 3. reasoning language
    reasoning_terms = [
        r"\b(consistent with|compatible with|concerning for|suggestive of)\b",
        r"\b(may represent|likely represents?|cannot exclude|in the setting of)\b",
        r"\b(due to|related to|secondary to|in the appropriate clinical)\b",
    ]
    reason_hits = sum(bool(re.search(p, tl)) for p in reasoning_terms)
    reasoning = min(reason_hits / 2, 1.0)  # 2+ = full score

    # 4. no repetition
    rr = repetition_rate(t)
    no_repetition = max(0.0, 1.0 - rr * 5)  # rr>0.2 β†’ score 0

    # 5. length in healthy range
    char_len = len(t)
    if 150 <= char_len <= 800:
        length_ok = 1.0
    elif char_len < 150:
        length_ok = char_len / 150
    else:
        length_ok = max(0.0, 1.0 - (char_len - 800) / 1400)

    # 6. no deidentification artifacts
    deid_count = len(re.findall(r"\b___\b", t))
    no_deid = max(0.0, 1.0 - deid_count * 0.1)

    overall = (completeness * 0.25 + specificity * 0.20 + reasoning * 0.20
               + no_repetition * 0.20 + length_ok * 0.10 + no_deid * 0.05)

    return {
        "completeness":      round(completeness,   3),
        "specificity":       round(specificity,    3),
        "reasoning":         round(reasoning,      3),
        "no_repetition":     round(no_repetition,  3),
        "length_ok":         round(length_ok,      3),
        "no_deid_artifacts": round(no_deid,        3),
        "overall_quality":   round(overall,        3),
        "char_len":          char_len,
        "repetition_rate":   round(rr,             4),
        "deid_count":        deid_count,
    }


# ─────────────────────────────────────────────────────────────────────────────
# C. IMAGE QUALITY ANALYSIS
# ─────────────────────────────────────────────────────────────────────────────

def analyze_image_quality(image_path: str) -> Dict:
    """
    Checks for:
      - Underexposure / overexposure (mean pixel value)
      - Low contrast (std dev of pixel values)
      - Probable rotation (using simple edge-direction histogram)
      - Extreme aspect ratio (portrait vs landscape)
      - Resolution too low for fine detail
    Returns a dict of flags + raw metrics.
    """
    try:
        from PIL import Image as PILImage
        import numpy as np

        with PILImage.open(image_path) as img:
            gray = np.array(img.convert("L"), dtype=np.float32)

        h, w = gray.shape
        mean_px  = float(gray.mean())
        std_px   = float(gray.std())
        min_dim  = min(h, w)
        aspect   = w / h if h > 0 else 1.0

        flags = []
        if mean_px < 30:   flags.append("underexposed")
        if mean_px > 225:  flags.append("overexposed")
        if std_px < 20:    flags.append("low_contrast")
        if min_dim < 224:  flags.append("too_small")
        if aspect < 0.5 or aspect > 2.5: flags.append("extreme_aspect_ratio")

        # Simple rotation detection via horizontal vs vertical edge ratio
        # Strong horizontal edges = likely PA/AP; strong vertical = lateral
        gy = np.abs(np.diff(gray, axis=0)).mean()   # horizontal edges
        gx = np.abs(np.diff(gray, axis=1)).mean()   # vertical edges
        edge_ratio = float(gx / (gy + 1e-8))
        if edge_ratio > 3.0: flags.append("possible_rotation")

        return {
            "mean_px": round(mean_px, 1),
            "std_px":  round(std_px, 1),
            "width":   w,
            "height":  h,
            "aspect":  round(aspect, 3),
            "edge_ratio": round(edge_ratio, 3),
            "flags":   flags,
            "ok":      len(flags) == 0,
        }
    except Exception as e:
        return {"flags": ["unreadable"], "ok": False, "error": str(e)}


# ─────────────────────────────────────────────────────────────────────────────
# D. LONGITUDINAL ANALYSIS (multiple studies per patient)
# ─────────────────────────────────────────────────────────────────────────────

def analyze_longitudinal(
    samples: List[Dict], study_to_patient: Dict[str, str]
) -> Dict:
    """
    Groups studies by patient, reports how many patients have 1, 2-5, 6-10, 10+
    studies. Multi-study patients risk the model memorising patient anatomy
    rather than learning to read X-rays in general.
    """
    if not study_to_patient:
        return {"status": "skipped β€” no metadata"}

    patient_counts: Counter = Counter()
    for s in samples:
        pid = study_to_patient.get(s["study_id"], s["study_id"])
        patient_counts[pid] += 1

    count_dist = Counter()
    for n in patient_counts.values():
        if   n == 1:          count_dist["1"] += 1
        elif n <= 5:          count_dist["2-5"] += 1
        elif n <= 10:         count_dist["6-10"] += 1
        else:                 count_dist["10+"] += 1

    heavy_patients = {pid: n for pid, n in patient_counts.items() if n > 10}
    studies_from_heavy = sum(heavy_patients.values())

    print(f"\n[longitudinal] {len(patient_counts)} unique patients")
    print(f"  Studies per patient distribution: {dict(count_dist)}")
    print(f"  Patients with >10 studies: {len(heavy_patients)} "
          f"({studies_from_heavy} studies = "
          f"{100*studies_from_heavy/len(samples):.1f}% of data)")

    return {
        "unique_patients": len(patient_counts),
        "studies_per_patient_dist": dict(count_dist),
        "heavy_patients_count": len(heavy_patients),
        "studies_from_heavy_patients": studies_from_heavy,
        "heavy_patient_pct_of_data": round(100*studies_from_heavy/len(samples), 2),
    }


# ─────────────────────────────────────────────────────────────────────────────
# E. VOCABULARY & TERMINOLOGY AUDIT
# ─────────────────────────────────────────────────────────────────────────────

# Common radiological abbreviations that should ideally be expanded for the model
ABBREV_MAP = {
    r"\bAP\b":      "anteroposterior",
    r"\bPA\b":      "posteroanterior",
    r"\bLAT\b":     "lateral",
    r"\bSVC\b":     "superior vena cava",
    r"\bIVC\b":     "inferior vena cava",
    r"\bLLL\b":     "left lower lobe",
    r"\bRLL\b":     "right lower lobe",
    r"\bLUL\b":     "left upper lobe",
    r"\bRUL\b":     "right upper lobe",
    r"\bLML\b":     "left middle lobe",
    r"\bRML\b":     "right middle lobe",
    r"\bCXR\b":     "chest X-ray",
    r"\bSOB\b":     "shortness of breath",
    r"\bCHF\b":     "congestive heart failure",
    r"\bCOPD\b":    "chronic obstructive pulmonary disease",
    r"\bET\b":      "endotracheal",
    r"\bNG\b":      "nasogastric",
    r"\bPICC\b":    "peripherally inserted central catheter",
    r"\bICD\b":     "implantable cardioverter-defibrillator",
    r"\bCABG\b":    "coronary artery bypass graft",
}

def expand_abbreviations(text: str) -> str:
    """Expand common radiology abbreviations for cleaner model input."""
    for pattern, expansion in ABBREV_MAP.items():
        text = re.sub(pattern, expansion, text)
    return text

def vocabulary_audit(reports: List[str], top_n: int = 50) -> Dict:
    """
    Analyses the vocabulary across all reports:
    - Total unique tokens
    - Most common clinical terms
    - Abbreviation frequency (candidates for expansion)
    - Vague/hedge language frequency
    """
    all_tokens: Counter = Counter()
    abbrev_hits: Counter = Counter()
    hedge_hits:  Counter = Counter()

    HEDGE_TERMS = [
        "may", "might", "could", "possible", "possibly", "probable", "probably",
        "likely", "unlikely", "cannot exclude", "suggest", "suspect",
        "questionable", "uncertain", "unclear", "limited",
    ]

    for report in reports:
        tokens = re.findall(r"[a-zA-Z]{3,}", report.lower())
        all_tokens.update(tokens)
        for abbrev_re in ABBREV_MAP:
            hits = len(re.findall(abbrev_re, report))
            if hits: abbrev_hits[abbrev_re] += hits
        for hedge in HEDGE_TERMS:
            if re.search(rf"\b{hedge}\b", report, re.IGNORECASE):
                hedge_hits[hedge] += 1

    print(f"\n[vocabulary] Unique tokens: {len(all_tokens)}")
    print(f"  Top-10 most common: {all_tokens.most_common(10)}")
    print(f"  Abbreviation usage (top 5): {abbrev_hits.most_common(5)}")
    print(f"  Hedge term usage (top 5):   {hedge_hits.most_common(5)}")

    return {
        "unique_tokens":    len(all_tokens),
        "top_tokens":       all_tokens.most_common(top_n),
        "abbreviation_freq":dict(abbrev_hits.most_common(20)),
        "hedge_freq":       dict(hedge_hits.most_common(20)),
        "hedge_total_reports": sum(hedge_hits.values()),
        "hedge_pct": round(100 * sum(hedge_hits.values()) / len(reports), 1),
    }


# ─────────────────────────────────────────────────────────────────────────────
# F. SECTION COMPLETENESS AUDIT
# ─────────────────────────────────────────────────────────────────────────────

def section_completeness_audit(samples: List[Dict]) -> Dict:
    """
    Categorises reports by which sections they contain.
    Models trained only on impression-only reports learn a different
    task than models trained on findings+impression reports.
    """
    cats = Counter()
    impression_only_ids = []
    for s in samples:
        t = s["report_text"].lower()
        has_f = bool(re.search(r"\bfindings?\b", t))
        has_i = bool(re.search(r"\bimpression\b", t))
        if has_f and has_i:
            cats["both"] += 1
        elif has_f:
            cats["findings_only"] += 1
        elif has_i:
            cats["impression_only"] += 1
            impression_only_ids.append(s["study_id"])
        else:
            cats["neither"] += 1

    total = len(samples)
    print(f"\n[sections] Report section breakdown:")
    for k, v in cats.most_common():
        print(f"  {k:<20} {v:>6} ({100*v/total:.1f}%)")

    return {
        "section_counts": dict(cats),
        "impression_only_sample_ids": impression_only_ids[:20],
        "pct_both": round(100*cats["both"]/total, 1),
        "pct_impression_only": round(100*cats["impression_only"]/total, 1),
    }


# ─────────────────────────────────────────────────────────────────────────────
# G. FINDING COMPLEXITY DISTRIBUTION
# How many distinct positive findings does each report describe?
# ─────────────────────────────────────────────────────────────────────────────

POSITIVE_FINDING_PATS = {
    "pneumothorax":    r"\bpneumothorax\b",
    "pleural_effusion":r"\bpleural effusion\b",
    "consolidation":   r"\bconsolidation\b",
    "pneumonia":       r"\bpneumonia\b",
    "atelectasis":     r"\batelectasis\b",
    "pulmonary_edema": r"\bpulmonary edema\b",
    "cardiomegaly":    r"\bcardiomegaly\b",
    "mass_nodule":     r"\b(pulmonary )?(mass|nodule|nodular)\b",
    "fracture":        r"\bfracture\b",
    "mediastinal_wide":r"\bmediastinal widening\b",
    "emphysema":       r"\bemphysema\b",
    "hiatal_hernia":   r"\bhiatal hernia\b",
    "device_any":      r"\b(PICC|endotracheal|ET tube|chest tube|pacemaker|sternotomy|NG tube)\b",
}

def count_positive_findings(report_text: str) -> int:
    tl = report_text.lower()
    count = 0
    for name, pat in POSITIVE_FINDING_PATS.items():
        for m in re.finditer(pat, tl, re.IGNORECASE):
            if not _is_negated(tl, m.start()):
                count += 1
                break
    return count

def finding_complexity_distribution(samples: List[Dict]) -> Dict:
    complexity: Counter = Counter()
    for s in samples:
        n = count_positive_findings(s["report_text"])
        bucket = str(n) if n <= 5 else "6+"
        complexity[bucket] += 1

    total = len(samples)
    print(f"\n[complexity] Positive findings per report:")
    for k in sorted(complexity, key=lambda x: int(x.rstrip("+"))):
        v = complexity[k]
        bar = "β–ˆ" * int(30 * v / total)
        print(f"  {k} findings: {v:>6} ({100*v/total:5.1f}%)  {bar}")

    return {"complexity_distribution": dict(complexity)}


# ─────────────────────────────────────────────────────────────────────────────
# H. REPORT DEDUPLICATION
# Detect near-identical reports (boilerplate templates or copy-paste)
# ─────────────────────────────────────────────────────────────────────────────

def fingerprint_report(text: str) -> str:
    """Normalised 8-char fingerprint for near-duplicate detection."""
    normalised = re.sub(r"\s+", " ", text.lower().strip())
    normalised = re.sub(r"\b___\b", "", normalised)
    return hashlib.md5(normalised.encode()).hexdigest()[:8]

def find_duplicate_reports(samples: List[Dict], threshold: int = 3) -> Dict:
    """
    Groups reports by fingerprint. Groups with >= threshold identical reports
    are flagged as boilerplate templates. These inflate the majority class.
    """
    fp_groups: Dict[str, List[str]] = defaultdict(list)
    for s in samples:
        fp = fingerprint_report(s["report_text"])
        fp_groups[fp].append(s["study_id"])

    boilerplate = {fp: ids for fp, ids in fp_groups.items() if len(ids) >= threshold}
    boilerplate_studies = sum(len(v) for v in boilerplate.values())
    unique_reports = sum(1 for v in fp_groups.values() if len(v) == 1)

    print(f"\n[dedup] Fingerprint groups: {len(fp_groups)}")
    print(f"  Unique reports (appear once): {unique_reports} ({100*unique_reports/len(samples):.1f}%)")
    print(f"  Boilerplate groups (β‰₯{threshold} copies): {len(boilerplate)}")
    print(f"  Studies using boilerplate: {boilerplate_studies} ({100*boilerplate_studies/len(samples):.1f}%)")
    if boilerplate:
        top = sorted(boilerplate.items(), key=lambda x: -len(x[1]))[:3]
        for fp, ids in top:
            sample_text = next(s["report_text"][:80] for s in samples if s["study_id"] == ids[0])
            print(f"    fp={fp}: {len(ids)} copies β€” '{sample_text}...'")

    return {
        "total_fingerprint_groups": len(fp_groups),
        "unique_reports_count": unique_reports,
        "boilerplate_groups": len(boilerplate),
        "boilerplate_studies": boilerplate_studies,
        "boilerplate_pct": round(100*boilerplate_studies/len(samples), 2),
        "top_boilerplate_fps": [
            {"fp": fp, "count": len(ids), "sample_ids": ids[:5]}
            for fp, ids in sorted(boilerplate.items(), key=lambda x: -len(x[1]))[:10]
        ],
    }


# ─────────────────────────────────────────────────────────────────────────────
# I. CHEXPERT 14-LABEL APPROXIMATION
# Approximates the 14 CheXpert labels via negation-aware regex.
# For production, use the actual CheXBERT model on HuggingFace.
# ─────────────────────────────────────────────────────────────────────────────

CHEXPERT_14 = {
    "No Finding":               (r"\bno (acute|active|significant|focal) (cardiopulmonary|intrathoracic|pulmonary)?\s*(process|finding|abnormality|disease)\b", False),
    "Enlarged Cardiomediastinum":(r"\b(enlarged|widened|widening) (cardiomediastinum|mediastinum)\b", True),
    "Cardiomegaly":             (r"\bcardiomegaly\b", True),
    "Lung Opacity":             (r"\b(focal |patchy |basilar |lobar )?opacit(y|ies)\b", True),
    "Lung Lesion":              (r"\b(pulmonary )?(mass|nodule|lesion)\b", True),
    "Edema":                    (r"\b(pulmonary )?edema\b", True),
    "Consolidation":            (r"\bconsolidation\b", True),
    "Pneumonia":                (r"\bpneumonia\b", True),
    "Atelectasis":              (r"\batelectasis\b", True),
    "Pneumothorax":             (r"\bpneumothorax\b", True),
    "Pleural Effusion":         (r"\bpleural effusion\b", True),
    "Pleural Other":            (r"\b(pleural (thickening|plaque|scarring|calcification))\b", True),
    "Fracture":                 (r"\bfracture\b", True),
    "Support Devices":          (r"\b(PICC|endotracheal|ET tube|chest tube|pacemaker|NG tube|central (venous )?catheter|sternotomy|Port-?A-?Cath)\b", True),
}

def chexpert_label_frequencies(samples: List[Dict]) -> Dict:
    label_counts: Counter = Counter()
    total = len(samples)
    for s in samples:
        tl = s["report_text"].lower()
        for label, (pattern, needs_pos_check) in CHEXPERT_14.items():
            if needs_pos_check:
                for m in re.finditer(pattern, tl, re.IGNORECASE):
                    if not _is_negated(tl, m.start()):
                        label_counts[label] += 1
                        break
            else:
                if re.search(pattern, tl, re.IGNORECASE):
                    label_counts[label] += 1

    print(f"\n[chexpert14] Label frequencies (negation-aware):")
    print(f"  {'Label':<30} {'Count':>7}  {'Prevalence':>10}")
    print(f"  {'-'*52}")
    for label in CHEXPERT_14:
        c = label_counts.get(label, 0)
        print(f"  {label:<30} {c:>7}  ({100*c/total:6.2f}%)")

    return {
        "chexpert14_counts": dict(label_counts),
        "chexpert14_prevalence": {k: round(100*v/total, 2)
                                   for k, v in label_counts.items()},
    }


# ─────────────────────────────────────────────────────────────────────────────
# MAIN RUNNER
# ─────────────────────────────────────────────────────────────────────────────

def load_samples(reports_dir: Path, min_chars: int = 40) -> List[Dict]:
    samples = []
    for rp in sorted(reports_dir.glob("*.txt")):
        text = rp.read_text(encoding="utf-8", errors="ignore").strip()
        if len(text) >= min_chars:
            samples.append({"study_id": rp.stem, "report_text": text,
                             "image_paths": []})
    return samples


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--reports_dir",  type=str, required=True)
    parser.add_argument("--images_root",  type=str, default="")
    parser.add_argument("--images_glob",  type=str, default="images_*")
    parser.add_argument("--metadata_csv", type=str, default="",
                        help="Path to mimic-cxr-2.0.0-metadata.csv (optional)")
    parser.add_argument("--out_dir",      type=str, default="analysis_out")
    parser.add_argument("--min_chars",    type=int, default=40)
    parser.add_argument("--skip_images",  action="store_true",
                        help="Skip image quality analysis (fast mode)")
    parser.add_argument("--quality_threshold", type=float, default=0.5,
                        help="Reports below this overall_quality score are flagged.")
    args = parser.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    reports_dir  = Path(args.reports_dir)
    metadata_csv = Path(args.metadata_csv) if args.metadata_csv else None

    print(f"Loading reports from: {reports_dir}")
    samples = load_samples(reports_dir, args.min_chars)
    print(f"Loaded {len(samples)} reports.")

    study_to_patient = load_study_to_patient(metadata_csv)

    summary = {"total_reports": len(samples)}

    # B. Per-report quality scores β†’ CSV
    print("\n[B] Scoring report quality...")
    quality_rows = []
    score_dist = Counter()
    low_quality_ids = []
    for s in samples:
        q = score_report(s["report_text"])
        q["study_id"] = s["study_id"]
        quality_rows.append(q)
        bucket = f"{math.floor(q['overall_quality']*10)/10:.1f}"
        score_dist[bucket] += 1
        if q["overall_quality"] < args.quality_threshold:
            low_quality_ids.append(s["study_id"])

    quality_csv = out_dir / "report_quality_scores.csv"
    with quality_csv.open("w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=list(quality_rows[0].keys()))
        writer.writeheader(); writer.writerows(quality_rows)
    print(f"  Quality CSV: {quality_csv}")
    print(f"  Score distribution: {dict(sorted(score_dist.items()))}")
    print(f"  Reports below quality {args.quality_threshold}: {len(low_quality_ids)}")
    summary["quality"] = {
        "score_distribution": dict(score_dist),
        "low_quality_count": len(low_quality_ids),
        "low_quality_sample_ids": low_quality_ids[:20],
        "mean_quality": round(statistics.mean(r["overall_quality"] for r in quality_rows), 3),
    }

    # C. Image quality (sample up to 2000 images for speed)
    if not args.skip_images and args.images_root:
        print("\n[C] Analysing image quality (sample)...")
        images_root = Path(args.images_root)
        all_images = []
        for glob_dir in sorted(images_root.glob(args.images_glob)):
            all_images.extend(glob_dir.rglob("*.jpg"))
            all_images.extend(glob_dir.rglob("*.png"))
        import random; random.shuffle(all_images)
        sample_imgs = all_images[:2000]
        img_flags: Counter = Counter()
        bad_images = []
        for img_path in sample_imgs:
            result = analyze_image_quality(str(img_path))
            for flag in result.get("flags", []):
                img_flags[flag] += 1
            if not result.get("ok", True):
                bad_images.append({"path": str(img_path), "flags": result["flags"]})
        print(f"  Analysed {len(sample_imgs)} images.")
        print(f"  Flag counts: {dict(img_flags)}")
        img_issues_csv = out_dir / "image_quality_issues.csv"
        with img_issues_csv.open("w", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=["path", "flags"])
            writer.writeheader()
            writer.writerows([{"path": r["path"], "flags": "|".join(r["flags"])}
                               for r in bad_images])
        summary["image_quality"] = {
            "sampled": len(sample_imgs),
            "flag_counts": dict(img_flags),
            "bad_image_count": len(bad_images),
        }

    # D. Longitudinal
    print("\n[D] Longitudinal analysis...")
    summary["longitudinal"] = analyze_longitudinal(samples, study_to_patient)

    # E. Vocabulary
    print("\n[E] Vocabulary audit...")
    summary["vocabulary"] = vocabulary_audit([s["report_text"] for s in samples])

    # F. Section completeness
    print("\n[F] Section completeness...")
    summary["section_completeness"] = section_completeness_audit(samples)

    # G. Finding complexity
    print("\n[G] Finding complexity...")
    summary["finding_complexity"] = finding_complexity_distribution(samples)

    # H. Deduplication
    print("\n[H] Deduplication...")
    summary["deduplication"] = find_duplicate_reports(samples)

    # I. CheXpert 14 labels
    print("\n[I] CheXpert 14-label frequencies...")
    summary["chexpert14"] = chexpert_label_frequencies(samples)

    # A. Leakage (needs metadata)
    if study_to_patient:
        print("\n[A] Patient leakage check...")
        n = len(samples)
        val_ids   = [s["study_id"] for s in samples[:int(n*0.02)]]
        train_ids = [s["study_id"] for s in samples[int(n*0.02):]]
        summary["leakage"] = check_patient_leakage(train_ids, val_ids, study_to_patient)

    # Save full JSON summary
    summary_path = out_dir / "advanced_analysis_summary.json"
    summary_path.write_text(json.dumps(summary, indent=2, default=str))
    print(f"\n{'='*60}")
    print(f"Full summary β†’ {summary_path}")
    print(f"Quality CSV  β†’ {quality_csv}")
    print(f"{'='*60}")


if __name__ == "__main__":
    main()