import torch from transformers import AutoTokenizer, GPT2Model import torch.nn as nn class ChessMoveClassifier(nn.Module): def __init__(self, model_name, num_labels=4096): super().__init__() self.base_model = GPT2Model.from_pretrained(model_name) self.dropout = nn.Dropout(0.1) self.classifier = nn.Linear(self.base_model.config.n_embd, num_labels) def forward(self, input_ids, attention_mask=None, **kwargs): outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) hidden_state = outputs.last_hidden_state[:, -1, :] logits = self.classifier(self.dropout(hidden_state)) return {"logits": logits} def model_fn(model_dir): model = ChessMoveClassifier(model_name="austindavis/ChessGPT_d12") model.load_state_dict(torch.load(f"{model_dir}/model.pt", map_location="cpu")) model.eval() return model