Vikrantyadav11234 commited on
Commit
db2dd1f
·
verified ·
1 Parent(s): c31e9b4

Upload shuffle_aligned_datasets.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. shuffle_aligned_datasets.py +46 -0
shuffle_aligned_datasets.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+ from tqdm import tqdm
4
+
5
+ def count_lines(filename):
6
+ with open(filename, 'rb') as f:
7
+ return sum(1 for _ in f)
8
+
9
+ def shuffle_datasets(src_file, tgt_file, output_src, output_tgt, chunk_size=1000000):
10
+ total_lines = count_lines(src_file)
11
+ print(f"Total lines: {total_lines}")
12
+
13
+ # Generate shuffled indices
14
+ indices = list(range(total_lines))
15
+ random.shuffle(indices)
16
+
17
+ with open(src_file, 'r', encoding='utf-8') as src, \
18
+ open(tgt_file, 'r', encoding='utf-8') as tgt, \
19
+ open(output_src, 'w', encoding='utf-8') as out_src, \
20
+ open(output_tgt, 'w', encoding='utf-8') as out_tgt:
21
+
22
+ src_lines = src.readlines()
23
+ tgt_lines = tgt.readlines()
24
+
25
+ for i in tqdm(range(0, total_lines, chunk_size), desc="Shuffling"):
26
+ chunk_indices = indices[i:i+chunk_size]
27
+ chunk_src = [src_lines[idx] for idx in chunk_indices]
28
+ chunk_tgt = [tgt_lines[idx] for idx in chunk_indices]
29
+
30
+ out_src.writelines(chunk_src)
31
+ out_tgt.writelines(chunk_tgt)
32
+
33
+ def main():
34
+ src_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_cleaned_shuffled.src-filtered.src'
35
+ tgt_file = '/home/vikrant-MNMT/myenv/NMT_V2/train_shuffled.tgt-filtered.tgt'
36
+ output_src = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.src'
37
+ output_tgt = '/home/vikrant-MNMT/myenv/NMT_V2/train_aggressively_shuffled.tgt'
38
+
39
+ print("Starting to aggressively shuffle datasets...")
40
+ shuffle_datasets(src_file, tgt_file, output_src, output_tgt)
41
+ print("Shuffling completed. Aggressively shuffled datasets saved to:")
42
+ print(f"Source: {output_src}")
43
+ print(f"Target: {output_tgt}")
44
+
45
+ if __name__ == "__main__":
46
+ main()