Lee93whut commited on
Commit ·
4f4fb4a
1
Parent(s): 3b43b04
fix(demo): auto-infer input_channels from checkpoint weight shape
Browse filesconv.0.weight.shape[1] 直接读取通道数,兼容 3-channel (R1-R3)
和 4-channel (R4) 权重,消除 size mismatch 加载失败。
app.py
CHANGED
|
@@ -186,7 +186,8 @@ def load_model(algo: str = DEFAULT_ALGO, grid_size: int = GRID_SIZE) -> tuple[Op
|
|
| 186 |
saved_gs = ckpt.get("grid_size", grid_size)
|
| 187 |
algorithm = ckpt.get("algorithm", "vanilla").strip().lower()
|
| 188 |
NetClass = DuelingDQNNetwork if "dueling" in algorithm else DQNNetwork
|
| 189 |
-
|
|
|
|
| 190 |
net.load_state_dict(ckpt["state_dict"])
|
| 191 |
net.eval()
|
| 192 |
return net, saved_gs
|
|
|
|
| 186 |
saved_gs = ckpt.get("grid_size", grid_size)
|
| 187 |
algorithm = ckpt.get("algorithm", "vanilla").strip().lower()
|
| 188 |
NetClass = DuelingDQNNetwork if "dueling" in algorithm else DQNNetwork
|
| 189 |
+
in_ch = ckpt["state_dict"]["conv.0.weight"].shape[1]
|
| 190 |
+
net = NetClass(grid_size=saved_gs, input_channels=in_ch)
|
| 191 |
net.load_state_dict(ckpt["state_dict"])
|
| 192 |
net.eval()
|
| 193 |
return net, saved_gs
|