|
|
import numpy as np
|
|
|
import h5py
|
|
|
import json
|
|
|
import torch
|
|
|
import torch.utils.data as data
|
|
|
import os
|
|
|
import pickle
|
|
|
from multiprocessing import Pool
|
|
|
|
|
|
def load_json(file):
|
|
|
with open(file) as json_file:
|
|
|
data = json.load(json_file)
|
|
|
return data
|
|
|
|
|
|
def calc_iou(a, b):
|
|
|
st = a[0] - a[1]
|
|
|
ed = a[0]
|
|
|
target_st = b[0] - b[1]
|
|
|
target_ed = b[0]
|
|
|
sst = min(st, target_st)
|
|
|
led = max(ed, target_ed)
|
|
|
lst = max(st, target_st)
|
|
|
sed = min(ed, target_ed)
|
|
|
iou = (sed - lst) / max(led - sst, 1)
|
|
|
return iou
|
|
|
|
|
|
def box_include(y, target):
|
|
|
st = y[0] - y[1]
|
|
|
ed = y[0]
|
|
|
target_st = target[0] - target[1]
|
|
|
target_ed = target[0]
|
|
|
detection_point = target_st
|
|
|
if ed > detection_point and target_st < st and target_ed > ed:
|
|
|
return True
|
|
|
return False
|
|
|
|
|
|
class VideoDataSet(data.Dataset):
|
|
|
def __init__(self, opt, subset="train", video_name=None):
|
|
|
self.subset = subset
|
|
|
self.mode = opt["mode"]
|
|
|
self.predefined_fps = opt["predefined_fps"]
|
|
|
self.video_anno_path = opt["video_anno"].format(opt["split"])
|
|
|
self.video_len_path = opt["video_len_file"].format(self.subset + '_' + opt["setup"])
|
|
|
self.num_of_class = opt["num_of_class"]
|
|
|
self.segment_size = opt["segment_size"]
|
|
|
self.label_name = []
|
|
|
self.match_score = {}
|
|
|
self.match_score_end = {}
|
|
|
self.match_length = {}
|
|
|
self.gt_action = {}
|
|
|
self.cls_label = {}
|
|
|
self.reg_label = {}
|
|
|
self.snip_label = {}
|
|
|
self.inputs = []
|
|
|
self.inputs_all = []
|
|
|
self.data_rescale = opt["data_rescale"]
|
|
|
self.anchors = opt["anchors"]
|
|
|
self.pos_threshold = opt["pos_threshold"]
|
|
|
self.single_video_name = video_name
|
|
|
|
|
|
self._getDatasetDict()
|
|
|
self._loadFeaturelen(opt)
|
|
|
self._getMatchScore()
|
|
|
self._makeInputSeq()
|
|
|
self._loadPropLabel(opt['proposal_label_file'].format(self.subset + '_' + opt["setup"]))
|
|
|
|
|
|
if self.subset == "train":
|
|
|
if opt['data_format'] == "h5":
|
|
|
feature_rgb_file = h5py.File(opt["video_feature_rgb_train"], 'r')
|
|
|
self.feature_rgb_file = {}
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
if keys[vidx] not in feature_rgb_file:
|
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_train']}")
|
|
|
self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:])
|
|
|
if opt['rgb_only']:
|
|
|
self.feature_flow_file = None
|
|
|
else:
|
|
|
self.feature_flow_file = {}
|
|
|
feature_flow_file = h5py.File(opt["video_feature_flow_train"], 'r')
|
|
|
for vidx in range(len(keys)):
|
|
|
if keys[vidx] not in feature_flow_file:
|
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_train']}")
|
|
|
self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:])
|
|
|
elif opt['data_format'] == "pickle":
|
|
|
feature_All = pickle.load(open(opt["video_feature_all_train"], 'rb'))
|
|
|
self.feature_rgb_file = {}
|
|
|
self.feature_flow_file = {}
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
if keys[vidx] not in feature_All:
|
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_train']}")
|
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
|
|
|
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
|
|
|
elif opt['data_format'] == "npz":
|
|
|
feature_All = {}
|
|
|
self.feature_rgb_file = {}
|
|
|
self.feature_flow_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_path = opt["video_feature_all_train"] + file + '.npz'
|
|
|
if not os.path.exists(feature_path):
|
|
|
raise ValueError(f"Feature file {feature_path} not found")
|
|
|
feature_All[file] = np.load(feature_path)['feats']
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
|
|
|
self.feature_flow_file = None
|
|
|
elif opt['data_format'] == "npz_i3d":
|
|
|
feature_All = {}
|
|
|
self.feature_rgb_file = {}
|
|
|
self.feature_flow_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_path = opt["video_feature_all_train"] + file + '.npz'
|
|
|
if not os.path.exists(feature_path):
|
|
|
raise ValueError(f"Feature file {feature_path} not found")
|
|
|
feature_All[file] = np.load(feature_path)
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
|
|
|
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
|
|
|
elif opt['data_format'] == "pt":
|
|
|
feature_All = {}
|
|
|
self.feature_rgb_file = {}
|
|
|
self.feature_flow_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_path = opt["video_feature_all_train"] + file + '.pt'
|
|
|
if not os.path.exists(feature_path):
|
|
|
raise ValueError(f"Feature file {feature_path} not found")
|
|
|
feature_All[file] = torch.load(feature_path)
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
|
|
|
self.feature_flow_file = None
|
|
|
else:
|
|
|
if opt['data_format'] == "h5":
|
|
|
feature_rgb_file = h5py.File(opt["video_feature_rgb_test"], 'r')
|
|
|
self.feature_rgb_file = {}
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
if keys[vidx] not in feature_rgb_file:
|
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_test']}")
|
|
|
self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:])
|
|
|
if opt['rgb_only']:
|
|
|
self.feature_flow_file = None
|
|
|
else:
|
|
|
self.feature_flow_file = {}
|
|
|
feature_flow_file = h5py.File(opt["video_feature_flow_test"], 'r')
|
|
|
for vidx in range(len(keys)):
|
|
|
if keys[vidx] not in feature_flow_file:
|
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_test']}")
|
|
|
self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:])
|
|
|
elif opt['data_format'] == "pickle":
|
|
|
feature_All = pickle.load(open(opt["video_feature_all_test"], 'rb'))
|
|
|
self.feature_rgb_file = {}
|
|
|
self.feature_flow_file = {}
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
if keys[vidx] not in feature_All:
|
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_test']}")
|
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
|
|
|
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
|
|
|
elif opt['data_format'] == "npz":
|
|
|
feature_All = {}
|
|
|
self.feature_rgb_file = {}
|
|
|
self.feature_flow_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_path = opt["video_feature_all_test"] + file + '.npz'
|
|
|
if not os.path.exists(feature_path):
|
|
|
raise ValueError(f"Feature file {feature_path} not found")
|
|
|
feature_All[file] = np.load(feature_path)['feats']
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
|
|
|
self.feature_flow_file = None
|
|
|
elif opt['data_format'] == "npz_i3d":
|
|
|
feature_All = {}
|
|
|
self.feature_rgb_file = {}
|
|
|
self.feature_flow_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_path = opt["video_feature_all_test"] + file + '.npz'
|
|
|
if not os.path.exists(feature_path):
|
|
|
raise ValueError(f"Feature file {feature_path} not found")
|
|
|
feature_All[file] = np.load(feature_path)
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb']
|
|
|
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow']
|
|
|
elif opt['data_format'] == "pt":
|
|
|
feature_All = {}
|
|
|
self.feature_rgb_file = {}
|
|
|
self.feature_flow_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_path = opt["video_feature_all_test"] + file + '.pt'
|
|
|
if not os.path.exists(feature_path):
|
|
|
raise ValueError(f"Feature file {feature_path} not found")
|
|
|
feature_All[file] = torch.load(feature_path)
|
|
|
keys = self.video_list
|
|
|
for vidx in range(len(keys)):
|
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:]
|
|
|
self.feature_flow_file = None
|
|
|
|
|
|
def _loadFeaturelen(self, opt):
|
|
|
if os.path.exists(self.video_len_path):
|
|
|
self.video_len = load_json(self.video_len_path)
|
|
|
return
|
|
|
|
|
|
self.video_len = {}
|
|
|
if self.subset == "train":
|
|
|
if opt['data_format'] == "h5":
|
|
|
feature_file = h5py.File(opt["video_feature_rgb_train"], 'r')
|
|
|
elif opt['data_format'] == "pickle":
|
|
|
feature_file = pickle.load(open(opt["video_feature_all_train"], 'rb'))
|
|
|
elif opt['data_format'] == "npz":
|
|
|
feature_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_file[file] = np.load(opt["video_feature_all_train"] + file + '.npz')['feats']
|
|
|
elif opt['data_format'] == "npz_i3d":
|
|
|
feature_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_file[file] = np.load(opt["video_feature_all_train"] + file + '.npz')
|
|
|
elif opt['data_format'] == "pt":
|
|
|
feature_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_file[file] = torch.load(opt["video_feature_all_train"] + file + '.pt')
|
|
|
else:
|
|
|
if opt['data_format'] == "h5":
|
|
|
feature_file = h5py.File(opt["video_feature_rgb_test"], 'r')
|
|
|
elif opt['data_format'] == "pickle":
|
|
|
feature_file = pickle.load(open(opt["video_feature_all_test"], 'rb'))
|
|
|
elif opt['data_format'] == "npz":
|
|
|
feature_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_file[file] = np.load(opt["video_feature_all_test"] + file + '.npz')['feats']
|
|
|
elif opt['data_format'] == "npz_i3d":
|
|
|
feature_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_file[file] = np.load(opt["video_feature_all_test"] + file + '.npz')
|
|
|
elif opt['data_format'] == "pt":
|
|
|
feature_file = {}
|
|
|
for file in self.video_list:
|
|
|
feature_file[file] = torch.load(opt["video_feature_all_test"] + file + '.pt')
|
|
|
|
|
|
keys = self.video_list
|
|
|
if opt['data_format'] == "h5":
|
|
|
for vidx in range(len(keys)):
|
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
|
|
|
elif opt['data_format'] == "pickle":
|
|
|
for vidx in range(len(keys)):
|
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
|
|
|
elif opt['data_format'] == "npz":
|
|
|
for vidx in range(len(keys)):
|
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
|
|
|
elif opt['data_format'] == "npz_i3d":
|
|
|
for vidx in range(len(keys)):
|
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
|
|
|
elif opt['data_format'] == "pt":
|
|
|
for vidx in range(len(keys)):
|
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
|
|
|
outfile = open(self.video_len_path, "w")
|
|
|
json.dump(self.video_len, outfile, indent=2)
|
|
|
outfile.close()
|
|
|
|
|
|
def _getDatasetDict(self):
|
|
|
anno_database = load_json(self.video_anno_path)
|
|
|
anno_database = anno_database['database']
|
|
|
self.video_dict = {}
|
|
|
if self.single_video_name:
|
|
|
if self.single_video_name in anno_database:
|
|
|
video_info = anno_database[self.single_video_name]
|
|
|
video_subset = video_info['subset']
|
|
|
if self.subset == "full" or self.subset in video_subset:
|
|
|
self.video_dict[self.single_video_name] = video_info
|
|
|
for seg in video_info['annotations']:
|
|
|
if not seg['label'] in self.label_name:
|
|
|
self.label_name.append(seg['label'])
|
|
|
else:
|
|
|
raise ValueError(f"Video {self.single_video_name} not found in annotation database")
|
|
|
else:
|
|
|
for video_name in anno_database:
|
|
|
video_info = anno_database[video_name]
|
|
|
video_subset = anno_database[video_name]['subset']
|
|
|
if self.subset == "full" or self.subset in video_subset:
|
|
|
self.video_dict[video_name] = video_info
|
|
|
for seg in video_info['annotations']:
|
|
|
if not seg['label'] in self.label_name:
|
|
|
self.label_name.append(seg['label'])
|
|
|
|
|
|
|
|
|
expected_labels = [
|
|
|
'Clean/Wipe', 'Close', 'Compress', 'Crack', 'Cut', 'Divide/Pull Apart',
|
|
|
'Dry', 'Inspect/Read', 'Mix', 'Move Around', 'Open', 'Operate', 'Other',
|
|
|
'Pour', 'Put', 'Squeeze', 'Take', 'Transfer', 'Turn off', 'Turn on', 'Wash',
|
|
|
'Spread'
|
|
|
]
|
|
|
for label in expected_labels:
|
|
|
if label not in self.label_name:
|
|
|
self.label_name.append(label)
|
|
|
|
|
|
self.label_name.sort()
|
|
|
self.video_list = list(self.video_dict.keys())
|
|
|
print(f"Labels in dataset.label_name: {self.label_name}")
|
|
|
print(f"Number of labels: {len(self.label_name)}, Expected: {self.num_of_class-1}")
|
|
|
print(f"{self.subset} subset video numbers: {len(self.video_list)}")
|
|
|
|
|
|
def _getMatchScore(self):
|
|
|
self.action_end_count = torch.zeros(2)
|
|
|
for index in range(0, len(self.video_list)):
|
|
|
video_name = self.video_list[index]
|
|
|
video_info = self.video_dict[video_name]
|
|
|
video_labels = video_info['annotations']
|
|
|
gt_bbox = []
|
|
|
gt_edlen = []
|
|
|
|
|
|
second_to_frame = self.video_len[video_name] / float(video_info['duration'])
|
|
|
for j in range(len(video_labels)):
|
|
|
tmp_info = video_labels[j]
|
|
|
tmp_start = tmp_info['segment'][0] * second_to_frame
|
|
|
tmp_end = tmp_info['segment'][1] * second_to_frame
|
|
|
tmp_label = self.label_name.index(tmp_info['label'])
|
|
|
gt_bbox.append([tmp_start, tmp_end, tmp_label])
|
|
|
gt_edlen.append([gt_bbox[-1][1], gt_bbox[-1][1] - gt_bbox[-1][0], tmp_label])
|
|
|
|
|
|
gt_bbox = np.array(gt_bbox)
|
|
|
gt_edlen = np.array(gt_edlen)
|
|
|
self.gt_action[video_name] = gt_edlen
|
|
|
|
|
|
match_score = np.zeros((self.video_len[video_name], self.num_of_class - 1), dtype=np.float32)
|
|
|
for idx in range(gt_bbox.shape[0]):
|
|
|
ed = int(gt_bbox[idx, 1]) + 1
|
|
|
st = int(gt_bbox[idx, 0])
|
|
|
match_score[st:ed, int(gt_bbox[idx, 2])] = idx + 1
|
|
|
self.match_score[video_name] = match_score
|
|
|
|
|
|
def _makeInputSeq(self):
|
|
|
data_idx = 0
|
|
|
for index in range(0, len(self.video_list)):
|
|
|
video_name = self.video_list[index]
|
|
|
duration = self.match_score[video_name].shape[0]
|
|
|
for i in range(1, duration + 1):
|
|
|
st = i - self.segment_size
|
|
|
ed = i
|
|
|
self.inputs_all.append([video_name, st, ed, data_idx])
|
|
|
data_idx += 1
|
|
|
|
|
|
self.inputs = self.inputs_all.copy()
|
|
|
print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
|
|
|
|
|
|
def _makePropLabelUnit(self, i):
|
|
|
video_name = self.inputs_all[i][0]
|
|
|
st = self.inputs_all[i][1]
|
|
|
ed = self.inputs_all[i][2]
|
|
|
cls_anc = []
|
|
|
reg_anc = []
|
|
|
|
|
|
for j in range(0, len(self.anchors)):
|
|
|
v1 = np.zeros(self.num_of_class)
|
|
|
v1[-1] = 1
|
|
|
v2 = np.zeros(2)
|
|
|
v2[-1] = -1e3
|
|
|
y_box = [ed - 1, self.anchors[j]]
|
|
|
|
|
|
subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[j], ed)
|
|
|
idx_list = []
|
|
|
for ii in range(0, subset_label.shape[0]):
|
|
|
for jj in range(0, subset_label.shape[1]):
|
|
|
idx = int(subset_label[ii, jj])
|
|
|
if idx > 0 and idx - 1 not in idx_list:
|
|
|
idx_list.append(idx - 1)
|
|
|
|
|
|
for idx in idx_list:
|
|
|
target_box = self.gt_action[video_name][idx]
|
|
|
cls = int(target_box[2])
|
|
|
iou = calc_iou(y_box, target_box)
|
|
|
if iou >= self.pos_threshold or (j == len(self.anchors) - 1 and box_include(y_box, target_box)) or (j == 0 and box_include(target_box, y_box)):
|
|
|
v1[cls] = 1
|
|
|
v1[-1] = 0
|
|
|
v2[0] = 1.0 * (target_box[0] - y_box[0]) / self.anchors[j]
|
|
|
v2[1] = np.log(1.0 * max(1, target_box[1]) / y_box[1])
|
|
|
|
|
|
cls_anc.append(v1)
|
|
|
reg_anc.append(v2)
|
|
|
|
|
|
v0 = np.zeros(self.num_of_class)
|
|
|
v0[-1] = 1
|
|
|
segment_size = ed - st
|
|
|
y_box = [ed - 1, self.anchors[-1]]
|
|
|
subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[-1], ed)
|
|
|
idx_list = []
|
|
|
for ii in range(0, subset_label.shape[0]):
|
|
|
for jj in range(0, subset_label.shape[1]):
|
|
|
idx = int(subset_label[ii, jj])
|
|
|
if idx > 0 and idx - 1 not in idx_list:
|
|
|
idx_list.append(idx - 1)
|
|
|
|
|
|
for idx in idx_list:
|
|
|
target_box = self.gt_action[video_name][idx]
|
|
|
cls = int(target_box[2])
|
|
|
iou = calc_iou(y_box, target_box)
|
|
|
if iou >= 0:
|
|
|
v0[cls] = 1
|
|
|
v0[-1] = 0
|
|
|
|
|
|
cls_anc = np.stack(cls_anc, axis=0)
|
|
|
reg_anc = np.stack(reg_anc, axis=0)
|
|
|
cls_snip = np.array(v0)
|
|
|
return cls_anc, reg_anc, cls_snip
|
|
|
|
|
|
def _loadPropLabel(self, filename):
|
|
|
if os.path.exists(filename):
|
|
|
prop_label_file = h5py.File(filename, 'r')
|
|
|
self.cls_label = np.array(prop_label_file['cls_label'][:])
|
|
|
self.reg_label = np.array(prop_label_file['reg_label'][:])
|
|
|
self.snip_label = np.array(prop_label_file['snip_label'][:])
|
|
|
prop_label_file.close()
|
|
|
self.action_frame_count = np.sum(self.cls_label.reshape((-1, self.cls_label.shape[-1])), axis=0)
|
|
|
self.action_frame_count = torch.Tensor(self.action_frame_count)
|
|
|
return
|
|
|
|
|
|
pool = Pool(os.cpu_count() // 2)
|
|
|
labels = pool.map(self._makePropLabelUnit, range(0, len(self.inputs_all)))
|
|
|
pool.close()
|
|
|
pool.join()
|
|
|
|
|
|
cls_label = []
|
|
|
reg_label = []
|
|
|
snip_label = []
|
|
|
for i in range(0, len(labels)):
|
|
|
cls_label.append(labels[i][0])
|
|
|
reg_label.append(labels[i][1])
|
|
|
snip_label.append(labels[i][2])
|
|
|
self.cls_label = np.stack(cls_label, axis=0)
|
|
|
self.reg_label = np.stack(reg_label, axis=0)
|
|
|
self.snip_label = np.stack(snip_label, axis=0)
|
|
|
|
|
|
outfile = h5py.File(filename, 'w')
|
|
|
dset_cls = outfile.create_dataset('/cls_label', self.cls_label.shape, maxshape=self.cls_label.shape, chunks=True, dtype=np.float32)
|
|
|
dset_cls[:, :] = self.cls_label[:, :]
|
|
|
dset_reg = outfile.create_dataset('/reg_label', self.reg_label.shape, maxshape=self.reg_label.shape, chunks=True, dtype=np.float32)
|
|
|
dset_reg[:, :] = self.reg_label[:, :]
|
|
|
dset_snip = outfile.create_dataset('/snip_label', self.snip_label.shape, maxshape=self.snip_label.shape, chunks=True, dtype=np.float32)
|
|
|
dset_snip[:, :] = self.snip_label[:, :]
|
|
|
outfile.close()
|
|
|
|
|
|
return
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
video_name, st, ed, data_idx = self.inputs[index]
|
|
|
if st >= 0:
|
|
|
feature = self._get_base_data(video_name, st, ed)
|
|
|
else:
|
|
|
feature = self._get_base_data(video_name, 0, ed)
|
|
|
padfunc2d = torch.nn.ConstantPad2d((0, 0, -st, 0), 0)
|
|
|
feature = padfunc2d(feature)
|
|
|
|
|
|
cls_label = torch.Tensor(self.cls_label[data_idx])
|
|
|
reg_label = torch.Tensor(self.reg_label[data_idx])
|
|
|
snip_label = torch.Tensor(self.snip_label[data_idx])
|
|
|
|
|
|
return feature, cls_label, reg_label, snip_label
|
|
|
|
|
|
def _get_base_data(self, video_name, st, ed):
|
|
|
feature_rgb = self.feature_rgb_file[video_name]
|
|
|
feature_rgb = feature_rgb[st:ed, :]
|
|
|
|
|
|
if self.feature_flow_file is not None:
|
|
|
feature_flow = self.feature_flow_file[video_name]
|
|
|
feature_flow = feature_flow[st:ed, :]
|
|
|
feature = np.append(feature_rgb, feature_flow, axis=1)
|
|
|
else:
|
|
|
feature = feature_rgb
|
|
|
feature = torch.from_numpy(np.array(feature))
|
|
|
|
|
|
return feature
|
|
|
|
|
|
def _get_train_label_with_class(self, video_name, st, ed):
|
|
|
duration = len(self.match_score[video_name])
|
|
|
st_padding = 0
|
|
|
ed_padding = 0
|
|
|
if st < 0:
|
|
|
st_padding = -st
|
|
|
st = 0
|
|
|
if ed > duration:
|
|
|
ed_padding = ed - duration
|
|
|
ed = duration
|
|
|
|
|
|
match_score = torch.Tensor(self.match_score[video_name][st:ed])
|
|
|
if st_padding > 0:
|
|
|
padfunc2d = torch.nn.ConstantPad2d((0, 0, st_padding, 0), 0)
|
|
|
match_score = padfunc2d(match_score)
|
|
|
if ed_padding > 0:
|
|
|
padfunc2d = torch.nn.ConstantPad2d((0, 0, 0, ed_padding), 0)
|
|
|
match_score = padfunc2d(match_score)
|
|
|
return match_score
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.inputs)
|
|
|
|
|
|
def reset_sample(self):
|
|
|
self.inputs = self.inputs_all.copy()
|
|
|
|
|
|
def select_sample(self, idx):
|
|
|
inputs = [self.inputs_all[i] for i in idx]
|
|
|
self.inputs = inputs.copy()
|
|
|
return
|
|
|
|
|
|
class SuppressDataSet(data.Dataset):
|
|
|
def __init__(self, opt, subset="train"):
|
|
|
self.subset = subset
|
|
|
self.mode = opt["mode"]
|
|
|
self.data_file = h5py.File(opt["suppress_label_file"].format(self.subset + "_" + opt['setup']), 'r')
|
|
|
self.video_list = list(self.data_file.keys())
|
|
|
self.inputs = []
|
|
|
for index in range(0, len(self.video_list)):
|
|
|
video_name = self.video_list[index]
|
|
|
duration = self.data_file[video_name + '/input'].shape[0]
|
|
|
for i in range(0, duration):
|
|
|
self.inputs.append([video_name, i])
|
|
|
|
|
|
print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
video_name, idx = self.inputs[index]
|
|
|
|
|
|
input_seq = self.data_file[video_name + '/input'][idx]
|
|
|
label = self.data_file[video_name + '/label'][idx]
|
|
|
|
|
|
input_seq = torch.from_numpy(input_seq)
|
|
|
label = torch.from_numpy(label)
|
|
|
|
|
|
return input_seq, label
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.inputs) |