offline_stores_try_on / model /DensePose /densepose_extractor.py
Ali Mohsin
feat: Add virtual try-on system components including DensePose, SMPL, and pix2pixHD models, rendering, and utilities.
5db43ff
import argparse
import glob
import logging
import os
import sys
from typing import Any, ClassVar, Dict, List
import torch
import copy
sys.path.append("./model/DensePose")
from detectron2.config import CfgNode, get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.engine.defaults import DefaultPredictor
from detectron2.structures.instances import Instances
from detectron2.utils.logger import setup_logger
from densepose import add_densepose_config
from densepose.structures import DensePoseChartPredictorOutput, DensePoseEmbeddingPredictorOutput
from densepose.utils.logger import verbosity_to_level
from densepose.vis.base import CompoundVisualizer
from densepose.vis.bounding_box import ScoredBoundingBoxVisualizer
from densepose.vis.densepose_outputs_vertex import (
DensePoseOutputsTextureVisualizer,
DensePoseOutputsVertexVisualizer,
get_texture_atlases,
)
from densepose.vis.densepose_results import (
DensePoseResultsContourVisualizer,
DensePoseResultsFineSegmentationVisualizer,
DensePoseResultsUVisualizer,
DensePoseResultsVVisualizer,
)
from densepose.vis.densepose_results_textures import (
DensePoseResultsVisualizerWithTexture,
get_texture_atlas,
)
from densepose.vis.extractor import (
CompoundExtractor,
DensePoseOutputsExtractor,
DensePoseResultExtractor,
create_extractor,
)
from .apply_net import create_argument_parser, DumpAction
import torch
import torch.nn.functional as F
import numpy as np
import cv2
def get_palette(num_cls):
""" Returns the color map for visualizing the segmentation mask.
Args:
num_cls: Number of classes
Returns:
The color map
"""
n = num_cls
palette = [0] * (n * 3)
for j in range(0, n):
lab = j
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
i += 1
lab >>= 3
return palette
class DensePoseExtractor(DumpAction):
def __init__(self):
self.parser = argparse.ArgumentParser()
self.dp_model = DumpAction()
#self.dp_model.add_arguments(self.parser)
self.args = self.parser.parse_args([])
opts = []
self.cfg = './model/DensePose/configs/densepose_rcnn_R_50_FPN_s1x.yaml'
self.model = 'https://dl.fbaipublicfiles.com/densepose/densepose_rcnn_R_50_FPN_s1x/165712039/model_final_162be9.pkl'
cfg = self.dp_model.setup_config(self.cfg, self.model, self.args, opts)
self.predictor = DefaultPredictor(cfg)
self.palette = np.array(get_palette(25), np.uint8).reshape(-1,3)
def forward(self,img):
img = img[:,:,[2,1,0]] # convert to BGR
with torch.no_grad():
outputs = self.predictor(img)["instances"]#BGR fromat
image_fpath = "file_name"
result = {"file_name": image_fpath}
if outputs.has("scores"):
result["scores"] = outputs.get("scores").cpu()
if outputs.has("pred_boxes"):
result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
if outputs.has("pred_densepose"):
if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
extractor = DensePoseResultExtractor()
elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
extractor = DensePoseOutputsExtractor()
result["pred_densepose"] = extractor(outputs)[0]
box_list = (result['pred_boxes_XYXY'][0]).tolist()
box_list = [int(round(coord)) for coord in box_list]
x_min, y_min, x_max, y_max = box_list
labels = result['pred_densepose'][0].labels
uv = result['pred_densepose'][0].uv
mask_h, mask_w = labels.shape
raw_h, raw_w, _ = img.shape
#convert label to float
labels = labels.float()/24.0
output_tensor = torch.zeros(3, raw_h, raw_w).cuda()
output_tensor[0,y_min:y_min+mask_h,x_min:x_min+mask_w]=labels
output_tensor[1:, y_min:y_min + mask_h, x_min:x_min + mask_w] = uv
return output_tensor
def get_soft_map(self,img, isRGB=False):
if isRGB:
img = img[:,:,[2,1,0]] # convert to BGR
with torch.no_grad():
outputs = self.predictor(img)["instances"]
image_fpath = "file_name"
result = {"file_name": image_fpath}
if outputs.has("scores"):
result["scores"] = outputs.get("scores").cpu()
if outputs.has("pred_boxes"):
result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
if outputs.has("pred_densepose"):
if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
extractor = DensePoseResultExtractor()
#print("yes")
elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
extractor = DensePoseOutputsExtractor()
coarse_segm = outputs.pred_densepose.coarse_segm
fine_segm = outputs.pred_densepose.fine_segm
#result["pred_densepose"] = extractor(outputs)[0]
if len(result['pred_boxes_XYXY'])==0:
return None
max_id = self.get_max_index(result['pred_boxes_XYXY'])
# print("Box:",result['pred_boxes_XYXY'][0])
box_list = (result['pred_boxes_XYXY'][max_id]).tolist()
box_list = [int(round(coord)) for coord in box_list]
x_min, y_min, x_max, y_max = box_list
w = max(int(x_max-x_min), 1)
h = max(int(y_max-y_min), 1)
# coarse segmentation
coarse_segm_bbox = F.interpolate(
coarse_segm,
(h, w),
mode="bilinear",
align_corners=False,
).argmax(dim=1)
# combined coarse and fine segmentation
#labels = (
# F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
# * (coarse_segm_bbox > 0).long()
#)
#print(F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).shape)
#print((coarse_segm_bbox > 0).shape)
soft_map = F.interpolate(fine_segm[[max_id]], (h, w), mode="bilinear", align_corners=False)#* (coarse_segm_bbox[[max_id]] > 0).long()
soft_map = soft_map[0]#CHW
raw_h, raw_w, _ = img.shape
output_map = torch.zeros((25,raw_h,raw_w), dtype=torch.float32).cuda()
#print(h,w)
#print(soft_map.shape)
#print(output_map[:, y_min:y_min + h, x_min:x_min + w].shape)
for i in range(25):
maxv=soft_map[i].max().item()
minv = soft_map[i].min().item()
soft_map[i] = (soft_map[i] - minv)/(maxv - minv)
output_map[:, y_min:y_min + h, x_min:x_min + w] = soft_map
torsoleghead_index = [1, 2, 5, 6, 7, 9, 8, 10, 11, 13, 12, 14, 23, 24]
left_arm_index = [4, 15, 17, 19, 21]
right_arm_index = [3, 16, 18, 20, 22]
r_channel = output_map[left_arm_index, :,:].max(dim=0)[0].cpu().numpy()
b_channel = output_map[right_arm_index, :,:].max(dim=0)[0].cpu().numpy()
g_channel = output_map[torsoleghead_index, :,:].max(dim=0)[0].cpu().numpy()
result_img = np.concatenate((r_channel[:,:,np.newaxis],g_channel[:,:,np.newaxis],b_channel[:,:,np.newaxis]), axis=2)
#result_img=result_img/20.0
#result_img[result_img<0]=0
result_img=(result_img*255).astype(np.uint8)
return result_img
def get_IUV(self,img, isRGB=False):
if isRGB:
img = img[:,:,[2,1,0]] # convert to BGR
with torch.no_grad():
outputs = self.predictor(img)["instances"]
image_fpath = "file_name"
result = {"file_name": image_fpath}
if outputs.has("scores"):
result["scores"] = outputs.get("scores").cpu()
if outputs.has("pred_boxes"):
result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
if outputs.has("pred_densepose"):
if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
extractor = DensePoseResultExtractor()#this
elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
extractor = DensePoseOutputsExtractor()
result["pred_densepose"] = extractor(outputs)[0]
if len(result['pred_boxes_XYXY'])==0:
return None
max_id = self.get_max_index(result['pred_boxes_XYXY'])
#print("Box:",result['pred_boxes_XYXY'][0])
box_list = (result['pred_boxes_XYXY'][max_id]).tolist()
box_list = [int(round(coord)) for coord in box_list]
x_min, y_min, x_max, y_max = box_list
labels = result['pred_densepose'][max_id].labels
uv = result['pred_densepose'][max_id].uv
mask_h, mask_w = labels.shape
raw_h, raw_w, _ = img.shape
#convert label to float
labels = labels.float()#/24.0
output_tensor = torch.zeros(3, raw_h, raw_w).cuda()
output_tensor[0,y_min:y_min+mask_h,x_min:x_min+mask_w]=labels
output_tensor[1:, y_min:y_min + mask_h, x_min:x_min + mask_w] = uv
output_tensor[1:,:]*=255.0
IUV = output_tensor.permute(1,2,0).cpu().numpy().astype(np.uint8)
#print(labels.max())
return IUV
def get_max_index(self, boxes):
areas=[]
for i in range(len(boxes)):
box=boxes[i]
area=(box[2]-box[0])*(box[3]-box[1])
areas.append(area)
return np.argmax(areas)
def IUV2img(self,IUV:np.ndarray):
IUV=IUV.astype(np.float32)
IUV[:,:,0]/=24.0
IUV[:,:,0]*=255
IUV=IUV.astype(np.uint8)
return IUV
def get_dp_img(self,img,isRGB=False):
return self.IUV2img(self.get_IUV(img,isRGB))
def get_hand_mask(self,img):
# input must be BGR
with torch.no_grad():
outputs = self.predictor(img)["instances"]#BGR fromat
image_fpath = "file_name"
result = {"file_name": image_fpath}
if outputs.has("scores"):
result["scores"] = outputs.get("scores").cpu()
if outputs.has("pred_boxes"):
result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
if outputs.has("pred_densepose"):
if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
extractor = DensePoseResultExtractor()
elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
extractor = DensePoseOutputsExtractor()
result["pred_densepose"] = extractor(outputs)[0]
raw_h, raw_w, _ = img.shape
if len(result['pred_boxes_XYXY']) == 0:
return np.zeros([raw_h, raw_w]).astype(bool)
box_list = (result['pred_boxes_XYXY'][0]).tolist()
box_list = [int(round(coord)) for coord in box_list]
x_min, y_min, x_max, y_max = box_list
labels = result['pred_densepose'][0].labels
uv = result['pred_densepose'][0].uv
mask_h, mask_w = labels.shape
#convert label to float
labels = labels.cpu().numpy().astype(np.uint8)
raw_labels = np.zeros([raw_h, raw_w])
raw_labels[y_min:y_min + mask_h, x_min:x_min + mask_w] = labels
hand_mask = (raw_labels==3)|(raw_labels==4)
return hand_mask
def get_vis_img(self,img_path):
output_tensor = self.forward(img_path).cpu()
output_tensor = output_tensor.permute(1,2,0)*255
output_img = output_tensor.numpy().astype(np.uint8)
cv2.imwrite('seg.jpg',output_img)
def get_bbox(self,img, isRGB=False):
if isRGB:
img = img[:,:,[2,1,0]] # convert to BGR
with torch.no_grad():
outputs = self.predictor(img)["instances"]
image_fpath = "file_name"
result = {"file_name": image_fpath}
if outputs.has("scores"):
result["scores"] = outputs.get("scores").cpu()
if outputs.has("pred_boxes"):
result["pred_boxes_XYXY"] = outputs.get("pred_boxes").tensor.cpu()
if outputs.has("pred_densepose"):
if isinstance(outputs.pred_densepose, DensePoseChartPredictorOutput):
extractor = DensePoseResultExtractor()
elif isinstance(outputs.pred_densepose, DensePoseEmbeddingPredictorOutput):
extractor = DensePoseOutputsExtractor()
result["pred_densepose"] = extractor(outputs)[0]
if len(result['pred_boxes_XYXY'])==0:
return None, None
max_id = self.get_max_index(result['pred_boxes_XYXY'])
#print("Box:",result['pred_boxes_XYXY'][0])
box_list = (result['pred_boxes_XYXY'][max_id]).tolist()
res_box_list=copy.deepcopy(box_list)
#print(res_box_list)
box_list = [int(round(coord)) for coord in box_list]
x_min, y_min, x_max, y_max = box_list
labels = result['pred_densepose'][max_id].labels
uv = result['pred_densepose'][max_id].uv
mask_h, mask_w = labels.shape
raw_h, raw_w, _ = img.shape
#convert label to float
labels = labels.float()#/24.0
output_tensor = torch.zeros(3, raw_h, raw_w).cuda()
output_tensor[0,y_min:y_min+mask_h,x_min:x_min+mask_w]=labels
output_tensor[1:, y_min:y_min + mask_h, x_min:x_min + mask_w] = uv
output_tensor[1:,:]*=255.0
IUV = output_tensor.permute(1,2,0).cpu().numpy().astype(np.uint8)
#print(labels.max())
return res_box_list, IUV
def get_trans2roi(self,img,new_h, new_w,isRGB=False):
bbox, IUV = self.get_bbox(img,isRGB)
if bbox is None:
return None, None
x_min, y_min, x_max, y_max = bbox
x_center = (x_min + x_max) / 2
y_center = (y_min + y_max) / 2
if (y_max - y_min) / ((x_max - x_min) + 1e-5) > (new_h / new_w): # Too tall
half_y = (y_max - y_center) * 1.1
half_x = half_y * new_w / new_h
else:
half_x = (x_max - x_center) * 1.1
half_y = half_x * new_h / new_w
src = np.zeros([3, 2], np.float32)
center = np.array([x_center, y_center], np.float32)
src[0, :] = center + np.array([-half_x, half_y], np.float32)
src[1, :] = center + np.array([-half_x, -half_y], np.float32)
src[2, :] = center + np.array([half_x, -half_y], np.float32)
dst = np.zeros([3, 2], np.float32)
dst[0, :] = np.array([0, new_h - 1], np.float32)
dst[1, :] = np.array([0, 0], np.float32)
dst[2, :] = np.array([new_w - 1, 0], np.float32)
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
inv_trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
roi_IUV = cv2.warpAffine(IUV, trans, (new_w, new_h),
flags=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(0, 0, 0))
return trans, inv_trans, roi_IUV
# 1, 2 = Torso, 3 = Right Hand, 4 = Left Hand, 5 = Left Foot, 6 = Right Foot, 7, 9 = Upper Leg Right, 8, 10 = Upper Leg Left, 11, 13 = Lower Leg Right, 12, 14 = Lower Leg Left, 15, 17 = Upper Arm Left, 16, 18 = Upper Arm Right, 19, 21 = Lower Arm Left, 20, 22 = Lower Arm Right, 23, 24 = Head;