File size: 705 Bytes
99ec8a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from PIL import Image
from typing import Dict


def load_image(img_path):
    img = Image.open(img_path).convert('RGB')
    return img


class LabelTransform:
    def __init__(self, mapping_dict: Dict):
        self.mapping_dict = mapping_dict
        self.mapping_dict = {k.lower(): v for k, v in mapping_dict.items()}
        self._keys = sorted(self.mapping_dict.keys())
        # self._keys = [item.lower() for item in self._keys]

    def __call__(self, label):
        label = label.lower()
        assert label in self._keys, (f'label {label} not in label mapping_dict provided.'
                                     f'Available keys {self._keys}')
        return self.mapping_dict[label]