| """ |
| Token threshold calibration analysis. |
| |
| Collects evidence quality metrics paired with HHEM scores to find |
| the optimal MIN_EVIDENCE_TOKENS threshold. |
| |
| Usage: |
| python scripts/calibrate_token_threshold.py --samples 200 |
| python scripts/calibrate_token_threshold.py --analyze-only # Re-analyze existing data |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from dataclasses import asdict, dataclass |
| from pathlib import Path |
| from typing import TYPE_CHECKING |
|
|
| import numpy as np |
|
|
| from sage.config import ( |
| CHARS_PER_TOKEN, |
| EVALUATION_QUERIES, |
| MIN_EVIDENCE_CHUNKS, |
| MIN_EVIDENCE_TOKENS, |
| get_logger, |
| log_banner, |
| log_section, |
| ) |
| from sage.core import AggregationMethod, ProductScore |
| from sage.core.evidence import check_evidence_quality |
| from sage.services.retrieval import get_candidates |
|
|
| if TYPE_CHECKING: |
| from sage.adapters.hhem import HallucinationDetector |
| from sage.services.explanation import Explainer |
|
|
| logger = get_logger(__name__) |
|
|
| DEFAULT_OUTPUT = Path("data/calibration/token_threshold.json") |
|
|
|
|
| @dataclass |
| class CalibrationRecord: |
| """Single data point for threshold calibration.""" |
|
|
| query: str |
| product_id: str |
| |
| chunk_count: int |
| total_tokens: int |
| min_chunk_tokens: int |
| max_chunk_tokens: int |
| top_score: float |
| avg_rating: float |
| |
| gate_passed: bool |
| hhem_score: float | None |
| is_hallucinated: bool | None |
|
|
|
|
| def estimate_tokens(text: str) -> int: |
| """Estimate token count from character count.""" |
| return len(text) // CHARS_PER_TOKEN |
|
|
|
|
| def collect_evidence_metrics(product: ProductScore) -> dict: |
| """Extract detailed evidence metrics for a product.""" |
| chunks = product.evidence |
| if not chunks: |
| return { |
| "chunk_count": 0, |
| "total_tokens": 0, |
| "min_chunk_tokens": 0, |
| "max_chunk_tokens": 0, |
| } |
|
|
| chunk_tokens = [estimate_tokens(c.text) for c in chunks] |
| return { |
| "chunk_count": len(chunks), |
| "total_tokens": sum(chunk_tokens), |
| "min_chunk_tokens": min(chunk_tokens), |
| "max_chunk_tokens": max(chunk_tokens), |
| } |
|
|
|
|
| def collect_calibration_data( |
| explainer: Explainer, |
| detector: HallucinationDetector, |
| max_samples: int = 200, |
| output_path: Path | None = None, |
| ) -> list[CalibrationRecord]: |
| """ |
| Collect calibration data by running pipeline with quality gate disabled. |
| |
| Key insight: We disable the gate to collect data on explanations that |
| WOULD have been refused, so we can see their actual HHEM scores. |
| """ |
| records: list[CalibrationRecord] = [] |
|
|
| log_banner(logger, "COLLECTING CALIBRATION DATA", width=60) |
| logger.info("Target samples: %d", max_samples) |
|
|
| for query in EVALUATION_QUERIES: |
| |
| products = get_candidates( |
| query=query, |
| k=10, |
| min_rating=2.0, |
| aggregation=AggregationMethod.MAX, |
| ) |
|
|
| for product in products: |
| if len(records) >= max_samples: |
| break |
|
|
| |
| quality = check_evidence_quality(product) |
| evidence_metrics = collect_evidence_metrics(product) |
|
|
| |
| result = explainer.generate_explanation( |
| query, |
| product, |
| max_evidence=5, |
| enforce_quality_gate=False, |
| ) |
|
|
| |
| hhem = detector.check_explanation( |
| result.evidence_texts, |
| result.explanation, |
| ) |
|
|
| record = CalibrationRecord( |
| query=query, |
| product_id=product.product_id, |
| chunk_count=evidence_metrics["chunk_count"], |
| total_tokens=evidence_metrics["total_tokens"], |
| min_chunk_tokens=evidence_metrics["min_chunk_tokens"], |
| max_chunk_tokens=evidence_metrics["max_chunk_tokens"], |
| top_score=quality.top_score, |
| avg_rating=product.avg_rating, |
| gate_passed=quality.is_sufficient, |
| hhem_score=hhem.score, |
| is_hallucinated=hhem.is_hallucinated, |
| ) |
| records.append(record) |
|
|
| logger.info( |
| "Sample %d: tokens=%d, gate=%s, hhem=%.3f", |
| len(records), |
| record.total_tokens, |
| "PASS" if record.gate_passed else "FAIL", |
| record.hhem_score, |
| ) |
|
|
| if len(records) >= max_samples: |
| break |
|
|
| |
| if output_path: |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "w") as f: |
| json.dump([asdict(r) for r in records], f, indent=2) |
| logger.info("Saved %d records to %s", len(records), output_path) |
|
|
| return records |
|
|
|
|
| def load_calibration_data(path: Path) -> list[CalibrationRecord]: |
| """Load previously collected calibration data.""" |
| with open(path) as f: |
| data = json.load(f) |
| return [CalibrationRecord(**r) for r in data] |
|
|
|
|
| def _compute_classification_metrics( |
| would_pass: np.ndarray, hallucinated: np.ndarray |
| ) -> dict: |
| """Compute precision, recall, F1 for a gate decision.""" |
| tp = int(np.sum(would_pass & ~hallucinated)) |
| fp = int(np.sum(would_pass & hallucinated)) |
| tn = int(np.sum(~would_pass & hallucinated)) |
| fn = int(np.sum(~would_pass & ~hallucinated)) |
|
|
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
| f1 = ( |
| 2 * precision * recall / (precision + recall) |
| if (precision + recall) > 0 |
| else 0.0 |
| ) |
|
|
| return { |
| "precision": round(precision, 3), |
| "recall": round(recall, 3), |
| "f1": round(f1, 3), |
| "tp": tp, |
| "fp": fp, |
| "tn": tn, |
| "fn": fn, |
| "n_passed": int(would_pass.sum()), |
| "n_refused": int((~would_pass).sum()), |
| } |
|
|
|
|
| @dataclass |
| class _ValidatedRecords: |
| """Pre-validated records with extracted numpy arrays.""" |
|
|
| tokens: np.ndarray |
| chunks: np.ndarray |
| hhem: np.ndarray |
| hallucinated: np.ndarray |
| count: int |
|
|
|
|
| def _validate_records( |
| records: list[CalibrationRecord], min_samples: int = 10 |
| ) -> _ValidatedRecords | None: |
| """Filter and validate records, returning None if insufficient.""" |
| valid = [r for r in records if r.hhem_score is not None] |
| if len(valid) < min_samples: |
| logger.warning("Not enough valid samples for analysis: %d", len(valid)) |
| return None |
| return _ValidatedRecords( |
| tokens=np.array([r.total_tokens for r in valid]), |
| chunks=np.array([r.chunk_count for r in valid]), |
| hhem=np.array([r.hhem_score for r in valid]), |
| hallucinated=np.array([r.is_hallucinated for r in valid]), |
| count=len(valid), |
| ) |
|
|
|
|
| def _sweep_threshold( |
| values: np.ndarray, |
| hallucinated: np.ndarray, |
| hhem: np.ndarray, |
| thresholds: list[int], |
| ) -> list[dict]: |
| """Sweep a single threshold dimension and compute metrics.""" |
| results = [] |
| for thresh in thresholds: |
| would_pass = values >= thresh |
| metrics = _compute_classification_metrics(would_pass, hallucinated) |
| metrics["threshold"] = thresh |
| metrics["mean_hhem_above"] = ( |
| round(float(hhem[would_pass].mean()), 3) if would_pass.any() else 0.0 |
| ) |
| results.append(metrics) |
| return results |
|
|
|
|
| def analyze_threshold(records: list[CalibrationRecord]) -> dict: |
| """Analyze data to find optimal token threshold.""" |
| validated = _validate_records(records) |
| if validated is None: |
| return {"error": "insufficient_samples", "count": len(records)} |
|
|
| thresholds = list(range(20, 200, 10)) |
| results = _sweep_threshold( |
| validated.tokens, validated.hallucinated, validated.hhem, thresholds |
| ) |
|
|
| best = max(results, key=lambda x: x["f1"]) |
| correlation = float(np.corrcoef(validated.tokens, validated.hhem)[0, 1]) |
|
|
| return { |
| "recommended_threshold": best["threshold"], |
| "best_f1": best["f1"], |
| "best_precision": best["precision"], |
| "best_recall": best["recall"], |
| "correlation_tokens_hhem": round(correlation, 3), |
| "current_threshold": MIN_EVIDENCE_TOKENS, |
| "total_samples": validated.count, |
| "hallucination_rate": round(float(validated.hallucinated.mean()), 3), |
| "mean_tokens": round(float(validated.tokens.mean()), 1), |
| "median_tokens": round(float(np.median(validated.tokens)), 1), |
| "all_thresholds": results, |
| } |
|
|
|
|
| def analyze_chunk_threshold(records: list[CalibrationRecord]) -> dict: |
| """Analyze optimal MIN_EVIDENCE_CHUNKS threshold.""" |
| validated = _validate_records(records) |
| if validated is None: |
| return {"error": "insufficient_samples", "count": len(records)} |
|
|
| thresholds = [1, 2, 3, 4, 5] |
| results = _sweep_threshold( |
| validated.chunks, validated.hallucinated, validated.hhem, thresholds |
| ) |
|
|
| best = max(results, key=lambda x: x["f1"]) |
| correlation = float(np.corrcoef(validated.chunks, validated.hhem)[0, 1]) |
|
|
| return { |
| "recommended_threshold": best["threshold"], |
| "best_f1": best["f1"], |
| "best_precision": best["precision"], |
| "best_recall": best["recall"], |
| "correlation_chunks_hhem": round(correlation, 3), |
| "current_threshold": MIN_EVIDENCE_CHUNKS, |
| "mean_chunks": round(float(validated.chunks.mean()), 2), |
| "all_thresholds": results, |
| } |
|
|
|
|
| def analyze_combined_thresholds(records: list[CalibrationRecord]) -> dict: |
| """2D sweep of token and chunk thresholds to find optimal combination.""" |
| validated = _validate_records(records) |
| if validated is None: |
| return {"error": "insufficient_samples", "count": len(records)} |
|
|
| token_thresholds = [20, 30, 40, 50, 75, 100] |
| chunk_thresholds = [1, 2, 3] |
|
|
| results = [] |
| for tok_thresh in token_thresholds: |
| for chunk_thresh in chunk_thresholds: |
| would_pass = (validated.tokens >= tok_thresh) & ( |
| validated.chunks >= chunk_thresh |
| ) |
| metrics = _compute_classification_metrics( |
| would_pass, validated.hallucinated |
| ) |
| metrics["token_threshold"] = tok_thresh |
| metrics["chunk_threshold"] = chunk_thresh |
| results.append(metrics) |
|
|
| best = max(results, key=lambda x: x["f1"]) |
|
|
| return { |
| "recommended_token_threshold": best["token_threshold"], |
| "recommended_chunk_threshold": best["chunk_threshold"], |
| "best_f1": best["f1"], |
| "best_precision": best["precision"], |
| "best_recall": best["recall"], |
| "current_token_threshold": MIN_EVIDENCE_TOKENS, |
| "current_chunk_threshold": MIN_EVIDENCE_CHUNKS, |
| "all_combinations": results, |
| } |
|
|
|
|
| def print_token_analysis(analysis: dict) -> None: |
| """Print token threshold analysis.""" |
| if "error" in analysis: |
| logger.error("Analysis failed: %s", analysis["error"]) |
| return |
|
|
| log_banner(logger, "TOKEN THRESHOLD ANALYSIS", width=60) |
|
|
| log_section(logger, "Summary") |
| logger.info("Total samples: %d", analysis["total_samples"]) |
| logger.info("Mean tokens: %.1f", analysis["mean_tokens"]) |
| logger.info("Median tokens: %.1f", analysis["median_tokens"]) |
| logger.info("Hallucination rate: %.1f%%", analysis["hallucination_rate"] * 100) |
| logger.info("Token-HHEM corr: %+.3f", analysis["correlation_tokens_hhem"]) |
|
|
| log_section(logger, "Recommendation") |
| logger.info("Current threshold: %d tokens", analysis["current_threshold"]) |
| logger.info("Recommended: %d tokens", analysis["recommended_threshold"]) |
| logger.info("Best F1: %.3f", analysis["best_f1"]) |
| logger.info("Best precision: %.3f", analysis["best_precision"]) |
| logger.info("Best recall: %.3f", analysis["best_recall"]) |
|
|
| log_section(logger, "Token Threshold Sweep") |
| logger.info("thresh | F1 | prec | recall | passed | refused") |
| logger.info("-------|-------|-------|--------|--------|--------") |
| for r in analysis["all_thresholds"]: |
| marker = " *" if r["threshold"] == analysis["recommended_threshold"] else "" |
| logger.info( |
| " %3d | %.3f | %.3f | %.3f | %3d | %3d%s", |
| r["threshold"], |
| r["f1"], |
| r["precision"], |
| r["recall"], |
| r["n_passed"], |
| r["n_refused"], |
| marker, |
| ) |
|
|
|
|
| def print_chunk_analysis(analysis: dict) -> None: |
| """Print chunk threshold analysis.""" |
| if "error" in analysis: |
| logger.error("Chunk analysis failed: %s", analysis["error"]) |
| return |
|
|
| log_banner(logger, "CHUNK THRESHOLD ANALYSIS", width=60) |
|
|
| log_section(logger, "Summary") |
| logger.info("Mean chunks: %.2f", analysis["mean_chunks"]) |
| logger.info("Chunk-HHEM corr: %+.3f", analysis["correlation_chunks_hhem"]) |
|
|
| log_section(logger, "Recommendation") |
| logger.info("Current threshold: %d chunks", analysis["current_threshold"]) |
| logger.info("Recommended: %d chunk(s)", analysis["recommended_threshold"]) |
| logger.info("Best F1: %.3f", analysis["best_f1"]) |
|
|
| log_section(logger, "Chunk Threshold Sweep") |
| logger.info("chunks | F1 | prec | recall | passed | refused") |
| logger.info("-------|-------|-------|--------|--------|--------") |
| for r in analysis["all_thresholds"]: |
| marker = " *" if r["threshold"] == analysis["recommended_threshold"] else "" |
| logger.info( |
| " %d | %.3f | %.3f | %.3f | %3d | %3d%s", |
| r["threshold"], |
| r["f1"], |
| r["precision"], |
| r["recall"], |
| r["n_passed"], |
| r["n_refused"], |
| marker, |
| ) |
|
|
|
|
| def print_combined_analysis(analysis: dict) -> None: |
| """Print combined threshold analysis.""" |
| if "error" in analysis: |
| logger.error("Combined analysis failed: %s", analysis["error"]) |
| return |
|
|
| log_banner(logger, "COMBINED THRESHOLD ANALYSIS", width=60) |
|
|
| log_section(logger, "Current vs Recommended") |
| logger.info( |
| "Current: %d tokens, %d chunks", |
| analysis["current_token_threshold"], |
| analysis["current_chunk_threshold"], |
| ) |
| logger.info( |
| "Recommended: %d tokens, %d chunk(s)", |
| analysis["recommended_token_threshold"], |
| analysis["recommended_chunk_threshold"], |
| ) |
| logger.info("Best F1: %.3f", analysis["best_f1"]) |
| logger.info("Precision: %.3f", analysis["best_precision"]) |
| logger.info("Recall: %.3f", analysis["best_recall"]) |
|
|
| log_section(logger, "2D Threshold Grid") |
| logger.info("tokens | chunks | F1 | prec | recall | passed | refused") |
| logger.info("-------|--------|-------|-------|--------|--------|--------") |
| for r in sorted( |
| analysis["all_combinations"], |
| key=lambda x: (x["token_threshold"], x["chunk_threshold"]), |
| ): |
| is_best = ( |
| r["token_threshold"] == analysis["recommended_token_threshold"] |
| and r["chunk_threshold"] == analysis["recommended_chunk_threshold"] |
| ) |
| is_current = ( |
| r["token_threshold"] == analysis["current_token_threshold"] |
| and r["chunk_threshold"] == analysis["current_chunk_threshold"] |
| ) |
| marker = " *" if is_best else (" (current)" if is_current else "") |
| logger.info( |
| " %3d | %d | %.3f | %.3f | %.3f | %3d | %3d%s", |
| r["token_threshold"], |
| r["chunk_threshold"], |
| r["f1"], |
| r["precision"], |
| r["recall"], |
| r["n_passed"], |
| r["n_refused"], |
| marker, |
| ) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser( |
| description="Calibrate MIN_EVIDENCE_TOKENS and MIN_EVIDENCE_CHUNKS thresholds" |
| ) |
| parser.add_argument( |
| "--samples", |
| type=int, |
| default=200, |
| help="Number of samples to collect", |
| ) |
| parser.add_argument( |
| "--output", |
| type=Path, |
| default=DEFAULT_OUTPUT, |
| help="Output path for calibration data", |
| ) |
| parser.add_argument( |
| "--analyze-only", |
| action="store_true", |
| help="Skip collection, analyze existing data", |
| ) |
| args = parser.parse_args() |
|
|
| if args.analyze_only: |
| if not args.output.exists(): |
| logger.error("No data file found at %s", args.output) |
| return |
| records = load_calibration_data(args.output) |
| logger.info("Loaded %d records from %s", len(records), args.output) |
| else: |
| from sage.adapters.hhem import HallucinationDetector |
| from sage.services.explanation import Explainer |
|
|
| explainer = Explainer() |
| detector = HallucinationDetector() |
| records = collect_calibration_data( |
| explainer, detector, args.samples, args.output |
| ) |
|
|
| |
| token_analysis = analyze_threshold(records) |
| chunk_analysis = analyze_chunk_threshold(records) |
| combined_analysis = analyze_combined_thresholds(records) |
|
|
| |
| print_token_analysis(token_analysis) |
| print_chunk_analysis(chunk_analysis) |
| print_combined_analysis(combined_analysis) |
|
|
| |
| all_analysis = { |
| "token_analysis": token_analysis, |
| "chunk_analysis": chunk_analysis, |
| "combined_analysis": combined_analysis, |
| } |
| analysis_path = args.output.with_suffix(".analysis.json") |
| with open(analysis_path, "w") as f: |
| json.dump(all_analysis, f, indent=2) |
| logger.info("Saved analysis to %s", analysis_path) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|