| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from unet import UNet |
| from torch.utils.data import DataLoader |
| from data import SegmentationDataset, transform_img |
|
|
| transform = transform_img() |
|
|
| train_dataset = SegmentationDataset("DUTS-TR-Image", "DUTS-TR-Mask", transform=transform) |
| train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
| test_dataset = SegmentationDataset("DUTS-TE-Image", "DUTS-TE-Mask", transform=transform) |
| test_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True) |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = UNet().to(device) |
| criterion = nn.BCEWithLogitsLoss() |
| optimizer = optim.Adam(model.parameters(), lr=1e-4) |
|
|
| def evaluate_model(model, dataloader, criterion, device): |
| model.eval() |
| total_loss = 0 |
| total_correct = 0 |
| total_pixels = 0 |
|
|
| with torch.no_grad(): |
| for images, masks in dataloader: |
| |
| images = images.to(device) |
| masks = masks.to(device) |
| |
| outputs = model(images) |
| |
| loss = criterion(outputs, masks) |
| total_loss += loss.item() |
|
|
| preds = torch.sigmoid(outputs) > 0.5 |
| total_correct += (preds==masks).sum().item() |
| total_pixels += torch.numel(preds) |
| |
| avg_loss = total_loss / len(dataloader) |
| accuracy = total_correct / total_pixels |
| return avg_loss, accuracy |
|
|
| num_epochs = 2 |
| total_correct = 0 |
| total_pixels = 0 |
|
|
| train_loss_lst = [] |
| train_accuracy_lst = [] |
| test_loss_lst = [] |
| test_accuracy_lst = [] |
|
|
| for epoch in range(num_epochs): |
| print(f"Epoch: {epoch+1}") |
| model.train() |
| epoch_loss = 0 |
| |
| for images, masks in train_dataloader: |
| |
| images = images.to(device) |
| masks = masks.to(device) |
| |
| outputs = model(images) |
| |
| loss = criterion(outputs, masks) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| preds = torch.sigmoid(outputs) > 0.5 |
| total_correct += (preds==masks).sum().item() |
| total_pixels += torch.numel(preds) |
|
|
| epoch_loss += loss.item() |
|
|
| train_accuracy = total_correct / total_pixels |
| avg_train_loss = epoch_loss/len(train_dataloader) |
| print(f"Train loss at {epoch+1} epoch: {avg_train_loss}") |
| print(f"Train accuracy at {epoch+1} epoch: {train_accuracy}") |
| test_loss, test_accuracy = evaluate_model(model, test_dataloader, criterion, device) |
| print(f"Test loss at {epoch+1} epoch: {test_loss}") |
| print(f"Test accuracy at {epoch+1} epoch: {test_accuracy}") |
| train_loss_lst.append(avg_train_loss) |
| test_loss_lst.append(test_loss) |
| train_accuracy_lst.append(train_accuracy) |
| test_accuracy_lst.append(test_accuracy) |