import torch.nn as nn class ChessModel(nn.Module): def __init__(self, num_classes): super(ChessModel, self).__init__() # conv1 -> relu -> conv2 -> relu -> flatten -> fc1 -> relu -> fc2 self.conv1 = nn.Conv2d(13, 64, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.flatten = nn.Flatten() self.fc1 = nn.Linear(8 * 8 * 128, 256) #8*8*128 = 8192 self.fc2 = nn.Linear(256, num_classes) self.relu = nn.ReLU() # Initialize weights nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu') nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu') nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.flatten(x) x = self.relu(self.fc1(x)) x = self.fc2(x) # Output raw logits return x