SAT / dataset.py
Darknsu's picture
Update dataset.py
3974320 verified
import numpy as np
import h5py
import json
import torch
import torch.utils.data as data
import os
import pickle
from multiprocessing import Pool
def load_json(file):
with open(file) as json_file:
data = json.load(json_file)
return data
def calc_iou(a, b):
st = a[0] - a[1]
ed = a[0]
target_st = b[0] - b[1]
target_ed = b[0]
sst = min(st, target_st)
led = max(ed, target_ed)
lst = max(st, target_st)
sed = min(ed, target_ed)
iou = (sed - lst) / max(led - sst, 1)
return iou
def box_include(y, target):
st = y[0] - y[1]
ed = y[0]
target_st = target[0] - target[1]
target_ed = target[0]
detection_point = target_st
if ed > detection_point and target_st < st and target_ed > ed:
return True
return False
class VideoDataSet(data.Dataset):
def __init__(self, opt, subset="train", video_name=None):
self.subset = subset
self.mode = opt["mode"]
self.predefined_fps = opt["predefined_fps"]
self.video_anno_path = opt["video_anno"].format(opt["split"])
self.video_len_path = opt["video_len_file"].format(self.subset + '_' + opt["setup"])
self.num_of_class = opt["num_of_class"]
self.segment_size = opt["segment_size"]
self.label_name = []
self.match_score = {}
self.match_score_end = {}
self.match_length = {}
self.gt_action = {}
self.cls_label = {}
self.reg_label = {}
self.snip_label = {}
self.inputs = []
self.inputs_all = []
self.data_rescale = opt["data_rescale"]
self.anchors = opt["anchors"]
self.pos_threshold = opt["pos_threshold"]
self.single_video_name = video_name
self._getDatasetDict()
self._loadFeaturelen(opt)
self._getMatchScore()
self._makeInputSeq()
self._loadPropLabel(opt['proposal_label_file'].format(self.subset + '_' + opt["setup"]))
if self.subset == "train":
if opt['data_format'] == "h5":
feature_rgb_file = h5py.File(opt["video_feature_rgb_train"], 'r')
self.feature_rgb_file = {}
keys = self.video_list
for vidx in range(len(keys)):
if keys[vidx] not in feature_rgb_file:
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_train']}")
self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:])
if opt['rgb_only']:
self.feature_flow_file = None
else:
self.feature_flow_file = {}
feature_flow_file = h5py.File(opt["video_feature_flow_train"], 'r')
for vidx in range(len(keys)):
if keys[vidx] not in feature_flow_file:
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_train']}")
self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:])
elif opt['data_format'] == "pickle":
feature_All = pickle.load(open(opt["video_feature_all_train"], 'rb'))
self.feature_rgb_file = {}
self.feature_flow_file = {}
keys = self.video_list
for vidx in range(len(keys)):
if keys[vidx] not in feature_All:
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_train']}")
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
elif opt['data_format'] == "npz":
feature_All = {}
self.feature_rgb_file = {}
self.feature_flow_file = {}
for file in self.video_list:
feature_path = os.path.join(opt["video_feature_all_train"], file + '.npz')
if not os.path.exists(feature_path):
raise ValueError(f"Feature file {feature_path} not found")
feature_All[file] = np.load(feature_path)['feats']
keys = self.video_list
for vidx in range(len(keys)):
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
self.feature_flow_file = None
elif opt['data_format'] == "npz_i3d":
feature_All = {}
self.feature_rgb_file = {}
self.feature_flow_file = {}
for file in self.video_list:
feature_path = os.path.join(opt["video_feature_all_train"], file + '.npz')
if not os.path.exists(feature_path):
raise ValueError(f"Feature file {feature_path} not found")
feature_All[file] = np.load(feature_path)
keys = self.video_list
for vidx in range(len(keys)):
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
elif opt['data_format'] == "pt":
feature_All = {}
self.feature_rgb_file = {}
self.feature_flow_file = {}
for file in self.video_list:
feature_path = os.path.join(opt["video_feature_all_train"], file + '.pt')
if not os.path.exists(feature_path):
raise ValueError(f"Feature file {feature_path} not found")
feature_All[file] = torch.load(feature_path)
keys = self.video_list
for vidx in range(len(keys)):
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
self.feature_flow_file = None
else:
if opt['data_format'] == "h5":
feature_rgb_file = h5py.File(opt["video_feature_rgb_test"], 'r')
self.feature_rgb_file = {}
keys = self.video_list
for vidx in range(len(keys)):
if keys[vidx] not in feature_rgb_file:
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_test']}")
self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:])
if opt['rgb_only']:
self.feature_flow_file = None
else:
self.feature_flow_file = {}
feature_flow_file = h5py.File(opt["video_feature_flow_test"], 'r')
for vidx in range(len(keys)):
if keys[vidx] not in feature_flow_file:
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_test']}")
self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:])
elif opt['data_format'] == "pickle":
feature_All = pickle.load(open(opt["video_feature_all_test"], 'rb'))
self.feature_rgb_file = {}
self.feature_flow_file = {}
keys = self.video_list
for vidx in range(len(keys)):
if keys[vidx] not in feature_All:
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_test']}")
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
elif opt['data_format'] == "npz":
feature_All = {}
self.feature_rgb_file = {}
self.feature_flow_file = {}
for file in self.video_list:
feature_path = os.path.join(opt["video_feature_all_test"], file + '.npz')
if not os.path.exists(feature_path):
raise ValueError(f"Feature file {feature_path} not found")
feature_All[file] = np.load(feature_path)['feats']
keys = self.video_list
for vidx in range(len(keys)):
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
self.feature_flow_file = None
elif opt['data_format'] == "npz_i3d":
feature_All = {}
self.feature_rgb_file = {}
self.feature_flow_file = {}
for file in self.video_list:
feature_path = os.path.join(opt["video_feature_all_test"], file + '.npz')
if not os.path.exists(feature_path):
raise ValueError(f"Feature file {feature_path} not found")
feature_All[file] = np.load(feature_path)
keys = self.video_list
for vidx in range(len(keys)):
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
elif opt['data_format'] == "pt":
feature_All = {}
self.feature_rgb_file = {}
self.feature_flow_file = {}
for file in self.video_list:
feature_path = os.path.join(opt["video_feature_all_test"], file + '.pt')
if not os.path.exists(feature_path):
raise ValueError(f"Feature file {feature_path} not found")
feature_All[file] = torch.load(feature_path)
keys = self.video_list
for vidx in range(len(keys)):
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
self.feature_flow_file = None
def _loadFeaturelen(self, opt):
if os.path.exists(self.video_len_path):
self.video_len = load_json(self.video_len_path)
return
self.video_len = {}
if self.subset == "train":
if opt['data_format'] == "h5":
feature_file = h5py.File(opt["video_feature_rgb_train"], 'r')
elif opt['data_format'] == "pickle":
feature_file = pickle.load(open(opt["video_feature_all_train"], 'rb'))
elif opt['data_format'] == "npz":
feature_file = {}
for file in self.video_list:
feature_file[file] = np.load(os.path.join(opt["video_feature_all_train"], file + '.npz'))['feats']
elif opt['data_format'] == "npz_i3d":
feature_file = {}
for file in self.video_list:
feature_file[file] = np.load(os.path.join(opt["video_feature_all_train"], file + '.npz'))
elif opt['data_format'] == "pt":
feature_file = {}
for file in self.video_list:
feature_file[file] = torch.load(os.path.join(opt["video_feature_all_train"], file + '.pt'))
else:
if opt['data_format'] == "h5":
feature_file = h5py.File(opt["video_feature_rgb_test"], 'r')
elif opt['data_format'] == "pickle":
feature_file = pickle.load(open(opt["video_feature_all_test"], 'rb'))
elif opt['data_format'] == "npz":
feature_file = {}
for file in self.video_list:
feature_file[file] = np.load(os.path.join(opt["video_feature_all_test"], file + '.npz'))['feats']
elif opt['data_format'] == "npz_i3d":
feature_file = {}
for file in self.video_list:
feature_file[file] = np.load(os.path.join(opt["video_feature_all_test"], file + '.npz'))
elif opt['data_format'] == "pt":
feature_file = {}
for file in self.video_list:
feature_file[file] = torch.load(os.path.join(opt["video_feature_all_test"], file + '.pt'))
keys = self.video_list
if opt['data_format'] == "h5":
for vidx in range(len(keys)):
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
elif opt['data_format'] == "pickle":
for vidx in range(len(keys)):
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
elif opt['data_format'] == "npz":
for vidx in range(len(keys)):
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
elif opt['data_format'] == "npz_i3d":
for vidx in range(len(keys)):
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
elif opt['data_format'] == "pt":
for vidx in range(len(keys)):
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
outfile = open(self.video_len_path, "w")
json.dump(self.video_len, outfile, indent=2)
outfile.close()
def _getDatasetDict(self):
anno_database = load_json(self.video_anno_path)
anno_database = anno_database['database']
self.video_dict = {}
if self.single_video_name:
if self.single_video_name in anno_database:
video_info = anno_database[self.single_video_name]
video_subset = video_info['subset']
if self.subset == "full" or self.subset in video_subset:
self.video_dict[self.single_video_name] = video_info
for seg in video_info['annotations']:
if not seg['label'] in self.label_name:
self.label_name.append(seg['label'])
else:
raise ValueError(f"Video {self.single_video_name} not found in annotation database")
else:
for video_name in anno_database:
video_info = anno_database[video_name]
video_subset = anno_database[video_name]['subset']
if self.subset == "full" or self.subset in video_subset:
self.video_dict[video_name] = video_info
for seg in video_info['annotations']:
if not seg['label'] in self.label_name:
self.label_name.append(seg['label'])
# Ensure all 22 EGTEA action classes are included
expected_labels = [
'Clean/Wipe', 'Close', 'Compress', 'Crack', 'Cut', 'Divide/Pull Apart',
'Dry', 'Inspect/Read', 'Mix', 'Move Around', 'Open', 'Operate', 'Other',
'Pour', 'Put', 'Squeeze', 'Take', 'Transfer', 'Turn off', 'Turn on', 'Wash',
'Spread' # Assumed missing label; replace with actual label if known
]
for label in expected_labels:
if label not in self.label_name:
self.label_name.append(label)
self.label_name.sort()
self.video_list = list(self.video_dict.keys())
print(f"Labels in dataset.label_name: {self.label_name}")
print(f"Number of labels: {len(self.label_name)}, Expected: {self.num_of_class-1}")
print(f"{self.subset} subset video numbers: {len(self.video_list)}")
def _getMatchScore(self):
self.action_end_count = torch.zeros(2)
for index in range(0, len(self.video_list)):
video_name = self.video_list[index]
video_info = self.video_dict[video_name]
video_labels = video_info['annotations']
gt_bbox = []
gt_edlen = []
second_to_frame = self.video_len[video_name] / float(video_info['duration'])
for j in range(len(video_labels)):
tmp_info = video_labels[j]
tmp_start = tmp_info['segment'][0] * second_to_frame
tmp_end = tmp_info['segment'][1] * second_to_frame
tmp_label = self.label_name.index(tmp_info['label'])
gt_bbox.append([tmp_start, tmp_end, tmp_label])
gt_edlen.append([gt_bbox[-1][1], gt_bbox[-1][1] - gt_bbox[-1][0], tmp_label])
gt_bbox = np.array(gt_bbox)
gt_edlen = np.array(gt_edlen)
self.gt_action[video_name] = gt_edlen
match_score = np.zeros((self.video_len[video_name], self.num_of_class - 1), dtype=np.float32)
for idx in range(gt_bbox.shape[0]):
ed = int(gt_bbox[idx, 1]) + 1
st = int(gt_bbox[idx, 0])
match_score[st:ed, int(gt_bbox[idx, 2])] = idx + 1
self.match_score[video_name] = match_score
def _makeInputSeq(self):
data_idx = 0
for index in range(0, len(self.video_list)):
video_name = self.video_list[index]
duration = self.match_score[video_name].shape[0]
for i in range(1, duration + 1):
st = i - self.segment_size
ed = i
self.inputs_all.append([video_name, st, ed, data_idx])
data_idx += 1
self.inputs = self.inputs_all.copy()
print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
def _makePropLabelUnit(self, i):
video_name = self.inputs_all[i][0]
st = self.inputs_all[i][1]
ed = self.inputs_all[i][2]
cls_anc = []
reg_anc = []
for j in range(0, len(self.anchors)):
v1 = np.zeros(self.num_of_class)
v1[-1] = 1
v2 = np.zeros(2)
v2[-1] = -1e3
y_box = [ed - 1, self.anchors[j]]
subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[j], ed)
idx_list = []
for ii in range(0, subset_label.shape[0]):
for jj in range(0, subset_label.shape[1]):
idx = int(subset_label[ii, jj])
if idx > 0 and idx - 1 not in idx_list:
idx_list.append(idx - 1)
for idx in idx_list:
target_box = self.gt_action[video_name][idx]
cls = int(target_box[2])
iou = calc_iou(y_box, target_box)
if iou >= self.pos_threshold or (j == len(self.anchors) - 1 and box_include(y_box, target_box)) or (j == 0 and box_include(target_box, y_box)):
v1[cls] = 1
v1[-1] = 0
v2[0] = 1.0 * (target_box[0] - y_box[0]) / self.anchors[j]
v2[1] = np.log(1.0 * max(1, target_box[1]) / y_box[1])
cls_anc.append(v1)
reg_anc.append(v2)
v0 = np.zeros(self.num_of_class)
v0[-1] = 1
segment_size = ed - st
y_box = [ed - 1, self.anchors[-1]]
subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[-1], ed)
idx_list = []
for ii in range(0, subset_label.shape[0]):
for jj in range(0, subset_label.shape[1]):
idx = int(subset_label[ii, jj])
if idx > 0 and idx - 1 not in idx_list:
idx_list.append(idx - 1)
for idx in idx_list:
target_box = self.gt_action[video_name][idx]
cls = int(target_box[2])
iou = calc_iou(y_box, target_box)
if iou >= 0:
v0[cls] = 1
v0[-1] = 0
cls_anc = np.stack(cls_anc, axis=0)
reg_anc = np.stack(reg_anc, axis=0)
cls_snip = np.array(v0)
return cls_anc, reg_anc, cls_snip
def _loadPropLabel(self, filename):
if os.path.exists(filename):
prop_label_file = h5py.File(filename, 'r')
self.cls_label = np.array(prop_label_file['cls_label'][:])
self.reg_label = np.array(prop_label_file['reg_label'][:])
self.snip_label = np.array(prop_label_file['snip_label'][:])
prop_label_file.close()
self.action_frame_count = np.sum(self.cls_label.reshape((-1, self.cls_label.shape[-1])), axis=0)
self.action_frame_count = torch.Tensor(self.action_frame_count)
return
pool = Pool(os.cpu_count() // 2)
labels = pool.map(self._makePropLabelUnit, range(0, len(self.inputs_all)))
pool.close()
pool.join()
cls_label = []
reg_label = []
snip_label = []
for i in range(0, len(labels)):
cls_label.append(labels[i][0])
reg_label.append(labels[i][1])
snip_label.append(labels[i][2])
self.cls_label = np.stack(cls_label, axis=0)
self.reg_label = np.stack(reg_label, axis=0)
self.snip_label = np.stack(snip_label, axis=0)
outfile = h5py.File(filename, 'w')
dset_cls = outfile.create_dataset('/cls_label', self.cls_label.shape, maxshape=self.cls_label.shape, chunks=True, dtype=np.float32)
dset_cls[:, :] = self.cls_label[:, :]
dset_reg = outfile.create_dataset('/reg_label', self.reg_label.shape, maxshape=self.reg_label.shape, chunks=True, dtype=np.float32)
dset_reg[:, :] = self.reg_label[:, :]
dset_snip = outfile.create_dataset('/snip_label', self.snip_label.shape, maxshape=self.snip_label.shape, chunks=True, dtype=np.float32)
dset_snip[:, :] = self.snip_label[:, :]
outfile.close()
return
def __getitem__(self, index):
video_name, st, ed, data_idx = self.inputs[index]
if st >= 0:
feature = self._get_base_data(video_name, st, ed)
else:
feature = self._get_base_data(video_name, 0, ed)
padfunc2d = torch.nn.ConstantPad2d((0, 0, -st, 0), 0)
feature = padfunc2d(feature)
cls_label = torch.Tensor(self.cls_label[data_idx])
reg_label = torch.Tensor(self.reg_label[data_idx])
snip_label = torch.Tensor(self.snip_label[data_idx])
return feature, cls_label, reg_label, snip_label
def _get_base_data(self, video_name, st, ed):
feature_rgb = self.feature_rgb_file[video_name]
feature_rgb = feature_rgb[st:ed, :]
if self.feature_flow_file is not None:
feature_flow = self.feature_flow_file[video_name]
feature_flow = feature_flow[st:ed, :]
feature = np.append(feature_rgb, feature_flow, axis=1)
else:
feature = feature_rgb
feature = torch.from_numpy(np.array(feature))
return feature
def _get_train_label_with_class(self, video_name, st, ed):
duration = len(self.match_score[video_name])
st_padding = 0
ed_padding = 0
if st < 0:
st_padding = -st
st = 0
if ed > duration:
ed_padding = ed - duration
ed = duration
match_score = torch.Tensor(self.match_score[video_name][st:ed])
if st_padding > 0:
padfunc2d = torch.nn.ConstantPad2d((0, 0, st_padding, 0), 0)
match_score = padfunc2d(match_score)
if ed_padding > 0:
padfunc2d = torch.nn.ConstantPad2d((0, 0, 0, ed_padding), 0)
match_score = padfunc2d(match_score)
return match_score
def __len__(self):
return len(self.inputs)
def reset_sample(self):
self.inputs = self.inputs_all.copy()
def select_sample(self, idx):
inputs = [self.inputs_all[i] for i in idx]
self.inputs = inputs.copy()
return
class SuppressDataSet(data.Dataset):
def __init__(self, opt, subset="train"):
self.subset = subset
self.mode = opt["mode"]
self.data_file = h5py.File(opt["suppress_label_file"].format(self.subset + "_" + opt['setup']), 'r')
self.video_list = list(self.data_file.keys())
self.inputs = []
for index in range(0, len(self.video_list)):
video_name = self.video_list[index]
duration = self.data_file[video_name + '/input'].shape[0]
for i in range(0, duration):
self.inputs.append([video_name, i])
print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
def __getitem__(self, index):
video_name, idx = self.inputs[index]
input_seq = self.data_file[video_name + '/input'][idx]
label = self.data_file[video_name + '/label'][idx]
input_seq = torch.from_numpy(input_seq)
label = torch.from_numpy(label)
return input_seq, label
def __len__(self):
return len(self.inputs)