| """Shared utilities for segmentation heads.""" | |
| import torch | |
| import torch.nn.functional as F | |
| def upsample_and_argmax(logits, target_size): | |
| """Upsample logits to target spatial size and return class indices.""" | |
| if logits.shape[2:] != target_size: | |
| logits = F.interpolate(logits, size=target_size, mode="bilinear", align_corners=False) | |
| return logits.argmax(dim=1) | |