eyupipler commited on
Commit
a0f945e
·
verified ·
1 Parent(s): 3899ece

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +6 -3
model.py CHANGED
@@ -7,21 +7,23 @@ class SimpleCNN(nn.Module):
7
  super(SimpleCNN, self).__init__()
8
  self.num_classes = num_classes
9
  self.model_type = model_type
10
- # Define convolutional and fc layers based on model_type
11
  if model_type == 'f':
12
- # Two pool layers: 448 -> 224 -> 112 -> 56 -> 28
13
  self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
14
  self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
15
  self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
16
  self.fc1 = nn.Linear(64 * 28 * 28, 256)
17
  self.dropout = nn.Dropout(0.5)
18
  elif model_type == 'c':
 
19
  self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
20
  self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
21
  self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
22
  self.fc1 = nn.Linear(128 * 28 * 28, 512)
23
  self.dropout = nn.Dropout(0.5)
24
  elif model_type == 'q':
 
25
  self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
  self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
27
  self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
@@ -51,6 +53,7 @@ class SimpleCNN(nn.Module):
51
  def load_model(version='c', device='cpu'):
52
  """
53
  Downloads and loads the SimpleCNN model for the specified version: 'f', 'c', or 'q'.
 
54
  """
55
  model_type = version.lower()
56
  filename = f"Vbai-2.1{model_type}.pt"
@@ -65,6 +68,6 @@ def load_model(version='c', device='cpu'):
65
  # Initialize and load model
66
  model = SimpleCNN(model_type=model_type, num_classes=6).to(device)
67
  state_dict = torch.load(weights_path, map_location=device)
68
- model.load_state_dict(state_dict)
69
  model.eval()
70
  return model
 
7
  super(SimpleCNN, self).__init__()
8
  self.num_classes = num_classes
9
  self.model_type = model_type
10
+ # Model architectures assume 224x224 input
11
  if model_type == 'f':
12
+ # After 3 pool layers: 224 -> 112 -> 56 -> 28
13
  self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
14
  self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
15
  self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
16
  self.fc1 = nn.Linear(64 * 28 * 28, 256)
17
  self.dropout = nn.Dropout(0.5)
18
  elif model_type == 'c':
19
+ # After 3 pool layers: 224 -> 112 -> 56 -> 28
20
  self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
21
  self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
22
  self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
23
  self.fc1 = nn.Linear(128 * 28 * 28, 512)
24
  self.dropout = nn.Dropout(0.5)
25
  elif model_type == 'q':
26
+ # After 4 pool layers: 224 -> 112 -> 56 -> 28 -> 14
27
  self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
28
  self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
  self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
 
53
  def load_model(version='c', device='cpu'):
54
  """
55
  Downloads and loads the SimpleCNN model for the specified version: 'f', 'c', or 'q'.
56
+ Input images must be resized to 224x224.
57
  """
58
  model_type = version.lower()
59
  filename = f"Vbai-2.1{model_type}.pt"
 
68
  # Initialize and load model
69
  model = SimpleCNN(model_type=model_type, num_classes=6).to(device)
70
  state_dict = torch.load(weights_path, map_location=device)
71
+ model.load_state_dict(state_dict, strict=False)
72
  model.eval()
73
  return model