camenduru's picture
thanks to show ❤
3bbb319
"""
This file contains functions that are used to perform data augmentation.
"""
from numpy.testing._private.utils import print_assert_equal
import torch
import numpy as np
import cv2
import skimage.transform
from PIL import Image
from pymaf_core import constants
def get_transform(center, scale, res, rot=0):
"""Generate transformation matrix."""
h = 200 * scale
t = np.zeros((3, 3))
t[0, 0] = float(res[1]) / h
t[1, 1] = float(res[0]) / h
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
t[2, 2] = 1
if not rot == 0:
t = np.dot(get_rot_transf(res, rot),t)
return t
def get_rot_transf(res, rot):
"""Generate rotation transformation matrix."""
if rot == 0:
return np.identity(3)
rot = -rot # To match direction of rotation from cropping
rot_mat = np.zeros((3,3))
rot_rad = rot * np.pi / 180
sn,cs = np.sin(rot_rad), np.cos(rot_rad)
rot_mat[0,:2] = [cs, -sn]
rot_mat[1,:2] = [sn, cs]
rot_mat[2,2] = 1
# Need to rotate around center
t_mat = np.eye(3)
t_mat[0,2] = -res[1]/2
t_mat[1,2] = -res[0]/2
t_inv = t_mat.copy()
t_inv[:2,2] *= -1
rot_transf = np.dot(t_inv,np.dot(rot_mat,t_mat))
return rot_transf
def transform(pt, center, scale, res, invert=0, rot=0):
"""Transform pixel location to different reference."""
t = get_transform(center, scale, res, rot=rot)
if invert:
t = np.linalg.inv(t)
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2].astype(int) + 1
def transform_pts(coords, center, scale, res, invert=0, rot=0):
"""Transform coordinates (N x 2) to different reference."""
new_coords = coords.copy()
for p in range(coords.shape[0]):
new_coords[p, 0:2] = transform(coords[p, 0:2], center, scale, res, invert, rot)
return new_coords
def crop(img, center, scale, res, rot=0):
"""Crop image according to the supplied bounding box."""
# Upper left point
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
# Bottom right point
br = np.array(transform([res[0]+1,
res[1]+1], center, scale, res, invert=1))-1
# Padding so that when rotated proper amount of context is included
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
if not rot == 0:
ul -= pad
br += pad
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(new_shape)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(len(img[0]), br[0])
old_y = max(0, ul[1]), min(len(img), br[1])
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
old_x[0]:old_x[1]]
if not rot == 0:
# Remove padding
new_img = skimage.transform.rotate(new_img, rot).astype(np.uint8)
new_img = new_img[pad:-pad, pad:-pad]
new_img_resized = np.array(Image.fromarray(new_img.astype(np.uint8)).resize(res))
return new_img_resized, new_img, new_shape
def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
"""'Undo' the image cropping/resizing.
This function is used when evaluating mask/part segmentation.
"""
res = img.shape[:2]
# Upper left point
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
# Bottom right point
br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1
# size of cropped image
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
new_shape = [br[1] - ul[1], br[0] - ul[0]]
if len(img.shape) > 2:
new_shape += [img.shape[2]]
new_img = np.zeros(orig_shape, dtype=np.uint8)
# Range to fill new array
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
# Range to sample from original image
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
img = np.array(Image.fromarray(img.astype(np.uint8)).resize(crop_shape))
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
return new_img
def rot_aa(aa, rot):
"""Rotate axis angle parameters."""
# pose parameters
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
[0, 0, 1]])
# find the rotation of the body in camera frame
per_rdg, _ = cv2.Rodrigues(aa)
# apply the global rotation to the global orientation
resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg))
aa = (resrot.T)[0]
return aa
def flip_img(img):
"""Flip rgb images or masks.
channels come last, e.g. (256,256,3).
"""
img = np.fliplr(img)
return img
def flip_kp(kp, is_smpl=False, type='body'):
"""Flip keypoints."""
assert type in ['body', 'hand', 'face', 'feet']
if type == 'body':
if len(kp) == 24:
if is_smpl:
flipped_parts = constants.SMPL_JOINTS_FLIP_PERM
else:
flipped_parts = constants.J24_FLIP_PERM
elif len(kp) == 49:
if is_smpl:
flipped_parts = constants.SMPL_J49_FLIP_PERM
else:
flipped_parts = constants.J49_FLIP_PERM
elif type == 'hand':
if len(kp) == 21:
flipped_parts = constants.SINGLE_HAND_FLIP_PERM
elif len(kp) == 42:
flipped_parts = constants.LRHAND_FLIP_PERM
elif type == 'face':
flipped_parts = constants.FACE_FLIP_PERM
elif type == 'feet':
flipped_parts = constants.FEEF_FLIP_PERM
kp = kp[flipped_parts]
kp[:,0] = - kp[:,0]
return kp
def flip_pose(pose):
"""Flip pose.
The flipping is based on SMPL parameters.
"""
flipped_parts = constants.SMPL_POSE_FLIP_PERM
pose = pose[flipped_parts]
# we also negate the second and the third dimension of the axis-angle
pose[1::3] = -pose[1::3]
pose[2::3] = -pose[2::3]
return pose
def flip_aa(pose):
"""Flip aa.
"""
# we also negate the second and the third dimension of the axis-angle
if len(pose.shape) == 1:
pose[1::3] = -pose[1::3]
pose[2::3] = -pose[2::3]
elif len(pose.shape) == 2:
pose[:, 1::3] = -pose[:, 1::3]
pose[:, 2::3] = -pose[:, 2::3]
else:
raise NotImplementedError
return pose
def normalize_2d_kp(kp_2d, crop_size=224, inv=False):
# Normalize keypoints between -1, 1
if not inv:
ratio = 1.0 / crop_size
kp_2d = 2.0 * kp_2d * ratio - 1.0
else:
ratio = 1.0 / crop_size
kp_2d = (kp_2d + 1.0)/(2*ratio)
return kp_2d
def j2d_processing(kp, transf):
"""Process gt 2D keypoints and apply transforms."""
# nparts = kp.shape[1]
bs, npart = kp.shape[:2]
kp_pad = torch.cat([kp, torch.ones((bs, npart, 1)).to(kp)], dim=-1)
kp_new = torch.bmm(transf, kp_pad.transpose(1, 2))
kp_new = kp_new.transpose(1, 2)
kp_new[:, :, :-1] = 2.*kp_new[:, :, :-1] / constants.IMG_RES - 1.
return kp_new[:, :, :2]
def generate_heatmap(joints, heatmap_size, sigma=1, joints_vis=None):
'''
param joints: [num_joints, 3]
param joints_vis: [num_joints, 3]
return: target, target_weight(1: visible, 0: invisible)
'''
num_joints = joints.shape[0]
device = joints.device
cur_device = torch.device(device.type, device.index)
if not hasattr(heatmap_size, '__len__'):
# width height
heatmap_size = [heatmap_size, heatmap_size]
assert len(heatmap_size) == 2
target_weight = np.ones((num_joints, 1), dtype=np.float32)
if joints_vis is not None:
target_weight[:, 0] = joints_vis[:, 0]
target = torch.zeros((num_joints,
heatmap_size[1],
heatmap_size[0]),
dtype=torch.float32,
device=cur_device)
tmp_size = sigma * 3
for joint_id in range(num_joints):
mu_x = int(joints[joint_id][0] * heatmap_size[0] + 0.5)
mu_y = int(joints[joint_id][1] * heatmap_size[1] + 0.5)
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= heatmap_size[0] or ul[1] >= heatmap_size[1] \
or br[0] < 0 or br[1] < 0:
# If not, just return the image as is
target_weight[joint_id] = 0
continue
# # Generate gaussian
size = 2 * tmp_size + 1
# x = np.arange(0, size, 1, np.float32)
# y = x[:, np.newaxis]
# x0 = y0 = size // 2
# # The gaussian is not normalized, we want the center value to equal 1
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
# g = torch.from_numpy(g.astype(np.float32))
x = torch.arange(0, size, dtype=torch.float32, device=cur_device)
y = x.unsqueeze(-1)
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = torch.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
v = target_weight[joint_id]
if v > 0.5:
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return target, target_weight