File size: 1,730 Bytes
db2dd1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import random
import os
from tqdm import tqdm

def count_lines(filename):
    with open(filename, 'rb') as f:
        return sum(1 for _ in f)

def shuffle_datasets(src_file, tgt_file, output_src, output_tgt, chunk_size=1000000):
    total_lines = count_lines(src_file)
    print(f"Total lines: {total_lines}")

    # Generate shuffled indices
    indices = list(range(total_lines))
    random.shuffle(indices)

    with open(src_file, 'r', encoding='utf-8') as src, \
         open(tgt_file, 'r', encoding='utf-8') as tgt, \
         open(output_src, 'w', encoding='utf-8') as out_src, \
         open(output_tgt, 'w', encoding='utf-8') as out_tgt:

        src_lines = src.readlines()
        tgt_lines = tgt.readlines()

        for i in tqdm(range(0, total_lines, chunk_size), desc="Shuffling"):
            chunk_indices = indices[i:i+chunk_size]
            chunk_src = [src_lines[idx] for idx in chunk_indices]
            chunk_tgt = [tgt_lines[idx] for idx in chunk_indices]

            out_src.writelines(chunk_src)
            out_tgt.writelines(chunk_tgt)

def main():
    src_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_cleaned_shuffled.src-filtered.src'
    tgt_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_shuffled.tgt-filtered.tgt'
    output_src = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.src'
    output_tgt = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.tgt'
    
    print("Starting to aggressively shuffle datasets...")
    shuffle_datasets(src_file, tgt_file, output_src, output_tgt)
    print("Shuffling completed. Aggressively shuffled datasets saved to:")
    print(f"Source: {output_src}")
    print(f"Target: {output_tgt}")

if __name__ == "__main__":
    main()