import torch import numpy as np class DQN(torch.nn.Module): def __init__(self, input_size=42, hidden_size=128, output_size=7): super(DQN, self).__init__() self.fc1 = torch.nn.Linear(input_size, hidden_size) self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Linear(hidden_size, output_size) def forward(self, x): x = self.fc1(x) x = self.relu(x) return self.fc2(x) def load_model(path): model = DQN() model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) model.eval() return model def get_best_action(board, model): flat_state = torch.tensor(board.flatten(), dtype=torch.float32).unsqueeze(0) with torch.no_grad(): q_values = model(flat_state) valid_actions = [c for c in range(7) if board[0][c] == 0] q_values[0, [i for i in range(7) if i not in valid_actions]] = -float('inf') return torch.argmax(q_values).item()