foreversheikh commited on
Commit
1e3a246
·
verified ·
1 Parent(s): dae6254

Upload TorchUtils.py

Browse files
Files changed (1) hide show
  1. TorchUtils.py +284 -0
TorchUtils.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Written by Eitan Kosman."""
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ from typing import List, Optional, Union
7
+
8
+ import torch
9
+ from torch import Tensor, nn
10
+ from torch.optim import Optimizer
11
+ from torch.utils.data import DataLoader
12
+
13
+ from utils.callbacks import Callback
14
+ from utils.types import Device
15
+ import torch
16
+
17
+ from network.anomaly_detector_model import AnomalyDetector
18
+
19
+ # Use safe_globals context
20
+
21
+
22
+
23
+ def get_torch_device() -> Device:
24
+ """
25
+ Retrieves the device to run torch models, with preferability to GPU (denoted as cuda by torch)
26
+ Returns: Device to run the models
27
+ """
28
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+
31
+ def load_model(model_path: str) -> nn.Module:
32
+ """Loads a Pytorch model (CPU compatible, PyTorch >=2.6)."""
33
+ logging.info(f"Load the model from: {model_path}")
34
+
35
+ from network.anomaly_detector_model import AnomalyDetector
36
+
37
+ # Wrap torch.load with safe_globals and weights_only=False
38
+ with torch.serialization.safe_globals([AnomalyDetector]):
39
+ model = torch.load(model_path, map_location="cpu", weights_only=False)
40
+
41
+ logging.info(model)
42
+ return model
43
+
44
+
45
+
46
+ class TorchModel(nn.Module):
47
+ """Wrapper class for a torch model to make it comfortable to train and load
48
+ models."""
49
+
50
+ def __init__(self, model: nn.Module) -> None:
51
+ super().__init__()
52
+ self.device = get_torch_device()
53
+ self.iteration = 0
54
+ self.model = model
55
+ self.is_data_parallel = False
56
+ self.callbacks = []
57
+
58
+ def register_callback(self, callback_fn: Callback) -> None:
59
+ """
60
+ Register a callback to be called after each evaluation run
61
+ Args:
62
+ callback_fn: a callable that accepts 2 inputs (output, target)
63
+ - output is the model's output
64
+ - target is the values of the target variable
65
+ """
66
+ self.callbacks.append(callback_fn)
67
+
68
+ def data_parallel(self):
69
+ """Transfers the model to data parallel mode."""
70
+ self.is_data_parallel = True
71
+ if not isinstance(self.model, torch.nn.DataParallel):
72
+ self.model = torch.nn.DataParallel(self.model, device_ids=[0, 1])
73
+
74
+ return self
75
+
76
+ @classmethod
77
+ def load_model(cls, model_path: str):
78
+ """
79
+ Loads a pickled model
80
+ Args:
81
+ model_path: path to the pickled model
82
+
83
+ Returns: TorchModel class instance wrapping the provided model
84
+ """
85
+ return cls(load_model(model_path))
86
+
87
+ def notify_callbacks(self, notification, *args, **kwargs) -> None:
88
+ """Calls all callbacks registered with this class.
89
+
90
+ Args:
91
+ notification: The type of notification to be called.
92
+ """
93
+ for callback in self.callbacks:
94
+ try:
95
+ method = getattr(callback, notification)
96
+ method(*args, **kwargs)
97
+ except (AttributeError, TypeError) as e:
98
+ logging.error(
99
+ f"callback {callback.__class__.__name__} doesn't fully implement the required interface {e}" # pylint: disable=line-too-long
100
+ )
101
+
102
+ def fit(
103
+ self,
104
+ train_iter: DataLoader,
105
+ criterion: nn.Module,
106
+ optimizer: Optimizer,
107
+ eval_iter: Optional[DataLoader] = None,
108
+ epochs: int = 10,
109
+ network_model_path_base: Optional[str] = None,
110
+ save_every: Optional[int] = None,
111
+ evaluate_every: Optional[int] = None,
112
+ ) -> None:
113
+ """
114
+
115
+ Args:
116
+ train_iter: iterator for training
117
+ criterion: loss function
118
+ optimizer: optimizer for the algorithm
119
+ eval_iter: iterator for evaluation
120
+ epochs: amount of epochs
121
+ network_model_path_base: where to save the models
122
+ save_every: saving model checkpoints every specified amount of epochs
123
+ evaluate_every: perform evaluation every specified amount of epochs.
124
+ If the evaluation is expensive, you probably want to
125
+ choose a high value for this
126
+ """
127
+ criterion = criterion.to(self.device)
128
+ self.notify_callbacks("on_training_start", epochs)
129
+
130
+ for epoch in range(epochs):
131
+ train_loss = self.do_epoch(
132
+ criterion=criterion,
133
+ optimizer=optimizer,
134
+ data_iter=train_iter,
135
+ epoch=epoch,
136
+ )
137
+
138
+ if save_every and network_model_path_base and epoch % save_every == 0:
139
+ logging.info(f"Save the model after epoch {epoch}")
140
+ self.save(os.path.join(network_model_path_base, f"epoch_{epoch}.pt"))
141
+
142
+ val_loss = None
143
+ if eval_iter and evaluate_every and epoch % evaluate_every == 0:
144
+ logging.info(f"Evaluating after epoch {epoch}")
145
+ val_loss = self.evaluate(
146
+ criterion=criterion,
147
+ data_iter=eval_iter,
148
+ )
149
+
150
+ self.notify_callbacks("on_training_iteration_end", train_loss, val_loss)
151
+
152
+ self.notify_callbacks("on_training_end", self.model)
153
+ # Save the last model anyway...
154
+ if network_model_path_base:
155
+ self.save(os.path.join(network_model_path_base, f"epoch_{epoch + 1}.pt"))
156
+
157
+ def evaluate(self, criterion: nn.Module, data_iter: DataLoader) -> float:
158
+ """
159
+ Evaluates the model
160
+ Args:
161
+ criterion: Loss function for calculating the evaluation
162
+ data_iter: torch data iterator
163
+ """
164
+ self.eval()
165
+ self.notify_callbacks("on_evaluation_start", len(data_iter))
166
+ total_loss = 0
167
+
168
+ with torch.no_grad():
169
+ for iteration, (batch, targets) in enumerate(data_iter):
170
+ batch = self.data_to_device(batch, self.device)
171
+ targets = self.data_to_device(targets, self.device)
172
+
173
+ outputs = self.model(batch)
174
+ loss = criterion(outputs, targets)
175
+
176
+ self.notify_callbacks(
177
+ "on_evaluation_step",
178
+ iteration,
179
+ outputs.detach().cpu(),
180
+ targets.detach().cpu(),
181
+ loss.item(),
182
+ )
183
+
184
+ total_loss += loss.item()
185
+
186
+ loss = total_loss / len(data_iter)
187
+ self.notify_callbacks("on_evaluation_end")
188
+ return loss
189
+
190
+ def do_epoch(
191
+ self,
192
+ criterion: nn.Module,
193
+ optimizer: Optimizer,
194
+ data_iter: DataLoader,
195
+ epoch: int,
196
+ ) -> float:
197
+ """Perform a whole epoch.
198
+
199
+ Args:
200
+ criterion (nn.Module): Loss function to be used.
201
+ optimizer (Optimizer): Optimizer to use for minimizing the loss function.
202
+ data_iter (DataLoader): Loader for data samples used for training the model.
203
+ epoch (int): The epoch number.
204
+
205
+ Returns:
206
+ float: Average training loss calculated during the epoch.
207
+ """
208
+ total_loss = 0
209
+ total_time = 0.0
210
+ self.train()
211
+ self.notify_callbacks("on_epoch_start", epoch, len(data_iter))
212
+ for iteration, (batch, targets) in enumerate(data_iter):
213
+ self.iteration += 1
214
+ start_time = time.time()
215
+ batch = self.data_to_device(batch, self.device)
216
+ targets = self.data_to_device(targets, self.device)
217
+
218
+ outputs = self.model(batch)
219
+
220
+ loss = criterion(outputs, targets)
221
+
222
+ # Backward and optimize
223
+ optimizer.zero_grad()
224
+ loss.backward()
225
+ optimizer.step()
226
+
227
+ total_loss += loss.item()
228
+
229
+ end_time = time.time()
230
+
231
+ total_time += end_time - start_time
232
+
233
+ self.notify_callbacks(
234
+ "on_epoch_step",
235
+ self.iteration,
236
+ iteration,
237
+ loss.item(),
238
+ )
239
+ self.iteration += 1
240
+
241
+ loss = total_loss / len(data_iter)
242
+
243
+ self.notify_callbacks("on_epoch_end", loss)
244
+ return loss
245
+
246
+ def data_to_device(
247
+ self, data: Union[Tensor, List[Tensor]], device: Device
248
+ ) -> Union[Tensor, List[Tensor]]:
249
+ """
250
+ Transfers a tensor data to a device
251
+ Args:
252
+ data: torch tensor
253
+ device: target device
254
+ """
255
+ if isinstance(data, list):
256
+ data = [d.to(device) for d in data]
257
+ elif isinstance(data, tuple):
258
+ data = tuple([d.to(device) for d in data])
259
+ else:
260
+ data = data.to(device)
261
+
262
+ return data
263
+
264
+ def save(self, model_path: str) -> None:
265
+ """Saves the model to the given path.
266
+
267
+ If currently using data parallel, the method
268
+ will save the original model and not the data parallel instance of it
269
+ Args:
270
+ model_path: target path to save the model to
271
+ """
272
+ if self.is_data_parallel:
273
+ torch.save(self.model.module, model_path)
274
+ else:
275
+ torch.save(self.model, model_path)
276
+
277
+ def get_model(self) -> nn.Module:
278
+ if self.is_data_parallel:
279
+ return self.model.module
280
+
281
+ return self.model
282
+
283
+ def forward(self, *args, **kwargs):
284
+ return self.model(*args, **kwargs)