| | import torch |
| | import torch.nn as nn |
| | from DeePixBis.Metrics import test_accuracy, test_loss |
| |
|
| |
|
| | class Trainer(): |
| | def __init__(self, train_dl, val_dl, model, epochs, opt, loss_fn, device='cpu'): |
| | self.train_dl = train_dl |
| | self.val_dl = val_dl |
| | self.model = model.to(device) |
| | self.epochs = epochs |
| | self.opt = opt |
| | self.loss_fn = loss_fn |
| | self.device = device |
| |
|
| | def train_one_epoch(self, num): |
| | print(f'\nEpoch ({num+1}/{self.epochs})') |
| | print('----------------------------------') |
| | |
| | for batch, (img, mask, label) in enumerate(self.train_dl): |
| | img, mask, label = img.to(self.device), mask.to(self.device), label.to(self.device) |
| | net_mask, net_label = self.model(img) |
| | loss = self.loss_fn(net_mask, net_label, mask, label) |
| |
|
| | |
| | self.opt.zero_grad() |
| | loss.backward() |
| | self.opt.step() |
| |
|
| | if batch % 9 == 0: |
| | print(f'Loss : {loss}') |
| |
|
| | |
| | test_acc = test_accuracy(self.model, self.val_dl) |
| | test_los = test_loss(self.model, self.val_dl, self.loss_fn) |
| |
|
| | print(f'Test Accuracy : {test_acc} Test Loss : {test_los}') |
| | return test_acc, test_los |
| |
|
| | def fit(self): |
| | training_loss = [] |
| | training_acc = [] |
| | self.model.train() |
| | for epoch in range(self.epochs): |
| | train_acc, train_loss = self.train_one_epoch(epoch) |
| | training_acc.append(train_acc) |
| | training_loss.append(train_loss) |
| |
|
| | return training_acc, training_loss |
| |
|