File size: 23,474 Bytes
ee3f635 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import base64
import numpy as np
from io import BytesIO
import torch
from PIL import Image
from torch.nn import functional as F
class DensePoseTransformData(object):
# Horizontal symmetry label transforms used for horizontal flip
MASK_LABEL_SYMMETRIES = [0, 1, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 14]
# fmt: off
POINT_LABEL_SYMMETRIES = [ 0, 1, 2, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15, 18, 17, 20, 19, 22, 21, 24, 23] # noqa
# fmt: on
def __init__(self, uv_symmetries):
self.mask_label_symmetries = DensePoseTransformData.MASK_LABEL_SYMMETRIES
self.point_label_symmetries = DensePoseTransformData.POINT_LABEL_SYMMETRIES
self.uv_symmetries = uv_symmetries
@staticmethod
def load(fpath):
import scipy.io
uv_symmetry_map = scipy.io.loadmat(fpath)
uv_symmetry_map_torch = {}
for key in ["U_transforms", "V_transforms"]:
uv_symmetry_map_torch[key] = []
map_src = uv_symmetry_map[key]
map_dst = uv_symmetry_map_torch[key]
for i in range(map_src.shape[1]):
map_dst.append(torch.from_numpy(map_src[0, i]).to(dtype=torch.float))
uv_symmetry_map_torch[key] = torch.stack(map_dst, dim=0).to(
device=torch.cuda.current_device()
)
transform_data = DensePoseTransformData(uv_symmetry_map_torch)
return transform_data
class DensePoseDataRelative(object):
"""
Dense pose relative annotations that can be applied to any bounding box:
x - normalized X coordinates [0, 255] of annotated points
y - normalized Y coordinates [0, 255] of annotated points
i - body part labels 0,...,24 for annotated points
u - body part U coordinates [0, 1] for annotated points
v - body part V coordinates [0, 1] for annotated points
segm - 256x256 segmentation mask with values 0,...,14
To obtain absolute x and y data wrt some bounding box one needs to first
divide the data by 256, multiply by the respective bounding box size
and add bounding box offset:
x_img = x0 + x_norm * w / 256.0
y_img = y0 + y_norm * h / 256.0
Segmentation masks are typically sampled to get image-based masks.
"""
# Key for normalized X coordinates in annotation dict
X_KEY = "dp_x"
# Key for normalized Y coordinates in annotation dict
Y_KEY = "dp_y"
# Key for U part coordinates in annotation dict
U_KEY = "dp_U"
# Key for V part coordinates in annotation dict
V_KEY = "dp_V"
# Key for I point labels in annotation dict
I_KEY = "dp_I"
# Key for segmentation mask in annotation dict
S_KEY = "dp_masks"
# Number of body parts in segmentation masks
N_BODY_PARTS = 14
# Number of parts in point labels
N_PART_LABELS = 24
MASK_SIZE = 256
def __init__(self, annotation, cleanup=False):
is_valid, reason_not_valid = DensePoseDataRelative.validate_annotation(annotation)
assert is_valid, "Invalid DensePose annotations: {}".format(reason_not_valid)
self.x = torch.as_tensor(annotation[DensePoseDataRelative.X_KEY])
self.y = torch.as_tensor(annotation[DensePoseDataRelative.Y_KEY])
self.i = torch.as_tensor(annotation[DensePoseDataRelative.I_KEY])
self.u = torch.as_tensor(annotation[DensePoseDataRelative.U_KEY])
self.v = torch.as_tensor(annotation[DensePoseDataRelative.V_KEY])
self.segm = DensePoseDataRelative.extract_segmentation_mask(annotation)
self.device = torch.device("cpu")
if cleanup:
DensePoseDataRelative.cleanup_annotation(annotation)
def to(self, device):
if self.device == device:
return self
new_data = DensePoseDataRelative.__new__(DensePoseDataRelative)
new_data.x = self.x
new_data.x = self.x.to(device)
new_data.y = self.y.to(device)
new_data.i = self.i.to(device)
new_data.u = self.u.to(device)
new_data.v = self.v.to(device)
new_data.segm = self.segm.to(device)
new_data.device = device
return new_data
@staticmethod
def extract_segmentation_mask(annotation):
import pycocotools.mask as mask_utils
poly_specs = annotation[DensePoseDataRelative.S_KEY]
segm = torch.zeros((DensePoseDataRelative.MASK_SIZE,) * 2, dtype=torch.float32)
for i in range(DensePoseDataRelative.N_BODY_PARTS):
poly_i = poly_specs[i]
if poly_i:
mask_i = mask_utils.decode(poly_i)
segm[mask_i > 0] = i + 1
return segm
@staticmethod
def validate_annotation(annotation):
for key in [
DensePoseDataRelative.X_KEY,
DensePoseDataRelative.Y_KEY,
DensePoseDataRelative.I_KEY,
DensePoseDataRelative.U_KEY,
DensePoseDataRelative.V_KEY,
DensePoseDataRelative.S_KEY,
]:
if key not in annotation:
return False, "no {key} data in the annotation".format(key=key)
return True, None
@staticmethod
def cleanup_annotation(annotation):
for key in [
DensePoseDataRelative.X_KEY,
DensePoseDataRelative.Y_KEY,
DensePoseDataRelative.I_KEY,
DensePoseDataRelative.U_KEY,
DensePoseDataRelative.V_KEY,
DensePoseDataRelative.S_KEY,
]:
if key in annotation:
del annotation[key]
def apply_transform(self, transforms, densepose_transform_data):
self._transform_pts(transforms, densepose_transform_data)
self._transform_segm(transforms, densepose_transform_data)
def _transform_pts(self, transforms, dp_transform_data):
import detectron2.data.transforms as T
# NOTE: This assumes that HorizFlipTransform is the only one that does flip
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
if do_hflip:
self.x = self.segm.size(1) - self.x
self._flip_iuv_semantics(dp_transform_data)
def _flip_iuv_semantics(self, dp_transform_data: DensePoseTransformData) -> None:
i_old = self.i.clone()
uv_symmetries = dp_transform_data.uv_symmetries
pt_label_symmetries = dp_transform_data.point_label_symmetries
for i in range(self.N_PART_LABELS):
if i + 1 in i_old:
annot_indices_i = i_old == i + 1
if pt_label_symmetries[i + 1] != i + 1:
self.i[annot_indices_i] = pt_label_symmetries[i + 1]
u_loc = (self.u[annot_indices_i] * 255).long()
v_loc = (self.v[annot_indices_i] * 255).long()
self.u[annot_indices_i] = uv_symmetries["U_transforms"][i][v_loc, u_loc].to(
device=self.u.device
)
self.v[annot_indices_i] = uv_symmetries["V_transforms"][i][v_loc, u_loc].to(
device=self.v.device
)
def _transform_segm(self, transforms, dp_transform_data):
import detectron2.data.transforms as T
# NOTE: This assumes that HorizFlipTransform is the only one that does flip
do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1
if do_hflip:
self.segm = torch.flip(self.segm, [1])
self._flip_segm_semantics(dp_transform_data)
def _flip_segm_semantics(self, dp_transform_data):
old_segm = self.segm.clone()
mask_label_symmetries = dp_transform_data.mask_label_symmetries
for i in range(self.N_BODY_PARTS):
if mask_label_symmetries[i + 1] != i + 1:
self.segm[old_segm == i + 1] = mask_label_symmetries[i + 1]
def normalized_coords_transform(x0, y0, w, h):
"""
Coordinates transform that maps top left corner to (-1, -1) and bottom
right corner to (1, 1). Used for torch.grid_sample to initialize the
grid
"""
def f(p):
return (2 * (p[0] - x0) / w - 1, 2 * (p[1] - y0) / h - 1)
return f
class DensePoseOutput(object):
def __init__(self, S, I, U, V, confidences):
"""
Args:
S (`torch.Tensor`): coarse segmentation tensor of size (N, A, H, W)
I (`torch.Tensor`): fine segmentation tensor of size (N, C, H, W)
U (`torch.Tensor`): U coordinates for each fine segmentation label of size (N, C, H, W)
V (`torch.Tensor`): V coordinates for each fine segmentation label of size (N, C, H, W)
confidences (dict of str -> `torch.Tensor`) estimated confidence model parameters
"""
self.S = S
self.I = I # noqa: E741
self.U = U
self.V = V
self.confidences = confidences
self._check_output_dims(S, I, U, V)
def _check_output_dims(self, S, I, U, V):
assert (
len(S.size()) == 4
), "Segmentation output should have 4 " "dimensions (NCHW), but has size {}".format(
S.size()
)
assert (
len(I.size()) == 4
), "Segmentation output should have 4 " "dimensions (NCHW), but has size {}".format(
S.size()
)
assert (
len(U.size()) == 4
), "Segmentation output should have 4 " "dimensions (NCHW), but has size {}".format(
S.size()
)
assert (
len(V.size()) == 4
), "Segmentation output should have 4 " "dimensions (NCHW), but has size {}".format(
S.size()
)
assert len(S) == len(I), (
"Number of output segmentation planes {} "
"should be equal to the number of output part index "
"planes {}".format(len(S), len(I))
)
assert S.size()[2:] == I.size()[2:], (
"Output segmentation plane size {} "
"should be equal to the output part index "
"plane size {}".format(S.size()[2:], I.size()[2:])
)
assert I.size() == U.size(), (
"Part index output shape {} "
"should be the same as U coordinates output shape {}".format(I.size(), U.size())
)
assert I.size() == V.size(), (
"Part index output shape {} "
"should be the same as V coordinates output shape {}".format(I.size(), V.size())
)
def resize(self, image_size_hw):
# do nothing - outputs are invariant to resize
pass
def _crop(self, S, I, U, V, bbox_old_xywh, bbox_new_xywh):
"""
Resample S, I, U, V from bbox_old to the cropped bbox_new
"""
x0old, y0old, wold, hold = bbox_old_xywh
x0new, y0new, wnew, hnew = bbox_new_xywh
tr_coords = normalized_coords_transform(x0old, y0old, wold, hold)
topleft = (x0new, y0new)
bottomright = (x0new + wnew, y0new + hnew)
topleft_norm = tr_coords(topleft)
bottomright_norm = tr_coords(bottomright)
hsize = S.size(1)
wsize = S.size(2)
grid = torch.meshgrid(
torch.arange(
topleft_norm[1],
bottomright_norm[1],
(bottomright_norm[1] - topleft_norm[1]) / hsize,
)[:hsize],
torch.arange(
topleft_norm[0],
bottomright_norm[0],
(bottomright_norm[0] - topleft_norm[0]) / wsize,
)[:wsize],
)
grid = torch.stack(grid, dim=2).to(S.device)
assert (
grid.size(0) == hsize
), "Resampled grid expected " "height={}, actual height={}".format(hsize, grid.size(0))
assert grid.size(1) == wsize, "Resampled grid expected " "width={}, actual width={}".format(
wsize, grid.size(1)
)
S_new = F.grid_sample(
S.unsqueeze(0),
torch.unsqueeze(grid, 0),
mode="bilinear",
padding_mode="border",
align_corners=True,
).squeeze(0)
I_new = F.grid_sample(
I.unsqueeze(0),
torch.unsqueeze(grid, 0),
mode="bilinear",
padding_mode="border",
align_corners=True,
).squeeze(0)
U_new = F.grid_sample(
U.unsqueeze(0),
torch.unsqueeze(grid, 0),
mode="bilinear",
padding_mode="border",
align_corners=True,
).squeeze(0)
V_new = F.grid_sample(
V.unsqueeze(0),
torch.unsqueeze(grid, 0),
mode="bilinear",
padding_mode="border",
align_corners=True,
).squeeze(0)
return S_new, I_new, U_new, V_new
def crop(self, indices_cropped, bboxes_old, bboxes_new):
"""
Crop outputs for selected bounding boxes to the new bounding boxes.
"""
# VK: cropping is ignored for now
# for i, ic in enumerate(indices_cropped):
# self.S[ic], self.I[ic], self.U[ic], self.V[ic] = \
# self._crop(self.S[ic], self.I[ic], self.U[ic], self.V[ic],
# bboxes_old[i], bboxes_new[i])
pass
def hflip(self, transform_data: DensePoseTransformData) -> None:
"""
Change S, I, U and V to take into account a Horizontal flip.
"""
if self.I.shape[0] > 0:
for el in "SIUV":
self.__dict__[el] = torch.flip(self.__dict__[el], [3])
self._flip_iuv_semantics_tensor(transform_data)
self._flip_segm_semantics_tensor(transform_data)
def _flip_iuv_semantics_tensor(self, dp_transform_data: DensePoseTransformData) -> None:
point_label_symmetries = dp_transform_data.point_label_symmetries
uv_symmetries = dp_transform_data.uv_symmetries
N, C, H, W = self.U.shape
u_loc = (self.U[:, 1:, :, :].clamp(0, 1) * 255).long()
v_loc = (self.V[:, 1:, :, :].clamp(0, 1) * 255).long()
Iindex = torch.arange(C - 1, device=self.U.device)[None, :, None, None].expand(
N, C - 1, H, W
)
self.U[:, 1:, :, :] = uv_symmetries["U_transforms"][Iindex, v_loc, u_loc].to(
device=self.U.device
)
self.V[:, 1:, :, :] = uv_symmetries["V_transforms"][Iindex, v_loc, u_loc].to(
device=self.V.device
)
for el in "IUV":
self.__dict__[el] = self.__dict__[el][:, point_label_symmetries, :, :]
def _flip_segm_semantics_tensor(self, dp_transform_data):
if self.S.shape[1] == DensePoseDataRelative.N_BODY_PARTS + 1:
self.S = self.S[:, dp_transform_data.mask_label_symmetries, :, :]
def to_result(self, boxes_xywh):
"""
Convert DensePose outputs to results format. Results are more compact,
but cannot be resampled any more
"""
result = DensePoseResult(boxes_xywh, self.S, self.I, self.U, self.V)
return result
def __getitem__(self, item):
if isinstance(item, int):
S_selected = self.S[item].unsqueeze(0)
I_selected = self.I[item].unsqueeze(0)
U_selected = self.U[item].unsqueeze(0)
V_selected = self.V[item].unsqueeze(0)
conf_selected = {}
for key in self.confidences:
conf_selected[key] = self.confidences[key][item].unsqueeze(0)
else:
S_selected = self.S[item]
I_selected = self.I[item]
U_selected = self.U[item]
V_selected = self.V[item]
conf_selected = {}
for key in self.confidences:
conf_selected[key] = self.confidences[key][item]
return DensePoseOutput(S_selected, I_selected, U_selected, V_selected, conf_selected)
def __str__(self):
s = "DensePoseOutput S {}, I {}, U {}, V {}".format(
list(self.S.size()), list(self.I.size()), list(self.U.size()), list(self.V.size())
)
s_conf = "confidences: [{}]".format(
", ".join([f"{key} {list(self.confidences[key].size())}" for key in self.confidences])
)
return ", ".join([s, s_conf])
def __len__(self):
return self.S.size(0)
class DensePoseResult(object):
def __init__(self, boxes_xywh, S, I, U, V):
self.results = []
self.boxes_xywh = boxes_xywh.cpu().tolist()
assert len(boxes_xywh.size()) == 2
assert boxes_xywh.size(1) == 4
for i, box_xywh in enumerate(boxes_xywh):
result_i = self._output_to_result(box_xywh, S[[i]], I[[i]], U[[i]], V[[i]])
result_numpy_i = result_i.cpu().numpy()
result_encoded_i = DensePoseResult.encode_png_data(result_numpy_i)
result_encoded_with_shape_i = (result_numpy_i.shape, result_encoded_i)
self.results.append(result_encoded_with_shape_i)
def __str__(self):
s = "DensePoseResult: N={} [{}]".format(
len(self.results), ", ".join([str(list(r[0])) for r in self.results])
)
return s
def _output_to_result(self, box_xywh, S, I, U, V):
x, y, w, h = box_xywh
w = max(int(w), 1)
h = max(int(h), 1)
result = torch.zeros([3, h, w], dtype=torch.uint8, device=U.device)
assert (
len(S.size()) == 4
), "AnnIndex tensor size should have {} " "dimensions but has {}".format(4, len(S.size()))
s_bbox = F.interpolate(S, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
assert (
len(I.size()) == 4
), "IndexUV tensor size should have {} " "dimensions but has {}".format(4, len(S.size()))
i_bbox = (
F.interpolate(I, (h, w), mode="bilinear", align_corners=False).argmax(dim=1)
* (s_bbox > 0).long()
).squeeze(0)
assert len(U.size()) == 4, "U tensor size should have {} " "dimensions but has {}".format(
4, len(U.size())
)
u_bbox = F.interpolate(U, (h, w), mode="bilinear", align_corners=False)
assert len(V.size()) == 4, "V tensor size should have {} " "dimensions but has {}".format(
4, len(V.size())
)
v_bbox = F.interpolate(V, (h, w), mode="bilinear", align_corners=False)
result[0] = i_bbox
for part_id in range(1, u_bbox.size(1)):
result[1][i_bbox == part_id] = (
(u_bbox[0, part_id][i_bbox == part_id] * 255).clamp(0, 255).to(torch.uint8)
)
result[2][i_bbox == part_id] = (
(v_bbox[0, part_id][i_bbox == part_id] * 255).clamp(0, 255).to(torch.uint8)
)
assert (
result.size(1) == h
), "Results height {} should be equal" "to bounding box height {}".format(result.size(1), h)
assert (
result.size(2) == w
), "Results width {} should be equal" "to bounding box width {}".format(result.size(2), w)
return result
@staticmethod
def encode_png_data(arr):
"""
Encode array data as a PNG image using the highest compression rate
@param arr [in] Data stored in an array of size (3, M, N) of type uint8
@return Base64-encoded string containing PNG-compressed data
"""
assert len(arr.shape) == 3, "Expected a 3D array as an input," " got a {0}D array".format(
len(arr.shape)
)
assert arr.shape[0] == 3, "Expected first array dimension of size 3," " got {0}".format(
arr.shape[0]
)
assert arr.dtype == np.uint8, "Expected an array of type np.uint8, " " got {0}".format(
arr.dtype
)
data = np.moveaxis(arr, 0, -1)
im = Image.fromarray(data)
fstream = BytesIO()
im.save(fstream, format="png", optimize=True)
s = base64.encodebytes(fstream.getvalue()).decode()
return s
@staticmethod
def decode_png_data(shape, s):
"""
Decode array data from a string that contains PNG-compressed data
@param Base64-encoded string containing PNG-compressed data
@return Data stored in an array of size (3, M, N) of type uint8
"""
fstream = BytesIO(base64.decodebytes(s.encode()))
im = Image.open(fstream)
data = np.moveaxis(np.array(im.getdata(), dtype=np.uint8), -1, 0)
return data.reshape(shape)
def __len__(self):
return len(self.results)
def __getitem__(self, item):
result_encoded = self.results[item]
bbox_xywh = self.boxes_xywh[item]
return result_encoded, bbox_xywh
class DensePoseList(object):
_TORCH_DEVICE_CPU = torch.device("cpu")
def __init__(self, densepose_datas, boxes_xyxy_abs, image_size_hw, device=_TORCH_DEVICE_CPU):
assert len(densepose_datas) == len(
boxes_xyxy_abs
), "Attempt to initialize DensePoseList with {} DensePose datas " "and {} boxes".format(
len(densepose_datas), len(boxes_xyxy_abs)
)
self.densepose_datas = []
for densepose_data in densepose_datas:
assert isinstance(densepose_data, DensePoseDataRelative) or densepose_data is None, (
"Attempt to initialize DensePoseList with DensePose datas "
"of type {}, expected DensePoseDataRelative".format(type(densepose_data))
)
densepose_data_ondevice = (
densepose_data.to(device) if densepose_data is not None else None
)
self.densepose_datas.append(densepose_data_ondevice)
self.boxes_xyxy_abs = boxes_xyxy_abs.to(device)
self.image_size_hw = image_size_hw
self.device = device
def to(self, device):
if self.device == device:
return self
return DensePoseList(self.densepose_datas, self.boxes_xyxy_abs, self.image_size_hw, device)
def __iter__(self):
return iter(self.densepose_datas)
def __len__(self):
return len(self.densepose_datas)
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_instances={}, ".format(len(self.densepose_datas))
s += "image_width={}, ".format(self.image_size_hw[1])
s += "image_height={})".format(self.image_size_hw[0])
return s
def __getitem__(self, item):
if isinstance(item, int):
densepose_data_rel = self.densepose_datas[item]
return densepose_data_rel
elif isinstance(item, slice):
densepose_datas_rel = self.densepose_datas[item]
boxes_xyxy_abs = self.boxes_xyxy_abs[item]
return DensePoseList(
densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device
)
elif isinstance(item, torch.Tensor) and (item.dtype == torch.bool):
densepose_datas_rel = [self.densepose_datas[i] for i, x in enumerate(item) if x > 0]
boxes_xyxy_abs = self.boxes_xyxy_abs[item]
return DensePoseList(
densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device
)
else:
densepose_datas_rel = [self.densepose_datas[i] for i in item]
boxes_xyxy_abs = self.boxes_xyxy_abs[item]
return DensePoseList(
densepose_datas_rel, boxes_xyxy_abs, self.image_size_hw, self.device
)
|