hari31416 commited on
Commit
ff6d1a9
·
1 Parent(s): c2aae40

Upload 2 files

Browse files
Files changed (2) hide show
  1. model.py +385 -0
  2. torch_train.py +543 -0
model.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ from datasets import load_dataset, Dataset, concatenate_datasets
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch.utils.data import DataLoader
6
+ from sklearn.model_selection import train_test_split
7
+ from sklearn.metrics import (
8
+ classification_report,
9
+ confusion_matrix,
10
+ accuracy_score,
11
+ precision_score,
12
+ )
13
+ from sklearn.ensemble import RandomForestClassifier
14
+ from xgboost import XGBClassifier
15
+ import torch.nn as nn
16
+ import torchmetrics
17
+ from torch.optim.lr_scheduler import CosineAnnealingLR
18
+
19
+ import numpy as np
20
+ import pandas as pd
21
+ import os
22
+ import pickle
23
+ import argparse
24
+ from torch_train import TorchTrain
25
+ from utilities import get_simple_logger
26
+
27
+ FILE_DIR = os.path.dirname(os.path.realpath(__file__))
28
+ DATA_DIR = os.path.join(FILE_DIR, "data")
29
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
30
+ random_state = 42
31
+ # set random state
32
+ np.random.seed(random_state)
33
+ torch.manual_seed(random_state)
34
+
35
+
36
+ class PDFDataLoader:
37
+ """A class that can be used to load the data to torch model. This will be used in the `PDFDataSet` class to create the final datasets."""
38
+
39
+ def __init__(self, df):
40
+ self.df = df
41
+
42
+ def __getitem__(self, idx):
43
+ """Gets the `idx` embedding and labels, converts them to the required format and returns them."""
44
+ row = self.df[idx]
45
+ embeddings = row["embeddings"]
46
+ label = row["label"]
47
+ # convert to torch int
48
+ label = np.array(label)
49
+ # add extra dimension to label
50
+ label = np.expand_dims(label, axis=0)
51
+ embeddings = torch.from_numpy(np.array(embeddings)).float()
52
+ return embeddings.to(device), torch.from_numpy(label).to(device).float()
53
+
54
+ def __len__(self):
55
+ return len(self.df)
56
+
57
+
58
+ class PDFDataSet:
59
+ def __init__(
60
+ self,
61
+ data_dir=DATA_DIR,
62
+ fraction_test_data_in_train=0.2,
63
+ model_ckpt="encoder",
64
+ ) -> None:
65
+ self.data_dir = data_dir
66
+ self.fraction_test_data_in_train = fraction_test_data_in_train
67
+ self.model_ckpt = model_ckpt
68
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
69
+ encoding_model = AutoModel.from_pretrained(model_ckpt)
70
+ encoding_model = encoding_model.to(device)
71
+ encoding_model = encoding_model.eval()
72
+ self.encoding_model = encoding_model
73
+ self.tokenizer = tokenizer
74
+ self.logger = get_simple_logger("pdf_dataset")
75
+
76
+ def create_datasets(self):
77
+ train_data_path = os.path.join(FILE_DIR, self.data_dir, "train.csv")
78
+ test_data_path = os.path.join(FILE_DIR, self.data_dir, "test.csv")
79
+ df = pd.read_csv(train_data_path)
80
+ test_df = pd.read_csv(test_data_path)
81
+ train_df, validation_df = train_test_split(df, test_size=0.3, random_state=42)
82
+ if self.fraction_test_data_in_train:
83
+ self.logger.info(
84
+ f"Adding {self.fraction_test_data_in_train} fraction of test dataset to the training set."
85
+ )
86
+ test_df, test_df_for_training = train_test_split(
87
+ test_df, test_size=self.fraction_test_data_in_train, random_state=42
88
+ )
89
+ train_df = pd.concat([train_df, test_df_for_training])
90
+
91
+ train_dataset = Dataset.from_pandas(train_df)
92
+ validation_dataset = Dataset.from_pandas(validation_df)
93
+ test_dataset = Dataset.from_pandas(test_df)
94
+ return train_dataset, validation_dataset, test_dataset
95
+
96
+ def mean_pooling(self, model_output, attention_mask):
97
+ token_embeddings = model_output[
98
+ 0
99
+ ] # First element of model_output contains all token embeddings
100
+ input_mask_expanded = (
101
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
102
+ )
103
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
104
+ input_mask_expanded.sum(1), min=1e-9
105
+ )
106
+
107
+ def sentences_to_embedding(self, sentences):
108
+ # Tokenize sentences
109
+ encoded_input = self.tokenizer(
110
+ sentences, padding=True, truncation=True, return_tensors="pt"
111
+ )
112
+ sentence_embeddings = self.mean_pooling(
113
+ self.encoding_model(**encoded_input), encoded_input["attention_mask"]
114
+ )
115
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
116
+ # remove last dimension
117
+ sentence_embeddings = sentence_embeddings.squeeze()
118
+ return sentence_embeddings.detach()
119
+
120
+ def get_embeddings(self, row):
121
+ return {
122
+ "embeddings": self.sentences_to_embedding(
123
+ sentences=row["content"],
124
+ )
125
+ }
126
+
127
+ def create_embeddings(self):
128
+ train_dataset, validation_dataset, test_dataset = self.create_datasets()
129
+ train_dataset = train_dataset.map(self.get_embeddings)
130
+ validation_dataset = validation_dataset.map(self.get_embeddings)
131
+ test_dataset = test_dataset.map(self.get_embeddings)
132
+ return train_dataset, validation_dataset, test_dataset
133
+
134
+
135
+ class PDFModel(nn.Module):
136
+ def __init__(self, input_size, hidden_sizes, output_size):
137
+ super(PDFModel, self).__init__()
138
+ self.seq_model = nn.Sequential()
139
+ for i, hidden_size in enumerate(hidden_sizes):
140
+ self.seq_model.add_module(f"linear_{i}", nn.Linear(input_size, hidden_size))
141
+ self.seq_model.add_module(f"relu_{i}", nn.ReLU())
142
+ input_size = hidden_size
143
+ self.last_layer = nn.Linear(input_size, output_size)
144
+ self.sigmoid = nn.Sigmoid()
145
+
146
+ def forward(self, x):
147
+ seq_out = self.seq_model(x)
148
+ out = self.last_layer(seq_out)
149
+ return self.sigmoid(out)
150
+
151
+
152
+ def evaluate_model(y_true, y_pred, model_name, split="train"):
153
+ accuracy = accuracy_score(y_true, y_pred)
154
+ precision = precision_score(y_true, y_pred)
155
+ classification_report_ = classification_report(y_true, y_pred)
156
+ print("------" * 10)
157
+ print(f"Evaluating for the model: {model_name} for {split} dataset...")
158
+ print(f"Accuracy: {accuracy}")
159
+ print(f"Precision: {precision}")
160
+ print(classification_report_)
161
+ print("------" * 10)
162
+
163
+
164
+ def train_dl_model(
165
+ train_data,
166
+ validation_data,
167
+ epochs=30,
168
+ input_shape=384,
169
+ hidden_sizes=[32, 16],
170
+ ):
171
+ model = PDFModel(input_size=input_shape, hidden_sizes=hidden_sizes, output_size=1)
172
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
173
+ loss_fn = nn.BCELoss()
174
+ accuracy = torchmetrics.Accuracy(
175
+ task="binary", num_classes=2, threshold=0.5, average="macro"
176
+ )
177
+ precision = torchmetrics.Precision(task="binary", average="macro")
178
+ metrics = {
179
+ "accuracy": accuracy,
180
+ "precision": precision,
181
+ }
182
+ scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=0.0001)
183
+ tt = TorchTrain(model, optimizer, loss_fn, metrics=metrics, scheduler=scheduler)
184
+ history = tt.fit(train_data, validation_data, verbose=True, epochs=epochs)
185
+ return history, model
186
+
187
+
188
+ def evaluate_models(fraction_test_data_in_train=0.1):
189
+ print("Creating Embeddings...")
190
+ ds = PDFDataSet(fraction_test_data_in_train=fraction_test_data_in_train)
191
+ train_dataset, validation_dataset, test_dataset = ds.create_embeddings()
192
+ print("Done\n")
193
+
194
+ print("Training DL Model")
195
+ # Create dataset for DL models:
196
+ BATCH_SIZE = 8
197
+ train_dataloader = PDFDataLoader(train_dataset)
198
+ validation_dataloader = PDFDataLoader(validation_dataset)
199
+ test_dataloader = PDFDataLoader(test_dataset)
200
+
201
+ train_data = DataLoader(train_dataloader, batch_size=BATCH_SIZE, shuffle=True)
202
+ validation_data = DataLoader(
203
+ validation_dataloader,
204
+ batch_size=BATCH_SIZE,
205
+ shuffle=True,
206
+ )
207
+ test_data = DataLoader(test_dataloader, batch_size=BATCH_SIZE, shuffle=True)
208
+ for X, y in train_data:
209
+ input_shape = int(X.shape[1])
210
+ output_shape = int(y.shape[1])
211
+ break
212
+ epochs = 30
213
+ hidden_sizes = [32, 16]
214
+ history, model = train_dl_model(
215
+ train_data=train_data,
216
+ validation_data=validation_data,
217
+ epochs=epochs,
218
+ hidden_sizes=hidden_sizes,
219
+ )
220
+ print("Done\n")
221
+ print("Evaluating DL Model")
222
+ y_test_pred = model(torch.from_numpy(np.array(test_dataset["embeddings"])).float())
223
+ y_test_pred = y_test_pred.detach().numpy()
224
+ y_test_pred = np.where(y_test_pred > 0.5, 1, 0)
225
+ evaluate_model(
226
+ y_true=test_dataset["label"],
227
+ y_pred=y_test_pred,
228
+ model_name="DL Model",
229
+ split="test",
230
+ )
231
+ print("Done\n")
232
+
233
+ # ML Models
234
+ print("Training and evaluating ML Models.")
235
+ X_train = train_dataset["embeddings"]
236
+ y_train = train_dataset["label"]
237
+ X_validation = validation_dataset["embeddings"]
238
+ y_validation = validation_dataset["label"]
239
+ X_test = test_dataset["embeddings"]
240
+ y_test = test_dataset["label"]
241
+ rfc_best_params = {
242
+ "max_depth": 23,
243
+ "max_features": "log2",
244
+ "n_estimators": 469,
245
+ }
246
+
247
+ xgb_best_params = {
248
+ "max_depth": 25,
249
+ "n_estimators": 372,
250
+ "learning_rate": 0.2522824287799319,
251
+ }
252
+ print("Fitting RandomForest")
253
+ rfc = RandomForestClassifier(**rfc_best_params)
254
+ rfc.fit(X_train, y_train)
255
+ evaluate_model(
256
+ y_true=y_train,
257
+ y_pred=rfc.predict(X_train),
258
+ model_name="RandomForest",
259
+ split="train",
260
+ )
261
+ evaluate_model(
262
+ y_true=y_validation,
263
+ y_pred=rfc.predict(X_validation),
264
+ model_name="RandomForest",
265
+ split="validation",
266
+ )
267
+ evaluate_model(
268
+ y_true=y_test,
269
+ y_pred=rfc.predict(X_test),
270
+ model_name="RandomForest",
271
+ split="test",
272
+ )
273
+
274
+ print("Fitting XGBoost")
275
+ xgb = XGBClassifier(**xgb_best_params)
276
+ xgb.fit(X_train, y_train)
277
+ evaluate_model(
278
+ y_true=y_train,
279
+ y_pred=xgb.predict(X_train),
280
+ model_name="XGBoost",
281
+ split="train",
282
+ )
283
+ evaluate_model(
284
+ y_true=y_validation,
285
+ y_pred=xgb.predict(X_validation),
286
+ model_name="XGBoost",
287
+ split="validation",
288
+ )
289
+ evaluate_model(
290
+ y_true=y_test,
291
+ y_pred=xgb.predict(X_test),
292
+ model_name="XGBoost",
293
+ split="test",
294
+ )
295
+ print("All Done")
296
+
297
+
298
+ def train_and_save_final_model(model_save_path="final_model.pkl"):
299
+ """This method creats and save the final model. The final model has the following characterstics:
300
+
301
+ - It is a RandomForestClassifier trained on all the training data and 10% of the test data. 10% of the test data. The 10% of test data is necessary as the distribution of the test data is very different from the training data.
302
+ - Since 10% of test data is used while training, this data is not used while claculating the final accuracy of the model, which is 100%.
303
+
304
+ Parameters
305
+ ----------
306
+ model_save_path : str, optional
307
+ The path to save the final model, by default "final_model.pkl"
308
+ Returns
309
+ -------
310
+ None
311
+ Examples
312
+ --------
313
+ >>> train_and_save_final_model()
314
+ >>> train_and_save_final_model(model_save_path="final_model.pkl")
315
+ """
316
+ print("Creating Embeddings...")
317
+ model_save_path = os.path.join(FILE_DIR, model_save_path)
318
+ ds = PDFDataSet(fraction_test_data_in_train=0.1)
319
+ train_dataset, validation_dataset, test_dataset = ds.create_embeddings()
320
+ train_dataset = concatenate_datasets([train_dataset, validation_dataset])
321
+ X_train = train_dataset["embeddings"]
322
+ X_test = test_dataset["embeddings"]
323
+ y_train = train_dataset["label"]
324
+ y_test = test_dataset["label"]
325
+
326
+ print("Training and evaluating the model...")
327
+ rfc_best_params = {
328
+ "max_depth": 23,
329
+ "max_features": "log2",
330
+ "n_estimators": 469,
331
+ }
332
+ rfc_model = RandomForestClassifier(**rfc_best_params)
333
+ rfc_model.fit(X_train, y_train)
334
+ evaluate_model(
335
+ y_true=y_train,
336
+ y_pred=rfc_model.predict(X_train),
337
+ model_name="Final Model",
338
+ split="train",
339
+ )
340
+ evaluate_model(
341
+ y_true=y_test,
342
+ y_pred=rfc_model.predict(X_test),
343
+ model_name="Final Model",
344
+ split="test",
345
+ )
346
+
347
+ print("Saving the model...")
348
+ with open(model_save_path, "wb") as f:
349
+ pickle.dump(rfc_model, f)
350
+ print(f"Model saved to: {model_save_path}")
351
+
352
+
353
+ def main(args):
354
+ task = args.task
355
+ if task == "train":
356
+ model_save_path = args.model_save_path
357
+ train_and_save_final_model(model_save_path=model_save_path)
358
+ elif task == "evaluate":
359
+ fraction_test_data_in_train = args.fraction
360
+ evaluate_models(fraction_test_data_in_train)
361
+
362
+
363
+ if __name__ == "__main__":
364
+ parser = argparse.ArgumentParser(description="Train and evaluate models")
365
+ parser.add_argument(
366
+ "--task",
367
+ type=str,
368
+ choices=["train", "evaluate"],
369
+ required=True,
370
+ help="Whether to train and save the best model or evaluate all the models.",
371
+ )
372
+ parser.add_argument(
373
+ "--fraction",
374
+ type=float,
375
+ default=0.1,
376
+ help="Fraction of test data in train dataset",
377
+ )
378
+ parser.add_argument(
379
+ "--model_save_path",
380
+ type=str,
381
+ default="final_model.pkl",
382
+ help="Path to save the final model",
383
+ )
384
+ args = parser.parse_args()
385
+ main(args)
torch_train.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+
5
+ class TorchTrain:
6
+ """A class for training a model in PyTorch.
7
+
8
+ Parameters
9
+ -----------
10
+ model (torch.nn.Module): The PyTorch model to train.
11
+ optimizer (torch.optim.Optimizer): The optimizer to use for training.
12
+ loss_function (callable): The loss function to use for training.
13
+ metrics (dict or callable, optional): The metrics to evaluate during training.
14
+ If a dictionary, the keys are the metric names and the values are functions that
15
+ take in `yhat` and `y` and return a metric value. If a callable, it should take
16
+ in `yhat` and `y` and return a metric value. Defaults to None.
17
+
18
+ Attributes
19
+ -----------
20
+ DEVICE (torch.device): The device to use for training (cuda if available, cpu otherwise).
21
+ model (torch.nn.Module): The PyTorch model being trained.
22
+ optimizer (torch.optim.Optimizer): The optimizer being used for training.
23
+ loss_function (callable): The loss function being used for training.
24
+ metrics (dict or callable): The metrics being evaluated during training.
25
+ metrics_evaluated (dict): The metrics evaluated during training.
26
+ train_loss (float): The average training loss.
27
+ test_loss (float): The average test loss.
28
+ train_iteration (int): The number of training iterations.
29
+ test_iteration (int): The number of test iterations.
30
+ train_metrics (dict): The metrics evaluated on the training data.
31
+ test_metrics (dict): The metrics evaluated on the test data.
32
+ """
33
+
34
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ def __init__(
37
+ self,
38
+ model,
39
+ optimizer,
40
+ loss_function,
41
+ metrics=None,
42
+ scheduler=None,
43
+ task_type="classification",
44
+ ) -> None:
45
+ """Initialize the TorchTrain object.
46
+
47
+ Parameters
48
+ -----------
49
+ model : torch.nn.Module
50
+ The PyTorch model to train.
51
+ optimizer : torch.optim.Optimizer
52
+ The optimizer to use for training.
53
+ loss_function : callable
54
+ The loss function to use for training.
55
+ metrics : dict or callable, optional
56
+ The metrics to evaluate during training. If a dictionary, the keys are the metric names
57
+ and the values are functions that take in `yhat` and `y` and return a metric value.
58
+ If a callable, it should take in `yhat` and `y` and return a metric value. Defaults to None.
59
+ scheduler : torch.optim.lr_scheduler, optional
60
+ The learning rate scheduler to use for training. Defaults to None.
61
+ """
62
+ self.model = model
63
+ self.model.to(self.DEVICE)
64
+ self.optimizer = optimizer
65
+ self.loss_function = loss_function
66
+ self.metrics = self.__preprocess_metrics(metrics)
67
+ self.scheduler = scheduler
68
+ self.metrics_evaluated = {}
69
+ self.train_loss = 0
70
+ self.test_loss = 0
71
+ self.train_iteration = 0
72
+ self.test_iteration = 0
73
+ self.train_metrics = {}
74
+ self.test_metrics = {}
75
+ self.history = {}
76
+ self.train_loss_all = []
77
+ self.test_loss_all = []
78
+ self.train_metrics_all = []
79
+ self.test_metrics_all = []
80
+ self.__train_scaled = False
81
+ self.__test_scaled = False
82
+ self.task_type = task_type
83
+
84
+ def __preprocess_metrics(self, metrics):
85
+ """Preprocesses the given metrics"""
86
+ if metrics is None:
87
+ return {}
88
+ if isinstance(metrics, dict):
89
+ return {key.title(): value for key, value in metrics.items()}
90
+ else:
91
+ raise TypeError(
92
+ "Metrics should be a dictionary of metrics or a function which takes yhat, y"
93
+ )
94
+
95
+ def __scale_matrices(self, loss, metrics, type="train"):
96
+ """Scales the loss and metrics
97
+
98
+ Parameters
99
+ -----------
100
+ loss : float
101
+ The loss to scale
102
+ metrics : dict
103
+ The metrics to scale
104
+ type : str, optional
105
+ The type of scaling to do, either "train" or "test", by default "train"
106
+
107
+ Returns
108
+ --------
109
+ loss : float
110
+ The scaled loss
111
+ metrics : dict
112
+ The scaled metrics
113
+ """
114
+ if type == "train" and not self.__train_scaled:
115
+ scale = self.train_iteration
116
+ self.__train_scaled = True
117
+ elif type == "test" and not self.__test_scaled:
118
+ scale = self.test_iteration
119
+ self.__test_scaled = True
120
+ else:
121
+ return loss, metrics
122
+ loss /= scale
123
+ for key in metrics:
124
+ metrics[key] /= scale
125
+ return loss, metrics
126
+
127
+ def __reset_counters(self):
128
+ """Resets all the counters and loss objects for a new epoch"""
129
+ self.train_loss, self.train_metrics = self.__scale_matrices(
130
+ self.train_loss, self.train_metrics, type="train"
131
+ )
132
+
133
+ self.test_loss, self.test_metrics = self.__scale_matrices(
134
+ self.test_loss, self.test_metrics, type="test"
135
+ )
136
+
137
+ self.train_loss_all.append(self.train_loss)
138
+ self.train_loss = 0
139
+
140
+ self.test_loss_all.append(self.test_loss)
141
+ self.test_loss = 0
142
+
143
+ self.train_iteration = 0
144
+ self.test_iteration = 0
145
+
146
+ self.train_metrics_all.append(self.train_metrics)
147
+ self.train_metrics = {}
148
+
149
+ self.test_metrics_all.append(self.test_metrics)
150
+ self.test_metrics = {}
151
+ self.__train_scaled = False
152
+ self.__test_scaled = False
153
+
154
+ @property
155
+ def loss(self):
156
+ """Returns the training loss"""
157
+ return self.train_loss_all[-1]
158
+
159
+ def __create_history(self):
160
+ """Creates the history dictionary"""
161
+ history = {
162
+ "train_loss": self.train_loss_all,
163
+ "val_loss": self.test_loss_all,
164
+ }
165
+ for key, value in self.metrics.items():
166
+ history[f"train_{key.lower()}"] = []
167
+ history[f"val_{key.lower()}"] = []
168
+
169
+ for item in self.train_metrics_all:
170
+ for key, value in item.items():
171
+ history[f"train_{key.lower()}"].append(value)
172
+
173
+ for item in self.test_metrics_all:
174
+ for key, value in item.items():
175
+ history[f"val_{key.lower()}"].append(value)
176
+ return history
177
+
178
+ def __parse_val(self, val):
179
+ """Parses the given value to a float"""
180
+ if isinstance(val, torch.Tensor):
181
+ val = val.item()
182
+ elif isinstance(val, np.ndarray):
183
+ val = float(val)
184
+ elif isinstance(val, (int, float)):
185
+ pass
186
+ else:
187
+ raise TypeError(
188
+ f"The given Metric function should return a tensor, numpy array, int, or float.\n\
189
+ Got {type(val)}"
190
+ )
191
+ return val
192
+
193
+ def _train_step(self, x, y):
194
+ """Perform a single training step.
195
+
196
+ Parameters
197
+ ----------
198
+ x : torch.Tensor
199
+ The input tensor.
200
+ y : torch.Tensor
201
+ The target tensor.
202
+
203
+ Returns
204
+ -------
205
+ tuple
206
+ A tuple containing the loss and the predicted output tensor.
207
+ """
208
+ self.model.train()
209
+ yhat = self.model(x)
210
+ l = self.loss_function(yhat, y)
211
+ self.optimizer.zero_grad()
212
+ l.backward()
213
+ self.optimizer.step()
214
+ self.train_iteration += 1
215
+ return l.item(), yhat
216
+
217
+ def _test_step(self, x, y):
218
+ """Perform a single testing step.
219
+
220
+ Parameters
221
+ ----------
222
+ x : torch.Tensor
223
+ The input tensor.
224
+ y : torch.Tensor
225
+ The target tensor.
226
+
227
+ Returns
228
+ -------
229
+ tuple
230
+ A tuple containing the loss and the predicted output tensor.
231
+ """
232
+ self.model.eval()
233
+ with torch.inference_mode():
234
+ yhat = self.model(x)
235
+ l = self.loss_function(yhat, y)
236
+ self.test_iteration += 1
237
+ return l.item(), yhat
238
+
239
+ def predict(self, x):
240
+ """Make predictions on a batch of data.
241
+
242
+ Parameters
243
+ ----------
244
+ x : torch.Tensor
245
+ The input tensor.
246
+
247
+ Returns
248
+ -------
249
+ torch.Tensor
250
+ The predicted output tensor.
251
+ """
252
+ self.model.eval()
253
+ yhat = self.model(x)
254
+ if self.task_type == "classification":
255
+ if len(yhat.shape) == 1:
256
+ # round
257
+ yhat = torch.round(yhat)
258
+ yhat = yhat.unsqueeze(1)
259
+ else:
260
+ yhat = torch.argmax(yhat, dim=1)
261
+
262
+ return yhat
263
+
264
+ def __calculate_metrics(self, yhat, y):
265
+ """Calculate the metrics for a batch of data.
266
+
267
+ Parameters
268
+ ----------
269
+ yhat : torch.Tensor
270
+ The predicted output tensor.
271
+ y : torch.Tensor
272
+ The target tensor.
273
+
274
+ Returns
275
+ -------
276
+ dict
277
+ A dictionary containing the values of the metrics.
278
+ """
279
+ metrics = {}
280
+ for key, metric in self.metrics.items():
281
+ val = metric(yhat, y)
282
+ if isinstance(val, torch.Tensor):
283
+ val = val.item()
284
+ elif isinstance(val, np.ndarray):
285
+ val = float(val)
286
+ elif isinstance(val, (int, float)):
287
+ pass
288
+ else:
289
+ raise TypeError(
290
+ f"Metric {key} should return a tensor, numpy array, int, or float"
291
+ )
292
+ metrics[key] = val
293
+ self.metrics_evaluated = metrics
294
+ return metrics
295
+
296
+ def __progress_bar(self, cur_iter, all_iter):
297
+ """Creates a progress bar showing the progress of the current batch.
298
+
299
+ Parameters
300
+ ----------
301
+ cur_iter : int
302
+ The current batch number.
303
+ all_iter : int
304
+ The total number of batches.
305
+
306
+ Returns
307
+ -------
308
+ str
309
+ The progress bar, in the form of "10/100[====----]".
310
+ """
311
+ len_progress_bar = 20
312
+ progress = int((cur_iter + 1) / all_iter * len_progress_bar)
313
+ progress_bar = "=" * progress + "-" * (len_progress_bar - progress)
314
+ return f"[{progress_bar}]"
315
+
316
+ def progress(self, cur_iter, all_iter, loss, metrics, on="train"):
317
+ """Prints a progress bar showing the progress of the current batch.
318
+
319
+ Parameters
320
+ ----------
321
+ cur_iter : int
322
+ The current batch number.
323
+ all_iter : int
324
+ The total number of batches.
325
+ loss : float
326
+ The current loss. Should be averaged over all batches.
327
+ metrics : dict
328
+ The metrics evaluated on the current batch.
329
+ on : str, optional
330
+ Whether the progress bar is for the training or testing data. Defaults to "train".
331
+
332
+ Returns
333
+ -------
334
+ str
335
+ The progress bar, in the form of "10/100[====----]".
336
+
337
+ Notes
338
+ -----
339
+ The progress bar shows the progress of the current batch as a bar of equal signs ("=") and
340
+ hyphens ("-"). The length of the bar is fixed at 20 characters. The current batch number
341
+ and total number of batches are displayed at the beginning of the progress bar. The current
342
+ loss and any metrics evaluated on the current batch are displayed at the end of the progress
343
+ bar.
344
+ """
345
+ # len_progress_bar = 20
346
+ # progress = int((cur_iter + 1) / all_iter * len_progress_bar)
347
+ # progress_bar = "=" * progress + "-" * (len_progress_bar - progress)
348
+ progress_bar = self.__progress_bar(cur_iter=cur_iter, all_iter=all_iter)
349
+
350
+ if on.lower() == "train":
351
+ iteration = self.train_iteration
352
+ prefix = f"Epoch {(self.current_epoch+1):2d}/{self.epochs:2d} Batch "
353
+ else:
354
+ iteration = self.test_iteration
355
+ prefix = "Epoch "
356
+
357
+ text = f"{prefix}{cur_iter:>4d}/{all_iter:>4d}{progress_bar} {on.title()} loss: {loss/iteration:.4f}"
358
+ for metric_name, metric_value in metrics.items():
359
+ text += f" | {on.title()} {metric_name}: {metric_value/iteration:.4f}"
360
+
361
+ return text
362
+
363
+ def update_metrics(self, cur_metrics, new_metrics):
364
+ """Update the metrics with the values for a new batch of data.
365
+
366
+ Parameters
367
+ ----------
368
+ cur_metrics : dict
369
+ The current values of the metrics.
370
+ new_metrics : dict
371
+ The values of the metrics for a new batch of data.
372
+
373
+ Returns
374
+ -------
375
+ dict
376
+ A dictionary containing the updated values of the metrics.
377
+ """
378
+ for key, value in new_metrics.items():
379
+ if key not in cur_metrics:
380
+ cur_metrics[key] = value
381
+ else:
382
+ cur_metrics[key] += value
383
+ return cur_metrics
384
+
385
+ def fit(
386
+ self,
387
+ train_loader,
388
+ validation_data_loader=None,
389
+ epochs=1,
390
+ verbose=True,
391
+ train_steps_per_epoch=None,
392
+ validation_steps_per_epoch=None,
393
+ ):
394
+ """Fit the PyTorch model.
395
+
396
+ Parameters
397
+ ----------
398
+ train_loader : torch.utils.data.DataLoader
399
+ The data loader for the training data.
400
+ validation_data_loader : torch.utils.data.DataLoader, optional
401
+ The data loader for the test data. Defaults to None.
402
+ epochs : int, optional
403
+ The number of epochs to train for. Defaults to 1.
404
+ verbose : bool, optional
405
+ Whether to print the training progress during training. Defaults to True.
406
+ train_steps_per_epoch : int, optional
407
+ The number of batches to train on per epoch. Defaults to None.
408
+ validation_steps_per_epoch : int, optional
409
+ The number of batches to test on per epoch. Defaults to None.
410
+
411
+ Returns
412
+ -------
413
+ None
414
+
415
+ Examples
416
+ --------
417
+ >>> model = MyModel()
418
+ >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
419
+ >>> loss_function = nn.CrossEntropyLoss()
420
+ >>> scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)
421
+ >>> train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
422
+ >>> validation_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
423
+ >>> trainer = TorchTrain(model, optimizer, loss_function, scheduler=scheduler)
424
+ >>> trainer.fit(train_loader, validation_data_loader=validation_data_loader, epochs=10, verbose=True)
425
+ """
426
+ self.epochs = epochs
427
+ if train_steps_per_epoch is None:
428
+ train_steps_per_epoch = len(train_loader)
429
+ if validation_data_loader is not None:
430
+ if validation_steps_per_epoch is None:
431
+ validation_steps_per_epoch = len(validation_data_loader)
432
+
433
+ for epoch in range(epochs):
434
+ self.current_epoch = epoch
435
+ for i, (x, y) in enumerate(train_loader):
436
+ x = x.to(self.DEVICE)
437
+ if isinstance(y, list) or isinstance(y, tuple):
438
+ y = [y_.to(self.DEVICE) for y_ in y]
439
+ else:
440
+ y = y.to(self.DEVICE)
441
+
442
+ train_loss, yhat = self._train_step(x, y)
443
+ self.train_loss += train_loss
444
+ metrics = self.__calculate_metrics(yhat, y)
445
+ self.train_metrics = self.update_metrics(self.train_metrics, metrics)
446
+
447
+ b_progress = self.progress(
448
+ i + 1,
449
+ train_steps_per_epoch,
450
+ self.train_loss,
451
+ self.train_metrics,
452
+ on="train",
453
+ )
454
+ if i == train_steps_per_epoch - 1:
455
+ print(b_progress)
456
+ break
457
+ else:
458
+ if verbose:
459
+ print(b_progress, end="\r")
460
+ if validation_data_loader is not None:
461
+ for i, (x, y) in enumerate(validation_data_loader):
462
+ x = x.to(self.DEVICE)
463
+ if isinstance(y, list) or isinstance(y, tuple):
464
+ y = [y_.to(self.DEVICE) for y_ in y]
465
+ else:
466
+ y = y.to(self.DEVICE)
467
+ test_loss, yhat = self._test_step(x, y)
468
+ self.test_loss += test_loss
469
+ metrics = self.__calculate_metrics(yhat, y)
470
+ self.test_metrics = self.update_metrics(self.test_metrics, metrics)
471
+ if i == validation_steps_per_epoch - 1:
472
+ break
473
+ test_progress = self.progress(
474
+ epoch + 1,
475
+ epochs,
476
+ self.test_loss,
477
+ self.test_metrics,
478
+ on="test",
479
+ )
480
+ print(test_progress)
481
+ self.__reset_counters()
482
+ if self.scheduler is not None:
483
+ self.scheduler.step()
484
+ if verbose and self.scheduler is not None:
485
+ print(f"New Learning rate: {self.scheduler.get_last_lr()[0]:.6f}")
486
+
487
+ return self.__create_history()
488
+
489
+ def save(self, path):
490
+ """Save the model to a file.
491
+
492
+ Parameters
493
+ ----------
494
+ path : str
495
+ The path to the file to save the model to.
496
+ """
497
+ torch.save(self.model.state_dict(), path)
498
+
499
+ def load(self, path):
500
+ """Load the model from a file.
501
+
502
+ Parameters
503
+ ----------
504
+ path : str
505
+ The path to the file to load the model from.
506
+ """
507
+ self.model.load_state_dict(torch.load(path))
508
+
509
+ def evaluate(self, data_loader, metric):
510
+ """Evaluate the model on a data loader and the given metric.
511
+
512
+ Parameters
513
+ ----------
514
+ data_loader : torch.utils.data.DataLoader
515
+ The data loader to evaluate the model on.
516
+ metric : function
517
+ The metric to evaluate the model with.
518
+
519
+ Returns
520
+ -------
521
+ float
522
+ The score of the model on the given metric.
523
+ """
524
+ running_score = 0
525
+ data_length = len(data_loader)
526
+ for i, (x, y) in enumerate(data_loader):
527
+ progress_bar = self.__progress_bar(i, data_length)
528
+ x = x.to(self.DEVICE)
529
+ if isinstance(y, list) or isinstance(y, tuple):
530
+ y = [y_.to(self.DEVICE) for y_ in y]
531
+ else:
532
+ y = y.to(self.DEVICE)
533
+
534
+ yhat = self.model(x)
535
+ yhat = torch.round(yhat)
536
+ score = metric(y, yhat)
537
+ score = self.__parse_val(score)
538
+ running_score += score
539
+
540
+ progress_bar = f"{i+1}/{data_length}" + progress_bar
541
+ progress_bar += f" Score: {(running_score/(i+1)):4f}"
542
+ print(progress_bar, end="\r")
543
+ return running_score / (len(data_loader))