| | |
| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision.transforms.functional as TF |
| | from PIL import Image |
| |
|
| |
|
| | class VaceImageProcessor(object): |
| |
|
| | def __init__(self, downsample=None, seq_len=None): |
| | self.downsample = downsample |
| | self.seq_len = seq_len |
| |
|
| | def _pillow_convert(self, image, cvt_type='RGB'): |
| | if image.mode != cvt_type: |
| | if image.mode == 'P': |
| | image = image.convert(f'{cvt_type}A') |
| | if image.mode == f'{cvt_type}A': |
| | bg = Image.new( |
| | cvt_type, |
| | size=(image.width, image.height), |
| | color=(255, 255, 255)) |
| | bg.paste(image, (0, 0), mask=image) |
| | image = bg |
| | else: |
| | image = image.convert(cvt_type) |
| | return image |
| |
|
| | def _load_image(self, img_path): |
| | if img_path is None or img_path == '': |
| | return None |
| | img = Image.open(img_path) |
| | img = self._pillow_convert(img) |
| | return img |
| |
|
| | def _resize_crop(self, img, oh, ow, normalize=True): |
| | """ |
| | Resize, center crop, convert to tensor, and normalize. |
| | """ |
| | |
| | iw, ih = img.size |
| | if iw != ow or ih != oh: |
| | |
| | scale = max(ow / iw, oh / ih) |
| | img = img.resize((round(scale * iw), round(scale * ih)), |
| | resample=Image.Resampling.LANCZOS) |
| | assert img.width >= ow and img.height >= oh |
| |
|
| | |
| | x1 = (img.width - ow) // 2 |
| | y1 = (img.height - oh) // 2 |
| | img = img.crop((x1, y1, x1 + ow, y1 + oh)) |
| |
|
| | |
| | if normalize: |
| | img = TF.to_tensor(img).sub_(0.5).div_(0.5).unsqueeze(1) |
| | return img |
| |
|
| | def _image_preprocess(self, img, oh, ow, normalize=True, **kwargs): |
| | return self._resize_crop(img, oh, ow, normalize) |
| |
|
| | def load_image(self, data_key, **kwargs): |
| | return self.load_image_batch(data_key, **kwargs) |
| |
|
| | def load_image_pair(self, data_key, data_key2, **kwargs): |
| | return self.load_image_batch(data_key, data_key2, **kwargs) |
| |
|
| | def load_image_batch(self, |
| | *data_key_batch, |
| | normalize=True, |
| | seq_len=None, |
| | **kwargs): |
| | seq_len = self.seq_len if seq_len is None else seq_len |
| | imgs = [] |
| | for data_key in data_key_batch: |
| | img = self._load_image(data_key) |
| | imgs.append(img) |
| | w, h = imgs[0].size |
| | dh, dw = self.downsample[1:] |
| |
|
| | |
| | scale = min(1., np.sqrt(seq_len / ((h / dh) * (w / dw)))) |
| | oh = int(h * scale) // dh * dh |
| | ow = int(w * scale) // dw * dw |
| | assert (oh // dh) * (ow // dw) <= seq_len |
| | imgs = [self._image_preprocess(img, oh, ow, normalize) for img in imgs] |
| | return *imgs, (oh, ow) |
| |
|
| |
|
| | class VaceVideoProcessor(object): |
| |
|
| | def __init__(self, downsample, min_area, max_area, min_fps, max_fps, |
| | zero_start, seq_len, keep_last, **kwargs): |
| | self.downsample = downsample |
| | self.min_area = min_area |
| | self.max_area = max_area |
| | self.min_fps = min_fps |
| | self.max_fps = max_fps |
| | self.zero_start = zero_start |
| | self.keep_last = keep_last |
| | self.seq_len = seq_len |
| | assert seq_len >= min_area / (self.downsample[1] * self.downsample[2]) |
| |
|
| | def set_area(self, area): |
| | self.min_area = area |
| | self.max_area = area |
| |
|
| | def set_seq_len(self, seq_len): |
| | self.seq_len = seq_len |
| |
|
| | @staticmethod |
| | def resize_crop(video: torch.Tensor, oh: int, ow: int): |
| | """ |
| | Resize, center crop and normalize for decord loaded video (torch.Tensor type) |
| | |
| | Parameters: |
| | video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C) |
| | oh - target height (int) |
| | ow - target width (int) |
| | |
| | Returns: |
| | The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W) |
| | |
| | Raises: |
| | """ |
| | |
| | video = video.permute(0, 3, 1, 2) |
| |
|
| | |
| | ih, iw = video.shape[2:] |
| | if ih != oh or iw != ow: |
| | |
| | scale = max(ow / iw, oh / ih) |
| | video = F.interpolate( |
| | video, |
| | size=(round(scale * ih), round(scale * iw)), |
| | mode='bicubic', |
| | antialias=True) |
| | assert video.size(3) >= ow and video.size(2) >= oh |
| |
|
| | |
| | x1 = (video.size(3) - ow) // 2 |
| | y1 = (video.size(2) - oh) // 2 |
| | video = video[:, :, y1:y1 + oh, x1:x1 + ow] |
| |
|
| | |
| | video = video.transpose(0, 1).float().div_(127.5).sub_(1.) |
| | return video |
| |
|
| | def _video_preprocess(self, video, oh, ow): |
| | return self.resize_crop(video, oh, ow) |
| |
|
| | def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, |
| | rng): |
| | target_fps = min(fps, self.max_fps) |
| | duration = frame_timestamps[-1].mean() |
| | x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box |
| | h, w = y2 - y1, x2 - x1 |
| | ratio = h / w |
| | df, dh, dw = self.downsample |
| |
|
| | area_z = min(self.seq_len, self.max_area / (dh * dw), |
| | (h // dh) * (w // dw)) |
| | of = min((int(duration * target_fps) - 1) // df + 1, |
| | int(self.seq_len / area_z)) |
| |
|
| | |
| | target_area_z = min(area_z, int(self.seq_len / of)) |
| | oh = round(np.sqrt(target_area_z * ratio)) |
| | ow = int(target_area_z / oh) |
| | of = (of - 1) * df + 1 |
| | oh *= dh |
| | ow *= dw |
| |
|
| | |
| | target_duration = of / target_fps |
| | begin = 0. if self.zero_start else rng.uniform( |
| | 0, duration - target_duration) |
| | timestamps = np.linspace(begin, begin + target_duration, of) |
| | frame_ids = np.argmax( |
| | np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], |
| | timestamps[:, None] < frame_timestamps[None, :, 1]), |
| | axis=1).tolist() |
| | return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps |
| |
|
| | def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, |
| | crop_box, rng): |
| | duration = frame_timestamps[-1].mean() |
| | x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box |
| | h, w = y2 - y1, x2 - x1 |
| | ratio = h / w |
| | df, dh, dw = self.downsample |
| |
|
| | area_z = min(self.seq_len, self.max_area / (dh * dw), |
| | (h // dh) * (w // dw)) |
| | of = min((len(frame_timestamps) - 1) // df + 1, |
| | int(self.seq_len / area_z)) |
| |
|
| | |
| | target_area_z = min(area_z, int(self.seq_len / of)) |
| | oh = round(np.sqrt(target_area_z * ratio)) |
| | ow = int(target_area_z / oh) |
| | of = (of - 1) * df + 1 |
| | oh *= dh |
| | ow *= dw |
| |
|
| | |
| | target_duration = duration |
| | target_fps = of / target_duration |
| | timestamps = np.linspace(0., target_duration, of) |
| | frame_ids = np.argmax( |
| | np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], |
| | timestamps[:, None] <= frame_timestamps[None, :, 1]), |
| | axis=1).tolist() |
| | |
| | return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps |
| |
|
| | def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng): |
| | if self.keep_last: |
| | return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, |
| | w, crop_box, rng) |
| | else: |
| | return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, |
| | crop_box, rng) |
| |
|
| | def load_video(self, data_key, crop_box=None, seed=2024, **kwargs): |
| | return self.load_video_batch( |
| | data_key, crop_box=crop_box, seed=seed, **kwargs) |
| |
|
| | def load_video_pair(self, |
| | data_key, |
| | data_key2, |
| | crop_box=None, |
| | seed=2024, |
| | **kwargs): |
| | return self.load_video_batch( |
| | data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs) |
| |
|
| | def load_video_batch(self, |
| | *data_key_batch, |
| | crop_box=None, |
| | seed=2024, |
| | **kwargs): |
| | rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000) |
| | |
| | import decord |
| | decord.bridge.set_bridge('torch') |
| | readers = [] |
| | for data_k in data_key_batch: |
| | reader = decord.VideoReader(data_k) |
| | readers.append(reader) |
| |
|
| | fps = readers[0].get_avg_fps() |
| | length = min([len(r) for r in readers]) |
| | frame_timestamps = [ |
| | readers[0].get_frame_timestamp(i) for i in range(length) |
| | ] |
| | frame_timestamps = np.array(frame_timestamps, dtype=np.float32) |
| | h, w = readers[0].next().shape[:2] |
| | frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox( |
| | fps, frame_timestamps, h, w, crop_box, rng) |
| |
|
| | |
| | videos = [ |
| | reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] |
| | for reader in readers |
| | ] |
| | videos = [self._video_preprocess(video, oh, ow) for video in videos] |
| | return *videos, frame_ids, (oh, ow), fps |
| | |
| |
|
| |
|
| | def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, |
| | device): |
| | for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)): |
| | if sub_src_video is None and sub_src_mask is None: |
| | src_video[i] = torch.zeros( |
| | (3, num_frames, image_size[0], image_size[1]), device=device) |
| | src_mask[i] = torch.ones( |
| | (1, num_frames, image_size[0], image_size[1]), device=device) |
| | for i, ref_images in enumerate(src_ref_images): |
| | if ref_images is not None: |
| | for j, ref_img in enumerate(ref_images): |
| | if ref_img is not None and ref_img.shape[-2:] != image_size: |
| | canvas_height, canvas_width = image_size |
| | ref_height, ref_width = ref_img.shape[-2:] |
| | white_canvas = torch.ones( |
| | (3, 1, canvas_height, canvas_width), |
| | device=device) |
| | scale = min(canvas_height / ref_height, |
| | canvas_width / ref_width) |
| | new_height = int(ref_height * scale) |
| | new_width = int(ref_width * scale) |
| | resized_image = F.interpolate( |
| | ref_img.squeeze(1).unsqueeze(0), |
| | size=(new_height, new_width), |
| | mode='bilinear', |
| | align_corners=False).squeeze(0).unsqueeze(1) |
| | top = (canvas_height - new_height) // 2 |
| | left = (canvas_width - new_width) // 2 |
| | white_canvas[:, :, top:top + new_height, |
| | left:left + new_width] = resized_image |
| | src_ref_images[i][j] = white_canvas |
| | return src_video, src_mask, src_ref_images |
| |
|