Lee93whut commited on
Commit
a888a00
ยท
1 Parent(s): f3ed6b3

fix(demo): strengthen anti-loop by penalizing moves toward high-frequency cells

Browse files

ๅŽŸ้€ป่พ‘ๅชๆƒฉ็ฝšใ€Œ็ฆปๅผ€้ซ˜้ข‘ๆ ผ็š„argmaxๅŠจไฝœใ€๏ผŒๆ— ๆณ•้˜ปๆญข Aโ†’Bโ†’Aโ†’B ๆŒฏ่กใ€‚
ๆ–ฐๅขž๏ผšๅฏนๆฏไธชๅŠจไฝœ้ข„ๅˆค็›ฎๆ ‡ๆ ผ๏ผŒ็›ฎๆ ‡ๆ ผ่ฎฟ้—ฎๆฌกๆ•ฐโ‰ฅ2ๆ—ถๅŒๆ ทๆ–ฝๅŠ  3.0ร—cnt ๆƒฉ็ฝš๏ผŒ
ไปŽๆบๅคดๅฐๅ ตๅ›žๅคด่ทฏ๏ผŒๆถˆ้™คไธคๆ ผๆญปๅพช็Žฏใ€‚

Files changed (1) hide show
  1. app.py +11 -0
app.py CHANGED
@@ -40,6 +40,7 @@ import yaml
40
  # โ”€โ”€ maze_env ๅŒ…๏ผˆๅทฒๅฎ‰่ฃ…๏ผŒ็›ดๆŽฅๅฏผๅ…ฅ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
41
  from maze_env import MazeEnv
42
  from maze_env.bfs import bfs as bfs_solve
 
43
 
44
  # โ”€โ”€ src ๅŒ…๏ผˆpip install -e . ๅŽๅฏ็›ดๆŽฅๅฏผๅ…ฅ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
45
  import torch.nn as nn
@@ -247,6 +248,16 @@ def dqn_rollout(
247
  action_candidate = int(q_values.argmax().item())
248
  q_values[action_candidate] -= 3.0 * cnt
249
 
 
 
 
 
 
 
 
 
 
 
250
  action = int(q_values.argmax().item())
251
  visited_count[cur_pos] = cnt + 1
252
  obs, _reward, terminated, truncated, info = env.step(action)
 
40
  # โ”€โ”€ maze_env ๅŒ…๏ผˆๅทฒๅฎ‰่ฃ…๏ผŒ็›ดๆŽฅๅฏผๅ…ฅ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
41
  from maze_env import MazeEnv
42
  from maze_env.bfs import bfs as bfs_solve
43
+ from maze_env.actions import DELTAS
44
 
45
  # โ”€โ”€ src ๅŒ…๏ผˆpip install -e . ๅŽๅฏ็›ดๆŽฅๅฏผๅ…ฅ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
46
  import torch.nn as nn
 
248
  action_candidate = int(q_values.argmax().item())
249
  q_values[action_candidate] -= 3.0 * cnt
250
 
251
+ # ๅฏนๆฏไธชๅŠจไฝœ้ข„ๅˆค็›ฎๆ ‡ๆ ผ๏ผŒ่‹ฅ็›ฎๆ ‡ๆ ผไนŸๆ˜ฏ้ซ˜้ข‘่ฎฟ้—ฎๆ ผๅˆ™้ขๅค–ๆƒฉ็ฝš
252
+ cur_r, cur_c = cur_pos
253
+ N = env.grid_size
254
+ for a, (dr, dc) in enumerate(DELTAS):
255
+ nr, nc = cur_r + dr, cur_c + dc
256
+ if 0 <= nr < N and 0 <= nc < N:
257
+ next_cnt = visited_count.get((nr, nc), 0)
258
+ if next_cnt >= 2:
259
+ q_values[a] -= 3.0 * next_cnt
260
+
261
  action = int(q_values.argmax().item())
262
  visited_count[cur_pos] = cnt + 1
263
  obs, _reward, terminated, truncated, info = env.step(action)