import pickle import random import time from collections import deque import gym_super_mario_bros import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from gym_super_mario_bros.actions import COMPLEX_MOVEMENT from nes_py.wrappers import JoypadSpace from wrappers import * def arrange(s): if not type(s) == "numpy.ndarray": s = np.array(s) assert len(s.shape) == 3 ret = np.transpose(s, (2, 0, 1)) return np.expand_dims(ret, 0) class replay_memory(object): def __init__(self, N): self.memory = deque(maxlen=N) def push(self, transition): self.memory.append(transition) def sample(self, n): return random.sample(self.memory, n) def __len__(self): return len(self.memory) class model(nn.Module): def __init__(self, n_frame, n_action, device): super(model, self).__init__() self.layer1 = nn.Conv2d(n_frame, 32, 8, 4) self.layer2 = nn.Conv2d(32, 64, 3, 1) self.fc = nn.Linear(20736, 512) self.q = nn.Linear(512, n_action) self.v = nn.Linear(512, 1) self.device = device self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v) self.seq.apply(init_weights) def forward(self, x): if type(x) != torch.Tensor: x = torch.FloatTensor(x).to(self.device) x = torch.relu(self.layer1(x)) x = torch.relu(self.layer2(x)) x = x.view(-1, 20736) x = torch.relu(self.fc(x)) adv = self.q(x) v = self.v(x) q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True)) return q def init_weights(m): if type(m) == nn.Conv2d: torch.nn.init.xavier_uniform_(m.weight) m.bias.data.fill_(0.01) def train(q, q_target, memory, batch_size, gamma, optimizer, device): s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size)))) s = np.array(s).squeeze() s_prime = np.array(s_prime).squeeze() a_max = q(s_prime).max(1)[1].unsqueeze(-1) r = torch.FloatTensor(r).unsqueeze(-1).to(device) done = torch.FloatTensor(done).unsqueeze(-1).to(device) with torch.no_grad(): y = r + gamma * q_target(s_prime).gather(1, a_max) * done a = torch.tensor(a).unsqueeze(-1).to(device) q_value = torch.gather(q(s), dim=1, index=a.view(-1, 1).long()) loss = F.smooth_l1_loss(q_value, y).mean() optimizer.zero_grad() loss.backward() optimizer.step() return loss def copy_weights(q, q_target): q_dict = q.state_dict() q_target.load_state_dict(q_dict) def main(env, q, q_target, optimizer, device): t = 0 gamma = 0.99 batch_size = 256 N = 50000 eps = 0.001 memory = replay_memory(N) update_interval = 50 print_interval = 10 score_lst = [] total_score = 0.0 loss = 0.0 start_time = time.perf_counter() for k in range(1000000): s = arrange(env.reset()) done = False while not done: if eps > np.random.rand(): a = env.action_space.sample() else: if device == "cpu": a = np.argmax(q(s).detach().numpy()) else: a = np.argmax(q(s).cpu().detach().numpy()) s_prime, r, done, _ = env.step(a) s_prime = arrange(s_prime) total_score += r r = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r memory.push((s, float(r), int(a), s_prime, int(1 - done))) s = s_prime stage = env.unwrapped._stage if len(memory) > 2000: loss += train(q, q_target, memory, batch_size, gamma, optimizer, device) t += 1 if t % update_interval == 0: copy_weights(q, q_target) torch.save(q.state_dict(), "mario_q.pth") torch.save(q_target.state_dict(), "mario_q_target.pth") if k % print_interval == 0: time_spent, start_time = ( time.perf_counter() - start_time, time.perf_counter(), ) print( "%s |Epoch : %d | score : %f | loss : %.2f | stage : %d | time spent: %f" % ( device, k, total_score / print_interval, loss / print_interval, stage, time_spent, ) ) score_lst.append(total_score / print_interval) total_score = 0 loss = 0.0 pickle.dump(score_lst, open("score.p", "wb")) if __name__ == "__main__": n_frame = 4 env = gym_super_mario_bros.make("SuperMarioBros-v0") env = JoypadSpace(env, COMPLEX_MOVEMENT) env = wrap_mario(env) device = "cpu" if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" q = model(n_frame, env.action_space.n, device).to(device) q_target = model(n_frame, env.action_space.n, device).to(device) optimizer = optim.Adam(q.parameters(), lr=0.0001) print(device) main(env, q, q_target, optimizer, device)