Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| class ChessModel(nn.Module): | |
| def __init__(self, num_classes): | |
| super(ChessModel, self).__init__() | |
| # conv1 -> relu -> conv2 -> relu -> flatten -> fc1 -> relu -> fc2 | |
| self.conv1 = nn.Conv2d(13, 64, kernel_size=3, padding=1) | |
| self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) | |
| self.flatten = nn.Flatten() | |
| self.fc1 = nn.Linear(8 * 8 * 128, 256) #8*8*128 = 8192 | |
| self.fc2 = nn.Linear(256, num_classes) | |
| self.relu = nn.ReLU() | |
| # Initialize weights | |
| nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu') | |
| nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu') | |
| nn.init.xavier_uniform_(self.fc1.weight) | |
| nn.init.xavier_uniform_(self.fc2.weight) | |
| def forward(self, x): | |
| x = self.relu(self.conv1(x)) | |
| x = self.relu(self.conv2(x)) | |
| x = self.flatten(x) | |
| x = self.relu(self.fc1(x)) | |
| x = self.fc2(x) # Output raw logits | |
| return x | |