|
|
import numpy as np |
|
|
import h5py |
|
|
import json |
|
|
import torch |
|
|
import torch.utils.data as data |
|
|
import os |
|
|
import pickle |
|
|
from multiprocessing import Pool |
|
|
|
|
|
def load_json(file): |
|
|
with open(file) as json_file: |
|
|
data = json.load(json_file) |
|
|
return data |
|
|
|
|
|
def calc_iou(a, b): |
|
|
st = a[0] - a[1] |
|
|
ed = a[0] |
|
|
target_st = b[0] - b[1] |
|
|
target_ed = b[0] |
|
|
sst = min(st, target_st) |
|
|
led = max(ed, target_ed) |
|
|
lst = max(st, target_st) |
|
|
sed = min(ed, target_ed) |
|
|
iou = (sed - lst) / max(led - sst, 1) |
|
|
return iou |
|
|
|
|
|
def box_include(y, target): |
|
|
st = y[0] - y[1] |
|
|
ed = y[0] |
|
|
target_st = target[0] - target[1] |
|
|
target_ed = target[0] |
|
|
detection_point = target_st |
|
|
if ed > detection_point and target_st < st and target_ed > ed: |
|
|
return True |
|
|
return False |
|
|
|
|
|
class VideoDataSet(data.Dataset): |
|
|
def __init__(self, opt, subset="train", video_name=None): |
|
|
self.subset = subset |
|
|
self.mode = opt["mode"] |
|
|
self.predefined_fps = opt["predefined_fps"] |
|
|
self.video_anno_path = opt["video_anno"].format(opt["split"]) |
|
|
self.video_len_path = opt["video_len_file"].format(self.subset + '_' + opt["setup"]) |
|
|
self.num_of_class = opt["num_of_class"] |
|
|
self.segment_size = opt["segment_size"] |
|
|
self.label_name = [] |
|
|
self.match_score = {} |
|
|
self.match_score_end = {} |
|
|
self.match_length = {} |
|
|
self.gt_action = {} |
|
|
self.cls_label = {} |
|
|
self.reg_label = {} |
|
|
self.snip_label = {} |
|
|
self.inputs = [] |
|
|
self.inputs_all = [] |
|
|
self.data_rescale = opt["data_rescale"] |
|
|
self.anchors = opt["anchors"] |
|
|
self.pos_threshold = opt["pos_threshold"] |
|
|
self.single_video_name = video_name |
|
|
|
|
|
self._getDatasetDict() |
|
|
self._loadFeaturelen(opt) |
|
|
self._getMatchScore() |
|
|
self._makeInputSeq() |
|
|
self._loadPropLabel(opt['proposal_label_file'].format(self.subset + '_' + opt["setup"])) |
|
|
|
|
|
if self.subset == "train": |
|
|
if opt['data_format'] == "h5": |
|
|
feature_rgb_file = h5py.File(opt["video_feature_rgb_train"], 'r') |
|
|
self.feature_rgb_file = {} |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
if keys[vidx] not in feature_rgb_file: |
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_train']}") |
|
|
self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:]) |
|
|
if opt['rgb_only']: |
|
|
self.feature_flow_file = None |
|
|
else: |
|
|
self.feature_flow_file = {} |
|
|
feature_flow_file = h5py.File(opt["video_feature_flow_train"], 'r') |
|
|
for vidx in range(len(keys)): |
|
|
if keys[vidx] not in feature_flow_file: |
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_train']}") |
|
|
self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:]) |
|
|
elif opt['data_format'] == "pickle": |
|
|
feature_All = pickle.load(open(opt["video_feature_all_train"], 'rb')) |
|
|
self.feature_rgb_file = {} |
|
|
self.feature_flow_file = {} |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
if keys[vidx] not in feature_All: |
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_train']}") |
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb'] |
|
|
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow'] |
|
|
elif opt['data_format'] == "npz": |
|
|
feature_All = {} |
|
|
self.feature_rgb_file = {} |
|
|
self.feature_flow_file = {} |
|
|
for file in self.video_list: |
|
|
feature_path = os.path.join(opt["video_feature_all_train"], file + '.npz') |
|
|
if not os.path.exists(feature_path): |
|
|
raise ValueError(f"Feature file {feature_path} not found") |
|
|
feature_All[file] = np.load(feature_path)['feats'] |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:] |
|
|
self.feature_flow_file = None |
|
|
elif opt['data_format'] == "npz_i3d": |
|
|
feature_All = {} |
|
|
self.feature_rgb_file = {} |
|
|
self.feature_flow_file = {} |
|
|
for file in self.video_list: |
|
|
feature_path = os.path.join(opt["video_feature_all_train"], file + '.npz') |
|
|
if not os.path.exists(feature_path): |
|
|
raise ValueError(f"Feature file {feature_path} not found") |
|
|
feature_All[file] = np.load(feature_path) |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb'] |
|
|
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow'] |
|
|
elif opt['data_format'] == "pt": |
|
|
feature_All = {} |
|
|
self.feature_rgb_file = {} |
|
|
self.feature_flow_file = {} |
|
|
for file in self.video_list: |
|
|
feature_path = os.path.join(opt["video_feature_all_train"], file + '.pt') |
|
|
if not os.path.exists(feature_path): |
|
|
raise ValueError(f"Feature file {feature_path} not found") |
|
|
feature_All[file] = torch.load(feature_path) |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:] |
|
|
self.feature_flow_file = None |
|
|
else: |
|
|
if opt['data_format'] == "h5": |
|
|
feature_rgb_file = h5py.File(opt["video_feature_rgb_test"], 'r') |
|
|
self.feature_rgb_file = {} |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
if keys[vidx] not in feature_rgb_file: |
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_rgb_test']}") |
|
|
self.feature_rgb_file[keys[vidx]] = np.array(feature_rgb_file[keys[vidx]][:]) |
|
|
if opt['rgb_only']: |
|
|
self.feature_flow_file = None |
|
|
else: |
|
|
self.feature_flow_file = {} |
|
|
feature_flow_file = h5py.File(opt["video_feature_flow_test"], 'r') |
|
|
for vidx in range(len(keys)): |
|
|
if keys[vidx] not in feature_flow_file: |
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_flow_test']}") |
|
|
self.feature_flow_file[keys[vidx]] = np.array(feature_flow_file[keys[vidx]][:]) |
|
|
elif opt['data_format'] == "pickle": |
|
|
feature_All = pickle.load(open(opt["video_feature_all_test"], 'rb')) |
|
|
self.feature_rgb_file = {} |
|
|
self.feature_flow_file = {} |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
if keys[vidx] not in feature_All: |
|
|
raise ValueError(f"Features for video {keys[vidx]} not found in {opt['video_feature_all_test']}") |
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb'] |
|
|
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow'] |
|
|
elif opt['data_format'] == "npz": |
|
|
feature_All = {} |
|
|
self.feature_rgb_file = {} |
|
|
self.feature_flow_file = {} |
|
|
for file in self.video_list: |
|
|
feature_path = os.path.join(opt["video_feature_all_test"], file + '.npz') |
|
|
if not os.path.exists(feature_path): |
|
|
raise ValueError(f"Feature file {feature_path} not found") |
|
|
feature_All[file] = np.load(feature_path)['feats'] |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:] |
|
|
self.feature_flow_file = None |
|
|
elif opt['data_format'] == "npz_i3d": |
|
|
feature_All = {} |
|
|
self.feature_rgb_file = {} |
|
|
self.feature_flow_file = {} |
|
|
for file in self.video_list: |
|
|
feature_path = os.path.join(opt["video_feature_all_test"], file + '.npz') |
|
|
if not os.path.exists(feature_path): |
|
|
raise ValueError(f"Feature file {feature_path} not found") |
|
|
feature_All[file] = np.load(feature_path) |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]]['rgb'] |
|
|
self.feature_flow_file[keys[vidx]] = feature_All[keys[vidx]]['flow'] |
|
|
elif opt['data_format'] == "pt": |
|
|
feature_All = {} |
|
|
self.feature_rgb_file = {} |
|
|
self.feature_flow_file = {} |
|
|
for file in self.video_list: |
|
|
feature_path = os.path.join(opt["video_feature_all_test"], file + '.pt') |
|
|
if not os.path.exists(feature_path): |
|
|
raise ValueError(f"Feature file {feature_path} not found") |
|
|
feature_All[file] = torch.load(feature_path) |
|
|
keys = self.video_list |
|
|
for vidx in range(len(keys)): |
|
|
self.feature_rgb_file[keys[vidx]] = feature_All[keys[vidx]][:] |
|
|
self.feature_flow_file = None |
|
|
|
|
|
def _loadFeaturelen(self, opt): |
|
|
if os.path.exists(self.video_len_path): |
|
|
self.video_len = load_json(self.video_len_path) |
|
|
return |
|
|
|
|
|
self.video_len = {} |
|
|
if self.subset == "train": |
|
|
if opt['data_format'] == "h5": |
|
|
feature_file = h5py.File(opt["video_feature_rgb_train"], 'r') |
|
|
elif opt['data_format'] == "pickle": |
|
|
feature_file = pickle.load(open(opt["video_feature_all_train"], 'rb')) |
|
|
elif opt['data_format'] == "npz": |
|
|
feature_file = {} |
|
|
for file in self.video_list: |
|
|
feature_file[file] = np.load(os.path.join(opt["video_feature_all_train"], file + '.npz'))['feats'] |
|
|
elif opt['data_format'] == "npz_i3d": |
|
|
feature_file = {} |
|
|
for file in self.video_list: |
|
|
feature_file[file] = np.load(os.path.join(opt["video_feature_all_train"], file + '.npz')) |
|
|
elif opt['data_format'] == "pt": |
|
|
feature_file = {} |
|
|
for file in self.video_list: |
|
|
feature_file[file] = torch.load(os.path.join(opt["video_feature_all_train"], file + '.pt')) |
|
|
else: |
|
|
if opt['data_format'] == "h5": |
|
|
feature_file = h5py.File(opt["video_feature_rgb_test"], 'r') |
|
|
elif opt['data_format'] == "pickle": |
|
|
feature_file = pickle.load(open(opt["video_feature_all_test"], 'rb')) |
|
|
elif opt['data_format'] == "npz": |
|
|
feature_file = {} |
|
|
for file in self.video_list: |
|
|
feature_file[file] = np.load(os.path.join(opt["video_feature_all_test"], file + '.npz'))['feats'] |
|
|
elif opt['data_format'] == "npz_i3d": |
|
|
feature_file = {} |
|
|
for file in self.video_list: |
|
|
feature_file[file] = np.load(os.path.join(opt["video_feature_all_test"], file + '.npz')) |
|
|
elif opt['data_format'] == "pt": |
|
|
feature_file = {} |
|
|
for file in self.video_list: |
|
|
feature_file[file] = torch.load(os.path.join(opt["video_feature_all_test"], file + '.pt')) |
|
|
|
|
|
keys = self.video_list |
|
|
if opt['data_format'] == "h5": |
|
|
for vidx in range(len(keys)): |
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]) |
|
|
elif opt['data_format'] == "pickle": |
|
|
for vidx in range(len(keys)): |
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb']) |
|
|
elif opt['data_format'] == "npz": |
|
|
for vidx in range(len(keys)): |
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]) |
|
|
elif opt['data_format'] == "npz_i3d": |
|
|
for vidx in range(len(keys)): |
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb']) |
|
|
elif opt['data_format'] == "pt": |
|
|
for vidx in range(len(keys)): |
|
|
self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]) |
|
|
outfile = open(self.video_len_path, "w") |
|
|
json.dump(self.video_len, outfile, indent=2) |
|
|
outfile.close() |
|
|
|
|
|
def _getDatasetDict(self): |
|
|
anno_database = load_json(self.video_anno_path) |
|
|
anno_database = anno_database['database'] |
|
|
self.video_dict = {} |
|
|
if self.single_video_name: |
|
|
if self.single_video_name in anno_database: |
|
|
video_info = anno_database[self.single_video_name] |
|
|
video_subset = video_info['subset'] |
|
|
if self.subset == "full" or self.subset in video_subset: |
|
|
self.video_dict[self.single_video_name] = video_info |
|
|
for seg in video_info['annotations']: |
|
|
if not seg['label'] in self.label_name: |
|
|
self.label_name.append(seg['label']) |
|
|
else: |
|
|
raise ValueError(f"Video {self.single_video_name} not found in annotation database") |
|
|
else: |
|
|
for video_name in anno_database: |
|
|
video_info = anno_database[video_name] |
|
|
video_subset = anno_database[video_name]['subset'] |
|
|
if self.subset == "full" or self.subset in video_subset: |
|
|
self.video_dict[video_name] = video_info |
|
|
for seg in video_info['annotations']: |
|
|
if not seg['label'] in self.label_name: |
|
|
self.label_name.append(seg['label']) |
|
|
|
|
|
|
|
|
expected_labels = [ |
|
|
'Clean/Wipe', 'Close', 'Compress', 'Crack', 'Cut', 'Divide/Pull Apart', |
|
|
'Dry', 'Inspect/Read', 'Mix', 'Move Around', 'Open', 'Operate', 'Other', |
|
|
'Pour', 'Put', 'Squeeze', 'Take', 'Transfer', 'Turn off', 'Turn on', 'Wash', |
|
|
'Spread' |
|
|
] |
|
|
for label in expected_labels: |
|
|
if label not in self.label_name: |
|
|
self.label_name.append(label) |
|
|
|
|
|
self.label_name.sort() |
|
|
self.video_list = list(self.video_dict.keys()) |
|
|
print(f"Labels in dataset.label_name: {self.label_name}") |
|
|
print(f"Number of labels: {len(self.label_name)}, Expected: {self.num_of_class-1}") |
|
|
print(f"{self.subset} subset video numbers: {len(self.video_list)}") |
|
|
|
|
|
def _getMatchScore(self): |
|
|
self.action_end_count = torch.zeros(2) |
|
|
for index in range(0, len(self.video_list)): |
|
|
video_name = self.video_list[index] |
|
|
video_info = self.video_dict[video_name] |
|
|
video_labels = video_info['annotations'] |
|
|
gt_bbox = [] |
|
|
gt_edlen = [] |
|
|
|
|
|
second_to_frame = self.video_len[video_name] / float(video_info['duration']) |
|
|
for j in range(len(video_labels)): |
|
|
tmp_info = video_labels[j] |
|
|
tmp_start = tmp_info['segment'][0] * second_to_frame |
|
|
tmp_end = tmp_info['segment'][1] * second_to_frame |
|
|
tmp_label = self.label_name.index(tmp_info['label']) |
|
|
gt_bbox.append([tmp_start, tmp_end, tmp_label]) |
|
|
gt_edlen.append([gt_bbox[-1][1], gt_bbox[-1][1] - gt_bbox[-1][0], tmp_label]) |
|
|
|
|
|
gt_bbox = np.array(gt_bbox) |
|
|
gt_edlen = np.array(gt_edlen) |
|
|
self.gt_action[video_name] = gt_edlen |
|
|
|
|
|
match_score = np.zeros((self.video_len[video_name], self.num_of_class - 1), dtype=np.float32) |
|
|
for idx in range(gt_bbox.shape[0]): |
|
|
ed = int(gt_bbox[idx, 1]) + 1 |
|
|
st = int(gt_bbox[idx, 0]) |
|
|
match_score[st:ed, int(gt_bbox[idx, 2])] = idx + 1 |
|
|
self.match_score[video_name] = match_score |
|
|
|
|
|
def _makeInputSeq(self): |
|
|
data_idx = 0 |
|
|
for index in range(0, len(self.video_list)): |
|
|
video_name = self.video_list[index] |
|
|
duration = self.match_score[video_name].shape[0] |
|
|
for i in range(1, duration + 1): |
|
|
st = i - self.segment_size |
|
|
ed = i |
|
|
self.inputs_all.append([video_name, st, ed, data_idx]) |
|
|
data_idx += 1 |
|
|
|
|
|
self.inputs = self.inputs_all.copy() |
|
|
print(f"{self.subset} subset seg numbers: {len(self.inputs)}") |
|
|
|
|
|
def _makePropLabelUnit(self, i): |
|
|
video_name = self.inputs_all[i][0] |
|
|
st = self.inputs_all[i][1] |
|
|
ed = self.inputs_all[i][2] |
|
|
cls_anc = [] |
|
|
reg_anc = [] |
|
|
|
|
|
for j in range(0, len(self.anchors)): |
|
|
v1 = np.zeros(self.num_of_class) |
|
|
v1[-1] = 1 |
|
|
v2 = np.zeros(2) |
|
|
v2[-1] = -1e3 |
|
|
y_box = [ed - 1, self.anchors[j]] |
|
|
|
|
|
subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[j], ed) |
|
|
idx_list = [] |
|
|
for ii in range(0, subset_label.shape[0]): |
|
|
for jj in range(0, subset_label.shape[1]): |
|
|
idx = int(subset_label[ii, jj]) |
|
|
if idx > 0 and idx - 1 not in idx_list: |
|
|
idx_list.append(idx - 1) |
|
|
|
|
|
for idx in idx_list: |
|
|
target_box = self.gt_action[video_name][idx] |
|
|
cls = int(target_box[2]) |
|
|
iou = calc_iou(y_box, target_box) |
|
|
if iou >= self.pos_threshold or (j == len(self.anchors) - 1 and box_include(y_box, target_box)) or (j == 0 and box_include(target_box, y_box)): |
|
|
v1[cls] = 1 |
|
|
v1[-1] = 0 |
|
|
v2[0] = 1.0 * (target_box[0] - y_box[0]) / self.anchors[j] |
|
|
v2[1] = np.log(1.0 * max(1, target_box[1]) / y_box[1]) |
|
|
|
|
|
cls_anc.append(v1) |
|
|
reg_anc.append(v2) |
|
|
|
|
|
v0 = np.zeros(self.num_of_class) |
|
|
v0[-1] = 1 |
|
|
segment_size = ed - st |
|
|
y_box = [ed - 1, self.anchors[-1]] |
|
|
subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[-1], ed) |
|
|
idx_list = [] |
|
|
for ii in range(0, subset_label.shape[0]): |
|
|
for jj in range(0, subset_label.shape[1]): |
|
|
idx = int(subset_label[ii, jj]) |
|
|
if idx > 0 and idx - 1 not in idx_list: |
|
|
idx_list.append(idx - 1) |
|
|
|
|
|
for idx in idx_list: |
|
|
target_box = self.gt_action[video_name][idx] |
|
|
cls = int(target_box[2]) |
|
|
iou = calc_iou(y_box, target_box) |
|
|
if iou >= 0: |
|
|
v0[cls] = 1 |
|
|
v0[-1] = 0 |
|
|
|
|
|
cls_anc = np.stack(cls_anc, axis=0) |
|
|
reg_anc = np.stack(reg_anc, axis=0) |
|
|
cls_snip = np.array(v0) |
|
|
return cls_anc, reg_anc, cls_snip |
|
|
|
|
|
def _loadPropLabel(self, filename): |
|
|
if os.path.exists(filename): |
|
|
prop_label_file = h5py.File(filename, 'r') |
|
|
self.cls_label = np.array(prop_label_file['cls_label'][:]) |
|
|
self.reg_label = np.array(prop_label_file['reg_label'][:]) |
|
|
self.snip_label = np.array(prop_label_file['snip_label'][:]) |
|
|
prop_label_file.close() |
|
|
self.action_frame_count = np.sum(self.cls_label.reshape((-1, self.cls_label.shape[-1])), axis=0) |
|
|
self.action_frame_count = torch.Tensor(self.action_frame_count) |
|
|
return |
|
|
|
|
|
pool = Pool(os.cpu_count() // 2) |
|
|
labels = pool.map(self._makePropLabelUnit, range(0, len(self.inputs_all))) |
|
|
pool.close() |
|
|
pool.join() |
|
|
|
|
|
cls_label = [] |
|
|
reg_label = [] |
|
|
snip_label = [] |
|
|
for i in range(0, len(labels)): |
|
|
cls_label.append(labels[i][0]) |
|
|
reg_label.append(labels[i][1]) |
|
|
snip_label.append(labels[i][2]) |
|
|
self.cls_label = np.stack(cls_label, axis=0) |
|
|
self.reg_label = np.stack(reg_label, axis=0) |
|
|
self.snip_label = np.stack(snip_label, axis=0) |
|
|
|
|
|
outfile = h5py.File(filename, 'w') |
|
|
dset_cls = outfile.create_dataset('/cls_label', self.cls_label.shape, maxshape=self.cls_label.shape, chunks=True, dtype=np.float32) |
|
|
dset_cls[:, :] = self.cls_label[:, :] |
|
|
dset_reg = outfile.create_dataset('/reg_label', self.reg_label.shape, maxshape=self.reg_label.shape, chunks=True, dtype=np.float32) |
|
|
dset_reg[:, :] = self.reg_label[:, :] |
|
|
dset_snip = outfile.create_dataset('/snip_label', self.snip_label.shape, maxshape=self.snip_label.shape, chunks=True, dtype=np.float32) |
|
|
dset_snip[:, :] = self.snip_label[:, :] |
|
|
outfile.close() |
|
|
|
|
|
return |
|
|
|
|
|
def __getitem__(self, index): |
|
|
video_name, st, ed, data_idx = self.inputs[index] |
|
|
if st >= 0: |
|
|
feature = self._get_base_data(video_name, st, ed) |
|
|
else: |
|
|
feature = self._get_base_data(video_name, 0, ed) |
|
|
padfunc2d = torch.nn.ConstantPad2d((0, 0, -st, 0), 0) |
|
|
feature = padfunc2d(feature) |
|
|
|
|
|
cls_label = torch.Tensor(self.cls_label[data_idx]) |
|
|
reg_label = torch.Tensor(self.reg_label[data_idx]) |
|
|
snip_label = torch.Tensor(self.snip_label[data_idx]) |
|
|
|
|
|
return feature, cls_label, reg_label, snip_label |
|
|
|
|
|
def _get_base_data(self, video_name, st, ed): |
|
|
feature_rgb = self.feature_rgb_file[video_name] |
|
|
feature_rgb = feature_rgb[st:ed, :] |
|
|
|
|
|
if self.feature_flow_file is not None: |
|
|
feature_flow = self.feature_flow_file[video_name] |
|
|
feature_flow = feature_flow[st:ed, :] |
|
|
feature = np.append(feature_rgb, feature_flow, axis=1) |
|
|
else: |
|
|
feature = feature_rgb |
|
|
feature = torch.from_numpy(np.array(feature)) |
|
|
|
|
|
return feature |
|
|
|
|
|
def _get_train_label_with_class(self, video_name, st, ed): |
|
|
duration = len(self.match_score[video_name]) |
|
|
st_padding = 0 |
|
|
ed_padding = 0 |
|
|
if st < 0: |
|
|
st_padding = -st |
|
|
st = 0 |
|
|
if ed > duration: |
|
|
ed_padding = ed - duration |
|
|
ed = duration |
|
|
|
|
|
match_score = torch.Tensor(self.match_score[video_name][st:ed]) |
|
|
if st_padding > 0: |
|
|
padfunc2d = torch.nn.ConstantPad2d((0, 0, st_padding, 0), 0) |
|
|
match_score = padfunc2d(match_score) |
|
|
if ed_padding > 0: |
|
|
padfunc2d = torch.nn.ConstantPad2d((0, 0, 0, ed_padding), 0) |
|
|
match_score = padfunc2d(match_score) |
|
|
return match_score |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.inputs) |
|
|
|
|
|
def reset_sample(self): |
|
|
self.inputs = self.inputs_all.copy() |
|
|
|
|
|
def select_sample(self, idx): |
|
|
inputs = [self.inputs_all[i] for i in idx] |
|
|
self.inputs = inputs.copy() |
|
|
return |
|
|
|
|
|
class SuppressDataSet(data.Dataset): |
|
|
def __init__(self, opt, subset="train"): |
|
|
self.subset = subset |
|
|
self.mode = opt["mode"] |
|
|
self.data_file = h5py.File(opt["suppress_label_file"].format(self.subset + "_" + opt['setup']), 'r') |
|
|
self.video_list = list(self.data_file.keys()) |
|
|
self.inputs = [] |
|
|
for index in range(0, len(self.video_list)): |
|
|
video_name = self.video_list[index] |
|
|
duration = self.data_file[video_name + '/input'].shape[0] |
|
|
for i in range(0, duration): |
|
|
self.inputs.append([video_name, i]) |
|
|
|
|
|
print(f"{self.subset} subset seg numbers: {len(self.inputs)}") |
|
|
|
|
|
def __getitem__(self, index): |
|
|
video_name, idx = self.inputs[index] |
|
|
|
|
|
input_seq = self.data_file[video_name + '/input'][idx] |
|
|
label = self.data_file[video_name + '/label'][idx] |
|
|
|
|
|
input_seq = torch.from_numpy(input_seq) |
|
|
label = torch.from_numpy(label) |
|
|
|
|
|
return input_seq, label |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.inputs) |