File size: 1,952 Bytes
43a95f4
8875bd7
43a95f4
 
301be88
43a95f4
 
 
 
 
 
 
301be88
43a95f4
 
 
301be88
 
8875bd7
 
 
301be88
 
 
 
43a95f4
301be88
 
43a95f4
 
301be88
 
 
43a95f4
301be88
43a95f4
301be88
 
 
 
43a95f4
301be88
 
 
 
 
 
 
 
 
 
43a95f4
 
 
8875bd7
301be88
43a95f4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# spatial-memory-checkpoints

DD-PPO PointNav checkpoints (Habitat, GPS-PointGoal task), full training
trajectory from initialisation to convergence.

| folder               | # checkpoints | frames per checkpoint |
| -------------------- | ------------- | --------------------- |
| `blind/`             | 35 (`0..34`)  | 10.06 M               |
| `coarse/`            | 50 (`0..49`)  |  5.0 M                |
| `foveated/`          | 50 (`0..49`)  |  5.0 M                |
| `foveated_logpolar/` | 50 (`0..49`)  |  5.0 M                |
| `uniform/`           | 50 (`0..49`)  |  5.0 M                |

`frames per ckpt` differs across folders, so to align at the same training
step, convert ckpt index to absolute frame count (`blind/ckpt.20.pth``coarse/ckpt.40.pth` ≈ 200 M frames).

## Load a checkpoint

```python
import torch
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(
    repo_id="alunxu/spatial-memory-checkpoints",
    filename="foveated/ckpt.49.pth",
)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
state_dict = ckpt["state_dict"]
config     = ckpt["config"]
```

Each `.pth` is a habitat-baselines checkpoint with keys `state_dict`,
`config`, and `extra_state`.

## Rebuild the policy and run rollouts

```python
from habitat_baselines.common.baseline_registry import baseline_registry

# Build env from ckpt's config (env_config = config.habitat).
policy_cls = baseline_registry.get_policy(
    config.habitat_baselines.rl.policy.name)
policy = policy_cls.from_config(
    config=config,
    observation_space=env.observation_space,
    action_space=env.action_space,
)
policy.load_state_dict(state_dict)
policy.eval()

# policy.act(...) returns (action, recurrent_hidden_states) where
# recurrent_hidden_states has shape (num_envs, num_layers, hidden_dim).
# Pass it back at the next step to keep the recurrent state.
```

Code: <https://github.com/alunxu/foveated-cog-map>.