gopal7093 commited on
Commit
e62bb77
·
verified ·
1 Parent(s): 54fd8ef

Create connect4_agent.py

Browse files
Files changed (1) hide show
  1. connect4_agent.py +28 -0
connect4_agent.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ class DQN(torch.nn.Module):
5
+ def __init__(self, input_size=42, hidden_size=128, output_size=7):
6
+ super(DQN, self).__init__()
7
+ self.fc1 = torch.nn.Linear(input_size, hidden_size)
8
+ self.relu = torch.nn.ReLU()
9
+ self.fc2 = torch.nn.Linear(hidden_size, output_size)
10
+
11
+ def forward(self, x):
12
+ x = self.fc1(x)
13
+ x = self.relu(x)
14
+ return self.fc2(x)
15
+
16
+ def load_model(path):
17
+ model = DQN()
18
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
19
+ model.eval()
20
+ return model
21
+
22
+ def get_best_action(board, model):
23
+ flat_state = torch.tensor(board.flatten(), dtype=torch.float32).unsqueeze(0)
24
+ with torch.no_grad():
25
+ q_values = model(flat_state)
26
+ valid_actions = [c for c in range(7) if board[0][c] == 0]
27
+ q_values[0, [i for i in range(7) if i not in valid_actions]] = -float('inf')
28
+ return torch.argmax(q_values).item()