Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import hf_hub_download | |
| class SimpleCNN(nn.Module): | |
| def __init__(self, model_type='f', num_classes=6): | |
| super(SimpleCNN, self).__init__() | |
| self.num_classes = num_classes | |
| self.model_type = model_type | |
| if model_type == 'f': | |
| self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) | |
| self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) | |
| self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) | |
| self.fc1 = nn.Linear(64 * 28 * 28, 256) | |
| self.dropout = nn.Dropout(0.5) | |
| elif model_type == 'c': | |
| self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1) | |
| self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) | |
| self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) | |
| self.fc1 = nn.Linear(128 * 28 * 28, 512) | |
| self.dropout = nn.Dropout(0.5) | |
| elif model_type == 'q': | |
| self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) | |
| self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) | |
| self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) | |
| self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) | |
| self.fc1 = nn.Linear(512 * 14 * 14, 1024) | |
| self.dropout = nn.Dropout(0.3) | |
| else: | |
| raise ValueError(f"Unknown model type: {model_type}") | |
| self.relu = nn.ReLU() | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.fc2 = nn.Linear(self.fc1.out_features, num_classes) | |
| def forward(self, x): | |
| x = self.pool(self.relu(self.conv1(x))) | |
| x = self.pool(self.relu(self.conv2(x))) | |
| x = self.pool(self.relu(self.conv3(x))) | |
| if self.model_type == 'q': | |
| x = self.pool(self.relu(self.conv4(x))) | |
| x = x.view(x.size(0), -1) | |
| x = self.relu(self.fc1(x)) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| return x | |
| def load_model(version='c', device='cpu'): | |
| model_type = version.lower() | |
| filename = f"Vbai-2.1{model_type}.pt" | |
| weights_path = hf_hub_download( | |
| repo_id="Neurazum/Vbai-DPA-2.1", | |
| filename=filename, | |
| repo_type="model" | |
| ) | |
| model = SimpleCNN(model_type=model_type, num_classes=6).to(device) | |
| state_dict = torch.load(weights_path, map_location=device) | |
| model.load_state_dict(state_dict, strict=False) | |
| model.eval() | |
| return model |