Spaces:
Build error
Build error
| import torch | |
| from torch import nn | |
| from data import BOARD_VECTOR_SIZE | |
| class ChessModel(nn.Module): | |
| def __init__(self, embedding_dims): | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| nn.Linear(BOARD_VECTOR_SIZE, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, 1024), | |
| nn.SiLU(), | |
| nn.Linear(1024, 1024), | |
| nn.SiLU(), | |
| nn.Linear(1024, embedding_dims), | |
| nn.SiLU(), | |
| ) | |
| self.popularity_head = nn.Sequential( | |
| nn.Linear(embedding_dims, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, 1), | |
| nn.Tanh(), | |
| ) | |
| # Since it will take too long for this to evaluate before the jam is over, just noop it. | |
| #self.evaluation_head = nn.Sequential( | |
| # nn.Linear(embedding_dims, 512), | |
| # nn.SiLU(), | |
| # nn.Linear(512, 1), | |
| # nn.Tanh(), | |
| #) | |
| self.evaluation_head = nn.Sequential( | |
| nn.Linear(embedding_dims, 1), | |
| ) | |
| self.reconstruction_head = nn.Sequential( | |
| nn.Linear(embedding_dims, 512), | |
| nn.SiLU(), | |
| nn.Linear(512, BOARD_VECTOR_SIZE), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| """Return the embedding, popularity, evaluation, and reconstruction.""" | |
| # Outputs have three heads: one for the board reconstruction, one for the popularity, and one for the eval. | |
| embedding = self.encoder(x) | |
| popularity = self.popularity_head(embedding) | |
| evaluation = self.evaluation_head(embedding) | |
| reconstruction = self.reconstruction_head(embedding) | |
| return embedding, popularity, evaluation, reconstruction | |