Spaces:
Running
Running
| """Written by Eitan Kosman.""" | |
| import logging | |
| import os | |
| import time | |
| from typing import List, Optional, Union | |
| import torch | |
| from torch import Tensor, nn | |
| from torch.optim import Optimizer | |
| from torch.utils.data import DataLoader | |
| from utils.callbacks import Callback | |
| from utils.types import Device | |
| import torch | |
| from network.anomaly_detector_model import AnomalyDetector | |
| # Use safe_globals context | |
| def get_torch_device() -> Device: | |
| """ | |
| Retrieves the device to run torch models, with preferability to GPU (denoted as cuda by torch) | |
| Returns: Device to run the models | |
| """ | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(model_path: str) -> nn.Module: | |
| """Loads a Pytorch model (CPU compatible, PyTorch >=2.6).""" | |
| logging.info(f"Load the model from: {model_path}") | |
| from network.anomaly_detector_model import AnomalyDetector | |
| # Wrap torch.load with safe_globals and weights_only=False | |
| with torch.serialization.safe_globals([AnomalyDetector]): | |
| model = torch.load(model_path, map_location="cpu", weights_only=False) | |
| logging.info(model) | |
| return model | |
| class TorchModel(nn.Module): | |
| """Wrapper class for a torch model to make it comfortable to train and load | |
| models.""" | |
| def __init__(self, model: nn.Module) -> None: | |
| super().__init__() | |
| self.device = get_torch_device() | |
| self.iteration = 0 | |
| self.model = model | |
| self.is_data_parallel = False | |
| self.callbacks = [] | |
| def register_callback(self, callback_fn: Callback) -> None: | |
| """ | |
| Register a callback to be called after each evaluation run | |
| Args: | |
| callback_fn: a callable that accepts 2 inputs (output, target) | |
| - output is the model's output | |
| - target is the values of the target variable | |
| """ | |
| self.callbacks.append(callback_fn) | |
| def data_parallel(self): | |
| """Transfers the model to data parallel mode.""" | |
| self.is_data_parallel = True | |
| if not isinstance(self.model, torch.nn.DataParallel): | |
| self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1]) | |
| return self | |
| def load_model(cls, model_path: str): | |
| """ | |
| Loads a pickled model | |
| Args: | |
| model_path: path to the pickled model | |
| Returns: TorchModel class instance wrapping the provided model | |
| """ | |
| return cls(load_model(model_path)) | |
| def notify_callbacks(self, notification, *args, **kwargs) -> None: | |
| """Calls all callbacks registered with this class. | |
| Args: | |
| notification: The type of notification to be called. | |
| """ | |
| for callback in self.callbacks: | |
| try: | |
| method = getattr(callback, notification) | |
| method(*args, **kwargs) | |
| except (AttributeError, TypeError) as e: | |
| logging.error( | |
| f"callback {callback.__class__.__name__} doesn't fully implement the required interface {e}" # pylint: disable=line-too-long | |
| ) | |
| def fit( | |
| self, | |
| train_iter: DataLoader, | |
| criterion: nn.Module, | |
| optimizer: Optimizer, | |
| eval_iter: Optional[DataLoader] = None, | |
| epochs: int = 10, | |
| network_model_path_base: Optional[str] = None, | |
| save_every: Optional[int] = None, | |
| evaluate_every: Optional[int] = None, | |
| ) -> None: | |
| """ | |
| Args: | |
| train_iter: iterator for training | |
| criterion: loss function | |
| optimizer: optimizer for the algorithm | |
| eval_iter: iterator for evaluation | |
| epochs: amount of epochs | |
| network_model_path_base: where to save the models | |
| save_every: saving model checkpoints every specified amount of epochs | |
| evaluate_every: perform evaluation every specified amount of epochs. | |
| If the evaluation is expensive, you probably want to | |
| choose a high value for this | |
| """ | |
| criterion = criterion.to(self.device) | |
| self.notify_callbacks("on_training_start", epochs) | |
| for epoch in range(epochs): | |
| train_loss = self.do_epoch( | |
| criterion=criterion, | |
| optimizer=optimizer, | |
| data_iter=train_iter, | |
| epoch=epoch, | |
| ) | |
| if save_every and network_model_path_base and epoch % save_every == 0: | |
| logging.info(f"Save the model after epoch {epoch}") | |
| self.save(os.path.join(network_model_path_base, f"epoch_{epoch}.pt")) | |
| val_loss = None | |
| if eval_iter and evaluate_every and epoch % evaluate_every == 0: | |
| logging.info(f"Evaluating after epoch {epoch}") | |
| val_loss = self.evaluate( | |
| criterion=criterion, | |
| data_iter=eval_iter, | |
| ) | |
| self.notify_callbacks("on_training_iteration_end", train_loss, val_loss) | |
| self.notify_callbacks("on_training_end", self.model) | |
| # Save the last model anyway... | |
| if network_model_path_base: | |
| self.save(os.path.join(network_model_path_base, f"epoch_{epoch + 1}.pt")) | |
| def evaluate(self, criterion: nn.Module, data_iter: DataLoader) -> float: | |
| """ | |
| Evaluates the model | |
| Args: | |
| criterion: Loss function for calculating the evaluation | |
| data_iter: torch data iterator | |
| """ | |
| self.eval() | |
| self.notify_callbacks("on_evaluation_start", len(data_iter)) | |
| total_loss = 0 | |
| with torch.no_grad(): | |
| for iteration, (batch, targets) in enumerate(data_iter): | |
| batch = self.data_to_device(batch, self.device) | |
| targets = self.data_to_device(targets, self.device) | |
| outputs = self.model(batch) | |
| loss = criterion(outputs, targets) | |
| self.notify_callbacks( | |
| "on_evaluation_step", | |
| iteration, | |
| outputs.detach().cpu(), | |
| targets.detach().cpu(), | |
| loss.item(), | |
| ) | |
| total_loss += loss.item() | |
| loss = total_loss / len(data_iter) | |
| self.notify_callbacks("on_evaluation_end") | |
| return loss | |
| def do_epoch( | |
| self, | |
| criterion: nn.Module, | |
| optimizer: Optimizer, | |
| data_iter: DataLoader, | |
| epoch: int, | |
| ) -> float: | |
| """Perform a whole epoch. | |
| Args: | |
| criterion (nn.Module): Loss function to be used. | |
| optimizer (Optimizer): Optimizer to use for minimizing the loss function. | |
| data_iter (DataLoader): Loader for data samples used for training the model. | |
| epoch (int): The epoch number. | |
| Returns: | |
| float: Average training loss calculated during the epoch. | |
| """ | |
| total_loss = 0 | |
| total_time = 0.0 | |
| self.train() | |
| self.notify_callbacks("on_epoch_start", epoch, len(data_iter)) | |
| for iteration, (batch, targets) in enumerate(data_iter): | |
| self.iteration += 1 | |
| start_time = time.time() | |
| batch = self.data_to_device(batch, self.device) | |
| targets = self.data_to_device(targets, self.device) | |
| outputs = self.model(batch) | |
| loss = criterion(outputs, targets) | |
| # Backward and optimize | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| end_time = time.time() | |
| total_time += end_time - start_time | |
| self.notify_callbacks( | |
| "on_epoch_step", | |
| self.iteration, | |
| iteration, | |
| loss.item(), | |
| ) | |
| self.iteration += 1 | |
| loss = total_loss / len(data_iter) | |
| self.notify_callbacks("on_epoch_end", loss) | |
| return loss | |
| def data_to_device( | |
| self, data: Union[Tensor, List[Tensor]], device: Device | |
| ) -> Union[Tensor, List[Tensor]]: | |
| """ | |
| Transfers a tensor data to a device | |
| Args: | |
| data: torch tensor | |
| device: target device | |
| """ | |
| if isinstance(data, list): | |
| data = [d.to(device) for d in data] | |
| elif isinstance(data, tuple): | |
| data = tuple([d.to(device) for d in data]) | |
| else: | |
| data = data.to(device) | |
| return data | |
| def save(self, model_path: str) -> None: | |
| """Saves the model to the given path. | |
| If currently using data parallel, the method | |
| will save the original model and not the data parallel instance of it | |
| Args: | |
| model_path: target path to save the model to | |
| """ | |
| if self.is_data_parallel: | |
| torch.save(self.model.module, model_path) | |
| else: | |
| torch.save(self.model, model_path) | |
| def get_model(self) -> nn.Module: | |
| if self.is_data_parallel: | |
| return self.model.module | |
| return self.model | |
| def forward(self, *args, **kwargs): | |
| return self.model(*args, **kwargs) | |