| # spatial-memory-checkpoints |
|
|
| DD-PPO PointNav checkpoints (Habitat, GPS-PointGoal task), full training |
| trajectory from initialisation to convergence. |
|
|
| | folder | # checkpoints | frames per checkpoint | |
| | -------------------- | ------------- | --------------------- | |
| | `blind/` | 35 (`0..34`) | 10.06 M | |
| | `coarse/` | 50 (`0..49`) | 5.0 M | |
| | `foveated/` | 50 (`0..49`) | 5.0 M | |
| | `foveated_logpolar/` | 50 (`0..49`) | 5.0 M | |
| | `uniform/` | 50 (`0..49`) | 5.0 M | |
|
|
| `frames per ckpt` differs across folders, so to align at the same training |
| step, convert ckpt index to absolute frame count (`blind/ckpt.20.pth` ≈ |
| `coarse/ckpt.40.pth` ≈ 200 M frames). |
|
|
| ## Load a checkpoint |
|
|
| ```python |
| import torch |
| from huggingface_hub import hf_hub_download |
| |
| ckpt_path = hf_hub_download( |
| repo_id="alunxu/spatial-memory-checkpoints", |
| filename="foveated/ckpt.49.pth", |
| ) |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| state_dict = ckpt["state_dict"] |
| config = ckpt["config"] |
| ``` |
|
|
| Each `.pth` is a habitat-baselines checkpoint with keys `state_dict`, |
| `config`, and `extra_state`. |
|
|
| ## Rebuild the policy and run rollouts |
|
|
| ```python |
| from habitat_baselines.common.baseline_registry import baseline_registry |
| |
| # Build env from ckpt's config (env_config = config.habitat). |
| policy_cls = baseline_registry.get_policy( |
| config.habitat_baselines.rl.policy.name) |
| policy = policy_cls.from_config( |
| config=config, |
| observation_space=env.observation_space, |
| action_space=env.action_space, |
| ) |
| policy.load_state_dict(state_dict) |
| policy.eval() |
| |
| # policy.act(...) returns (action, recurrent_hidden_states) where |
| # recurrent_hidden_states has shape (num_envs, num_layers, hidden_dim). |
| # Pass it back at the next step to keep the recurrent state. |
| ``` |
|
|
| Code: <https://github.com/alunxu/foveated-cog-map>. |
|
|