File size: 2,335 Bytes
ee3e701
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import json
import os
from itertools import chain
from os.path import join

import numpy as np
from joblib import Parallel, delayed
from tqdm import tqdm


def img_concat(data_root, output_root, data_name, token_num_per_sentence):
    data_dir = join(data_root, data_name)
    out_dir = join(output_root, data_name)
    
    os.makedirs(out_dir, exist_ok=True)
    
    data_bin_path = os.path.join(data_dir, "train.bin")
    out_data_bin_path = os.path.join(out_dir, "train.bin")
    out_data_meta_path = os.path.join(out_dir, "train.bin.meta")
    
    with open(data_bin_path, "r") as bin_file:
        data_bin = bin_file.readlines()
    
    cu = 0
    new_data_bin = []
    cu_seq_len_list = []
    
    sentence = []
    for index, data in enumerate(data_bin):
        data = json.loads(data)['tokens']
        if index > 0 and index % token_num_per_sentence == 0:
            tokens = list(chain(*sentence))
            seq_len = len(tokens)
            saved_bin = str.encode(json.dumps(dict(tokens=tokens)) + "\n")
            
            new_data_bin.append(saved_bin)
            cu_seq_len_list.append((cu, seq_len))
            cu += len(saved_bin)
            sentence = []
        
        sentence.append(data)
    
    tokens = list(chain(*sentence))
    seq_len = len(tokens)
    saved_bin = str.encode(json.dumps(dict(tokens=tokens)) + "\n")
    
    new_data_bin.append(saved_bin)
    cu_seq_len_list.append((cu, seq_len))
    cu += len(saved_bin)
    
    with open(out_data_bin_path, "wb+") as out_bin_file:
        out_bin_file.writelines(new_data_bin)
    np.save(out_data_meta_path, cu_seq_len_list)
    os.rename(f'{out_data_meta_path}.npy', out_data_meta_path)


if __name__ == '__main__':
    token_num_per_sentence = 6
    file_name = 'Rain13K'
    data_root = '/home/ma-user/work/data/vq_token'
    
    data_dir = join(data_root, file_name)
    output_dir = join(data_root, f'{file_name}-sentence_{token_num_per_sentence}')
    
    # for data_name in tqdm(os.listdir(data_dir)):
    #     img_concat(data_dir, output_dir, data_name, token_num_per_sentence)
    
    Parallel(n_jobs=64)(delayed(img_concat)(data_dir, output_dir, data_name, token_num_per_sentence)
                        for data_name in tqdm(os.listdir(data_dir)))