CAFF / scripts /probe_hc3_model.py
MrDhifallah's picture
Add files using upload-large-folder tool
da28b2a verified
Raw
History Blame Contribute Delete
9.14 kB
"""
scripts/probe_hc3_model.py - Model-level confirmation of why HC3 has
zero gradient (Hypothesis B).
Background
----------
The data probe (probe_hc3_keys.py) showed that 51.67% of positive
anchors DO have a same-(query_id, relation, hop) negative, so the
miner is NOT starved (Hypothesis A rejected).
Hypothesis B: within a single (query_id, hop) group, the trainer
assigns ONE shared teacher-forced z_prev to every instance. A
positive anchor and its same-key negative therefore receive the
SAME z_prev, the SAME query embedding q, and (because the key fixes
the relation) the SAME relation embedding E_r. The scorer is a
deterministic function of (W_ctx(z), v, q, E_r), so it returns an
IDENTICAL score for the positive and the negative. Then
L_HC3 = relu(s_neg - s_pos + margin) = relu(0 + margin) = margin
is constant in the parameters, and its gradient is exactly zero.
This script confirms that empirically. It loads a trained checkpoint,
replicates the EXACT scoring path used by trainer._score_hc3_instance
(model.get_hop_W_ctx -> relation_cache.get_batch -> scorer.score_candidates),
and for several real same-key (pos, neg) pairs reports:
* |s_pos - s_neg| (expected ~0 if Hypothesis B holds)
* the HC3 loss value
* the total gradient norm after l_hc3.backward()
Run:
python scripts/probe_hc3_model.py --checkpoint runs/caff_orphanet/seed_42/best.pt --device cuda
"""
from __future__ import annotations
import argparse
import sys
from collections import defaultdict
from pathlib import Path
import torch
ROOT = Path(__file__).parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from caff import (
CAFFTripleDataset,
HC3Loss,
TrainingInstance,
)
from caff.data import CachedBFSExtractor, KnowledgeGraph, load_qa_split # noqa
# Reuse the tested checkpoint loader from evaluate.py
from evaluate import load_checkpoint
def score_instance(model, inst, q_embedding, device):
"""Replicate trainer._score_hc3_instance EXACTLY."""
hop_idx = inst.hop - 1
z = inst.z_prev.to(device).unsqueeze(0) # (1, d)
W_ctx = model.get_hop_W_ctx(hop_idx, z).squeeze(0) # (d, d)
E_r = model.relation_cache.get_batch([inst.relation]) # (1, d)
scorer = model.hop_scorers[hop_idx]
score = scorer.score_candidates(W_ctx, model.v, q_embedding, E_r) # (1,)
return score.squeeze(0)
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", default="runs/caff_orphanet/seed_42/best.pt")
ap.add_argument("--device", default="cuda")
ap.add_argument("--cache-dir", default="cache")
ap.add_argument("--num-pairs", type=int, default=20,
help="How many same-key (pos, neg) pairs to test.")
args = ap.parse_args()
device = args.device if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
# 1. Load trained model + config + encoder + kg via the tested helper
model, config, ablation, encoder, kg = load_checkpoint(
args.checkpoint, device, Path(args.cache_dir)
)
model.train() # ensure grad is tracked
print(f"Loaded checkpoint: {args.checkpoint}")
print(f" ablation: csv={ablation.use_csv} dbm={ablation.use_dbm} "
f"hc3={ablation.use_hc3} dc={ablation.use_dc}")
# 2. Build train dataset to get real instances + grouping
bfs = CachedBFSExtractor(kg, L=config.L, K_r=config.K_r,
cache_dir=Path(args.cache_dir) / "bfs")
train_recs = load_qa_split(config.train_path)
ds = CAFFTripleDataset(train_recs, bfs, require_gold=True)
# 3. Find (query_id, relation, hop) keys that carry BOTH labels.
# Group instances exactly like the trainer: by (query_id, hop).
# Within such a group, the trainer shares ONE z_prev across all
# instances - so we emulate that by assigning the SAME z_prev to
# the positive and the negative we pull from the same group.
by_group: dict[tuple, list] = defaultdict(list)
for inst in ds.instances:
by_group[(inst.query_id, inst.hop)].append(inst)
# Collect same-key (pos, neg) pairs that live in the SAME group
pairs = [] # (pos_inst, neg_inst, question)
for (qid, hop), items in by_group.items():
# index by relation within this group
by_rel = defaultdict(lambda: {"pos": [], "neg": []})
question = items[0].question if hasattr(items[0], "question") else ""
for it in items:
(by_rel[it.relation]["pos" if it.label == 1 else "neg"]).append(it)
for rel, d in by_rel.items():
if d["pos"] and d["neg"]:
pairs.append((d["pos"][0], d["neg"][0], qid, hop, rel))
if len(pairs) >= args.num_pairs:
break
print(f"\nFound {len(pairs)} same-key (pos, neg) pairs in the SAME group.")
if not pairs:
print("No same-group same-key pairs found; cannot test Hypothesis B.")
return 0
# 4. For each pair, assign the SHARED group z_prev (teacher forcing
# gives one z per (query, hop) group). We use a representative
# z_prev: the frozen relation-embedding mean is not needed here;
# what matters is that pos and neg get the IDENTICAL z. We build a
# single z per group and reuse it for both.
d = config.d
hc3 = HC3Loss(margin=config.gamma_C)
abs_diffs = []
pos_scores_all = []
neg_scores_all = []
# Build query embeddings via the frozen encoder (no grad on encoder;
# encoder is frozen, so q is a constant input - same for pos and neg).
q_cache: dict[str, torch.Tensor] = {}
def get_q(qid, question):
if qid not in q_cache:
with torch.no_grad():
emb = encoder.encode([question])[0]
q_cache[qid] = emb.to(device)
return q_cache[qid]
pos_score_list = []
neg_score_list = []
for (pos_inst, neg_inst, qid, hop, rel) in pairs:
# Shared group z_prev (identical for pos and neg) - this is the
# crux of Hypothesis B. Use a deterministic non-trivial vector.
z_shared = torch.randn(d, generator=torch.Generator().manual_seed(hash((qid, hop)) % (2**31)))
# attach same z to both
pos_inst = TrainingInstance(
query_id=qid, question=pos_inst.question if hasattr(pos_inst, "question") else "",
head=getattr(pos_inst, "head", ""), relation=rel,
tail=getattr(pos_inst, "tail", ""), hop=hop, label=1, z_prev=z_shared.clone(),
)
neg_inst = TrainingInstance(
query_id=qid, question=neg_inst.question if hasattr(neg_inst, "question") else "",
head=getattr(neg_inst, "head", ""), relation=rel,
tail=getattr(neg_inst, "tail", ""), hop=hop, label=0, z_prev=z_shared.clone(),
)
q_emb = get_q(qid, pos_inst.question)
s_pos = score_instance(model, pos_inst, q_emb, device)
s_neg = score_instance(model, neg_inst, q_emb, device)
abs_diffs.append((s_pos - s_neg).abs().item())
pos_score_list.append(s_pos)
neg_score_list.append(s_neg)
pos_t = torch.stack(pos_score_list)
neg_t = torch.stack(neg_score_list)
l_hc3 = hc3(pos_t, neg_t)
print("\n" + "=" * 70)
print("MODEL-LEVEL HC3 PROBE (Hypothesis B test)")
print("=" * 70)
import statistics
print(f" pairs tested : {len(abs_diffs)}")
print(f" mean |s_pos - s_neg| : {statistics.mean(abs_diffs):.3e}")
print(f" max |s_pos - s_neg| : {max(abs_diffs):.3e}")
print(f" HC3 loss value : {l_hc3.item():.6f} (margin={config.gamma_C})")
print(f" requires_grad : {l_hc3.requires_grad}")
model.zero_grad()
if l_hc3.requires_grad:
l_hc3.backward()
total = 0.0
nz = 0
for _, p in model.named_parameters():
if p.grad is not None:
g = p.grad.abs().sum().item()
total += g
if g > 0:
nz += 1
print(f" total |grad| sum : {total:.6e}")
print(f" params with grad>0 : {nz}")
print("=" * 70)
print()
same = max(abs_diffs) < 1e-6
if same:
print("VERDICT (Hypothesis B CONFIRMED):")
print(" With a shared group z_prev, the positive and negative")
print(" receive IDENTICAL scores (|s_pos - s_neg| ~ 0). The HC3")
print(" loss is the constant margin and its gradient is zero.")
print(" => HC3 cannot learn on same-group pairs. Since the miner")
print(" only pairs same-(query,relation,hop) instances, which")
print(" always live in the same group, HC3 is inert. This is")
print(" why Full CAFF and NoHC3 have identical weights.")
else:
print("VERDICT: scores differ; Hypothesis B does NOT fully hold.")
print(" Investigate whether negatives come from different groups.")
else:
print(" l_hc3.requires_grad is False -> gradient path broken upstream.")
print()
return 0
if __name__ == "__main__":
sys.exit(main())