quad / train.py
horiyouta's picture
2603271026
350facf
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
import copy
class QuadEnv:
def __init__(self):
self.reset()
def reset(self):
self.board = np.full((8, 8), -1, dtype=int)
self.hand = [self.generate_piece() for _ in range(3)]
return self.get_state()
def generate_piece(self):
while True:
p = np.random.randint(0, 4, (2, 2)).tolist()
counts = [0] * 4
for r in range(2):
for c in range(2):
counts[p[r][c]] += 1
if max(counts) < 3:
return p
def get_state(self):
return torch.FloatTensor(self.board.flatten()).unsqueeze(0)
def can_place(self, piece, r, c):
for ir in range(2):
for ic in range(2):
if r+ir >= 8 or c+ic >= 8 or self.board[r+ir][c+ic] != -1:
return False
return True
def step(self, action_idx):
p_idx = action_idx // 196
rem = action_idx % 196
rot = rem // 49
rem2 = rem % 49
r = rem2 // 7
c = rem2 % 7
piece = self.hand[p_idx]
if piece is None or not self.can_place(piece, r, c):
return self.get_state(), -50.0, True # 置けない場合は即終了ペナルティ
for _ in range(rot):
piece = [[piece[1][0], piece[0][0]], [piece[1][1], piece[0][1]]]
for ir in range(2):
for ic in range(2):
self.board[r+ir][c+ic] = piece[ir][ic]
self.hand[p_idx] = self.generate_piece()
score, done = self.process_matches()
# 通常の配置完了で微小な報酬、スコアで大きな報酬
reward = 1.0 + (score / 10.0)
return self.get_state(), float(reward), done
def process_matches(self):
score = 0
combo = 0
while True:
visited = [[False]*8 for _ in range(8)]
to_remove = set()
for r in range(8):
for c in range(8):
color = self.board[r][c]
if 0 <= color <= 3 and not visited[r][c]:
q = [(r, c)]
visited[r][c] = True
group = [(r, c)]
while q:
cr, cc = q.pop(0)
for dr, dc in [(-1,0), (1,0), (0,-1), (0,1)]:
nr, nc = cr + dr, cc + dc
if 0 <= nr < 8 and 0 <= nc < 8 and not visited[nr][nc] and self.board[nr][nc] == color:
visited[nr][nc] = True
q.append((nr, nc))
group.append((nr, nc))
if len(group) >= 3:
for gr, gc in group:
to_remove.add((gr, gc))
if not to_remove:
break
combo += 1
score += len(to_remove) * 10 * combo
for rr, cc in to_remove:
self.board[rr][cc] = -1
# 置ける場所があるかチェック
any_valid = False
for p in self.hand:
if p is not None:
for rr in range(7):
for cc in range(7):
if self.can_place(p, rr, cc):
any_valid = True
break
if any_valid: break
if any_valid: break
return score, not any_valid
class DQN(nn.Module):
def __init__(self, input_size, output_size):
super(DQN, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_size, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(),
nn.Linear(256, output_size)
)
def forward(self, x):
return self.fc(x)
def train():
env = QuadEnv()
policy_net = DQN(64, 588)
target_net = copy.deepcopy(policy_net)
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=0.0005)
memory = deque(maxlen=20000)
batch_size = 64
gamma = 0.95
epsilon = 1.0
epsilon_min = 0.05
epsilon_decay = 0.995
epochs = 2000
for epoch in range(epochs):
state = env.reset()
done = False
total_reward = 0
step_count = 0
while not done:
if random.random() < epsilon:
action_idx = random.randint(0, 587)
else:
with torch.no_grad():
action_idx = policy_net(state).argmax().item()
next_state, reward, done = env.step(action_idx)
memory.append((state, action_idx, reward, next_state, done))
state = next_state
total_reward += reward
step_count += 1
# 経験再生
if len(memory) > batch_size:
batch = random.sample(memory, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
states = torch.cat(states)
actions = torch.tensor(actions).unsqueeze(1)
rewards = torch.tensor(rewards, dtype=torch.float32)
next_states = torch.cat(next_states)
dones = torch.tensor(dones, dtype=torch.float32)
q_values = policy_net(states).gather(1, actions).squeeze(1)
with torch.no_grad():
next_q_values = target_net(next_states).max(1)[0]
target_q_values = rewards + gamma * next_q_values * (1 - dones)
loss = nn.MSELoss()(q_values, target_q_values)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if done: break
epsilon = max(epsilon_min, epsilon * epsilon_decay)
# 定期的にターゲットネットワークを更新
if epoch % 10 == 0:
target_net.load_state_dict(policy_net.state_dict())
if epoch % 10 == 0:
print(f"Epoch {epoch} | Total Reward: {total_reward:.1f} | Steps: {step_count} | Epsilon: {epsilon:.3f}")
torch.save(policy_net.state_dict(), "model.pth")
print("Training Complete. Model saved as model.pth")
if __name__ == "__main__":
train()