Spaces:
Running
Running
File size: 9,535 Bytes
1c4c77a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
"""Written by Eitan Kosman."""
import logging
import os
import time
from typing import List, Optional, Union
import torch
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from utils.callbacks import Callback
from utils.types import Device
import torch
from network.anomaly_detector_model import AnomalyDetector
# Use safe_globals context
def get_torch_device() -> Device:
"""
Retrieves the device to run torch models, with preferability to GPU (denoted as cuda by torch)
Returns: Device to run the models
"""
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(model_path: str) -> nn.Module:
"""Loads a Pytorch model (CPU compatible, PyTorch >=2.6)."""
logging.info(f"Load the model from: {model_path}")
from network.anomaly_detector_model import AnomalyDetector
# Wrap torch.load with safe_globals and weights_only=False
with torch.serialization.safe_globals([AnomalyDetector]):
model = torch.load(model_path, map_location="cpu", weights_only=False)
logging.info(model)
return model
class TorchModel(nn.Module):
"""Wrapper class for a torch model to make it comfortable to train and load
models."""
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.device = get_torch_device()
self.iteration = 0
self.model = model
self.is_data_parallel = False
self.callbacks = []
def register_callback(self, callback_fn: Callback) -> None:
"""
Register a callback to be called after each evaluation run
Args:
callback_fn: a callable that accepts 2 inputs (output, target)
- output is the model's output
- target is the values of the target variable
"""
self.callbacks.append(callback_fn)
def data_parallel(self):
"""Transfers the model to data parallel mode."""
self.is_data_parallel = True
if not isinstance(self.model, torch.nn.DataParallel):
self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1])
return self
@classmethod
def load_model(cls, model_path: str):
"""
Loads a pickled model
Args:
model_path: path to the pickled model
Returns: TorchModel class instance wrapping the provided model
"""
return cls(load_model(model_path))
def notify_callbacks(self, notification, *args, **kwargs) -> None:
"""Calls all callbacks registered with this class.
Args:
notification: The type of notification to be called.
"""
for callback in self.callbacks:
try:
method = getattr(callback, notification)
method(*args, **kwargs)
except (AttributeError, TypeError) as e:
logging.error(
f"callback {callback.__class__.__name__} doesn't fully implement the required interface {e}" # pylint: disable=line-too-long
)
def fit(
self,
train_iter: DataLoader,
criterion: nn.Module,
optimizer: Optimizer,
eval_iter: Optional[DataLoader] = None,
epochs: int = 10,
network_model_path_base: Optional[str] = None,
save_every: Optional[int] = None,
evaluate_every: Optional[int] = None,
) -> None:
"""
Args:
train_iter: iterator for training
criterion: loss function
optimizer: optimizer for the algorithm
eval_iter: iterator for evaluation
epochs: amount of epochs
network_model_path_base: where to save the models
save_every: saving model checkpoints every specified amount of epochs
evaluate_every: perform evaluation every specified amount of epochs.
If the evaluation is expensive, you probably want to
choose a high value for this
"""
criterion = criterion.to(self.device)
self.notify_callbacks("on_training_start", epochs)
for epoch in range(epochs):
train_loss = self.do_epoch(
criterion=criterion,
optimizer=optimizer,
data_iter=train_iter,
epoch=epoch,
)
if save_every and network_model_path_base and epoch % save_every == 0:
logging.info(f"Save the model after epoch {epoch}")
self.save(os.path.join(network_model_path_base, f"epoch_{epoch}.pt"))
val_loss = None
if eval_iter and evaluate_every and epoch % evaluate_every == 0:
logging.info(f"Evaluating after epoch {epoch}")
val_loss = self.evaluate(
criterion=criterion,
data_iter=eval_iter,
)
self.notify_callbacks("on_training_iteration_end", train_loss, val_loss)
self.notify_callbacks("on_training_end", self.model)
# Save the last model anyway...
if network_model_path_base:
self.save(os.path.join(network_model_path_base, f"epoch_{epoch + 1}.pt"))
def evaluate(self, criterion: nn.Module, data_iter: DataLoader) -> float:
"""
Evaluates the model
Args:
criterion: Loss function for calculating the evaluation
data_iter: torch data iterator
"""
self.eval()
self.notify_callbacks("on_evaluation_start", len(data_iter))
total_loss = 0
with torch.no_grad():
for iteration, (batch, targets) in enumerate(data_iter):
batch = self.data_to_device(batch, self.device)
targets = self.data_to_device(targets, self.device)
outputs = self.model(batch)
loss = criterion(outputs, targets)
self.notify_callbacks(
"on_evaluation_step",
iteration,
outputs.detach().cpu(),
targets.detach().cpu(),
loss.item(),
)
total_loss += loss.item()
loss = total_loss / len(data_iter)
self.notify_callbacks("on_evaluation_end")
return loss
def do_epoch(
self,
criterion: nn.Module,
optimizer: Optimizer,
data_iter: DataLoader,
epoch: int,
) -> float:
"""Perform a whole epoch.
Args:
criterion (nn.Module): Loss function to be used.
optimizer (Optimizer): Optimizer to use for minimizing the loss function.
data_iter (DataLoader): Loader for data samples used for training the model.
epoch (int): The epoch number.
Returns:
float: Average training loss calculated during the epoch.
"""
total_loss = 0
total_time = 0.0
self.train()
self.notify_callbacks("on_epoch_start", epoch, len(data_iter))
for iteration, (batch, targets) in enumerate(data_iter):
self.iteration += 1
start_time = time.time()
batch = self.data_to_device(batch, self.device)
targets = self.data_to_device(targets, self.device)
outputs = self.model(batch)
loss = criterion(outputs, targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
end_time = time.time()
total_time += end_time - start_time
self.notify_callbacks(
"on_epoch_step",
self.iteration,
iteration,
loss.item(),
)
self.iteration += 1
loss = total_loss / len(data_iter)
self.notify_callbacks("on_epoch_end", loss)
return loss
def data_to_device(
self, data: Union[Tensor, List[Tensor]], device: Device
) -> Union[Tensor, List[Tensor]]:
"""
Transfers a tensor data to a device
Args:
data: torch tensor
device: target device
"""
if isinstance(data, list):
data = [d.to(device) for d in data]
elif isinstance(data, tuple):
data = tuple([d.to(device) for d in data])
else:
data = data.to(device)
return data
def save(self, model_path: str) -> None:
"""Saves the model to the given path.
If currently using data parallel, the method
will save the original model and not the data parallel instance of it
Args:
model_path: target path to save the model to
"""
if self.is_data_parallel:
torch.save(self.model.module, model_path)
else:
torch.save(self.model, model_path)
def get_model(self) -> nn.Module:
if self.is_data_parallel:
return self.model.module
return self.model
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
|