Lee93whut commited on
Commit ยท
a888a00
1
Parent(s): f3ed6b3
fix(demo): strengthen anti-loop by penalizing moves toward high-frequency cells
Browse filesๅ้ป่พๅชๆฉ็ฝใ็ฆปๅผ้ซ้ขๆ ผ็argmaxๅจไฝใ๏ผๆ ๆณ้ปๆญข AโBโAโB ๆฏ่กใ
ๆฐๅข๏ผๅฏนๆฏไธชๅจไฝ้ขๅค็ฎๆ ๆ ผ๏ผ็ฎๆ ๆ ผ่ฎฟ้ฎๆฌกๆฐโฅ2ๆถๅๆ ทๆฝๅ 3.0รcnt ๆฉ็ฝ๏ผ
ไปๆบๅคดๅฐๅ ตๅๅคด่ทฏ๏ผๆถ้คไธคๆ ผๆญปๅพช็ฏใ
app.py
CHANGED
|
@@ -40,6 +40,7 @@ import yaml
|
|
| 40 |
# โโ maze_env ๅ
๏ผๅทฒๅฎ่ฃ
๏ผ็ดๆฅๅฏผๅ
ฅ๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 41 |
from maze_env import MazeEnv
|
| 42 |
from maze_env.bfs import bfs as bfs_solve
|
|
|
|
| 43 |
|
| 44 |
# โโ src ๅ
๏ผpip install -e . ๅๅฏ็ดๆฅๅฏผๅ
ฅ๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 45 |
import torch.nn as nn
|
|
@@ -247,6 +248,16 @@ def dqn_rollout(
|
|
| 247 |
action_candidate = int(q_values.argmax().item())
|
| 248 |
q_values[action_candidate] -= 3.0 * cnt
|
| 249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
action = int(q_values.argmax().item())
|
| 251 |
visited_count[cur_pos] = cnt + 1
|
| 252 |
obs, _reward, terminated, truncated, info = env.step(action)
|
|
|
|
| 40 |
# โโ maze_env ๅ
๏ผๅทฒๅฎ่ฃ
๏ผ็ดๆฅๅฏผๅ
ฅ๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 41 |
from maze_env import MazeEnv
|
| 42 |
from maze_env.bfs import bfs as bfs_solve
|
| 43 |
+
from maze_env.actions import DELTAS
|
| 44 |
|
| 45 |
# โโ src ๅ
๏ผpip install -e . ๅๅฏ็ดๆฅๅฏผๅ
ฅ๏ผโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 46 |
import torch.nn as nn
|
|
|
|
| 248 |
action_candidate = int(q_values.argmax().item())
|
| 249 |
q_values[action_candidate] -= 3.0 * cnt
|
| 250 |
|
| 251 |
+
# ๅฏนๆฏไธชๅจไฝ้ขๅค็ฎๆ ๆ ผ๏ผ่ฅ็ฎๆ ๆ ผไนๆฏ้ซ้ข่ฎฟ้ฎๆ ผๅ้ขๅคๆฉ็ฝ
|
| 252 |
+
cur_r, cur_c = cur_pos
|
| 253 |
+
N = env.grid_size
|
| 254 |
+
for a, (dr, dc) in enumerate(DELTAS):
|
| 255 |
+
nr, nc = cur_r + dr, cur_c + dc
|
| 256 |
+
if 0 <= nr < N and 0 <= nc < N:
|
| 257 |
+
next_cnt = visited_count.get((nr, nc), 0)
|
| 258 |
+
if next_cnt >= 2:
|
| 259 |
+
q_values[a] -= 3.0 * next_cnt
|
| 260 |
+
|
| 261 |
action = int(q_values.argmax().item())
|
| 262 |
visited_count[cur_pos] = cnt + 1
|
| 263 |
obs, _reward, terminated, truncated, info = env.step(action)
|