landmarkclassifier / src /model.py
mewhenmonkeyavatar's picture
real initial commit.
4b7c478
Raw
History Blame Contribute Delete
2.69 kB
import torch
import torch.nn as nn
# define the CNN architecture
# MY CODE HERE
class ResidualBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_ch)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
self.shortcut = nn.Identity()
if stride != 1 or in_ch != out_ch:
self.shortcut = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1, stride, bias=False),
nn.BatchNorm2d(out_ch)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return self.relu(out)
class MyModel(nn.Module):
def __init__(self, num_classes: int = 1000, dropout: float = 0.7) -> None:
super().__init__()
self.Backbone = nn.Sequential(
ResidualBlock(3, 64, 1),
ResidualBlock(64, 128, 2),
ResidualBlock(128, 256, 2),
ResidualBlock(256, 512, 2),
nn.AdaptiveAvgPool2d((1,1))
)
self.Classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(True),
nn.Dropout(dropout),
nn.Linear(512,192),
nn.BatchNorm1d(192),
nn.ReLU(True),
nn.Dropout(dropout),
nn.Linear(192, num_classes)
)
def forward(self, x):
return self.Classifier(self.Backbone(x))
######################################################################################
# TESTS
######################################################################################
import pytest
@pytest.fixture(scope="session")
def data_loaders():
from .data import get_data_loaders
return get_data_loaders(batch_size=2)
def test_model_construction(data_loaders):
model = MyModel(num_classes=23, dropout=0.3)
dataiter = iter(data_loaders["train"])
images, labels = next(dataiter)
out = model(images)
assert isinstance(
out, torch.Tensor
), "The output of the .forward method should be a Tensor of size ([batch_size], [n_classes])"
assert out.shape == torch.Size(
[2, 23]
), f"Expected an output tensor of size (2, 23), got {out.shape}"