|
|
|
|
|
|
|
|
|
|
|
from pathlib import Path |
|
|
import copy |
|
|
from abc import ABC, abstractmethod |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .component import create_net |
|
|
from .logger import BaseLogger |
|
|
from lib import ParamSet |
|
|
from typing import List, Dict, Tuple, Union |
|
|
|
|
|
|
|
|
|
|
|
LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]] |
|
|
|
|
|
|
|
|
logger = BaseLogger.get_logger(__name__) |
|
|
|
|
|
|
|
|
class BaseModel(ABC): |
|
|
""" |
|
|
Class to construct model. This class is the base class to construct model. |
|
|
""" |
|
|
def __init__(self, params: ParamSet) -> None: |
|
|
""" |
|
|
Class to define Model |
|
|
|
|
|
Args: |
|
|
param (ParamSet): parameters |
|
|
""" |
|
|
self.params = params |
|
|
self.device = self.params.device |
|
|
|
|
|
self.network = create_net( |
|
|
mlp=self.params.mlp, |
|
|
net=self.params.net, |
|
|
num_outputs_for_label=self.params.num_outputs_for_label, |
|
|
mlp_num_inputs=self.params.mlp_num_inputs, |
|
|
in_channel=self.params.in_channel, |
|
|
vit_image_size=self.params.vit_image_size, |
|
|
pretrained=self.params.pretrained |
|
|
) |
|
|
self.network.to(self.device) |
|
|
|
|
|
|
|
|
self.acting_best_weight = None |
|
|
self.acting_best_epoch = None |
|
|
|
|
|
def train(self) -> None: |
|
|
""" |
|
|
Make network training mode. |
|
|
""" |
|
|
self.network.train() |
|
|
|
|
|
def eval(self) -> None: |
|
|
""" |
|
|
Make network evaluation mode. |
|
|
""" |
|
|
self.network.eval() |
|
|
|
|
|
@abstractmethod |
|
|
def set_data( |
|
|
self, |
|
|
data: Dict |
|
|
) -> Tuple[ |
|
|
Dict[str, torch.FloatTensor], |
|
|
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
|
|
]: |
|
|
raise NotImplementedError |
|
|
|
|
|
def store_weight(self, at_epoch: int = None) -> None: |
|
|
""" |
|
|
Store weight and epoch number when it is saved. |
|
|
|
|
|
Args: |
|
|
at_epoch (int): epoch number when save weight |
|
|
""" |
|
|
self.acting_best_epoch = at_epoch |
|
|
|
|
|
_network = copy.deepcopy(self.network) |
|
|
if hasattr(_network, 'module'): |
|
|
|
|
|
self.acting_best_weight = copy.deepcopy(_network.module.to(torch.device('cpu')).state_dict()) |
|
|
else: |
|
|
self.acting_best_weight = copy.deepcopy(_network.state_dict()) |
|
|
|
|
|
def save_weight(self, save_datetime_dir: str, as_best: bool = None) -> None: |
|
|
""" |
|
|
Save weight. |
|
|
|
|
|
Args: |
|
|
save_datetime_dir (str): save_datetime_dir |
|
|
as_best (bool): True if weight is saved as best, otherwise False. Defaults to None. |
|
|
""" |
|
|
|
|
|
save_dir = Path(save_datetime_dir, 'weights') |
|
|
save_dir.mkdir(parents=True, exist_ok=True) |
|
|
save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt' |
|
|
save_path = Path(save_dir, save_name) |
|
|
|
|
|
if as_best: |
|
|
save_name_as_best = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '_best' + '.pt' |
|
|
save_path_as_best = Path(save_dir, save_name_as_best) |
|
|
if save_path.exists(): |
|
|
|
|
|
save_path.rename(save_path_as_best) |
|
|
else: |
|
|
torch.save(self.acting_best_weight, save_path_as_best) |
|
|
else: |
|
|
save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt' |
|
|
torch.save(self.acting_best_weight, save_path) |
|
|
|
|
|
def load_weight(self, weight_path: Path) -> None: |
|
|
""" |
|
|
Load wight from weight_path. |
|
|
|
|
|
Args: |
|
|
weight_path (Path): path to weight |
|
|
""" |
|
|
logger.info(f"Load weight: {weight_path}.\n") |
|
|
weight = torch.load(weight_path) |
|
|
self.network.load_state_dict(weight) |
|
|
|
|
|
|
|
|
class ModelMixin: |
|
|
def to_gpu(self, gpu_ids: List[int]) -> None: |
|
|
""" |
|
|
Make model compute on the GPU. |
|
|
|
|
|
Args: |
|
|
gpu_ids (List[int]): GPU ids |
|
|
""" |
|
|
if gpu_ids != []: |
|
|
assert torch.cuda.is_available(), 'No available GPU on this machine.' |
|
|
self.network = nn.DataParallel(self.network, device_ids=gpu_ids) |
|
|
|
|
|
def init_network(self) -> None: |
|
|
""" |
|
|
Initialize network. |
|
|
This method is used at test to reset the current weight by redefining network. |
|
|
""" |
|
|
self.network = create_net( |
|
|
mlp=self.params.mlp, |
|
|
net=self.params.net, |
|
|
num_outputs_for_label=self.params.num_outputs_for_label, |
|
|
mlp_num_inputs=self.params.mlp_num_inputs, |
|
|
in_channel=self.params.in_channel, |
|
|
vit_image_size=self.params.vit_image_size, |
|
|
pretrained=self.params.pretrained |
|
|
) |
|
|
self.network.to(self.device) |
|
|
|
|
|
class ModelWidget(BaseModel, ModelMixin): |
|
|
""" |
|
|
Class for a widget to inherit multiple classes simultaneously |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class MLPModel(ModelWidget): |
|
|
""" |
|
|
Class for MLP model |
|
|
""" |
|
|
|
|
|
def __init__(self, params: ParamSet) -> None: |
|
|
""" |
|
|
Args: |
|
|
params: (ParamSet): parameters |
|
|
""" |
|
|
super().__init__(params) |
|
|
|
|
|
def set_data( |
|
|
self, |
|
|
data: Dict |
|
|
) -> Tuple[ |
|
|
Dict[str, torch.FloatTensor], |
|
|
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
|
|
]: |
|
|
""" |
|
|
Unpack data for forwarding of MLP and calculating loss |
|
|
by passing them to device. |
|
|
When deepsurv, period and network are also returned. |
|
|
|
|
|
Args: |
|
|
data (Dict): dictionary of data |
|
|
|
|
|
Returns: |
|
|
Tuple[ |
|
|
Dict[str, torch.FloatTensor], |
|
|
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
|
|
]: input of model and data for calculating loss. |
|
|
eg. |
|
|
([inputs], [labels]), or ([inputs], [labels, periods, network]) when deepsurv |
|
|
""" |
|
|
in_data = {'inputs': data['inputs'].to(self.device)} |
|
|
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} |
|
|
|
|
|
if not any(data['periods']): |
|
|
return in_data, labels |
|
|
|
|
|
|
|
|
labels = { |
|
|
**labels, |
|
|
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} |
|
|
} |
|
|
return in_data, labels |
|
|
|
|
|
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward |
|
|
|
|
|
Args: |
|
|
in_data (Dict[str, torch.Tensor]): data to be input into model |
|
|
|
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: output |
|
|
""" |
|
|
inputs = in_data['inputs'] |
|
|
output = self.network(inputs) |
|
|
return output |
|
|
|
|
|
|
|
|
class CVModel(ModelWidget): |
|
|
""" |
|
|
Class for CNN or ViT model |
|
|
""" |
|
|
def __init__(self, params: ParamSet) -> None: |
|
|
""" |
|
|
Args: |
|
|
params: (ParamSet): parameters |
|
|
""" |
|
|
super().__init__(params) |
|
|
|
|
|
def set_data( |
|
|
self, |
|
|
data: Dict |
|
|
) -> Tuple[ |
|
|
Dict[str, torch.FloatTensor], |
|
|
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
|
|
]: |
|
|
""" |
|
|
Unpack data for forwarding of CNN or ViT and calculating loss by passing them to device. |
|
|
When deepsurv, period and network are also returned. |
|
|
|
|
|
Args: |
|
|
data (Dict): dictionary of data |
|
|
|
|
|
Returns: |
|
|
Tuple[ |
|
|
Dict[str, torch.FloatTensor], |
|
|
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
|
|
]: input of model and data for calculating loss. |
|
|
eg. |
|
|
([image], [labels]), or ([image], [labels, periods, network]) when deepsurv |
|
|
""" |
|
|
in_data = {'image': data['image'].to(self.device)} |
|
|
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} |
|
|
|
|
|
if not any(data['periods']): |
|
|
return in_data, labels |
|
|
|
|
|
|
|
|
labels = { |
|
|
**labels, |
|
|
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} |
|
|
} |
|
|
return in_data, labels |
|
|
|
|
|
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward |
|
|
|
|
|
Args: |
|
|
in_data (Dict[str, torch.Tensor]): data to be input into model |
|
|
|
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: output |
|
|
""" |
|
|
image = in_data['image'] |
|
|
output = self.network(image) |
|
|
return output |
|
|
|
|
|
|
|
|
class FusionModel(ModelWidget): |
|
|
""" |
|
|
Class for MLP+CNN or MLP+ViT model. |
|
|
""" |
|
|
def __init__(self, params: ParamSet) -> None: |
|
|
""" |
|
|
Args: |
|
|
params: (ParamSet): parameters |
|
|
""" |
|
|
super().__init__(params) |
|
|
|
|
|
def set_data( |
|
|
self, |
|
|
data: Dict |
|
|
) -> Tuple[ |
|
|
Dict[str, torch.FloatTensor], |
|
|
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
|
|
]: |
|
|
""" |
|
|
Unpack data for forwarding of MLP+CNN or MLP+ViT and calculating loss |
|
|
by passing them to device. |
|
|
When deepsurv, period and network are also returned. |
|
|
|
|
|
Args: |
|
|
data (Dict): dictionary of data |
|
|
|
|
|
Returns: |
|
|
Tuple[ |
|
|
Dict[str, torch.FloatTensor], |
|
|
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]] |
|
|
]: input of model and data for calculating loss. |
|
|
eg. |
|
|
([inputs, image], [labels]), or ([inputs, image], [labels, periods, network]) when deepsurv |
|
|
""" |
|
|
in_data = { |
|
|
'inputs': data['inputs'].to(self.device), |
|
|
'image': data['image'].to(self.device) |
|
|
} |
|
|
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}} |
|
|
|
|
|
if not any(data['periods']): |
|
|
return in_data, labels |
|
|
|
|
|
|
|
|
labels = { |
|
|
**labels, |
|
|
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)} |
|
|
} |
|
|
return in_data, labels |
|
|
|
|
|
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward |
|
|
|
|
|
Args: |
|
|
in_data (Dict[str, torch.Tensor]): data to be input into model |
|
|
|
|
|
Returns: |
|
|
Dict[str, torch.Tensor]: output |
|
|
""" |
|
|
inputs = in_data['inputs'] |
|
|
image = in_data['image'] |
|
|
output = self.network(inputs, image) |
|
|
return output |
|
|
|
|
|
|
|
|
def create_model(params: ParamSet) -> nn.Module: |
|
|
""" |
|
|
Construct model. |
|
|
|
|
|
Args: |
|
|
params (ParamSet): parameters |
|
|
|
|
|
Returns: |
|
|
nn.Module: model |
|
|
""" |
|
|
_isMLPModel = (params.mlp is not None) and (params.net is None) |
|
|
_isCVModel = (params.mlp is None) and (params.net is not None) |
|
|
_isFusionModel = (params.mlp is not None) and (params.net is not None) |
|
|
|
|
|
if _isMLPModel: |
|
|
return MLPModel(params) |
|
|
elif _isCVModel: |
|
|
return CVModel(params) |
|
|
elif _isFusionModel: |
|
|
return FusionModel(params) |
|
|
else: |
|
|
raise ValueError(f"Invalid model type: mlp={params.mlp}, net={params.net}.") |
|
|
|