File size: 5,898 Bytes
e4b9a7b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | # Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, Tuple
import torch
import torch.nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.optim.optimizer import Optimizer
from monai.engines.utils import get_devices_spec
from monai.utils import exact_version, optional_import
create_supervised_trainer, _ = optional_import("ignite.engine", "0.3.0", exact_version, "create_supervised_trainer")
create_supervised_evaluator, _ = optional_import("ignite.engine", "0.3.0", exact_version, "create_supervised_evaluator")
_prepare_batch, _ = optional_import("ignite.engine", "0.3.0", exact_version, "_prepare_batch")
if TYPE_CHECKING:
from ignite.engine import Engine
from ignite.metrics import Metric
else:
Engine, _ = optional_import("ignite.engine", "0.3.0", exact_version, "Engine")
Metric, _ = optional_import("ignite.metrics", "0.3.0", exact_version, "Metric")
def _default_transform(_x: torch.Tensor, _y: torch.Tensor, _y_pred: torch.Tensor, loss: torch.Tensor) -> float:
return loss.item()
def _default_eval_transform(
x: torch.Tensor, y: torch.Tensor, y_pred: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return y_pred, y
def create_multigpu_supervised_trainer(
net: torch.nn.Module,
optimizer: Optimizer,
loss_fn: Callable,
devices: Optional[Sequence[torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = _default_transform,
distributed: bool = False,
) -> Engine:
"""
Derived from `create_supervised_trainer` in Ignite.
Factory function for creating a trainer for supervised models.
Args:
net: the network to train.
optimizer: the optimizer to use.
loss_fn: the loss function to use.
devices: device(s) type specification (default: None).
Applies to both model and batches. None is all devices used, empty list is CPU only.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use
the first device as output device.
Returns:
Engine: a trainer engine with supervised update function.
Note:
`engine.state.output` for this engine is defined by `output_transform` parameter and is the loss
of the processed batch by default.
"""
devices_ = get_devices_spec(devices)
if distributed:
net = DistributedDataParallel(net, device_ids=devices_)
elif len(devices_) > 1:
net = DataParallel(net)
return create_supervised_trainer(
net, optimizer, loss_fn, devices_[0], non_blocking, prepare_batch, output_transform
)
def create_multigpu_supervised_evaluator(
net: torch.nn.Module,
metrics: Optional[Dict[str, Metric]] = None,
devices: Optional[Sequence[torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = _default_eval_transform,
distributed: bool = False,
) -> Engine:
"""
Derived from `create_supervised_evaluator` in Ignite.
Factory function for creating an evaluator for supervised models.
Args:
net: the model to train.
metrics: a map of metric names to Metrics.
devices: device(s) type specification (default: None).
Applies to both model and batches. None is all devices used, empty list is CPU only.
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)`
which fits output expected by metrics. If you change it you should use `output_transform` in metrics.
distributed: whether convert model to `DistributedDataParallel`, if have multiple devices, use
the first device as output device.
Note:
`engine.state.output` for this engine is defined by `output_transform` parameter and is
a tuple of `(batch_pred, batch_y)` by default.
Returns:
Engine: an evaluator engine with supervised inference function.
"""
devices_ = get_devices_spec(devices)
if distributed:
net = DistributedDataParallel(net, device_ids=devices_)
elif len(devices_) > 1:
net = DataParallel(net)
return create_supervised_evaluator(net, metrics, devices_[0], non_blocking, prepare_batch, output_transform)
|