Spaces:
Sleeping
Sleeping
| import cv2 | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| from PIL import Image | |
| class VideoDataset(Dataset): | |
| def __init__(self, path: str, transforms: any = None): | |
| self.cap = cv2.VideoCapture(path) | |
| self.transforms = transforms | |
| self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS) | |
| self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| def __len__(self): | |
| return self.frame_count | |
| def __getitem__(self, idx): | |
| if isinstance(idx, slice): | |
| return [self[i] for i in range(*idx.indices(len(self)))] | |
| if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx: | |
| self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, img = self.cap.read() | |
| if not ret: | |
| raise IndexError(f'Idx: {idx} out of length: {len(self)}') | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| img = Image.fromarray(img) | |
| if self.transforms: | |
| img = self.transforms(img) | |
| return img | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_value, exc_traceback): | |
| self.cap.release() | |