| import os |
| import re |
|
|
| import numpy as np |
| import pandas as pd |
| import requests |
| import torch |
| import torchvision |
| import torchvision.transforms as transforms |
| from PIL import Image |
| from torchvision.datasets.folder import IMG_EXTENSIONS, pil_loader |
| from torchvision.io import write_video |
| from torchvision.utils import save_image |
|
|
|
|
| VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") |
|
|
| regex = re.compile( |
| r"^(?:http|ftp)s?://" |
| r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" |
| r"localhost|" |
| r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" |
| r"(?::\d+)?" |
| r"(?:/?|[/?]\S+)$", |
| re.IGNORECASE, |
| ) |
|
|
| import numbers |
| import random |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| def _is_tensor_video_clip(clip): |
| if not torch.is_tensor(clip): |
| raise TypeError("clip should be Tensor. Got %s" % type(clip)) |
|
|
| if not clip.ndimension() == 4: |
| raise ValueError("clip should be 4D. Got %dD" % clip.dim()) |
|
|
| return True |
|
|
|
|
| def crop(clip, i, j, h, w): |
| """ |
| Args: |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
| """ |
| if len(clip.size()) != 4: |
| raise ValueError("clip should be a 4D tensor") |
| return clip[..., i : i + h, j : j + w] |
|
|
|
|
| def resize(clip, target_size, interpolation_mode): |
| if len(target_size) != 2: |
| raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") |
| return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) |
|
|
|
|
| def resize_scale(clip, target_size, interpolation_mode): |
| if len(target_size) != 2: |
| raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") |
| H, W = clip.size(-2), clip.size(-1) |
| scale_ = target_size[0] / min(H, W) |
| return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False) |
|
|
|
|
| def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): |
| """ |
| Do spatial cropping and resizing to the video clip |
| Args: |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
| i (int): i in (i,j) i.e coordinates of the upper left corner. |
| j (int): j in (i,j) i.e coordinates of the upper left corner. |
| h (int): Height of the cropped region. |
| w (int): Width of the cropped region. |
| size (tuple(int, int)): height and width of resized clip |
| Returns: |
| clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W) |
| """ |
| if not _is_tensor_video_clip(clip): |
| raise ValueError("clip should be a 4D torch.tensor") |
| clip = crop(clip, i, j, h, w) |
| clip = resize(clip, size, interpolation_mode) |
| return clip |
|
|
|
|
| def center_crop(clip, crop_size): |
| if not _is_tensor_video_clip(clip): |
| raise ValueError("clip should be a 4D torch.tensor") |
| h, w = clip.size(-2), clip.size(-1) |
| th, tw = crop_size |
| if h < th or w < tw: |
| raise ValueError("height and width must be no smaller than crop_size") |
|
|
| i = int(round((h - th) / 2.0)) |
| j = int(round((w - tw) / 2.0)) |
| return crop(clip, i, j, th, tw) |
|
|
|
|
| def center_crop_using_short_edge(clip): |
| if not _is_tensor_video_clip(clip): |
| raise ValueError("clip should be a 4D torch.tensor") |
| h, w = clip.size(-2), clip.size(-1) |
| if h < w: |
| th, tw = h, h |
| i = 0 |
| j = int(round((w - tw) / 2.0)) |
| else: |
| th, tw = w, w |
| i = int(round((h - th) / 2.0)) |
| j = 0 |
| return crop(clip, i, j, th, tw) |
|
|
|
|
| def resize_crop_to_fill(clip, target_size): |
| if not _is_tensor_video_clip(clip): |
| raise ValueError("clip should be a 4D torch.tensor") |
| h, w = clip.size(-2), clip.size(-1) |
| th, tw = target_size[0], target_size[1] |
| rh, rw = th / h, tw / w |
| if rh > rw: |
| sh, sw = th, round(w * rh) |
| clip = resize(clip, (sh, sw), "bilinear") |
| i = 0 |
| j = int(round(sw - tw) / 2.0) |
| else: |
| sh, sw = round(h * rw), tw |
| clip = resize(clip, (sh, sw), "bilinear") |
| i = int(round(sh - th) / 2.0) |
| j = 0 |
| assert i + th <= clip.size(-2) and j + tw <= clip.size(-1) |
| return crop(clip, i, j, th, tw) |
|
|
|
|
| def random_shift_crop(clip): |
| """ |
| Slide along the long edge, with the short edge as crop size |
| """ |
| if not _is_tensor_video_clip(clip): |
| raise ValueError("clip should be a 4D torch.tensor") |
| h, w = clip.size(-2), clip.size(-1) |
|
|
| if h <= w: |
| short_edge = h |
| else: |
| short_edge = w |
|
|
| th, tw = short_edge, short_edge |
|
|
| i = torch.randint(0, h - th + 1, size=(1,)).item() |
| j = torch.randint(0, w - tw + 1, size=(1,)).item() |
| return crop(clip, i, j, th, tw) |
|
|
|
|
| def to_tensor(clip): |
| """ |
| Convert tensor data type from uint8 to float, divide value by 255.0 and |
| permute the dimensions of clip tensor |
| Args: |
| clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) |
| Return: |
| clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) |
| """ |
| _is_tensor_video_clip(clip) |
| if not clip.dtype == torch.uint8: |
| raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) |
| |
| return clip.float() / 255.0 |
|
|
|
|
| def normalize(clip, mean, std, inplace=False): |
| """ |
| Args: |
| clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) |
| mean (tuple): pixel RGB mean. Size is (3) |
| std (tuple): pixel standard deviation. Size is (3) |
| Returns: |
| normalized clip (torch.tensor): Size is (T, C, H, W) |
| """ |
| if not _is_tensor_video_clip(clip): |
| raise ValueError("clip should be a 4D torch.tensor") |
| if not inplace: |
| clip = clip.clone() |
| mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) |
| |
| std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) |
| clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) |
| return clip |
|
|
|
|
| def hflip(clip): |
| """ |
| Args: |
| clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W) |
| Returns: |
| flipped clip (torch.tensor): Size is (T, C, H, W) |
| """ |
| if not _is_tensor_video_clip(clip): |
| raise ValueError("clip should be a 4D torch.tensor") |
| return clip.flip(-1) |
|
|
|
|
| class ResizeCrop: |
| def __init__(self, size): |
| if isinstance(size, numbers.Number): |
| self.size = (int(size), int(size)) |
| else: |
| self.size = size |
|
|
| def __call__(self, clip): |
| clip = resize_crop_to_fill(clip, self.size) |
| return clip |
|
|
| def __repr__(self) -> str: |
| return f"{self.__class__.__name__}(size={self.size})" |
|
|
|
|
| class RandomCropVideo: |
| def __init__(self, size): |
| if isinstance(size, numbers.Number): |
| self.size = (int(size), int(size)) |
| else: |
| self.size = size |
|
|
| def __call__(self, clip): |
| """ |
| Args: |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
| Returns: |
| torch.tensor: randomly cropped video clip. |
| size is (T, C, OH, OW) |
| """ |
| i, j, h, w = self.get_params(clip) |
| return crop(clip, i, j, h, w) |
|
|
| def get_params(self, clip): |
| h, w = clip.shape[-2:] |
| th, tw = self.size |
|
|
| if h < th or w < tw: |
| raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") |
|
|
| if w == tw and h == th: |
| return 0, 0, h, w |
|
|
| i = torch.randint(0, h - th + 1, size=(1,)).item() |
| j = torch.randint(0, w - tw + 1, size=(1,)).item() |
|
|
| return i, j, th, tw |
|
|
| def __repr__(self) -> str: |
| return f"{self.__class__.__name__}(size={self.size})" |
|
|
|
|
| class CenterCropResizeVideo: |
| """ |
| First use the short side for cropping length, |
| center crop video, then resize to the specified size |
| """ |
|
|
| def __init__( |
| self, |
| size, |
| interpolation_mode="bilinear", |
| ): |
| if isinstance(size, tuple): |
| if len(size) != 2: |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") |
| self.size = size |
| else: |
| self.size = (size, size) |
|
|
| self.interpolation_mode = interpolation_mode |
|
|
| def __call__(self, clip): |
| """ |
| Args: |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
| Returns: |
| torch.tensor: scale resized / center cropped video clip. |
| size is (T, C, crop_size, crop_size) |
| """ |
| clip_center_crop = center_crop_using_short_edge(clip) |
| clip_center_crop_resize = resize( |
| clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode |
| ) |
| return clip_center_crop_resize |
|
|
| def __repr__(self) -> str: |
| return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
|
| class UCFCenterCropVideo: |
| """ |
| First scale to the specified size in equal proportion to the short edge, |
| then center cropping |
| """ |
|
|
| def __init__( |
| self, |
| size, |
| interpolation_mode="bilinear", |
| ): |
| if isinstance(size, tuple): |
| if len(size) != 2: |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") |
| self.size = size |
| else: |
| self.size = (size, size) |
|
|
| self.interpolation_mode = interpolation_mode |
|
|
| def __call__(self, clip): |
| """ |
| Args: |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
| Returns: |
| torch.tensor: scale resized / center cropped video clip. |
| size is (T, C, crop_size, crop_size) |
| """ |
| clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode) |
| clip_center_crop = center_crop(clip_resize, self.size) |
| return clip_center_crop |
|
|
| def __repr__(self) -> str: |
| return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
|
| class KineticsRandomCropResizeVideo: |
| """ |
| Slide along the long edge, with the short edge as crop size. And resie to the desired size. |
| """ |
|
|
| def __init__( |
| self, |
| size, |
| interpolation_mode="bilinear", |
| ): |
| if isinstance(size, tuple): |
| if len(size) != 2: |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") |
| self.size = size |
| else: |
| self.size = (size, size) |
|
|
| self.interpolation_mode = interpolation_mode |
|
|
| def __call__(self, clip): |
| clip_random_crop = random_shift_crop(clip) |
| clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode) |
| return clip_resize |
|
|
|
|
| class CenterCropVideo: |
| def __init__( |
| self, |
| size, |
| interpolation_mode="bilinear", |
| ): |
| if isinstance(size, tuple): |
| if len(size) != 2: |
| raise ValueError(f"size should be tuple (height, width), instead got {size}") |
| self.size = size |
| else: |
| self.size = (size, size) |
|
|
| self.interpolation_mode = interpolation_mode |
|
|
| def __call__(self, clip): |
| """ |
| Args: |
| clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) |
| Returns: |
| torch.tensor: center cropped video clip. |
| size is (T, C, crop_size, crop_size) |
| """ |
| clip_center_crop = center_crop(clip, self.size) |
| return clip_center_crop |
|
|
| def __repr__(self) -> str: |
| return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" |
|
|
|
|
| class NormalizeVideo: |
| """ |
| Normalize the video clip by mean subtraction and division by standard deviation |
| Args: |
| mean (3-tuple): pixel RGB mean |
| std (3-tuple): pixel RGB standard deviation |
| inplace (boolean): whether do in-place normalization |
| """ |
|
|
| def __init__(self, mean, std, inplace=False): |
| self.mean = mean |
| self.std = std |
| self.inplace = inplace |
|
|
| def __call__(self, clip): |
| """ |
| Args: |
| clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W) |
| """ |
| return normalize(clip, self.mean, self.std, self.inplace) |
|
|
| def __repr__(self) -> str: |
| return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})" |
|
|
|
|
| class ToTensorVideo: |
| """ |
| Convert tensor data type from uint8 to float, divide value by 255.0 and |
| permute the dimensions of clip tensor |
| """ |
|
|
| def __init__(self): |
| pass |
|
|
| def __call__(self, clip): |
| """ |
| Args: |
| clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) |
| Return: |
| clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) |
| """ |
| return to_tensor(clip) |
|
|
| def __repr__(self) -> str: |
| return self.__class__.__name__ |
|
|
|
|
| class RandomHorizontalFlipVideo: |
| """ |
| Flip the video clip along the horizontal direction with a given probability |
| Args: |
| p (float): probability of the clip being flipped. Default value is 0.5 |
| """ |
|
|
| def __init__(self, p=0.5): |
| self.p = p |
|
|
| def __call__(self, clip): |
| """ |
| Args: |
| clip (torch.tensor): Size is (T, C, H, W) |
| Return: |
| clip (torch.tensor): Size is (T, C, H, W) |
| """ |
| if random.random() < self.p: |
| clip = hflip(clip) |
| return clip |
|
|
| def __repr__(self) -> str: |
| return f"{self.__class__.__name__}(p={self.p})" |
|
|
|
|
| |
| |
| |
| class TemporalRandomCrop(object): |
| """Temporally crop the given frame indices at a random location. |
| |
| Args: |
| size (int): Desired length of frames will be seen in the model. |
| """ |
|
|
| def __init__(self, size): |
| self.size = size |
|
|
| def __call__(self, total_frames): |
| rand_end = max(0, total_frames - self.size - 1) |
| begin_index = random.randint(0, rand_end) |
| end_index = min(begin_index + self.size, total_frames) |
| return begin_index, end_index |
|
|
|
|
| def is_img(path): |
| ext = os.path.splitext(path)[-1].lower() |
| return ext in IMG_EXTENSIONS |
|
|
|
|
| def is_vid(path): |
| ext = os.path.splitext(path)[-1].lower() |
| return ext in VID_EXTENSIONS |
|
|
|
|
| def is_url(url): |
| return re.match(regex, url) is not None |
|
|
|
|
| def read_file(input_path): |
| if input_path.endswith(".csv"): |
| return pd.read_csv(input_path) |
| elif input_path.endswith(".parquet"): |
| return pd.read_parquet(input_path) |
| else: |
| raise NotImplementedError(f"Unsupported file format: {input_path}") |
|
|
|
|
| def download_url(input_path): |
| output_dir = "cache" |
| os.makedirs(output_dir, exist_ok=True) |
| base_name = os.path.basename(input_path) |
| output_path = os.path.join(output_dir, base_name) |
| img_data = requests.get(input_path).content |
| with open(output_path, "wb") as handler: |
| handler.write(img_data) |
| print(f"URL {input_path} downloaded to {output_path}") |
| return output_path |
|
|
|
|
| def temporal_random_crop(vframes, num_frames, frame_interval): |
| temporal_sample = TemporalRandomCrop(num_frames * frame_interval) |
| total_frames = len(vframes) |
| start_frame_ind, end_frame_ind = temporal_sample(total_frames) |
| assert ( |
| end_frame_ind - start_frame_ind >= num_frames |
| ), f"Not enough frames to sample, {end_frame_ind} - {start_frame_ind} < {num_frames}" |
| frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, num_frames, dtype=int) |
| video = vframes[frame_indice] |
| return video |
|
|
|
|
| def get_transforms_video(name="center", image_size=(256, 256)): |
| if name is None: |
| return None |
| elif name == "center": |
| assert image_size[0] == image_size[1], "image_size must be square for center crop" |
| transform_video = transforms.Compose( |
| [ |
| ToTensorVideo(), |
| |
| UCFCenterCropVideo(image_size[0]), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| ] |
| ) |
| elif name == "resize_crop": |
| transform_video = transforms.Compose( |
| [ |
| ToTensorVideo(), |
| ResizeCrop(image_size), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| ] |
| ) |
| else: |
| raise NotImplementedError(f"Transform {name} not implemented") |
| return transform_video |
|
|
|
|
| def get_transforms_image(name="center", image_size=(256, 256)): |
| if name is None: |
| return None |
| elif name == "center": |
| assert image_size[0] == image_size[1], "Image size must be square for center crop" |
| transform = transforms.Compose( |
| [ |
| transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size[0])), |
| |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| ] |
| ) |
| elif name == "resize_crop": |
| transform = transforms.Compose( |
| [ |
| transforms.Lambda(lambda pil_image: resize_crop_to_fill(pil_image, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), |
| ] |
| ) |
| else: |
| raise NotImplementedError(f"Transform {name} not implemented") |
| return transform |
|
|
|
|
| def read_image_from_path(path, transform=None, transform_name="center", num_frames=1, image_size=(256, 256)): |
| image = pil_loader(path) |
| if transform is None: |
| transform = get_transforms_image(image_size=image_size, name=transform_name) |
| image = transform(image) |
| video = image.unsqueeze(0).repeat(num_frames, 1, 1, 1) |
| video = video.permute(1, 0, 2, 3) |
| return video |
|
|
|
|
| def read_video_from_path(path, transform=None, transform_name="center", image_size=(256, 256)): |
| vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit="sec", output_format="TCHW") |
| if transform is None: |
| transform = get_transforms_video(image_size=image_size, name=transform_name) |
| video = transform(vframes) |
| video = video.permute(1, 0, 2, 3) |
| return video |
|
|
|
|
| def read_from_path(path, image_size, transform_name="center"): |
| if is_url(path): |
| path = download_url(path) |
| ext = os.path.splitext(path)[-1].lower() |
| if ext.lower() in VID_EXTENSIONS: |
| return read_video_from_path(path, image_size=image_size, transform_name=transform_name) |
| else: |
| assert ext.lower() in IMG_EXTENSIONS, f"Unsupported file format: {ext}" |
| return read_image_from_path(path, image_size=image_size, transform_name=transform_name) |
|
|
|
|
| def save_sample(x, save_path=None, fps=8, normalize=True, value_range=(-1, 1), force_video=False, verbose=True): |
| """ |
| Args: |
| x (Tensor): shape [C, T, H, W] |
| """ |
| assert x.ndim == 4 |
|
|
| if not force_video and x.shape[1] == 1: |
| save_path += ".png" |
| x = x.squeeze(1) |
| save_image([x], save_path, normalize=normalize, value_range=value_range) |
| else: |
| save_path += ".mp4" |
| if normalize: |
| low, high = value_range |
| x.clamp_(min=low, max=high) |
| x.sub_(low).div_(max(high - low, 1e-5)) |
|
|
| x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 3, 0).to("cpu", torch.uint8) |
| write_video(save_path, x, fps=fps, video_codec="h264") |
| if verbose: |
| print(f"Saved to {save_path}") |
| return save_path |
|
|
|
|
| def center_crop_arr(pil_image, image_size): |
| """ |
| Center cropping implementation from ADM. |
| https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 |
| """ |
| while min(*pil_image.size) >= 2 * image_size: |
| pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX) |
|
|
| scale = image_size / min(*pil_image.size) |
| pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC) |
|
|
| arr = np.array(pil_image) |
| crop_y = (arr.shape[0] - image_size) // 2 |
| crop_x = (arr.shape[1] - image_size) // 2 |
| return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) |
|
|
|
|
| def resize_crop_to_fill(pil_image, image_size): |
| w, h = pil_image.size |
| th, tw = image_size |
| rh, rw = th / h, tw / w |
| if rh > rw: |
| sh, sw = th, round(w * rh) |
| image = pil_image.resize((sw, sh), Image.BICUBIC) |
| i = 0 |
| j = int(round((sw - tw) / 2.0)) |
| else: |
| sh, sw = round(h * rw), tw |
| image = pil_image.resize((sw, sh), Image.BICUBIC) |
| i = int(round((sh - th) / 2.0)) |
| j = 0 |
| arr = np.array(image) |
| assert i + th <= arr.shape[0] and j + tw <= arr.shape[1] |
| return Image.fromarray(arr[i : i + th, j : j + tw]) |
|
|