beatalignment / split_data.py
william590y's picture
Upload folder using huggingface_hub
151b875 verified
import os
import random
import argparse
from pathlib import Path
from tqdm import tqdm
def split_data(input_file, train_file, test_file, test_ratio=0.1, seed=42):
"""
Split tokenized data into training and test sets.
Args:
input_file (str): Path to the input file containing tokenized sequences
train_file (str): Path to write training sequences
test_file (str): Path to write test sequences
test_ratio (float): Proportion of data to use for testing (default: 0.1)
seed (int): Random seed for reproducibility (default: 42)
"""
random.seed(seed)
# Count number of sequences in the file
print("Counting sequences in the file...")
with open(input_file, 'r') as f:
total_sequences = sum(1 for _ in f)
print(f"Total sequences found: {total_sequences}")
# Determine how many sequences to put in test set
test_count = int(total_sequences * test_ratio)
train_count = total_sequences - test_count
# Generate indices for test set
all_indices = list(range(total_sequences))
random.shuffle(all_indices)
test_indices = set(all_indices[:test_count])
print(f"Splitting data: {train_count} training sequences, {test_count} test sequences")
# Split the data
with open(input_file, 'r') as infile, \
open(train_file, 'w') as train_out, \
open(test_file, 'w') as test_out:
for i, line in tqdm(enumerate(infile), total=total_sequences, desc="Splitting data"):
if i in test_indices:
test_out.write(line)
else:
train_out.write(line)
print(f"Done! Training data saved to {train_file}, test data saved to {test_file}")
# Print file sizes
train_size_mb = os.path.getsize(train_file) / (1024 * 1024)
test_size_mb = os.path.getsize(test_file) / (1024 * 1024)
print(f"Training file size: {train_size_mb:.2f} MB")
print(f"Test file size: {test_size_mb:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Split tokenized data into training and test sets")
parser.add_argument("--input", type=str, default="./data/output.txt", help="Input file path")
parser.add_argument("--train", type=str, default="./data/train.txt", help="Output path for training data")
parser.add_argument("--test", type=str, default="./data/test.txt", help="Output path for test data")
parser.add_argument("--test-ratio", type=float, default=0.1, help="Proportion of data to use for testing")
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
args = parser.parse_args()
split_data(args.input, args.train, args.test, args.test_ratio, args.seed)