| from torch import nn, Tensor |
| import math |
| from typing import List, Optional, Dict, Tuple |
|
|
|
|
| class MultiscaleMAE(nn.Module): |
| def __init__( |
| self, |
| scales: List[int] = [1, 2, 4], |
| min_scale_weight: float = 0.0, |
| max_scale_weight: float = 1.0, |
| alpha: float = 0.5, |
| weights: Optional[List[float]] = None, |
| ) -> None: |
| super().__init__() |
| assert isinstance(scales, (list, tuple)) and len(scales) > 0 and all(isinstance(s, int) and s > 0 for s in scales), f"Expected scales to be a list of positive integers, got {scales}" |
| assert max_scale_weight >= min_scale_weight >= 0, f"Expected max_scale_weight to be greater than or equal to min_scale_weight, got {min_scale_weight} and {max_scale_weight}" |
| assert 1 > alpha > 0, f"Expected alpha to be between 0 and 1, got {alpha}" |
| self.min_scale_weight, self.max_scale_weight = min_scale_weight, max_scale_weight |
|
|
| scales = sorted(scales) |
| weights = [min_scale_weight + (max_scale_weight - min_scale_weight) * alpha ** (math.log2(scales[-1] / s)) for s in scales] if weights is None else weights |
|
|
| assert len(scales) == len(weights), f"Expected scales and weights to have the same length, got {len(scales)} and {len(weights)}" |
| self.scales, self.weights = scales, weights |
|
|
| for idx in range(len(scales)): |
| setattr(self, f"pool_{scales[idx]}", nn.AvgPool2d(kernel_size=scales[idx], stride=scales[idx]) if scales[idx] > 1 else nn.Identity()) |
| setattr(self, f"weight_{scales[idx]}", weights[idx]) |
| setattr(self, f"mae_loss_fn_{scales[idx]}", nn.L1Loss(reduction="none")) |
|
|
| def forward( |
| self, |
| pred_den_map: Tensor, |
| gt_den_map: Tensor, |
| ) -> Tuple[Tensor, Dict]: |
| assert len(pred_den_map.shape) == 4, f"Expected pred_den_map to have 4 dimensions, got {len(pred_den_map.shape)}" |
| assert len(gt_den_map.shape) == 4, f"Expected gt_den_map to have 4 dimensions, got {len(gt_den_map.shape)}" |
| assert pred_den_map.shape[1] == gt_den_map.shape[1] == 1, f"Expected pred_den_map and gt_den_map to have 1 channel, got {pred_den_map.shape[1]} and {gt_den_map.shape[1]}" |
| assert pred_den_map.shape == gt_den_map.shape, f"Expected pred_den_map and gt_den_map to have the same shape, got {pred_den_map.shape} and {gt_den_map.shape}" |
| |
| loss, loss_info = 0, {} |
| for idx in range(len(self.scales)): |
| pool = getattr(self, f"pool_{self.scales[idx]}") |
| weight = getattr(self, f"weight_{self.scales[idx]}") |
| loss_fn = getattr(self, f"mae_loss_fn_{self.scales[idx]}") |
|
|
| pred_den_map_pool = pool(pred_den_map) |
| gt_den_map_pool = pool(gt_den_map) |
|
|
| mae_loss_scale = loss_fn(pred_den_map_pool, gt_den_map_pool).sum(dim=(-1, -2)).mean() |
| loss += weight * mae_loss_scale |
| loss_info[f"mae_loss_{self.scales[idx]}"] = mae_loss_scale.detach() |
| |
| return loss, loss_info |
|
|