Upload 7 files
Browse files- .gitattributes +4 -0
- Screenshot 2025-11-07 at 1.03.27 PM.png +3 -0
- duel_dqn.py +178 -0
- flux_krea_00776_.png +3 -0
- output.mp4 +3 -0
- pyqt5_duel_dqn_super_mario_bros_tutorial/app.py +2294 -0
- pyqt5_duel_dqn_super_mario_bros_tutorial/flux_krea_00776_.png +3 -0
- pyqt5_duel_dqn_super_mario_bros_tutorial/installed_packages_dqn.txt +4 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
flux_krea_00776_.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
output.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
pyqt5_duel_dqn_super_mario_bros_tutorial/flux_krea_00776_.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
Screenshot[[:space:]]2025-11-07[[:space:]]at[[:space:]]1.03.27 PM.png filter=lfs diff=lfs merge=lfs -text
|
Screenshot 2025-11-07 at 1.03.27 PM.png
ADDED
|
Git LFS Details
|
duel_dqn.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pickle
|
| 2 |
+
import random
|
| 3 |
+
import time
|
| 4 |
+
from collections import deque
|
| 5 |
+
|
| 6 |
+
import gym_super_mario_bros
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.optim as optim
|
| 12 |
+
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
|
| 13 |
+
from nes_py.wrappers import JoypadSpace
|
| 14 |
+
|
| 15 |
+
from wrappers import *
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def arrange(s):
|
| 19 |
+
if not type(s) == "numpy.ndarray":
|
| 20 |
+
s = np.array(s)
|
| 21 |
+
assert len(s.shape) == 3
|
| 22 |
+
ret = np.transpose(s, (2, 0, 1))
|
| 23 |
+
return np.expand_dims(ret, 0)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class replay_memory(object):
|
| 27 |
+
def __init__(self, N):
|
| 28 |
+
self.memory = deque(maxlen=N)
|
| 29 |
+
|
| 30 |
+
def push(self, transition):
|
| 31 |
+
self.memory.append(transition)
|
| 32 |
+
|
| 33 |
+
def sample(self, n):
|
| 34 |
+
return random.sample(self.memory, n)
|
| 35 |
+
|
| 36 |
+
def __len__(self):
|
| 37 |
+
return len(self.memory)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class model(nn.Module):
|
| 41 |
+
def __init__(self, n_frame, n_action, device):
|
| 42 |
+
super(model, self).__init__()
|
| 43 |
+
self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
|
| 44 |
+
self.layer2 = nn.Conv2d(32, 64, 3, 1)
|
| 45 |
+
self.fc = nn.Linear(20736, 512)
|
| 46 |
+
self.q = nn.Linear(512, n_action)
|
| 47 |
+
self.v = nn.Linear(512, 1)
|
| 48 |
+
|
| 49 |
+
self.device = device
|
| 50 |
+
self.seq = nn.Sequential(self.layer1, self.layer2, self.fc, self.q, self.v)
|
| 51 |
+
|
| 52 |
+
self.seq.apply(init_weights)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
if type(x) != torch.Tensor:
|
| 56 |
+
x = torch.FloatTensor(x).to(self.device)
|
| 57 |
+
x = torch.relu(self.layer1(x))
|
| 58 |
+
x = torch.relu(self.layer2(x))
|
| 59 |
+
x = x.view(-1, 20736)
|
| 60 |
+
x = torch.relu(self.fc(x))
|
| 61 |
+
adv = self.q(x)
|
| 62 |
+
v = self.v(x)
|
| 63 |
+
q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
|
| 64 |
+
|
| 65 |
+
return q
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def init_weights(m):
|
| 69 |
+
if type(m) == nn.Conv2d:
|
| 70 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 71 |
+
m.bias.data.fill_(0.01)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def train(q, q_target, memory, batch_size, gamma, optimizer, device):
|
| 75 |
+
s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size))))
|
| 76 |
+
s = np.array(s).squeeze()
|
| 77 |
+
s_prime = np.array(s_prime).squeeze()
|
| 78 |
+
a_max = q(s_prime).max(1)[1].unsqueeze(-1)
|
| 79 |
+
r = torch.FloatTensor(r).unsqueeze(-1).to(device)
|
| 80 |
+
done = torch.FloatTensor(done).unsqueeze(-1).to(device)
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
y = r + gamma * q_target(s_prime).gather(1, a_max) * done
|
| 83 |
+
a = torch.tensor(a).unsqueeze(-1).to(device)
|
| 84 |
+
q_value = torch.gather(q(s), dim=1, index=a.view(-1, 1).long())
|
| 85 |
+
|
| 86 |
+
loss = F.smooth_l1_loss(q_value, y).mean()
|
| 87 |
+
optimizer.zero_grad()
|
| 88 |
+
loss.backward()
|
| 89 |
+
optimizer.step()
|
| 90 |
+
return loss
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def copy_weights(q, q_target):
|
| 94 |
+
q_dict = q.state_dict()
|
| 95 |
+
q_target.load_state_dict(q_dict)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def main(env, q, q_target, optimizer, device):
|
| 99 |
+
t = 0
|
| 100 |
+
gamma = 0.99
|
| 101 |
+
batch_size = 256
|
| 102 |
+
|
| 103 |
+
N = 50000
|
| 104 |
+
eps = 0.001
|
| 105 |
+
memory = replay_memory(N)
|
| 106 |
+
update_interval = 50
|
| 107 |
+
print_interval = 10
|
| 108 |
+
|
| 109 |
+
score_lst = []
|
| 110 |
+
total_score = 0.0
|
| 111 |
+
loss = 0.0
|
| 112 |
+
start_time = time.perf_counter()
|
| 113 |
+
|
| 114 |
+
for k in range(1000000):
|
| 115 |
+
s = arrange(env.reset())
|
| 116 |
+
done = False
|
| 117 |
+
|
| 118 |
+
while not done:
|
| 119 |
+
if eps > np.random.rand():
|
| 120 |
+
a = env.action_space.sample()
|
| 121 |
+
else:
|
| 122 |
+
if device == "cpu":
|
| 123 |
+
a = np.argmax(q(s).detach().numpy())
|
| 124 |
+
else:
|
| 125 |
+
a = np.argmax(q(s).cpu().detach().numpy())
|
| 126 |
+
s_prime, r, done, _ = env.step(a)
|
| 127 |
+
s_prime = arrange(s_prime)
|
| 128 |
+
total_score += r
|
| 129 |
+
r = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
|
| 130 |
+
memory.push((s, float(r), int(a), s_prime, int(1 - done)))
|
| 131 |
+
s = s_prime
|
| 132 |
+
stage = env.unwrapped._stage
|
| 133 |
+
if len(memory) > 2000:
|
| 134 |
+
loss += train(q, q_target, memory, batch_size, gamma, optimizer, device)
|
| 135 |
+
t += 1
|
| 136 |
+
if t % update_interval == 0:
|
| 137 |
+
copy_weights(q, q_target)
|
| 138 |
+
torch.save(q.state_dict(), "mario_q.pth")
|
| 139 |
+
torch.save(q_target.state_dict(), "mario_q_target.pth")
|
| 140 |
+
|
| 141 |
+
if k % print_interval == 0:
|
| 142 |
+
time_spent, start_time = (
|
| 143 |
+
time.perf_counter() - start_time,
|
| 144 |
+
time.perf_counter(),
|
| 145 |
+
)
|
| 146 |
+
print(
|
| 147 |
+
"%s |Epoch : %d | score : %f | loss : %.2f | stage : %d | time spent: %f"
|
| 148 |
+
% (
|
| 149 |
+
device,
|
| 150 |
+
k,
|
| 151 |
+
total_score / print_interval,
|
| 152 |
+
loss / print_interval,
|
| 153 |
+
stage,
|
| 154 |
+
time_spent,
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
score_lst.append(total_score / print_interval)
|
| 158 |
+
total_score = 0
|
| 159 |
+
loss = 0.0
|
| 160 |
+
pickle.dump(score_lst, open("score.p", "wb"))
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
if __name__ == "__main__":
|
| 164 |
+
n_frame = 4
|
| 165 |
+
env = gym_super_mario_bros.make("SuperMarioBros-v0")
|
| 166 |
+
env = JoypadSpace(env, COMPLEX_MOVEMENT)
|
| 167 |
+
env = wrap_mario(env)
|
| 168 |
+
device = "cpu"
|
| 169 |
+
if torch.cuda.is_available():
|
| 170 |
+
device = "cuda"
|
| 171 |
+
elif torch.backends.mps.is_available():
|
| 172 |
+
device = "mps"
|
| 173 |
+
q = model(n_frame, env.action_space.n, device).to(device)
|
| 174 |
+
q_target = model(n_frame, env.action_space.n, device).to(device)
|
| 175 |
+
optimizer = optim.Adam(q.parameters(), lr=0.0001)
|
| 176 |
+
print(device)
|
| 177 |
+
|
| 178 |
+
main(env, q, q_target, optimizer, device)
|
flux_krea_00776_.png
ADDED
|
Git LFS Details
|
output.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c0f705a6effd35fbb11e8e5a337669c42ad244c62ba84a084f95f813a6fb981a
|
| 3 |
+
size 57430066
|
pyqt5_duel_dqn_super_mario_bros_tutorial/app.py
ADDED
|
@@ -0,0 +1,2294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
| 4 |
+
QHBoxLayout, QTextEdit, QPushButton, QTabWidget,
|
| 5 |
+
QLabel, QScrollArea, QFrame, QSplitter, QListWidget,
|
| 6 |
+
QListWidgetItem, QGroupBox, QProgressBar, QSpinBox,
|
| 7 |
+
QDoubleSpinBox, QCheckBox, QComboBox, QMessageBox)
|
| 8 |
+
from PyQt5.QtCore import Qt, QTimer, pyqtSignal
|
| 9 |
+
from PyQt5.QtGui import QFont, QPalette, QColor, QPixmap
|
| 10 |
+
import markdown
|
| 11 |
+
|
| 12 |
+
class CodeExerciseWidget(QWidget):
|
| 13 |
+
exerciseCompleted = pyqtSignal(str, bool) # section_name, completed
|
| 14 |
+
|
| 15 |
+
def __init__(self, title, description, code_template, solution_code, section_name):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.title = title
|
| 18 |
+
self.description = description
|
| 19 |
+
self.code_template = code_template
|
| 20 |
+
self.solution_code = solution_code
|
| 21 |
+
self.section_name = section_name
|
| 22 |
+
self.is_completed = False
|
| 23 |
+
|
| 24 |
+
self.init_ui()
|
| 25 |
+
|
| 26 |
+
def init_ui(self):
|
| 27 |
+
layout = QVBoxLayout()
|
| 28 |
+
|
| 29 |
+
# Title
|
| 30 |
+
title_label = QLabel(f"<h2>{self.title}</h2>")
|
| 31 |
+
layout.addWidget(title_label)
|
| 32 |
+
|
| 33 |
+
# Description
|
| 34 |
+
desc_label = QLabel(self.description)
|
| 35 |
+
desc_label.setWordWrap(True)
|
| 36 |
+
layout.addWidget(desc_label)
|
| 37 |
+
|
| 38 |
+
# Code editor area
|
| 39 |
+
code_group = QGroupBox("Your Implementation")
|
| 40 |
+
code_layout = QVBoxLayout()
|
| 41 |
+
|
| 42 |
+
self.code_editor = QTextEdit()
|
| 43 |
+
self.code_editor.setPlainText(self.code_template)
|
| 44 |
+
self.code_editor.setFont(QFont("Courier", 10))
|
| 45 |
+
code_layout.addWidget(self.code_editor)
|
| 46 |
+
|
| 47 |
+
# Buttons
|
| 48 |
+
button_layout = QHBoxLayout()
|
| 49 |
+
self.run_btn = QPushButton("Run Code")
|
| 50 |
+
self.run_btn.clicked.connect(self.run_code)
|
| 51 |
+
|
| 52 |
+
self.solution_btn = QPushButton("Show Solution")
|
| 53 |
+
self.solution_btn.clicked.connect(self.show_solution)
|
| 54 |
+
|
| 55 |
+
self.reset_btn = QPushButton("Reset")
|
| 56 |
+
self.reset_btn.clicked.connect(self.reset_code)
|
| 57 |
+
|
| 58 |
+
self.complete_btn = QPushButton("Mark as Completed")
|
| 59 |
+
self.complete_btn.clicked.connect(self.mark_completed)
|
| 60 |
+
self.complete_btn.setStyleSheet("background-color: #28a745; color: white;")
|
| 61 |
+
|
| 62 |
+
button_layout.addWidget(self.run_btn)
|
| 63 |
+
button_layout.addWidget(self.solution_btn)
|
| 64 |
+
button_layout.addWidget(self.reset_btn)
|
| 65 |
+
button_layout.addWidget(self.complete_btn)
|
| 66 |
+
code_layout.addLayout(button_layout)
|
| 67 |
+
|
| 68 |
+
code_group.setLayout(code_layout)
|
| 69 |
+
layout.addWidget(code_group)
|
| 70 |
+
|
| 71 |
+
# Output area
|
| 72 |
+
output_group = QGroupBox("Output")
|
| 73 |
+
output_layout = QVBoxLayout()
|
| 74 |
+
|
| 75 |
+
self.output_text = QTextEdit()
|
| 76 |
+
self.output_text.setReadOnly(True)
|
| 77 |
+
self.output_text.setFont(QFont("Courier", 9))
|
| 78 |
+
output_layout.addWidget(self.output_text)
|
| 79 |
+
|
| 80 |
+
output_group.setLayout(output_layout)
|
| 81 |
+
layout.addWidget(output_group)
|
| 82 |
+
|
| 83 |
+
self.setLayout(layout)
|
| 84 |
+
|
| 85 |
+
def run_code(self):
|
| 86 |
+
# In a real implementation, this would execute the code safely
|
| 87 |
+
# For this example, we'll just show a message
|
| 88 |
+
self.output_text.append("Running your code...")
|
| 89 |
+
self.output_text.append("(In a full implementation, this would execute the Python code)")
|
| 90 |
+
|
| 91 |
+
def show_solution(self):
|
| 92 |
+
self.code_editor.setPlainText(self.solution_code)
|
| 93 |
+
|
| 94 |
+
def reset_code(self):
|
| 95 |
+
self.code_editor.setPlainText(self.code_template)
|
| 96 |
+
self.output_text.clear()
|
| 97 |
+
|
| 98 |
+
def mark_completed(self):
|
| 99 |
+
self.is_completed = True
|
| 100 |
+
self.complete_btn.setEnabled(False)
|
| 101 |
+
self.complete_btn.setText("✓ Completed")
|
| 102 |
+
self.complete_btn.setStyleSheet("background-color: #6c757d; color: white;")
|
| 103 |
+
self.output_text.append("\n✓ Exercise marked as completed!")
|
| 104 |
+
self.exerciseCompleted.emit(self.section_name, True)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class SectionCompletionWidget(QWidget):
|
| 108 |
+
def __init__(self, section_name):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.section_name = section_name
|
| 111 |
+
self.is_completed = False
|
| 112 |
+
|
| 113 |
+
self.init_ui()
|
| 114 |
+
|
| 115 |
+
def init_ui(self):
|
| 116 |
+
layout = QHBoxLayout()
|
| 117 |
+
|
| 118 |
+
self.checkbox = QCheckBox(f"Complete {self.section_name}")
|
| 119 |
+
self.checkbox.stateChanged.connect(self.on_checkbox_changed)
|
| 120 |
+
|
| 121 |
+
layout.addWidget(self.checkbox)
|
| 122 |
+
layout.addStretch()
|
| 123 |
+
|
| 124 |
+
self.setLayout(layout)
|
| 125 |
+
|
| 126 |
+
def on_checkbox_changed(self, state):
|
| 127 |
+
self.is_completed = (state == Qt.Checked)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class DuelingDQNTutorialApp(QMainWindow):
|
| 131 |
+
def __init__(self):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.setWindowTitle("Dueling DQN for Super Mario Bros - Interactive Tutorial")
|
| 134 |
+
self.setGeometry(100, 100, 1200, 800)
|
| 135 |
+
|
| 136 |
+
# Track completion state
|
| 137 |
+
self.section_completion = {
|
| 138 |
+
"Introduction": False,
|
| 139 |
+
"Dueling DQN Theory": False,
|
| 140 |
+
"Environment Setup": False,
|
| 141 |
+
"Replay Memory": False,
|
| 142 |
+
"Neural Network": False,
|
| 143 |
+
"Training Algorithm": False,
|
| 144 |
+
"Complete Implementation": False,
|
| 145 |
+
"Exercises": False
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
self.exercise_completion = {
|
| 149 |
+
"Replay Memory": False,
|
| 150 |
+
"Dueling DQN": False,
|
| 151 |
+
"Environment Wrapper": False,
|
| 152 |
+
"Training Loop": False,
|
| 153 |
+
"Reward Shaping": False,
|
| 154 |
+
"Model Saving": False,
|
| 155 |
+
"Hyperparameter Tuning": False,
|
| 156 |
+
"Evaluation": False
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
self.init_ui()
|
| 160 |
+
|
| 161 |
+
def init_ui(self):
|
| 162 |
+
# Central widget and main layout
|
| 163 |
+
central_widget = QWidget()
|
| 164 |
+
self.setCentralWidget(central_widget)
|
| 165 |
+
main_layout = QHBoxLayout(central_widget)
|
| 166 |
+
|
| 167 |
+
# Left sidebar with navigation
|
| 168 |
+
left_sidebar = self.create_sidebar()
|
| 169 |
+
main_layout.addWidget(left_sidebar)
|
| 170 |
+
|
| 171 |
+
# Right content area
|
| 172 |
+
right_content = self.create_content_area()
|
| 173 |
+
main_layout.addWidget(right_content)
|
| 174 |
+
|
| 175 |
+
# Apply styling
|
| 176 |
+
self.apply_styling()
|
| 177 |
+
|
| 178 |
+
def create_sidebar(self):
|
| 179 |
+
sidebar = QWidget()
|
| 180 |
+
sidebar.setFixedWidth(300)
|
| 181 |
+
sidebar_layout = QVBoxLayout(sidebar)
|
| 182 |
+
|
| 183 |
+
# Title
|
| 184 |
+
title_label = QLabel("Dueling DQN Tutorial")
|
| 185 |
+
title_label.setFont(QFont("Arial", 14, QFont.Bold))
|
| 186 |
+
title_label.setAlignment(Qt.AlignCenter)
|
| 187 |
+
sidebar_layout.addWidget(title_label)
|
| 188 |
+
|
| 189 |
+
# Navigation list
|
| 190 |
+
self.nav_list = QListWidget()
|
| 191 |
+
self.nav_list.addItems([
|
| 192 |
+
"Introduction",
|
| 193 |
+
"Dueling DQN Theory",
|
| 194 |
+
"Environment Setup",
|
| 195 |
+
"Replay Memory",
|
| 196 |
+
"Neural Network",
|
| 197 |
+
"Training Algorithm",
|
| 198 |
+
"Complete Implementation",
|
| 199 |
+
"Exercises"
|
| 200 |
+
])
|
| 201 |
+
self.nav_list.currentRowChanged.connect(self.change_content)
|
| 202 |
+
sidebar_layout.addWidget(self.nav_list)
|
| 203 |
+
|
| 204 |
+
# Section completion tracking
|
| 205 |
+
completion_group = QGroupBox("Section Completion")
|
| 206 |
+
completion_layout = QVBoxLayout()
|
| 207 |
+
|
| 208 |
+
self.section_widgets = {}
|
| 209 |
+
for section in self.section_completion.keys():
|
| 210 |
+
widget = SectionCompletionWidget(section)
|
| 211 |
+
widget.checkbox.stateChanged.connect(self.update_progress)
|
| 212 |
+
self.section_widgets[section] = widget
|
| 213 |
+
completion_layout.addWidget(widget)
|
| 214 |
+
|
| 215 |
+
completion_group.setLayout(completion_layout)
|
| 216 |
+
sidebar_layout.addWidget(completion_group)
|
| 217 |
+
|
| 218 |
+
# Progress section
|
| 219 |
+
progress_group = QGroupBox("Your Progress")
|
| 220 |
+
progress_layout = QVBoxLayout()
|
| 221 |
+
|
| 222 |
+
self.progress_bar = QProgressBar()
|
| 223 |
+
self.progress_bar.setValue(0)
|
| 224 |
+
progress_layout.addWidget(QLabel("Overall Progress:"))
|
| 225 |
+
progress_layout.addWidget(self.progress_bar)
|
| 226 |
+
|
| 227 |
+
self.exercise_count = QLabel("Exercises: 0/8 completed")
|
| 228 |
+
progress_layout.addWidget(self.exercise_count)
|
| 229 |
+
|
| 230 |
+
self.section_count = QLabel("Sections: 0/8 completed")
|
| 231 |
+
progress_layout.addWidget(self.section_count)
|
| 232 |
+
|
| 233 |
+
progress_group.setLayout(progress_layout)
|
| 234 |
+
sidebar_layout.addWidget(progress_group)
|
| 235 |
+
|
| 236 |
+
sidebar_layout.addStretch()
|
| 237 |
+
|
| 238 |
+
return sidebar
|
| 239 |
+
|
| 240 |
+
def create_content_area(self):
|
| 241 |
+
# Create tab widget for different content sections
|
| 242 |
+
self.content_tabs = QTabWidget()
|
| 243 |
+
|
| 244 |
+
# Introduction tab
|
| 245 |
+
intro_tab = self.create_intro_tab()
|
| 246 |
+
self.content_tabs.addTab(intro_tab, "Introduction")
|
| 247 |
+
|
| 248 |
+
# Theory tab
|
| 249 |
+
theory_tab = self.create_theory_tab()
|
| 250 |
+
self.content_tabs.addTab(theory_tab, "Dueling DQN Theory")
|
| 251 |
+
|
| 252 |
+
# Environment tab
|
| 253 |
+
env_tab = self.create_environment_tab()
|
| 254 |
+
self.content_tabs.addTab(env_tab, "Environment Setup")
|
| 255 |
+
|
| 256 |
+
# Replay Memory tab
|
| 257 |
+
memory_tab = self.create_memory_tab()
|
| 258 |
+
self.content_tabs.addTab(memory_tab, "Replay Memory")
|
| 259 |
+
|
| 260 |
+
# Neural Network tab
|
| 261 |
+
nn_tab = self.create_nn_tab()
|
| 262 |
+
self.content_tabs.addTab(nn_tab, "Neural Network")
|
| 263 |
+
|
| 264 |
+
# Training tab
|
| 265 |
+
training_tab = self.create_training_tab()
|
| 266 |
+
self.content_tabs.addTab(training_tab, "Training Algorithm")
|
| 267 |
+
|
| 268 |
+
# Implementation tab
|
| 269 |
+
impl_tab = self.create_implementation_tab()
|
| 270 |
+
self.content_tabs.addTab(impl_tab, "Complete Implementation")
|
| 271 |
+
|
| 272 |
+
# Exercises tab
|
| 273 |
+
exercises_tab = self.create_exercises_tab()
|
| 274 |
+
self.content_tabs.addTab(exercises_tab, "Exercises")
|
| 275 |
+
|
| 276 |
+
return self.content_tabs
|
| 277 |
+
|
| 278 |
+
def create_intro_tab(self):
|
| 279 |
+
widget = QWidget()
|
| 280 |
+
layout = QVBoxLayout(widget)
|
| 281 |
+
|
| 282 |
+
# Title
|
| 283 |
+
title_label = QLabel("<h1 style='text-align: center; color: #2c3e50;'>Dueling DQN for Super Mario Bros</h1>")
|
| 284 |
+
title_label.setAlignment(Qt.AlignCenter)
|
| 285 |
+
layout.addWidget(title_label)
|
| 286 |
+
|
| 287 |
+
# Image section with better styling
|
| 288 |
+
image_frame = QFrame()
|
| 289 |
+
image_frame.setFrameStyle(QFrame.Box)
|
| 290 |
+
image_frame.setLineWidth(1)
|
| 291 |
+
image_frame.setMidLineWidth(0)
|
| 292 |
+
image_frame.setStyleSheet("QFrame { border: 1px solid #ddd; border-radius: 8px; background-color: #fafafa; }")
|
| 293 |
+
image_layout = QVBoxLayout(image_frame)
|
| 294 |
+
|
| 295 |
+
# Try multiple possible image locations
|
| 296 |
+
image_loaded = False
|
| 297 |
+
possible_paths = [
|
| 298 |
+
"flux_krea_00776_.png",
|
| 299 |
+
"./flux_krea_00776_.png",
|
| 300 |
+
"images/flux_krea_00776_.png",
|
| 301 |
+
"../flux_krea_00776_.png"
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
pixmap = None
|
| 305 |
+
for path in possible_paths:
|
| 306 |
+
if os.path.exists(path):
|
| 307 |
+
pixmap = QPixmap(path)
|
| 308 |
+
if not pixmap.isNull():
|
| 309 |
+
image_loaded = True
|
| 310 |
+
break
|
| 311 |
+
|
| 312 |
+
if image_loaded and pixmap:
|
| 313 |
+
image_label = QLabel()
|
| 314 |
+
# Scale image to reasonable size while maintaining aspect ratio
|
| 315 |
+
scaled_pixmap = pixmap.scaled(550, 350, Qt.KeepAspectRatio, Qt.SmoothTransformation)
|
| 316 |
+
image_label.setPixmap(scaled_pixmap)
|
| 317 |
+
image_label.setAlignment(Qt.AlignCenter)
|
| 318 |
+
image_label.setStyleSheet("padding: 10px;")
|
| 319 |
+
image_layout.addWidget(image_label)
|
| 320 |
+
|
| 321 |
+
# Image caption
|
| 322 |
+
caption = QLabel("<small><em>Dueling DQN Architecture Visualization generated by flux.1-krea.dev</em></small>")
|
| 323 |
+
caption.setAlignment(Qt.AlignCenter)
|
| 324 |
+
caption.setStyleSheet("color: #666; font-style: italic; padding: 5px;")
|
| 325 |
+
image_layout.addWidget(caption)
|
| 326 |
+
else:
|
| 327 |
+
# Fallback placeholder
|
| 328 |
+
placeholder = QLabel("🎨 Architecture Visualization Image\n<small><em>flux_krea_00776_.png</em></small>")
|
| 329 |
+
placeholder.setAlignment(Qt.AlignCenter)
|
| 330 |
+
placeholder.setStyleSheet("color: #999; font-style: italic; padding: 40px; background-color: #f0f0f0; border-radius: 4px;")
|
| 331 |
+
image_layout.addWidget(placeholder)
|
| 332 |
+
|
| 333 |
+
layout.addWidget(image_frame)
|
| 334 |
+
|
| 335 |
+
# Welcome section
|
| 336 |
+
welcome_text = """
|
| 337 |
+
<h2 style='color: #34495e;'>Welcome to the Interactive Tutorial!</h2>
|
| 338 |
+
|
| 339 |
+
<p style='line-height: 1.6;'>This application will guide you through implementing a <strong>Dueling Deep Q-Network (DQN)</strong>
|
| 340 |
+
to play Super Mario Bros. You'll learn both the theory behind Dueling DQN and get
|
| 341 |
+
hands-on experience building each component through interactive exercises.</p>
|
| 342 |
+
|
| 343 |
+
<div style='background-color: #e8f4fd; padding: 15px; border-radius: 5px; border-left: 4px solid #3498db;'>
|
| 344 |
+
<strong>🚀 Interactive Learning:</strong> This tutorial combines theoretical explanations with
|
| 345 |
+
hands-on coding exercises. Complete each section and exercise to track your progress!
|
| 346 |
+
</div>
|
| 347 |
+
|
| 348 |
+
<h3 style='color: #2c3e50; margin-top: 20px;'>What You'll Learn:</h3>
|
| 349 |
+
<ul style='line-height: 1.6;'>
|
| 350 |
+
<li><strong>Mathematical Foundation:</strong> Understand the dueling architecture that separates value and advantage streams</li>
|
| 351 |
+
<li><strong>Environment Setup:</strong> Configure the Super Mario Bros environment with proper preprocessing</li>
|
| 352 |
+
<li><strong>Experience Replay:</strong> Implement memory buffer for stable training</li>
|
| 353 |
+
<li><strong>Neural Network:</strong> Build the dueling DQN with PyTorch</li>
|
| 354 |
+
<li><strong>Training Algorithm:</strong> Master the DQN training loop with target networks</li>
|
| 355 |
+
<li><strong>Complete Implementation:</strong> Combine all components into a working AI agent</li>
|
| 356 |
+
</ul>
|
| 357 |
+
|
| 358 |
+
<h3 style='color: #2c3e50;'>Prerequisites:</h3>
|
| 359 |
+
<ul style='line-height: 1.6;'>
|
| 360 |
+
<li>Basic Python programming knowledge</li>
|
| 361 |
+
<li>Familiarity with PyTorch (helpful but not required)</li>
|
| 362 |
+
<li>Understanding of basic reinforcement learning concepts</li>
|
| 363 |
+
<li>Curiosity to build an AI that can play Super Mario Bros! 🎮</li>
|
| 364 |
+
</ul>
|
| 365 |
+
|
| 366 |
+
<div style='background-color: #fff3cd; padding: 15px; border-radius: 5px; border-left: 4px solid #ffc107; margin-top: 20px;'>
|
| 367 |
+
<strong>💡 Pro Tip:</strong> Use the navigation panel on the left to move through sections.
|
| 368 |
+
Mark sections as completed and work through exercises to build your understanding step by step.
|
| 369 |
+
Each section builds upon the previous one!
|
| 370 |
+
</div>
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
text_edit = QTextEdit()
|
| 374 |
+
text_edit.setHtml(welcome_text)
|
| 375 |
+
text_edit.setReadOnly(True)
|
| 376 |
+
text_edit.setStyleSheet("""
|
| 377 |
+
QTextEdit {
|
| 378 |
+
background-color: white;
|
| 379 |
+
border: 1px solid #ddd;
|
| 380 |
+
border-radius: 5px;
|
| 381 |
+
padding: 10px;
|
| 382 |
+
line-height: 1.6;
|
| 383 |
+
}
|
| 384 |
+
""")
|
| 385 |
+
layout.addWidget(text_edit)
|
| 386 |
+
|
| 387 |
+
return widget
|
| 388 |
+
|
| 389 |
+
def create_theory_tab(self):
|
| 390 |
+
widget = QWidget()
|
| 391 |
+
layout = QVBoxLayout(widget)
|
| 392 |
+
|
| 393 |
+
content = """
|
| 394 |
+
<h1>Dueling DQN Theory</h1>
|
| 395 |
+
|
| 396 |
+
<h2>What is Dueling DQN?</h2>
|
| 397 |
+
|
| 398 |
+
<p>Dueling DQN is an improvement over the standard Deep Q-Network that separates
|
| 399 |
+
the estimation of state value and action advantages. This architecture leads to
|
| 400 |
+
more stable and efficient learning.</p>
|
| 401 |
+
|
| 402 |
+
<h2>Mathematical Foundation</h2>
|
| 403 |
+
|
| 404 |
+
<p>In standard DQN, we estimate Q-values directly:</p>
|
| 405 |
+
<p style="text-align: center;"><code>Q(s, a) = f(s, a)</code></p>
|
| 406 |
+
|
| 407 |
+
<p>In Dueling DQN, we decompose the Q-value into two streams:</p>
|
| 408 |
+
<p style="text-align: center;"><code>Q(s, a) = V(s) + A(s, a)</code></p>
|
| 409 |
+
|
| 410 |
+
<p>Where:</p>
|
| 411 |
+
<ul>
|
| 412 |
+
<li><strong>V(s)</strong>: Value function - how good it is to be in state s</li>
|
| 413 |
+
<li><strong>A(s, a)</strong>: Advantage function - how much better action a is compared to other actions in state s</li>
|
| 414 |
+
</ul>
|
| 415 |
+
|
| 416 |
+
<p>To ensure identifiability, we use the following aggregation:</p>
|
| 417 |
+
<p style="text-align: center;"><code>Q(s, a) = V(s) + (A(s, a) - mean(A(s, ·)))</code></p>
|
| 418 |
+
|
| 419 |
+
<h2>Benefits of Dueling Architecture</h2>
|
| 420 |
+
|
| 421 |
+
<h3>Compared to Standard DQN:</h3>
|
| 422 |
+
<ul>
|
| 423 |
+
<li>Better generalization across actions</li>
|
| 424 |
+
<li>More stable learning</li>
|
| 425 |
+
<li>Faster convergence in many environments</li>
|
| 426 |
+
<li>Ability to learn which states are valuable without having to learn the effect of each action</li>
|
| 427 |
+
</ul>
|
| 428 |
+
|
| 429 |
+
<h3>Compared to PPO (Proximal Policy Optimization):</h3>
|
| 430 |
+
|
| 431 |
+
<table border="1" style="width:100%">
|
| 432 |
+
<tr>
|
| 433 |
+
<th>Dueling DQN</th>
|
| 434 |
+
<th>PPO</th>
|
| 435 |
+
</tr>
|
| 436 |
+
<tr>
|
| 437 |
+
<td>Value-based method</td>
|
| 438 |
+
<td>Policy-based method</td>
|
| 439 |
+
</tr>
|
| 440 |
+
<tr>
|
| 441 |
+
<td>Off-policy learning</td>
|
| 442 |
+
<td>On-policy learning</td>
|
| 443 |
+
</tr>
|
| 444 |
+
<tr>
|
| 445 |
+
<td>Discrete action spaces</td>
|
| 446 |
+
<td>Continuous or discrete action spaces</td>
|
| 447 |
+
</tr>
|
| 448 |
+
<tr>
|
| 449 |
+
<td>Generally more sample-efficient</td>
|
| 450 |
+
<td>Generally more stable</td>
|
| 451 |
+
</tr>
|
| 452 |
+
<tr>
|
| 453 |
+
<td>Easier to implement and debug</td>
|
| 454 |
+
<td>More hyperparameters to tune</td>
|
| 455 |
+
</tr>
|
| 456 |
+
</table>
|
| 457 |
+
|
| 458 |
+
<p>For Super Mario Bros, which has discrete actions, Dueling DQN strikes a good
|
| 459 |
+
balance between sample efficiency and implementation complexity.</p>
|
| 460 |
+
|
| 461 |
+
<h2>Architecture Diagram</h2>
|
| 462 |
+
|
| 463 |
+
<p style="text-align: center;">
|
| 464 |
+
[Convolutional Layers] → [Feature Vector] → [Value Stream] + [Advantage Stream] → [Q-values]
|
| 465 |
+
</p>
|
| 466 |
+
|
| 467 |
+
<p>The convolutional layers process the game frames, then the network splits into
|
| 468 |
+
two streams that estimate V(s) and A(s, a) separately, which are then combined
|
| 469 |
+
to produce the final Q-values.</p>
|
| 470 |
+
|
| 471 |
+
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin-top: 20px;">
|
| 472 |
+
<strong>Key Insight:</strong> The dueling architecture allows the network to learn
|
| 473 |
+
which states are inherently valuable without having to learn the value of each action
|
| 474 |
+
in those states. This is particularly useful in games like Mario where many actions
|
| 475 |
+
lead to similar outcomes in safe states.
|
| 476 |
+
</div>
|
| 477 |
+
"""
|
| 478 |
+
|
| 479 |
+
text_edit = QTextEdit()
|
| 480 |
+
text_edit.setHtml(content)
|
| 481 |
+
text_edit.setReadOnly(True)
|
| 482 |
+
layout.addWidget(text_edit)
|
| 483 |
+
|
| 484 |
+
return widget
|
| 485 |
+
|
| 486 |
+
def create_environment_tab(self):
|
| 487 |
+
widget = QWidget()
|
| 488 |
+
layout = QVBoxLayout(widget)
|
| 489 |
+
|
| 490 |
+
content = """
|
| 491 |
+
<h1>Environment Setup</h1>
|
| 492 |
+
|
| 493 |
+
<h2>Setting up Super Mario Bros</h2>
|
| 494 |
+
|
| 495 |
+
<p>We use the <code>gym-super-mario-bros</code> package to create our environment.
|
| 496 |
+
This provides a standardized interface to interact with the game.</p>
|
| 497 |
+
|
| 498 |
+
<h3>Key Components:</h3>
|
| 499 |
+
|
| 500 |
+
<pre><code>
|
| 501 |
+
import gym_super_mario_bros
|
| 502 |
+
from nes_py.wrappers import JoypadSpace
|
| 503 |
+
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT
|
| 504 |
+
from wrappers import *
|
| 505 |
+
</code></pre>
|
| 506 |
+
|
| 507 |
+
<h3>Environment Wrappers</h3>
|
| 508 |
+
|
| 509 |
+
<p>We apply several wrappers to preprocess the environment:</p>
|
| 510 |
+
|
| 511 |
+
<ul>
|
| 512 |
+
<li><strong>JoypadSpace</strong>: Limits actions to a predefined set (COMPLEX_MOVEMENT)</li>
|
| 513 |
+
<li><strong>wrap_mario</strong>: Custom wrapper that applies:
|
| 514 |
+
<ul>
|
| 515 |
+
<li>Frame skipping</li>
|
| 516 |
+
<li>Frame stacking</li>
|
| 517 |
+
<li>Reward scaling</li>
|
| 518 |
+
<li>Observation preprocessing</li>
|
| 519 |
+
</ul>
|
| 520 |
+
</li>
|
| 521 |
+
</ul>
|
| 522 |
+
|
| 523 |
+
<h3>Observation Preprocessing</h3>
|
| 524 |
+
|
| 525 |
+
<p>The <code>arrange</code> function transforms the observation:</p>
|
| 526 |
+
|
| 527 |
+
<pre><code>
|
| 528 |
+
def arrange(s):
|
| 529 |
+
if not type(s) == "numpy.ndarray":
|
| 530 |
+
s = np.array(s)
|
| 531 |
+
assert len(s.shape) == 3
|
| 532 |
+
ret = np.transpose(s, (2, 0, 1))
|
| 533 |
+
return np.expand_dims(ret, 0)
|
| 534 |
+
</code></pre>
|
| 535 |
+
|
| 536 |
+
<p>This function:</p>
|
| 537 |
+
<ol>
|
| 538 |
+
<li>Converts the observation to a numpy array if needed</li>
|
| 539 |
+
<li>Transposes dimensions from (H, W, C) to (C, H, W)</li>
|
| 540 |
+
<li>Adds a batch dimension</li>
|
| 541 |
+
</ol>
|
| 542 |
+
|
| 543 |
+
<h3>Action Space</h3>
|
| 544 |
+
|
| 545 |
+
<p>COMPLEX_MOVEMENT provides 12 possible actions that combine button presses:</p>
|
| 546 |
+
|
| 547 |
+
<ul>
|
| 548 |
+
<li>NOOP - No operation</li>
|
| 549 |
+
<li>RIGHT - Move right</li>
|
| 550 |
+
<li>RIGHT+A - Right and jump</li>
|
| 551 |
+
<li>RIGHT+B - Right and run</li>
|
| 552 |
+
<li>RIGHT+A+B - Right, jump and run</li>
|
| 553 |
+
<li>A - Jump</li>
|
| 554 |
+
<li>LEFT - Move left</li>
|
| 555 |
+
<li>LEFT+A - Left and jump</li>
|
| 556 |
+
<li>LEFT+B - Left and run</li>
|
| 557 |
+
<li>LEFT+A+B - Left, jump and run</li>
|
| 558 |
+
<li>DOWN - Duck</li>
|
| 559 |
+
<li>UP - Look up</li>
|
| 560 |
+
</ul>
|
| 561 |
+
|
| 562 |
+
<h3>Installation Requirements</h3>
|
| 563 |
+
|
| 564 |
+
<pre><code>
|
| 565 |
+
pip install gym-super-mario-bros
|
| 566 |
+
pip install nes_py
|
| 567 |
+
pip install opencv-python
|
| 568 |
+
pip install torch
|
| 569 |
+
</code></pre>
|
| 570 |
+
|
| 571 |
+
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin-top: 20px;">
|
| 572 |
+
<strong>Note:</strong> The environment returns RGB frames of shape (240, 256, 3).
|
| 573 |
+
We preprocess these to (4, 84, 84) by stacking 4 grayscale frames and resizing.
|
| 574 |
+
This significantly reduces the input dimensionality while preserving temporal information.
|
| 575 |
+
</div>
|
| 576 |
+
"""
|
| 577 |
+
|
| 578 |
+
text_edit = QTextEdit()
|
| 579 |
+
text_edit.setHtml(content)
|
| 580 |
+
text_edit.setReadOnly(True)
|
| 581 |
+
layout.addWidget(text_edit)
|
| 582 |
+
|
| 583 |
+
return widget
|
| 584 |
+
|
| 585 |
+
def create_memory_tab(self):
|
| 586 |
+
widget = QWidget()
|
| 587 |
+
layout = QVBoxLayout(widget)
|
| 588 |
+
|
| 589 |
+
content = """
|
| 590 |
+
<h1>Replay Memory</h1>
|
| 591 |
+
|
| 592 |
+
<h2>Experience Replay</h2>
|
| 593 |
+
|
| 594 |
+
<p>Experience replay is a key component of DQN that helps stabilize training
|
| 595 |
+
by breaking correlations in sequential observations.</p>
|
| 596 |
+
|
| 597 |
+
<h3>Implementation</h3>
|
| 598 |
+
|
| 599 |
+
<pre><code>
|
| 600 |
+
class replay_memory(object):
|
| 601 |
+
def __init__(self, N):
|
| 602 |
+
self.memory = deque(maxlen=N)
|
| 603 |
+
|
| 604 |
+
def push(self, transition):
|
| 605 |
+
self.memory.append(transition)
|
| 606 |
+
|
| 607 |
+
def sample(self, n):
|
| 608 |
+
return random.sample(self.memory, n)
|
| 609 |
+
|
| 610 |
+
def __len__(self):
|
| 611 |
+
return len(self.memory)
|
| 612 |
+
</code></pre>
|
| 613 |
+
|
| 614 |
+
<h3>Key Features:</h3>
|
| 615 |
+
|
| 616 |
+
<ul>
|
| 617 |
+
<li><strong>Fixed Size</strong>: Uses deque with maxlen to maintain a fixed memory size</li>
|
| 618 |
+
<li><strong>Random Sampling</strong>: Samples random batches to break temporal correlations</li>
|
| 619 |
+
<li><strong>Transition Storage</strong>: Stores (state, action, reward, next_state, done) tuples</li>
|
| 620 |
+
</ul>
|
| 621 |
+
|
| 622 |
+
<h3>Why Use Experience Replay?</h3>
|
| 623 |
+
|
| 624 |
+
<ol>
|
| 625 |
+
<li><strong>Data Efficiency</strong>: Each experience can be used multiple times</li>
|
| 626 |
+
<li><strong>Reduced Correlation</strong>: Random sampling breaks correlation between consecutive samples</li>
|
| 627 |
+
<li><strong>Smoother Learning</strong>: Averages behavior distribution over many previous states</li>
|
| 628 |
+
<li><strong>Stable Training</strong>: Preforms like training on a stationary distribution</li>
|
| 629 |
+
</ol>
|
| 630 |
+
|
| 631 |
+
<h3>Transition Format</h3>
|
| 632 |
+
|
| 633 |
+
<p>Each transition stored in memory contains:</p>
|
| 634 |
+
|
| 635 |
+
<ul>
|
| 636 |
+
<li><strong>s</strong>: Current state (preprocessed frame)</li>
|
| 637 |
+
<li><strong>a</strong>: Action taken</li>
|
| 638 |
+
<li><strong>r</strong>: Reward received (transformed)</li>
|
| 639 |
+
<li><strong>s'</strong>: Next state</li>
|
| 640 |
+
<li><strong>done</strong>: Whether the episode ended (stored as 1-done for terminal states)</li>
|
| 641 |
+
</ul>
|
| 642 |
+
|
| 643 |
+
<h3>Memory Management</h3>
|
| 644 |
+
|
| 645 |
+
<p>The replay memory has a fixed capacity (N=50000 in our implementation).
|
| 646 |
+
When the memory is full, older experiences are automatically removed as new ones are added.
|
| 647 |
+
This ensures the agent always trains on recent experiences while maintaining diversity.</p>
|
| 648 |
+
|
| 649 |
+
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin-top: 20px;">
|
| 650 |
+
<strong>Best Practice:</strong> Wait until the replay memory has sufficient samples
|
| 651 |
+
(typically 1000-5000) before starting training to ensure diverse batches and
|
| 652 |
+
avoid overfitting to early experiences.
|
| 653 |
+
</div>
|
| 654 |
+
"""
|
| 655 |
+
|
| 656 |
+
text_edit = QTextEdit()
|
| 657 |
+
text_edit.setHtml(content)
|
| 658 |
+
text_edit.setReadOnly(True)
|
| 659 |
+
layout.addWidget(text_edit)
|
| 660 |
+
|
| 661 |
+
return widget
|
| 662 |
+
|
| 663 |
+
def create_nn_tab(self):
|
| 664 |
+
widget = QWidget()
|
| 665 |
+
layout = QVBoxLayout(widget)
|
| 666 |
+
|
| 667 |
+
content = """
|
| 668 |
+
<h1>Neural Network Architecture</h1>
|
| 669 |
+
|
| 670 |
+
<h2>Dueling DQN Model</h2>
|
| 671 |
+
|
| 672 |
+
<p>The neural network implements the dueling architecture with separate streams
|
| 673 |
+
for value and advantage estimation.</p>
|
| 674 |
+
|
| 675 |
+
<h3>Network Structure</h3>
|
| 676 |
+
|
| 677 |
+
<pre><code>
|
| 678 |
+
class model(nn.Module):
|
| 679 |
+
def __init__(self, n_frame, n_action, device):
|
| 680 |
+
super(model, self).__init__()
|
| 681 |
+
self.layer1 = nn.Conv2d(n_frame, 32, 8, 4)
|
| 682 |
+
self.layer2 = nn.Conv2d(32, 64, 3, 1)
|
| 683 |
+
self.fc = nn.Linear(20736, 512)
|
| 684 |
+
self.q = nn.Linear(512, n_action) # Advantage stream
|
| 685 |
+
self.v = nn.Linear(512, 1) # Value stream
|
| 686 |
+
|
| 687 |
+
self.device = device
|
| 688 |
+
|
| 689 |
+
def forward(self, x):
|
| 690 |
+
if type(x) != torch.Tensor:
|
| 691 |
+
x = torch.FloatTensor(x).to(self.device)
|
| 692 |
+
x = torch.relu(self.layer1(x))
|
| 693 |
+
x = torch.relu(self.layer2(x))
|
| 694 |
+
x = x.view(-1, 20736)
|
| 695 |
+
x = torch.relu(self.fc(x))
|
| 696 |
+
adv = self.q(x) # Advantage stream
|
| 697 |
+
v = self.v(x) # Value stream
|
| 698 |
+
# Combine using dueling formula
|
| 699 |
+
q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
|
| 700 |
+
return q
|
| 701 |
+
</code></pre>
|
| 702 |
+
|
| 703 |
+
<h3>Layer Details</h3>
|
| 704 |
+
|
| 705 |
+
<ul>
|
| 706 |
+
<li><strong>Conv2d(4, 32, 8, 4)</strong>: 32 filters, kernel size 8, stride 4</li>
|
| 707 |
+
<li><strong>Conv2d(32, 64, 3, 1)</strong>: 64 filters, kernel size 3, stride 1</li>
|
| 708 |
+
<li><strong>Linear(20736, 512)</strong>: Fully connected layer with 512 units</li>
|
| 709 |
+
<li><strong>Value Head</strong>: Single output estimating V(s)</li>
|
| 710 |
+
<li><strong>Advantage Head</strong>: n_action outputs estimating A(s, a)</li>
|
| 711 |
+
</ul>
|
| 712 |
+
|
| 713 |
+
<h3>Dueling Combination</h3>
|
| 714 |
+
|
| 715 |
+
<p>The key innovation is how we combine value and advantage:</p>
|
| 716 |
+
|
| 717 |
+
<pre><code>
|
| 718 |
+
q = v + (adv - 1 / adv.shape[-1] * adv.sum(-1, keepdim=True))
|
| 719 |
+
</code></pre>
|
| 720 |
+
|
| 721 |
+
<p>This ensures that:</p>
|
| 722 |
+
<ul>
|
| 723 |
+
<li>The advantage function has zero mean for each state</li>
|
| 724 |
+
<li>We can recover both V(s) and A(s, a) from Q(s, a)</li>
|
| 725 |
+
<li>The network learns which states are valuable without having to learn
|
| 726 |
+
the effect of each action in every state</li>
|
| 727 |
+
<li>The value stream focuses on state quality, advantage stream on action quality</li>
|
| 728 |
+
</ul>
|
| 729 |
+
|
| 730 |
+
<h3>Weight Initialization</h3>
|
| 731 |
+
|
| 732 |
+
<pre><code>
|
| 733 |
+
def init_weights(m):
|
| 734 |
+
if type(m) == nn.Conv2d:
|
| 735 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
| 736 |
+
m.bias.data.fill_(0.01)
|
| 737 |
+
</code></pre>
|
| 738 |
+
|
| 739 |
+
<p>We use Xavier initialization for convolutional layers to maintain stable gradients
|
| 740 |
+
throughout the network. This helps with convergence during training.</p>
|
| 741 |
+
|
| 742 |
+
<h3>Feature Extraction</h3>
|
| 743 |
+
|
| 744 |
+
<p>The convolutional layers extract spatial features from the input frames:</p>
|
| 745 |
+
<ul>
|
| 746 |
+
<li><strong>First conv layer</strong>: Detects basic features like edges and colors</li>
|
| 747 |
+
<li><strong>Second conv layer</strong>: Combines basic features into more complex patterns</li>
|
| 748 |
+
<li><strong>Flattening</strong>: Converts spatial features to a 1D vector for the fully connected layers</li>
|
| 749 |
+
</ul>
|
| 750 |
+
|
| 751 |
+
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin-top: 20px;">
|
| 752 |
+
<strong>Architecture Choice:</strong> The 20736 value comes from the flattened output
|
| 753 |
+
of the second convolutional layer. This is calculated based on the input dimensions
|
| 754 |
+
and the network architecture. For a (4, 84, 84) input, the convolutions produce
|
| 755 |
+
feature maps that when flattened give 20736 elements.
|
| 756 |
+
</div>
|
| 757 |
+
"""
|
| 758 |
+
|
| 759 |
+
text_edit = QTextEdit()
|
| 760 |
+
text_edit.setHtml(content)
|
| 761 |
+
text_edit.setReadOnly(True)
|
| 762 |
+
layout.addWidget(text_edit)
|
| 763 |
+
|
| 764 |
+
return widget
|
| 765 |
+
|
| 766 |
+
def create_training_tab(self):
|
| 767 |
+
widget = QWidget()
|
| 768 |
+
layout = QVBoxLayout(widget)
|
| 769 |
+
|
| 770 |
+
content = """
|
| 771 |
+
<h1>Training Algorithm</h1>
|
| 772 |
+
|
| 773 |
+
<h2>DQN Training Loop</h2>
|
| 774 |
+
|
| 775 |
+
<p>The training process follows the standard DQN algorithm with the addition
|
| 776 |
+
of the dueling architecture.</p>
|
| 777 |
+
|
| 778 |
+
<h3>Key Components</h3>
|
| 779 |
+
|
| 780 |
+
<ul>
|
| 781 |
+
<li><strong>Target Network</strong>: Separate network for stable Q-targets</li>
|
| 782 |
+
<li><strong>Experience Replay</strong>: Stores and samples past experiences</li>
|
| 783 |
+
<li><strong>Epsilon-Greedy Exploration</strong>: Balances exploration and exploitation</li>
|
| 784 |
+
<li><strong>Periodic Updates</strong>: Syncs target network with main network</li>
|
| 785 |
+
<li><strong>Gradient Clipping</strong>: Prevents exploding gradients (implicit in smooth_l1_loss)</li>
|
| 786 |
+
</ul>
|
| 787 |
+
|
| 788 |
+
<h3>Training Function</h3>
|
| 789 |
+
|
| 790 |
+
<pre><code>
|
| 791 |
+
def train(q, q_target, memory, batch_size, gamma, optimizer, device):
|
| 792 |
+
s, r, a, s_prime, done = list(map(list, zip(*memory.sample(batch_size))))
|
| 793 |
+
s = np.array(s).squeeze()
|
| 794 |
+
s_prime = np.array(s_prime).squeeze()
|
| 795 |
+
a_max = q(s_prime).max(1)[1].unsqueeze(-1)
|
| 796 |
+
r = torch.FloatTensor(r).unsqueeze(-1).to(device)
|
| 797 |
+
done = torch.FloatTensor(done).unsqueeze(-1).to(device)
|
| 798 |
+
with torch.no_grad():
|
| 799 |
+
y = r + gamma * q_target(s_prime).gather(1, a_max) * done
|
| 800 |
+
a = torch.tensor(a).unsqueeze(-1).to(device)
|
| 801 |
+
q_value = torch.gather(q(s), dim=1, index=a.view(-1, 1).long())
|
| 802 |
+
|
| 803 |
+
loss = F.smooth_l1_loss(q_value, y).mean()
|
| 804 |
+
optimizer.zero_grad()
|
| 805 |
+
loss.backward()
|
| 806 |
+
optimizer.step()
|
| 807 |
+
return loss
|
| 808 |
+
</code></pre>
|
| 809 |
+
|
| 810 |
+
<h3>Loss Calculation</h3>
|
| 811 |
+
|
| 812 |
+
<p>We use Smooth L1 Loss (Huber loss) which is more robust to outliers than MSE:</p>
|
| 813 |
+
|
| 814 |
+
<pre><code>
|
| 815 |
+
loss = F.smooth_l1_loss(q_value, y).mean()
|
| 816 |
+
</code></pre>
|
| 817 |
+
|
| 818 |
+
<p>Smooth L1 Loss behaves like L1 loss when the error is large (reducing sensitivity to outliers)
|
| 819 |
+
and like L2 loss when the error is small (providing smooth gradients).</p>
|
| 820 |
+
|
| 821 |
+
<h3>Target Calculation</h3>
|
| 822 |
+
|
| 823 |
+
<p>The target Q-values are calculated using the target network:</p>
|
| 824 |
+
|
| 825 |
+
<pre><code>
|
| 826 |
+
y = r + gamma * q_target(s_prime).gather(1, a_max) * done
|
| 827 |
+
</code></pre>
|
| 828 |
+
|
| 829 |
+
<p>Where:</p>
|
| 830 |
+
<ul>
|
| 831 |
+
<li><strong>r</strong>: Immediate reward</li>
|
| 832 |
+
<li><strong>gamma</strong>: Discount factor (0.99)</li>
|
| 833 |
+
<li><strong>q_target(s_prime)</strong>: Q-values from target network</li>
|
| 834 |
+
<li><strong>a_max</strong>: Best action according to online network (Double DQN)</li>
|
| 835 |
+
<li><strong>done</strong>: Terminal state indicator (0 for terminal states)</li>
|
| 836 |
+
</ul>
|
| 837 |
+
|
| 838 |
+
<h3>Double DQN</h3>
|
| 839 |
+
|
| 840 |
+
<p>Note that we use the Double DQN approach: the online network selects the action,
|
| 841 |
+
but the target network evaluates it. This reduces overestimation bias.</p>
|
| 842 |
+
|
| 843 |
+
<h3>Main Training Loop</h3>
|
| 844 |
+
|
| 845 |
+
<p>The main function handles:</p>
|
| 846 |
+
|
| 847 |
+
<ol>
|
| 848 |
+
<li>Environment interaction and experience collection</li>
|
| 849 |
+
<li>Epsilon-greedy action selection</li>
|
| 850 |
+
<li>Reward transformation and experience storage</li>
|
| 851 |
+
<li>Network training with experience replay</li>
|
| 852 |
+
<li>Target network updates at fixed intervals</li>
|
| 853 |
+
<li>Model checkpointing and progress logging</li>
|
| 854 |
+
</ol>
|
| 855 |
+
|
| 856 |
+
<h3>Hyperparameters</h3>
|
| 857 |
+
|
| 858 |
+
<ul>
|
| 859 |
+
<li><strong>Gamma</strong>: 0.99 (discount factor for future rewards)</li>
|
| 860 |
+
<li><strong>Batch Size</strong>: 256 (experiences per training step)</li>
|
| 861 |
+
<li><strong>Memory Size</strong>: 50,000 (maximum replay buffer size)</li>
|
| 862 |
+
<li><strong>Learning Rate</strong>: 0.0001 (Adam optimizer)</li>
|
| 863 |
+
<li><strong>Update Interval</strong>: 50 steps (target network sync frequency)</li>
|
| 864 |
+
<li><strong>Epsilon</strong>: 0.001 (fixed exploration rate during training)</li>
|
| 865 |
+
<li><strong>Print Interval</strong>: 10 episodes (progress reporting frequency)</li>
|
| 866 |
+
</ul>
|
| 867 |
+
|
| 868 |
+
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin-top: 20px;">
|
| 869 |
+
<strong>Training Strategy:</strong> The agent starts training only after collecting
|
| 870 |
+
2000 experiences to ensure diverse training data. The fixed epsilon of 0.001 means
|
| 871 |
+
the agent explores 0.1% of the time, exploiting its learned policy 99.9% of the time.
|
| 872 |
+
</div>
|
| 873 |
+
"""
|
| 874 |
+
|
| 875 |
+
text_edit = QTextEdit()
|
| 876 |
+
text_edit.setHtml(content)
|
| 877 |
+
text_edit.setReadOnly(True)
|
| 878 |
+
layout.addWidget(text_edit)
|
| 879 |
+
|
| 880 |
+
return widget
|
| 881 |
+
|
| 882 |
+
def create_implementation_tab(self):
|
| 883 |
+
widget = QWidget()
|
| 884 |
+
layout = QVBoxLayout(widget)
|
| 885 |
+
|
| 886 |
+
content = """
|
| 887 |
+
<h1>Complete Implementation</h1>
|
| 888 |
+
|
| 889 |
+
<h2>Putting It All Together</h2>
|
| 890 |
+
|
| 891 |
+
<p>Now let's look at the complete implementation that brings all components together.</p>
|
| 892 |
+
|
| 893 |
+
<h3>Main Function</h3>
|
| 894 |
+
|
| 895 |
+
<pre><code>
|
| 896 |
+
def main(env, q, q_target, optimizer, device):
|
| 897 |
+
t = 0
|
| 898 |
+
gamma = 0.99
|
| 899 |
+
batch_size = 256
|
| 900 |
+
|
| 901 |
+
N = 50000
|
| 902 |
+
eps = 0.001
|
| 903 |
+
memory = replay_memory(N)
|
| 904 |
+
update_interval = 50
|
| 905 |
+
print_interval = 10
|
| 906 |
+
|
| 907 |
+
score_lst = []
|
| 908 |
+
total_score = 0.0
|
| 909 |
+
loss = 0.0
|
| 910 |
+
start_time = time.perf_counter()
|
| 911 |
+
|
| 912 |
+
for k in range(1000000):
|
| 913 |
+
s = arrange(env.reset())
|
| 914 |
+
done = False
|
| 915 |
+
|
| 916 |
+
while not done:
|
| 917 |
+
if eps > np.random.rand():
|
| 918 |
+
a = env.action_space.sample()
|
| 919 |
+
else:
|
| 920 |
+
if device == "cpu":
|
| 921 |
+
a = np.argmax(q(s).detach().numpy())
|
| 922 |
+
else:
|
| 923 |
+
a = np.argmax(q(s).cpu().detach().numpy())
|
| 924 |
+
s_prime, r, done, _ = env.step(a)
|
| 925 |
+
s_prime = arrange(s_prime)
|
| 926 |
+
total_score += r
|
| 927 |
+
r = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
|
| 928 |
+
memory.push((s, float(r), int(a), s_prime, int(1 - done)))
|
| 929 |
+
s = s_prime
|
| 930 |
+
stage = env.unwrapped._stage
|
| 931 |
+
if len(memory) > 2000:
|
| 932 |
+
loss += train(q, q_target, memory, batch_size, gamma, optimizer, device)
|
| 933 |
+
t += 1
|
| 934 |
+
if t % update_interval == 0:
|
| 935 |
+
copy_weights(q, q_target)
|
| 936 |
+
torch.save(q.state_dict(), "mario_q.pth")
|
| 937 |
+
torch.save(q_target.state_dict(), "mario_q_target.pth")
|
| 938 |
+
|
| 939 |
+
if k % print_interval == 0:
|
| 940 |
+
time_spent, start_time = (
|
| 941 |
+
time.perf_counter() - start_time,
|
| 942 |
+
time.perf_counter(),
|
| 943 |
+
)
|
| 944 |
+
print(
|
| 945 |
+
"%s |Epoch : %d | score : %f | loss : %.2f | stage : %d | time spent: %f"
|
| 946 |
+
% (
|
| 947 |
+
device,
|
| 948 |
+
k,
|
| 949 |
+
total_score / print_interval,
|
| 950 |
+
loss / print_interval,
|
| 951 |
+
stage,
|
| 952 |
+
time_spent,
|
| 953 |
+
)
|
| 954 |
+
)
|
| 955 |
+
score_lst.append(total_score / print_interval)
|
| 956 |
+
total_score = 0
|
| 957 |
+
loss = 0.0
|
| 958 |
+
pickle.dump(score_lst, open("score.p", "wb"))
|
| 959 |
+
</code></pre>
|
| 960 |
+
|
| 961 |
+
<h3>Reward Shaping</h3>
|
| 962 |
+
|
| 963 |
+
<p>We apply a transformation to the rewards to improve learning:</p>
|
| 964 |
+
|
| 965 |
+
<pre><code>
|
| 966 |
+
r = np.sign(r) * (np.sqrt(abs(r) + 1) - 1) + 0.001 * r
|
| 967 |
+
</code></pre>
|
| 968 |
+
|
| 969 |
+
<p>This transformation:</p>
|
| 970 |
+
<ul>
|
| 971 |
+
<li>Compresses large rewards while preserving their sign using square root</li>
|
| 972 |
+
<li>Adds a small linear component (0.001 * r) to maintain reward differences</li>
|
| 973 |
+
<li>Helps with reward scaling issues in environments with varying reward magnitudes</li>
|
| 974 |
+
<li>Makes the learning process more stable by bounding extreme rewards</li>
|
| 975 |
+
</ul>
|
| 976 |
+
|
| 977 |
+
<h3>Device Handling</h3>
|
| 978 |
+
|
| 979 |
+
<p>The code automatically detects and uses available hardware:</p>
|
| 980 |
+
|
| 981 |
+
<pre><code>
|
| 982 |
+
device = "cpu"
|
| 983 |
+
if torch.cuda.is_available():
|
| 984 |
+
device = "cuda"
|
| 985 |
+
elif torch.backends.mps.is_available():
|
| 986 |
+
device = "mps"
|
| 987 |
+
</code></pre>
|
| 988 |
+
|
| 989 |
+
<p>This ensures optimal performance across different systems while maintaining compatibility.</p>
|
| 990 |
+
|
| 991 |
+
<h3>Model Initialization</h3>
|
| 992 |
+
|
| 993 |
+
<pre><code>
|
| 994 |
+
q = model(n_frame, env.action_space.n, device).to(device)
|
| 995 |
+
q_target = model(n_frame, env.action_space.n, device).to(device)
|
| 996 |
+
optimizer = optim.Adam(q.parameters(), lr=0.0001)
|
| 997 |
+
</code></pre>
|
| 998 |
+
|
| 999 |
+
<h3>Running the Code</h3>
|
| 1000 |
+
|
| 1001 |
+
<p>To run the complete implementation:</p>
|
| 1002 |
+
|
| 1003 |
+
<ol>
|
| 1004 |
+
<li>Install dependencies: <code>pip install gym-super-mario-bros nes_py torch</code></li>
|
| 1005 |
+
<li>Create the <code>wrappers.py</code> file with environment wrappers</li>
|
| 1006 |
+
<li>Run the script: <code>python duel_dqn.py</code></li>
|
| 1007 |
+
</ol>
|
| 1008 |
+
|
| 1009 |
+
<h3>Expected Output</h3>
|
| 1010 |
+
|
| 1011 |
+
<p>During training, you should see output like:</p>
|
| 1012 |
+
|
| 1013 |
+
<pre><code>
|
| 1014 |
+
cuda |Epoch : 0 | score : 125.400000 | loss : 0.00 | stage : 1 | time spent: 12.345678
|
| 1015 |
+
cuda |Epoch : 10 | score : 256.800000 | loss : 1.23 | stage : 1 | time spent: 15.678901
|
| 1016 |
+
cuda |Epoch : 20 | score : 512.100000 | loss : 0.87 | stage : 2 | time spent: 18.901234
|
| 1017 |
+
</code></pre>
|
| 1018 |
+
|
| 1019 |
+
<h3>Model Checkpoints</h3>
|
| 1020 |
+
|
| 1021 |
+
<p>The training process automatically saves model checkpoints:</p>
|
| 1022 |
+
<ul>
|
| 1023 |
+
<li><code>mario_q.pth</code>: Main Q-network weights</li>
|
| 1024 |
+
<li><code>mario_q_target.pth</code>: Target network weights</li>
|
| 1025 |
+
<li><code>score.p</code>: Training scores for analysis</li>
|
| 1026 |
+
</ul>
|
| 1027 |
+
|
| 1028 |
+
<div style="background-color: #f8f9fa; padding: 15px; border-radius: 5px; margin-top: 20px;">
|
| 1029 |
+
<strong>Training Time:</strong> Training will take many hours or days to produce a competent agent.
|
| 1030 |
+
You can monitor progress through the printed statistics and saved model checkpoints.
|
| 1031 |
+
For faster results, consider using a pre-trained model or reducing the environment complexity.
|
| 1032 |
+
</div>
|
| 1033 |
+
"""
|
| 1034 |
+
|
| 1035 |
+
text_edit = QTextEdit()
|
| 1036 |
+
text_edit.setHtml(content)
|
| 1037 |
+
text_edit.setReadOnly(True)
|
| 1038 |
+
layout.addWidget(text_edit)
|
| 1039 |
+
|
| 1040 |
+
return widget
|
| 1041 |
+
|
| 1042 |
+
def create_exercises_tab(self):
|
| 1043 |
+
widget = QWidget()
|
| 1044 |
+
layout = QVBoxLayout(widget)
|
| 1045 |
+
|
| 1046 |
+
# Create a tab widget for different exercises
|
| 1047 |
+
exercise_tabs = QTabWidget()
|
| 1048 |
+
|
| 1049 |
+
# Exercise 1: Replay Memory
|
| 1050 |
+
exercise1 = CodeExerciseWidget(
|
| 1051 |
+
"Exercise 1: Implement Replay Memory",
|
| 1052 |
+
"Create a replay memory class that stores transitions and can sample random batches.",
|
| 1053 |
+
"""import random
|
| 1054 |
+
from collections import deque
|
| 1055 |
+
|
| 1056 |
+
class ReplayMemory:
|
| 1057 |
+
def __init__(self, capacity):
|
| 1058 |
+
# TODO: Initialize memory with maximum capacity
|
| 1059 |
+
pass
|
| 1060 |
+
|
| 1061 |
+
def push(self, transition):
|
| 1062 |
+
# TODO: Add a transition to memory
|
| 1063 |
+
pass
|
| 1064 |
+
|
| 1065 |
+
def sample(self, batch_size):
|
| 1066 |
+
# TODO: Return a random sample of batch_size transitions
|
| 1067 |
+
pass
|
| 1068 |
+
|
| 1069 |
+
def __len__(self):
|
| 1070 |
+
# TODO: Return current size of memory
|
| 1071 |
+
pass
|
| 1072 |
+
""",
|
| 1073 |
+
"""import random
|
| 1074 |
+
from collections import deque
|
| 1075 |
+
|
| 1076 |
+
class ReplayMemory:
|
| 1077 |
+
def __init__(self, capacity):
|
| 1078 |
+
self.memory = deque(maxlen=capacity)
|
| 1079 |
+
|
| 1080 |
+
def push(self, transition):
|
| 1081 |
+
self.memory.append(transition)
|
| 1082 |
+
|
| 1083 |
+
def sample(self, batch_size):
|
| 1084 |
+
return random.sample(self.memory, batch_size)
|
| 1085 |
+
|
| 1086 |
+
def __len__(self):
|
| 1087 |
+
return len(self.memory)
|
| 1088 |
+
""",
|
| 1089 |
+
"Replay Memory"
|
| 1090 |
+
)
|
| 1091 |
+
exercise1.exerciseCompleted.connect(self.on_exercise_completed)
|
| 1092 |
+
exercise_tabs.addTab(exercise1, "Replay Memory")
|
| 1093 |
+
|
| 1094 |
+
# Exercise 2: Dueling DQN Model
|
| 1095 |
+
exercise2 = CodeExerciseWidget(
|
| 1096 |
+
"Exercise 2: Implement Dueling DQN Model",
|
| 1097 |
+
"Create the neural network with separate value and advantage streams.",
|
| 1098 |
+
"""import torch
|
| 1099 |
+
import torch.nn as nn
|
| 1100 |
+
import torch.nn.functional as F
|
| 1101 |
+
|
| 1102 |
+
class DuelingDQN(nn.Module):
|
| 1103 |
+
def __init__(self, input_shape, num_actions):
|
| 1104 |
+
super(DuelingDQN, self).__init__()
|
| 1105 |
+
# TODO: Define convolutional layers
|
| 1106 |
+
# TODO: Define value stream
|
| 1107 |
+
# TODO: Define advantage stream
|
| 1108 |
+
|
| 1109 |
+
def forward(self, x):
|
| 1110 |
+
# TODO: Implement forward pass with dueling architecture
|
| 1111 |
+
pass
|
| 1112 |
+
""",
|
| 1113 |
+
"""import torch
|
| 1114 |
+
import torch.nn as nn
|
| 1115 |
+
import torch.nn.functional as F
|
| 1116 |
+
|
| 1117 |
+
class DuelingDQN(nn.Module):
|
| 1118 |
+
def __init__(self, input_shape, num_actions):
|
| 1119 |
+
super(DuelingDQN, self).__init__()
|
| 1120 |
+
self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4)
|
| 1121 |
+
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
|
| 1122 |
+
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
|
| 1123 |
+
|
| 1124 |
+
# Calculate size after convolutions
|
| 1125 |
+
conv_size = self._get_conv_size(input_shape)
|
| 1126 |
+
|
| 1127 |
+
self.fc_adv = nn.Linear(conv_size, 512)
|
| 1128 |
+
self.fc_val = nn.Linear(conv_size, 512)
|
| 1129 |
+
|
| 1130 |
+
self.advantage = nn.Linear(512, num_actions)
|
| 1131 |
+
self.value = nn.Linear(512, 1)
|
| 1132 |
+
|
| 1133 |
+
def _get_conv_size(self, shape):
|
| 1134 |
+
x = torch.zeros(1, *shape)
|
| 1135 |
+
x = F.relu(self.conv1(x))
|
| 1136 |
+
x = F.relu(self.conv2(x))
|
| 1137 |
+
x = F.relu(self.conv3(x))
|
| 1138 |
+
return x.view(1, -1).size(1)
|
| 1139 |
+
|
| 1140 |
+
def forward(self, x):
|
| 1141 |
+
x = F.relu(self.conv1(x))
|
| 1142 |
+
x = F.relu(self.conv2(x))
|
| 1143 |
+
x = F.relu(self.conv3(x))
|
| 1144 |
+
x = x.view(x.size(0), -1)
|
| 1145 |
+
|
| 1146 |
+
adv = F.relu(self.fc_adv(x))
|
| 1147 |
+
val = F.relu(self.fc_val(x))
|
| 1148 |
+
|
| 1149 |
+
adv = self.advantage(adv)
|
| 1150 |
+
val = self.value(val).expand(x.size(0), adv.size(1))
|
| 1151 |
+
|
| 1152 |
+
# Combine using dueling formula
|
| 1153 |
+
q = val + adv - adv.mean(1, keepdim=True).expand(x.size(0), adv.size(1))
|
| 1154 |
+
return q
|
| 1155 |
+
""",
|
| 1156 |
+
"Dueling DQN"
|
| 1157 |
+
)
|
| 1158 |
+
exercise2.exerciseCompleted.connect(self.on_exercise_completed)
|
| 1159 |
+
exercise_tabs.addTab(exercise2, "Dueling DQN")
|
| 1160 |
+
|
| 1161 |
+
# Exercise 3: Environment Wrapper
|
| 1162 |
+
exercise3 = CodeExerciseWidget(
|
| 1163 |
+
"Exercise 3: Environment Wrapper",
|
| 1164 |
+
"Create a wrapper to preprocess the Super Mario Bros environment.",
|
| 1165 |
+
"""import gym
|
| 1166 |
+
import numpy as np
|
| 1167 |
+
from collections import deque
|
| 1168 |
+
|
| 1169 |
+
class MarioEnvironmentWrapper:
|
| 1170 |
+
def __init__(self, env, stack_frames=4):
|
| 1171 |
+
self.env = env
|
| 1172 |
+
self.stack_frames = stack_frames
|
| 1173 |
+
# TODO: Initialize frame stack
|
| 1174 |
+
|
| 1175 |
+
def reset(self):
|
| 1176 |
+
# TODO: Reset environment and initialize frame stack
|
| 1177 |
+
pass
|
| 1178 |
+
|
| 1179 |
+
def step(self, action):
|
| 1180 |
+
# TODO: Take step and update frame stack
|
| 1181 |
+
pass
|
| 1182 |
+
|
| 1183 |
+
def _preprocess_frame(self, frame):
|
| 1184 |
+
# TODO: Preprocess frame (grayscale, resize, normalize)
|
| 1185 |
+
pass
|
| 1186 |
+
""",
|
| 1187 |
+
"""import gym
|
| 1188 |
+
import numpy as np
|
| 1189 |
+
from collections import deque
|
| 1190 |
+
import cv2
|
| 1191 |
+
|
| 1192 |
+
class MarioEnvironmentWrapper:
|
| 1193 |
+
def __init__(self, env, stack_frames=4):
|
| 1194 |
+
self.env = env
|
| 1195 |
+
self.stack_frames = stack_frames
|
| 1196 |
+
self.frames = deque(maxlen=stack_frames)
|
| 1197 |
+
|
| 1198 |
+
def reset(self):
|
| 1199 |
+
frame = self.env.reset()
|
| 1200 |
+
frame = self._preprocess_frame(frame)
|
| 1201 |
+
for _ in range(self.stack_frames):
|
| 1202 |
+
self.frames.append(frame)
|
| 1203 |
+
return np.array(self.frames)
|
| 1204 |
+
|
| 1205 |
+
def step(self, action):
|
| 1206 |
+
next_frame, reward, done, info = self.env.step(action)
|
| 1207 |
+
next_frame = self._preprocess_frame(next_frame)
|
| 1208 |
+
self.frames.append(next_frame)
|
| 1209 |
+
return np.array(self.frames), reward, done, info
|
| 1210 |
+
|
| 1211 |
+
def _preprocess_frame(self, frame):
|
| 1212 |
+
# Convert to grayscale
|
| 1213 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 1214 |
+
# Resize
|
| 1215 |
+
frame = cv2.resize(frame, (84, 84))
|
| 1216 |
+
# Normalize
|
| 1217 |
+
frame = frame / 255.0
|
| 1218 |
+
return frame
|
| 1219 |
+
""",
|
| 1220 |
+
"Environment Wrapper"
|
| 1221 |
+
)
|
| 1222 |
+
exercise3.exerciseCompleted.connect(self.on_exercise_completed)
|
| 1223 |
+
exercise_tabs.addTab(exercise3, "Environment Wrapper")
|
| 1224 |
+
|
| 1225 |
+
# Exercise 4: Training Loop
|
| 1226 |
+
exercise4 = CodeExerciseWidget(
|
| 1227 |
+
"Exercise 4: Training Loop Implementation",
|
| 1228 |
+
"Implement the main training loop with epsilon-greedy exploration and target network updates.",
|
| 1229 |
+
"""def train_dqn(env, model, target_model, optimizer, device, episodes=1000):
|
| 1230 |
+
memory = ReplayMemory(10000)
|
| 1231 |
+
batch_size = 32
|
| 1232 |
+
gamma = 0.99
|
| 1233 |
+
epsilon = 1.0
|
| 1234 |
+
epsilon_min = 0.01
|
| 1235 |
+
epsilon_decay = 0.995
|
| 1236 |
+
target_update = 10
|
| 1237 |
+
|
| 1238 |
+
scores = []
|
| 1239 |
+
|
| 1240 |
+
for episode in range(episodes):
|
| 1241 |
+
state = env.reset()
|
| 1242 |
+
total_reward = 0
|
| 1243 |
+
done = False
|
| 1244 |
+
|
| 1245 |
+
while not done:
|
| 1246 |
+
# TODO: Implement epsilon-greedy action selection
|
| 1247 |
+
# TODO: Take action and store experience
|
| 1248 |
+
# TODO: Train model if enough experiences
|
| 1249 |
+
# TODO: Update target network periodically
|
| 1250 |
+
pass
|
| 1251 |
+
|
| 1252 |
+
# TODO: Decay epsilon
|
| 1253 |
+
# TODO: Log progress
|
| 1254 |
+
|
| 1255 |
+
return scores
|
| 1256 |
+
""",
|
| 1257 |
+
"""def train_dqn(env, model, target_model, optimizer, device, episodes=1000):
|
| 1258 |
+
memory = ReplayMemory(10000)
|
| 1259 |
+
batch_size = 32
|
| 1260 |
+
gamma = 0.99
|
| 1261 |
+
epsilon = 1.0
|
| 1262 |
+
epsilon_min = 0.01
|
| 1263 |
+
epsilon_decay = 0.995
|
| 1264 |
+
target_update = 10
|
| 1265 |
+
|
| 1266 |
+
scores = []
|
| 1267 |
+
|
| 1268 |
+
for episode in range(episodes):
|
| 1269 |
+
state = env.reset()
|
| 1270 |
+
total_reward = 0
|
| 1271 |
+
done = False
|
| 1272 |
+
steps = 0
|
| 1273 |
+
|
| 1274 |
+
while not done:
|
| 1275 |
+
# Epsilon-greedy action selection
|
| 1276 |
+
if np.random.random() <= epsilon:
|
| 1277 |
+
action = env.action_space.sample()
|
| 1278 |
+
else:
|
| 1279 |
+
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
| 1280 |
+
q_values = model(state_tensor)
|
| 1281 |
+
action = q_values.argmax().item()
|
| 1282 |
+
|
| 1283 |
+
# Take action
|
| 1284 |
+
next_state, reward, done, info = env.step(action)
|
| 1285 |
+
total_reward += reward
|
| 1286 |
+
|
| 1287 |
+
# Store experience
|
| 1288 |
+
memory.push((state, action, reward, next_state, done))
|
| 1289 |
+
state = next_state
|
| 1290 |
+
|
| 1291 |
+
# Train model if enough experiences
|
| 1292 |
+
if len(memory) > batch_size:
|
| 1293 |
+
batch = memory.sample(batch_size)
|
| 1294 |
+
states, actions, rewards, next_states, dones = zip(*batch)
|
| 1295 |
+
|
| 1296 |
+
# Convert to tensors
|
| 1297 |
+
states = torch.FloatTensor(np.array(states)).to(device)
|
| 1298 |
+
actions = torch.LongTensor(actions).to(device)
|
| 1299 |
+
rewards = torch.FloatTensor(rewards).to(device)
|
| 1300 |
+
next_states = torch.FloatTensor(np.array(next_states)).to(device)
|
| 1301 |
+
dones = torch.BoolTensor(dones).to(device)
|
| 1302 |
+
|
| 1303 |
+
# Compute Q-values
|
| 1304 |
+
current_q = model(states).gather(1, actions.unsqueeze(1))
|
| 1305 |
+
|
| 1306 |
+
# Compute target Q-values
|
| 1307 |
+
with torch.no_grad():
|
| 1308 |
+
next_actions = model(next_states).argmax(1)
|
| 1309 |
+
next_q = target_model(next_states).gather(1, next_actions.unsqueeze(1))
|
| 1310 |
+
target_q = rewards.unsqueeze(1) + gamma * next_q * (~dones).unsqueeze(1)
|
| 1311 |
+
|
| 1312 |
+
# Compute loss and update
|
| 1313 |
+
loss = F.smooth_l1_loss(current_q, target_q)
|
| 1314 |
+
optimizer.zero_grad()
|
| 1315 |
+
loss.backward()
|
| 1316 |
+
optimizer.step()
|
| 1317 |
+
|
| 1318 |
+
steps += 1
|
| 1319 |
+
|
| 1320 |
+
# Update target network
|
| 1321 |
+
if steps % target_update == 0:
|
| 1322 |
+
target_model.load_state_dict(model.state_dict())
|
| 1323 |
+
|
| 1324 |
+
# Decay epsilon
|
| 1325 |
+
epsilon = max(epsilon_min, epsilon * epsilon_decay)
|
| 1326 |
+
|
| 1327 |
+
scores.append(total_reward)
|
| 1328 |
+
|
| 1329 |
+
if episode % 100 == 0:
|
| 1330 |
+
print(f"Episode {episode}, Score: {total_reward}, Epsilon: {epsilon:.3f}")
|
| 1331 |
+
|
| 1332 |
+
return scores
|
| 1333 |
+
""",
|
| 1334 |
+
"Training Loop"
|
| 1335 |
+
)
|
| 1336 |
+
exercise4.exerciseCompleted.connect(self.on_exercise_completed)
|
| 1337 |
+
exercise_tabs.addTab(exercise4, "Training Loop")
|
| 1338 |
+
|
| 1339 |
+
# Exercise 5: Reward Shaping
|
| 1340 |
+
exercise5 = CodeExerciseWidget(
|
| 1341 |
+
"Exercise 5: Reward Shaping Function",
|
| 1342 |
+
"Implement a reward transformation function to improve learning stability.",
|
| 1343 |
+
"""def transform_reward(reward, info, prev_info=None):
|
| 1344 |
+
\"\"\"
|
| 1345 |
+
Transform the raw reward to improve learning.
|
| 1346 |
+
|
| 1347 |
+
Args:
|
| 1348 |
+
reward: Raw reward from environment
|
| 1349 |
+
info: Current step info dictionary
|
| 1350 |
+
prev_info: Previous step info (for computing deltas)
|
| 1351 |
+
|
| 1352 |
+
Returns:
|
| 1353 |
+
Transformed reward value
|
| 1354 |
+
\"\"\"
|
| 1355 |
+
transformed_reward = reward
|
| 1356 |
+
|
| 1357 |
+
# TODO: Implement reward transformations:
|
| 1358 |
+
# 1. Scale large rewards
|
| 1359 |
+
# 2. Add small positive rewards for progress
|
| 1360 |
+
# 3. Penalize time-wasting
|
| 1361 |
+
# 4. Reward level completion
|
| 1362 |
+
|
| 1363 |
+
return transformed_reward
|
| 1364 |
+
""",
|
| 1365 |
+
"""def transform_reward(reward, info, prev_info=None):
|
| 1366 |
+
\"\"\"
|
| 1367 |
+
Transform the raw reward to improve learning.
|
| 1368 |
+
|
| 1369 |
+
Args:
|
| 1370 |
+
reward: Raw reward from environment
|
| 1371 |
+
info: Current step info dictionary
|
| 1372 |
+
prev_info: Previous step info (for computing deltas)
|
| 1373 |
+
|
| 1374 |
+
Returns:
|
| 1375 |
+
Transformed reward value
|
| 1376 |
+
\"\"\"
|
| 1377 |
+
transformed_reward = reward
|
| 1378 |
+
|
| 1379 |
+
# 1. Scale large rewards using signed log transform
|
| 1380 |
+
if abs(reward) > 1:
|
| 1381 |
+
transformed_reward = np.sign(reward) * (np.sqrt(abs(reward) + 1) - 1) + 0.001 * reward
|
| 1382 |
+
|
| 1383 |
+
# 2. Add small positive reward for x-position progress
|
| 1384 |
+
if prev_info is not None:
|
| 1385 |
+
x_pos = info.get('x_pos', 0)
|
| 1386 |
+
prev_x_pos = prev_info.get('x_pos', 0)
|
| 1387 |
+
x_progress = x_pos - prev_x_pos
|
| 1388 |
+
|
| 1389 |
+
# Reward moving right, penalize moving left
|
| 1390 |
+
if x_progress > 0:
|
| 1391 |
+
transformed_reward += 0.1
|
| 1392 |
+
elif x_progress < -1: # Small left movements might be okay
|
| 1393 |
+
transformed_reward -= 0.2
|
| 1394 |
+
|
| 1395 |
+
# 3. Penalize time-wasting (standing still too long)
|
| 1396 |
+
time_penalty = -0.01
|
| 1397 |
+
transformed_reward += time_penalty
|
| 1398 |
+
|
| 1399 |
+
# 4. Large reward for completing level
|
| 1400 |
+
if info.get('flag_get', False):
|
| 1401 |
+
transformed_reward += 100
|
| 1402 |
+
|
| 1403 |
+
# 5. Reward coin collection
|
| 1404 |
+
if prev_info is not None:
|
| 1405 |
+
coins = info.get('coins', 0)
|
| 1406 |
+
prev_coins = prev_info.get('coins', 0)
|
| 1407 |
+
if coins > prev_coins:
|
| 1408 |
+
transformed_reward += 1.0
|
| 1409 |
+
|
| 1410 |
+
# 6. Reward defeating enemies
|
| 1411 |
+
if prev_info is not None:
|
| 1412 |
+
score = info.get('score', 0)
|
| 1413 |
+
prev_score = prev_info.get('score', 0)
|
| 1414 |
+
if score > prev_score and reward == 0:
|
| 1415 |
+
# Enemy defeated (score increased but no immediate reward)
|
| 1416 |
+
transformed_reward += 0.5
|
| 1417 |
+
|
| 1418 |
+
# Clip reward to reasonable range
|
| 1419 |
+
transformed_reward = np.clip(transformed_reward, -5, 10)
|
| 1420 |
+
|
| 1421 |
+
return transformed_reward
|
| 1422 |
+
|
| 1423 |
+
|
| 1424 |
+
# Alternative: The reward transformation from the original code
|
| 1425 |
+
def original_reward_transform(reward):
|
| 1426 |
+
\"\"\"The reward transformation used in the original duel_dqn.py\"\"\"
|
| 1427 |
+
return np.sign(reward) * (np.sqrt(abs(reward) + 1) - 1) + 0.001 * reward
|
| 1428 |
+
""",
|
| 1429 |
+
"Reward Shaping"
|
| 1430 |
+
)
|
| 1431 |
+
exercise5.exerciseCompleted.connect(self.on_exercise_completed)
|
| 1432 |
+
exercise_tabs.addTab(exercise5, "Reward Shaping")
|
| 1433 |
+
|
| 1434 |
+
# Exercise 6: Model Saving/Loading
|
| 1435 |
+
exercise6 = CodeExerciseWidget(
|
| 1436 |
+
"Exercise 6: Model Saving and Loading",
|
| 1437 |
+
"Implement functions to save and load trained models with proper error handling.",
|
| 1438 |
+
"""import torch
|
| 1439 |
+
import os
|
| 1440 |
+
|
| 1441 |
+
class ModelManager:
|
| 1442 |
+
def __init__(self, model_dir='models'):
|
| 1443 |
+
self.model_dir = model_dir
|
| 1444 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 1445 |
+
|
| 1446 |
+
def save_checkpoint(self, model, target_model, optimizer, episode, loss, filename):
|
| 1447 |
+
\"\"\"
|
| 1448 |
+
Save training checkpoint including model weights and training state.
|
| 1449 |
+
\"\"\"
|
| 1450 |
+
# TODO: Implement checkpoint saving
|
| 1451 |
+
pass
|
| 1452 |
+
|
| 1453 |
+
def load_checkpoint(self, filename, model, target_model=None, optimizer=None):
|
| 1454 |
+
\"\"\"
|
| 1455 |
+
Load training checkpoint and restore state.
|
| 1456 |
+
\"\"\"
|
| 1457 |
+
# TODO: Implement checkpoint loading
|
| 1458 |
+
pass
|
| 1459 |
+
|
| 1460 |
+
def save_best_model(self, model, score, filename):
|
| 1461 |
+
\"\"\"
|
| 1462 |
+
Save the best model based on validation score.
|
| 1463 |
+
\"\"\"
|
| 1464 |
+
# TODO: Implement best model saving
|
| 1465 |
+
pass
|
| 1466 |
+
""",
|
| 1467 |
+
"""import torch
|
| 1468 |
+
import os
|
| 1469 |
+
import json
|
| 1470 |
+
|
| 1471 |
+
class ModelManager:
|
| 1472 |
+
def __init__(self, model_dir='models'):
|
| 1473 |
+
self.model_dir = model_dir
|
| 1474 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 1475 |
+
|
| 1476 |
+
def save_checkpoint(self, model, target_model, optimizer, episode, loss, scores, filename):
|
| 1477 |
+
\"\"\"
|
| 1478 |
+
Save training checkpoint including model weights and training state.
|
| 1479 |
+
\"\"\"
|
| 1480 |
+
checkpoint = {
|
| 1481 |
+
'episode': episode,
|
| 1482 |
+
'model_state_dict': model.state_dict(),
|
| 1483 |
+
'target_model_state_dict': target_model.state_dict(),
|
| 1484 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 1485 |
+
'loss': loss,
|
| 1486 |
+
'scores': scores,
|
| 1487 |
+
'model_architecture': str(model)
|
| 1488 |
+
}
|
| 1489 |
+
|
| 1490 |
+
filepath = os.path.join(self.model_dir, filename)
|
| 1491 |
+
torch.save(checkpoint, filepath)
|
| 1492 |
+
|
| 1493 |
+
# Also save metadata as JSON
|
| 1494 |
+
metadata = {
|
| 1495 |
+
'episode': episode,
|
| 1496 |
+
'loss': loss,
|
| 1497 |
+
'final_score': scores[-1] if scores else 0,
|
| 1498 |
+
'timestamp': time.time()
|
| 1499 |
+
}
|
| 1500 |
+
metadata_path = os.path.join(self.model_dir, f"{filename}_metadata.json")
|
| 1501 |
+
with open(metadata_path, 'w') as f:
|
| 1502 |
+
json.dump(metadata, f, indent=2)
|
| 1503 |
+
|
| 1504 |
+
print(f"Checkpoint saved: {filepath}")
|
| 1505 |
+
|
| 1506 |
+
def load_checkpoint(self, filename, model, target_model=None, optimizer=None):
|
| 1507 |
+
\"\"\"
|
| 1508 |
+
Load training checkpoint and restore state.
|
| 1509 |
+
|
| 1510 |
+
Returns:
|
| 1511 |
+
dict: Checkpoint data including episode and scores
|
| 1512 |
+
\"\"\"
|
| 1513 |
+
filepath = os.path.join(self.model_dir, filename)
|
| 1514 |
+
|
| 1515 |
+
if not os.path.exists(filepath):
|
| 1516 |
+
raise FileNotFoundError(f"Checkpoint file not found: {filepath}")
|
| 1517 |
+
|
| 1518 |
+
checkpoint = torch.load(filepath, map_location='cpu')
|
| 1519 |
+
|
| 1520 |
+
# Load model weights
|
| 1521 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 1522 |
+
|
| 1523 |
+
# Load target model weights if provided
|
| 1524 |
+
if target_model is not None and 'target_model_state_dict' in checkpoint:
|
| 1525 |
+
target_model.load_state_dict(checkpoint['target_model_state_dict'])
|
| 1526 |
+
|
| 1527 |
+
# Load optimizer state if provided
|
| 1528 |
+
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
|
| 1529 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 1530 |
+
|
| 1531 |
+
print(f"Checkpoint loaded: {filepath}")
|
| 1532 |
+
print(f"Resuming from episode {checkpoint['episode']}")
|
| 1533 |
+
|
| 1534 |
+
return {
|
| 1535 |
+
'episode': checkpoint['episode'],
|
| 1536 |
+
'loss': checkpoint.get('loss', 0),
|
| 1537 |
+
'scores': checkpoint.get('scores', [])
|
| 1538 |
+
}
|
| 1539 |
+
|
| 1540 |
+
def save_best_model(self, model, score, filename='best_model.pth'):
|
| 1541 |
+
\"\"\"
|
| 1542 |
+
Save the best model based on validation score.
|
| 1543 |
+
\"\"\"
|
| 1544 |
+
filepath = os.path.join(self.model_dir, filename)
|
| 1545 |
+
|
| 1546 |
+
# Save model state
|
| 1547 |
+
torch.save(model.state_dict(), filepath)
|
| 1548 |
+
|
| 1549 |
+
# Save score information
|
| 1550 |
+
score_info = {
|
| 1551 |
+
'score': score,
|
| 1552 |
+
'timestamp': time.time()
|
| 1553 |
+
}
|
| 1554 |
+
score_path = os.path.join(self.model_dir, f"{filename}_score.json")
|
| 1555 |
+
with open(score_path, 'w') as f:
|
| 1556 |
+
json.dump(score_info, f, indent=2)
|
| 1557 |
+
|
| 1558 |
+
print(f"Best model saved with score {score}: {filepath}")
|
| 1559 |
+
|
| 1560 |
+
def list_checkpoints(self):
|
| 1561 |
+
\"\"\"List all available checkpoints.\"\"\"
|
| 1562 |
+
checkpoints = []
|
| 1563 |
+
for file in os.listdir(self.model_dir):
|
| 1564 |
+
if file.endswith('.pth') and not file.endswith('_metadata.json'):
|
| 1565 |
+
checkpoints.append(file)
|
| 1566 |
+
return sorted(checkpoints)
|
| 1567 |
+
|
| 1568 |
+
|
| 1569 |
+
# Usage example:
|
| 1570 |
+
def setup_model_persistence():
|
| 1571 |
+
\"\"\"Example of how to use the ModelManager\"\"\"
|
| 1572 |
+
manager = ModelManager()
|
| 1573 |
+
|
| 1574 |
+
# Example saving
|
| 1575 |
+
# manager.save_checkpoint(q, q_target, optimizer, episode, loss, scores, 'checkpoint_1000.pth')
|
| 1576 |
+
|
| 1577 |
+
# Example loading
|
| 1578 |
+
# checkpoint_data = manager.load_checkpoint('checkpoint_1000.pth', q, q_target, optimizer)
|
| 1579 |
+
# start_episode = checkpoint_data['episode'] + 1
|
| 1580 |
+
|
| 1581 |
+
return manager
|
| 1582 |
+
""",
|
| 1583 |
+
"Model Saving"
|
| 1584 |
+
)
|
| 1585 |
+
exercise6.exerciseCompleted.connect(self.on_exercise_completed)
|
| 1586 |
+
exercise_tabs.addTab(exercise6, "Model Saving")
|
| 1587 |
+
|
| 1588 |
+
# Exercise 7: Hyperparameter Tuning
|
| 1589 |
+
exercise7 = CodeExerciseWidget(
|
| 1590 |
+
"Exercise 7: Hyperparameter Configuration System",
|
| 1591 |
+
"Create a flexible configuration system for managing hyperparameters.",
|
| 1592 |
+
"""class DQNConfig:
|
| 1593 |
+
\"\"\"Configuration class for DQN hyperparameters.\"\"\"
|
| 1594 |
+
|
| 1595 |
+
def __init__(self):
|
| 1596 |
+
# TODO: Define all hyperparameters with default values
|
| 1597 |
+
pass
|
| 1598 |
+
|
| 1599 |
+
def from_dict(self, config_dict):
|
| 1600 |
+
\"\"\"Update configuration from dictionary.\"\"\"
|
| 1601 |
+
# TODO: Implement dictionary-based configuration
|
| 1602 |
+
pass
|
| 1603 |
+
|
| 1604 |
+
def to_dict(self):
|
| 1605 |
+
\"\"\"Convert configuration to dictionary.\"\"\"
|
| 1606 |
+
# TODO: Implement conversion to dictionary
|
| 1607 |
+
pass
|
| 1608 |
+
|
| 1609 |
+
def validate(self):
|
| 1610 |
+
\"\"\"Validate configuration parameters.\"\"\"
|
| 1611 |
+
# TODO: Implement validation logic
|
| 1612 |
+
pass
|
| 1613 |
+
""",
|
| 1614 |
+
"""import json
|
| 1615 |
+
import yaml
|
| 1616 |
+
from dataclasses import dataclass, asdict
|
| 1617 |
+
from typing import List, Optional
|
| 1618 |
+
|
| 1619 |
+
@dataclass
|
| 1620 |
+
class DQNConfig:
|
| 1621 |
+
\"\"\"Configuration class for DQN hyperparameters.\"\"\"
|
| 1622 |
+
|
| 1623 |
+
# Environment settings
|
| 1624 |
+
env_name: str = "SuperMarioBros-v0"
|
| 1625 |
+
action_space: str = "COMPLEX_MOVEMENT"
|
| 1626 |
+
frame_stack: int = 4
|
| 1627 |
+
frame_skip: int = 4
|
| 1628 |
+
|
| 1629 |
+
# Training hyperparameters
|
| 1630 |
+
learning_rate: float = 0.0001
|
| 1631 |
+
gamma: float = 0.99
|
| 1632 |
+
batch_size: int = 32
|
| 1633 |
+
buffer_size: int = 10000
|
| 1634 |
+
episodes: int = 10000
|
| 1635 |
+
|
| 1636 |
+
# Exploration settings
|
| 1637 |
+
epsilon_start: float = 1.0
|
| 1638 |
+
epsilon_end: float = 0.01
|
| 1639 |
+
epsilon_decay: float = 0.995
|
| 1640 |
+
epsilon_decay_steps: int = 10000
|
| 1641 |
+
|
| 1642 |
+
# Network architecture
|
| 1643 |
+
conv_filters: List[int] = None
|
| 1644 |
+
conv_kernel_sizes: List[int] = None
|
| 1645 |
+
conv_strides: List[int] = None
|
| 1646 |
+
hidden_sizes: List[int] = None
|
| 1647 |
+
|
| 1648 |
+
# Training schedule
|
| 1649 |
+
learning_starts: int = 1000
|
| 1650 |
+
target_update_frequency: int = 1000
|
| 1651 |
+
train_frequency: int = 4
|
| 1652 |
+
|
| 1653 |
+
# Reward shaping
|
| 1654 |
+
reward_scale: float = 1.0
|
| 1655 |
+
reward_clip: float = 10.0
|
| 1656 |
+
use_reward_shaping: bool = True
|
| 1657 |
+
|
| 1658 |
+
# Device and logging
|
| 1659 |
+
device: str = "auto"
|
| 1660 |
+
save_interval: int = 100
|
| 1661 |
+
log_interval: int = 10
|
| 1662 |
+
eval_interval: int = 100
|
| 1663 |
+
|
| 1664 |
+
def __post_init__(self):
|
| 1665 |
+
\"\"\"Set default values for lists.\"\"\"
|
| 1666 |
+
if self.conv_filters is None:
|
| 1667 |
+
self.conv_filters = [32, 64, 64]
|
| 1668 |
+
if self.conv_kernel_sizes is None:
|
| 1669 |
+
self.conv_kernel_sizes = [8, 4, 3]
|
| 1670 |
+
if self.conv_strides is None:
|
| 1671 |
+
self.conv_strides = [4, 2, 1]
|
| 1672 |
+
if self.hidden_sizes is None:
|
| 1673 |
+
self.hidden_sizes = [512]
|
| 1674 |
+
|
| 1675 |
+
@classmethod
|
| 1676 |
+
def from_dict(cls, config_dict):
|
| 1677 |
+
\"\"\"Create configuration from dictionary.\"\"\"
|
| 1678 |
+
return cls(**config_dict)
|
| 1679 |
+
|
| 1680 |
+
def to_dict(self):
|
| 1681 |
+
\"\"\"Convert configuration to dictionary.\"\"\"
|
| 1682 |
+
return asdict(self)
|
| 1683 |
+
|
| 1684 |
+
def save(self, filename):
|
| 1685 |
+
\"\"\"Save configuration to JSON file.\"\"\"
|
| 1686 |
+
with open(filename, 'w') as f:
|
| 1687 |
+
json.dump(self.to_dict(), f, indent=2)
|
| 1688 |
+
|
| 1689 |
+
@classmethod
|
| 1690 |
+
def load(cls, filename):
|
| 1691 |
+
\"\"\"Load configuration from JSON file.\"\"\"
|
| 1692 |
+
with open(filename, 'r') as f:
|
| 1693 |
+
config_dict = json.load(f)
|
| 1694 |
+
return cls.from_dict(config_dict)
|
| 1695 |
+
|
| 1696 |
+
def validate(self):
|
| 1697 |
+
\"\"\"Validate configuration parameters.\"\"\"
|
| 1698 |
+
assert self.learning_rate > 0, "Learning rate must be positive"
|
| 1699 |
+
assert 0 <= self.gamma <= 1, "Gamma must be between 0 and 1"
|
| 1700 |
+
assert self.batch_size > 0, "Batch size must be positive"
|
| 1701 |
+
assert self.buffer_size >= self.batch_size, "Buffer size must be >= batch size"
|
| 1702 |
+
assert 0 <= self.epsilon_end <= self.epsilon_start <= 1, "Invalid epsilon values"
|
| 1703 |
+
assert len(self.conv_filters) == len(self.conv_kernel_sizes) == len(self.conv_strides), \
|
| 1704 |
+
"Conv configuration lists must have same length"
|
| 1705 |
+
|
| 1706 |
+
|
| 1707 |
+
# Example configurations for different scenarios
|
| 1708 |
+
def get_fast_config():
|
| 1709 |
+
\"\"\"Configuration for fast training (lower quality).\"\"\"
|
| 1710 |
+
return DQNConfig(
|
| 1711 |
+
batch_size=16,
|
| 1712 |
+
buffer_size=5000,
|
| 1713 |
+
episodes=2000,
|
| 1714 |
+
epsilon_decay=0.99,
|
| 1715 |
+
target_update_frequency=500,
|
| 1716 |
+
save_interval=50
|
| 1717 |
+
)
|
| 1718 |
+
|
| 1719 |
+
def get_high_quality_config():
|
| 1720 |
+
\"\"\"Configuration for high-quality training (slower).\"\"\"
|
| 1721 |
+
return DQNConfig(
|
| 1722 |
+
batch_size=64,
|
| 1723 |
+
buffer_size=50000,
|
| 1724 |
+
episodes=50000,
|
| 1725 |
+
epsilon_decay=0.999,
|
| 1726 |
+
target_update_frequency=1000,
|
| 1727 |
+
conv_filters=[64, 128, 128],
|
| 1728 |
+
hidden_sizes=[512, 256]
|
| 1729 |
+
)
|
| 1730 |
+
|
| 1731 |
+
def get_debug_config():
|
| 1732 |
+
\"\"\"Configuration for debugging.\"\"\"
|
| 1733 |
+
return DQNConfig(
|
| 1734 |
+
episodes=100,
|
| 1735 |
+
batch_size=8,
|
| 1736 |
+
buffer_size=1000,
|
| 1737 |
+
save_interval=10,
|
| 1738 |
+
log_interval=1
|
| 1739 |
+
)
|
| 1740 |
+
|
| 1741 |
+
|
| 1742 |
+
# Usage example:
|
| 1743 |
+
def setup_training_with_config():
|
| 1744 |
+
config = DQNConfig()
|
| 1745 |
+
config.validate()
|
| 1746 |
+
|
| 1747 |
+
# Save config
|
| 1748 |
+
config.save('training_config.json')
|
| 1749 |
+
|
| 1750 |
+
# Load config
|
| 1751 |
+
# loaded_config = DQNConfig.load('training_config.json')
|
| 1752 |
+
|
| 1753 |
+
return config
|
| 1754 |
+
""",
|
| 1755 |
+
"Hyperparameter Tuning"
|
| 1756 |
+
)
|
| 1757 |
+
exercise7.exerciseCompleted.connect(self.on_exercise_completed)
|
| 1758 |
+
exercise_tabs.addTab(exercise7, "Hyperparameter Tuning")
|
| 1759 |
+
|
| 1760 |
+
# Exercise 8: Evaluation Metrics
|
| 1761 |
+
exercise8 = CodeExerciseWidget(
|
| 1762 |
+
"Exercise 8: Evaluation and Metrics System",
|
| 1763 |
+
"Implement comprehensive evaluation metrics and visualization tools.",
|
| 1764 |
+
"""import matplotlib.pyplot as plt
|
| 1765 |
+
import numpy as np
|
| 1766 |
+
from collections import deque
|
| 1767 |
+
|
| 1768 |
+
class TrainingMetrics:
|
| 1769 |
+
def __init__(self, window_size=100):
|
| 1770 |
+
# TODO: Initialize metric trackers
|
| 1771 |
+
pass
|
| 1772 |
+
|
| 1773 |
+
def update(self, episode, score, loss, epsilon, steps):
|
| 1774 |
+
# TODO: Update all metrics
|
| 1775 |
+
pass
|
| 1776 |
+
|
| 1777 |
+
def get_recent_scores(self):
|
| 1778 |
+
# TODO: Return recent scores for plotting
|
| 1779 |
+
pass
|
| 1780 |
+
|
| 1781 |
+
def plot_training_progress(self):
|
| 1782 |
+
# TODO: Create training progress visualization
|
| 1783 |
+
pass
|
| 1784 |
+
|
| 1785 |
+
def generate_report(self):
|
| 1786 |
+
# TODO: Generate training summary report
|
| 1787 |
+
pass
|
| 1788 |
+
|
| 1789 |
+
def evaluate_agent(env, model, episodes=10, render=False):
|
| 1790 |
+
\"\"\"
|
| 1791 |
+
Evaluate the trained agent on multiple episodes.
|
| 1792 |
+
\"\"\"
|
| 1793 |
+
# TODO: Implement agent evaluation
|
| 1794 |
+
pass
|
| 1795 |
+
""",
|
| 1796 |
+
"""import matplotlib.pyplot as plt
|
| 1797 |
+
import numpy as np
|
| 1798 |
+
from collections import deque
|
| 1799 |
+
import pandas as pd
|
| 1800 |
+
import seaborn as sns
|
| 1801 |
+
from typing import Dict, List, Tuple
|
| 1802 |
+
|
| 1803 |
+
class TrainingMetrics:
|
| 1804 |
+
def __init__(self, window_size=100):
|
| 1805 |
+
self.window_size = window_size
|
| 1806 |
+
|
| 1807 |
+
# Metrics storage
|
| 1808 |
+
self.episode_scores = []
|
| 1809 |
+
self.episode_losses = []
|
| 1810 |
+
self.episode_steps = []
|
| 1811 |
+
self.episode_epsilons = []
|
| 1812 |
+
self.episode_times = []
|
| 1813 |
+
|
| 1814 |
+
# Rolling averages
|
| 1815 |
+
self.recent_scores = deque(maxlen=window_size)
|
| 1816 |
+
self.recent_losses = deque(maxlen=window_size)
|
| 1817 |
+
|
| 1818 |
+
# Best performance tracking
|
| 1819 |
+
self.best_score = -np.inf
|
| 1820 |
+
self.best_episode = 0
|
| 1821 |
+
|
| 1822 |
+
def update(self, episode, score, loss, epsilon, steps, time_taken):
|
| 1823 |
+
\"\"\"Update all metrics with new episode data.\"\"\"
|
| 1824 |
+
self.episode_scores.append(score)
|
| 1825 |
+
self.episode_losses.append(loss)
|
| 1826 |
+
self.episode_epsilons.append(epsilon)
|
| 1827 |
+
self.episode_steps.append(steps)
|
| 1828 |
+
self.episode_times.append(time_taken)
|
| 1829 |
+
|
| 1830 |
+
self.recent_scores.append(score)
|
| 1831 |
+
self.recent_losses.append(loss)
|
| 1832 |
+
|
| 1833 |
+
# Update best score
|
| 1834 |
+
if score > self.best_score:
|
| 1835 |
+
self.best_score = score
|
| 1836 |
+
self.best_episode = episode
|
| 1837 |
+
|
| 1838 |
+
def get_recent_scores(self) -> List[float]:
|
| 1839 |
+
\"\"\"Return recent scores for plotting.\"\"\"
|
| 1840 |
+
return list(self.recent_scores)
|
| 1841 |
+
|
| 1842 |
+
def get_moving_averages(self) -> Dict[str, float]:
|
| 1843 |
+
\"\"\"Calculate moving averages of key metrics.\"\"\"
|
| 1844 |
+
if len(self.recent_scores) == 0:
|
| 1845 |
+
return {}
|
| 1846 |
+
|
| 1847 |
+
return {
|
| 1848 |
+
'score_ma': np.mean(self.recent_scores),
|
| 1849 |
+
'loss_ma': np.mean(self.recent_losses),
|
| 1850 |
+
'score_std': np.std(self.recent_scores),
|
| 1851 |
+
'loss_std': np.std(self.recent_losses)
|
| 1852 |
+
}
|
| 1853 |
+
|
| 1854 |
+
def plot_training_progress(self, save_path=None):
|
| 1855 |
+
\"\"\"Create comprehensive training progress visualization.\"\"\"
|
| 1856 |
+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
|
| 1857 |
+
|
| 1858 |
+
episodes = range(len(self.episode_scores))
|
| 1859 |
+
|
| 1860 |
+
# Plot 1: Scores
|
| 1861 |
+
ax1.plot(episodes, self.episode_scores, 'b-', alpha=0.3, label='Raw Scores')
|
| 1862 |
+
if len(episodes) >= self.window_size:
|
| 1863 |
+
moving_avg = pd.Series(self.episode_scores).rolling(self.window_size).mean()
|
| 1864 |
+
ax1.plot(episodes, moving_avg, 'r-', linewidth=2, label=f'Moving Avg ({self.window_size})')
|
| 1865 |
+
ax1.axhline(y=self.best_score, color='g', linestyle='--', label=f'Best: {self.best_score:.1f}')
|
| 1866 |
+
ax1.set_xlabel('Episode')
|
| 1867 |
+
ax1.set_ylabel('Score')
|
| 1868 |
+
ax1.set_title('Training Scores')
|
| 1869 |
+
ax1.legend()
|
| 1870 |
+
ax1.grid(True, alpha=0.3)
|
| 1871 |
+
|
| 1872 |
+
# Plot 2: Loss
|
| 1873 |
+
ax2.plot(episodes, self.episode_losses, 'r-', alpha=0.3)
|
| 1874 |
+
if len(episodes) >= self.window_size:
|
| 1875 |
+
moving_avg_loss = pd.Series(self.episode_losses).rolling(self.window_size).mean()
|
| 1876 |
+
ax2.plot(episodes, moving_avg_loss, 'black', linewidth=2)
|
| 1877 |
+
ax2.set_xlabel('Episode')
|
| 1878 |
+
ax2.set_ylabel('Loss')
|
| 1879 |
+
ax2.set_title('Training Loss')
|
| 1880 |
+
ax2.grid(True, alpha=0.3)
|
| 1881 |
+
|
| 1882 |
+
# Plot 3: Epsilon
|
| 1883 |
+
ax3.plot(episodes, self.episode_epsilons, 'g-')
|
| 1884 |
+
ax3.set_xlabel('Episode')
|
| 1885 |
+
ax3.set_ylabel('Epsilon')
|
| 1886 |
+
ax3.set_title('Exploration Rate')
|
| 1887 |
+
ax3.grid(True, alpha=0.3)
|
| 1888 |
+
|
| 1889 |
+
# Plot 4: Steps per episode
|
| 1890 |
+
ax4.plot(episodes, self.episode_steps, 'purple', alpha=0.3)
|
| 1891 |
+
if len(episodes) >= self.window_size:
|
| 1892 |
+
moving_avg_steps = pd.Series(self.episode_steps).rolling(self.window_size).mean()
|
| 1893 |
+
ax4.plot(episodes, moving_avg_steps, 'black', linewidth=2)
|
| 1894 |
+
ax4.set_xlabel('Episode')
|
| 1895 |
+
ax4.set_ylabel('Steps')
|
| 1896 |
+
ax4.set_title('Steps per Episode')
|
| 1897 |
+
ax4.grid(True, alpha=0.3)
|
| 1898 |
+
|
| 1899 |
+
plt.tight_layout()
|
| 1900 |
+
|
| 1901 |
+
if save_path:
|
| 1902 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 1903 |
+
print(f"Training plot saved to {save_path}")
|
| 1904 |
+
|
| 1905 |
+
plt.show()
|
| 1906 |
+
|
| 1907 |
+
def generate_report(self) -> Dict:
|
| 1908 |
+
\"\"\"Generate comprehensive training summary report.\"\"\"
|
| 1909 |
+
if not self.episode_scores:
|
| 1910 |
+
return {}
|
| 1911 |
+
|
| 1912 |
+
ma = self.get_moving_averages()
|
| 1913 |
+
|
| 1914 |
+
report = {
|
| 1915 |
+
'total_episodes': len(self.episode_scores),
|
| 1916 |
+
'best_score': self.best_score,
|
| 1917 |
+
'best_episode': self.best_episode,
|
| 1918 |
+
'final_score': self.episode_scores[-1],
|
| 1919 |
+
'average_score': np.mean(self.episode_scores),
|
| 1920 |
+
'median_score': np.median(self.episode_scores),
|
| 1921 |
+
'std_score': np.std(self.episode_scores),
|
| 1922 |
+
'average_loss': np.mean(self.episode_losses),
|
| 1923 |
+
'average_steps': np.mean(self.episode_steps),
|
| 1924 |
+
'total_training_time': np.sum(self.episode_times),
|
| 1925 |
+
'recent_score_ma': ma.get('score_ma', 0),
|
| 1926 |
+
'recent_loss_ma': ma.get('loss_ma', 0),
|
| 1927 |
+
'recent_score_std': ma.get('score_std', 0)
|
| 1928 |
+
}
|
| 1929 |
+
|
| 1930 |
+
return report
|
| 1931 |
+
|
| 1932 |
+
def save_metrics(self, filename):
|
| 1933 |
+
\"\"\"Save metrics to file.\"\"\"
|
| 1934 |
+
metrics_data = {
|
| 1935 |
+
'episode_scores': self.episode_scores,
|
| 1936 |
+
'episode_losses': self.episode_losses,
|
| 1937 |
+
'episode_epsilons': self.episode_epsilons,
|
| 1938 |
+
'episode_steps': self.episode_steps,
|
| 1939 |
+
'episode_times': self.episode_times,
|
| 1940 |
+
'best_score': self.best_score,
|
| 1941 |
+
'best_episode': self.best_episode
|
| 1942 |
+
}
|
| 1943 |
+
|
| 1944 |
+
np.savez(filename, **metrics_data)
|
| 1945 |
+
print(f"Metrics saved to {filename}")
|
| 1946 |
+
|
| 1947 |
+
|
| 1948 |
+
def evaluate_agent(env, model, device, episodes=10, render=False, epsilon=0.01):
|
| 1949 |
+
\"\"\"
|
| 1950 |
+
Evaluate the trained agent on multiple episodes.
|
| 1951 |
+
|
| 1952 |
+
Returns:
|
| 1953 |
+
dict: Evaluation metrics
|
| 1954 |
+
\"\"\"
|
| 1955 |
+
model.eval()
|
| 1956 |
+
episode_scores = []
|
| 1957 |
+
episode_steps = []
|
| 1958 |
+
episode_infos = []
|
| 1959 |
+
|
| 1960 |
+
for episode in range(episodes):
|
| 1961 |
+
state = env.reset()
|
| 1962 |
+
total_reward = 0
|
| 1963 |
+
steps = 0
|
| 1964 |
+
done = False
|
| 1965 |
+
episode_info = {}
|
| 1966 |
+
|
| 1967 |
+
while not done:
|
| 1968 |
+
if render:
|
| 1969 |
+
env.render()
|
| 1970 |
+
|
| 1971 |
+
# Epsilon-greedy evaluation (small epsilon for evaluation)
|
| 1972 |
+
if np.random.random() < epsilon:
|
| 1973 |
+
action = env.action_space.sample()
|
| 1974 |
+
else:
|
| 1975 |
+
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
| 1976 |
+
with torch.no_grad():
|
| 1977 |
+
q_values = model(state_tensor)
|
| 1978 |
+
action = q_values.argmax().item()
|
| 1979 |
+
|
| 1980 |
+
next_state, reward, done, info = env.step(action)
|
| 1981 |
+
total_reward += reward
|
| 1982 |
+
state = next_state
|
| 1983 |
+
steps += 1
|
| 1984 |
+
|
| 1985 |
+
# Store final info
|
| 1986 |
+
if done:
|
| 1987 |
+
episode_info = info.copy()
|
| 1988 |
+
|
| 1989 |
+
episode_scores.append(total_reward)
|
| 1990 |
+
episode_steps.append(steps)
|
| 1991 |
+
episode_infos.append(episode_info)
|
| 1992 |
+
|
| 1993 |
+
print(f"Evaluation Episode {episode + 1}/{episodes}: "
|
| 1994 |
+
f"Score: {total_reward:.1f}, Steps: {steps}")
|
| 1995 |
+
|
| 1996 |
+
model.train()
|
| 1997 |
+
|
| 1998 |
+
# Calculate evaluation metrics
|
| 1999 |
+
metrics = {
|
| 2000 |
+
'mean_score': np.mean(episode_scores),
|
| 2001 |
+
'std_score': np.std(episode_scores),
|
| 2002 |
+
'median_score': np.median(episode_scores),
|
| 2003 |
+
'max_score': np.max(episode_scores),
|
| 2004 |
+
'min_score': np.min(episode_scores),
|
| 2005 |
+
'mean_steps': np.mean(episode_steps),
|
| 2006 |
+
'completion_rate': sum(1 for info in episode_infos if info.get('flag_get', False)) / episodes,
|
| 2007 |
+
'average_coins': np.mean([info.get('coins', 0) for info in episode_infos]),
|
| 2008 |
+
'average_x_pos': np.mean([info.get('x_pos', 0) for info in episode_infos])
|
| 2009 |
+
}
|
| 2010 |
+
|
| 2011 |
+
print(f"\\nEvaluation Results ({episodes} episodes):")
|
| 2012 |
+
print(f"Average Score: {metrics['mean_score']:.2f} ± {metrics['std_score']:.2f}")
|
| 2013 |
+
print(f"Completion Rate: {metrics['completion_rate']:.1%}")
|
| 2014 |
+
print(f"Average Coins: {metrics['average_coins']:.1f}")
|
| 2015 |
+
print(f"Average Final X Position: {metrics['average_x_pos']:.1f}")
|
| 2016 |
+
|
| 2017 |
+
return metrics
|
| 2018 |
+
|
| 2019 |
+
|
| 2020 |
+
# Usage example:
|
| 2021 |
+
def setup_evaluation_system():
|
| 2022 |
+
metrics = TrainingMetrics(window_size=100)
|
| 2023 |
+
|
| 2024 |
+
# Example usage during training:
|
| 2025 |
+
# for episode in range(episodes):
|
| 2026 |
+
# score, loss, steps, epsilon, time_taken = run_episode(...)
|
| 2027 |
+
# metrics.update(episode, score, loss, epsilon, steps, time_taken)
|
| 2028 |
+
#
|
| 2029 |
+
# if episode % 100 == 0:
|
| 2030 |
+
# metrics.plot_training_progress()
|
| 2031 |
+
# report = metrics.generate_report()
|
| 2032 |
+
|
| 2033 |
+
return metrics
|
| 2034 |
+
""",
|
| 2035 |
+
"Evaluation"
|
| 2036 |
+
)
|
| 2037 |
+
exercise8.exerciseCompleted.connect(self.on_exercise_completed)
|
| 2038 |
+
exercise_tabs.addTab(exercise8, "Evaluation")
|
| 2039 |
+
|
| 2040 |
+
layout.addWidget(exercise_tabs)
|
| 2041 |
+
return widget
|
| 2042 |
+
|
| 2043 |
+
def change_content(self, index):
|
| 2044 |
+
self.content_tabs.setCurrentIndex(index)
|
| 2045 |
+
|
| 2046 |
+
def on_exercise_completed(self, exercise_name, completed):
|
| 2047 |
+
if completed:
|
| 2048 |
+
self.exercise_completion[exercise_name] = True
|
| 2049 |
+
self.update_progress()
|
| 2050 |
+
|
| 2051 |
+
def update_progress(self):
|
| 2052 |
+
# Update section completion from checkboxes
|
| 2053 |
+
for section, widget in self.section_widgets.items():
|
| 2054 |
+
self.section_completion[section] = widget.is_completed
|
| 2055 |
+
|
| 2056 |
+
# Calculate progress
|
| 2057 |
+
completed_sections = sum(self.section_completion.values())
|
| 2058 |
+
completed_exercises = sum(self.exercise_completion.values())
|
| 2059 |
+
total_sections = len(self.section_completion)
|
| 2060 |
+
total_exercises = len(self.exercise_completion)
|
| 2061 |
+
|
| 2062 |
+
# Calculate overall progress (50% sections, 50% exercises)
|
| 2063 |
+
section_progress = (completed_sections / total_sections) * 50
|
| 2064 |
+
exercise_progress = (completed_exercises / total_exercises) * 50
|
| 2065 |
+
overall_progress = section_progress + exercise_progress
|
| 2066 |
+
|
| 2067 |
+
# Animate progress bar
|
| 2068 |
+
current_value = self.progress_bar.value()
|
| 2069 |
+
if overall_progress > current_value:
|
| 2070 |
+
# Animate progress increase
|
| 2071 |
+
timer = QTimer(self)
|
| 2072 |
+
timer.timeout.connect(lambda: self.animate_progress(current_value, overall_progress, timer))
|
| 2073 |
+
timer.start(50)
|
| 2074 |
+
else:
|
| 2075 |
+
self.progress_bar.setValue(int(overall_progress))
|
| 2076 |
+
|
| 2077 |
+
# Update UI
|
| 2078 |
+
self.exercise_count.setText(f"Exercises: {completed_exercises}/{total_exercises} completed")
|
| 2079 |
+
self.section_count.setText(f"Sections: {completed_sections}/{total_sections} completed")
|
| 2080 |
+
|
| 2081 |
+
# Update list widget with completion indicators
|
| 2082 |
+
for i in range(self.nav_list.count()):
|
| 2083 |
+
item = self.nav_list.item(i)
|
| 2084 |
+
section_name = item.text().replace('✓ ', '')
|
| 2085 |
+
if self.section_completion.get(section_name, False):
|
| 2086 |
+
item.setText(f"✓ {section_name}")
|
| 2087 |
+
item.setBackground(QColor(40, 167, 69, 50)) # Light green background
|
| 2088 |
+
else:
|
| 2089 |
+
item.setText(section_name)
|
| 2090 |
+
item.setBackground(QColor(255, 255, 255)) # White background
|
| 2091 |
+
|
| 2092 |
+
# Show celebration message when all completed
|
| 2093 |
+
if completed_sections == total_sections and completed_exercises == total_exercises:
|
| 2094 |
+
self.show_completion_celebration()
|
| 2095 |
+
|
| 2096 |
+
def animate_progress(self, start, end, timer):
|
| 2097 |
+
current = self.progress_bar.value() + 2
|
| 2098 |
+
if current >= end:
|
| 2099 |
+
self.progress_bar.setValue(int(end))
|
| 2100 |
+
timer.stop()
|
| 2101 |
+
else:
|
| 2102 |
+
self.progress_bar.setValue(current)
|
| 2103 |
+
|
| 2104 |
+
def show_completion_celebration(self):
|
| 2105 |
+
msg = QMessageBox(self)
|
| 2106 |
+
msg.setWindowTitle("🎉 Congratulations!")
|
| 2107 |
+
msg.setText("""
|
| 2108 |
+
<h2>Amazing Achievement! 🏆</h2>
|
| 2109 |
+
<p>You've successfully completed the entire Dueling DQN tutorial!</p>
|
| 2110 |
+
<p><strong>You now have:</strong></p>
|
| 2111 |
+
<ul>
|
| 2112 |
+
<li>Solid understanding of Dueling DQN architecture</li>
|
| 2113 |
+
<li>Hands-on experience building reinforcement learning systems</li>
|
| 2114 |
+
<li>Working knowledge of PyTorch for deep RL</li>
|
| 2115 |
+
<li>Complete implementation of an AI that can play Super Mario Bros!</li>
|
| 2116 |
+
</ul>
|
| 2117 |
+
<p>Keep exploring and building amazing AI projects! 🚀</p>
|
| 2118 |
+
""")
|
| 2119 |
+
msg.setIcon(QMessageBox.Information)
|
| 2120 |
+
msg.exec_()
|
| 2121 |
+
|
| 2122 |
+
|
| 2123 |
+
def apply_styling(self):
|
| 2124 |
+
# Enhanced styling
|
| 2125 |
+
self.setStyleSheet("""
|
| 2126 |
+
QMainWindow {
|
| 2127 |
+
background: qlineargradient(x1: 0, y1: 0, x2: 1, y2: 1,
|
| 2128 |
+
stop: 0 #f8f9fa, stop: 1 #e9ecef);
|
| 2129 |
+
font-family: 'Segoe UI', Arial, sans-serif;
|
| 2130 |
+
}
|
| 2131 |
+
|
| 2132 |
+
QWidget {
|
| 2133 |
+
background: transparent;
|
| 2134 |
+
color: #2c3e50;
|
| 2135 |
+
}
|
| 2136 |
+
|
| 2137 |
+
/* Navigation Sidebar */
|
| 2138 |
+
QListWidget {
|
| 2139 |
+
background: white;
|
| 2140 |
+
border: 1px solid #dee2e6;
|
| 2141 |
+
border-radius: 8px;
|
| 2142 |
+
outline: none;
|
| 2143 |
+
font-size: 14px;
|
| 2144 |
+
}
|
| 2145 |
+
|
| 2146 |
+
QListWidget::item {
|
| 2147 |
+
background: white;
|
| 2148 |
+
border-bottom: 1px solid #f8f9fa;
|
| 2149 |
+
padding: 12px 15px;
|
| 2150 |
+
color: #495057;
|
| 2151 |
+
}
|
| 2152 |
+
|
| 2153 |
+
QListWidget::item:selected {
|
| 2154 |
+
background: qlineargradient(x1: 0, y1: 0, x2: 1, y2: 0,
|
| 2155 |
+
stop: 0 #007bff, stop: 1 #0056b3);
|
| 2156 |
+
color: white;
|
| 2157 |
+
font-weight: bold;
|
| 2158 |
+
border-radius: 4px;
|
| 2159 |
+
}
|
| 2160 |
+
|
| 2161 |
+
QListWidget::item:hover:!selected {
|
| 2162 |
+
background: #e9ecef;
|
| 2163 |
+
border-radius: 4px;
|
| 2164 |
+
}
|
| 2165 |
+
|
| 2166 |
+
/* Progress Section */
|
| 2167 |
+
QGroupBox {
|
| 2168 |
+
background: white;
|
| 2169 |
+
border: 1px solid #dee2e6;
|
| 2170 |
+
border-radius: 8px;
|
| 2171 |
+
margin-top: 10px;
|
| 2172 |
+
padding-top: 15px;
|
| 2173 |
+
font-weight: bold;
|
| 2174 |
+
color: #495057;
|
| 2175 |
+
}
|
| 2176 |
+
|
| 2177 |
+
QGroupBox::title {
|
| 2178 |
+
subcontrol-origin: margin;
|
| 2179 |
+
left: 10px;
|
| 2180 |
+
padding: 0 8px 0 8px;
|
| 2181 |
+
background: white;
|
| 2182 |
+
color: #495057;
|
| 2183 |
+
}
|
| 2184 |
+
|
| 2185 |
+
QProgressBar {
|
| 2186 |
+
border: 1px solid #ced4da;
|
| 2187 |
+
border-radius: 4px;
|
| 2188 |
+
text-align: center;
|
| 2189 |
+
background: #e9ecef;
|
| 2190 |
+
color: #495057;
|
| 2191 |
+
}
|
| 2192 |
+
|
| 2193 |
+
QProgressBar::chunk {
|
| 2194 |
+
background: qlineargradient(x1: 0, y1: 0, x2: 1, y2: 0,
|
| 2195 |
+
stop: 0 #28a745, stop: 1 #20c997);
|
| 2196 |
+
border-radius: 3px;
|
| 2197 |
+
}
|
| 2198 |
+
|
| 2199 |
+
/* Tabs */
|
| 2200 |
+
QTabWidget::pane {
|
| 2201 |
+
border: 1px solid #dee2e6;
|
| 2202 |
+
border-radius: 8px;
|
| 2203 |
+
background: white;
|
| 2204 |
+
}
|
| 2205 |
+
|
| 2206 |
+
QTabBar::tab {
|
| 2207 |
+
background: #f8f9fa;
|
| 2208 |
+
border: 1px solid #dee2e6;
|
| 2209 |
+
border-bottom: none;
|
| 2210 |
+
padding: 8px 16px;
|
| 2211 |
+
margin-right: 2px;
|
| 2212 |
+
border-top-left-radius: 4px;
|
| 2213 |
+
border-top-right-radius: 4px;
|
| 2214 |
+
color: #6c757d;
|
| 2215 |
+
}
|
| 2216 |
+
|
| 2217 |
+
QTabBar::tab:selected {
|
| 2218 |
+
background: white;
|
| 2219 |
+
border-bottom: 2px solid #007bff;
|
| 2220 |
+
color: #007bff;
|
| 2221 |
+
font-weight: bold;
|
| 2222 |
+
}
|
| 2223 |
+
|
| 2224 |
+
QTabBar::tab:hover:!selected {
|
| 2225 |
+
background: #e9ecef;
|
| 2226 |
+
color: #495057;
|
| 2227 |
+
}
|
| 2228 |
+
|
| 2229 |
+
/* Buttons */
|
| 2230 |
+
QPushButton {
|
| 2231 |
+
background: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
| 2232 |
+
stop: 0 #007bff, stop: 1 #0056b3);
|
| 2233 |
+
color: white;
|
| 2234 |
+
border: none;
|
| 2235 |
+
padding: 8px 16px;
|
| 2236 |
+
border-radius: 4px;
|
| 2237 |
+
font-weight: bold;
|
| 2238 |
+
}
|
| 2239 |
+
|
| 2240 |
+
QPushButton:hover {
|
| 2241 |
+
background: qlineargradient(x1: 0, y1: 0, x2: 0, y2: 1,
|
| 2242 |
+
stop: 0 #0056b3, stop: 1 #004085);
|
| 2243 |
+
}
|
| 2244 |
+
|
| 2245 |
+
QPushButton:pressed {
|
| 2246 |
+
background: #004085;
|
| 2247 |
+
}
|
| 2248 |
+
|
| 2249 |
+
/* Text areas */
|
| 2250 |
+
QTextEdit {
|
| 2251 |
+
background: white;
|
| 2252 |
+
border: 1px solid #dee2e6;
|
| 2253 |
+
border-radius: 5px;
|
| 2254 |
+
padding: 10px;
|
| 2255 |
+
selection-background-color: #007bff;
|
| 2256 |
+
}
|
| 2257 |
+
|
| 2258 |
+
/* Checkboxes */
|
| 2259 |
+
QCheckBox {
|
| 2260 |
+
spacing: 8px;
|
| 2261 |
+
color: #495057;
|
| 2262 |
+
}
|
| 2263 |
+
|
| 2264 |
+
QCheckBox::indicator {
|
| 2265 |
+
width: 16px;
|
| 2266 |
+
height: 16px;
|
| 2267 |
+
border: 2px solid #adb5bd;
|
| 2268 |
+
border-radius: 3px;
|
| 2269 |
+
background: white;
|
| 2270 |
+
}
|
| 2271 |
+
|
| 2272 |
+
QCheckBox::indicator:checked {
|
| 2273 |
+
background: #007bff;
|
| 2274 |
+
border: 2px solid #007bff;
|
| 2275 |
+
}
|
| 2276 |
+
|
| 2277 |
+
QCheckBox::indicator:checked:hover {
|
| 2278 |
+
background: #0056b3;
|
| 2279 |
+
border: 2px solid #0056b3;
|
| 2280 |
+
}
|
| 2281 |
+
""")
|
| 2282 |
+
|
| 2283 |
+
|
| 2284 |
+
if __name__ == "__main__":
|
| 2285 |
+
app = QApplication(sys.argv)
|
| 2286 |
+
|
| 2287 |
+
# Set application-wide font
|
| 2288 |
+
font = QFont("Segoe UI", 10)
|
| 2289 |
+
app.setFont(font)
|
| 2290 |
+
|
| 2291 |
+
window = DuelingDQNTutorialApp()
|
| 2292 |
+
window.show()
|
| 2293 |
+
|
| 2294 |
+
sys.exit(app.exec_())
|
pyqt5_duel_dqn_super_mario_bros_tutorial/flux_krea_00776_.png
ADDED
|
Git LFS Details
|
pyqt5_duel_dqn_super_mario_bros_tutorial/installed_packages_dqn.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Markdown==3.10
|
| 2 |
+
PyQt5==5.15.11
|
| 3 |
+
PyQt5-Qt5==5.15.17
|
| 4 |
+
PyQt5_sip==12.17.1
|