|
|
import os |
|
|
from torchvision import transforms |
|
|
from .transforms import * |
|
|
from .masking_generator import TubeMaskingGenerator, RandomMaskingGenerator |
|
|
from .mae import VideoMAE |
|
|
from .mae_multi import VideoMAE_multi |
|
|
from .kinetics import VideoClsDataset |
|
|
from .kinetics_sparse import VideoClsDataset_sparse |
|
|
from .anet import ANetDataset |
|
|
from .ssv2 import SSVideoClsDataset, SSRawFrameClsDataset |
|
|
from .hmdb import HMDBVideoClsDataset, HMDBRawFrameClsDataset |
|
|
|
|
|
|
|
|
class DataAugmentationForVideoMAE(object): |
|
|
def __init__(self, args): |
|
|
self.input_mean = [0.485, 0.456, 0.406] |
|
|
self.input_std = [0.229, 0.224, 0.225] |
|
|
normalize = GroupNormalize(self.input_mean, self.input_std) |
|
|
self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66]) |
|
|
if args.color_jitter > 0: |
|
|
self.transform = transforms.Compose([ |
|
|
self.train_augmentation, |
|
|
GroupColorJitter(args.color_jitter), |
|
|
GroupRandomHorizontalFlip(flip=args.flip), |
|
|
Stack(roll=False), |
|
|
ToTorchFormatTensor(div=True), |
|
|
normalize, |
|
|
]) |
|
|
else: |
|
|
self.transform = transforms.Compose([ |
|
|
self.train_augmentation, |
|
|
GroupRandomHorizontalFlip(flip=args.flip), |
|
|
Stack(roll=False), |
|
|
ToTorchFormatTensor(div=True), |
|
|
normalize, |
|
|
]) |
|
|
if args.mask_type == 'tube': |
|
|
self.masked_position_generator = TubeMaskingGenerator( |
|
|
args.window_size, args.mask_ratio |
|
|
) |
|
|
elif args.mask_type == 'random': |
|
|
self.masked_position_generator = RandomMaskingGenerator( |
|
|
args.window_size, args.mask_ratio |
|
|
) |
|
|
elif args.mask_type in 'attention': |
|
|
self.masked_position_generator = None |
|
|
|
|
|
def __call__(self, images): |
|
|
process_data, _ = self.transform(images) |
|
|
if self.masked_position_generator is None: |
|
|
return process_data, -1 |
|
|
else: |
|
|
return process_data, self.masked_position_generator() |
|
|
|
|
|
def __repr__(self): |
|
|
repr = "(DataAugmentationForVideoMAE,\n" |
|
|
repr += " transform = %s,\n" % str(self.transform) |
|
|
repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) |
|
|
repr += ")" |
|
|
return repr |
|
|
|
|
|
|
|
|
def build_pretraining_dataset(args): |
|
|
transform = DataAugmentationForVideoMAE(args) |
|
|
dataset = VideoMAE( |
|
|
root=None, |
|
|
setting=args.data_path, |
|
|
prefix=args.prefix, |
|
|
split=args.split, |
|
|
video_ext='mp4', |
|
|
is_color=True, |
|
|
modality='rgb', |
|
|
num_segments=args.num_segments, |
|
|
new_length=args.num_frames, |
|
|
new_step=args.sampling_rate, |
|
|
transform=transform, |
|
|
temporal_jitter=False, |
|
|
video_loader=True, |
|
|
use_decord=args.use_decord, |
|
|
lazy_init=False, |
|
|
num_sample=args.num_sample) |
|
|
print("Data Aug = %s" % str(transform)) |
|
|
return dataset |
|
|
|
|
|
|
|
|
def build_multi_pretraining_dataset(args): |
|
|
origianl_flip = args.flip |
|
|
transform = DataAugmentationForVideoMAE(args) |
|
|
args.flip = False |
|
|
transform_ssv2 = DataAugmentationForVideoMAE(args) |
|
|
args.flip = origianl_flip |
|
|
|
|
|
dataset = VideoMAE_multi( |
|
|
root=None, |
|
|
setting=args.data_path, |
|
|
prefix=args.prefix, |
|
|
split=args.split, |
|
|
is_color=True, |
|
|
modality='rgb', |
|
|
num_segments=args.num_segments, |
|
|
new_length=args.num_frames, |
|
|
new_step=args.sampling_rate, |
|
|
transform=transform, |
|
|
transform_ssv2=transform_ssv2, |
|
|
temporal_jitter=False, |
|
|
video_loader=True, |
|
|
use_decord=args.use_decord, |
|
|
lazy_init=False, |
|
|
num_sample=args.num_sample) |
|
|
print("Data Aug = %s" % str(transform)) |
|
|
print("Data Aug for SSV2 = %s" % str(transform_ssv2)) |
|
|
return dataset |
|
|
|
|
|
|
|
|
def build_dataset(is_train, test_mode, args): |
|
|
print(f'Use Dataset: {args.data_set}') |
|
|
if args.data_set in [ |
|
|
'Kinetics', |
|
|
'Kinetics_sparse', |
|
|
'mitv1_sparse' |
|
|
]: |
|
|
mode = None |
|
|
anno_path = None |
|
|
if is_train is True: |
|
|
mode = 'train' |
|
|
anno_path = os.path.join(args.data_path, 'train.csv') |
|
|
elif test_mode is True: |
|
|
mode = 'test' |
|
|
anno_path = os.path.join(args.data_path, 'test.csv') |
|
|
else: |
|
|
mode = 'validation' |
|
|
anno_path = os.path.join(args.data_path, 'val.csv') |
|
|
|
|
|
if 'sparse' in args.data_set: |
|
|
func = VideoClsDataset_sparse |
|
|
else: |
|
|
func = VideoClsDataset |
|
|
|
|
|
dataset = func( |
|
|
anno_path=anno_path, |
|
|
prefix=args.prefix, |
|
|
split=args.split, |
|
|
mode=mode, |
|
|
clip_len=args.num_frames, |
|
|
frame_sample_rate=args.sampling_rate, |
|
|
num_segment=1, |
|
|
test_num_segment=args.test_num_segment, |
|
|
test_num_crop=args.test_num_crop, |
|
|
num_crop=1 if not test_mode else 3, |
|
|
keep_aspect_ratio=True, |
|
|
crop_size=args.input_size, |
|
|
short_side_size=args.short_side_size, |
|
|
new_height=256, |
|
|
new_width=320, |
|
|
args=args) |
|
|
|
|
|
nb_classes = args.nb_classes |
|
|
|
|
|
elif args.data_set == 'SSV2': |
|
|
mode = None |
|
|
anno_path = None |
|
|
if is_train is True: |
|
|
mode = 'train' |
|
|
anno_path = os.path.join(args.data_path, 'train.csv') |
|
|
elif test_mode is True: |
|
|
mode = 'test' |
|
|
anno_path = os.path.join(args.data_path, 'test.csv') |
|
|
else: |
|
|
mode = 'validation' |
|
|
anno_path = os.path.join(args.data_path, 'val.csv') |
|
|
|
|
|
if args.use_decord: |
|
|
func = SSVideoClsDataset |
|
|
else: |
|
|
func = SSRawFrameClsDataset |
|
|
|
|
|
dataset = func( |
|
|
anno_path=anno_path, |
|
|
prefix=args.prefix, |
|
|
split=args.split, |
|
|
mode=mode, |
|
|
clip_len=1, |
|
|
num_segment=args.num_frames, |
|
|
test_num_segment=args.test_num_segment, |
|
|
test_num_crop=args.test_num_crop, |
|
|
num_crop=1 if not test_mode else 3, |
|
|
keep_aspect_ratio=True, |
|
|
crop_size=args.input_size, |
|
|
short_side_size=args.short_side_size, |
|
|
new_height=256, |
|
|
new_width=320, |
|
|
filename_tmpl=args.filename_tmpl, |
|
|
args=args) |
|
|
nb_classes = 174 |
|
|
|
|
|
elif args.data_set == 'UCF101': |
|
|
mode = None |
|
|
anno_path = None |
|
|
if is_train is True: |
|
|
mode = 'train' |
|
|
anno_path = os.path.join(args.data_path, 'train.csv') |
|
|
elif test_mode is True: |
|
|
mode = 'test' |
|
|
anno_path = os.path.join(args.data_path, 'test.csv') |
|
|
else: |
|
|
mode = 'validation' |
|
|
anno_path = os.path.join(args.data_path, 'val.csv') |
|
|
|
|
|
dataset = VideoClsDataset( |
|
|
anno_path=anno_path, |
|
|
prefix=args.prefix, |
|
|
split=args.split, |
|
|
mode=mode, |
|
|
clip_len=args.num_frames, |
|
|
frame_sample_rate=args.sampling_rate, |
|
|
num_segment=1, |
|
|
test_num_segment=args.test_num_segment, |
|
|
test_num_crop=args.test_num_crop, |
|
|
num_crop=1 if not test_mode else 3, |
|
|
keep_aspect_ratio=True, |
|
|
crop_size=args.input_size, |
|
|
short_side_size=args.short_side_size, |
|
|
new_height=256, |
|
|
new_width=320, |
|
|
args=args) |
|
|
nb_classes = 101 |
|
|
|
|
|
elif args.data_set == 'HMDB51': |
|
|
mode = None |
|
|
anno_path = None |
|
|
if is_train is True: |
|
|
mode = 'train' |
|
|
anno_path = os.path.join(args.data_path, 'train.csv') |
|
|
elif test_mode is True: |
|
|
mode = 'test' |
|
|
anno_path = os.path.join(args.data_path, 'test.csv') |
|
|
else: |
|
|
mode = 'validation' |
|
|
anno_path = os.path.join(args.data_path, 'val.csv') |
|
|
|
|
|
if args.use_decord: |
|
|
func = HMDBVideoClsDataset |
|
|
else: |
|
|
func = HMDBRawFrameClsDataset |
|
|
|
|
|
dataset = func( |
|
|
anno_path=anno_path, |
|
|
prefix=args.prefix, |
|
|
split=args.split, |
|
|
mode=mode, |
|
|
clip_len=1, |
|
|
num_segment=args.num_frames, |
|
|
test_num_segment=args.test_num_segment, |
|
|
test_num_crop=args.test_num_crop, |
|
|
num_crop=1 if not test_mode else 3, |
|
|
keep_aspect_ratio=True, |
|
|
crop_size=args.input_size, |
|
|
short_side_size=args.short_side_size, |
|
|
new_height=256, |
|
|
new_width=320, |
|
|
filename_tmpl=args.filename_tmpl, |
|
|
args=args) |
|
|
nb_classes = 51 |
|
|
|
|
|
elif args.data_set in [ |
|
|
'ANet', |
|
|
'HACS', |
|
|
'ANet_interval', |
|
|
'HACS_interval' |
|
|
]: |
|
|
mode = None |
|
|
anno_path = None |
|
|
if is_train is True: |
|
|
mode = 'train' |
|
|
anno_path = os.path.join(args.data_path, 'train.csv') |
|
|
elif test_mode is True: |
|
|
mode = 'test' |
|
|
anno_path = os.path.join(args.data_path, 'test.csv') |
|
|
else: |
|
|
mode = 'validation' |
|
|
anno_path = os.path.join(args.data_path, 'val.csv') |
|
|
|
|
|
if 'interval' in args.data_set: |
|
|
func = ANetDataset |
|
|
else: |
|
|
func = VideoClsDataset_sparse |
|
|
|
|
|
dataset = func( |
|
|
anno_path=anno_path, |
|
|
prefix=args.prefix, |
|
|
split=args.split, |
|
|
mode=mode, |
|
|
clip_len=args.num_frames, |
|
|
frame_sample_rate=args.sampling_rate, |
|
|
num_segment=1, |
|
|
test_num_segment=args.test_num_segment, |
|
|
test_num_crop=args.test_num_crop, |
|
|
num_crop=1 if not test_mode else 3, |
|
|
keep_aspect_ratio=True, |
|
|
crop_size=args.input_size, |
|
|
short_side_size=args.short_side_size, |
|
|
new_height=256, |
|
|
new_width=320, |
|
|
args=args) |
|
|
nb_classes = args.nb_classes |
|
|
|
|
|
elif args.data_set == 'basketball': |
|
|
mode = None |
|
|
anno_path = None |
|
|
if is_train is True: |
|
|
mode = 'train' |
|
|
anno_path = os.path.join(args.data_path, 'train.csv') |
|
|
elif test_mode is True: |
|
|
mode = 'test' |
|
|
anno_path = os.path.join(args.data_path, 'test.csv') |
|
|
else: |
|
|
mode = 'validation' |
|
|
anno_path = os.path.join(args.data_path, 'val.csv') |
|
|
|
|
|
dataset = VideoClsDataset( |
|
|
anno_path=anno_path, |
|
|
prefix=args.prefix, |
|
|
split=args.split, |
|
|
mode=mode, |
|
|
clip_len=args.num_frames, |
|
|
frame_sample_rate=args.sampling_rate, |
|
|
num_segment=1, |
|
|
test_num_segment=args.test_num_segment, |
|
|
test_num_crop=args.test_num_crop, |
|
|
num_crop=1 if not test_mode else 3, |
|
|
keep_aspect_ratio=True, |
|
|
crop_size=args.input_size, |
|
|
short_side_size=args.short_side_size, |
|
|
new_height=256, |
|
|
new_width=320, |
|
|
args=args) |
|
|
nb_classes = 5 |
|
|
|
|
|
else: |
|
|
print(f'Wrong: {args.data_set}') |
|
|
raise NotImplementedError() |
|
|
assert nb_classes == args.nb_classes |
|
|
print("Number of the class = %d" % args.nb_classes) |
|
|
|
|
|
return dataset, nb_classes |
|
|
|