import os import random import argparse from pathlib import Path from tqdm import tqdm def split_data(input_file, train_file, test_file, test_ratio=0.1, seed=42): """ Split tokenized data into training and test sets. Args: input_file (str): Path to the input file containing tokenized sequences train_file (str): Path to write training sequences test_file (str): Path to write test sequences test_ratio (float): Proportion of data to use for testing (default: 0.1) seed (int): Random seed for reproducibility (default: 42) """ random.seed(seed) # Count number of sequences in the file print("Counting sequences in the file...") with open(input_file, 'r') as f: total_sequences = sum(1 for _ in f) print(f"Total sequences found: {total_sequences}") # Determine how many sequences to put in test set test_count = int(total_sequences * test_ratio) train_count = total_sequences - test_count # Generate indices for test set all_indices = list(range(total_sequences)) random.shuffle(all_indices) test_indices = set(all_indices[:test_count]) print(f"Splitting data: {train_count} training sequences, {test_count} test sequences") # Split the data with open(input_file, 'r') as infile, \ open(train_file, 'w') as train_out, \ open(test_file, 'w') as test_out: for i, line in tqdm(enumerate(infile), total=total_sequences, desc="Splitting data"): if i in test_indices: test_out.write(line) else: train_out.write(line) print(f"Done! Training data saved to {train_file}, test data saved to {test_file}") # Print file sizes train_size_mb = os.path.getsize(train_file) / (1024 * 1024) test_size_mb = os.path.getsize(test_file) / (1024 * 1024) print(f"Training file size: {train_size_mb:.2f} MB") print(f"Test file size: {test_size_mb:.2f} MB") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Split tokenized data into training and test sets") parser.add_argument("--input", type=str, default="./data/output.txt", help="Input file path") parser.add_argument("--train", type=str, default="./data/train.txt", help="Output path for training data") parser.add_argument("--test", type=str, default="./data/test.txt", help="Output path for test data") parser.add_argument("--test-ratio", type=float, default=0.1, help="Proportion of data to use for testing") parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") args = parser.parse_args() split_data(args.input, args.train, args.test, args.test_ratio, args.seed)