Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| from typing import List | |
| import utils.basic | |
| def sequence_loss( | |
| flow_preds, | |
| flow_gt, | |
| valids, | |
| vis=None, | |
| gamma=0.8, | |
| use_huber_loss=False, | |
| loss_only_for_visible=False, | |
| ): | |
| """Loss function defined over sequence of flow predictions""" | |
| total_flow_loss = 0.0 | |
| for j in range(len(flow_gt)): | |
| B, S, N, D = flow_gt[j].shape | |
| B, S2, N = valids[j].shape | |
| assert S == S2 | |
| n_predictions = len(flow_preds[j]) | |
| flow_loss = 0.0 | |
| for i in range(n_predictions): | |
| i_weight = gamma ** (n_predictions - i - 1) | |
| flow_pred = flow_preds[j][i] | |
| if use_huber_loss: | |
| i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) | |
| else: | |
| i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2 | |
| i_loss = torch.mean(i_loss, dim=3) # B, S, N | |
| valid_ = valids[j].clone() | |
| if loss_only_for_visible: | |
| valid_ = valid_ * vis[j] | |
| flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss, valid_) | |
| flow_loss = flow_loss / n_predictions | |
| total_flow_loss += flow_loss | |
| return total_flow_loss / len(flow_gt) | |
| def sequence_loss_dense( | |
| flow_preds, | |
| flow_gt, | |
| valids, | |
| vis=None, | |
| gamma=0.8, | |
| use_huber_loss=False, | |
| loss_only_for_visible=False, | |
| ): | |
| """Loss function defined over sequence of flow predictions""" | |
| total_flow_loss = 0.0 | |
| for j in range(len(flow_gt)): | |
| # print('flow_gt[j]', flow_gt[j].shape) | |
| B, S, D, H, W = flow_gt[j].shape | |
| B, S2, _, H, W = valids[j].shape | |
| assert S == S2 | |
| n_predictions = len(flow_preds[j]) | |
| flow_loss = 0.0 | |
| # import ipdb; ipdb.set_trace() | |
| for i in range(n_predictions): | |
| # print('flow_e[j][i]', flow_preds[j][i].shape) | |
| i_weight = gamma ** (n_predictions - i - 1) | |
| flow_pred = flow_preds[j][i] # B,S,2,H,W | |
| if use_huber_loss: | |
| i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0) # B,S,2,H,W | |
| else: | |
| i_loss = (flow_pred - flow_gt[j]).abs() # B,S,2,H,W | |
| i_loss_ = torch.mean(i_loss, dim=2) # B,S,H,W | |
| valid_ = valids[j].reshape(B,S,H,W) | |
| # print(' (%d,%d) i_loss_' % (i,j), i_loss_.shape) | |
| # print(' (%d,%d) valid_' % (i,j), valid_.shape) | |
| if loss_only_for_visible: | |
| valid_ = valid_ * vis[j].reshape(B,-1,H,W) # usually B,S,H,W, but maybe B,1,H,W | |
| flow_loss += i_weight * utils.basic.reduce_masked_mean(i_loss_, valid_, broadcast=True) | |
| # import ipdb; ipdb.set_trace() | |
| flow_loss = flow_loss / n_predictions | |
| total_flow_loss += flow_loss | |
| return total_flow_loss / len(flow_gt) | |
| def huber_loss(x, y, delta=1.0): | |
| """Calculate element-wise Huber loss between x and y""" | |
| diff = x - y | |
| abs_diff = diff.abs() | |
| flag = (abs_diff <= delta).float() | |
| return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta) | |
| def sequence_BCE_loss(vis_preds, vis_gts, valids=None, use_logits=False): | |
| total_bce_loss = 0.0 | |
| # all_vis_preds = [torch.stack(vp) for vp in vis_preds] | |
| # all_vis_preds = torch.stack(all_vis_preds) | |
| # utils.basic.print_stats('all_vis_preds', all_vis_preds) | |
| for j in range(len(vis_preds)): | |
| n_predictions = len(vis_preds[j]) | |
| bce_loss = 0.0 | |
| for i in range(n_predictions): | |
| # utils.basic.print_stats('vis_preds[%d][%d]' % (j,i), vis_preds[j][i]) | |
| # utils.basic.print_stats('vis_gts[%d]' % (i), vis_gts[i]) | |
| if use_logits: | |
| loss = F.binary_cross_entropy_with_logits(vis_preds[j][i], vis_gts[j], reduction='none') | |
| else: | |
| loss = F.binary_cross_entropy(vis_preds[j][i], vis_gts[j], reduction='none') | |
| if valids is None: | |
| bce_loss += loss.mean() | |
| else: | |
| bce_loss += (loss * valids[j]).mean() | |
| bce_loss = bce_loss / n_predictions | |
| total_bce_loss += bce_loss | |
| return total_bce_loss / len(vis_preds) | |
| # def sequence_BCE_loss_dense(vis_preds, vis_gts): | |
| # total_bce_loss = 0.0 | |
| # for j in range(len(vis_preds)): | |
| # n_predictions = len(vis_preds[j]) | |
| # bce_loss = 0.0 | |
| # for i in range(n_predictions): | |
| # vis_e = vis_preds[j][i] | |
| # vis_g = vis_gts[j] | |
| # print('vis_e', vis_e.shape, 'vis_g', vis_g.shape) | |
| # vis_loss = F.binary_cross_entropy(vis_e, vis_g) | |
| # bce_loss += vis_loss | |
| # bce_loss = bce_loss / n_predictions | |
| # total_bce_loss += bce_loss | |
| # return total_bce_loss / len(vis_preds) | |
| def sequence_prob_loss( | |
| tracks: torch.Tensor, | |
| confidence: torch.Tensor, | |
| target_points: torch.Tensor, | |
| visibility: torch.Tensor, | |
| expected_dist_thresh: float = 12.0, | |
| use_logits=False, | |
| ): | |
| """Loss for classifying if a point is within pixel threshold of its target.""" | |
| # Points with an error larger than 12 pixels are likely to be useless; marking | |
| # them as occluded will actually improve Jaccard metrics and give | |
| # qualitatively better results. | |
| total_logprob_loss = 0.0 | |
| for j in range(len(tracks)): | |
| n_predictions = len(tracks[j]) | |
| logprob_loss = 0.0 | |
| for i in range(n_predictions): | |
| err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1) | |
| valid = (err <= expected_dist_thresh**2).float() | |
| if use_logits: | |
| loss = F.binary_cross_entropy_with_logits(confidence[j][i], valid, reduction="none") | |
| else: | |
| loss = F.binary_cross_entropy(confidence[j][i], valid, reduction="none") | |
| loss *= visibility[j] | |
| loss = torch.mean(loss, dim=[1, 2]) | |
| logprob_loss += loss | |
| logprob_loss = logprob_loss / n_predictions | |
| total_logprob_loss += logprob_loss | |
| return total_logprob_loss / len(tracks) | |
| def sequence_prob_loss_dense( | |
| tracks: torch.Tensor, | |
| confidence: torch.Tensor, | |
| target_points: torch.Tensor, | |
| visibility: torch.Tensor, | |
| expected_dist_thresh: float = 12.0, | |
| use_logits=False, | |
| ): | |
| """Loss for classifying if a point is within pixel threshold of its target.""" | |
| # Points with an error larger than 12 pixels are likely to be useless; marking | |
| # them as occluded will actually improve Jaccard metrics and give | |
| # qualitatively better results. | |
| # all_confidence = [torch.stack(vp) for vp in confidence] | |
| # all_confidence = torch.stack(all_confidence) | |
| # utils.basic.print_stats('all_confidence', all_confidence) | |
| total_logprob_loss = 0.0 | |
| for j in range(len(tracks)): | |
| n_predictions = len(tracks[j]) | |
| logprob_loss = 0.0 | |
| for i in range(n_predictions): | |
| # print('trajs_e', tracks[j][i].shape) | |
| # print('trajs_g', target_points[j].shape) | |
| err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=2) | |
| positive = (err <= expected_dist_thresh**2).float() | |
| # print('conf', confidence[j][i].shape, 'positive', positive.shape) | |
| if use_logits: | |
| loss = F.binary_cross_entropy_with_logits(confidence[j][i].squeeze(2), positive, reduction="none") | |
| else: | |
| loss = F.binary_cross_entropy(confidence[j][i].squeeze(2), positive, reduction="none") | |
| loss *= visibility[j].squeeze(2) # B,S,H,W | |
| loss = torch.mean(loss, dim=[1,2,3]) | |
| logprob_loss += loss | |
| logprob_loss = logprob_loss / n_predictions | |
| total_logprob_loss += logprob_loss | |
| return total_logprob_loss / len(tracks) | |
| def masked_mean(data, mask, dim): | |
| if mask is None: | |
| return data.mean(dim=dim, keepdim=True) | |
| mask = mask.float() | |
| mask_sum = torch.sum(mask, dim=dim, keepdim=True) | |
| mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( | |
| mask_sum, min=1.0 | |
| ) | |
| return mask_mean | |
| def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): | |
| if mask is None: | |
| return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) | |
| mask = mask.float() | |
| mask_sum = torch.sum(mask, dim=dim, keepdim=True) | |
| mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( | |
| mask_sum, min=1.0 | |
| ) | |
| mask_var = torch.sum( | |
| mask * (data - mask_mean) ** 2, dim=dim, keepdim=True | |
| ) / torch.clamp(mask_sum, min=1.0) | |
| return mask_mean.squeeze(dim), mask_var.squeeze(dim) | |