| """ |
| scripts/probe_hc3_model.py - Model-level confirmation of why HC3 has |
| zero gradient (Hypothesis B). |
| |
| Background |
| ---------- |
| The data probe (probe_hc3_keys.py) showed that 51.67% of positive |
| anchors DO have a same-(query_id, relation, hop) negative, so the |
| miner is NOT starved (Hypothesis A rejected). |
| |
| Hypothesis B: within a single (query_id, hop) group, the trainer |
| assigns ONE shared teacher-forced z_prev to every instance. A |
| positive anchor and its same-key negative therefore receive the |
| SAME z_prev, the SAME query embedding q, and (because the key fixes |
| the relation) the SAME relation embedding E_r. The scorer is a |
| deterministic function of (W_ctx(z), v, q, E_r), so it returns an |
| IDENTICAL score for the positive and the negative. Then |
| |
| L_HC3 = relu(s_neg - s_pos + margin) = relu(0 + margin) = margin |
| |
| is constant in the parameters, and its gradient is exactly zero. |
| |
| This script confirms that empirically. It loads a trained checkpoint, |
| replicates the EXACT scoring path used by trainer._score_hc3_instance |
| (model.get_hop_W_ctx -> relation_cache.get_batch -> scorer.score_candidates), |
| and for several real same-key (pos, neg) pairs reports: |
| |
| * |s_pos - s_neg| (expected ~0 if Hypothesis B holds) |
| * the HC3 loss value |
| * the total gradient norm after l_hc3.backward() |
| |
| Run: |
| python scripts/probe_hc3_model.py --checkpoint runs/caff_orphanet/seed_42/best.pt --device cuda |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import sys |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import torch |
|
|
| ROOT = Path(__file__).parent.parent |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from caff import ( |
| CAFFTripleDataset, |
| HC3Loss, |
| TrainingInstance, |
| ) |
| from caff.data import CachedBFSExtractor, KnowledgeGraph, load_qa_split |
| |
| from evaluate import load_checkpoint |
|
|
|
|
| def score_instance(model, inst, q_embedding, device): |
| """Replicate trainer._score_hc3_instance EXACTLY.""" |
| hop_idx = inst.hop - 1 |
| z = inst.z_prev.to(device).unsqueeze(0) |
| W_ctx = model.get_hop_W_ctx(hop_idx, z).squeeze(0) |
| E_r = model.relation_cache.get_batch([inst.relation]) |
| scorer = model.hop_scorers[hop_idx] |
| score = scorer.score_candidates(W_ctx, model.v, q_embedding, E_r) |
| return score.squeeze(0) |
|
|
|
|
| def main() -> int: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--checkpoint", default="runs/caff_orphanet/seed_42/best.pt") |
| ap.add_argument("--device", default="cuda") |
| ap.add_argument("--cache-dir", default="cache") |
| ap.add_argument("--num-pairs", type=int, default=20, |
| help="How many same-key (pos, neg) pairs to test.") |
| args = ap.parse_args() |
|
|
| device = args.device if torch.cuda.is_available() else "cpu" |
| print(f"Device: {device}") |
|
|
| |
| model, config, ablation, encoder, kg = load_checkpoint( |
| args.checkpoint, device, Path(args.cache_dir) |
| ) |
| model.train() |
| print(f"Loaded checkpoint: {args.checkpoint}") |
| print(f" ablation: csv={ablation.use_csv} dbm={ablation.use_dbm} " |
| f"hc3={ablation.use_hc3} dc={ablation.use_dc}") |
|
|
| |
| bfs = CachedBFSExtractor(kg, L=config.L, K_r=config.K_r, |
| cache_dir=Path(args.cache_dir) / "bfs") |
| train_recs = load_qa_split(config.train_path) |
| ds = CAFFTripleDataset(train_recs, bfs, require_gold=True) |
|
|
| |
| |
| |
| |
| |
| by_group: dict[tuple, list] = defaultdict(list) |
| for inst in ds.instances: |
| by_group[(inst.query_id, inst.hop)].append(inst) |
|
|
| |
| pairs = [] |
| for (qid, hop), items in by_group.items(): |
| |
| by_rel = defaultdict(lambda: {"pos": [], "neg": []}) |
| question = items[0].question if hasattr(items[0], "question") else "" |
| for it in items: |
| (by_rel[it.relation]["pos" if it.label == 1 else "neg"]).append(it) |
| for rel, d in by_rel.items(): |
| if d["pos"] and d["neg"]: |
| pairs.append((d["pos"][0], d["neg"][0], qid, hop, rel)) |
| if len(pairs) >= args.num_pairs: |
| break |
|
|
| print(f"\nFound {len(pairs)} same-key (pos, neg) pairs in the SAME group.") |
| if not pairs: |
| print("No same-group same-key pairs found; cannot test Hypothesis B.") |
| return 0 |
|
|
| |
| |
| |
| |
| |
| d = config.d |
| hc3 = HC3Loss(margin=config.gamma_C) |
|
|
| abs_diffs = [] |
| pos_scores_all = [] |
| neg_scores_all = [] |
|
|
| |
| |
| q_cache: dict[str, torch.Tensor] = {} |
|
|
| def get_q(qid, question): |
| if qid not in q_cache: |
| with torch.no_grad(): |
| emb = encoder.encode([question])[0] |
| q_cache[qid] = emb.to(device) |
| return q_cache[qid] |
|
|
| pos_score_list = [] |
| neg_score_list = [] |
| for (pos_inst, neg_inst, qid, hop, rel) in pairs: |
| |
| |
| z_shared = torch.randn(d, generator=torch.Generator().manual_seed(hash((qid, hop)) % (2**31))) |
| |
| pos_inst = TrainingInstance( |
| query_id=qid, question=pos_inst.question if hasattr(pos_inst, "question") else "", |
| head=getattr(pos_inst, "head", ""), relation=rel, |
| tail=getattr(pos_inst, "tail", ""), hop=hop, label=1, z_prev=z_shared.clone(), |
| ) |
| neg_inst = TrainingInstance( |
| query_id=qid, question=neg_inst.question if hasattr(neg_inst, "question") else "", |
| head=getattr(neg_inst, "head", ""), relation=rel, |
| tail=getattr(neg_inst, "tail", ""), hop=hop, label=0, z_prev=z_shared.clone(), |
| ) |
| q_emb = get_q(qid, pos_inst.question) |
| s_pos = score_instance(model, pos_inst, q_emb, device) |
| s_neg = score_instance(model, neg_inst, q_emb, device) |
| abs_diffs.append((s_pos - s_neg).abs().item()) |
| pos_score_list.append(s_pos) |
| neg_score_list.append(s_neg) |
|
|
| pos_t = torch.stack(pos_score_list) |
| neg_t = torch.stack(neg_score_list) |
| l_hc3 = hc3(pos_t, neg_t) |
|
|
| print("\n" + "=" * 70) |
| print("MODEL-LEVEL HC3 PROBE (Hypothesis B test)") |
| print("=" * 70) |
| import statistics |
| print(f" pairs tested : {len(abs_diffs)}") |
| print(f" mean |s_pos - s_neg| : {statistics.mean(abs_diffs):.3e}") |
| print(f" max |s_pos - s_neg| : {max(abs_diffs):.3e}") |
| print(f" HC3 loss value : {l_hc3.item():.6f} (margin={config.gamma_C})") |
| print(f" requires_grad : {l_hc3.requires_grad}") |
|
|
| model.zero_grad() |
| if l_hc3.requires_grad: |
| l_hc3.backward() |
| total = 0.0 |
| nz = 0 |
| for _, p in model.named_parameters(): |
| if p.grad is not None: |
| g = p.grad.abs().sum().item() |
| total += g |
| if g > 0: |
| nz += 1 |
| print(f" total |grad| sum : {total:.6e}") |
| print(f" params with grad>0 : {nz}") |
| print("=" * 70) |
| print() |
| same = max(abs_diffs) < 1e-6 |
| if same: |
| print("VERDICT (Hypothesis B CONFIRMED):") |
| print(" With a shared group z_prev, the positive and negative") |
| print(" receive IDENTICAL scores (|s_pos - s_neg| ~ 0). The HC3") |
| print(" loss is the constant margin and its gradient is zero.") |
| print(" => HC3 cannot learn on same-group pairs. Since the miner") |
| print(" only pairs same-(query,relation,hop) instances, which") |
| print(" always live in the same group, HC3 is inert. This is") |
| print(" why Full CAFF and NoHC3 have identical weights.") |
| else: |
| print("VERDICT: scores differ; Hypothesis B does NOT fully hold.") |
| print(" Investigate whether negatives come from different groups.") |
| else: |
| print(" l_hc3.requires_grad is False -> gradient path broken upstream.") |
| print() |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|