Spaces:
Runtime error
Runtime error
| import os | |
| from torchvision import transforms | |
| from .transforms import * | |
| from .masking_generator import TubeMaskingGenerator, RandomMaskingGenerator | |
| from .mae import VideoMAE | |
| from .mae_multi import VideoMAE_multi | |
| from .kinetics import VideoClsDataset | |
| from .kinetics_sparse import VideoClsDataset_sparse | |
| from .anet import ANetDataset | |
| from .ssv2 import SSVideoClsDataset, SSRawFrameClsDataset | |
| from .hmdb import HMDBVideoClsDataset, HMDBRawFrameClsDataset | |
| class DataAugmentationForVideoMAE(object): | |
| def __init__(self, args): | |
| self.input_mean = [0.485, 0.456, 0.406] # IMAGENET_DEFAULT_MEAN | |
| self.input_std = [0.229, 0.224, 0.225] # IMAGENET_DEFAULT_STD | |
| normalize = GroupNormalize(self.input_mean, self.input_std) | |
| self.train_augmentation = GroupMultiScaleCrop(args.input_size, [1, .875, .75, .66]) | |
| if args.color_jitter > 0: | |
| self.transform = transforms.Compose([ | |
| self.train_augmentation, | |
| GroupColorJitter(args.color_jitter), | |
| GroupRandomHorizontalFlip(flip=args.flip), | |
| Stack(roll=False), | |
| ToTorchFormatTensor(div=True), | |
| normalize, | |
| ]) | |
| else: | |
| self.transform = transforms.Compose([ | |
| self.train_augmentation, | |
| GroupRandomHorizontalFlip(flip=args.flip), | |
| Stack(roll=False), | |
| ToTorchFormatTensor(div=True), | |
| normalize, | |
| ]) | |
| if args.mask_type == 'tube': | |
| self.masked_position_generator = TubeMaskingGenerator( | |
| args.window_size, args.mask_ratio | |
| ) | |
| elif args.mask_type == 'random': | |
| self.masked_position_generator = RandomMaskingGenerator( | |
| args.window_size, args.mask_ratio | |
| ) | |
| elif args.mask_type in 'attention': | |
| self.masked_position_generator = None | |
| def __call__(self, images): | |
| process_data, _ = self.transform(images) | |
| if self.masked_position_generator is None: | |
| return process_data, -1 | |
| else: | |
| return process_data, self.masked_position_generator() | |
| def __repr__(self): | |
| repr = "(DataAugmentationForVideoMAE,\n" | |
| repr += " transform = %s,\n" % str(self.transform) | |
| repr += " Masked position generator = %s,\n" % str(self.masked_position_generator) | |
| repr += ")" | |
| return repr | |
| def build_pretraining_dataset(args): | |
| transform = DataAugmentationForVideoMAE(args) | |
| dataset = VideoMAE( | |
| root=None, | |
| setting=args.data_path, | |
| prefix=args.prefix, | |
| split=args.split, | |
| video_ext='mp4', | |
| is_color=True, | |
| modality='rgb', | |
| num_segments=args.num_segments, | |
| new_length=args.num_frames, | |
| new_step=args.sampling_rate, | |
| transform=transform, | |
| temporal_jitter=False, | |
| video_loader=True, | |
| use_decord=args.use_decord, | |
| lazy_init=False, | |
| num_sample=args.num_sample) | |
| print("Data Aug = %s" % str(transform)) | |
| return dataset | |
| def build_multi_pretraining_dataset(args): | |
| origianl_flip = args.flip | |
| transform = DataAugmentationForVideoMAE(args) | |
| args.flip = False | |
| transform_ssv2 = DataAugmentationForVideoMAE(args) | |
| args.flip = origianl_flip | |
| dataset = VideoMAE_multi( | |
| root=None, | |
| setting=args.data_path, | |
| prefix=args.prefix, | |
| split=args.split, | |
| is_color=True, | |
| modality='rgb', | |
| num_segments=args.num_segments, | |
| new_length=args.num_frames, | |
| new_step=args.sampling_rate, | |
| transform=transform, | |
| transform_ssv2=transform_ssv2, | |
| temporal_jitter=False, | |
| video_loader=True, | |
| use_decord=args.use_decord, | |
| lazy_init=False, | |
| num_sample=args.num_sample) | |
| print("Data Aug = %s" % str(transform)) | |
| print("Data Aug for SSV2 = %s" % str(transform_ssv2)) | |
| return dataset | |
| def build_dataset(is_train, test_mode, args): | |
| print(f'Use Dataset: {args.data_set}') | |
| if args.data_set in [ | |
| 'Kinetics', | |
| 'Kinetics_sparse', | |
| 'mitv1_sparse' | |
| ]: | |
| mode = None | |
| anno_path = None | |
| if is_train is True: | |
| mode = 'train' | |
| anno_path = os.path.join(args.data_path, 'train.csv') | |
| elif test_mode is True: | |
| mode = 'test' | |
| anno_path = os.path.join(args.data_path, 'test.csv') | |
| else: | |
| mode = 'validation' | |
| anno_path = os.path.join(args.data_path, 'val.csv') | |
| if 'sparse' in args.data_set: | |
| func = VideoClsDataset_sparse | |
| else: | |
| func = VideoClsDataset | |
| dataset = func( | |
| anno_path=anno_path, | |
| prefix=args.prefix, | |
| split=args.split, | |
| mode=mode, | |
| clip_len=args.num_frames, | |
| frame_sample_rate=args.sampling_rate, | |
| num_segment=1, | |
| test_num_segment=args.test_num_segment, | |
| test_num_crop=args.test_num_crop, | |
| num_crop=1 if not test_mode else 3, | |
| keep_aspect_ratio=True, | |
| crop_size=args.input_size, | |
| short_side_size=args.short_side_size, | |
| new_height=256, | |
| new_width=320, | |
| args=args) | |
| nb_classes = args.nb_classes | |
| elif args.data_set == 'SSV2': | |
| mode = None | |
| anno_path = None | |
| if is_train is True: | |
| mode = 'train' | |
| anno_path = os.path.join(args.data_path, 'train.csv') | |
| elif test_mode is True: | |
| mode = 'test' | |
| anno_path = os.path.join(args.data_path, 'test.csv') | |
| else: | |
| mode = 'validation' | |
| anno_path = os.path.join(args.data_path, 'val.csv') | |
| if args.use_decord: | |
| func = SSVideoClsDataset | |
| else: | |
| func = SSRawFrameClsDataset | |
| dataset = func( | |
| anno_path=anno_path, | |
| prefix=args.prefix, | |
| split=args.split, | |
| mode=mode, | |
| clip_len=1, | |
| num_segment=args.num_frames, | |
| test_num_segment=args.test_num_segment, | |
| test_num_crop=args.test_num_crop, | |
| num_crop=1 if not test_mode else 3, | |
| keep_aspect_ratio=True, | |
| crop_size=args.input_size, | |
| short_side_size=args.short_side_size, | |
| new_height=256, | |
| new_width=320, | |
| filename_tmpl=args.filename_tmpl, | |
| args=args) | |
| nb_classes = 174 | |
| elif args.data_set == 'UCF101': | |
| mode = None | |
| anno_path = None | |
| if is_train is True: | |
| mode = 'train' | |
| anno_path = os.path.join(args.data_path, 'train.csv') | |
| elif test_mode is True: | |
| mode = 'test' | |
| anno_path = os.path.join(args.data_path, 'test.csv') | |
| else: | |
| mode = 'validation' | |
| anno_path = os.path.join(args.data_path, 'val.csv') | |
| dataset = VideoClsDataset( | |
| anno_path=anno_path, | |
| prefix=args.prefix, | |
| split=args.split, | |
| mode=mode, | |
| clip_len=args.num_frames, | |
| frame_sample_rate=args.sampling_rate, | |
| num_segment=1, | |
| test_num_segment=args.test_num_segment, | |
| test_num_crop=args.test_num_crop, | |
| num_crop=1 if not test_mode else 3, | |
| keep_aspect_ratio=True, | |
| crop_size=args.input_size, | |
| short_side_size=args.short_side_size, | |
| new_height=256, | |
| new_width=320, | |
| args=args) | |
| nb_classes = 101 | |
| elif args.data_set == 'HMDB51': | |
| mode = None | |
| anno_path = None | |
| if is_train is True: | |
| mode = 'train' | |
| anno_path = os.path.join(args.data_path, 'train.csv') | |
| elif test_mode is True: | |
| mode = 'test' | |
| anno_path = os.path.join(args.data_path, 'test.csv') | |
| else: | |
| mode = 'validation' | |
| anno_path = os.path.join(args.data_path, 'val.csv') | |
| if args.use_decord: | |
| func = HMDBVideoClsDataset | |
| else: | |
| func = HMDBRawFrameClsDataset | |
| dataset = func( | |
| anno_path=anno_path, | |
| prefix=args.prefix, | |
| split=args.split, | |
| mode=mode, | |
| clip_len=1, | |
| num_segment=args.num_frames, | |
| test_num_segment=args.test_num_segment, | |
| test_num_crop=args.test_num_crop, | |
| num_crop=1 if not test_mode else 3, | |
| keep_aspect_ratio=True, | |
| crop_size=args.input_size, | |
| short_side_size=args.short_side_size, | |
| new_height=256, | |
| new_width=320, | |
| filename_tmpl=args.filename_tmpl, | |
| args=args) | |
| nb_classes = 51 | |
| elif args.data_set in [ | |
| 'ANet', | |
| 'HACS', | |
| 'ANet_interval', | |
| 'HACS_interval' | |
| ]: | |
| mode = None | |
| anno_path = None | |
| if is_train is True: | |
| mode = 'train' | |
| anno_path = os.path.join(args.data_path, 'train.csv') | |
| elif test_mode is True: | |
| mode = 'test' | |
| anno_path = os.path.join(args.data_path, 'test.csv') | |
| else: | |
| mode = 'validation' | |
| anno_path = os.path.join(args.data_path, 'val.csv') | |
| if 'interval' in args.data_set: | |
| func = ANetDataset | |
| else: | |
| func = VideoClsDataset_sparse | |
| dataset = func( | |
| anno_path=anno_path, | |
| prefix=args.prefix, | |
| split=args.split, | |
| mode=mode, | |
| clip_len=args.num_frames, | |
| frame_sample_rate=args.sampling_rate, | |
| num_segment=1, | |
| test_num_segment=args.test_num_segment, | |
| test_num_crop=args.test_num_crop, | |
| num_crop=1 if not test_mode else 3, | |
| keep_aspect_ratio=True, | |
| crop_size=args.input_size, | |
| short_side_size=args.short_side_size, | |
| new_height=256, | |
| new_width=320, | |
| args=args) | |
| nb_classes = args.nb_classes | |
| else: | |
| print(f'Wrong: {args.data_set}') | |
| raise NotImplementedError() | |
| assert nb_classes == args.nb_classes | |
| print("Number of the class = %d" % args.nb_classes) | |
| return dataset, nb_classes | |