| from PIL import Image | |
| from typing import Dict | |
| def load_image(img_path): | |
| img = Image.open(img_path).convert('RGB') | |
| return img | |
| class LabelTransform: | |
| def __init__(self, mapping_dict: Dict): | |
| self.mapping_dict = mapping_dict | |
| self.mapping_dict = {k.lower(): v for k, v in mapping_dict.items()} | |
| self._keys = sorted(self.mapping_dict.keys()) | |
| # self._keys = [item.lower() for item in self._keys] | |
| def __call__(self, label): | |
| label = label.lower() | |
| assert label in self._keys, (f'label {label} not in label mapping_dict provided.' | |
| f'Available keys {self._keys}') | |
| return self.mapping_dict[label] | |