| """ |
| scripts/threshold_sweep.py β Sweep decision thresholds against the |
| dev set using a saved best.pt checkpoint, without retraining. |
| |
| Why this is fast: |
| The model emits per-triple scores once. The threshold only |
| enters precision_recall_f1 (and hop_stratified_precision). |
| MAP and NDCG are threshold-independent. We score the dev set |
| once, then re-evaluate F1 across many thresholds. |
| |
| Usage: |
| python scripts/threshold_sweep.py |
| [--config configs/caff_orphanet.yaml] |
| [--checkpoint runs/caff_orphanet/seed_42/best.pt] |
| [--thresholds 0.20,0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.60] |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import logging |
| import sys |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import yaml |
|
|
| ROOT = Path(__file__).parent.parent |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from caff import ( |
| AblationFlags, |
| CAFFConfig, |
| CAFFEvaluator, |
| CAFFModel, |
| CAFFTripleDataset, |
| CachedBFSExtractor, |
| FrozenBioEncoder, |
| KnowledgeGraph, |
| RelationEmbeddingCache, |
| load_qa_split, |
| ) |
| from caff.evaluator import ( |
| hop_stratified_precision, |
| mean_average_precision, |
| mean_ndcg_at_k, |
| precision_recall_f1, |
| ) |
| from caff.utils import set_global_seed |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
| logger = logging.getLogger("threshold_sweep") |
|
|
|
|
| def load_config(yaml_path: Path) -> tuple[CAFFConfig, AblationFlags]: |
| """Load YAML config (same logic as train.py::load_config).""" |
| with yaml_path.open("r", encoding="utf-8") as f: |
| raw = yaml.safe_load(f) |
| cfg_dict = raw.get("config", {}) |
| abl_dict = raw.get("ablation", {}) |
| config = CAFFConfig(**cfg_dict) |
| ablation = AblationFlags(**abl_dict) if abl_dict else AblationFlags() |
| return config, ablation |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument( |
| "--config", default="configs/caff_orphanet.yaml", |
| ) |
| parser.add_argument( |
| "--checkpoint", |
| default="runs/caff_orphanet/seed_42/best.pt", |
| ) |
| parser.add_argument( |
| "--thresholds", |
| default="0.20,0.25,0.30,0.35,0.40,0.45,0.50,0.55,0.60", |
| ) |
| parser.add_argument("--device", default="cpu") |
| parser.add_argument("--cache-dir", default="cache") |
| args = parser.parse_args() |
|
|
| config_path = Path(args.config) |
| ckpt_path = Path(args.checkpoint) |
| if not config_path.exists(): |
| logger.error(f"Config not found: {config_path}") |
| return 1 |
| if not ckpt_path.exists(): |
| logger.error(f"Checkpoint not found: {ckpt_path}") |
| return 1 |
|
|
| thresholds = [float(t.strip()) for t in args.thresholds.split(",") if t.strip()] |
| if not thresholds: |
| logger.error("No thresholds provided.") |
| return 1 |
|
|
| |
| config, ablation = load_config(config_path) |
| set_global_seed(config.seed, deterministic=config.deterministic) |
| logger.info(f"Loaded config: {config_path.name}") |
| logger.info(f" KG path: {config.kg_path}") |
| logger.info(f" Dev path: {config.dev_path}") |
| logger.info(f" Encoder: {config.encoder_name}") |
|
|
| |
| logger.info("Loading KG ...") |
| kg = KnowledgeGraph.from_tsv( |
| config.kg_path, |
| min_relation_freq=config.min_relation_freq, |
| ) |
|
|
| |
| logger.info(f"Loading encoder: {config.encoder_name}") |
| encoder = FrozenBioEncoder(config.encoder_name, device=args.device) |
| rel_cache_path = Path(args.cache_dir) / "relation_embeddings.pt" |
| relation_cache = RelationEmbeddingCache( |
| encoder=encoder, |
| relations=kg.relations, |
| cache_path=rel_cache_path, |
| ) |
|
|
| |
| bfs = CachedBFSExtractor( |
| kg, L=config.L, K_r=config.K_r, |
| cache_dir=Path(args.cache_dir) / "bfs", |
| ) |
|
|
| |
| dev_recs = load_qa_split(config.dev_path) |
| dev_ds = CAFFTripleDataset(dev_recs, bfs, require_gold=True) |
| logger.info(f"Dev dataset: {len(dev_ds):,} triple instances") |
|
|
| |
| model = CAFFModel(config, relation_cache, ablation=ablation).to(args.device) |
| payload = torch.load(ckpt_path, map_location=args.device, weights_only=False) |
| if not isinstance(payload, dict) or "model" not in payload: |
| logger.error( |
| f"Unexpected checkpoint format. Keys: " |
| f"{list(payload.keys()) if isinstance(payload, dict) else type(payload)}" |
| ) |
| return 1 |
| missing, unexpected = model.load_state_dict(payload["model"], strict=False) |
| if missing: |
| logger.warning( |
| f"Missing keys when loading: " |
| f"{missing[:5]}{'...' if len(missing)>5 else ''}" |
| ) |
| if unexpected: |
| logger.warning( |
| f"Unexpected keys when loading: " |
| f"{unexpected[:5]}{'...' if len(unexpected)>5 else ''}" |
| ) |
| model.eval() |
|
|
| if "metrics" in payload: |
| m = payload["metrics"] |
| logger.info( |
| f"Checkpoint training metrics: " |
| f"epoch={m.get('epoch')}, " |
| f"dev_f1={m.get('dev_f1')}, dev_map={m.get('dev_map')}" |
| ) |
|
|
| |
| evaluator = CAFFEvaluator( |
| config=config, |
| encoder=encoder, |
| mode="teacher_forced", |
| threshold=thresholds[0], |
| ) |
| logger.info("Scoring dev set once (this is the slow part)...") |
| scores, instances, _retained = evaluator._score_dataset(model, dev_ds) |
| labels = np.array([i.label for i in instances]) |
| logger.info( |
| f" Done. {len(scores):,} candidate scores; " |
| f"{int(labels.sum()):,} positives " |
| f"({100*labels.mean():.2f}% positive rate)" |
| ) |
|
|
| |
| q_groups: dict[str, list[tuple[float, int]]] = defaultdict(list) |
| for inst, sc, lbl in zip(instances, scores.tolist(), labels.tolist()): |
| q_groups[inst.query_id].append((sc, lbl)) |
| per_query: dict[str, tuple[np.ndarray, np.ndarray]] = {} |
| for qid, items in q_groups.items(): |
| per_query[qid] = ( |
| np.array([x[0] for x in items]), |
| np.array([x[1] for x in items]), |
| ) |
| map_val = mean_average_precision(per_query) |
| ndcg_val = mean_ndcg_at_k(per_query, k=10) |
| logger.info(f"MAP (threshold-independent): {map_val:.4f}") |
| logger.info(f"NDCG@10 (threshold-independent): {ndcg_val:.4f}") |
|
|
| |
| print() |
| print("=" * 80) |
| print( |
| f"{'theta':>6} | {'precision':>9} | {'recall':>7} | " |
| f"{'F1':>7} | {'hop1':>6} | {'hop2':>6} | {'hop3':>6}" |
| ) |
| print("=" * 80) |
| best_f1 = -1.0 |
| best_t = None |
| for t in thresholds: |
| prf = precision_recall_f1(scores, labels, t) |
| hop_prec = hop_stratified_precision(instances, scores, t) |
| f1 = prf["f1"] |
| if f1 > best_f1: |
| best_f1 = f1 |
| best_t = t |
| print( |
| f"{t:>6.2f} | {prf['precision']:>9.4f} | {prf['recall']:>7.4f} | " |
| f"{f1:>7.4f} | {hop_prec.get(1, 0.0):>6.4f} | " |
| f"{hop_prec.get(2, 0.0):>6.4f} | {hop_prec.get(3, 0.0):>6.4f}" |
| ) |
| print("=" * 80) |
| print(f"Best threshold: {best_t:.2f} (F1 = {best_f1:.4f})") |
| print(f"MAP = {map_val:.4f} NDCG@10 = {ndcg_val:.4f} (constant across thresholds)") |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|