Lee93whut commited on
Commit
4f4fb4a
·
1 Parent(s): 3b43b04

fix(demo): auto-infer input_channels from checkpoint weight shape

Browse files

conv.0.weight.shape[1] 直接读取通道数,兼容 3-channel (R1-R3)
和 4-channel (R4) 权重,消除 size mismatch 加载失败。

Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -186,7 +186,8 @@ def load_model(algo: str = DEFAULT_ALGO, grid_size: int = GRID_SIZE) -> tuple[Op
186
  saved_gs = ckpt.get("grid_size", grid_size)
187
  algorithm = ckpt.get("algorithm", "vanilla").strip().lower()
188
  NetClass = DuelingDQNNetwork if "dueling" in algorithm else DQNNetwork
189
- net = NetClass(grid_size=saved_gs)
 
190
  net.load_state_dict(ckpt["state_dict"])
191
  net.eval()
192
  return net, saved_gs
 
186
  saved_gs = ckpt.get("grid_size", grid_size)
187
  algorithm = ckpt.get("algorithm", "vanilla").strip().lower()
188
  NetClass = DuelingDQNNetwork if "dueling" in algorithm else DQNNetwork
189
+ in_ch = ckpt["state_dict"]["conv.0.weight"].shape[1]
190
+ net = NetClass(grid_size=saved_gs, input_channels=in_ch)
191
  net.load_state_dict(ckpt["state_dict"])
192
  net.eval()
193
  return net, saved_gs