forensics-grpo / code /libs /datasets /data_utils.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
3.47 kB
import os
import copy
import random
import numpy as np
import random
import torch
def trivial_batch_collator(batch):
"""
A batch collator that does nothing
"""
return batch
def worker_init_reset_seed(worker_id):
"""
Reset random seed for each worker
"""
seed = torch.initial_seed() % 2 ** 31
np.random.seed(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
def truncate_feats(
data_dict,
max_seq_len,
trunc_thresh,
crop_ratio=None,
max_num_trials=200,
has_action=True,
no_trunc=False
):
"""
Truncate feats and time stamps in a dict item
data_dict = {'video_id' : str
'feats' : Tensor C x T
'segments' : Tensor N x 2 (in feature grid)
'labels' : Tensor N
'fps' : float
'feat_stride' : int
'feat_num_frames' : in
"""
# get the meta info
feat_len = data_dict['feats'].shape[1]
num_segs = data_dict['segments'].shape[0]
# seq_len < max_seq_len
if feat_len <= max_seq_len:
# do nothing
if crop_ratio == None:
return data_dict
# randomly crop the seq by setting max_seq_len to a value in [l, r]
else:
max_seq_len = random.randint(
max(round(crop_ratio[0] * feat_len), 1),
min(round(crop_ratio[1] * feat_len), feat_len),
)
# # corner case
if feat_len == max_seq_len:
return data_dict
# otherwise, deep copy the dict
data_dict = copy.deepcopy(data_dict)
# try a few times till a valid truncation with at least one action
for _ in range(max_num_trials):
# sample a random truncation of the video feats
st = random.randint(0, feat_len - max_seq_len)
ed = st + max_seq_len
window = torch.as_tensor([st, ed], dtype=torch.float32)
# compute the intersection between the sampled window and all segments
window = window[None].repeat(num_segs, 1)
left = torch.maximum(window[:, 0], data_dict['segments'][:, 0])
right = torch.minimum(window[:, 1], data_dict['segments'][:, 1])
inter = (right - left).clamp(min=0)
area_segs = torch.abs(
data_dict['segments'][:, 1] - data_dict['segments'][:, 0])
inter_ratio = inter / area_segs
# only select those segments over the thresh
seg_idx = (inter_ratio >= trunc_thresh)
if no_trunc:
# with at least one action and not truncating any actions
seg_trunc_idx = torch.logical_and(
(inter_ratio > 0.0), (inter_ratio < 1.0)
)
if (seg_idx.sum().item() > 0) and (seg_trunc_idx.sum().item() == 0):
break
elif has_action:
# with at least one action
if seg_idx.sum().item() > 0:
break
else:
# without any constraints
break
# feats: C x T
data_dict['feats'] = data_dict['feats'][:, st:ed].clone()
# segments: N x 2 in feature grids
data_dict['segments'] = torch.stack((left[seg_idx], right[seg_idx]), dim=1)
# shift the time stamps due to truncation
data_dict['segments'] = data_dict['segments'] - st
# labels: N
data_dict['labels'] = data_dict['labels'][seg_idx].clone()
return data_dict