air / model.py
grfdjiwsd's picture
Create model.py
9fae7c3 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# This class MUST match the architecture of the model you saved in the .pth file.
# For this example, we assume 3 output classes (e.g., cat, dog, bird).
# And input images of size 3x224x224 (3 channels, 224x224 pixels).
class SimpleCNN(nn.Module):
def __init__(self, num_classes=3):
super(SimpleCNN, self).__init__()
# Conv Layer 1
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
# Conv Layer 2
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# Flatten the layer
# Image size starts at 224x224, after two pools -> 224/2 -> 112/2 -> 56x56
self.fc1 = nn.Linear(32 * 56 * 56, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
# Flatten the output for the fully connected layers
x = x.view(-1, 32 * 56 * 56)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return x