Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn, Tensor | |
| from typing import Any, List, Tuple, Dict | |
| from .dm_loss import DMLoss | |
| from .utils import _reshape_density | |
| class DACELoss(nn.Module): | |
| def __init__( | |
| self, | |
| bins: List[Tuple[float, float]], | |
| reduction: int, | |
| weight_count_loss: float = 1.0, | |
| count_loss: str = "mae", | |
| **kwargs: Any | |
| ) -> None: | |
| super().__init__() | |
| assert len(bins) > 0, f"Expected at least one bin, got {bins}" | |
| assert all([len(b) == 2 for b in bins]), f"Expected all bins to be of length 2, got {bins}" | |
| assert all([b[0] <= b[1] for b in bins]), f"Expected all bins to be in increasing order, got {bins}" | |
| self.bins = bins | |
| self.reduction = reduction | |
| self.cross_entropy_fn = nn.CrossEntropyLoss(reduction="none") | |
| count_loss = count_loss.lower() | |
| assert count_loss in ["mae", "mse", "dmcount"], f"Expected count_loss to be one of ['mae', 'mse', 'dmcount'], got {count_loss}" | |
| self.count_loss = count_loss | |
| if self.count_loss == "mae": | |
| self.use_dm_loss = False | |
| self.count_loss_fn = nn.L1Loss(reduction="none") | |
| elif self.count_loss == "mse": | |
| self.use_dm_loss = False | |
| self.count_loss_fn = nn.MSELoss(reduction="none") | |
| else: | |
| self.use_dm_loss = True | |
| assert "input_size" in kwargs, f"Expected input_size to be in kwargs when count_loss='dmcount', got {kwargs}" | |
| self.count_loss_fn = DMLoss(reduction=reduction, **kwargs) | |
| self.weight_count_loss = weight_count_loss | |
| def _bin_count(self, density_map: Tensor) -> Tensor: | |
| class_map = torch.zeros_like(density_map, dtype=torch.long) | |
| for idx, (low, high) in enumerate(self.bins): | |
| mask = (density_map >= low) & (density_map <= high) | |
| class_map[mask] = idx | |
| return class_map.squeeze(1) # remove channel dimension | |
| def forward(self, pred_class: Tensor, pred_density: Tensor, target_density: Tensor, target_points: List[Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]: | |
| target_density = _reshape_density(target_density, reduction=self.reduction) if target_density.shape[-2:] != pred_density.shape[-2:] else target_density | |
| assert pred_density.shape == target_density.shape, f"Expected pred_density and target_density to have the same shape, got {pred_density.shape} and {target_density.shape}" | |
| target_class = self._bin_count(target_density) | |
| cross_entropy_loss = self.cross_entropy_fn(pred_class, target_class).sum(dim=(-1, -2)).mean() | |
| if self.use_dm_loss: | |
| count_loss, loss_info = self.count_loss_fn(pred_density, target_density, target_points) | |
| loss_info["ce_loss"] = cross_entropy_loss.detach() | |
| else: | |
| count_loss = self.count_loss_fn(pred_density, target_density).sum(dim=(-1, -2, -3)).mean() | |
| loss_info = { | |
| "ce_loss": cross_entropy_loss.detach(), | |
| f"{self.count_loss}_loss": count_loss.detach(), | |
| } | |
| loss = cross_entropy_loss + self.weight_count_loss * count_loss | |
| loss_info["loss"] = loss.detach() | |
| return loss, loss_info | |