Chess_with_AI / src /chess_ML /model.py
ncn2569's picture
initial deploy
53b9b08
import torch.nn as nn
class ChessModel(nn.Module):
def __init__(self, num_classes):
super(ChessModel, self).__init__()
# conv1 -> relu -> conv2 -> relu -> flatten -> fc1 -> relu -> fc2
self.conv1 = nn.Conv2d(13, 64, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(8 * 8 * 128, 256) #8*8*128 = 8192
self.fc2 = nn.Linear(256, num_classes)
self.relu = nn.ReLU()
# Initialize weights
nn.init.kaiming_uniform_(self.conv1.weight, nonlinearity='relu')
nn.init.kaiming_uniform_(self.conv2.weight, nonlinearity='relu')
nn.init.xavier_uniform_(self.fc1.weight)
nn.init.xavier_uniform_(self.fc2.weight)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.fc2(x) # Output raw logits
return x