|
|
import os
|
|
|
import random
|
|
|
import argparse
|
|
|
from pathlib import Path
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
def split_data(input_file, train_file, test_file, test_ratio=0.1, seed=42):
|
|
|
"""
|
|
|
Split tokenized data into training and test sets.
|
|
|
|
|
|
Args:
|
|
|
input_file (str): Path to the input file containing tokenized sequences
|
|
|
train_file (str): Path to write training sequences
|
|
|
test_file (str): Path to write test sequences
|
|
|
test_ratio (float): Proportion of data to use for testing (default: 0.1)
|
|
|
seed (int): Random seed for reproducibility (default: 42)
|
|
|
"""
|
|
|
random.seed(seed)
|
|
|
|
|
|
|
|
|
print("Counting sequences in the file...")
|
|
|
with open(input_file, 'r') as f:
|
|
|
total_sequences = sum(1 for _ in f)
|
|
|
|
|
|
print(f"Total sequences found: {total_sequences}")
|
|
|
|
|
|
|
|
|
test_count = int(total_sequences * test_ratio)
|
|
|
train_count = total_sequences - test_count
|
|
|
|
|
|
|
|
|
all_indices = list(range(total_sequences))
|
|
|
random.shuffle(all_indices)
|
|
|
test_indices = set(all_indices[:test_count])
|
|
|
|
|
|
print(f"Splitting data: {train_count} training sequences, {test_count} test sequences")
|
|
|
|
|
|
|
|
|
with open(input_file, 'r') as infile, \
|
|
|
open(train_file, 'w') as train_out, \
|
|
|
open(test_file, 'w') as test_out:
|
|
|
|
|
|
for i, line in tqdm(enumerate(infile), total=total_sequences, desc="Splitting data"):
|
|
|
if i in test_indices:
|
|
|
test_out.write(line)
|
|
|
else:
|
|
|
train_out.write(line)
|
|
|
|
|
|
print(f"Done! Training data saved to {train_file}, test data saved to {test_file}")
|
|
|
|
|
|
|
|
|
train_size_mb = os.path.getsize(train_file) / (1024 * 1024)
|
|
|
test_size_mb = os.path.getsize(test_file) / (1024 * 1024)
|
|
|
print(f"Training file size: {train_size_mb:.2f} MB")
|
|
|
print(f"Test file size: {test_size_mb:.2f} MB")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
parser = argparse.ArgumentParser(description="Split tokenized data into training and test sets")
|
|
|
parser.add_argument("--input", type=str, default="./data/output.txt", help="Input file path")
|
|
|
parser.add_argument("--train", type=str, default="./data/train.txt", help="Output path for training data")
|
|
|
parser.add_argument("--test", type=str, default="./data/test.txt", help="Output path for test data")
|
|
|
parser.add_argument("--test-ratio", type=float, default=0.1, help="Proportion of data to use for testing")
|
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
split_data(args.input, args.train, args.test, args.test_ratio, args.seed) |