import torch import torch.nn as nn import torch.optim as optim import random import numpy as np from collections import deque import copy class QuadEnv: def __init__(self): self.reset() def reset(self): self.board = np.full((8, 8), -1, dtype=int) self.hand = [self.generate_piece() for _ in range(3)] return self.get_state() def generate_piece(self): while True: p = np.random.randint(0, 4, (2, 2)).tolist() counts = [0] * 4 for r in range(2): for c in range(2): counts[p[r][c]] += 1 if max(counts) < 3: return p def get_state(self): return torch.FloatTensor(self.board.flatten()).unsqueeze(0) def can_place(self, piece, r, c): for ir in range(2): for ic in range(2): if r+ir >= 8 or c+ic >= 8 or self.board[r+ir][c+ic] != -1: return False return True def step(self, action_idx): p_idx = action_idx // 196 rem = action_idx % 196 rot = rem // 49 rem2 = rem % 49 r = rem2 // 7 c = rem2 % 7 piece = self.hand[p_idx] if piece is None or not self.can_place(piece, r, c): return self.get_state(), -50.0, True # 置けない場合は即終了ペナルティ for _ in range(rot): piece = [[piece[1][0], piece[0][0]], [piece[1][1], piece[0][1]]] for ir in range(2): for ic in range(2): self.board[r+ir][c+ic] = piece[ir][ic] self.hand[p_idx] = self.generate_piece() score, done = self.process_matches() # 通常の配置完了で微小な報酬、スコアで大きな報酬 reward = 1.0 + (score / 10.0) return self.get_state(), float(reward), done def process_matches(self): score = 0 combo = 0 while True: visited = [[False]*8 for _ in range(8)] to_remove = set() for r in range(8): for c in range(8): color = self.board[r][c] if 0 <= color <= 3 and not visited[r][c]: q = [(r, c)] visited[r][c] = True group = [(r, c)] while q: cr, cc = q.pop(0) for dr, dc in [(-1,0), (1,0), (0,-1), (0,1)]: nr, nc = cr + dr, cc + dc if 0 <= nr < 8 and 0 <= nc < 8 and not visited[nr][nc] and self.board[nr][nc] == color: visited[nr][nc] = True q.append((nr, nc)) group.append((nr, nc)) if len(group) >= 3: for gr, gc in group: to_remove.add((gr, gc)) if not to_remove: break combo += 1 score += len(to_remove) * 10 * combo for rr, cc in to_remove: self.board[rr][cc] = -1 # 置ける場所があるかチェック any_valid = False for p in self.hand: if p is not None: for rr in range(7): for cc in range(7): if self.can_place(p, rr, cc): any_valid = True break if any_valid: break if any_valid: break return score, not any_valid class DQN(nn.Module): def __init__(self, input_size, output_size): super(DQN, self).__init__() self.fc = nn.Sequential( nn.Linear(input_size, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, output_size) ) def forward(self, x): return self.fc(x) def train(): env = QuadEnv() policy_net = DQN(64, 588) target_net = copy.deepcopy(policy_net) target_net.eval() optimizer = optim.Adam(policy_net.parameters(), lr=0.0005) memory = deque(maxlen=20000) batch_size = 64 gamma = 0.95 epsilon = 1.0 epsilon_min = 0.05 epsilon_decay = 0.995 epochs = 2000 for epoch in range(epochs): state = env.reset() done = False total_reward = 0 step_count = 0 while not done: if random.random() < epsilon: action_idx = random.randint(0, 587) else: with torch.no_grad(): action_idx = policy_net(state).argmax().item() next_state, reward, done = env.step(action_idx) memory.append((state, action_idx, reward, next_state, done)) state = next_state total_reward += reward step_count += 1 # 経験再生 if len(memory) > batch_size: batch = random.sample(memory, batch_size) states, actions, rewards, next_states, dones = zip(*batch) states = torch.cat(states) actions = torch.tensor(actions).unsqueeze(1) rewards = torch.tensor(rewards, dtype=torch.float32) next_states = torch.cat(next_states) dones = torch.tensor(dones, dtype=torch.float32) q_values = policy_net(states).gather(1, actions).squeeze(1) with torch.no_grad(): next_q_values = target_net(next_states).max(1)[0] target_q_values = rewards + gamma * next_q_values * (1 - dones) loss = nn.MSELoss()(q_values, target_q_values) optimizer.zero_grad() loss.backward() optimizer.step() if done: break epsilon = max(epsilon_min, epsilon * epsilon_decay) # 定期的にターゲットネットワークを更新 if epoch % 10 == 0: target_net.load_state_dict(policy_net.state_dict()) if epoch % 10 == 0: print(f"Epoch {epoch} | Total Reward: {total_reward:.1f} | Steps: {step_count} | Epsilon: {epsilon:.3f}") torch.save(policy_net.state_dict(), "model.pth") print("Training Complete. Model saved as model.pth") if __name__ == "__main__": train()