Spaces:
Configuration error
Configuration error
| # pylint: disable=R0801 | |
| """ | |
| talking_video_dataset.py | |
| This module defines the TalkingVideoDataset class, a custom PyTorch dataset | |
| for handling talking video data. The dataset uses video files, masks, and | |
| embeddings to prepare data for tasks such as video generation and | |
| speech-driven video animation. | |
| Classes: | |
| TalkingVideoDataset | |
| Dependencies: | |
| json | |
| random | |
| torch | |
| decord.VideoReader, decord.cpu | |
| PIL.Image | |
| torch.utils.data.Dataset | |
| torchvision.transforms | |
| Example: | |
| from talking_video_dataset import TalkingVideoDataset | |
| from torch.utils.data import DataLoader | |
| # Example configuration for the Wav2Vec model | |
| class Wav2VecConfig: | |
| def __init__(self, audio_type, model_scale, features): | |
| self.audio_type = audio_type | |
| self.model_scale = model_scale | |
| self.features = features | |
| wav2vec_cfg = Wav2VecConfig(audio_type="wav2vec2", model_scale="base", features="feature") | |
| # Initialize dataset | |
| dataset = TalkingVideoDataset( | |
| img_size=(512, 512), | |
| sample_rate=16000, | |
| audio_margin=2, | |
| n_motion_frames=0, | |
| n_sample_frames=16, | |
| data_meta_paths=["path/to/meta1.json", "path/to/meta2.json"], | |
| wav2vec_cfg=wav2vec_cfg, | |
| ) | |
| # Initialize dataloader | |
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True) | |
| # Fetch one batch of data | |
| batch = next(iter(dataloader)) | |
| print(batch["pixel_values_vid"].shape) # Example output: (4, 16, 3, 512, 512) | |
| The TalkingVideoDataset class provides methods for loading video frames, masks, | |
| audio embeddings, and other relevant data, applying transformations, and preparing | |
| the data for training and evaluation in a deep learning pipeline. | |
| Attributes: | |
| img_size (tuple): The dimensions to resize the video frames to. | |
| sample_rate (int): The audio sample rate. | |
| audio_margin (int): The margin for audio sampling. | |
| n_motion_frames (int): The number of motion frames. | |
| n_sample_frames (int): The number of sample frames. | |
| data_meta_paths (list): List of paths to the JSON metadata files. | |
| wav2vec_cfg (object): Configuration for the Wav2Vec model. | |
| Methods: | |
| augmentation(images, transform, state=None): Apply transformation to input images. | |
| __getitem__(index): Get a sample from the dataset at the specified index. | |
| __len__(): Return the length of the dataset. | |
| """ | |
| import json | |
| import random | |
| from typing import List | |
| import torch | |
| from decord import VideoReader, cpu | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| class TalkingVideoDataset(Dataset): | |
| """ | |
| A dataset class for processing talking video data. | |
| Args: | |
| img_size (tuple, optional): The size of the output images. Defaults to (512, 512). | |
| sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000. | |
| audio_margin (int, optional): The margin for the audio data. Defaults to 2. | |
| n_motion_frames (int, optional): The number of motion frames. Defaults to 0. | |
| n_sample_frames (int, optional): The number of sample frames. Defaults to 16. | |
| data_meta_paths (list, optional): The paths to the data metadata. Defaults to None. | |
| wav2vec_cfg (dict, optional): The configuration for the wav2vec model. Defaults to None. | |
| Attributes: | |
| img_size (tuple): The size of the output images. | |
| sample_rate (int): The sample rate of the audio data. | |
| audio_margin (int): The margin for the audio data. | |
| n_motion_frames (int): The number of motion frames. | |
| n_sample_frames (int): The number of sample frames. | |
| data_meta_paths (list): The paths to the data metadata. | |
| wav2vec_cfg (dict): The configuration for the wav2vec model. | |
| """ | |
| def __init__( | |
| self, | |
| img_size=(512, 512), | |
| sample_rate=16000, | |
| audio_margin=2, | |
| n_motion_frames=0, | |
| n_sample_frames=16, | |
| data_meta_paths=None, | |
| wav2vec_cfg=None, | |
| ): | |
| super().__init__() | |
| self.sample_rate = sample_rate | |
| self.img_size = img_size | |
| self.audio_margin = audio_margin | |
| self.n_motion_frames = n_motion_frames | |
| self.n_sample_frames = n_sample_frames | |
| self.audio_type = wav2vec_cfg.audio_type | |
| self.audio_model = wav2vec_cfg.model_scale | |
| self.audio_features = wav2vec_cfg.features | |
| vid_meta = [] | |
| for data_meta_path in data_meta_paths: | |
| with open(data_meta_path, "r", encoding="utf-8") as f: | |
| vid_meta.extend(json.load(f)) | |
| self.vid_meta = vid_meta | |
| self.length = len(self.vid_meta) | |
| self.pixel_transform = transforms.Compose( | |
| [ | |
| transforms.Resize(self.img_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| self.cond_transform = transforms.Compose( | |
| [ | |
| transforms.Resize(self.img_size), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| self.attn_transform_64 = transforms.Compose( | |
| [ | |
| transforms.Resize((64,64)), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| self.attn_transform_32 = transforms.Compose( | |
| [ | |
| transforms.Resize((32, 32)), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| self.attn_transform_16 = transforms.Compose( | |
| [ | |
| transforms.Resize((16, 16)), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| self.attn_transform_8 = transforms.Compose( | |
| [ | |
| transforms.Resize((8, 8)), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| def augmentation(self, images, transform, state=None): | |
| """ | |
| Apply the given transformation to the input images. | |
| Args: | |
| images (List[PIL.Image] or PIL.Image): The input images to be transformed. | |
| transform (torchvision.transforms.Compose): The transformation to be applied to the images. | |
| state (torch.ByteTensor, optional): The state of the random number generator. | |
| If provided, it will set the RNG state to this value before applying the transformation. Defaults to None. | |
| Returns: | |
| torch.Tensor: The transformed images as a tensor. | |
| If the input was a list of images, the tensor will have shape (f, c, h, w), | |
| where f is the number of images, c is the number of channels, h is the height, and w is the width. | |
| If the input was a single image, the tensor will have shape (c, h, w), | |
| where c is the number of channels, h is the height, and w is the width. | |
| """ | |
| if state is not None: | |
| torch.set_rng_state(state) | |
| if isinstance(images, List): | |
| transformed_images = [transform(img) for img in images] | |
| ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) | |
| else: | |
| ret_tensor = transform(images) # (c, h, w) | |
| return ret_tensor | |
| def __getitem__(self, index): | |
| video_meta = self.vid_meta[index] | |
| video_path = video_meta["video_path"] | |
| mask_path = video_meta["mask_path"] | |
| lip_mask_union_path = video_meta.get("sep_mask_lip", None) | |
| face_mask_union_path = video_meta.get("sep_mask_face", None) | |
| full_mask_union_path = video_meta.get("sep_mask_border", None) | |
| face_emb_path = video_meta["face_emb_path"] | |
| audio_emb_path = video_meta[ | |
| f"{self.audio_type}_emb_{self.audio_model}_{self.audio_features}" | |
| ] | |
| tgt_mask_pil = Image.open(mask_path) | |
| video_frames = VideoReader(video_path, ctx=cpu(0)) | |
| assert tgt_mask_pil is not None, "Fail to load target mask." | |
| assert (video_frames is not None and len(video_frames) > 0), "Fail to load video frames." | |
| video_length = len(video_frames) | |
| assert ( | |
| video_length | |
| > self.n_sample_frames + self.n_motion_frames + 2 * self.audio_margin | |
| ) | |
| start_idx = random.randint( | |
| self.n_motion_frames, | |
| video_length - self.n_sample_frames - self.audio_margin - 1, | |
| ) | |
| videos = video_frames[start_idx : start_idx + self.n_sample_frames] | |
| frame_list = [ | |
| Image.fromarray(video).convert("RGB") for video in videos.asnumpy() | |
| ] | |
| face_masks_list = [Image.open(face_mask_union_path)] * self.n_sample_frames | |
| lip_masks_list = [Image.open(lip_mask_union_path)] * self.n_sample_frames | |
| full_masks_list = [Image.open(full_mask_union_path)] * self.n_sample_frames | |
| assert face_masks_list[0] is not None, "Fail to load face mask." | |
| assert lip_masks_list[0] is not None, "Fail to load lip mask." | |
| assert full_masks_list[0] is not None, "Fail to load full mask." | |
| face_emb = torch.load(face_emb_path) | |
| audio_emb = torch.load(audio_emb_path) | |
| indices = ( | |
| torch.arange(2 * self.audio_margin + 1) - self.audio_margin | |
| ) # Generates [-2, -1, 0, 1, 2] | |
| center_indices = torch.arange( | |
| start_idx, | |
| start_idx + self.n_sample_frames, | |
| ).unsqueeze(1) + indices.unsqueeze(0) | |
| audio_tensor = audio_emb[center_indices] | |
| ref_img_idx = random.randint( | |
| self.n_motion_frames, | |
| video_length - self.n_sample_frames - self.audio_margin - 1, | |
| ) | |
| ref_img = video_frames[ref_img_idx].asnumpy() | |
| ref_img = Image.fromarray(ref_img) | |
| if self.n_motion_frames > 0: | |
| motions = video_frames[start_idx - self.n_motion_frames : start_idx] | |
| motion_list = [ | |
| Image.fromarray(motion).convert("RGB") for motion in motions.asnumpy() | |
| ] | |
| # transform | |
| state = torch.get_rng_state() | |
| pixel_values_vid = self.augmentation(frame_list, self.pixel_transform, state) | |
| pixel_values_mask = self.augmentation(tgt_mask_pil, self.cond_transform, state) | |
| pixel_values_mask = pixel_values_mask.repeat(3, 1, 1) | |
| pixel_values_face_mask = [ | |
| self.augmentation(face_masks_list, self.attn_transform_64, state), | |
| self.augmentation(face_masks_list, self.attn_transform_32, state), | |
| self.augmentation(face_masks_list, self.attn_transform_16, state), | |
| self.augmentation(face_masks_list, self.attn_transform_8, state), | |
| ] | |
| pixel_values_lip_mask = [ | |
| self.augmentation(lip_masks_list, self.attn_transform_64, state), | |
| self.augmentation(lip_masks_list, self.attn_transform_32, state), | |
| self.augmentation(lip_masks_list, self.attn_transform_16, state), | |
| self.augmentation(lip_masks_list, self.attn_transform_8, state), | |
| ] | |
| pixel_values_full_mask = [ | |
| self.augmentation(full_masks_list, self.attn_transform_64, state), | |
| self.augmentation(full_masks_list, self.attn_transform_32, state), | |
| self.augmentation(full_masks_list, self.attn_transform_16, state), | |
| self.augmentation(full_masks_list, self.attn_transform_8, state), | |
| ] | |
| pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state) | |
| pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) | |
| if self.n_motion_frames > 0: | |
| pixel_values_motion = self.augmentation( | |
| motion_list, self.pixel_transform, state | |
| ) | |
| pixel_values_ref_img = torch.cat( | |
| [pixel_values_ref_img, pixel_values_motion], dim=0 | |
| ) | |
| sample = { | |
| "video_dir": video_path, | |
| "pixel_values_vid": pixel_values_vid, | |
| "pixel_values_mask": pixel_values_mask, | |
| "pixel_values_face_mask": pixel_values_face_mask, | |
| "pixel_values_lip_mask": pixel_values_lip_mask, | |
| "pixel_values_full_mask": pixel_values_full_mask, | |
| "audio_tensor": audio_tensor, | |
| "pixel_values_ref_img": pixel_values_ref_img, | |
| "face_emb": face_emb, | |
| } | |
| return sample | |
| def __len__(self): | |
| return len(self.vid_meta) | |