|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
from torch.amp import autocast |
|
|
from typing import List, Tuple, Dict |
|
|
|
|
|
from .bregman_pytorch import sinkhorn |
|
|
from .utils import _reshape_density |
|
|
|
|
|
EPS = 1e-8 |
|
|
|
|
|
|
|
|
class OTLoss(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_size: int, |
|
|
block_size: int, |
|
|
numItermax: int = 100, |
|
|
regularization: float = 10.0 |
|
|
) -> None: |
|
|
super().__init__() |
|
|
assert input_size % block_size == 0 |
|
|
|
|
|
self.input_size = input_size |
|
|
self.block_size = block_size |
|
|
self.num_blocks_h = input_size // block_size |
|
|
self.num_blocks_w = input_size // block_size |
|
|
self.numItermax = numItermax |
|
|
self.regularization = regularization |
|
|
|
|
|
|
|
|
self.coords_h = torch.arange(0, input_size, step=block_size, dtype=torch.float32) + block_size / 2 |
|
|
self.coords_w = torch.arange(0, input_size, step=block_size, dtype=torch.float32) + block_size / 2 |
|
|
self.coords_h, self.coords_w = self.coords_h.unsqueeze(0), self.coords_w.unsqueeze(0) |
|
|
|
|
|
def set_numItermax(self, numItermax: int) -> None: |
|
|
self.numItermax = numItermax |
|
|
|
|
|
@autocast(device_type="cuda", enabled=True, dtype=torch.float32) |
|
|
def forward(self, pred_den_map: Tensor, pred_den_map_normed: Tensor, gt_points: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
assert pred_den_map.shape[1:] == pred_den_map_normed.shape[1:] == (1, self.num_blocks_h, self.num_blocks_w), f"Expected pred_den_map to have shape (B, 1, {self.num_blocks_h}, {self.num_blocks_w}), but got {pred_den_map.shape} and {pred_den_map_normed.shape}" |
|
|
assert len(gt_points) == pred_den_map.shape[0] == pred_den_map_normed.shape[0], f"Expected gt_points to have length {pred_den_map_normed.shape[0]}, but got {len(gt_points)}" |
|
|
device = pred_den_map.device |
|
|
|
|
|
loss = torch.zeros(1, device=device) |
|
|
ot_obj_values = torch.zeros(1, device=device) |
|
|
w_dist = torch.zeros(1, device=device) |
|
|
coords_h, coords_w = self.coords_h.to(device), self.coords_w.to(device) |
|
|
for idx, points in enumerate(gt_points): |
|
|
if len(points) > 0: |
|
|
|
|
|
x, y = points[:, 0].unsqueeze(1), points[:, 1].unsqueeze(1) |
|
|
x_dist = -2 * torch.matmul(x, coords_w) + x * x + coords_w * coords_w |
|
|
y_dist = -2 * torch.matmul(y, coords_h) + y * y + coords_h * coords_h |
|
|
dist = x_dist.unsqueeze(1) + y_dist.unsqueeze(2) |
|
|
dist = dist.view((dist.shape[0], -1)) |
|
|
|
|
|
source_prob = pred_den_map_normed[idx].view(-1).detach() |
|
|
target_prob = (torch.ones(len(points)) / len(points)).to(device) |
|
|
|
|
|
P, log = sinkhorn( |
|
|
a=target_prob, |
|
|
b=source_prob, |
|
|
C=dist, |
|
|
reg=self.regularization, |
|
|
maxIter=self.numItermax, |
|
|
log=True |
|
|
) |
|
|
beta = log["beta"] |
|
|
w_dist += (dist * P).sum() |
|
|
ot_obj_values += (pred_den_map_normed[idx] * beta.view(1, self.num_blocks_h, self.num_blocks_w)).sum() |
|
|
|
|
|
|
|
|
source_density = pred_den_map[idx].view(-1).detach() |
|
|
source_count = source_density.sum() |
|
|
gradient_1 = (source_count) / (source_count * source_count+ EPS) * beta |
|
|
gradient_2 = (source_density * beta).sum() / (source_count * source_count + EPS) |
|
|
gradient = gradient_1 - gradient_2 |
|
|
gradient = gradient.detach().view(1, self.num_blocks_h, self.num_blocks_w) |
|
|
|
|
|
loss += torch.sum(pred_den_map[idx] * gradient) |
|
|
|
|
|
return loss, w_dist, ot_obj_values |
|
|
|
|
|
|
|
|
class DMLoss(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
input_size: int, |
|
|
block_size: int, |
|
|
numItermax: int = 100, |
|
|
regularization: float = 10.0, |
|
|
weight_ot: float = 0.1, |
|
|
weight_tv: float = 0.01, |
|
|
weight_cnt: float = 1.0, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.input_size = input_size |
|
|
self.block_size = block_size |
|
|
self.weight_ot = weight_ot |
|
|
self.weight_tv = weight_tv |
|
|
self.weight_cnt = weight_cnt |
|
|
|
|
|
self.ot_loss = OTLoss( |
|
|
input_size=self.input_size, |
|
|
block_size=self.block_size, |
|
|
numItermax=numItermax, |
|
|
regularization=regularization, |
|
|
) |
|
|
self.tv_loss = nn.L1Loss(reduction="none") |
|
|
self.cnt_loss = nn.L1Loss(reduction="mean") |
|
|
self.weight_ot = weight_ot |
|
|
self.weight_tv = weight_tv |
|
|
|
|
|
@autocast(device_type="cuda", enabled=True, dtype=torch.float32) |
|
|
def forward(self, pred_den_map: Tensor, gt_den_map: Tensor, gt_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: |
|
|
gt_den_map = _reshape_density(gt_den_map, block_size=self.ot_loss.block_size) if gt_den_map.shape[-2:] != pred_den_map.shape[-2:] else gt_den_map |
|
|
assert pred_den_map.shape == gt_den_map.shape, f"Expected pred_den_map and gt_den_map to have the same shape, got {pred_den_map.shape} and {gt_den_map.shape}" |
|
|
|
|
|
pred_cnt = pred_den_map.view(pred_den_map.shape[0], -1).sum(dim=1) |
|
|
pred_den_map_normed = pred_den_map / (pred_cnt.view(-1, 1, 1, 1) + EPS) |
|
|
gt_cnt = torch.tensor([len(p) for p in gt_points], dtype=torch.float32).to(pred_den_map.device) |
|
|
gt_den_map_normed = gt_den_map / (gt_cnt.view(-1, 1, 1, 1) + EPS) |
|
|
|
|
|
ot_loss, w_dist, _ = self.ot_loss(pred_den_map, pred_den_map_normed, gt_points) |
|
|
|
|
|
tv_loss = (self.tv_loss(pred_den_map_normed, gt_den_map_normed).sum(dim=(1, 2, 3)) * gt_cnt).mean() if self.weight_tv > 0 else 0 |
|
|
|
|
|
cnt_loss = self.cnt_loss(pred_cnt, gt_cnt) if self.weight_cnt > 0 else 0 |
|
|
|
|
|
loss = ot_loss * self.weight_ot + tv_loss * self.weight_tv + cnt_loss * self.weight_cnt |
|
|
|
|
|
loss_info = { |
|
|
"ot_loss": ot_loss.detach(), |
|
|
"dm_loss": loss.detach(), |
|
|
"w_dist": w_dist.detach(), |
|
|
} |
|
|
if self.weight_tv > 0: |
|
|
loss_info["tv_loss"] = tv_loss.detach() |
|
|
if self.weight_cnt > 0: |
|
|
loss_info["cnt_loss"] = cnt_loss.detach() |
|
|
|
|
|
return loss, loss_info |
|
|
|