| |
| """Split a JSONL (with pair_index) into train_remainder + stress by pair_index mod.""" |
|
|
| import argparse |
| import json |
| from collections import defaultdict |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser( |
| description="Group rows by pair_index; pairs with pair_index %% n == r go to stress, rest to train_out." |
| ) |
| p.add_argument("--input", required=True) |
| p.add_argument("--train_out", required=True) |
| p.add_argument("--stress_out", required=True) |
| p.add_argument("--n", type=int, default=5, help="Modulus (default 5)") |
| p.add_argument("--r", type=int, default=0, help="Remainder class sent to stress (default 0)") |
| args = p.parse_args() |
|
|
| by_pair = defaultdict(list) |
| with open(args.input, encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if not line: |
| continue |
| row = json.loads(line) |
| by_pair[int(row["pair_index"])].append(row) |
|
|
| stress_pairs = [] |
| train_pairs = [] |
| for pi, rows in sorted(by_pair.items()): |
| if pi % args.n == args.r: |
| stress_pairs.append(rows) |
| else: |
| train_pairs.append(rows) |
|
|
| def write(path, groups): |
| with open(path, "w", encoding="utf-8") as out: |
| for grp in groups: |
| for row in grp: |
| out.write(json.dumps(row, ensure_ascii=False) + "\n") |
|
|
| write(args.train_out, train_pairs) |
| write(args.stress_out, stress_pairs) |
| print( |
| f"Wrote train={args.train_out} (pairs={len(train_pairs)}) " |
| f"stress={args.stress_out} (pairs={len(stress_pairs)})" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|