Lee93whut commited on
Commit
c8377dc
·
1 Parent(s): 006f45e

fix(demo): re-enable inference-side anti-loop Q-penalty

Browse files

R4 visited_map (ch3) makes Q-function Markov-correct, but coverage
gaps in training leave some states prone to 2-cell oscillation loops.

Add back inference-side guard (does not affect training distribution):
- Track per-cell visit count in visited_count dict
- When cnt >= 2: penalise current argmax action by 3.0 * cnt
- Force argmax on penalised Q-values → breaks oscillation
- q_values.clone() ensures original tensor is not mutated

Two-layer design:
Training layer: visited_map ch3 encodes history → Q internalisesit
Inference layer: Q-penalty as safety net for under-covered states

Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -228,14 +228,26 @@ def dqn_rollout(
228
 
229
  path = [env.agent_pos]
230
 
231
- # 注:R4 起观测已包含 visited_map 第4通道(ch3),Agent 天然感知访问历史,
232
- # 无需在推理侧注 Q 值惩罚直接贪心执行即可。
 
 
 
 
233
  while True:
234
  s = torch.from_numpy(obs).unsqueeze(0)
235
  with torch.no_grad():
236
- q_values = net(s)[0] # shape: (num_actions,)
 
 
 
 
 
 
 
237
 
238
  action = int(q_values.argmax().item())
 
239
  obs, _reward, terminated, truncated, info = env.step(action)
240
 
241
  # 只在实际移动时追加(撞墙时位置不变,避免重复坐标导致动画抖帧)
 
228
 
229
  path = [env.agent_pos]
230
 
231
+ # 推理侧 anti-loop 兜底:visited_map(ch3)已让 Q 函数内化访问历史,
232
+ # 但对未充分覆盖的状态仍可能陷两格死循环
233
+ # 访问次数 >= 2 时对当前 argmax 动作施加递进 Q 值惩罚作为安全网,
234
+ # 不修改网络权重,不影响训练分布。
235
+ visited_count: dict[tuple, int] = {}
236
+
237
  while True:
238
  s = torch.from_numpy(obs).unsqueeze(0)
239
  with torch.no_grad():
240
+ q_values = net(s)[0].clone() # shape: (num_actions,)
241
+
242
+ # 对高频重访格子的当前最优动作施加惩罚
243
+ cur_pos = env.agent_pos
244
+ cnt = visited_count.get(cur_pos, 0)
245
+ if cnt >= 2:
246
+ action_candidate = int(q_values.argmax().item())
247
+ q_values[action_candidate] -= 3.0 * cnt
248
 
249
  action = int(q_values.argmax().item())
250
+ visited_count[cur_pos] = cnt + 1
251
  obs, _reward, terminated, truncated, info = env.step(action)
252
 
253
  # 只在实际移动时追加(撞墙时位置不变,避免重复坐标导致动画抖帧)