#!/usr/bin/env python3 """Merge cleaned training data with converted APTNER data for R7. Produces: - data/processed/r7_5class_train.jsonl (existing train + APTNER train new) - data/processed/r7_5class_valid.jsonl (existing valid + APTNER dev new) """ import json from collections import defaultdict from pathlib import Path def load_jsonl(path: Path): records = [] with open(path) as f: for line in f: if line.strip(): records.append(json.loads(line)) return records def count_entities(records): counts = defaultdict(int) total = 0 for r in records: for key, positions in r["spans"].items(): label = key.split(":")[0] counts[label] += len(positions) total += len(positions) return total, dict(counts) def main(): base = Path("/home/ubuntu/alkyline/data/processed") # Load sources print("Loading existing cleaned data...") train_existing = load_jsonl(base / "enriched_5class_train_cleaned.jsonl") valid_existing = load_jsonl(base / "enriched_5class_valid_cleaned.jsonl") print("Loading converted APTNER data...") aptner_train = load_jsonl(base / "aptner_5class_train.jsonl") aptner_dev = load_jsonl(base / "aptner_5class_dev.jsonl") # Merge train_merged = train_existing + aptner_train valid_merged = valid_existing + aptner_dev # Write train_out = base / "r7_5class_train.jsonl" valid_out = base / "r7_5class_valid.jsonl" for path, records in [(train_out, train_merged), (valid_out, valid_merged)]: with open(path, "w") as f: for r in records: f.write(json.dumps(r, ensure_ascii=False) + "\n") # Statistics print("\n=== R7 Dataset Statistics ===\n") for name, existing, new, merged in [ ("Train", train_existing, aptner_train, train_merged), ("Valid", valid_existing, aptner_dev, valid_merged), ]: e_total, e_by_class = count_entities(existing) n_total, n_by_class = count_entities(new) m_total, m_by_class = count_entities(merged) print(f"--- {name} ---") print(f" Existing: {len(existing):>6} sentences, {e_total:>6} entities") print(f" + APTNER: {len(new):>6} sentences, {n_total:>6} entities") print(f" = Merged: {len(merged):>6} sentences, {m_total:>6} entities") print(f"\n Per-class breakdown (merged):") all_labels = sorted(set(list(e_by_class.keys()) + list(n_by_class.keys()))) for label in all_labels: e = e_by_class.get(label, 0) n = n_by_class.get(label, 0) m = m_by_class.get(label, 0) print(f" {label:<16} {e:>6} existing + {n:>5} new = {m:>6}") print() print(f"Written: {train_out}") print(f"Written: {valid_out}") if __name__ == "__main__": main()