File size: 3,474 Bytes
33569f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | import os
import copy
import random
import numpy as np
import random
import torch
def trivial_batch_collator(batch):
"""
A batch collator that does nothing
"""
return batch
def worker_init_reset_seed(worker_id):
"""
Reset random seed for each worker
"""
seed = torch.initial_seed() % 2 ** 31
np.random.seed(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
def truncate_feats(
data_dict,
max_seq_len,
trunc_thresh,
crop_ratio=None,
max_num_trials=200,
has_action=True,
no_trunc=False
):
"""
Truncate feats and time stamps in a dict item
data_dict = {'video_id' : str
'feats' : Tensor C x T
'segments' : Tensor N x 2 (in feature grid)
'labels' : Tensor N
'fps' : float
'feat_stride' : int
'feat_num_frames' : in
"""
# get the meta info
feat_len = data_dict['feats'].shape[1]
num_segs = data_dict['segments'].shape[0]
# seq_len < max_seq_len
if feat_len <= max_seq_len:
# do nothing
if crop_ratio == None:
return data_dict
# randomly crop the seq by setting max_seq_len to a value in [l, r]
else:
max_seq_len = random.randint(
max(round(crop_ratio[0] * feat_len), 1),
min(round(crop_ratio[1] * feat_len), feat_len),
)
# # corner case
if feat_len == max_seq_len:
return data_dict
# otherwise, deep copy the dict
data_dict = copy.deepcopy(data_dict)
# try a few times till a valid truncation with at least one action
for _ in range(max_num_trials):
# sample a random truncation of the video feats
st = random.randint(0, feat_len - max_seq_len)
ed = st + max_seq_len
window = torch.as_tensor([st, ed], dtype=torch.float32)
# compute the intersection between the sampled window and all segments
window = window[None].repeat(num_segs, 1)
left = torch.maximum(window[:, 0], data_dict['segments'][:, 0])
right = torch.minimum(window[:, 1], data_dict['segments'][:, 1])
inter = (right - left).clamp(min=0)
area_segs = torch.abs(
data_dict['segments'][:, 1] - data_dict['segments'][:, 0])
inter_ratio = inter / area_segs
# only select those segments over the thresh
seg_idx = (inter_ratio >= trunc_thresh)
if no_trunc:
# with at least one action and not truncating any actions
seg_trunc_idx = torch.logical_and(
(inter_ratio > 0.0), (inter_ratio < 1.0)
)
if (seg_idx.sum().item() > 0) and (seg_trunc_idx.sum().item() == 0):
break
elif has_action:
# with at least one action
if seg_idx.sum().item() > 0:
break
else:
# without any constraints
break
# feats: C x T
data_dict['feats'] = data_dict['feats'][:, st:ed].clone()
# segments: N x 2 in feature grids
data_dict['segments'] = torch.stack((left[seg_idx], right[seg_idx]), dim=1)
# shift the time stamps due to truncation
data_dict['segments'] = data_dict['segments'] - st
# labels: N
data_dict['labels'] = data_dict['labels'][seg_idx].clone()
return data_dict
|