Spaces:
Running
Running
File size: 2,563 Bytes
4d48c23 6498090 4d48c23 6498090 4d48c23 3899ece 4e34900 6498090 3899ece 6498090 3899ece 6498090 3899ece 4e34900 3899ece 4d48c23 3899ece 4d48c23 3899ece 4d48c23 6498090 4e34900 6498090 4d48c23 4e34900 3899ece a0f945e 4d48c23 3899ece |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
class SimpleCNN(nn.Module):
def __init__(self, model_type='f', num_classes=6):
super(SimpleCNN, self).__init__()
self.num_classes = num_classes
self.model_type = model_type
if model_type == 'f':
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(64 * 28 * 28, 256)
self.dropout = nn.Dropout(0.5)
elif model_type == 'c':
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(128 * 28 * 28, 512)
self.dropout = nn.Dropout(0.5)
elif model_type == 'q':
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(512 * 14 * 14, 1024)
self.dropout = nn.Dropout(0.3)
else:
raise ValueError(f"Unknown model type: {model_type}")
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc2 = nn.Linear(self.fc1.out_features, num_classes)
def forward(self, x):
x = self.pool(self.relu(self.conv1(x)))
x = self.pool(self.relu(self.conv2(x)))
x = self.pool(self.relu(self.conv3(x)))
if self.model_type == 'q':
x = self.pool(self.relu(self.conv4(x)))
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
def load_model(version='c', device='cpu'):
model_type = version.lower()
filename = f"Vbai-2.1{model_type}.pt"
weights_path = hf_hub_download(
repo_id="Neurazum/Vbai-DPA-2.1",
filename=filename,
repo_type="model"
)
model = SimpleCNN(model_type=model_type, num_classes=6).to(device)
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict, strict=False)
model.eval()
return model |