File size: 5,028 Bytes
625a17f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import argparse
import json
import tqdm
import cv2
import os
import numpy as np
from pycocotools import mask as mask_utils
import random
from PIL import Image
import copy
import string


def extract_object_name(text):
    parts = text.split("is")
    if len(parts) > 1:
        return parts[1].strip()
    return None


if __name__ == "__main__":
    #每一帧随机增强数量
    augu_num = 4
    original_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval_val_psalm_instruction_train_correct_new.json"
    with open(original_path, "r") as fp:
        datas_origianl = json.load(fp)
    #用来存放新数据 最后的数据是原始数据和新数据的拼接
    new_data = []
    json_path = '/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval_val_psalm_withtext_train_new_targetframe.json'
    save_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval_val_psalm_train_new_augument_n4_instruction.json"
    with open(json_path, "r") as fp:
        datas = json.load(fp)
    

    #用来计数
    total_num = len(datas)
    k = 0
    sent_id = 9444

    #统计video_name
    set_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval/ImageSets/2017/train.txt"
    video_names = []
    with open(set_path, 'r') as f:
        for line in f:
            video_names.append(line.strip())


    for video in video_names:
        #在同一个video内做增强
        data_list = []
        for data in datas:
            if data["video_name"] == video:
                data_list.append(data)
        for data in data_list:
            sample_unique_instances = []
            for ann in data['anns']:
                sample_unique_instances.append(ann['category_id'])

            data_sample = random.sample(data_list, augu_num)
            for sample in data_sample:
                #对随机挑选出的k个样本做筛选
                if data['new_img_id'] == sample['new_img_id']:
                    continue
                #当前帧物体的数量要小于参考帧的数量,筛选
                if len(data['anns']) > len(sample['anns']):
                    continue

                unique_instances = []
                #参考帧的物体类别一定要包含当前帧的
                for ann in sample['anns']:
                    unique_instances.append(ann['category_id'])
                
                skip = False
                for id in sample_unique_instances:
                    if id not in unique_instances:
                        skip = True
                        break
                if skip:
                    continue

                first_frame_anns = copy.deepcopy(sample['anns'])
                if len(data['anns']) < len(first_frame_anns):
                    first_frame_anns = [ann for ann in first_frame_anns if ann['category_id'] in sample_unique_instances]
                # print(len(data['anns']), len(first_frame_anns))
                # print("unique_instances", unique_instances)
                # print("sample_unique_instances", sample_unique_instances) #debug
                assert len(data['anns']) == len(first_frame_anns)

                skip_text = False #debug
                instruct_list = []
                for anno in first_frame_anns:
                    text = anno["text"]
                    # 提取is之后的句子
                    raw = extract_object_name(text)
                    #将raw变小写
                    if raw == None:  #debug
                        skip_text = True
                        print(sample['image'])
                        break
                    raw_lower = raw.lower()
                    # 删除 "green" 并去掉多余的空格
                    result = raw_lower.replace("green", "").strip()  
                    # 删除所有标点符号
                    sent = result.translate(str.maketrans('', '', string.punctuation))
                    tokens = sent.split()
                    sample_text = {
                        "tokens": tokens,
                        "raw": raw,
                        "sent_id": sent_id,
                        "sent": sent
                    }
                    sent_id += 1
                    instruct_list.append(sample_text)
                if skip_text:
                    continue


                data_new = {
                'image': data['image'],
                'image_info': data['image_info'],
                'anns': data['anns'],
                'first_frame_image':sample['image'],
                'first_frame_anns': first_frame_anns,
                'new_img_id': total_num+k,
                'video_name': data['video_name'],
                "instruction": instruct_list
                }
                new_data.append(data_new)
                k += 1


    data_all = datas_origianl + new_data
    with open(save_path, 'w') as f:
        json.dump(data_all, f)
    print(f'Save at {save_path}. Total sample: {len(data_all)}')