Spaces:
Runtime error
Runtime error
| import torch | |
| import random | |
| from torch import nn, Tensor | |
| import os | |
| import numpy as np | |
| import math | |
| import torch.nn.functional as F | |
| from torch import nn | |
| class PoseProjector(nn.Module): | |
| def __init__(self, hidden_dim=256, num_body_points=17): | |
| super().__init__() | |
| self.num_body_points = num_body_points | |
| self.V_projector = nn.Linear(hidden_dim, num_body_points) | |
| nn.init.constant_(self.V_projector.bias.data, 0) | |
| self.Z_projector = MLP(hidden_dim, hidden_dim, num_body_points * 2, 3) | |
| nn.init.constant_(self.Z_projector.layers[-1].weight.data, 0) | |
| nn.init.constant_(self.Z_projector.layers[-1].bias.data, 0) | |
| def forward(self, hs): | |
| """_summary_ | |
| Args: | |
| hs (_type_): ..., bs, nq, hidden_dim | |
| """ | |
| Z = self.Z_projector(hs) # ..., bs, nq, 34 | |
| V = self.V_projector(hs) # ..., bs, nq, 17 | |
| return Z, V | |
| def gen_encoder_output_proposals(memory: Tensor, | |
| memory_padding_mask: Tensor, | |
| spatial_shapes: Tensor, | |
| learnedwh=None): | |
| """ | |
| Input: | |
| - memory: bs, \sum{hw}, d_model | |
| - memory_padding_mask: bs, \sum{hw} | |
| - spatial_shapes: nlevel, 2 | |
| - learnedwh: 2 | |
| Output: | |
| - output_memory: bs, \sum{hw}, d_model | |
| - output_proposals: bs, \sum{hw}, 4 | |
| """ | |
| N_, S_, C_ = memory.shape | |
| base_scale = 4.0 | |
| proposals = [] | |
| _cur = 0 | |
| for lvl, (H_, W_) in enumerate(spatial_shapes): | |
| mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view( | |
| N_, H_, W_, 1) | |
| valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) | |
| valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) | |
| grid_y, grid_x = torch.meshgrid( | |
| torch.linspace(0, | |
| H_ - 1, | |
| H_, | |
| dtype=torch.float32, | |
| device=memory.device), | |
| torch.linspace(0, | |
| W_ - 1, | |
| W_, | |
| dtype=torch.float32, | |
| device=memory.device)) | |
| grid = torch.cat( | |
| [grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) # H_, W_, 2 | |
| scale = torch.cat([valid_W.unsqueeze(-1), | |
| valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) | |
| grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale | |
| if learnedwh is not None: | |
| wh = torch.ones_like(grid) * learnedwh.sigmoid() * (2.0**lvl) | |
| else: | |
| wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) | |
| proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) | |
| proposals.append(proposal) | |
| _cur += (H_ * W_) | |
| # import pdb; pdb.set_trace() | |
| output_proposals = torch.cat(proposals, 1) | |
| output_proposals_valid = ((output_proposals > 0.01) & | |
| (output_proposals < 0.99)).all(-1, keepdim=True) | |
| output_proposals = torch.log(output_proposals / | |
| (1 - output_proposals)) # unsigmoid | |
| output_proposals = output_proposals.masked_fill( | |
| memory_padding_mask.unsqueeze(-1), float('inf')) | |
| output_proposals = output_proposals.masked_fill(~output_proposals_valid, | |
| float('inf')) | |
| output_memory = memory | |
| output_memory = output_memory.masked_fill( | |
| memory_padding_mask.unsqueeze(-1), float(0)) | |
| output_memory = output_memory.masked_fill(~output_proposals_valid, | |
| float(0)) | |
| return output_memory, output_proposals | |
| class RandomBoxPerturber(): | |
| def __init__(self, | |
| x_noise_scale=0.2, | |
| y_noise_scale=0.2, | |
| w_noise_scale=0.2, | |
| h_noise_scale=0.2) -> None: | |
| self.noise_scale = torch.Tensor( | |
| [x_noise_scale, y_noise_scale, w_noise_scale, h_noise_scale]) | |
| def __call__(self, refanchors: Tensor) -> Tensor: | |
| nq, bs, query_dim = refanchors.shape | |
| device = refanchors.device | |
| noise_raw = torch.rand_like(refanchors) | |
| noise_scale = self.noise_scale.to(device)[:query_dim] | |
| new_refanchors = refanchors * (1 + (noise_raw - 0.5) * noise_scale) | |
| return new_refanchors.clamp_(0, 1) | |
| def sigmoid_focal_loss(inputs, | |
| targets, | |
| num_boxes, | |
| alpha: float = 0.25, | |
| gamma: float = 2): | |
| """ | |
| Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. | |
| Args: | |
| inputs: A float tensor of arbitrary shape. | |
| The predictions for each example. | |
| targets: A float tensor with the same shape as inputs. Stores the binary | |
| classification label for each element in inputs | |
| (0 for the negative class and 1 for the positive class). | |
| alpha: (optional) Weighting factor in range (0,1) to balance | |
| positive vs negative examples. Default = -1 (no weighting). | |
| gamma: Exponent of the modulating factor (1 - p_t) to | |
| balance easy vs hard examples. | |
| Returns: | |
| Loss tensor | |
| """ | |
| prob = inputs.sigmoid() | |
| ce_loss = F.binary_cross_entropy_with_logits(inputs, | |
| targets, | |
| reduction='none') | |
| p_t = prob * targets + (1 - prob) * (1 - targets) | |
| loss = ce_loss * ((1 - p_t)**gamma) | |
| if alpha >= 0: | |
| alpha_t = alpha * targets + (1 - alpha) * (1 - targets) | |
| loss = alpha_t * loss | |
| return loss.mean(1).sum() / num_boxes | |
| class MLP(nn.Module): | |
| """Very simple multi-layer perceptron (also called FFN)""" | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList( | |
| nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| return x | |
| def _get_activation_fn(activation, d_model=256, batch_dim=0): | |
| """Return an activation function given a string.""" | |
| if activation == 'relu': | |
| return F.relu | |
| if activation == 'gelu': | |
| return F.gelu | |
| if activation == 'glu': | |
| return F.glu | |
| if activation == 'prelu': | |
| return nn.PReLU() | |
| if activation == 'selu': | |
| return F.selu | |
| raise RuntimeError(F'activation should be relu/gelu, not {activation}.') | |
| def gen_sineembed_for_position(pos_tensor): | |
| # n_query, bs, _ = pos_tensor.size() | |
| # sineembed_tensor = torch.zeros(n_query, bs, 256) | |
| scale = 2 * math.pi | |
| dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) | |
| dim_t = 10000**(2 * (dim_t // 2) / 128) | |
| x_embed = pos_tensor[:, :, 0] * scale | |
| y_embed = pos_tensor[:, :, 1] * scale | |
| pos_x = x_embed[:, :, None] / dim_t | |
| pos_y = y_embed[:, :, None] / dim_t | |
| pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), | |
| dim=3).flatten(2) | |
| pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), | |
| dim=3).flatten(2) | |
| if pos_tensor.size(-1) == 2: | |
| pos = torch.cat((pos_y, pos_x), dim=2) | |
| elif pos_tensor.size(-1) == 4: | |
| w_embed = pos_tensor[:, :, 2] * scale | |
| pos_w = w_embed[:, :, None] / dim_t | |
| pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), | |
| dim=3).flatten(2) | |
| h_embed = pos_tensor[:, :, 3] * scale | |
| pos_h = h_embed[:, :, None] / dim_t | |
| pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), | |
| dim=3).flatten(2) | |
| pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) | |
| else: | |
| raise ValueError('Unknown pos_tensor shape(-1):{}'.format( | |
| pos_tensor.size(-1))) | |
| return pos | |
| def oks_overlaps(kpt_preds, kpt_gts, kpt_valids, kpt_areas, sigmas): | |
| sigmas = kpt_preds.new_tensor(sigmas) | |
| variances = (sigmas * 2)**2 | |
| assert kpt_preds.size(0) == kpt_gts.size(0) | |
| kpt_preds = kpt_preds.reshape(-1, kpt_preds.size(-1) // 2, 2) | |
| kpt_gts = kpt_gts.reshape(-1, kpt_gts.size(-1) // 2, 2) | |
| squared_distance = (kpt_preds[:, :, 0] - kpt_gts[:, :, 0]) ** 2 + \ | |
| (kpt_preds[:, :, 1] - kpt_gts[:, :, 1]) ** 2 | |
| # import pdb | |
| # pdb.set_trace() | |
| # assert (kpt_valids.sum(-1) > 0).all() | |
| squared_distance0 = squared_distance / (kpt_areas[:, None] * | |
| variances[None, :] * 2) | |
| squared_distance1 = torch.exp(-squared_distance0) | |
| squared_distance1 = squared_distance1 * kpt_valids | |
| oks = squared_distance1.sum(dim=1) / (kpt_valids.sum(dim=1) + 1e-6) | |
| return oks | |
| def oks_loss(pred, | |
| target, | |
| valid=None, | |
| area=None, | |
| linear=False, | |
| sigmas=None, | |
| eps=1e-6): | |
| """Oks loss. | |
| Computing the oks loss between a set of predicted poses and target poses. | |
| The loss is calculated as negative log of oks. | |
| Args: | |
| pred (torch.Tensor): Predicted poses of format (x1, y1, x2, y2, ...), | |
| shape (n, 2K). | |
| target (torch.Tensor): Corresponding gt poses, shape (n, 2K). | |
| linear (bool, optional): If True, use linear scale of loss instead of | |
| log scale. Default: False. | |
| eps (float): Eps to avoid log(0). | |
| Return: | |
| torch.Tensor: Loss tensor. | |
| """ | |
| oks = oks_overlaps(pred, target, valid, area, sigmas).clamp(min=eps) | |
| if linear: | |
| loss = 1 - oks | |
| else: | |
| loss = -oks.log() | |
| loss = loss * valid.sum(-1) / (valid.sum(-1) + eps) | |
| return loss | |
| class OKSLoss(nn.Module): | |
| """IoULoss. | |
| Computing the oks loss between a set of predicted poses and target poses. | |
| Args: | |
| linear (bool): If True, use linear scale of loss instead of log scale. | |
| Default: False. | |
| eps (float): Eps to avoid log(0). | |
| reduction (str): Options are "none", "mean" and "sum". | |
| loss_weight (float): Weight of loss. | |
| """ | |
| def __init__(self, | |
| linear=False, | |
| num_keypoints=17, | |
| eps=1e-6, | |
| reduction='mean', | |
| loss_weight=1.0): | |
| super(OKSLoss, self).__init__() | |
| self.linear = linear | |
| self.eps = eps | |
| self.reduction = reduction | |
| self.loss_weight = loss_weight | |
| if num_keypoints == 17: | |
| self.sigmas = np.array([ | |
| .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, | |
| 1.07, .87, .87, .89, .89 | |
| ], | |
| dtype=np.float32) / 10.0 | |
| elif num_keypoints == 14: | |
| self.sigmas = np.array([ | |
| .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89, | |
| .79, .79 | |
| ]) / 10.0 | |
| elif num_keypoints == 6: | |
| self.sigmas = np.array( | |
| [ | |
| .25,.25,.25,.25,.25,.25 | |
| ], dtype=np.float32 | |
| )/ 10.0 | |
| else: | |
| raise ValueError(f'Unsupported keypoints number {num_keypoints}') | |
| def forward(self, | |
| pred, | |
| target, | |
| valid, | |
| area, | |
| weight=None, | |
| avg_factor=None, | |
| reduction_override=None): | |
| """Forward function. | |
| Args: | |
| pred (torch.Tensor): The prediction. | |
| target (torch.Tensor): The learning target of the prediction. | |
| valid (torch.Tensor): The visible flag of the target pose. | |
| area (torch.Tensor): The area of the target pose. | |
| weight (torch.Tensor, optional): The weight of loss for each | |
| prediction. Defaults to None. | |
| avg_factor (int, optional): Average factor that is used to average | |
| the loss. Defaults to None. | |
| reduction_override (str, optional): The reduction method used to | |
| override the original reduction method of the loss. | |
| Defaults to None. Options are "none", "mean" and "sum". | |
| """ | |
| assert reduction_override in (None, 'none', 'mean', 'sum') | |
| reduction = (reduction_override | |
| if reduction_override else self.reduction) | |
| if (weight is not None) and (not torch.any(weight > 0)) and ( | |
| reduction != 'none'): | |
| if pred.dim() == weight.dim() + 1: | |
| weight = weight.unsqueeze(1) | |
| return (pred * weight).sum() # 0 | |
| if weight is not None and weight.dim() > 1: | |
| # TODO: remove this in the future | |
| # reduce the weight of shape (n, 4) to (n,) to match the | |
| # iou_loss of shape (n,) | |
| assert weight.shape == pred.shape | |
| weight = weight.mean(-1) | |
| loss = self.loss_weight * oks_loss(pred, | |
| target, | |
| valid=valid, | |
| area=area, | |
| linear=self.linear, | |
| sigmas=self.sigmas, | |
| eps=self.eps) | |
| return loss | |