|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from datasets import load_dataset |
|
|
|
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
import math |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
IMG_SIZE = 256 |
|
|
|
|
|
image_transform = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), |
|
|
]) |
|
|
|
|
|
mask_transform = transforms.Compose([ |
|
|
transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=transforms.InterpolationMode.NEAREST), |
|
|
transforms.PILToTensor(), |
|
|
]) |
|
|
|
|
|
|
|
|
|
|
|
def visualize_mask(mask: Image, img: Image = None): |
|
|
""" |
|
|
Visualizes the segmentation mask. If an image is provided, it overlays the mask on the image. |
|
|
|
|
|
:param mask: The segmentation mask to visualize. Expects a pillow image |
|
|
:param img: The image to overlay the mask on. Expects a pillow image |
|
|
:return: |
|
|
""" |
|
|
mask_np = np.array(mask) |
|
|
|
|
|
class_colors = { |
|
|
1: [255, 0, 0], |
|
|
2: [0, 255, 0], |
|
|
3: [0, 0, 255], |
|
|
} |
|
|
|
|
|
h, w = mask_np.shape |
|
|
color_mask = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
for class_id, color in class_colors.items(): |
|
|
color_mask[mask_np == class_id] = color |
|
|
|
|
|
if img is not None: |
|
|
img = img.convert('RGBA') |
|
|
overlay = Image.fromarray(color_mask).convert('RGBA') |
|
|
blended = Image.blend(img, overlay, alpha=0.5) |
|
|
|
|
|
plt.figure(figsize=(6, 6)) |
|
|
plt.imshow(blended) |
|
|
plt.title('Segmentation Overlay') |
|
|
plt.axis('off') |
|
|
plt.show() |
|
|
else: |
|
|
plt.figure(figsize=(6, 6)) |
|
|
plt.imshow(color_mask) |
|
|
plt.title('Colorized Segmentation Mask') |
|
|
plt.axis('off') |
|
|
plt.show() |
|
|
|
|
|
def visualize_mask_tensor(img_tensor: torch.Tensor, mask_tensor: torch.Tensor, alpha: float = 0.5, title: str = "Image Mask"): |
|
|
""" |
|
|
Overlays a mask tensor onto an image tensor and displays the result. |
|
|
|
|
|
:param img_tensor: Image tensor of shape (C, H, W) |
|
|
:param mask_tensor: Mask tensor of shape (H, W) |
|
|
:param mode: 'class' or 'category' |
|
|
:param alpha: Transparency for overlay |
|
|
""" |
|
|
img_np = img_tensor.permute(1, 2, 0).numpy() |
|
|
mask_np = mask_tensor.numpy() |
|
|
|
|
|
h, w = mask_np.shape |
|
|
color_mask = np.zeros((h, w, 3), dtype=np.uint8) |
|
|
|
|
|
color_mask[mask_np == 1] = [255, 0, 0] |
|
|
color_mask[mask_np == 2] = [0, 0, 255] |
|
|
|
|
|
plt.figure(figsize=(6, 6)) |
|
|
plt.imshow(img_np) |
|
|
plt.imshow(color_mask, alpha=alpha) |
|
|
plt.axis('off') |
|
|
plt.title(title) |
|
|
plt.show() |
|
|
|
|
|
|
|
|
class_labels = ['cat', 'dog'] |
|
|
category_labels = ['Abyssinian', 'american bulldog', 'american pit bull terrier', 'basset hound', 'beagle', 'Bengal', 'Birman', 'Bombay', |
|
|
'boxer', 'British Shorthair', 'chihuahua', 'Egyptian Mau', 'english cocker spaniel', 'english setter', 'german shorthaired', 'great pyrenees', 'havanese', 'japanese chin', 'keeshond', 'leonberger', 'Maine Coon', 'miniature pinscher', 'newfoundland', |
|
|
'Persian', 'pomeranian', 'pug', 'Ragdoll', 'Russian Blue', 'saint bernard', 'samoyed', 'scottish terrier', 'shiba inu', 'Siamese', |
|
|
'Sphynx', 'staffordshire bull terrier', 'wheaten terrier', 'yorkshire terrier'] |
|
|
|
|
|
|
|
|
|
|
|
def transform_class_example(example: dict, include_ambiguous: bool = False) -> dict: |
|
|
""" |
|
|
Transforms the image and mask in from a given example |
|
|
|
|
|
:param example: an example from a hf dataset; a dictionary with keys from the dataset |
|
|
:param include_ambiguous: whether to include the ambiguous class in the mask. |
|
|
If true, the ambiguous class is included in the mask. If false, the ambiguous class is removed from the mask (background). |
|
|
|
|
|
:return: a dictionary with the transformed image and mask |
|
|
""" |
|
|
class_label = example['class'] |
|
|
mask = mask_transform(example["msk"]).squeeze(0) |
|
|
|
|
|
mask[mask == 0] = 2 |
|
|
mask[mask == 1] = 0 |
|
|
mask[mask == 2] = 1 |
|
|
if include_ambiguous: |
|
|
mask[mask == 3] = 1 |
|
|
else: |
|
|
mask[mask == 3] = 0 |
|
|
|
|
|
return { |
|
|
"image": image_transform(example["img"]), |
|
|
"mask": mask, |
|
|
"classification": class_label |
|
|
} |
|
|
|
|
|
def transform_category_example(example: dict, include_ambiguous: bool = False) -> dict: |
|
|
""" |
|
|
Transforms the image and mask in from a given example |
|
|
|
|
|
:param example: an example from a hf dataset; a dictionary with keys from the dataset |
|
|
:param include_ambiguous: whether to include the ambiguous class in the mask. |
|
|
If true, the ambiguous class is included in the mask. If false, the ambiguous class is removed from the mask (background). |
|
|
:return: a dictionary with the transformed image and mask |
|
|
""" |
|
|
category_label = example['category'] |
|
|
mask = mask_transform(example["msk"]).squeeze(0) |
|
|
|
|
|
|
|
|
mask[mask == 0] = 2 |
|
|
mask[mask == 1] = 0 |
|
|
mask[mask == 2] = 1 |
|
|
if include_ambiguous: |
|
|
mask[mask == 3] = 1 |
|
|
else: |
|
|
mask[mask == 3] = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return { |
|
|
"image": image_transform(example["img"]), |
|
|
"mask": mask, |
|
|
"classification": category_label |
|
|
} |
|
|
|
|
|
def generate_dataset(split: str, classification: str = 'class'): |
|
|
""" |
|
|
Generates a dataset for the given split and classification type. |
|
|
|
|
|
:param split: which split to generate. Can be train, test or valid |
|
|
:param mask: decides whether to use the mask, or the class label |
|
|
:param classification: Can either be class or category. Class is either cat or dog, while category is the breed. |
|
|
:return: a dataset of the given split. With transformed images and masks. |
|
|
""" |
|
|
if split not in ['test', 'train', 'valid']: |
|
|
raise ValueError('Split must be either test train or valid') |
|
|
if classification not in ['class', 'category']: |
|
|
raise ValueError('Classification must be either class or category. Class is either cat or dog, while category is the breed.') |
|
|
|
|
|
dataset = load_dataset('cvdl/oxford-pets', split=split) |
|
|
|
|
|
|
|
|
if classification == 'class': |
|
|
dataset = dataset.map(transform_class_example, remove_columns=['bbox', 'img', 'msk', 'category', 'class']) |
|
|
elif classification == 'category': |
|
|
dataset = dataset.map(transform_category_example, remove_columns=['bbox', 'img', 'msk', 'class', 'category']) |
|
|
|
|
|
dataset.set_format(type='torch', columns=['image', 'mask', 'classification']) |
|
|
|
|
|
|
|
|
return dataset |
|
|
|
|
|
train_set = generate_dataset('train', classification='class') |
|
|
test_set = generate_dataset('test', classification='class') |
|
|
|
|
|
mask = True |
|
|
batch_size = 32 |
|
|
|
|
|
train_loader = DataLoader(train_set, batch_size=batch_size) |
|
|
test_loader = DataLoader(test_set, batch_size=batch_size) |
|
|
|
|
|
class LeNet(nn.Module): |
|
|
def __init__(self, num_classes=2): |
|
|
super(LeNet, self).__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(3, 6, 5, 1) |
|
|
self.conv2 = nn.Conv2d(6, 12, 5, 1) |
|
|
self.conv3 = nn.Conv2d(12, 24, 5, 1) |
|
|
|
|
|
self.lin1 = nn.Linear(24*28*28, 1024) |
|
|
self.lin2 = nn.Linear(1024, 512) |
|
|
self.lin3 = nn.Linear(512, 128) |
|
|
self.lin4 = nn.Linear(128, num_classes) |
|
|
|
|
|
def forward(self, x): |
|
|
x = nn.functional.max_pool2d(nn.functional.relu(self.conv1(x)), (2,2)) |
|
|
x = nn.functional.max_pool2d(nn.functional.relu(self.conv2(x)), (2,2)) |
|
|
x = nn.functional.max_pool2d(nn.functional.relu(self.conv3(x)), (2,2)) |
|
|
x = torch.flatten(x, 1) |
|
|
x = nn.functional.relu(self.lin1(x)) |
|
|
x = nn.functional.relu(self.lin2(x)) |
|
|
x = nn.functional.relu(self.lin3(x)) |
|
|
out = self.lin4(x) |
|
|
return out |
|
|
|
|
|
def train_model(model, train_loader, optimizer, criterion, epoch): |
|
|
model.train() |
|
|
train_loss = 0.0 |
|
|
for idx, batch in tqdm(enumerate(train_loader)): |
|
|
optimizer.zero_grad() |
|
|
|
|
|
output = model(batch['image']) |
|
|
loss = criterion(output, batch['classification']) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
train_loss += loss.item() |
|
|
|
|
|
train_loss /= len(train_loader) |
|
|
print('[Training set] Epoch: {:d}, Average loss: {:.4f}'.format(epoch+1, train_loss)) |
|
|
|
|
|
return train_loss |
|
|
|
|
|
def test_model(model, test_loader, epoch): |
|
|
model.eval() |
|
|
correct = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for idx, value in tqdm(enumerate(test_loader)): |
|
|
output = model(value['image']) |
|
|
batch_pred = output.max(1, keepdim=True)[1] |
|
|
correct += batch_pred.eq(value['classification'].view_as(batch_pred)).sum().item() |
|
|
|
|
|
|
|
|
test_acc = correct / len(test_loader.dataset) |
|
|
print('[Test set] Epoch: {:d}, Accuracy: {:.2f}%\n'.format( |
|
|
epoch+1, 100. * test_acc)) |
|
|
|
|
|
return test_acc |
|
|
|
|
|
model = LeNet(num_classes=2) |
|
|
criterion = nn.CrossEntropyLoss() |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) |
|
|
|
|
|
best_acc = 0 |
|
|
for epoch in range(0, 100): |
|
|
print("now training") |
|
|
train_model(model, train_loader, optimizer, criterion, epoch) |
|
|
print("now testing") |
|
|
acc = test_model(model, test_loader, epoch) |
|
|
print(f"epoch {epoch+1}, best_acc {max(best_acc, acc)}") |
|
|
best_acc = max(best_acc, acc) |
|
|
|
|
|
torch.save(model.state_dict(), f'animals_nn_epoch_{epoch+1}.pth') |
|
|
|