alunxu's picture
README: strip idea-revealing framing; load-only
43a95f4 verified

spatial-memory-checkpoints

DD-PPO PointNav checkpoints (Habitat, GPS-PointGoal task), full training trajectory from initialisation to convergence.

folder # checkpoints frames per checkpoint
blind/ 35 (0..34) 10.06 M
coarse/ 50 (0..49) 5.0 M
foveated/ 50 (0..49) 5.0 M
foveated_logpolar/ 50 (0..49) 5.0 M
uniform/ 50 (0..49) 5.0 M

frames per ckpt differs across folders, so to align at the same training step, convert ckpt index to absolute frame count (blind/ckpt.20.pthcoarse/ckpt.40.pth ≈ 200 M frames).

Load a checkpoint

import torch
from huggingface_hub import hf_hub_download

ckpt_path = hf_hub_download(
    repo_id="alunxu/spatial-memory-checkpoints",
    filename="foveated/ckpt.49.pth",
)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
state_dict = ckpt["state_dict"]
config     = ckpt["config"]

Each .pth is a habitat-baselines checkpoint with keys state_dict, config, and extra_state.

Rebuild the policy and run rollouts

from habitat_baselines.common.baseline_registry import baseline_registry

# Build env from ckpt's config (env_config = config.habitat).
policy_cls = baseline_registry.get_policy(
    config.habitat_baselines.rl.policy.name)
policy = policy_cls.from_config(
    config=config,
    observation_space=env.observation_space,
    action_space=env.action_space,
)
policy.load_state_dict(state_dict)
policy.eval()

# policy.act(...) returns (action, recurrent_hidden_states) where
# recurrent_hidden_states has shape (num_envs, num_layers, hidden_dim).
# Pass it back at the next step to keep the recurrent state.

Code: https://github.com/alunxu/foveated-cog-map.