File size: 523 Bytes
710f946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
from typing import Sequence

import torchvision.transforms as transforms


def preprocess(target_input_size: Sequence[int]) -> transforms.Compose:
    """Return the inference transform used by the demo model."""
    if not (isinstance(target_input_size, (list, tuple)) and len(target_input_size) == 3):
        raise ValueError("target_input_size must be (C, H, W)")

    _, height, width = target_input_size
    return transforms.Compose([
        transforms.Resize((height, width)),
        transforms.ToTensor(),
    ])