| | import os |
| | from pathlib import Path |
| | from copy import deepcopy |
| | from typing import Optional |
| | import numpy as np |
| | import PIL.Image as Image |
| | import torch |
| | import torchvision.transforms.functional as F |
| | from torch.utils.data import Dataset, default_collate |
| | import json |
| | from vhap.util.log import get_logger |
| | from vhap.config.base import DataConfig |
| |
|
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | class VideoDataset(Dataset): |
| | def __init__( |
| | self, |
| | cfg: DataConfig, |
| | img_to_tensor: bool = False, |
| | batchify_all_views: bool = False, |
| | ): |
| | """ |
| | Args: |
| | root_folder: Path to dataset with the following directory layout |
| | <root_folder>/ |
| | |---images/ |
| | | |---<timestep_id>.jpg |
| | | |
| | |---alpha_maps/ |
| | | |---<timestep_id>.png |
| | | |
| | |---landmark2d/ |
| | |---face-alignment/ |
| | | |---<camera_id>.npz |
| | | |
| | |---STAR/ |
| | |---<camera_id>.npz |
| | """ |
| | super().__init__() |
| | self.cfg = cfg |
| | self.img_to_tensor = img_to_tensor |
| | self.batchify_all_views = batchify_all_views |
| | |
| | sequence_paths = self.match_sequences() |
| | if len(sequence_paths) > 1: |
| | logger.info(f"Found multiple sequences: {sequence_paths}") |
| | raise ValueError(f"Found multiple sequences by '{cfg.sequence}': \n" + "\n\t".join([str(x) for x in sequence_paths])) |
| | elif len(sequence_paths) == 0: |
| | raise ValueError(f"Cannot find sequence: {cfg.sequence}") |
| | self.sequence_path = sequence_paths[0] |
| | logger.info(f"Initializing dataset from {self.sequence_path}") |
| |
|
| | self.define_properties() |
| | self.load_camera_params() |
| |
|
| | |
| | self.timestep_ids = set( |
| | f.split('.')[0].split('_')[-1] |
| | for f in os.listdir(self.sequence_path / self.properties['rgb']['folder']) if f.endswith(self.properties['rgb']['suffix']) |
| | ) |
| | self.timestep_ids = sorted(self.timestep_ids) |
| | self.timestep_indices = list(range(len(self.timestep_ids))) |
| |
|
| | self.filter_division(cfg.division) |
| | self.filter_subset(cfg.subset) |
| | |
| | logger.info(f"number of timesteps: {self.num_timesteps}, number of cameras: {self.num_cameras}") |
| |
|
| | |
| | self.items = [] |
| | for fi, timestep_index in enumerate(self.timestep_indices): |
| | for ci, camera_id in enumerate(self.camera_ids): |
| | self.items.append( |
| | { |
| | "timestep_index": fi, |
| | "timestep_index_original": timestep_index, |
| | "timestep_id": self.timestep_ids[timestep_index], |
| | "camera_index": ci, |
| | "camera_id": camera_id, |
| | } |
| | ) |
| | |
| | def match_sequences(self): |
| | logger.info(f"Looking for sequence '{self.cfg.sequence}' at {self.cfg.root_folder}") |
| | return list(filter(lambda x: x.is_dir(), self.cfg.root_folder.glob(f"{self.cfg.sequence}*"))) |
| | |
| | def define_properties(self): |
| | self.properties = { |
| | "rgb": { |
| | "folder": f"images_{self.cfg.n_downsample_rgb}" |
| | if self.cfg.n_downsample_rgb |
| | else "images", |
| | "per_timestep": True, |
| | |
| | "suffix": "png", |
| | }, |
| | "alpha_map": { |
| | "folder": "alpha_maps", |
| | "per_timestep": True, |
| | "suffix": "jpg", |
| | }, |
| | "landmark2d/face-alignment": { |
| | "folder": "landmark2d/face-alignment", |
| | "per_timestep": False, |
| | "suffix": "npz", |
| | }, |
| | "landmark2d/STAR": { |
| | "folder": "landmark2d/STAR", |
| | "per_timestep": False, |
| | "suffix": "npz", |
| | }, |
| | "landmark2d/lms": { |
| | "folder": "landmark2d/landmarks", |
| | "per_timestep": False, |
| | "suffix": "npz", |
| | }, |
| | } |
| |
|
| | @staticmethod |
| | def get_number_after_prefix(string, prefix): |
| | i = string.find(prefix) |
| | if i != -1: |
| | number_begin = i + len(prefix) |
| | assert number_begin < len(string), f"No number found behind prefix '{prefix}'" |
| | assert string[number_begin].isdigit(), f"No number found behind prefix '{prefix}'" |
| |
|
| | non_digit_indices = [i for i, c in enumerate(string[number_begin:]) if not c.isdigit()] |
| | if len(non_digit_indices) > 0: |
| | number_end = number_begin + min(non_digit_indices) |
| | return int(string[number_begin:number_end]) |
| | else: |
| | return int(string[number_begin:]) |
| | else: |
| | return None |
| | |
| | def filter_division(self, division): |
| | pass |
| |
|
| | def filter_subset(self, subset): |
| | if subset is not None: |
| | if 'ti' in subset: |
| | ti = self.get_number_after_prefix(subset, 'ti') |
| | if 'tj' in subset: |
| | tj = self.get_number_after_prefix(subset, 'tj') |
| | self.timestep_indices = self.timestep_indices[ti:tj+1] |
| | else: |
| | self.timestep_indices = self.timestep_indices[ti:ti+1] |
| | elif 'tn' in subset: |
| | tn = self.get_number_after_prefix(subset, 'tn') |
| | tn_all = len(self.timestep_indices) |
| | tn = min(tn, tn_all) |
| | self.timestep_indices = self.timestep_indices[::tn_all // tn][:tn] |
| | elif 'ts' in subset: |
| | ts = self.get_number_after_prefix(subset, 'ts') |
| | self.timestep_indices = self.timestep_indices[::ts] |
| | if 'ci' in subset: |
| | ci = self.get_number_after_prefix(subset, 'ci') |
| | self.camera_ids = self.camera_ids[ci:ci+1] |
| | elif 'cn' in subset: |
| | cn = self.get_number_after_prefix(subset, 'cn') |
| | cn_all = len(self.camera_ids) |
| | cn = min(cn, cn_all) |
| | self.camera_ids = self.camera_ids[::cn_all // cn][:cn] |
| | elif 'cs' in subset: |
| | cs = self.get_number_after_prefix(subset, 'cs') |
| | self.camera_ids = self.camera_ids[::cs] |
| | |
| | def load_camera_params(self): |
| | self.camera_ids = ['0'] |
| |
|
| | |
| | f, h, w = 512, 512, 512 |
| | K = torch.Tensor([ |
| | [f, 0, w], |
| | [0, f, h], |
| | [0, 0, 1] |
| | ]) |
| |
|
| | orientation = torch.eye(3)[None, ...] |
| | location = torch.Tensor([0, 0, 1])[None, ..., None] |
| |
|
| | c2w = torch.cat([orientation, location], dim=-1) |
| |
|
| | if self.cfg.target_extrinsic_type == "w2c": |
| | R = orientation.transpose(-1, -2) |
| | T = orientation.transpose(-1, -2) @ -location |
| | w2c = torch.cat([R, T], dim=-1) |
| | extrinsic = w2c |
| | elif self.cfg.target_extrinsic_type == "c2w": |
| | extrinsic = c2w |
| | else: |
| | raise NotImplementedError(f"Unknown extrinsic type: {self.cfg.target_extrinsic_type}") |
| |
|
| | self.camera_params = {} |
| | for i, camera_id in enumerate(self.camera_ids): |
| | self.camera_params[camera_id] = {"intrinsic": K, "extrinsic": extrinsic[i]} |
| |
|
| | return self.camera_params |
| |
|
| | def __len__(self): |
| | if self.batchify_all_views: |
| | return self.num_timesteps |
| | else: |
| | return len(self.items) |
| |
|
| | def __getitem__(self, i): |
| | if self.batchify_all_views: |
| | return self.getitem_by_timestep(i) |
| | else: |
| | return self.getitem_single_image(i) |
| |
|
| | def getitem_single_image(self, i): |
| | item = deepcopy(self.items[i]) |
| |
|
| | rgb_path = self.get_property_path("rgb", i) |
| | item["rgb"] = np.array(Image.open(rgb_path))[:, :, :3] |
| |
|
| | camera_param = self.camera_params[item["camera_id"]] |
| | item["intrinsic"] = camera_param["intrinsic"].clone() |
| | item["extrinsic"] = camera_param["extrinsic"].clone() |
| |
|
| | if self.cfg.use_alpha_map or self.cfg.background_color is not None: |
| | alpha_path = self.get_property_path("alpha_map", i) |
| | item["alpha_map"] = np.array(Image.open(alpha_path)) |
| |
|
| | if self.cfg.use_landmark: |
| | timestep_index = self.items[i]["timestep_index"] |
| |
|
| | landmark_path = self.get_property_path("landmark2d/lms", i) |
| | landmark_npz = np.load(landmark_path) |
| |
|
| | lms_eyes_path = os.path.join(os.path.dirname(landmark_path),'iris.json') |
| |
|
| | item["lmk2d"] = landmark_npz["face_landmark_2d"][timestep_index] |
| | if (item["lmk2d"][:, :2] == -1).sum() > 0: |
| | item["lmk2d"][:, 2:] = 0.0 |
| | else: |
| | item["lmk2d"][:, 2:] = 1.0 |
| |
|
| | if(os.path.exists(lms_eyes_path)): |
| | with open(lms_eyes_path,'r') as f: |
| | lms_eye = json.load(f) |
| | lms_eye = np.array([lms_eye[key] for key in lms_eye][timestep_index]).reshape((2,2)) / 1024. |
| | lms_eye = np.concatenate([lms_eye,np.ones((2,1))],axis=1)[(1,0),:] |
| | item["lmk2d"] = np.concatenate([item["lmk2d"], lms_eye], 0) |
| | else: |
| | item["lmk2d"] = np.concatenate([item["lmk2d"]], 0) |
| |
|
| | item = self.apply_transforms(item) |
| | return item |
| |
|
| | def getitem_by_timestep(self, timestep_index): |
| | begin = timestep_index * self.num_cameras |
| | indices = range(begin, begin + self.num_cameras) |
| | item = default_collate([self.getitem_single_image(i) for i in indices]) |
| |
|
| | item["num_cameras"] = self.num_cameras |
| | return item |
| |
|
| | def apply_transforms(self, item): |
| | item = self.apply_scale_factor(item) |
| | item = self.apply_background_color(item) |
| | item = self.apply_to_tensor(item) |
| | return item |
| |
|
| | def apply_to_tensor(self, item): |
| | if self.img_to_tensor: |
| | if "rgb" in item: |
| | item["rgb"] = F.to_tensor(item["rgb"]) |
| |
|
| | if "alpha_map" in item: |
| | item["alpha_map"] = F.to_tensor(item["alpha_map"]) |
| | return item |
| |
|
| | def apply_scale_factor(self, item): |
| | assert self.cfg.scale_factor <= 1.0 |
| |
|
| | if "rgb" in item: |
| | H, W, _ = item["rgb"].shape |
| | h, w = int(H * self.cfg.scale_factor), int(W * self.cfg.scale_factor) |
| | rgb = Image.fromarray(item["rgb"]).resize( |
| | (w, h), resample=Image.BILINEAR |
| | ) |
| | item["rgb"] = np.array(rgb) |
| | |
| | |
| | if "lmk2d" in item: |
| | item["lmk2d"][..., 0] *= w |
| | item["lmk2d"][..., 1] *= h |
| | |
| | if "lmk2d_iris" in item: |
| | item["lmk2d_iris"][..., 0] *= w |
| | item["lmk2d_iris"][..., 1] *= h |
| |
|
| | if "bbox_2d" in item: |
| | item["bbox_2d"][[0, 2]] *= w |
| | item["bbox_2d"][[1, 3]] *= h |
| |
|
| | |
| | n_downsample_rgb = self.cfg.n_downsample_rgb if self.cfg.n_downsample_rgb else 1 |
| | scale_factor = self.cfg.scale_factor / n_downsample_rgb |
| | item["scale_factor"] = scale_factor |
| | if scale_factor < 1.0: |
| | if "intrinsic" in item: |
| | item["intrinsic"][:2] *= scale_factor |
| | if "alpha_map" in item: |
| | h, w = item["rgb"].shape[:2] |
| | alpha_map = Image.fromarray(item["alpha_map"]).resize( |
| | (w, h), Image.Resampling.BILINEAR |
| | ) |
| | item["alpha_map"] = np.array(alpha_map) |
| | return item |
| |
|
| | def apply_background_color(self, item): |
| | if self.cfg.background_color is not None: |
| | assert ( |
| | "alpha_map" in item |
| | ), "'alpha_map' is required to apply background color." |
| | fg = item["rgb"] |
| | if self.cfg.background_color == "white": |
| | bg = np.ones_like(fg) * 255 |
| | elif self.cfg.background_color == "black": |
| | bg = np.zeros_like(fg) |
| | else: |
| | raise NotImplementedError( |
| | f"Unknown background color: {self.cfg.background_color}." |
| | ) |
| |
|
| | |
| | w = item["alpha_map"] / 255 |
| | img = (w * fg + (1 - w) * bg).astype(np.uint8) |
| | item["rgb"] = img |
| | return item |
| |
|
| | def get_property_path( |
| | self, |
| | name, |
| | index: Optional[int] = None, |
| | timestep_id: Optional[str] = None, |
| | camera_id: Optional[str] = None, |
| | ): |
| | p = self.properties[name] |
| | folder = p["folder"] if "folder" in p else None |
| | per_timestep = p["per_timestep"] |
| | suffix = p["suffix"] |
| |
|
| | path = self.sequence_path |
| | if folder is not None: |
| | path = path / folder |
| |
|
| | if self.num_cameras > 1: |
| | if camera_id is None: |
| | assert ( |
| | index is not None), "index is required when camera_id is not provided." |
| | camera_id = self.items[index]["camera_id"] |
| | if "cam_id_prefix" in p: |
| | camera_id = p["cam_id_prefix"] + camera_id |
| | else: |
| | camera_id = "" |
| |
|
| | if per_timestep: |
| | if timestep_id is None: |
| | assert index is not None, "index is required when timestep_id is not provided." |
| | timestep_id = self.items[index]["timestep_id"] |
| | if len(camera_id) > 0: |
| | path /= f"{camera_id}_{timestep_id}.{suffix}" |
| | else: |
| | path /= f"{timestep_id}.{suffix}" |
| | else: |
| | if len(camera_id) > 0: |
| | path /= f"{camera_id}.{suffix}" |
| | else: |
| | path = Path(str(path) + f".{suffix}") |
| |
|
| | return path |
| | |
| | def get_property_path_list(self, name): |
| | paths = [] |
| | for i in range(len(self.items)): |
| | img_path = self.get_property_path(name, i) |
| | paths.append(img_path) |
| | return paths |
| |
|
| | @property |
| | def num_timesteps(self): |
| | return len(self.timestep_indices) |
| |
|
| | @property |
| | def num_cameras(self): |
| | return len(self.camera_ids) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import tyro |
| | from tqdm import tqdm |
| | from torch.utils.data import DataLoader |
| | from vhap.config.base import DataConfig, import_module |
| |
|
| | cfg = tyro.cli(DataConfig) |
| | cfg.use_landmark = False |
| | dataset = import_module(cfg._target)( |
| | cfg=cfg, |
| | img_to_tensor=False, |
| | batchify_all_views=True, |
| | ) |
| |
|
| | print(len(dataset)) |
| |
|
| | sample = dataset[0] |
| | print(sample.keys()) |
| | print(sample["rgb"].shape) |
| |
|
| | dataloader = DataLoader(dataset, batch_size=None, shuffle=False, num_workers=1) |
| | for item in tqdm(dataloader): |
| | pass |
| |
|