arcspan / scripts /merge_r7_data.py
chairulridjal's picture
Add files using upload-large-folder tool
3dac39e verified
#!/usr/bin/env python3
"""Merge cleaned training data with converted APTNER data for R7.
Produces:
- data/processed/r7_5class_train.jsonl (existing train + APTNER train new)
- data/processed/r7_5class_valid.jsonl (existing valid + APTNER dev new)
"""
import json
from collections import defaultdict
from pathlib import Path
def load_jsonl(path: Path):
records = []
with open(path) as f:
for line in f:
if line.strip():
records.append(json.loads(line))
return records
def count_entities(records):
counts = defaultdict(int)
total = 0
for r in records:
for key, positions in r["spans"].items():
label = key.split(":")[0]
counts[label] += len(positions)
total += len(positions)
return total, dict(counts)
def main():
base = Path("/home/ubuntu/alkyline/data/processed")
# Load sources
print("Loading existing cleaned data...")
train_existing = load_jsonl(base / "enriched_5class_train_cleaned.jsonl")
valid_existing = load_jsonl(base / "enriched_5class_valid_cleaned.jsonl")
print("Loading converted APTNER data...")
aptner_train = load_jsonl(base / "aptner_5class_train.jsonl")
aptner_dev = load_jsonl(base / "aptner_5class_dev.jsonl")
# Merge
train_merged = train_existing + aptner_train
valid_merged = valid_existing + aptner_dev
# Write
train_out = base / "r7_5class_train.jsonl"
valid_out = base / "r7_5class_valid.jsonl"
for path, records in [(train_out, train_merged), (valid_out, valid_merged)]:
with open(path, "w") as f:
for r in records:
f.write(json.dumps(r, ensure_ascii=False) + "\n")
# Statistics
print("\n=== R7 Dataset Statistics ===\n")
for name, existing, new, merged in [
("Train", train_existing, aptner_train, train_merged),
("Valid", valid_existing, aptner_dev, valid_merged),
]:
e_total, e_by_class = count_entities(existing)
n_total, n_by_class = count_entities(new)
m_total, m_by_class = count_entities(merged)
print(f"--- {name} ---")
print(f" Existing: {len(existing):>6} sentences, {e_total:>6} entities")
print(f" + APTNER: {len(new):>6} sentences, {n_total:>6} entities")
print(f" = Merged: {len(merged):>6} sentences, {m_total:>6} entities")
print(f"\n Per-class breakdown (merged):")
all_labels = sorted(set(list(e_by_class.keys()) + list(n_by_class.keys())))
for label in all_labels:
e = e_by_class.get(label, 0)
n = n_by_class.get(label, 0)
m = m_by_class.get(label, 0)
print(f" {label:<16} {e:>6} existing + {n:>5} new = {m:>6}")
print()
print(f"Written: {train_out}")
print(f"Written: {valid_out}")
if __name__ == "__main__":
main()