File size: 3,426 Bytes
434b0b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Multi-HMR
# Copyright (c) 2024-present NAVER Corp.
# CC BY-NC-SA 4.0 license

import torch
import numpy as np
from PIL import Image, ImageOps
import torch.nn.functional as F
import cv2
import time

IMG_NORM_MEAN = [0.485, 0.456, 0.406]
IMG_NORM_STD = [0.229, 0.224, 0.225]


def normalize_rgb_tensor(img, imgenet_normalization=True):
    img = img / 255.0
    if imgenet_normalization:
        img = (
            img - torch.tensor(IMG_NORM_MEAN, device=img.device).view(1, 3, 1, 1)
        ) / torch.tensor(IMG_NORM_STD, device=img.device).view(1, 3, 1, 1)
    return img


def normalize_rgb(img, imagenet_normalization=True):
    """
    Args:
        - img: np.array - (W,H,3) - np.uint8 - 0/255
    Return:
        - img: np.array - (3,W,H) - np.float - -3/3
    """
    img = img.astype(np.float32) / 255.0
    img = np.transpose(img, (2, 0, 1))
    if imagenet_normalization:
        img = (img - np.asarray(IMG_NORM_MEAN).reshape(3, 1, 1)) / np.asarray(
            IMG_NORM_STD
        ).reshape(3, 1, 1)
    img = img.astype(np.float32)
    return img


def denormalize_rgb(img, imagenet_normalization=True):
    """
    Args:
        - img: np.array - (3,W,H) - np.float - -3/3
    Return:
        - img: np.array - (W,H,3) - np.uint8 - 0/255
    """
    if imagenet_normalization:
        img = (img * np.asarray(IMG_NORM_STD).reshape(3, 1, 1)) + np.asarray(
            IMG_NORM_MEAN
        ).reshape(3, 1, 1)
    img = np.transpose(img, (1, 2, 0)) * 255.0
    img = img.astype(np.uint8)
    return img


def unpatch(data, patch_size=14, c=3, img_size=224):
    # c = 3
    if len(data.shape) == 2:
        c = 1
        data = data[:, :, None].repeat([1, 1, patch_size**2])

    B, N, HWC = data.shape
    HW = patch_size**2
    c = int(HWC / HW)
    h = w = int(N**0.5)
    p = q = int(HW**0.5)
    data = data.reshape([B, h, w, p, q, c])
    data = torch.einsum("nhwpqc->nchpwq", data)
    return data.reshape([B, c, img_size, img_size])


def image_pad(img, img_size, device=torch.device("cuda")):
    img_pil = ImageOps.contain(img, (img_size, img_size))
    img_pil_bis = ImageOps.pad(
        img_pil.copy(), size=(img_size, img_size), color=(255, 255, 255)
    )
    img_pil = ImageOps.pad(
        img_pil, size=(img_size, img_size)
    )  # pad with zero on the smallest side

    # Go to numpy
    resize_img = np.asarray(img_pil)

    # Normalize and go to torch.
    resize_img = normalize_rgb(resize_img)

    x = torch.from_numpy(resize_img).unsqueeze(0).to(device)
    return x, img_pil_bis


def image_pad_cuda(img, img_size, rot=0, device=torch.device("cuda"), vis=False):
    img = torch.Tensor(img).to(device)
    img = torch.flip(img, dims=[2]).unsqueeze(0).permute(0, 3, 1, 2)
    if rot != 0:
        img = torch.rot90(img, rot, [2, 3])

    if vis:
        image = img.clone()[0].permute(1, 2, 0).cpu().numpy()
        if image.dtype != np.uint8:
            image = image.astype(np.uint8)
        cv2.imshow("k4a", image[..., ::-1])
        cv2.waitKey(1)
    _, _, h, w = img.shape
    scale_factor = min(img_size / w, img_size / h)

    img = F.interpolate(img, scale_factor=scale_factor, mode="bilinear")

    _, _, h, w = img.shape

    pad_w = (img_size - w) // 2
    pad_h = (img_size - h) // 2

    img = F.pad(img, (pad_w, pad_w, pad_h, pad_h), mode="constant", value=255)

    # Normalize and go to torch.
    resize_img = normalize_rgb_tensor(img)
    return resize_img, img