File size: 7,063 Bytes
0b69a1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
from torch import nn, Tensor
from torch.amp import autocast
from typing import List, Tuple, Dict
from .bregman_pytorch import sinkhorn
from .utils import _reshape_density
EPS = 1e-8
class OTLoss(nn.Module):
def __init__(
self,
input_size: int,
block_size: int,
numItermax: int = 100,
regularization: float = 10.0
) -> None:
super().__init__()
assert input_size % block_size == 0
self.input_size = input_size
self.block_size = block_size
self.num_blocks_h = input_size // block_size
self.num_blocks_w = input_size // block_size
self.numItermax = numItermax
self.regularization = regularization
# coordinate is same to image space, set to constant since crop size is same
self.coords_h = torch.arange(0, input_size, step=block_size, dtype=torch.float32) + block_size / 2
self.coords_w = torch.arange(0, input_size, step=block_size, dtype=torch.float32) + block_size / 2
self.coords_h, self.coords_w = self.coords_h.unsqueeze(0), self.coords_w.unsqueeze(0) # [1, #coordinates]
def set_numItermax(self, numItermax: int) -> None:
self.numItermax = numItermax
@autocast(device_type="cuda", enabled=True, dtype=torch.float32) # avoid numerical instability
def forward(self, pred_den_map: Tensor, pred_den_map_normed: Tensor, gt_points: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
assert pred_den_map.shape[1:] == pred_den_map_normed.shape[1:] == (1, self.num_blocks_h, self.num_blocks_w), f"Expected pred_den_map to have shape (B, 1, {self.num_blocks_h}, {self.num_blocks_w}), but got {pred_den_map.shape} and {pred_den_map_normed.shape}"
assert len(gt_points) == pred_den_map.shape[0] == pred_den_map_normed.shape[0], f"Expected gt_points to have length {pred_den_map_normed.shape[0]}, but got {len(gt_points)}"
device = pred_den_map.device
loss = torch.zeros(1, device=device)
ot_obj_values = torch.zeros(1, device=device)
w_dist = torch.zeros(1, device=device) # Wasserstein distance
coords_h, coords_w = self.coords_h.to(device), self.coords_w.to(device) # [1, #coordinates]
for idx, points in enumerate(gt_points):
if len(points) > 0:
# compute l2 square distance, it should be source target distance. [#gt, #coordinates * #coordinates]
x, y = points[:, 0].unsqueeze(1), points[:, 1].unsqueeze(1) # [#gt, 1]
x_dist = -2 * torch.matmul(x, coords_w) + x * x + coords_w * coords_w # [#gt, #coordinates]
y_dist = -2 * torch.matmul(y, coords_h) + y * y + coords_h * coords_h # [#gt, #coordinates]
dist = x_dist.unsqueeze(1) + y_dist.unsqueeze(2)
dist = dist.view((dist.shape[0], -1)) # size of [#gt, #coordinates * #coordinates]
source_prob = pred_den_map_normed[idx].view(-1).detach()
target_prob = (torch.ones(len(points)) / len(points)).to(device)
# use sinkhorn to solve OT, compute optimal beta.
P, log = sinkhorn(
a=target_prob,
b=source_prob,
C=dist,
reg=self.regularization,
maxIter=self.numItermax,
log=True
)
beta = log["beta"] # size is the same as source_prob: [#coordinates * #coordinates]
w_dist += (dist * P).sum()
ot_obj_values += (pred_den_map_normed[idx] * beta.view(1, self.num_blocks_h, self.num_blocks_w)).sum()
# compute the gradient of OT loss to predicted density (pred_den_map).
# im_grad = beta / source_count - < beta, source_density> / (source_count)^2
source_density = pred_den_map[idx].view(-1).detach()
source_count = source_density.sum()
gradient_1 = (source_count) / (source_count * source_count+ EPS) * beta # size of [#coordinates * #coordinates]
gradient_2 = (source_density * beta).sum() / (source_count * source_count + EPS) # size of 1
gradient = gradient_1 - gradient_2
gradient = gradient.detach().view(1, self.num_blocks_h, self.num_blocks_w)
# Define loss = <im_grad, predicted density>. The gradient of loss w.r.t predicted density is im_grad.
loss += torch.sum(pred_den_map[idx] * gradient)
return loss, w_dist, ot_obj_values
class DMLoss(nn.Module):
def __init__(
self,
input_size: int,
block_size: int,
numItermax: int = 100,
regularization: float = 10.0,
weight_ot: float = 0.1,
weight_tv: float = 0.01,
weight_cnt: float = 1.0,
) -> None:
super().__init__()
self.input_size = input_size
self.block_size = block_size
self.weight_ot = weight_ot
self.weight_tv = weight_tv
self.weight_cnt = weight_cnt
self.ot_loss = OTLoss(
input_size=self.input_size,
block_size=self.block_size,
numItermax=numItermax,
regularization=regularization,
)
self.tv_loss = nn.L1Loss(reduction="none")
self.cnt_loss = nn.L1Loss(reduction="mean")
self.weight_ot = weight_ot
self.weight_tv = weight_tv
@autocast(device_type="cuda", enabled=True, dtype=torch.float32) # avoid numerical instability
def forward(self, pred_den_map: Tensor, gt_den_map: Tensor, gt_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
gt_den_map = _reshape_density(gt_den_map, block_size=self.ot_loss.block_size) if gt_den_map.shape[-2:] != pred_den_map.shape[-2:] else gt_den_map
assert pred_den_map.shape == gt_den_map.shape, f"Expected pred_den_map and gt_den_map to have the same shape, got {pred_den_map.shape} and {gt_den_map.shape}"
pred_cnt = pred_den_map.view(pred_den_map.shape[0], -1).sum(dim=1)
pred_den_map_normed = pred_den_map / (pred_cnt.view(-1, 1, 1, 1) + EPS)
gt_cnt = torch.tensor([len(p) for p in gt_points], dtype=torch.float32).to(pred_den_map.device)
gt_den_map_normed = gt_den_map / (gt_cnt.view(-1, 1, 1, 1) + EPS)
ot_loss, w_dist, _ = self.ot_loss(pred_den_map, pred_den_map_normed, gt_points)
tv_loss = (self.tv_loss(pred_den_map_normed, gt_den_map_normed).sum(dim=(1, 2, 3)) * gt_cnt).mean() if self.weight_tv > 0 else 0
cnt_loss = self.cnt_loss(pred_cnt, gt_cnt) if self.weight_cnt > 0 else 0
loss = ot_loss * self.weight_ot + tv_loss * self.weight_tv + cnt_loss * self.weight_cnt
loss_info = {
"ot_loss": ot_loss.detach(),
"dm_loss": loss.detach(),
"w_dist": w_dist.detach(),
}
if self.weight_tv > 0:
loss_info["tv_loss"] = tv_loss.detach()
if self.weight_cnt > 0:
loss_info["cnt_loss"] = cnt_loss.detach()
return loss, loss_info
|