| |
| |
| |
| |
| |
| |
| import logging |
| import os |
| import pathlib |
| import warnings |
|
|
| from logging import getLogger |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| from decord import VideoReader, cpu |
|
|
| import torch |
|
|
| from video_classification_frozen.utils import make_transforms |
| from src.datasets.utils.weighted_sampler import DistributedWeightedSampler |
|
|
| _GLOBAL_SEED = 0 |
| logger = getLogger() |
|
|
|
|
| def make_videodataset( |
| data_paths, |
| batch_size, |
| frames_per_clip=8, |
| frame_step=4, |
| num_clips=1, |
| random_clip_sampling=True, |
| allow_clip_overlap=False, |
| filter_short_videos=False, |
| filter_long_videos=int(10**9), |
| transform=None, |
| shared_transform=None, |
| rank=0, |
| world_size=1, |
| datasets_weights=None, |
| collator=None, |
| drop_last=True, |
| num_workers=10, |
| pin_mem=True, |
| duration=None, |
| log_dir=None, |
| ): |
| dataset = VideoDataset( |
| data_paths=data_paths, |
| datasets_weights=datasets_weights, |
| frames_per_clip=frames_per_clip, |
| frame_step=frame_step, |
| num_clips=num_clips, |
| random_clip_sampling=random_clip_sampling, |
| allow_clip_overlap=allow_clip_overlap, |
| filter_short_videos=filter_short_videos, |
| filter_long_videos=filter_long_videos, |
| duration=duration, |
| shared_transform=shared_transform, |
| transform=transform) |
|
|
| logger.info('VideoDataset dataset created') |
| if datasets_weights is not None: |
| dist_sampler = DistributedWeightedSampler( |
| dataset.sample_weights, |
| num_replicas=world_size, |
| rank=rank, |
| shuffle=True) |
| else: |
| dist_sampler = torch.utils.data.distributed.DistributedSampler( |
| dataset, |
| num_replicas=world_size, |
| rank=rank, |
| shuffle=True) |
|
|
| data_loader = torch.utils.data.DataLoader( |
| dataset, |
| collate_fn=collator, |
| sampler=dist_sampler, |
| batch_size=batch_size, |
| drop_last=drop_last, |
| pin_memory=pin_mem, |
| num_workers=num_workers, |
| persistent_workers=num_workers > 0) |
| logger.info('VideoDataset unsupervised data loader created') |
|
|
| return dataset, data_loader, dist_sampler |
|
|
|
|
| class VideoDataset(torch.utils.data.Dataset): |
| """ Video classification dataset. """ |
|
|
| def __init__( |
| self, |
| data_paths, |
| datasets_weights=None, |
| frames_per_clip=16, |
| frame_step=4, |
| num_clips=1, |
| transform=None, |
| shared_transform=None, |
| random_clip_sampling=True, |
| allow_clip_overlap=True, |
| filter_short_videos=False, |
| filter_long_videos=int(10**9), |
| duration=None, |
| ): |
| self.data_paths = data_paths |
| self.datasets_weights = datasets_weights |
| self.frames_per_clip = frames_per_clip |
| self.frame_step = frame_step |
| self.num_clips = num_clips |
| self.transform = transform |
| self.shared_transform = shared_transform |
| self.random_clip_sampling = random_clip_sampling |
| self.allow_clip_overlap = allow_clip_overlap |
| self.filter_short_videos = filter_short_videos |
| self.filter_long_videos = filter_long_videos |
| self.duration = duration |
|
|
| if VideoReader is None: |
| raise ImportError('Unable to import "decord" which is required to read videos.') |
|
|
| |
| samples, labels = [], [] |
| self.num_samples_per_dataset = [] |
| for data_path in self.data_paths: |
|
|
| if data_path[-4:] == '.csv': |
| data = pd.read_csv(data_path, header=None, delimiter=" ") |
| samples += list(data.values[:, 0]) |
| labels += list(data.values[:, 1]) |
| num_samples = len(data) |
| self.num_samples_per_dataset.append(num_samples) |
|
|
| elif data_path[-4:] == '.npy': |
| data = np.load(data_path, allow_pickle=True) |
| data = list(map(lambda x: repr(x)[1:-1], data)) |
| samples += data |
| labels += [0] * len(data) |
| num_samples = len(data) |
| self.num_samples_per_dataset.append(len(data)) |
|
|
| |
| |
| self.sample_weights = None |
| if self.datasets_weights is not None: |
| self.sample_weights = [] |
| for dw, ns in zip(self.datasets_weights, self.num_samples_per_dataset): |
| self.sample_weights += [dw / ns] * ns |
|
|
| self.samples = samples |
| self.labels = labels |
|
|
| def __getitem__(self, index): |
| sample = self.samples[index] |
|
|
| |
| loaded_video = False |
| while not loaded_video: |
| buffer, clip_indices = self.loadvideo_decord(sample) |
| loaded_video = len(buffer) > 0 |
| if not loaded_video: |
| index = np.random.randint(self.__len__()) |
| sample = self.samples[index] |
|
|
| |
| label = self.labels[index] |
|
|
| def split_into_clips(video): |
| """ Split video into a list of clips """ |
| fpc = self.frames_per_clip |
| nc = self.num_clips |
| return [video[i*fpc:(i+1)*fpc] for i in range(nc)] |
|
|
| |
| if self.shared_transform is not None: |
| buffer = self.shared_transform(buffer) |
| buffer = split_into_clips(buffer) |
| if self.transform is not None: |
| buffer = [self.transform(clip) for clip in buffer] |
| return buffer, label, clip_indices, sample |
|
|
| def loadvideo_decord(self, sample): |
| """ Load video content using Decord """ |
| |
| fname = sample |
| if not os.path.exists(fname): |
| warnings.warn(f'video path not found {fname=}') |
| return [], None |
|
|
| _fsize = os.path.getsize(fname) |
| if _fsize < 1 * 1024: |
| warnings.warn(f'video too short {fname=}') |
| return [], None |
| if _fsize > self.filter_long_videos: |
| warnings.warn(f'skipping long video of size {_fsize=} (bytes)') |
| return [], None |
|
|
| try: |
| vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) |
| except Exception: |
| return [], None |
|
|
| |
|
|
| fpc = self.frames_per_clip |
| fstp = self.frame_step |
| if self.duration is not None: |
| try: |
| fps = vr.get_avg_fps() |
| fstp = int(self.duration * fps / fpc) |
| except Exception as e: |
| warnings.warn(e) |
| clip_len = int(fpc * fstp) |
| |
| if self.filter_short_videos and len(vr) < clip_len: |
| warnings.warn(f'skipping video of length {len(vr)}') |
| return [], None |
|
|
| vr.seek(0) |
|
|
| |
| |
| partition_len = len(vr) // self.num_clips |
|
|
| all_indices, clip_indices = [], [] |
| for i in range(self.num_clips): |
|
|
| if partition_len > clip_len: |
| |
| |
| end_indx = clip_len |
| if self.random_clip_sampling: |
| end_indx = np.random.randint(clip_len, partition_len) |
| start_indx = end_indx - clip_len |
| indices = np.linspace(start_indx, end_indx, num=fpc) |
| indices = np.clip(indices, start_indx, end_indx-1).astype(np.int64) |
| |
| indices = indices + i * partition_len |
| else: |
| |
| |
| |
| if not self.allow_clip_overlap: |
| indices = np.linspace(0, partition_len, num=partition_len // fstp) |
| indices = np.concatenate((indices, np.ones(fpc - partition_len // fstp) * partition_len,)) |
| indices = np.clip(indices, 0, partition_len-1).astype(np.int64) |
| |
| indices = indices + i * partition_len |
|
|
| |
| |
| else: |
| sample_len = min(clip_len, len(vr)) - 1 |
| indices = np.linspace(0, sample_len, num=sample_len // fstp) |
| indices = np.concatenate((indices, np.ones(fpc - sample_len // fstp) * sample_len,)) |
| indices = np.clip(indices, 0, sample_len-1).astype(np.int64) |
| |
| clip_step = 0 |
| if len(vr) > clip_len: |
| clip_step = (len(vr) - clip_len) // (self.num_clips - 1) |
| indices = indices + i * clip_step |
|
|
| clip_indices.append(indices) |
| all_indices.extend(list(indices)) |
|
|
| buffer = vr.get_batch(all_indices).asnumpy() |
| return buffer, clip_indices |
|
|
| def __len__(self): |
| return len(self.samples) |
|
|
|
|
| if __name__ == '__main__': |
| train_data_paths = ['D:/Softwares/kube/jepa-main/video_dataset/train_data.csv'] |
| |
| |
| transform = make_transforms( |
| training=True, |
| num_views_per_clip=3, |
| random_horizontal_flip=False, |
| random_resize_aspect_ratio=(0.75, 4 / 3), |
| random_resize_scale=(0.08, 1.0), |
| reprob=0.25, |
| auto_augment=True, |
| motion_shift=False, |
| crop_size=224, |
| ) |
| dataset = VideoDataset( |
| data_paths=train_data_paths, |
| datasets_weights=None, |
| frames_per_clip=16, |
| frame_step=2, |
| num_clips=2, |
| random_clip_sampling=True, |
| allow_clip_overlap=True, |
| filter_short_videos=False, |
| filter_long_videos=int(1e9), |
| duration=None, |
| shared_transform=None, |
| transform=transform) |
| item = dataset.__getitem__(0) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| print("test video loader!") |