| |
|
|
| |
| import glob |
| import torch |
| import torchvision |
| import matplotlib.pyplot as plt |
| from torch.utils.data import Dataset |
| import numpy as np |
|
|
|
|
| |
| class MomentsDataset(Dataset): |
| def __init__(self, videos_folder, num_frames, samples_per_video, frame_size=512) -> None: |
| super().__init__() |
| |
| self.videos_paths = glob.glob(f'{videos_folder}/*mp4') |
| self.resize = torchvision.transforms.Resize(size=frame_size) |
| self.center_crop = torchvision.transforms.CenterCrop(size=frame_size) |
| self.num_samples_per_video = samples_per_video |
| self.num_frames = num_frames |
|
|
| def __len__(self): |
| return len(self.videos_paths) * self.num_samples_per_video |
| |
| def __getitem__(self, idx): |
| video_idx = idx // self.num_samples_per_video |
| video_path = self.videos_paths[video_idx] |
| |
| try: |
| start_idx = np.random.randint(0, 20) |
| |
| unsampled_video_frames, audio_frames, info = torchvision.io.read_video(video_path,output_format="TCHW") |
| sampled_indices = torch.tensor(np.linspace(start_idx, len(unsampled_video_frames)-1, self.num_frames).astype(int)) |
| sampled_frames = unsampled_video_frames[sampled_indices] |
| processed_frames = [] |
|
|
| for frame in sampled_frames: |
| resized_cropped_frame = self.center_crop(self.resize(frame)) |
| processed_frames.append(resized_cropped_frame) |
| frames = torch.stack(processed_frames, dim=0) |
| frames = frames.float() / 255.0 |
| except Exception as e: |
| print('oops', e) |
| rand_idx = np.random.randint(0, len(self)) |
| return self.__getitem__(rand_idx) |
| |
| out_dict = {'frames': frames, |
| 'caption': 'none', |
| 'keywords': 'none'} |
| |
| return out_dict |
| |
|
|
|
|