| | import torch |
| | import torch.nn as nn |
| | from torchvision.transforms import Compose, ToTensor, RandomHorizontalFlip, Normalize, Resize, RandomRotation |
| | import numpy as np |
| | from torch.utils.data import DataLoader |
| | from DeePixBis.Dataset import PixWiseDataset |
| | from DeePixBis.Model import DeePixBiS |
| | from DeePixBis.Loss import PixWiseBCELoss |
| | from DeePixBis.Metrics import predict, test_accuracy, test_loss |
| | from DeePixBis.Trainer import Trainer |
| |
|
| | model = DeePixBiS() |
| | model.load_state_dict(torch.load('./DeePixBiS.pth')) |
| |
|
| | loss_fn = PixWiseBCELoss() |
| |
|
| | opt = torch.optim.Adam(model.parameters(), lr=0.0001) |
| |
|
| | train_tfms = Compose([Resize([224, 224]), |
| | RandomHorizontalFlip(), |
| | RandomRotation(10), |
| | ToTensor(), |
| | Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
| |
|
| | test_tfms = Compose([Resize([224, 224]), |
| | ToTensor(), |
| | Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
| |
|
| | train_dataset = PixWiseDataset('./train_data.csv', transform=train_tfms) |
| | train_ds = train_dataset.dataset() |
| |
|
| | val_dataset = PixWiseDataset('./test_data.csv', transform=test_tfms) |
| | val_ds = val_dataset.dataset() |
| |
|
| | batch_size = 10 |
| | train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True) |
| | val_dl = DataLoader(val_ds, batch_size, shuffle=True, num_workers=0, pin_memory=True) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | trainer = Trainer(train_dl, val_dl, model, 1, opt, loss_fn) |
| |
|
| | print('Training Beginning\n') |
| | trainer.fit() |
| | print('\nTraining Complete') |
| | torch.save(model.state_dict(), './DeePixBiS.pth') |
| |
|