File size: 835 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
"""Module containing utilities for classification."""
import torch
from torch import nn


class ClassificationHead(nn.Module):
    """Module to be used as a classification head.

    Args:
        embed_size: the logits tensor coming from the networks.
        num_classes: an integer representing the numbers of classes to classify.

    Example:
        >>> feat = torch.rand(1, 256, 256)
        >>> head = ClassificationHead(256, 10)
        >>> head(feat).shape
        torch.Size([1, 10])
    """
    def __init__(self, embed_size: int = 768, num_classes: int = 10) -> None:
        super().__init__()
        self.norm = nn.LayerNorm(embed_size)
        self.linear = nn.Linear(embed_size, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = x.mean(-2)
        return self.linear(self.norm(out))