"""Shared utilities for segmentation heads.""" import torch import torch.nn.functional as F def upsample_and_argmax(logits, target_size): """Upsample logits to target spatial size and return class indices.""" if logits.shape[2:] != target_size: logits = F.interpolate(logits, size=target_size, mode="bilinear", align_corners=False) return logits.argmax(dim=1)