File size: 834 Bytes
a9d56ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torchvision.transforms import v2

RESIZE = {
    "effnet": 384,
    "resnet": 224,
    "mbnet": 224,
    "swin": 256,
}


def get_preprocessing(model_type: str) -> v2.Compose:
    """
    Gets the right image preprocessing transform for each model

    Parameters
    ----------
    model_type : str
        Model nickname

    Returns
    -------
    v2.Compose
        Preprocessing transform

    Raises
    ------
    NotImplementedError
        If it's an invalid model_type
    """
    resize = RESIZE[model_type]
    transform = v2.Compose(
        [
            v2.ToImage(),
            v2.Resize((resize, resize)),
            v2.ToDtype(torch.float, True),
            v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            v2.Grayscale(3),
        ]
    )

    return transform