subnet32-llm-detector / scripts /holdout_pairs_jsonl.py
ThaoTran7's picture
incomplete commit
485127c
#!/usr/bin/env python3
"""Split a JSONL (with pair_index) into train_remainder + stress by pair_index mod."""
import argparse
import json
from collections import defaultdict
def main():
p = argparse.ArgumentParser(
description="Group rows by pair_index; pairs with pair_index %% n == r go to stress, rest to train_out."
)
p.add_argument("--input", required=True)
p.add_argument("--train_out", required=True)
p.add_argument("--stress_out", required=True)
p.add_argument("--n", type=int, default=5, help="Modulus (default 5)")
p.add_argument("--r", type=int, default=0, help="Remainder class sent to stress (default 0)")
args = p.parse_args()
by_pair = defaultdict(list)
with open(args.input, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
by_pair[int(row["pair_index"])].append(row)
stress_pairs = []
train_pairs = []
for pi, rows in sorted(by_pair.items()):
if pi % args.n == args.r:
stress_pairs.append(rows)
else:
train_pairs.append(rows)
def write(path, groups):
with open(path, "w", encoding="utf-8") as out:
for grp in groups:
for row in grp:
out.write(json.dumps(row, ensure_ascii=False) + "\n")
write(args.train_out, train_pairs)
write(args.stress_out, stress_pairs)
print(
f"Wrote train={args.train_out} (pairs={len(train_pairs)}) "
f"stress={args.stress_out} (pairs={len(stress_pairs)})"
)
if __name__ == "__main__":
main()