import torch import torch.nn as nn import torch.nn.functional as F # This class MUST match the architecture of the model you saved in the .pth file. # For this example, we assume 3 output classes (e.g., cat, dog, bird). # And input images of size 3x224x224 (3 channels, 224x224 pixels). class SimpleCNN(nn.Module): def __init__(self, num_classes=3): super(SimpleCNN, self).__init__() # Conv Layer 1 self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1) self.relu1 = nn.ReLU() self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # Conv Layer 2 self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1) self.relu2 = nn.ReLU() self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # Flatten the layer # Image size starts at 224x224, after two pools -> 224/2 -> 112/2 -> 56x56 self.fc1 = nn.Linear(32 * 56 * 56, 128) self.relu3 = nn.ReLU() self.fc2 = nn.Linear(128, num_classes) def forward(self, x): x = self.pool1(self.relu1(self.conv1(x))) x = self.pool2(self.relu2(self.conv2(x))) # Flatten the output for the fully connected layers x = x.view(-1, 32 * 56 * 56) x = self.relu3(self.fc1(x)) x = self.fc2(x) return x