phanerozoic's picture
8 segmentation head candidates with shared losses/utils and registry
0e8110e verified
"""Shared utilities for segmentation heads."""
import torch
import torch.nn.functional as F
def upsample_and_argmax(logits, target_size):
"""Upsample logits to target spatial size and return class indices."""
if logits.shape[2:] != target_size:
logits = F.interpolate(logits, size=target_size, mode="bilinear", align_corners=False)
return logits.argmax(dim=1)