Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn, Tensor | |
| from torch.cuda.amp import autocast | |
| from typing import List, Any, Tuple, Dict | |
| from .bregman_pytorch import sinkhorn | |
| from .utils import _reshape_density | |
| EPS = 1e-8 | |
| class OTLoss(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| reduction: int, | |
| norm_cood: bool, | |
| num_of_iter_in_ot: int = 100, | |
| reg: float = 10.0 | |
| ) -> None: | |
| super().__init__() | |
| assert input_size % reduction == 0 | |
| self.input_size = input_size | |
| self.reduction = reduction | |
| self.norm_cood = norm_cood | |
| self.num_of_iter_in_ot = num_of_iter_in_ot | |
| self.reg = reg | |
| # coordinate is same to image space, set to constant since crop size is same | |
| self.cood = torch.arange(0, input_size, step=reduction, dtype=torch.float32) + reduction / 2 | |
| self.density_size = self.cood.size(0) | |
| self.cood.unsqueeze_(0) # [1, #cood] | |
| self.cood = self.cood / input_size * 2 - 1 if self.norm_cood else self.cood | |
| self.output_size = self.cood.size(1) | |
| # avoid numerical instability | |
| def forward(self, pred_density: Tensor, normed_pred_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, float, Tensor]: | |
| batch_size = normed_pred_density.size(0) | |
| assert len(target_points) == batch_size, f"Expected target_points to have length {batch_size}, but got {len(target_points)}" | |
| assert self.output_size == normed_pred_density.size(2) | |
| device = pred_density.device | |
| loss = torch.zeros([1]).to(device) | |
| ot_obj_values = torch.zeros([1]).to(device) | |
| wd = 0 # Wasserstein distance | |
| cood = self.cood.to(device) | |
| for idx, points in enumerate(target_points): | |
| if len(points) > 0: | |
| # compute l2 square distance, it should be source target distance. [#gt, #cood * #cood] | |
| points = points / self.input_size * 2 - 1 if self.norm_cood else points | |
| x = points[:, 0].unsqueeze_(1) # [#gt, 1] | |
| y = points[:, 1].unsqueeze_(1) | |
| x_dist = -2 * torch.matmul(x, cood) + x * x + cood * cood # [#gt, #cood] | |
| y_dist = -2 * torch.matmul(y, cood) + y * y + cood * cood | |
| y_dist.unsqueeze_(2) | |
| x_dist.unsqueeze_(1) | |
| dist = y_dist + x_dist | |
| dist = dist.view((dist.size(0), -1)) # size of [#gt, #cood * #cood] | |
| source_prob = normed_pred_density[idx][0].view([-1]).detach() | |
| target_prob = (torch.ones([len(points)]) / len(points)).to(device) | |
| # use sinkhorn to solve OT, compute optimal beta. | |
| P, log = sinkhorn(target_prob, source_prob, dist, self.reg, maxIter=self.num_of_iter_in_ot, log=True) | |
| beta = log["beta"] # size is the same as source_prob: [#cood * #cood] | |
| ot_obj_values += torch.sum(normed_pred_density[idx] * beta.view([1, self.output_size, self.output_size])) | |
| # compute the gradient of OT loss to predicted density (pred_density). | |
| # im_grad = beta / source_count - < beta, source_density> / (source_count)^2 | |
| source_density = pred_density[idx][0].view([-1]).detach() | |
| source_count = source_density.sum() | |
| gradient_1 = (source_count) / (source_count * source_count+ EPS) * beta # size of [#cood * #cood] | |
| gradient_2 = (source_density * beta).sum() / (source_count * source_count + EPS) # size of 1 | |
| gradient = gradient_1 - gradient_2 | |
| gradient = gradient.detach().view([1, self.output_size, self.output_size]) | |
| # Define loss = <im_grad, predicted density>. The gradient of loss w.r.t predicted density is im_grad. | |
| loss += torch.sum(pred_density[idx] * gradient) | |
| wd += torch.sum(dist * P).item() | |
| return loss, wd, ot_obj_values | |
| class DMLoss(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| reduction: int, | |
| norm_cood: bool = False, | |
| weight_ot: float = 0.1, | |
| weight_tv: float = 0.01, | |
| **kwargs: Any | |
| ) -> None: | |
| super().__init__() | |
| self.ot_loss = OTLoss(input_size, reduction, norm_cood, **kwargs) | |
| self.tv_loss = nn.L1Loss(reduction="none") | |
| self.count_loss = nn.L1Loss(reduction="mean") | |
| self.weight_ot = weight_ot | |
| self.weight_tv = weight_tv | |
| # avoid numerical instability | |
| def forward(self, pred_density: Tensor, target_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: | |
| target_density = _reshape_density(target_density, reduction=self.ot_loss.reduction) if target_density.shape[-2:] != pred_density.shape[-2:] else target_density | |
| assert pred_density.shape == target_density.shape, f"Expected pred_density and target_density to have the same shape, got {pred_density.shape} and {target_density.shape}" | |
| pred_count = pred_density.view(pred_density.shape[0], -1).sum(dim=1) | |
| normed_pred_density = pred_density / (pred_count.view(-1, 1, 1, 1) + EPS) | |
| target_count = torch.tensor([len(p) for p in target_points], dtype=torch.float32).to(target_density.device) | |
| normed_target_density = target_density / (target_count.view(-1, 1, 1, 1) + EPS) | |
| ot_loss, _, _ = self.ot_loss(pred_density, normed_pred_density, target_points) | |
| tv_loss = (self.tv_loss(normed_pred_density, normed_target_density).sum(dim=(1, 2, 3)) * target_count).mean() | |
| count_loss = self.count_loss(pred_count, target_count) | |
| loss = ot_loss * self.weight_ot + tv_loss * self.weight_tv + count_loss | |
| loss_info = { | |
| "loss": loss.detach(), | |
| "ot_loss": ot_loss.detach(), | |
| "tv_loss": tv_loss.detach(), | |
| "count_loss": count_loss.detach(), | |
| } | |
| return loss, loss_info | |