File size: 383 Bytes
0e8110e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
"""Shared utilities for segmentation heads."""

import torch
import torch.nn.functional as F


def upsample_and_argmax(logits, target_size):
    """Upsample logits to target spatial size and return class indices."""
    if logits.shape[2:] != target_size:
        logits = F.interpolate(logits, size=target_size, mode="bilinear", align_corners=False)
    return logits.argmax(dim=1)