Important_NMT_DOCs / shuffle_aligned_datasets.py
Vikrantyadav11234's picture
Upload shuffle_aligned_datasets.py with huggingface_hub
db2dd1f verified
import random
import os
from tqdm import tqdm
def count_lines(filename):
with open(filename, 'rb') as f:
return sum(1 for _ in f)
def shuffle_datasets(src_file, tgt_file, output_src, output_tgt, chunk_size=1000000):
total_lines = count_lines(src_file)
print(f"Total lines: {total_lines}")
# Generate shuffled indices
indices = list(range(total_lines))
random.shuffle(indices)
with open(src_file, 'r', encoding='utf-8') as src, \
open(tgt_file, 'r', encoding='utf-8') as tgt, \
open(output_src, 'w', encoding='utf-8') as out_src, \
open(output_tgt, 'w', encoding='utf-8') as out_tgt:
src_lines = src.readlines()
tgt_lines = tgt.readlines()
for i in tqdm(range(0, total_lines, chunk_size), desc="Shuffling"):
chunk_indices = indices[i:i+chunk_size]
chunk_src = [src_lines[idx] for idx in chunk_indices]
chunk_tgt = [tgt_lines[idx] for idx in chunk_indices]
out_src.writelines(chunk_src)
out_tgt.writelines(chunk_tgt)
def main():
src_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_cleaned_shuffled.src-filtered.src'
tgt_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_shuffled.tgt-filtered.tgt'
output_src = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.src'
output_tgt = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.tgt'
print("Starting to aggressively shuffle datasets...")
shuffle_datasets(src_file, tgt_file, output_src, output_tgt)
print("Shuffling completed. Aggressively shuffled datasets saved to:")
print(f"Source: {output_src}")
print(f"Target: {output_tgt}")
if __name__ == "__main__":
main()