|
|
| """
|
| scripts/annotate_triples.py β Annotate gold-relevance labels for QA records.
|
|
|
| Implements the labeling protocol described in paper Β§8.1:
|
|
|
| "Gold relevance labels y_β are assigned by shortest-path
|
| reachability: triples on any shortest path from seed to gold
|
| answer entity receive y=1; all others y=0. This yields 3.7M
|
| labeled training instances across hop depths 1β3."
|
|
|
| This script takes raw QA records (with seeds and gold answers) and
|
| emits a JSON file ready to be consumed by `caff.data.load_qa_split`.
|
|
|
| Input format
|
| ------------
|
| A JSON list of records:
|
| {
|
| "query_id": "pubmedqa_train_001",
|
| "question": "What drug targets ...?",
|
| "seeds": ["C1234567", "C7654321"], # UMLS CUIs
|
| "gold_answer": "C8888888", # UMLS CUI
|
| "answer_label": "yes" | "no" | "maybe" | text # optional
|
| }
|
|
|
| Output is the same format, plus this script pre-computes BFS
|
| subgraphs (writing them to the cache directory) and verifies that
|
| the gold-answer entity is reachable within L hops for at least
|
| SOME seeds. Records with no reachable gold answer are flagged.
|
|
|
| Usage
|
| -----
|
| python scripts/annotate_triples.py \
|
| --kg data/processed/merged_kg.tsv \
|
| --qa-input data/raw/pubmedqa_with_seeds.json \
|
| --qa-output data/processed/train.json \
|
| --cache-dir cache/bfs/ \
|
| --L 3 --K-r 20
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import json
|
| import logging
|
| from pathlib import Path
|
|
|
| from caff.data import (
|
| CachedBFSExtractor,
|
| KnowledgeGraph,
|
| annotate_gold_relevance,
|
| )
|
| from caff.utils.logging import setup_logging
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| p = argparse.ArgumentParser(description="Annotate gold relevance.")
|
| p.add_argument("--kg", required=True, help="merged_kg.tsv")
|
| p.add_argument("--qa-input", required=True, help="Input QA JSON.")
|
| p.add_argument("--qa-output", required=True, help="Output annotated JSON.")
|
| p.add_argument("--cache-dir", default="cache/bfs",
|
| help="BFS subgraph cache directory.")
|
| p.add_argument("--L", type=int, default=3, help="Max BFS depth.")
|
| p.add_argument("--K-r", type=int, default=20,
|
| help="Frequency cap per (head, relation).")
|
| p.add_argument("--min-relation-freq", type=int, default=50)
|
| return p.parse_args()
|
|
|
|
|
| def main() -> None:
|
| args = parse_args()
|
| setup_logging(level="INFO")
|
|
|
|
|
| kg = KnowledgeGraph.from_tsv(
|
| args.kg, min_relation_freq=args.min_relation_freq
|
| )
|
|
|
|
|
| qa_path = Path(args.qa_input)
|
| with qa_path.open("r", encoding="utf-8") as f:
|
| records = json.load(f)
|
| logger.info(f"Loaded {len(records):,} QA records from {qa_path}")
|
|
|
|
|
| bfs = CachedBFSExtractor(
|
| kg, L=args.L, K_r=args.K_r, cache_dir=args.cache_dir
|
| )
|
|
|
|
|
| n_reachable = 0
|
| n_total_pos = 0
|
| n_total_cand = 0
|
| annotated_records: list[dict] = []
|
|
|
| for i, rec in enumerate(records):
|
| qid = rec["query_id"]
|
| seeds = rec.get("seeds", [])
|
| gold = rec.get("gold_answer")
|
| if not seeds:
|
| logger.warning(f" [{qid}] no seeds β skipping")
|
| continue
|
|
|
| bfs_data = bfs.extract(qid, seeds, gold_answer=gold)
|
| n_candidates = sum(len(C) for C in bfs_data["candidate_sets_cap"])
|
| n_pos = len(bfs_data["gold_positives"] or set())
|
| n_total_cand += n_candidates
|
| n_total_pos += n_pos
|
| if n_pos > 0:
|
| n_reachable += 1
|
|
|
| out = dict(rec)
|
| out["n_candidates_per_hop"] = [
|
| len(C) for C in bfs_data["candidate_sets_cap"]
|
| ]
|
| out["n_gold_positives"] = n_pos
|
| annotated_records.append(out)
|
|
|
| if (i + 1) % 200 == 0:
|
| logger.info(
|
| f" ... {i + 1:,}/{len(records):,} "
|
| f"reachable={n_reachable:,} "
|
| f"avg_cands={n_total_cand / (i + 1):.0f} "
|
| f"avg_pos={n_total_pos / max(1, i + 1):.1f}"
|
| )
|
|
|
|
|
| logger.info("β" * 60)
|
| logger.info(f"Annotation summary:")
|
| logger.info(f" Records: {len(annotated_records):,}")
|
| logger.info(f" With β₯1 gold-positive triple: {n_reachable:,} "
|
| f"({100 * n_reachable / max(1, len(annotated_records)):.1f}%)")
|
| logger.info(f" Total candidate triples (post-FreqCap): {n_total_cand:,}")
|
| logger.info(f" Total gold-positive triples: {n_total_pos:,}")
|
| logger.info(f" Paper Β§8.1 reports β3.7M labeled instances total")
|
| logger.info("β" * 60)
|
|
|
|
|
| out_path = Path(args.qa_output)
|
| out_path.parent.mkdir(parents=True, exist_ok=True)
|
| with out_path.open("w", encoding="utf-8") as f:
|
| json.dump(annotated_records, f, indent=2)
|
| logger.info(f"Wrote {len(annotated_records):,} records to {out_path}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |