| |
| """Merge cleaned training data with converted APTNER data for R7. |
| |
| Produces: |
| - data/processed/r7_5class_train.jsonl (existing train + APTNER train new) |
| - data/processed/r7_5class_valid.jsonl (existing valid + APTNER dev new) |
| """ |
| import json |
| from collections import defaultdict |
| from pathlib import Path |
|
|
|
|
| def load_jsonl(path: Path): |
| records = [] |
| with open(path) as f: |
| for line in f: |
| if line.strip(): |
| records.append(json.loads(line)) |
| return records |
|
|
|
|
| def count_entities(records): |
| counts = defaultdict(int) |
| total = 0 |
| for r in records: |
| for key, positions in r["spans"].items(): |
| label = key.split(":")[0] |
| counts[label] += len(positions) |
| total += len(positions) |
| return total, dict(counts) |
|
|
|
|
| def main(): |
| base = Path("/home/ubuntu/alkyline/data/processed") |
|
|
| |
| print("Loading existing cleaned data...") |
| train_existing = load_jsonl(base / "enriched_5class_train_cleaned.jsonl") |
| valid_existing = load_jsonl(base / "enriched_5class_valid_cleaned.jsonl") |
|
|
| print("Loading converted APTNER data...") |
| aptner_train = load_jsonl(base / "aptner_5class_train.jsonl") |
| aptner_dev = load_jsonl(base / "aptner_5class_dev.jsonl") |
|
|
| |
| train_merged = train_existing + aptner_train |
| valid_merged = valid_existing + aptner_dev |
|
|
| |
| train_out = base / "r7_5class_train.jsonl" |
| valid_out = base / "r7_5class_valid.jsonl" |
|
|
| for path, records in [(train_out, train_merged), (valid_out, valid_merged)]: |
| with open(path, "w") as f: |
| for r in records: |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") |
|
|
| |
| print("\n=== R7 Dataset Statistics ===\n") |
|
|
| for name, existing, new, merged in [ |
| ("Train", train_existing, aptner_train, train_merged), |
| ("Valid", valid_existing, aptner_dev, valid_merged), |
| ]: |
| e_total, e_by_class = count_entities(existing) |
| n_total, n_by_class = count_entities(new) |
| m_total, m_by_class = count_entities(merged) |
|
|
| print(f"--- {name} ---") |
| print(f" Existing: {len(existing):>6} sentences, {e_total:>6} entities") |
| print(f" + APTNER: {len(new):>6} sentences, {n_total:>6} entities") |
| print(f" = Merged: {len(merged):>6} sentences, {m_total:>6} entities") |
| print(f"\n Per-class breakdown (merged):") |
| all_labels = sorted(set(list(e_by_class.keys()) + list(n_by_class.keys()))) |
| for label in all_labels: |
| e = e_by_class.get(label, 0) |
| n = n_by_class.get(label, 0) |
| m = m_by_class.get(label, 0) |
| print(f" {label:<16} {e:>6} existing + {n:>5} new = {m:>6}") |
| print() |
|
|
| print(f"Written: {train_out}") |
| print(f"Written: {valid_out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|