CAFF / scripts /annotate_triples.py
MrDhifallah's picture
Upload folder using huggingface_hub
634ebe8 verified
Raw
History Blame Contribute Delete
5.99 kB
#!/usr/bin/env python
"""
scripts/annotate_triples.py β€” Annotate gold-relevance labels for QA records.
Implements the labeling protocol described in paper Β§8.1:
"Gold relevance labels y_β„“ are assigned by shortest-path
reachability: triples on any shortest path from seed to gold
answer entity receive y=1; all others y=0. This yields 3.7M
labeled training instances across hop depths 1–3."
This script takes raw QA records (with seeds and gold answers) and
emits a JSON file ready to be consumed by `caff.data.load_qa_split`.
Input format
------------
A JSON list of records:
{
"query_id": "pubmedqa_train_001",
"question": "What drug targets ...?",
"seeds": ["C1234567", "C7654321"], # UMLS CUIs
"gold_answer": "C8888888", # UMLS CUI
"answer_label": "yes" | "no" | "maybe" | text # optional
}
Output is the same format, plus this script pre-computes BFS
subgraphs (writing them to the cache directory) and verifies that
the gold-answer entity is reachable within L hops for at least
SOME seeds. Records with no reachable gold answer are flagged.
Usage
-----
python scripts/annotate_triples.py \
--kg data/processed/merged_kg.tsv \
--qa-input data/raw/pubmedqa_with_seeds.json \
--qa-output data/processed/train.json \
--cache-dir cache/bfs/ \
--L 3 --K-r 20
"""
from __future__ import annotations
import argparse
import json
import logging
from pathlib import Path
from caff.data import (
CachedBFSExtractor,
KnowledgeGraph,
annotate_gold_relevance,
)
from caff.utils.logging import setup_logging
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Annotate gold relevance.")
p.add_argument("--kg", required=True, help="merged_kg.tsv")
p.add_argument("--qa-input", required=True, help="Input QA JSON.")
p.add_argument("--qa-output", required=True, help="Output annotated JSON.")
p.add_argument("--cache-dir", default="cache/bfs",
help="BFS subgraph cache directory.")
p.add_argument("--L", type=int, default=3, help="Max BFS depth.")
p.add_argument("--K-r", type=int, default=20,
help="Frequency cap per (head, relation).")
p.add_argument("--min-relation-freq", type=int, default=50)
return p.parse_args()
def main() -> None:
args = parse_args()
setup_logging(level="INFO")
# ─── Load KG ─────────────────────────────────────────────────
kg = KnowledgeGraph.from_tsv(
args.kg, min_relation_freq=args.min_relation_freq
)
# ─── Load raw QA records ────────────────────────────────────
qa_path = Path(args.qa_input)
with qa_path.open("r", encoding="utf-8") as f:
records = json.load(f)
logger.info(f"Loaded {len(records):,} QA records from {qa_path}")
# ─── Pre-compute BFS subgraphs (with cache) ─────────────────
bfs = CachedBFSExtractor(
kg, L=args.L, K_r=args.K_r, cache_dir=args.cache_dir
)
# ─── Annotate each record ───────────────────────────────────
n_reachable = 0
n_total_pos = 0
n_total_cand = 0
annotated_records: list[dict] = []
for i, rec in enumerate(records):
qid = rec["query_id"]
seeds = rec.get("seeds", [])
gold = rec.get("gold_answer")
if not seeds:
logger.warning(f" [{qid}] no seeds β€” skipping")
continue
bfs_data = bfs.extract(qid, seeds, gold_answer=gold)
n_candidates = sum(len(C) for C in bfs_data["candidate_sets_cap"])
n_pos = len(bfs_data["gold_positives"] or set())
n_total_cand += n_candidates
n_total_pos += n_pos
if n_pos > 0:
n_reachable += 1
out = dict(rec)
out["n_candidates_per_hop"] = [
len(C) for C in bfs_data["candidate_sets_cap"]
]
out["n_gold_positives"] = n_pos
annotated_records.append(out)
if (i + 1) % 200 == 0:
logger.info(
f" ... {i + 1:,}/{len(records):,} "
f"reachable={n_reachable:,} "
f"avg_cands={n_total_cand / (i + 1):.0f} "
f"avg_pos={n_total_pos / max(1, i + 1):.1f}"
)
# ─── Stats ──────────────────────────────────────────────────
logger.info("─" * 60)
logger.info(f"Annotation summary:")
logger.info(f" Records: {len(annotated_records):,}")
logger.info(f" With β‰₯1 gold-positive triple: {n_reachable:,} "
f"({100 * n_reachable / max(1, len(annotated_records)):.1f}%)")
logger.info(f" Total candidate triples (post-FreqCap): {n_total_cand:,}")
logger.info(f" Total gold-positive triples: {n_total_pos:,}")
logger.info(f" Paper Β§8.1 reports β‰ˆ3.7M labeled instances total")
logger.info("─" * 60)
# ─── Write output ───────────────────────────────────────────
out_path = Path(args.qa_output)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
json.dump(annotated_records, f, indent=2)
logger.info(f"Wrote {len(annotated_records):,} records to {out_path}")
if __name__ == "__main__":
main()