camenduru's picture
thanks to show ❤
3bbb319
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de
import os
import cv2
import numpy as np
import os.path as osp
from sklearn.random_projection import johnson_lindenstrauss_min_dim
import torch
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor
import torch.nn.functional as F
from torchvision.transforms import Normalize
from utils.smooth_bbox import get_all_bbox_params
from .data_utils.img_utils import get_single_image_crop_demo
from utils.cam_params import read_cam_params, homo_vector
from pymaf_core import path_config, constants
from pymaf_core.cfgs import cfg
from utils.imutils import crop, flip_img, flip_pose, flip_aa, flip_kp, transform, get_transform, get_rot_transf, rot_aa
class Inference(Dataset):
def __init__(self, image_folder, frames, bboxes=None, joints2d=None, scale=1.0, crop_size=224, pre_load_imgs=None, full_body=False, person_ids=[], wb_kps={}):
self.pre_load_imgs = pre_load_imgs
if pre_load_imgs is None:
self.image_file_names = [
osp.join(image_folder, x)
for x in os.listdir(image_folder)
if x.endswith('.png') or x.endswith('.jpg')
]
self.image_file_names = sorted(self.image_file_names)
self.image_file_names = np.array(self.image_file_names)[frames]
self.bboxes = bboxes
self.joints2d = joints2d
self.scale_factor = scale
self.crop_size = crop_size
self.frames = frames
self.has_keypoints = True if joints2d is not None else False
self.full_body = full_body
self.person_ids = person_ids
self.normalize_img = Normalize(mean=constants.IMG_NORM_MEAN, std=constants.IMG_NORM_STD)
self.norm_joints2d = np.zeros_like(self.joints2d)
if self.has_keypoints:
if not self.full_body:
bboxes, time_pt1, time_pt2 = get_all_bbox_params(joints2d, vis_thresh=0.3)
bboxes[:, 2:] = 150. / bboxes[:, 2:]
self.bboxes = np.stack([bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 2]]).T
self.image_file_names = self.image_file_names[time_pt1:time_pt2]
self.joints2d = joints2d[time_pt1:time_pt2]
self.frames = frames[time_pt1:time_pt2]
else:
bboxes = []
scales = []
for j2d in joints2d:
kp2d_valid = j2d[j2d[:, 2]>0.]
bbox = [min(kp2d_valid[:, 0]), min(kp2d_valid[:, 1]),
max(kp2d_valid[:, 0]), max(kp2d_valid[:, 1])]
center = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
scale = self.scale_factor * 1.2 * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 200.
res = [constants.IMG_RES, constants.IMG_RES]
ul = np.array(transform([1, 1], center, scale, res, invert=1))-1
# Bottom right point
br = np.array(transform([res[0]+1,
res[1]+1], center, scale, res, invert=1))-1
center = [(ul[0] + br[0]) / 2., (ul[1] + br[1]) / 2.]
width_height = [br[0] - ul[0], br[1] - ul[1]]
bbox = np.array(center + width_height)
bboxes.append(bbox)
scales.append(scale)
self.bboxes = np.stack(bboxes)
self.scales = np.array(scales)
self.image_file_names = self.image_file_names
self.joints2d = joints2d
self.frames = frames
if self.full_body:
joints2d_face = wb_kps['joints2d_face']
joints2d_lhand = wb_kps['joints2d_lhand']
joints2d_rhand = wb_kps['joints2d_rhand']
joints_part = {'lhand': joints2d_lhand, 'rhand': joints2d_rhand, 'face': joints2d_face}
self.bboxes_part = {}
self.joints2d_part = {}
for part, joints in joints_part.items():
# print('joints2d part', part, type(joints), joints[0].shape)
# bboxes, time_pt1, time_pt2 = get_all_bbox_params(joints, vis_thresh=-1)
# bboxes[:, 2:] = 150. / bboxes[:, 2:]
# self.bboxes_part[part] = np.stack([bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 2]]).T
# self.joints2d_part[part] = joints[time_pt1:time_pt2]
self.joints2d_part[part] = joints
if len(self.joints2d_part[part]) == 0:
print('part 0000', part, time_pt1, time_pt2, joints[time_pt1:time_pt2])
exit()
def __len__(self):
# return len(self.image_file_names)
return len(self.bboxes)
def rgb_processing(self, rgb_img, center, scale, res, rot=0., flip=0):
"""Process rgb image and do augmentation."""
# in the rgb image we add pixel noise in a channel-wise manner
# rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0]))
# rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1]))
# rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2]))
# crop
crop_img_resized, crop_img, crop_shape = crop(rgb_img, center, scale, res, rot=rot)
# flip the image
if flip:
crop_img_resized = flip_img(crop_img_resized)
crop_img = flip_img(crop_img)
# rgb_img = flip_img(rgb_img)
# (3,224,224),float,[0,1]
crop_img_resized = np.transpose(crop_img_resized.astype('float32'), (2,0,1)) / 255.0
crop_img = np.transpose(crop_img.astype('float32'), (2,0,1)) / 255.0
# rgb_img = np.transpose(rgb_img.astype('float32'), (2,0,1)) / 255.0
return crop_img_resized, crop_img, crop_shape
def j2d_processing(self, kp, t, f, is_smpl=False, is_hand=False, is_face=False, is_feet=False):
"""Process gt 2D keypoints and apply all augmentation transforms."""
kp = kp.copy()
nparts = kp.shape[0]
# res = [constants.IMG_RES, constants.IMG_RES]
# t = get_transform(center, scale, res, rot=rot)
for i in range(nparts):
pt = kp[i,0:2]
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
kp[i,0:2] = new_pt[:2]
# kp[i,0:2] = new_pt[:2].astype(int) + 1
# convert to normalized coordinates
kp[:,:-1] = 2.*kp[:,:-1] / constants.IMG_RES - 1.
# flip the x coordinates
if f:
if is_hand:
kp = flip_kp(kp, type='hand')
elif is_face:
kp = flip_kp(kp, type='face')
elif is_feet:
kp = flip_kp(kp, type='feet')
else:
kp = flip_kp(kp, is_smpl)
kp = kp.astype('float32')
return kp
def __getitem__(self, idx):
if self.pre_load_imgs is not None:
img = self.pre_load_imgs
else:
# img = cv2.cvtColor(cv2.imread(self.image_file_names[idx]), cv2.COLOR_BGR2RGB)
img_orig = cv2.imread(self.image_file_names[idx])[:,:,::-1].copy().astype(np.float32)
# img_orig = img.copy()
orig_height, orig_width = img_orig.shape[:2]
if not self.full_body:
bbox = self.bboxes[idx]
j2d = self.joints2d[idx] if self.has_keypoints else None
norm_img, raw_img, kp_2d = get_single_image_crop_demo(
img,
bbox,
kp_2d=j2d,
scale=self.scale_factor,
crop_size=self.crop_size)
if self.has_keypoints:
return norm_img, kp_2d
else:
return norm_img
else:
item = {}
scale = self.scale_factor
rot = 0.
flip = 0
j2d = self.joints2d[idx]
kp2d_valid = j2d[j2d[:, 2]>0.]
bbox = [min(kp2d_valid[:, 0]), min(kp2d_valid[:, 1]),
max(kp2d_valid[:, 0]), max(kp2d_valid[:, 1])]
center = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
sc = 1.2 * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 200.
img, _, crop_shape = self.rgb_processing(img_orig, center, sc*scale, [constants.IMG_RES, constants.IMG_RES])
# crop_img = np.transpose(img.astype('float32'), (1,2,0)) * 255.
# cv2.imwrite('notebooks/output/body_img.png', crop_img.astype(np.uint8))
# Store image before normalization to use it in visualization
item['img_body'] = self.normalize_img(torch.from_numpy(img).float())
item['orig_height'] = orig_height
item['orig_width'] = orig_width
item['person_id'] = self.person_ids[idx]
img_hr, img_crop, _ = self.rgb_processing(img_orig, center, sc*scale, [constants.IMG_RES * 8, constants.IMG_RES * 8])
# print('img_hr', img_hr.shape)
# img_orig = flip_img(img_orig) if flip else img_orig
# img_orig = np.transpose(img_orig.astype('float32'), (2,0,1)) / 255.0
# item['img_orig'] = self.normalize_img(torch.from_numpy(img_orig).float())
kps_transf = get_transform(center, sc * scale, [constants.IMG_RES, constants.IMG_RES], rot=rot)
# rot_transf = get_rot_transf([constants.IMG_RES, constants.IMG_RES], rot)
# item['scale'] = float(sc * scale)
# item['center'] = center.astype(np.float32)
# item['kps_transf'] = get_transform(center, sc * scale, [constants.IMG_RES, constants.IMG_RES], rot=rot).astype(np.float32)
# item['rot_transf'] = rot_transf.astype(np.float32)
lhand_kp2d, rhand_kp2d, face_kp2d = self.joints2d_part['lhand'][idx], self.joints2d_part['rhand'][idx], self.joints2d_part['face'][idx]
hand_kp2d = self.j2d_processing(np.concatenate([lhand_kp2d, rhand_kp2d]).copy(), kps_transf, flip, is_hand=True)
face_kp2d = self.j2d_processing(face_kp2d.copy(), kps_transf, flip, is_face=True)
n_hand_kp = len(constants.HAND_NAMES)
# item['lhand_kp2d'] = hand_kp2d[:n_hand_kp]
# item['rhand_kp2d'] = hand_kp2d[n_hand_kp:]
# item['face_kp2d'] = face_kp2d
# part_kp2d_dict = {'lhand': item['lhand_kp2d'], 'rhand': item['rhand_kp2d'], 'face': item['face_kp2d']}
part_kp2d_dict = {'lhand': hand_kp2d[:n_hand_kp], 'rhand': hand_kp2d[n_hand_kp:], 'face': face_kp2d}
for part in ['lhand', 'rhand', 'face']:
kp2d = part_kp2d_dict[part]
# kp2d_valid = kp2d[kp2d[:, 2]>0.005]
kp2d_valid = kp2d[kp2d[:, 2]>0.]
if len(kp2d_valid) > 0:
bbox = [min(kp2d_valid[:, 0]), min(kp2d_valid[:, 1]),
max(kp2d_valid[:, 0]), max(kp2d_valid[:, 1])]
center_part = [(bbox[2] + bbox[0]) / 2., (bbox[3] + bbox[1]) / 2.]
scale_part = 2. * max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 2
# handle invalid part keypoints
if len(kp2d_valid) < 1 or scale_part < 0.01:
center_part = [0, 0]
scale_part = 0.5
kp2d[:, 2] = 0
center_part = torch.tensor(center_part).float()
theta_part = torch.zeros(1, 2, 3)
theta_part[:, 0, 0] = scale_part
theta_part[:, 1, 1] = scale_part
theta_part[:, :, -1] = center_part
crop_hf_img_size = torch.Size([1, 3, cfg.MODEL.PyMAF.HF_IMG_SIZE, cfg.MODEL.PyMAF.HF_IMG_SIZE])
grid = F.affine_grid(theta_part.detach(), crop_hf_img_size, align_corners=False)
img_part = F.grid_sample(torch.from_numpy(img_crop[None]), grid.cpu(), align_corners=False).squeeze(0)
item[f'img_{part}'] = self.normalize_img(img_part.float())
theta_i_inv = torch.zeros_like(theta_part)
theta_i_inv[:, 0, 0] = 1. / theta_part[:, 0, 0]
theta_i_inv[:, 1, 1] = 1. / theta_part[:, 1, 1]
theta_i_inv[:, :, -1] = - theta_part[:, :, -1] / theta_part[:, 0, 0].unsqueeze(-1)
# kp2d = torch.from_numpy(kp2d[None])
# part_kp2d = torch.bmm(theta_i_inv, homo_vector(kp2d[:, :, :2]).permute(0, 2, 1)).permute(0, 2, 1)
# part_kp2d = torch.cat([part_kp2d, kp2d[:, :, 2:3]], dim=-1).squeeze(0)
# item[f'{part}_kp2d_local'] = part_kp2d
# item[f'{part}_theta'] = theta_part[0]
item[f'{part}_theta_inv'] = theta_i_inv[0]
return item
# return [item[k] for k in ['img', 'img_lhand', 'img_rhand', 'img_face', 'lhand_theta_inv', 'rhand_theta_inv', 'face_theta_inv']]
class ImageFolder(Dataset):
def __init__(self, image_folder):
self.image_file_names = [
osp.join(image_folder, x)
for x in os.listdir(image_folder)
if x.endswith('.png') or x.endswith('.jpg')
]
self.image_file_names = sorted(self.image_file_names)
def __len__(self):
return len(self.image_file_names)
def __getitem__(self, idx):
img = cv2.cvtColor(cv2.imread(self.image_file_names[idx]), cv2.COLOR_BGR2RGB)
return to_tensor(img)