import torch import torch.nn as nn import torch.nn.functional as F class LargeNet(nn.Module): def __init__(self): super(LargeNet, self).__init__() self.name = "large" self.conv1 = nn.Conv2d(3, 5, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(5, 10, 5) self.fc1 = nn.Linear(10 * 29 * 29, 32) self.fc2 = nn.Linear(32, 7) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 10 * 29 * 29) x = F.relu(self.fc1(x)) x = self.fc2(x) x = x.squeeze(1) # Flatten to [batch_size] return x def load_model(model_path, device='cpu'): """Load the trained model from saved weights""" model = LargeNet() state_dict = torch.load(model_path, map_location=device) model.load_state_dict(state_dict) model.to(device) model.eval() return model