File size: 2,085 Bytes
7e3a804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# generate **.pth
import os
import sys
import torch
sys.path.append('.')

import argparse
parser = argparse.ArgumentParser(description='Data preparation')
parser.add_argument('--dataset', type=str, choices=['refcoco', 'refcoco+','refcocog_google', 'refcocog_umd'], default='refcoco')
args = parser.parse_args()

def main(args):
    dataset = args.dataset
    input_txt_list = os.listdir(f'../ln_data/anns/{dataset}')
    if not os.path.exists(f'../data/{dataset}'):
        os.makedirs(f'../data/{dataset}')
    for input_txt in input_txt_list:
        split = input_txt.split('_')[-1].split('.')[0]
        input_txt = os.path.join('../ln_data/anns', dataset, input_txt)
        res = []
        with open(input_txt, encoding='utf-8') as f:
            lines = f.readlines()
            for line in lines:
                line = line.split() 
                stop = len(line)
                img_name = line[0]
                for i in range(1,len(line)):
                    if (line[i]=='~'):
                        stop=i
                        break
                box_ = [list(map(int,box.split(','))) for box in line[1:stop]]
                box = box_[0][:4]
                seg_id=box_[0][-1]
                
                sent_stop=stop+1
                for i in range(stop+1,len(line)):
                    if line[i]=='~': 
                        des = ''
                        for word in line[sent_stop:i]:
                            des = des + word + ' '
                        sent_stop=i+1
                        des = des.rstrip(' ')
                        res.append((img_name, seg_id, box, des))
                des = ''
                for word in line[sent_stop:len(line)]:
                    des = des + word + ' '
                des = des.rstrip(' ')
                res.append((img_name, seg_id, box, des))
            # print(res)

        imgset_path = '{0}_{1}.pth'.format(dataset, split)
        images = torch.save(res, os.path.join("../data", dataset, imgset_path))
    print(dataset, " done")

if __name__ == "__main__":
    main(args)