| | |
| | |
| |
|
| | from pathlib import Path |
| | import copy |
| | from abc import ABC, abstractmethod |
| | import torch |
| | import torch.nn as nn |
| | from .component import create_net |
| | from .logger import BaseLogger |
| | from lib import ParamSet |
| | from typing import List, Dict, Tuple, Union |
| |
|
| | |
| | |
| | LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]] |
| |
|
| |
|
| | logger = BaseLogger.get_logger(__name__) |
| |
|
| |
|
| | class BaseModel(ABC): |
| | """ |
| | Class to construct model. This class is the base class to construct model. |
| | """ |
| | def __init__(self, params: ParamSet) -> None: |
| | """ |
| | Class to define Model |
| | |
| | Args: |
| | param (ParamSet): parameters |
| | """ |
| | self.params = params |
| | self.device = self.params.device |
| |
|
| | self.network = create_net( |
| | mlp=self.params.mlp, |
| | net=self.params.net, |
| | num_outputs_for_label=self.params.num_outputs_for_label, |
| | mlp_num_inputs=self.params.mlp_num_inputs, |
| | in_channel=self.params.in_channel, |
| | vit_image_size=self.params.vit_image_size, |
| | pretrained=self.params.pretrained |
| | ) |
| | self.network.to(self.device) |
| |
|
| | |
| | self.acting_best_weight = None |
| | self.acting_best_epoch = None |
| |
|
| | def train(self) -> None: |
| | """ |
| | Make network training mode. |
| | """ |
| | self.network.train() |
| |
|
| | def eval(self) -> None: |
| | """ |
| | Make network evaluation mode. |
| | """ |
| | self.network.eval() |
| |
|
| | @abstractmethod |
| | def set_data( |
| | self, |
| | data: Dict |
| | ) -> Tuple[ |
| | Dict[str, torch.FloatTensor], |
| | Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
| | ]: |
| | raise NotImplementedError |
| |
|
| | def store_weight(self, at_epoch: int = None) -> None: |
| | """ |
| | Store weight and epoch number when it is saved. |
| | |
| | Args: |
| | at_epoch (int): epoch number when save weight |
| | """ |
| | self.acting_best_epoch = at_epoch |
| |
|
| | _network = copy.deepcopy(self.network) |
| | if hasattr(_network, 'module'): |
| | |
| | self.acting_best_weight = copy.deepcopy(_network.module.to(torch.device('cpu')).state_dict()) |
| | else: |
| | self.acting_best_weight = copy.deepcopy(_network.state_dict()) |
| |
|
| | def save_weight(self, save_datetime_dir: str, as_best: bool = None) -> None: |
| | """ |
| | Save weight. |
| | |
| | Args: |
| | save_datetime_dir (str): save_datetime_dir |
| | as_best (bool): True if weight is saved as best, otherwise False. Defaults to None. |
| | """ |
| |
|
| | save_dir = Path(save_datetime_dir, 'weights') |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| | save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt' |
| | save_path = Path(save_dir, save_name) |
| |
|
| | if as_best: |
| | save_name_as_best = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '_best' + '.pt' |
| | save_path_as_best = Path(save_dir, save_name_as_best) |
| | if save_path.exists(): |
| | |
| | save_path.rename(save_path_as_best) |
| | else: |
| | torch.save(self.acting_best_weight, save_path_as_best) |
| | else: |
| | save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt' |
| | torch.save(self.acting_best_weight, save_path) |
| |
|
| | def load_weight(self, weight_path: Path) -> None: |
| | """ |
| | Load wight from weight_path. |
| | |
| | Args: |
| | weight_path (Path): path to weight |
| | """ |
| | logger.info(f"Load weight: {weight_path}.\n") |
| | weight = torch.load(weight_path) |
| | self.network.load_state_dict(weight) |
| |
|
| |
|
| | class ModelMixin: |
| | def to_gpu(self, gpu_ids: List[int]) -> None: |
| | """ |
| | Make model compute on the GPU. |
| | |
| | Args: |
| | gpu_ids (List[int]): GPU ids |
| | """ |
| | if gpu_ids != []: |
| | assert torch.cuda.is_available(), 'No available GPU on this machine.' |
| | self.network = nn.DataParallel(self.network, device_ids=gpu_ids) |
| |
|
| | def init_network(self) -> None: |
| | """ |
| | Initialize network. |
| | This method is used at test to reset the current weight by redefining network. |
| | """ |
| | self.network = create_net( |
| | mlp=self.params.mlp, |
| | net=self.params.net, |
| | num_outputs_for_label=self.params.num_outputs_for_label, |
| | mlp_num_inputs=self.params.mlp_num_inputs, |
| | in_channel=self.params.in_channel, |
| | vit_image_size=self.params.vit_image_size, |
| | pretrained=self.params.pretrained |
| | ) |
| | self.network.to(self.device) |
| |
|
| | class ModelWidget(BaseModel, ModelMixin): |
| | """ |
| | Class for a widget to inherit multiple classes simultaneously |
| | """ |
| | pass |
| |
|
| |
|
| | class MLPModel(ModelWidget): |
| | """ |
| | Class for MLP model |
| | """ |
| |
|
| | def __init__(self, params: ParamSet) -> None: |
| | """ |
| | Args: |
| | params: (ParamSet): parameters |
| | """ |
| | super().__init__(params) |
| |
|
| | def set_data( |
| | self, |
| | data: Dict |
| | ) -> Tuple[ |
| | Dict[str, torch.FloatTensor], |
| | Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
| | ]: |
| | """ |
| | Unpack data for forwarding of MLP and calculating loss |
| | by passing them to device. |
| | When deepsurv, period and network are also returned. |
| | |
| | Args: |
| | data (Dict): dictionary of data |
| | |
| | Returns: |
| | Tuple[ |
| | Dict[str, torch.FloatTensor], |
| | Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
| | ]: input of model and data for calculating loss. |
| | eg. |
| | ([inputs], [labels]), or ([inputs], [labels, periods, network]) when deepsurv |
| | """ |
| | in_data = {'inputs': data['inputs'].to(self.device)} |
| | labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} |
| |
|
| | if not any(data['periods']): |
| | return in_data, labels |
| |
|
| | |
| | labels = { |
| | **labels, |
| | **{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} |
| | } |
| | return in_data, labels |
| |
|
| | def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| | """ |
| | Forward |
| | |
| | Args: |
| | in_data (Dict[str, torch.Tensor]): data to be input into model |
| | |
| | Returns: |
| | Dict[str, torch.Tensor]: output |
| | """ |
| | inputs = in_data['inputs'] |
| | output = self.network(inputs) |
| | return output |
| |
|
| |
|
| | class CVModel(ModelWidget): |
| | """ |
| | Class for CNN or ViT model |
| | """ |
| | def __init__(self, params: ParamSet) -> None: |
| | """ |
| | Args: |
| | params: (ParamSet): parameters |
| | """ |
| | super().__init__(params) |
| |
|
| | def set_data( |
| | self, |
| | data: Dict |
| | ) -> Tuple[ |
| | Dict[str, torch.FloatTensor], |
| | Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
| | ]: |
| | """ |
| | Unpack data for forwarding of CNN or ViT and calculating loss by passing them to device. |
| | When deepsurv, period and network are also returned. |
| | |
| | Args: |
| | data (Dict): dictionary of data |
| | |
| | Returns: |
| | Tuple[ |
| | Dict[str, torch.FloatTensor], |
| | Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
| | ]: input of model and data for calculating loss. |
| | eg. |
| | ([image], [labels]), or ([image], [labels, periods, network]) when deepsurv |
| | """ |
| | in_data = {'image': data['image'].to(self.device)} |
| | labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} |
| |
|
| | if not any(data['periods']): |
| | return in_data, labels |
| |
|
| | |
| | labels = { |
| | **labels, |
| | **{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} |
| | } |
| | return in_data, labels |
| |
|
| | def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| | """ |
| | Forward |
| | |
| | Args: |
| | in_data (Dict[str, torch.Tensor]): data to be input into model |
| | |
| | Returns: |
| | Dict[str, torch.Tensor]: output |
| | """ |
| | image = in_data['image'] |
| | output = self.network(image) |
| | return output |
| |
|
| |
|
| | class FusionModel(ModelWidget): |
| | """ |
| | Class for MLP+CNN or MLP+ViT model. |
| | """ |
| | def __init__(self, params: ParamSet) -> None: |
| | """ |
| | Args: |
| | params: (ParamSet): parameters |
| | """ |
| | super().__init__(params) |
| |
|
| | def set_data( |
| | self, |
| | data: Dict |
| | ) -> Tuple[ |
| | Dict[str, torch.FloatTensor], |
| | Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
| | ]: |
| | """ |
| | Unpack data for forwarding of MLP+CNN or MLP+ViT and calculating loss |
| | by passing them to device. |
| | When deepsurv, period and network are also returned. |
| | |
| | Args: |
| | data (Dict): dictionary of data |
| | |
| | Returns: |
| | Tuple[ |
| | Dict[str, torch.FloatTensor], |
| | Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
| | ]: input of model and data for calculating loss. |
| | eg. |
| | ([inputs, image], [labels]), or ([inputs, image], [labels, periods, network]) when deepsurv |
| | """ |
| | in_data = { |
| | 'inputs': data['inputs'].to(self.device), |
| | 'image': data['image'].to(self.device) |
| | } |
| | labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} |
| |
|
| | if not any(data['periods']): |
| | return in_data, labels |
| |
|
| | |
| | labels = { |
| | **labels, |
| | **{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} |
| | } |
| | return in_data, labels |
| |
|
| | def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| | """ |
| | Forward |
| | |
| | Args: |
| | in_data (Dict[str, torch.Tensor]): data to be input into model |
| | |
| | Returns: |
| | Dict[str, torch.Tensor]: output |
| | """ |
| | inputs = in_data['inputs'] |
| | image = in_data['image'] |
| | output = self.network(inputs, image) |
| | return output |
| |
|
| |
|
| | def create_model(params: ParamSet) -> nn.Module: |
| | """ |
| | Construct model. |
| | |
| | Args: |
| | params (ParamSet): parameters |
| | |
| | Returns: |
| | nn.Module: model |
| | """ |
| | _isMLPModel = (params.mlp is not None) and (params.net is None) |
| | _isCVModel = (params.mlp is None) and (params.net is not None) |
| | _isFusionModel = (params.mlp is not None) and (params.net is not None) |
| |
|
| | if _isMLPModel: |
| | return MLPModel(params) |
| | elif _isCVModel: |
| | return CVModel(params) |
| | elif _isFusionModel: |
| | return FusionModel(params) |
| | else: |
| | raise ValueError(f"Invalid model type: mlp={params.mlp}, net={params.net}.") |
| |
|