Lee93whut commited on
Commit ·
c8377dc
1
Parent(s): 006f45e
fix(demo): re-enable inference-side anti-loop Q-penalty
Browse filesR4 visited_map (ch3) makes Q-function Markov-correct, but coverage
gaps in training leave some states prone to 2-cell oscillation loops.
Add back inference-side guard (does not affect training distribution):
- Track per-cell visit count in visited_count dict
- When cnt >= 2: penalise current argmax action by 3.0 * cnt
- Force argmax on penalised Q-values → breaks oscillation
- q_values.clone() ensures original tensor is not mutated
Two-layer design:
Training layer: visited_map ch3 encodes history → Q internalisesit
Inference layer: Q-penalty as safety net for under-covered states
app.py
CHANGED
|
@@ -228,14 +228,26 @@ def dqn_rollout(
|
|
| 228 |
|
| 229 |
path = [env.agent_pos]
|
| 230 |
|
| 231 |
-
#
|
| 232 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
while True:
|
| 234 |
s = torch.from_numpy(obs).unsqueeze(0)
|
| 235 |
with torch.no_grad():
|
| 236 |
-
q_values = net(s)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
|
| 238 |
action = int(q_values.argmax().item())
|
|
|
|
| 239 |
obs, _reward, terminated, truncated, info = env.step(action)
|
| 240 |
|
| 241 |
# 只在实际移动时追加(撞墙时位置不变,避免重复坐标导致动画抖帧)
|
|
|
|
| 228 |
|
| 229 |
path = [env.agent_pos]
|
| 230 |
|
| 231 |
+
# 推理侧 anti-loop 兜底:visited_map(ch3)已让 Q 函数内化访问历史,
|
| 232 |
+
# 但对未充分覆盖的状态仍可能陷入两格死循环。
|
| 233 |
+
# 访问次数 >= 2 时对当前 argmax 动作施加递进 Q 值惩罚作为安全网,
|
| 234 |
+
# 不修改网络权重,不影响训练分布。
|
| 235 |
+
visited_count: dict[tuple, int] = {}
|
| 236 |
+
|
| 237 |
while True:
|
| 238 |
s = torch.from_numpy(obs).unsqueeze(0)
|
| 239 |
with torch.no_grad():
|
| 240 |
+
q_values = net(s)[0].clone() # shape: (num_actions,)
|
| 241 |
+
|
| 242 |
+
# 对高频重访格子的当前最优动作施加惩罚
|
| 243 |
+
cur_pos = env.agent_pos
|
| 244 |
+
cnt = visited_count.get(cur_pos, 0)
|
| 245 |
+
if cnt >= 2:
|
| 246 |
+
action_candidate = int(q_values.argmax().item())
|
| 247 |
+
q_values[action_candidate] -= 3.0 * cnt
|
| 248 |
|
| 249 |
action = int(q_values.argmax().item())
|
| 250 |
+
visited_count[cur_pos] = cnt + 1
|
| 251 |
obs, _reward, terminated, truncated, info = env.step(action)
|
| 252 |
|
| 253 |
# 只在实际移动时追加(撞墙时位置不变,避免重复坐标导致动画抖帧)
|