File size: 6,844 Bytes
350facf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
import copy
class QuadEnv:
def __init__(self):
self.reset()
def reset(self):
self.board = np.full((8, 8), -1, dtype=int)
self.hand = [self.generate_piece() for _ in range(3)]
return self.get_state()
def generate_piece(self):
while True:
p = np.random.randint(0, 4, (2, 2)).tolist()
counts = [0] * 4
for r in range(2):
for c in range(2):
counts[p[r][c]] += 1
if max(counts) < 3:
return p
def get_state(self):
return torch.FloatTensor(self.board.flatten()).unsqueeze(0)
def can_place(self, piece, r, c):
for ir in range(2):
for ic in range(2):
if r+ir >= 8 or c+ic >= 8 or self.board[r+ir][c+ic] != -1:
return False
return True
def step(self, action_idx):
p_idx = action_idx // 196
rem = action_idx % 196
rot = rem // 49
rem2 = rem % 49
r = rem2 // 7
c = rem2 % 7
piece = self.hand[p_idx]
if piece is None or not self.can_place(piece, r, c):
return self.get_state(), -50.0, True # 置けない場合は即終了ペナルティ
for _ in range(rot):
piece = [[piece[1][0], piece[0][0]], [piece[1][1], piece[0][1]]]
for ir in range(2):
for ic in range(2):
self.board[r+ir][c+ic] = piece[ir][ic]
self.hand[p_idx] = self.generate_piece()
score, done = self.process_matches()
# 通常の配置完了で微小な報酬、スコアで大きな報酬
reward = 1.0 + (score / 10.0)
return self.get_state(), float(reward), done
def process_matches(self):
score = 0
combo = 0
while True:
visited = [[False]*8 for _ in range(8)]
to_remove = set()
for r in range(8):
for c in range(8):
color = self.board[r][c]
if 0 <= color <= 3 and not visited[r][c]:
q = [(r, c)]
visited[r][c] = True
group = [(r, c)]
while q:
cr, cc = q.pop(0)
for dr, dc in [(-1,0), (1,0), (0,-1), (0,1)]:
nr, nc = cr + dr, cc + dc
if 0 <= nr < 8 and 0 <= nc < 8 and not visited[nr][nc] and self.board[nr][nc] == color:
visited[nr][nc] = True
q.append((nr, nc))
group.append((nr, nc))
if len(group) >= 3:
for gr, gc in group:
to_remove.add((gr, gc))
if not to_remove:
break
combo += 1
score += len(to_remove) * 10 * combo
for rr, cc in to_remove:
self.board[rr][cc] = -1
# 置ける場所があるかチェック
any_valid = False
for p in self.hand:
if p is not None:
for rr in range(7):
for cc in range(7):
if self.can_place(p, rr, cc):
any_valid = True
break
if any_valid: break
if any_valid: break
return score, not any_valid
class DQN(nn.Module):
def __init__(self, input_size, output_size):
super(DQN, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_size, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, output_size)
)
def forward(self, x):
return self.fc(x)
def train():
env = QuadEnv()
policy_net = DQN(64, 588)
target_net = copy.deepcopy(policy_net)
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=0.0005)
memory = deque(maxlen=20000)
batch_size = 64
gamma = 0.95
epsilon = 1.0
epsilon_min = 0.05
epsilon_decay = 0.995
epochs = 2000
for epoch in range(epochs):
state = env.reset()
done = False
total_reward = 0
step_count = 0
while not done:
if random.random() < epsilon:
action_idx = random.randint(0, 587)
else:
with torch.no_grad():
action_idx = policy_net(state).argmax().item()
next_state, reward, done = env.step(action_idx)
memory.append((state, action_idx, reward, next_state, done))
state = next_state
total_reward += reward
step_count += 1
# 経験再生
if len(memory) > batch_size:
batch = random.sample(memory, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.cat(states)
actions = torch.tensor(actions).unsqueeze(1)
rewards = torch.tensor(rewards, dtype=torch.float32)
next_states = torch.cat(next_states)
dones = torch.tensor(dones, dtype=torch.float32)
q_values = policy_net(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_net(next_states).max(1)[0]
target_q_values = rewards + gamma * next_q_values * (1 - dones)
loss = nn.MSELoss()(q_values, target_q_values)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if done: break
epsilon = max(epsilon_min, epsilon * epsilon_decay)
# 定期的にターゲットネットワークを更新
if epoch % 10 == 0:
target_net.load_state_dict(policy_net.state_dict())
if epoch % 10 == 0:
print(f"Epoch {epoch} | Total Reward: {total_reward:.1f} | Steps: {step_count} | Epsilon: {epsilon:.3f}")
torch.save(policy_net.state_dict(), "model.pth")
print("Training Complete. Model saved as model.pth")
if __name__ == "__main__":
train() |