File size: 13,468 Bytes
6835659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
"""
MSCI Threshold Calibration

Calibrates MSCI thresholds using ROC analysis to find optimal
classification boundaries for "coherent" vs "incoherent" samples.

Key analyses:
- ROC curve: MSCI as classifier
- AUC (Area Under Curve)
- Optimal threshold via Youden's J statistic
- Precision-Recall analysis
"""

from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
from scipy import stats


@dataclass
class CalibrationResult:
    """Result of threshold calibration."""
    optimal_threshold: float
    youden_j: float
    auc: float
    sensitivity_at_optimal: float  # True positive rate
    specificity_at_optimal: float  # True negative rate
    precision_at_optimal: float
    f1_at_optimal: float
    roc_curve: Dict[str, List[float]]

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return {
            "optimal_threshold": self.optimal_threshold,
            "youden_j": self.youden_j,
            "auc": self.auc,
            "sensitivity_at_optimal": self.sensitivity_at_optimal,
            "specificity_at_optimal": self.specificity_at_optimal,
            "precision_at_optimal": self.precision_at_optimal,
            "f1_at_optimal": self.f1_at_optimal,
            "roc_curve": self.roc_curve,
        }


class ThresholdCalibrator:
    """
    Calibrates MSCI thresholds for coherence classification.

    Uses human judgments as the validation target to find optimal
    MSCI threshold that maximizes discrimination between coherent
    and incoherent samples. Note: human judgments serve as the
    best available reference, not absolute ground truth.
    """

    def __init__(self, human_threshold: float = 0.6):
        """
        Initialize calibrator.

        Args:
            human_threshold: Human score above which sample is "coherent"
                            (e.g., 0.6 = 3/5 or higher on Likert scale)
        """
        self.human_threshold = human_threshold

    def compute_roc_curve(
        self,
        msci_scores: List[float],
        human_scores: List[float],
        n_thresholds: int = 100,
    ) -> Tuple[List[float], List[float], List[float]]:
        """
        Compute ROC curve points.

        Args:
            msci_scores: MSCI scores (predictor)
            human_scores: Human scores (validation target, normalized 0-1)
            n_thresholds: Number of threshold points

        Returns:
            Tuple of (thresholds, tpr_list, fpr_list)
        """
        # Binarize human scores: 1 = coherent, 0 = incoherent
        y_true = [1 if h >= self.human_threshold else 0 for h in human_scores]

        # Generate thresholds
        min_msci = min(msci_scores)
        max_msci = max(msci_scores)
        thresholds = np.linspace(min_msci - 0.01, max_msci + 0.01, n_thresholds)

        tpr_list = []  # True positive rate (sensitivity)
        fpr_list = []  # False positive rate (1 - specificity)

        for threshold in thresholds:
            # Predict: 1 if MSCI >= threshold
            y_pred = [1 if m >= threshold else 0 for m in msci_scores]

            # Compute confusion matrix elements
            tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 1)
            fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 0)
            fp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 1)
            tn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 0)

            # Rates
            tpr = tp / (tp + fn) if (tp + fn) > 0 else 0
            fpr = fp / (fp + tn) if (fp + tn) > 0 else 0

            tpr_list.append(tpr)
            fpr_list.append(fpr)

        return list(thresholds), tpr_list, fpr_list

    def compute_auc(
        self,
        fpr_list: List[float],
        tpr_list: List[float],
    ) -> float:
        """
        Compute Area Under ROC Curve using trapezoidal rule.

        Args:
            fpr_list: False positive rates
            tpr_list: True positive rates

        Returns:
            AUC value
        """
        # Sort by FPR for proper integration
        sorted_points = sorted(zip(fpr_list, tpr_list))
        sorted_fpr = [p[0] for p in sorted_points]
        sorted_tpr = [p[1] for p in sorted_points]

        # Trapezoidal integration
        auc = 0.0
        for i in range(1, len(sorted_fpr)):
            auc += (sorted_fpr[i] - sorted_fpr[i-1]) * (sorted_tpr[i] + sorted_tpr[i-1]) / 2

        return auc

    def find_optimal_threshold(
        self,
        thresholds: List[float],
        tpr_list: List[float],
        fpr_list: List[float],
    ) -> Tuple[float, float, int]:
        """
        Find optimal threshold using Youden's J statistic.

        J = sensitivity + specificity - 1 = TPR - FPR

        Args:
            thresholds: MSCI threshold values
            tpr_list: True positive rates
            fpr_list: False positive rates

        Returns:
            Tuple of (optimal_threshold, youden_j, optimal_index)
        """
        youden_j = [tpr - fpr for tpr, fpr in zip(tpr_list, fpr_list)]
        optimal_idx = int(np.argmax(youden_j))

        return thresholds[optimal_idx], youden_j[optimal_idx], optimal_idx

    def calibrate(
        self,
        msci_scores: List[float],
        human_scores: List[float],
    ) -> CalibrationResult:
        """
        Perform full threshold calibration.

        Args:
            msci_scores: MSCI scores
            human_scores: Human coherence scores (normalized 0-1)

        Returns:
            CalibrationResult with optimal threshold and metrics
        """
        if len(msci_scores) != len(human_scores):
            raise ValueError("Score lists must have same length")

        if len(msci_scores) < 10:
            raise ValueError("Need at least 10 samples for calibration")

        # Compute ROC curve
        thresholds, tpr_list, fpr_list = self.compute_roc_curve(
            msci_scores, human_scores
        )

        # Compute AUC
        auc = self.compute_auc(fpr_list, tpr_list)

        # Find optimal threshold
        optimal_threshold, youden_j, opt_idx = self.find_optimal_threshold(
            thresholds, tpr_list, fpr_list
        )

        # Compute metrics at optimal threshold
        sensitivity = tpr_list[opt_idx]
        specificity = 1 - fpr_list[opt_idx]

        # Precision and F1 at optimal threshold
        y_true = [1 if h >= self.human_threshold else 0 for h in human_scores]
        y_pred = [1 if m >= optimal_threshold else 0 for m in msci_scores]

        tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 1)
        fp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 1)
        fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 0)

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = sensitivity
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

        return CalibrationResult(
            optimal_threshold=optimal_threshold,
            youden_j=youden_j,
            auc=auc,
            sensitivity_at_optimal=sensitivity,
            specificity_at_optimal=specificity,
            precision_at_optimal=precision,
            f1_at_optimal=f1,
            roc_curve={
                "thresholds": thresholds,
                "tpr": tpr_list,
                "fpr": fpr_list,
            },
        )

    def calibrate_from_human_eval(
        self,
        human_eval_path: Path,
    ) -> CalibrationResult:
        """
        Calibrate from human evaluation session.

        Args:
            human_eval_path: Path to human evaluation session JSON

        Returns:
            CalibrationResult
        """
        from src.evaluation.human_eval_schema import EvaluationSession

        session = EvaluationSession.load(Path(human_eval_path))

        msci_scores = []
        human_scores = []

        # Build sample_id -> msci mapping
        sample_msci = {s.sample_id: s.msci_score for s in session.samples if s.msci_score}

        for eval in session.evaluations:
            if eval.is_rerating:
                continue
            if eval.sample_id not in sample_msci:
                continue

            msci_scores.append(sample_msci[eval.sample_id])
            human_scores.append(eval.weighted_score())

        return self.calibrate(msci_scores, human_scores)

    def evaluate_thresholds(
        self,
        msci_scores: List[float],
        human_scores: List[float],
        thresholds: List[float],
    ) -> Dict[str, Dict[str, float]]:
        """
        Evaluate classification performance at multiple thresholds.

        Args:
            msci_scores: MSCI scores
            human_scores: Human scores
            thresholds: Thresholds to evaluate

        Returns:
            Dict mapping threshold to performance metrics
        """
        y_true = [1 if h >= self.human_threshold else 0 for h in human_scores]
        results = {}

        for threshold in thresholds:
            y_pred = [1 if m >= threshold else 0 for m in msci_scores]

            tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 1)
            tn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 0)
            fp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 1)
            fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 0)

            accuracy = (tp + tn) / len(y_true) if y_true else 0
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

            results[f"{threshold:.3f}"] = {
                "accuracy": accuracy,
                "precision": precision,
                "recall": recall,
                "f1": f1,
                "true_positives": tp,
                "true_negatives": tn,
                "false_positives": fp,
                "false_negatives": fn,
            }

        return results

    def generate_report(
        self,
        calibration_result: CalibrationResult,
        output_path: Optional[Path] = None,
    ) -> Dict[str, Any]:
        """
        Generate calibration report.

        Args:
            calibration_result: Result from calibrate()
            output_path: Optional path to save report

        Returns:
            Complete calibration report
        """
        report = {
            "analysis_type": "MSCI Threshold Calibration",
            "purpose": "Find optimal MSCI threshold for coherence classification",
            "method": "ROC analysis with Youden's J optimization",
            "human_threshold": self.human_threshold,
            "results": calibration_result.to_dict(),
        }

        # AUC interpretation
        auc = calibration_result.auc
        if auc >= 0.9:
            auc_interp = "Excellent discrimination"
        elif auc >= 0.8:
            auc_interp = "Good discrimination"
        elif auc >= 0.7:
            auc_interp = "Acceptable discrimination"
        elif auc >= 0.6:
            auc_interp = "Poor discrimination"
        else:
            auc_interp = "Failed discrimination (no better than chance)"

        report["interpretation"] = {
            "auc_interpretation": auc_interp,
            "optimal_threshold": calibration_result.optimal_threshold,
            "threshold_usage": (
                f"Samples with MSCI >= {calibration_result.optimal_threshold:.3f} "
                f"should be classified as 'coherent'"
            ),
            "expected_performance": {
                "sensitivity": f"{calibration_result.sensitivity_at_optimal:.1%} of coherent samples correctly identified",
                "specificity": f"{calibration_result.specificity_at_optimal:.1%} of incoherent samples correctly rejected",
                "precision": f"{calibration_result.precision_at_optimal:.1%} of 'coherent' predictions are correct",
            },
        }

        # Recommendations
        if auc >= 0.7:
            report["recommendations"] = [
                f"Use MSCI threshold of {calibration_result.optimal_threshold:.3f} for binary classification",
                "MSCI provides meaningful discrimination between coherent and incoherent samples",
            ]
        else:
            report["recommendations"] = [
                "MSCI alone may not reliably distinguish coherent from incoherent samples",
                "Consider combining MSCI with other metrics",
                "Human evaluation may be necessary for borderline cases",
            ]

        if output_path:
            # Exclude full ROC curve from saved file to reduce size
            report_to_save = report.copy()
            if "roc_curve" in report_to_save.get("results", {}):
                report_to_save["results"] = report_to_save["results"].copy()
                del report_to_save["results"]["roc_curve"]
                report_to_save["results"]["roc_curve_note"] = "Excluded from file (100 points)"

            output_path = Path(output_path)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            with output_path.open("w", encoding="utf-8") as f:
                json.dump(report_to_save, f, indent=2, ensure_ascii=False)

        return report