File size: 5,990 Bytes
634ebe8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | #!/usr/bin/env python
"""
scripts/annotate_triples.py β Annotate gold-relevance labels for QA records.
Implements the labeling protocol described in paper Β§8.1:
"Gold relevance labels y_β are assigned by shortest-path
reachability: triples on any shortest path from seed to gold
answer entity receive y=1; all others y=0. This yields 3.7M
labeled training instances across hop depths 1β3."
This script takes raw QA records (with seeds and gold answers) and
emits a JSON file ready to be consumed by `caff.data.load_qa_split`.
Input format
------------
A JSON list of records:
{
"query_id": "pubmedqa_train_001",
"question": "What drug targets ...?",
"seeds": ["C1234567", "C7654321"], # UMLS CUIs
"gold_answer": "C8888888", # UMLS CUI
"answer_label": "yes" | "no" | "maybe" | text # optional
}
Output is the same format, plus this script pre-computes BFS
subgraphs (writing them to the cache directory) and verifies that
the gold-answer entity is reachable within L hops for at least
SOME seeds. Records with no reachable gold answer are flagged.
Usage
-----
python scripts/annotate_triples.py \
--kg data/processed/merged_kg.tsv \
--qa-input data/raw/pubmedqa_with_seeds.json \
--qa-output data/processed/train.json \
--cache-dir cache/bfs/ \
--L 3 --K-r 20
"""
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
from caff.data import (
CachedBFSExtractor,
KnowledgeGraph,
annotate_gold_relevance,
)
from caff.utils.logging import setup_logging
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Annotate gold relevance.")
p.add_argument("--kg", required=True, help="merged_kg.tsv")
p.add_argument("--qa-input", required=True, help="Input QA JSON.")
p.add_argument("--qa-output", required=True, help="Output annotated JSON.")
p.add_argument("--cache-dir", default="cache/bfs",
help="BFS subgraph cache directory.")
p.add_argument("--L", type=int, default=3, help="Max BFS depth.")
p.add_argument("--K-r", type=int, default=20,
help="Frequency cap per (head, relation).")
p.add_argument("--min-relation-freq", type=int, default=50)
return p.parse_args()
def main() -> None:
args = parse_args()
setup_logging(level="INFO")
# βββ Load KG βββββββββββββββββββββββββββββββββββββββββββββββββ
kg = KnowledgeGraph.from_tsv(
args.kg, min_relation_freq=args.min_relation_freq
)
# βββ Load raw QA records ββββββββββββββββββββββββββββββββββββ
qa_path = Path(args.qa_input)
with qa_path.open("r", encoding="utf-8") as f:
records = json.load(f)
logger.info(f"Loaded {len(records):,} QA records from {qa_path}")
# βββ Pre-compute BFS subgraphs (with cache) βββββββββββββββββ
bfs = CachedBFSExtractor(
kg, L=args.L, K_r=args.K_r, cache_dir=args.cache_dir
)
# βββ Annotate each record βββββββββββββββββββββββββββββββββββ
n_reachable = 0
n_total_pos = 0
n_total_cand = 0
annotated_records: list[dict] = []
for i, rec in enumerate(records):
qid = rec["query_id"]
seeds = rec.get("seeds", [])
gold = rec.get("gold_answer")
if not seeds:
logger.warning(f" [{qid}] no seeds β skipping")
continue
bfs_data = bfs.extract(qid, seeds, gold_answer=gold)
n_candidates = sum(len(C) for C in bfs_data["candidate_sets_cap"])
n_pos = len(bfs_data["gold_positives"] or set())
n_total_cand += n_candidates
n_total_pos += n_pos
if n_pos > 0:
n_reachable += 1
out = dict(rec)
out["n_candidates_per_hop"] = [
len(C) for C in bfs_data["candidate_sets_cap"]
]
out["n_gold_positives"] = n_pos
annotated_records.append(out)
if (i + 1) % 200 == 0:
logger.info(
f" ... {i + 1:,}/{len(records):,} "
f"reachable={n_reachable:,} "
f"avg_cands={n_total_cand / (i + 1):.0f} "
f"avg_pos={n_total_pos / max(1, i + 1):.1f}"
)
# βββ Stats ββββββββββββββββββββββββββββββββββββββββββββββββββ
logger.info("β" * 60)
logger.info(f"Annotation summary:")
logger.info(f" Records: {len(annotated_records):,}")
logger.info(f" With β₯1 gold-positive triple: {n_reachable:,} "
f"({100 * n_reachable / max(1, len(annotated_records)):.1f}%)")
logger.info(f" Total candidate triples (post-FreqCap): {n_total_cand:,}")
logger.info(f" Total gold-positive triples: {n_total_pos:,}")
logger.info(f" Paper Β§8.1 reports β3.7M labeled instances total")
logger.info("β" * 60)
# βββ Write output βββββββββββββββββββββββββββββββββββββββββββ
out_path = Path(args.qa_output)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
json.dump(annotated_records, f, indent=2)
logger.info(f"Wrote {len(annotated_records):,} records to {out_path}")
if __name__ == "__main__":
main() |