File size: 5,028 Bytes
625a17f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | import argparse
import json
import tqdm
import cv2
import os
import numpy as np
from pycocotools import mask as mask_utils
import random
from PIL import Image
import copy
import string
def extract_object_name(text):
parts = text.split("is")
if len(parts) > 1:
return parts[1].strip()
return None
if __name__ == "__main__":
#每一帧随机增强数量
augu_num = 4
original_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval_val_psalm_instruction_train_correct_new.json"
with open(original_path, "r") as fp:
datas_origianl = json.load(fp)
#用来存放新数据 最后的数据是原始数据和新数据的拼接
new_data = []
json_path = '/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval_val_psalm_withtext_train_new_targetframe.json'
save_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval_val_psalm_train_new_augument_n4_instruction.json"
with open(json_path, "r") as fp:
datas = json.load(fp)
#用来计数
total_num = len(datas)
k = 0
sent_id = 9444
#统计video_name
set_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval/ImageSets/2017/train.txt"
video_names = []
with open(set_path, 'r') as f:
for line in f:
video_names.append(line.strip())
for video in video_names:
#在同一个video内做增强
data_list = []
for data in datas:
if data["video_name"] == video:
data_list.append(data)
for data in data_list:
sample_unique_instances = []
for ann in data['anns']:
sample_unique_instances.append(ann['category_id'])
data_sample = random.sample(data_list, augu_num)
for sample in data_sample:
#对随机挑选出的k个样本做筛选
if data['new_img_id'] == sample['new_img_id']:
continue
#当前帧物体的数量要小于参考帧的数量,筛选
if len(data['anns']) > len(sample['anns']):
continue
unique_instances = []
#参考帧的物体类别一定要包含当前帧的
for ann in sample['anns']:
unique_instances.append(ann['category_id'])
skip = False
for id in sample_unique_instances:
if id not in unique_instances:
skip = True
break
if skip:
continue
first_frame_anns = copy.deepcopy(sample['anns'])
if len(data['anns']) < len(first_frame_anns):
first_frame_anns = [ann for ann in first_frame_anns if ann['category_id'] in sample_unique_instances]
# print(len(data['anns']), len(first_frame_anns))
# print("unique_instances", unique_instances)
# print("sample_unique_instances", sample_unique_instances) #debug
assert len(data['anns']) == len(first_frame_anns)
skip_text = False #debug
instruct_list = []
for anno in first_frame_anns:
text = anno["text"]
# 提取is之后的句子
raw = extract_object_name(text)
#将raw变小写
if raw == None: #debug
skip_text = True
print(sample['image'])
break
raw_lower = raw.lower()
# 删除 "green" 并去掉多余的空格
result = raw_lower.replace("green", "").strip()
# 删除所有标点符号
sent = result.translate(str.maketrans('', '', string.punctuation))
tokens = sent.split()
sample_text = {
"tokens": tokens,
"raw": raw,
"sent_id": sent_id,
"sent": sent
}
sent_id += 1
instruct_list.append(sample_text)
if skip_text:
continue
data_new = {
'image': data['image'],
'image_info': data['image_info'],
'anns': data['anns'],
'first_frame_image':sample['image'],
'first_frame_anns': first_frame_anns,
'new_img_id': total_num+k,
'video_name': data['video_name'],
"instruction": instruct_list
}
new_data.append(data_new)
k += 1
data_all = datas_origianl + new_data
with open(save_path, 'w') as f:
json.dump(data_all, f)
print(f'Save at {save_path}. Total sample: {len(data_all)}') |