|
|
import random |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
|
|
|
def count_lines(filename): |
|
|
with open(filename, 'rb') as f: |
|
|
return sum(1 for _ in f) |
|
|
|
|
|
def shuffle_datasets(src_file, tgt_file, output_src, output_tgt, chunk_size=1000000): |
|
|
total_lines = count_lines(src_file) |
|
|
print(f"Total lines: {total_lines}") |
|
|
|
|
|
|
|
|
indices = list(range(total_lines)) |
|
|
random.shuffle(indices) |
|
|
|
|
|
with open(src_file, 'r', encoding='utf-8') as src, \ |
|
|
open(tgt_file, 'r', encoding='utf-8') as tgt, \ |
|
|
open(output_src, 'w', encoding='utf-8') as out_src, \ |
|
|
open(output_tgt, 'w', encoding='utf-8') as out_tgt: |
|
|
|
|
|
src_lines = src.readlines() |
|
|
tgt_lines = tgt.readlines() |
|
|
|
|
|
for i in tqdm(range(0, total_lines, chunk_size), desc="Shuffling"): |
|
|
chunk_indices = indices[i:i+chunk_size] |
|
|
chunk_src = [src_lines[idx] for idx in chunk_indices] |
|
|
chunk_tgt = [tgt_lines[idx] for idx in chunk_indices] |
|
|
|
|
|
out_src.writelines(chunk_src) |
|
|
out_tgt.writelines(chunk_tgt) |
|
|
|
|
|
def main(): |
|
|
src_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_cleaned_shuffled.src-filtered.src' |
|
|
tgt_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_shuffled.tgt-filtered.tgt' |
|
|
output_src = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.src' |
|
|
output_tgt = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.tgt' |
|
|
|
|
|
print("Starting to aggressively shuffle datasets...") |
|
|
shuffle_datasets(src_file, tgt_file, output_src, output_tgt) |
|
|
print("Shuffling completed. Aggressively shuffled datasets saved to:") |
|
|
print(f"Source: {output_src}") |
|
|
print(f"Target: {output_tgt}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|