eyupipler commited on
Commit
3899ece
·
verified ·
1 Parent(s): 4e34900

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +34 -36
model.py CHANGED
@@ -5,59 +5,57 @@ from huggingface_hub import hf_hub_download
5
  class SimpleCNN(nn.Module):
6
  def __init__(self, model_type='f', num_classes=6):
7
  super(SimpleCNN, self).__init__()
 
8
  self.model_type = model_type
9
- # Define conv layers based on model_type
10
  if model_type == 'f':
11
- channels = [3, 16, 32, 64]
12
- dropout_p = 0.5
13
- fc_hidden = 256
 
 
 
14
  elif model_type == 'c':
15
- channels = [3, 32, 64, 128]
16
- dropout_p = 0.5
17
- fc_hidden = 512
 
 
18
  elif model_type == 'q':
19
- channels = [3, 64, 128, 256, 512]
20
- dropout_p = 0.3
21
- fc_hidden = 1024
 
 
 
22
  else:
23
  raise ValueError(f"Unknown model type: {model_type}")
24
 
25
- # Build conv blocks
26
- layers = []
27
- for in_c, out_c in zip(channels[:-1], channels[1:]):
28
- layers.append(nn.Conv2d(in_c, out_c, kernel_size=3, padding=1))
29
- layers.append(nn.ReLU())
30
- layers.append(nn.MaxPool2d(2))
31
- self.features = nn.Sequential(*layers)
32
- self.dropout = nn.Dropout(dropout_p)
33
-
34
- # Dynamically compute flattened size
35
- with torch.no_grad():
36
- dummy = torch.zeros(1, 3, 448, 448)
37
- feat = self.features(dummy)
38
- flattened_size = feat.view(1, -1).size(1)
39
-
40
- # Fully connected layers
41
- self.fc1 = nn.Linear(flattened_size, fc_hidden)
42
- self.fc2 = nn.Linear(fc_hidden, num_classes)
43
 
44
  def forward(self, x):
45
- x = self.features(x)
 
 
 
 
46
  x = x.view(x.size(0), -1)
47
- x = self.dropout(torch.relu(self.fc1(x)))
 
48
  x = self.fc2(x)
49
  return x
50
 
51
 
52
  def load_model(version='c', device='cpu'):
53
  """
54
- Loads the correct model based on version: 'f', 'c', or 'q'.
55
  """
56
- # Determine filename and model_type
57
  model_type = version.lower()
58
  filename = f"Vbai-2.1{model_type}.pt"
59
 
60
- # Download weights
61
  weights_path = hf_hub_download(
62
  repo_id="Neurazum/Vbai-DPA-2.1",
63
  filename=filename,
@@ -66,7 +64,7 @@ def load_model(version='c', device='cpu'):
66
 
67
  # Initialize and load model
68
  model = SimpleCNN(model_type=model_type, num_classes=6).to(device)
69
- state = torch.load(weights_path, map_location=device)
70
- model.load_state_dict(state)
71
  model.eval()
72
- return model
 
5
  class SimpleCNN(nn.Module):
6
  def __init__(self, model_type='f', num_classes=6):
7
  super(SimpleCNN, self).__init__()
8
+ self.num_classes = num_classes
9
  self.model_type = model_type
10
+ # Define convolutional and fc layers based on model_type
11
  if model_type == 'f':
12
+ # Two pool layers: 448 -> 224 -> 112 -> 56 -> 28
13
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
14
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
15
+ self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
16
+ self.fc1 = nn.Linear(64 * 28 * 28, 256)
17
+ self.dropout = nn.Dropout(0.5)
18
  elif model_type == 'c':
19
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
20
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
21
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
22
+ self.fc1 = nn.Linear(128 * 28 * 28, 512)
23
+ self.dropout = nn.Dropout(0.5)
24
  elif model_type == 'q':
25
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
+ self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
27
+ self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
28
+ self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
29
+ self.fc1 = nn.Linear(512 * 14 * 14, 1024)
30
+ self.dropout = nn.Dropout(0.3)
31
  else:
32
  raise ValueError(f"Unknown model type: {model_type}")
33
 
34
+ self.relu = nn.ReLU()
35
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
36
+ self.fc2 = nn.Linear(self.fc1.out_features, num_classes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def forward(self, x):
39
+ x = self.pool(self.relu(self.conv1(x)))
40
+ x = self.pool(self.relu(self.conv2(x)))
41
+ x = self.pool(self.relu(self.conv3(x)))
42
+ if self.model_type == 'q':
43
+ x = self.pool(self.relu(self.conv4(x)))
44
  x = x.view(x.size(0), -1)
45
+ x = self.relu(self.fc1(x))
46
+ x = self.dropout(x)
47
  x = self.fc2(x)
48
  return x
49
 
50
 
51
  def load_model(version='c', device='cpu'):
52
  """
53
+ Downloads and loads the SimpleCNN model for the specified version: 'f', 'c', or 'q'.
54
  """
 
55
  model_type = version.lower()
56
  filename = f"Vbai-2.1{model_type}.pt"
57
 
58
+ # Download the weight file from Hugging Face Hub
59
  weights_path = hf_hub_download(
60
  repo_id="Neurazum/Vbai-DPA-2.1",
61
  filename=filename,
 
64
 
65
  # Initialize and load model
66
  model = SimpleCNN(model_type=model_type, num_classes=6).to(device)
67
+ state_dict = torch.load(weights_path, map_location=device)
68
+ model.load_state_dict(state_dict)
69
  model.eval()
70
+ return model