| import torch
|
| import torch.nn as nn
|
| import torch.optim as optim
|
| import random
|
| import numpy as np
|
| from collections import deque
|
| import copy
|
|
|
| class QuadEnv:
|
| def __init__(self):
|
| self.reset()
|
|
|
| def reset(self):
|
| self.board = np.full((8, 8), -1, dtype=int)
|
| self.hand = [self.generate_piece() for _ in range(3)]
|
| return self.get_state()
|
|
|
| def generate_piece(self):
|
| while True:
|
| p = np.random.randint(0, 4, (2, 2)).tolist()
|
| counts = [0] * 4
|
| for r in range(2):
|
| for c in range(2):
|
| counts[p[r][c]] += 1
|
| if max(counts) < 3:
|
| return p
|
|
|
| def get_state(self):
|
| return torch.FloatTensor(self.board.flatten()).unsqueeze(0)
|
|
|
| def can_place(self, piece, r, c):
|
| for ir in range(2):
|
| for ic in range(2):
|
| if r+ir >= 8 or c+ic >= 8 or self.board[r+ir][c+ic] != -1:
|
| return False
|
| return True
|
|
|
| def step(self, action_idx):
|
| p_idx = action_idx // 196
|
| rem = action_idx % 196
|
| rot = rem // 49
|
| rem2 = rem % 49
|
| r = rem2 // 7
|
| c = rem2 % 7
|
|
|
| piece = self.hand[p_idx]
|
| if piece is None or not self.can_place(piece, r, c):
|
| return self.get_state(), -50.0, True
|
|
|
| for _ in range(rot):
|
| piece = [[piece[1][0], piece[0][0]], [piece[1][1], piece[0][1]]]
|
|
|
| for ir in range(2):
|
| for ic in range(2):
|
| self.board[r+ir][c+ic] = piece[ir][ic]
|
|
|
| self.hand[p_idx] = self.generate_piece()
|
| score, done = self.process_matches()
|
|
|
|
|
| reward = 1.0 + (score / 10.0)
|
| return self.get_state(), float(reward), done
|
|
|
| def process_matches(self):
|
| score = 0
|
| combo = 0
|
| while True:
|
| visited = [[False]*8 for _ in range(8)]
|
| to_remove = set()
|
| for r in range(8):
|
| for c in range(8):
|
| color = self.board[r][c]
|
| if 0 <= color <= 3 and not visited[r][c]:
|
| q = [(r, c)]
|
| visited[r][c] = True
|
| group = [(r, c)]
|
| while q:
|
| cr, cc = q.pop(0)
|
| for dr, dc in [(-1,0), (1,0), (0,-1), (0,1)]:
|
| nr, nc = cr + dr, cc + dc
|
| if 0 <= nr < 8 and 0 <= nc < 8 and not visited[nr][nc] and self.board[nr][nc] == color:
|
| visited[nr][nc] = True
|
| q.append((nr, nc))
|
| group.append((nr, nc))
|
| if len(group) >= 3:
|
| for gr, gc in group:
|
| to_remove.add((gr, gc))
|
| if not to_remove:
|
| break
|
| combo += 1
|
| score += len(to_remove) * 10 * combo
|
| for rr, cc in to_remove:
|
| self.board[rr][cc] = -1
|
|
|
|
|
| any_valid = False
|
| for p in self.hand:
|
| if p is not None:
|
| for rr in range(7):
|
| for cc in range(7):
|
| if self.can_place(p, rr, cc):
|
| any_valid = True
|
| break
|
| if any_valid: break
|
| if any_valid: break
|
|
|
| return score, not any_valid
|
|
|
| class DQN(nn.Module):
|
| def __init__(self, input_size, output_size):
|
| super(DQN, self).__init__()
|
| self.fc = nn.Sequential(
|
| nn.Linear(input_size, 256), nn.ReLU(),
|
| nn.Linear(256, 256), nn.ReLU(),
|
| nn.Linear(256, output_size)
|
| )
|
| def forward(self, x):
|
| return self.fc(x)
|
|
|
| def train():
|
| env = QuadEnv()
|
| policy_net = DQN(64, 588)
|
| target_net = copy.deepcopy(policy_net)
|
| target_net.eval()
|
|
|
| optimizer = optim.Adam(policy_net.parameters(), lr=0.0005)
|
| memory = deque(maxlen=20000)
|
|
|
| batch_size = 64
|
| gamma = 0.95
|
| epsilon = 1.0
|
| epsilon_min = 0.05
|
| epsilon_decay = 0.995
|
|
|
| epochs = 2000
|
| for epoch in range(epochs):
|
| state = env.reset()
|
| done = False
|
| total_reward = 0
|
| step_count = 0
|
|
|
| while not done:
|
| if random.random() < epsilon:
|
| action_idx = random.randint(0, 587)
|
| else:
|
| with torch.no_grad():
|
| action_idx = policy_net(state).argmax().item()
|
|
|
| next_state, reward, done = env.step(action_idx)
|
| memory.append((state, action_idx, reward, next_state, done))
|
| state = next_state
|
| total_reward += reward
|
| step_count += 1
|
|
|
|
|
| if len(memory) > batch_size:
|
| batch = random.sample(memory, batch_size)
|
| states, actions, rewards, next_states, dones = zip(*batch)
|
|
|
| states = torch.cat(states)
|
| actions = torch.tensor(actions).unsqueeze(1)
|
| rewards = torch.tensor(rewards, dtype=torch.float32)
|
| next_states = torch.cat(next_states)
|
| dones = torch.tensor(dones, dtype=torch.float32)
|
|
|
| q_values = policy_net(states).gather(1, actions).squeeze(1)
|
| with torch.no_grad():
|
| next_q_values = target_net(next_states).max(1)[0]
|
| target_q_values = rewards + gamma * next_q_values * (1 - dones)
|
|
|
| loss = nn.MSELoss()(q_values, target_q_values)
|
| optimizer.zero_grad()
|
| loss.backward()
|
| optimizer.step()
|
|
|
| if done: break
|
|
|
| epsilon = max(epsilon_min, epsilon * epsilon_decay)
|
|
|
|
|
| if epoch % 10 == 0:
|
| target_net.load_state_dict(policy_net.state_dict())
|
|
|
| if epoch % 10 == 0:
|
| print(f"Epoch {epoch} | Total Reward: {total_reward:.1f} | Steps: {step_count} | Epsilon: {epsilon:.3f}")
|
|
|
| torch.save(policy_net.state_dict(), "model.pth")
|
| print("Training Complete. Model saved as model.pth")
|
|
|
| if __name__ == "__main__":
|
| train() |