ObjectRelator-Original / datasets /build_davis_augument.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
import argparse
import json
import tqdm
import cv2
import os
import numpy as np
from pycocotools import mask as mask_utils
import random
from PIL import Image
import copy
import string
def extract_object_name(text):
parts = text.split("is")
if len(parts) > 1:
return parts[1].strip()
return None
if __name__ == "__main__":
#每一帧随机增强数量
augu_num = 4
original_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval_val_psalm_instruction_train_correct_new.json"
with open(original_path, "r") as fp:
datas_origianl = json.load(fp)
#用来存放新数据 最后的数据是原始数据和新数据的拼接
new_data = []
json_path = '/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval_val_psalm_withtext_train_new_targetframe.json'
save_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval_val_psalm_train_new_augument_n4_instruction.json"
with open(json_path, "r") as fp:
datas = json.load(fp)
#用来计数
total_num = len(datas)
k = 0
sent_id = 9444
#统计video_name
set_path = "/data/work-gcp-europe-west4-a/yuqian_fu/datasets/DAVIS/2017/trainval/ImageSets/2017/train.txt"
video_names = []
with open(set_path, 'r') as f:
for line in f:
video_names.append(line.strip())
for video in video_names:
#在同一个video内做增强
data_list = []
for data in datas:
if data["video_name"] == video:
data_list.append(data)
for data in data_list:
sample_unique_instances = []
for ann in data['anns']:
sample_unique_instances.append(ann['category_id'])
data_sample = random.sample(data_list, augu_num)
for sample in data_sample:
#对随机挑选出的k个样本做筛选
if data['new_img_id'] == sample['new_img_id']:
continue
#当前帧物体的数量要小于参考帧的数量,筛选
if len(data['anns']) > len(sample['anns']):
continue
unique_instances = []
#参考帧的物体类别一定要包含当前帧的
for ann in sample['anns']:
unique_instances.append(ann['category_id'])
skip = False
for id in sample_unique_instances:
if id not in unique_instances:
skip = True
break
if skip:
continue
first_frame_anns = copy.deepcopy(sample['anns'])
if len(data['anns']) < len(first_frame_anns):
first_frame_anns = [ann for ann in first_frame_anns if ann['category_id'] in sample_unique_instances]
# print(len(data['anns']), len(first_frame_anns))
# print("unique_instances", unique_instances)
# print("sample_unique_instances", sample_unique_instances) #debug
assert len(data['anns']) == len(first_frame_anns)
skip_text = False #debug
instruct_list = []
for anno in first_frame_anns:
text = anno["text"]
# 提取is之后的句子
raw = extract_object_name(text)
#将raw变小写
if raw == None: #debug
skip_text = True
print(sample['image'])
break
raw_lower = raw.lower()
# 删除 "green" 并去掉多余的空格
result = raw_lower.replace("green", "").strip()
# 删除所有标点符号
sent = result.translate(str.maketrans('', '', string.punctuation))
tokens = sent.split()
sample_text = {
"tokens": tokens,
"raw": raw,
"sent_id": sent_id,
"sent": sent
}
sent_id += 1
instruct_list.append(sample_text)
if skip_text:
continue
data_new = {
'image': data['image'],
'image_info': data['image_info'],
'anns': data['anns'],
'first_frame_image':sample['image'],
'first_frame_anns': first_frame_anns,
'new_img_id': total_num+k,
'video_name': data['video_name'],
"instruction": instruct_list
}
new_data.append(data_new)
k += 1
data_all = datas_origianl + new_data
with open(save_path, 'w') as f:
json.dump(data_all, f)
print(f'Save at {save_path}. Total sample: {len(data_all)}')