| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import matplotlib.pyplot as plt |
| import torch |
| from torchvision.utils import draw_bounding_boxes, draw_keypoints |
|
|
|
|
| connectivity_face = ( |
| [(i, i + 1) for i in list(range(0, 16))] |
| + [(i, i + 1) for i in list(range(17, 21))] |
| + [(i, i + 1) for i in list(range(22, 26))] |
| + [(i, i + 1) for i in list(range(27, 30))] |
| + [(i, i + 1) for i in list(range(31, 35))] |
| + [(i, i + 1) for i in list(range(36, 41))] |
| + [(36, 41)] |
| + [(i, i + 1) for i in list(range(42, 47))] |
| + [(42, 47)] |
| + [(i, i + 1) for i in list(range(48, 59))] |
| + [(48, 59)] |
| + [(i, i + 1) for i in list(range(60, 67))] |
| + [(60, 67)] |
| ) |
|
|
|
|
| def plot_landmarks_2d( |
| img: torch.tensor, |
| lmks: torch.tensor, |
| connectivity=None, |
| colors="white", |
| unit=1, |
| input_float=False, |
| ): |
| if input_float: |
| img = (img * 255).byte() |
|
|
| img = draw_keypoints( |
| img, |
| lmks, |
| connectivity=connectivity, |
| colors=colors, |
| radius=2 * unit, |
| width=2 * unit, |
| ) |
|
|
| if input_float: |
| img = img.float() / 255 |
| return img |
|
|
|
|
| def blend(a, b, w): |
| return (a * w + b * (1 - w)).byte() |
|
|
|
|
| if __name__ == "__main__": |
| from argparse import ArgumentParser |
| from torch.utils.data import DataLoader |
| from matplotlib import pyplot as plt |
|
|
| from vhap.data.nersemble_dataset import NeRSembleDataset |
|
|
| parser = ArgumentParser() |
| parser.add_argument("--root_folder", type=str, required=True) |
| parser.add_argument("--subject", type=str, required=True) |
| parser.add_argument("--sequence", type=str, required=True) |
| parser.add_argument("--division", default=None) |
| parser.add_argument("--subset", default=None) |
| parser.add_argument("--scale_factor", type=float, default=1.0) |
| parser.add_argument("--blend_weight", type=float, default=0.6) |
| args = parser.parse_args() |
|
|
| dataset = NeRSembleDataset( |
| root_folder=args.root_folder, |
| subject=args.subject, |
| sequence=args.sequence, |
| division=args.division, |
| subset=args.subset, |
| n_downsample_rgb=2, |
| scale_factor=args.scale_factor, |
| use_landmark=True, |
| ) |
| dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4) |
|
|
| for item in dataloader: |
| unit = int(item["scale_factor"][0] * 3) + 1 |
|
|
| rgb = item["rgb"][0].permute(2, 0, 1) |
| vis = rgb |
|
|
| if "bbox_2d" in item: |
| bbox = item["bbox_2d"][0][:4] |
| tmp = draw_bounding_boxes(vis, bbox[None, ...], width=5 * unit) |
| vis = blend(tmp, vis, args.blend_weight) |
|
|
| if "lmk2d" in item: |
| face_landmark = item["lmk2d"][0][:, :2] |
| tmp = plot_landmarks_2d( |
| vis, |
| face_landmark[None, ...], |
| connectivity=connectivity_face, |
| colors="white", |
| unit=unit, |
| ) |
| vis = blend(tmp, vis, args.blend_weight) |
|
|
| if "lmk2d_iris" in item: |
| iris_landmark = item["lmk2d_iris"][0][:, :2] |
| tmp = plot_landmarks_2d( |
| vis, |
| iris_landmark[None, ...], |
| colors="blue", |
| unit=unit, |
| ) |
| vis = blend(tmp, vis, args.blend_weight) |
|
|
| vis = vis.permute(1, 2, 0).numpy() |
| plt.imshow(vis) |
| plt.draw() |
| while not plt.waitforbuttonpress(timeout=-1): |
| pass |
|
|