eyupipler commited on
Commit
5fedaa6
·
verified ·
1 Parent(s): aaa9386

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -10
model.py CHANGED
@@ -7,23 +7,19 @@ class SimpleCNN(nn.Module):
7
  super(SimpleCNN, self).__init__()
8
  self.num_classes = num_classes
9
  self.model_type = model_type
10
- # Model architectures assume 224x224 input
11
  if model_type == 'f':
12
- # After 3 pool layers: 224 -> 112 -> 56 -> 28
13
  self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
14
  self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
15
  self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
16
  self.fc1 = nn.Linear(64 * 28 * 28, 256)
17
  self.dropout = nn.Dropout(0.5)
18
  elif model_type == 'c':
19
- # After 3 pool layers: 224 -> 112 -> 56 -> 28
20
  self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
21
  self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
22
  self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
23
  self.fc1 = nn.Linear(128 * 28 * 28, 512)
24
  self.dropout = nn.Dropout(0.5)
25
  elif model_type == 'q':
26
- # After 4 pool layers: 224 -> 112 -> 56 -> 28 -> 14
27
  self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
28
  self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
  self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
@@ -51,21 +47,15 @@ class SimpleCNN(nn.Module):
51
 
52
 
53
  def load_model(version='c', device='cpu'):
54
- """
55
- Downloads and loads the SimpleCNN model for the specified version: 'f', 'c', or 'q'.
56
- Input images must be resized to 224x224.
57
- """
58
  model_type = version.lower()
59
  filename = f"Vbai-2.1{model_type}.pt"
60
 
61
- # Download the weight file from Hugging Face Hub
62
  weights_path = hf_hub_download(
63
  repo_id="Neurazum/Vbai-DPA-2.1",
64
  filename=filename,
65
  repo_type="model"
66
  )
67
 
68
- # Initialize and load model
69
  model = SimpleCNN(model_type=model_type, num_classes=6).to(device)
70
  state_dict = torch.load(weights_path, map_location=device)
71
  model.load_state_dict(state_dict, strict=False)
 
7
  super(SimpleCNN, self).__init__()
8
  self.num_classes = num_classes
9
  self.model_type = model_type
 
10
  if model_type == 'f':
 
11
  self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
12
  self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
13
  self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
14
  self.fc1 = nn.Linear(64 * 28 * 28, 256)
15
  self.dropout = nn.Dropout(0.5)
16
  elif model_type == 'c':
 
17
  self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
18
  self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
19
  self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
20
  self.fc1 = nn.Linear(128 * 28 * 28, 512)
21
  self.dropout = nn.Dropout(0.5)
22
  elif model_type == 'q':
 
23
  self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
24
  self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
25
  self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
 
47
 
48
 
49
  def load_model(version='c', device='cpu'):
 
 
 
 
50
  model_type = version.lower()
51
  filename = f"Vbai-2.1{model_type}.pt"
52
 
 
53
  weights_path = hf_hub_download(
54
  repo_id="Neurazum/Vbai-DPA-2.1",
55
  filename=filename,
56
  repo_type="model"
57
  )
58
 
 
59
  model = SimpleCNN(model_type=model_type, num_classes=6).to(device)
60
  state_dict = torch.load(weights_path, map_location=device)
61
  model.load_state_dict(state_dict, strict=False)