File size: 7,063 Bytes
0b69a1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
from torch import nn, Tensor
from torch.amp import autocast
from typing import List, Tuple, Dict

from .bregman_pytorch import sinkhorn
from .utils import _reshape_density

EPS = 1e-8


class OTLoss(nn.Module):
    def __init__(
        self,
        input_size: int,
        block_size: int,
        numItermax: int = 100,
        regularization: float = 10.0
    ) -> None:
        super().__init__()
        assert input_size % block_size == 0

        self.input_size = input_size
        self.block_size = block_size
        self.num_blocks_h = input_size // block_size
        self.num_blocks_w = input_size // block_size
        self.numItermax = numItermax
        self.regularization = regularization

        # coordinate is same to image space, set to constant since crop size is same
        self.coords_h = torch.arange(0, input_size, step=block_size, dtype=torch.float32) + block_size / 2
        self.coords_w = torch.arange(0, input_size, step=block_size, dtype=torch.float32) + block_size / 2
        self.coords_h, self.coords_w = self.coords_h.unsqueeze(0), self.coords_w.unsqueeze(0)  # [1, #coordinates]

    def set_numItermax(self, numItermax: int) -> None:
        self.numItermax = numItermax

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)  # avoid numerical instability
    def forward(self, pred_den_map: Tensor, pred_den_map_normed: Tensor, gt_points: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
        assert pred_den_map.shape[1:] == pred_den_map_normed.shape[1:] == (1, self.num_blocks_h, self.num_blocks_w), f"Expected pred_den_map to have shape (B, 1, {self.num_blocks_h}, {self.num_blocks_w}), but got {pred_den_map.shape} and {pred_den_map_normed.shape}"
        assert len(gt_points) == pred_den_map.shape[0] == pred_den_map_normed.shape[0], f"Expected gt_points to have length {pred_den_map_normed.shape[0]}, but got {len(gt_points)}"
        device = pred_den_map.device

        loss = torch.zeros(1, device=device)
        ot_obj_values = torch.zeros(1, device=device)
        w_dist = torch.zeros(1, device=device)  # Wasserstein distance
        coords_h, coords_w = self.coords_h.to(device), self.coords_w.to(device)  # [1, #coordinates]
        for idx, points in enumerate(gt_points):
            if len(points) > 0:
                # compute l2 square distance, it should be source target distance. [#gt, #coordinates * #coordinates]
                x, y = points[:, 0].unsqueeze(1), points[:, 1].unsqueeze(1)  # [#gt, 1]
                x_dist = -2 * torch.matmul(x, coords_w) + x * x + coords_w * coords_w  # [#gt, #coordinates]
                y_dist = -2 * torch.matmul(y, coords_h) + y * y + coords_h * coords_h  # [#gt, #coordinates]
                dist = x_dist.unsqueeze(1) + y_dist.unsqueeze(2)
                dist = dist.view((dist.shape[0], -1)) # size of [#gt, #coordinates * #coordinates]

                source_prob = pred_den_map_normed[idx].view(-1).detach()
                target_prob = (torch.ones(len(points)) / len(points)).to(device)
                # use sinkhorn to solve OT, compute optimal beta.
                P, log = sinkhorn(
                    a=target_prob,
                    b=source_prob,
                    C=dist,
                    reg=self.regularization,
                    maxIter=self.numItermax,
                    log=True
                )
                beta = log["beta"] # size is the same as source_prob: [#coordinates * #coordinates]
                w_dist += (dist * P).sum()
                ot_obj_values += (pred_den_map_normed[idx] * beta.view(1, self.num_blocks_h, self.num_blocks_w)).sum()
                # compute the gradient of OT loss to predicted density (pred_den_map).
                # im_grad = beta / source_count - < beta, source_density> / (source_count)^2
                source_density = pred_den_map[idx].view(-1).detach()
                source_count = source_density.sum()
                gradient_1 = (source_count) / (source_count * source_count+ EPS) * beta # size of [#coordinates * #coordinates]
                gradient_2 = (source_density * beta).sum() / (source_count * source_count + EPS) # size of 1
                gradient = gradient_1 - gradient_2
                gradient = gradient.detach().view(1, self.num_blocks_h, self.num_blocks_w)
                # Define loss = <im_grad, predicted density>. The gradient of loss w.r.t predicted density is im_grad.
                loss += torch.sum(pred_den_map[idx] * gradient)

        return loss, w_dist, ot_obj_values


class DMLoss(nn.Module):
    def __init__(
        self,
        input_size: int,
        block_size: int,
        numItermax: int = 100,
        regularization: float = 10.0,
        weight_ot: float = 0.1,
        weight_tv: float = 0.01,
        weight_cnt: float = 1.0,
    ) -> None:
        super().__init__()
        self.input_size = input_size
        self.block_size = block_size
        self.weight_ot = weight_ot
        self.weight_tv = weight_tv
        self.weight_cnt = weight_cnt

        self.ot_loss = OTLoss(
            input_size=self.input_size,
            block_size=self.block_size,
            numItermax=numItermax,
            regularization=regularization,
        )
        self.tv_loss = nn.L1Loss(reduction="none")
        self.cnt_loss = nn.L1Loss(reduction="mean")
        self.weight_ot = weight_ot
        self.weight_tv = weight_tv

    @autocast(device_type="cuda", enabled=True, dtype=torch.float32)  # avoid numerical instability
    def forward(self, pred_den_map: Tensor, gt_den_map: Tensor, gt_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
        gt_den_map = _reshape_density(gt_den_map, block_size=self.ot_loss.block_size) if gt_den_map.shape[-2:] != pred_den_map.shape[-2:] else gt_den_map
        assert pred_den_map.shape == gt_den_map.shape, f"Expected pred_den_map and gt_den_map to have the same shape, got {pred_den_map.shape} and {gt_den_map.shape}"

        pred_cnt = pred_den_map.view(pred_den_map.shape[0], -1).sum(dim=1)
        pred_den_map_normed = pred_den_map / (pred_cnt.view(-1, 1, 1, 1) + EPS)
        gt_cnt = torch.tensor([len(p) for p in gt_points], dtype=torch.float32).to(pred_den_map.device)
        gt_den_map_normed = gt_den_map / (gt_cnt.view(-1, 1, 1, 1) + EPS)

        ot_loss, w_dist, _ = self.ot_loss(pred_den_map, pred_den_map_normed, gt_points)

        tv_loss = (self.tv_loss(pred_den_map_normed, gt_den_map_normed).sum(dim=(1, 2, 3)) * gt_cnt).mean() if self.weight_tv > 0 else 0

        cnt_loss = self.cnt_loss(pred_cnt, gt_cnt) if self.weight_cnt > 0 else 0

        loss = ot_loss * self.weight_ot + tv_loss * self.weight_tv + cnt_loss * self.weight_cnt

        loss_info = {
            "ot_loss": ot_loss.detach(),
            "dm_loss": loss.detach(),
            "w_dist": w_dist.detach(),
        }
        if self.weight_tv > 0:
            loss_info["tv_loss"] = tv_loss.detach()
        if self.weight_cnt > 0:
            loss_info["cnt_loss"] = cnt_loss.detach()

        return loss, loss_info