|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
|
|
|
from model import ResNet18 |
|
|
from preprocessing import PreprocessedImageFolder, augmentations, make_dls |
|
|
from trainer import ( |
|
|
LRFinderCB, |
|
|
ActivationStatsCB, |
|
|
AugmentCB, |
|
|
DeviceCB, |
|
|
MultiClassAccuracyCB, |
|
|
ProgressCB, |
|
|
Trainer, |
|
|
WandBCB, |
|
|
) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
train_ds = PreprocessedImageFolder("./dataset/train", None) |
|
|
valid_ds = PreprocessedImageFolder("./dataset/test", None) |
|
|
dls = make_dls(train_ds, valid_ds, batch_size=32, num_workers=2) |
|
|
|
|
|
model = ResNet18(in_channels=1, num_classes=len(train_ds.classes)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
progress = ProgressCB(in_notebook=False) |
|
|
wandb_cb = WandBCB(proj_name="test", model_path="./model.pth") |
|
|
augment = AugmentCB(device=device, transform=augmentations) |
|
|
acc_cb = MultiClassAccuracyCB(with_wandb=True) |
|
|
|
|
|
trainer = Trainer( |
|
|
model, |
|
|
dls, |
|
|
F.cross_entropy, |
|
|
torch.optim.SGD, |
|
|
lr=1e-4, |
|
|
cbs=[DeviceCB(device), augment, progress, wandb_cb, acc_cb], |
|
|
) |
|
|
trainer.fit(5, True, True) |
|
|
|
|
|
|
|
|
progress.plot_losses(save=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|