PGC-AI-Chatbot / scripts /calibrate_threshold.py
Jacooo's picture
Deploy from GitHub: 49fa342
92dc348 verified
"""Empirical threshold calibration for VERIFIED_DENSE_THRESHOLD / VERIFIED_HYBRID_THRESHOLD.
Methodology
-----------
For each labeled query in the golden retrieval cases we retrieve a wide candidate pool
(match_threshold=0.20, match_count=20) from the real Supabase vector store, then label
every returned chunk as either True Positive (TP) or True Negative (TN):
TP = source name matches the expected source AND
at least one expected keyword is found in the chunk content
Everything else is TN.
We then build two separate score distributions:
Dense path β€” all TP vs all TN cosine similarity scores
Cross-modal β€” same, but restricted to chunks also found by FTS
For each candidate threshold t ∈ [0.45, 0.80] we compute Youden's J statistic:
J(t) = TPR(t) βˆ’ FPR(t)
= (TP above t / total TP) βˆ’ (TN above t / total TN)
The threshold that maximises J is the operating point with minimum TP/TN overlap.
Requirements
------------
- Real Supabase connection (SUPABASE_URL + SUPABASE_KEY env vars or .env file)
- BGE-M3 model accessible (fastembed downloads on first run)
Usage
-----
cd "AI Chatbot"
python scripts/calibrate_threshold.py
"""
from __future__ import annotations
import asyncio
import json
import sys
from pathlib import Path
from typing import Any
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from app.retrieval_eval import load_golden_retrieval_cases
from app.vector_store import (
VERIFIED_DENSE_THRESHOLD,
VERIFIED_HYBRID_THRESHOLD,
search_knowledge,
search_knowledge_fts,
)
# ── constants ──────────────────────────────────────────────────────────────────
CALIB_THRESHOLD = 0.20 # retrieval floor for data collection (wide net)
CALIB_COUNT = 20 # chunks per query (wider than production default of 7)
HISTOGRAM_BINS = 15 # score buckets for the ASCII plot
HISTOGRAM_LOW = 0.40 # left edge of histogram x-axis
HISTOGRAM_HIGH = 0.82 # right edge
CANDIDATE_RANGE = range(45, 81) # thresholds to evaluate (0.45 β†’ 0.80 in 0.01 steps)
ScoreRecord = tuple[float, bool, bool, str] # (similarity, is_tp, is_cross_modal, case_id)
# ── helpers ────────────────────────────────────────────────────────────────────
def _is_tp(
chunk: dict[str, Any],
expected_source: str,
expected_keywords: list[str],
expected_filename: str | None = None,
expected_page: int | None = None,
) -> bool:
source = chunk.get("source", "").strip()
content = (chunk.get("content") or "").lower()
source_match = source == expected_source.strip()
if not source_match:
return False
# Optionally narrow to exact page (when available, prefer it but don't require it exclusively)
if expected_page is not None and chunk.get("page_number") is not None:
if chunk.get("page_number") == expected_page:
return True # exact page + source = definite TP regardless of keywords
# wrong page β€” fall back to keyword check (might still be TP from adjacent page)
if not expected_keywords:
return True
return any(kw.lower() in content for kw in expected_keywords)
async def _collect_case_scores(
query: str,
expected_source: str,
expected_keywords: list[str],
case_id: str,
expected_filename: str | None = None,
expected_page: int | None = None,
) -> list[ScoreRecord]:
dense_chunks = await search_knowledge(
query=query,
match_threshold=CALIB_THRESHOLD,
match_count=CALIB_COUNT,
query_label=f"calib:{case_id}",
)
fts_chunks = await search_knowledge_fts(
query=query,
match_count=CALIB_COUNT,
)
fts_keys: set[tuple] = {
(c.get("filename"), c.get("page_number")) for c in fts_chunks
}
records: list[ScoreRecord] = []
for chunk in dense_chunks:
sim = chunk.get("similarity", 0.0)
key = (chunk.get("filename"), chunk.get("page_number"))
is_cross_modal = key in fts_keys
tp = _is_tp(chunk, expected_source, expected_keywords, expected_filename, expected_page)
records.append((sim, tp, is_cross_modal, case_id))
tp_count = sum(1 for _, tp, _, _ in records if tp)
print(f" β†’ {len(records)} chunks returned, {tp_count} TP")
return records
async def collect_all_scores() -> list[ScoreRecord]:
cases = load_golden_retrieval_cases()
rag_cases = [
c for c in cases
if c.get("expected_mode") == "vector_rag" and c.get("expected_found") and c.get("expected_source")
]
print(f"\nCalibrating over {len(rag_cases)} labeled cases (vector_rag + calibration):\n")
all_records: list[ScoreRecord] = []
for case in rag_cases:
query = str(case["query"])
source = str(case["expected_source"])
keywords: list[str] = case.get("expected_content_keywords") or []
case_id = str(case["case_id"])
filename: str | None = case.get("expected_filename")
page: int | None = case.get("expected_page")
print(f" [{case_id}] \"{query[:70]}\"")
records = await _collect_case_scores(query, source, keywords, case_id, filename, page)
all_records.extend(records)
return all_records
# ── analysis ───────────────────────────────────────────────────────────────────
def compute_optimal_threshold(
records: list[ScoreRecord],
cross_modal_only: bool = False,
) -> tuple[float, float]:
"""Return (optimal_threshold, youden_j) maximising Youden's J over candidate range."""
subset = [
(sim, tp)
for sim, tp, is_cm, _ in records
if (not cross_modal_only or is_cm)
]
if not subset:
return 0.0, 0.0
tp_scores = [s for s, tp in subset if tp]
tn_scores = [s for s, tp in subset if not tp]
if not tp_scores or not tn_scores:
return 0.0, 0.0
best_t, best_j = 0.0, -99.0
for ti in CANDIDATE_RANGE:
t = ti / 100.0
tpr = sum(1 for s in tp_scores if s >= t) / len(tp_scores)
fpr = sum(1 for s in tn_scores if s >= t) / len(tn_scores)
j = tpr - fpr
if j > best_j:
best_j = j
best_t = t
return best_t, best_j
def print_histogram(
records: list[ScoreRecord],
cross_modal_only: bool = False,
) -> None:
subset = [
(sim, tp)
for sim, tp, is_cm, _ in records
if (not cross_modal_only or is_cm)
]
if not subset:
print(" (no data)\n")
return
tp_scores = [s for s, tp in subset if tp]
tn_scores = [s for s, tp in subset if not tp]
label = "Cross-modal chunks (dense AND fts)" if cross_modal_only else "All dense chunks"
print(f"\n{label} (n={len(subset)}, TP={len(tp_scores)}, TN={len(tn_scores)})")
print(f"{'Bucket':>14} {'TP':>4} {'TN':>4} {'TP (β–ˆ)':25} {'TN (β–‘)':25}")
print("─" * 75)
bin_width = (HISTOGRAM_HIGH - HISTOGRAM_LOW) / HISTOGRAM_BINS
max_count = 1
for i in range(HISTOGRAM_BINS):
lo = HISTOGRAM_LOW + i * bin_width
hi = lo + bin_width
max_count = max(
max_count,
sum(1 for s in tp_scores if lo <= s < hi),
sum(1 for s in tn_scores if lo <= s < hi),
)
for i in range(HISTOGRAM_BINS):
lo = HISTOGRAM_LOW + i * bin_width
hi = lo + bin_width
tp_n = sum(1 for s in tp_scores if lo <= s < hi)
tn_n = sum(1 for s in tn_scores if lo <= s < hi)
tp_bar = "β–ˆ" * int(tp_n / max_count * 24)
tn_bar = "β–‘" * int(tn_n / max_count * 24)
print(f" {lo:.2f}–{hi:.2f} {tp_n:>4} {tn_n:>4} {tp_bar:<25} {tn_bar}")
def print_precision_recall_table(
records: list[ScoreRecord],
cross_modal_only: bool = False,
) -> None:
"""Print precision / recall / F1 across the interesting threshold range."""
subset = [
(sim, tp)
for sim, tp, is_cm, _ in records
if (not cross_modal_only or is_cm)
]
tp_scores = [s for s, tp in subset if tp]
tn_scores = [s for s, tp in subset if not tp]
if not tp_scores:
return
print(f"\n{'Threshold':>12} {'TPR':>7} {'FPR':>7} {'J':>7} {'Prec':>7} {'F1':>7}")
print("─" * 55)
for ti in range(55, 78, 2):
t = ti / 100.0
tp_above = sum(1 for s in tp_scores if s >= t)
tn_above = sum(1 for s in tn_scores if s >= t)
tpr = tp_above / len(tp_scores) if tp_scores else 0
fpr = tn_above / len(tn_scores) if tn_scores else 0
j = tpr - fpr
prec = tp_above / (tp_above + tn_above) if (tp_above + tn_above) else 0
f1 = 2 * prec * tpr / (prec + tpr) if (prec + tpr) else 0
print(f" t={t:.2f} {tpr:>6.1%} {fpr:>6.1%} {j:>+7.3f} {prec:>6.1%} {f1:>6.1%}")
# ── main ───────────────────────────────────────────────────────────────────────
async def main() -> None:
records = await collect_all_scores()
total_tp = sum(1 for _, tp, _, _ in records if tp)
total_tn = sum(1 for _, tp, _, _ in records if not tp)
cross_modal_total = sum(1 for _, _, is_cm, _ in records if is_cm)
print(f"\nTotal data points: {len(records)} (TP={total_tp}, TN={total_tn}, cross-modal={cross_modal_total})\n")
# ── Dense path ──
print("=" * 75)
print("DENSE PATH CALIBRATION")
print("=" * 75)
print_histogram(records, cross_modal_only=False)
print_precision_recall_table(records, cross_modal_only=False)
dense_t, dense_j = compute_optimal_threshold(records, cross_modal_only=False)
# ── Cross-modal path ──
print("\n" + "=" * 75)
print("CROSS-MODAL PATH CALIBRATION (dense AND fts confirmed chunks only)")
print("=" * 75)
print_histogram(records, cross_modal_only=True)
print_precision_recall_table(records, cross_modal_only=True)
hybrid_t, hybrid_j = compute_optimal_threshold(records, cross_modal_only=True)
# ── Recommendation ──
print("\n" + "=" * 75)
print("RECOMMENDED THRESHOLDS (argmax Youden's J)")
print("=" * 75)
print(f" VERIFIED_DENSE_THRESHOLD = {dense_t:.2f} (J = {dense_j:+.3f})")
if hybrid_t:
print(f" VERIFIED_HYBRID_THRESHOLD = {hybrid_t:.2f} (J = {hybrid_j:+.3f})")
else:
print(" VERIFIED_HYBRID_THRESHOLD = (insufficient cross-modal data)")
print()
print(" Currently set:")
print(f" VERIFIED_DENSE_THRESHOLD = {VERIFIED_DENSE_THRESHOLD}")
print(f" VERIFIED_HYBRID_THRESHOLD = {VERIFIED_HYBRID_THRESHOLD}")
print()
if abs(dense_t - VERIFIED_DENSE_THRESHOLD) < 0.01:
print(" βœ… Dense threshold looks well-calibrated.")
else:
direction = "↑ raise" if dense_t > VERIFIED_DENSE_THRESHOLD else "↓ lower"
print(f" ⚠️ Dense threshold should change: {VERIFIED_DENSE_THRESHOLD} β†’ {dense_t:.2f} ({direction})")
if hybrid_t and abs(hybrid_t - VERIFIED_HYBRID_THRESHOLD) < 0.01:
print(" βœ… Hybrid threshold looks well-calibrated.")
elif hybrid_t:
direction = "↑ raise" if hybrid_t > VERIFIED_HYBRID_THRESHOLD else "↓ lower"
print(f" ⚠️ Hybrid threshold should change: {VERIFIED_HYBRID_THRESHOLD} β†’ {hybrid_t:.2f} ({direction})")
if __name__ == "__main__":
asyncio.run(main())