eyupipler commited on
Commit
6498090
·
verified ·
1 Parent(s): 4d48c23

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +36 -23
model.py CHANGED
@@ -1,47 +1,60 @@
1
- from huggingface_hub import hf_hub_download
2
  import torch
3
  import torch.nn as nn
 
4
 
5
  class SimpleCNN(nn.Module):
6
- def __init__(self, num_classes=6):
7
  super(SimpleCNN, self).__init__()
8
- self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
9
- self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
10
- self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
11
- self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  self.relu = nn.ReLU()
13
- self.dropout = nn.Dropout(0.5)
14
- self._initialize_fc(num_classes)
15
-
16
- def _initialize_fc(self, num_classes):
17
- dummy_input = torch.zeros(1, 3, 448, 448)
18
- x = self.pool(self.relu(self.conv1(dummy_input)))
19
- x = self.pool(self.relu(self.conv2(x)))
20
- x = self.pool(self.relu(self.conv3(x)))
21
- flattened_size = x.view(x.size(0), -1).shape[1]
22
- self.fc1 = nn.Linear(flattened_size, 512)
23
- self.fc2 = nn.Linear(512, num_classes)
24
 
25
  def forward(self, x):
26
  x = self.pool(self.relu(self.conv1(x)))
27
  x = self.pool(self.relu(self.conv2(x)))
28
  x = self.pool(self.relu(self.conv3(x)))
 
 
29
  x = x.view(x.size(0), -1)
30
- x = self.dropout(self.relu(self.fc1(x)))
 
31
  x = self.fc2(x)
32
  return x
33
 
34
- def load_model(model_name: str, device: str = 'cpu'):
35
- version = model_name.split()[-1].lower()
36
- filename = f"Vbai-DPA-{version}.pt"
37
-
 
 
 
38
  weights_path = hf_hub_download(
39
  repo_id="Neurazum/Vbai-DPA-2.1",
40
  filename=filename,
41
  repo_type="model"
42
  )
43
 
44
- model = SimpleCNN(num_classes=6).to(device)
45
  state = torch.load(weights_path, map_location=device)
46
  model.load_state_dict(state)
47
  model.eval()
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from huggingface_hub import hf_hub_download
4
 
5
  class SimpleCNN(nn.Module):
6
+ def __init__(self, model_type='f', num_classes=6):
7
  super(SimpleCNN, self).__init__()
8
+ self.num_classes = num_classes
9
+ if model_type == 'f':
10
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
11
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
12
+ self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
13
+ self.fc1 = nn.Linear(64 * 28 * 28, 256)
14
+ self.dropout = nn.Dropout(0.5)
15
+ elif model_type == 'c':
16
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
17
+ self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
18
+ self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
19
+ self.fc1 = nn.Linear(128 * 28 * 28, 512)
20
+ self.dropout = nn.Dropout(0.5)
21
+ elif model_type == 'q':
22
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
23
+ self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
24
+ self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
25
+ self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
26
+ self.fc1 = nn.Linear(512 * 14 * 14, 1024)
27
+ self.dropout = nn.Dropout(0.3)
28
+ self.fc2 = nn.Linear(self.fc1.out_features, num_classes)
29
  self.relu = nn.ReLU()
30
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
 
 
 
 
 
 
 
 
 
 
31
 
32
  def forward(self, x):
33
  x = self.pool(self.relu(self.conv1(x)))
34
  x = self.pool(self.relu(self.conv2(x)))
35
  x = self.pool(self.relu(self.conv3(x)))
36
+ if hasattr(self, 'conv4'):
37
+ x = self.pool(self.relu(self.conv4(x)))
38
  x = x.view(x.size(0), -1)
39
+ x = self.relu(self.fc1(x))
40
+ x = self.dropout(x)
41
  x = self.fc2(x)
42
  return x
43
 
44
+
45
+ def load_model(version='c', device='cpu'):
46
+ """
47
+ Loads the correct model based on version: 'f', 'c', or 'q'
48
+ """
49
+ filename = f"Vbai-2.1{version}.pt"
50
+
51
  weights_path = hf_hub_download(
52
  repo_id="Neurazum/Vbai-DPA-2.1",
53
  filename=filename,
54
  repo_type="model"
55
  )
56
 
57
+ model = SimpleCNN(model_type=version, num_classes=6).to(device)
58
  state = torch.load(weights_path, map_location=device)
59
  model.load_state_dict(state)
60
  model.eval()