| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class LargeNet(nn.Module): | |
| def __init__(self): | |
| super(LargeNet, self).__init__() | |
| self.name = "large" | |
| self.conv1 = nn.Conv2d(3, 5, 5) | |
| self.pool = nn.MaxPool2d(2, 2) | |
| self.conv2 = nn.Conv2d(5, 10, 5) | |
| self.fc1 = nn.Linear(10 * 29 * 29, 32) | |
| self.fc2 = nn.Linear(32, 7) | |
| def forward(self, x): | |
| x = self.pool(F.relu(self.conv1(x))) | |
| x = self.pool(F.relu(self.conv2(x))) | |
| x = x.view(-1, 10 * 29 * 29) | |
| x = F.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| x = x.squeeze(1) # Flatten to [batch_size] | |
| return x | |
| def load_model(model_path, device='cpu'): | |
| """Load the trained model from saved weights""" | |
| model = LargeNet() | |
| state_dict = torch.load(model_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model | |