Spaces:
Runtime error
Runtime error
| from abc import ABC, abstractmethod | |
| from collections.abc import Callable | |
| from functools import wraps | |
| from typing import Any | |
| from numpy.typing import NDArray | |
| import numpy as np | |
| import torch | |
| class Imputer(ABC): | |
| ''' ... ''' | |
| def __init__(self, | |
| modalities: dict[str, dict[str, Any]], | |
| is_embedding: dict[str, bool] | None = None | |
| ) -> None: | |
| ''' ... ''' | |
| self.modalities = modalities | |
| self.is_embedding = is_embedding | |
| def __call__(self, | |
| smp: dict[str, int | NDArray[np.float32] | None], | |
| ) -> dict[str, int | NDArray[np.float32]]: | |
| ''' ... ''' | |
| pass | |
| def _keyerror_hint(func): | |
| ''' Print hint for resolving KeyError. ''' | |
| def wrapper(*args, **kwargs): | |
| try: | |
| return func(*args, **kwargs) | |
| except KeyError as err: | |
| raise ValueError('Format the data using Formatter module.') from err | |
| return wrapper | |
| class ConstantImputer(Imputer): | |
| ''' ... ''' | |
| def __call__(self, | |
| smp: dict[str, int | NDArray[np.float32] | None], | |
| ) -> dict[str, int | NDArray[np.float32]]: | |
| ''' ... ''' | |
| new = dict() | |
| for k, info in self.modalities.items(): | |
| if smp[k] is not None: | |
| new[k] = smp[k] | |
| else: | |
| if self.is_embedding is not None and k in self.is_embedding and self.is_embedding[k]: | |
| new[k] = np.zeros(256, dtype=np.float32) | |
| else: | |
| if info['type'] == 'categorical': | |
| new[k] = 0 | |
| elif info['type'] == 'numerical' or info['type'] == 'imaging': | |
| new[k] = np.zeros(tuple(info['shape']), dtype=np.float32) | |
| else: | |
| raise ValueError | |
| return new | |
| class FrequencyImputer(Imputer): | |
| ''' ... ''' | |
| def __init__(self, | |
| modalities: dict[str, dict[str, Any]], | |
| dat: list[dict[str, int | NDArray[np.float32] | None]], | |
| ) -> None: | |
| ''' ... ''' | |
| super().__init__(modalities) | |
| # List[Dict] to Dict[List] | |
| self.pool = {k: [smp[k] for smp in dat] for k in modalities} | |
| # remove None | |
| self.pool = {k: [v for v in self.pool[k] if v is not None] for k in self.pool} | |
| def __call__(self, | |
| smp: dict[str, int | NDArray[np.float32] | None], | |
| ) -> dict[str, int | NDArray[np.float32]]: | |
| ''' ... ''' | |
| new = dict() | |
| for k, info in self.modalities.items(): | |
| if smp[k] is not None: | |
| new[k] = smp[k] | |
| else: | |
| # print(k) | |
| if info['type'] == 'categorical': | |
| new[k] = 0 | |
| else: | |
| if info['type'] == 'numerical': | |
| try: | |
| rnd_idx = np.random.randint(0, len(self.pool[k])) | |
| new[k] = np.array(self.pool[k][rnd_idx]) | |
| except: | |
| new[k] = np.array([0], dtype=np.float32) | |
| # print(type(new[k])) | |
| elif info['type'] == 'imaging': | |
| new[k] = np.zeros(tuple(info['shape']), dtype=np.float32) | |
| # print(new[k].shape) | |
| else: | |
| ic(info['shape']) | |
| raise ValueError | |
| return new |