DragMesh-2 / README.md
Neptune-T's picture
Simplify checkpoint variant names
4ab9bec
|
Raw
History Blame Contribute Delete
3.04 kB
---
library_name: pytorch
tags:
- reinforcement-learning
- robotics
- manipulation
- hand-object-interaction
- pytorch
license: other
---
# DragMesh-2 Evaluation Checkpoints
This repository contains the PyTorch policy checkpoints used for the
DragMesh-2 main-table evaluation. The release includes seven policy variants
evaluated on seven object-part tasks, for a total of 49 checkpoints.
The corresponding hand-object interaction trajectories are available in the
[DragMesh-2 dataset](https://huggingface.co/datasets/AIGeeksGroup/DragMesh-2).
The referenced objects and actionable parts originate from
[GAPartNet](https://pku-epic.github.io/GAPartNet/); GAPartNet assets are not
included in this model repository.
## Repository layout
```text
checkpoints/
<variant>/
object_<object_id>/
handle_<part_id>/
policy.pth
model_manifest.jsonl
```
Each path explicitly identifies the evaluated experiment, source object, and
manipulated part. `model_manifest.jsonl` records the normalized checkpoint
path, original relative path, object category, file size, and SHA-256 digest.
## Policy variants
| Name | Variant |
| --- | --- |
| `state` | State-only PPO baseline |
| `history` | Flat-history PPO baseline |
| `gla` | GLA without the auxiliary objective |
| `pica` | PICA without the GLA auxiliary objective (`v2c`) |
| `dragmesh2` | DragMesh-2 PICA policy |
| `gru` | GRU PPO baseline |
| `transformer` | Transformer PPO baseline |
## Evaluation tasks
| Category | Object ID | Part ID |
| --- | --- | --- |
| Dishwasher | `12583` | `handle_1` |
| Microwave | `7310` | `handle_1` |
| StorageFurniture | `45261` | `handle_7` |
| StorageFurniture | `45661` | `handle_3` |
| StorageFurniture | `45936` | `handle_1` |
| StorageFurniture | `46440` | `handle_5` |
| StorageFurniture | `48513` | `handle_2` |
## Checkpoint format
Each `policy.pth` is a PyTorch training checkpoint with these top-level fields:
- `model`: policy state dictionary;
- `optimizer`: optimizer state;
- `running_mean_std`: observation normalization state;
- `reward_mean_std`: reward normalization state;
- `epoch`, `frame`, and `last_mean_rewards`: training metadata;
- `env_state`: serialized environment state when available.
Load checkpoints only in a trusted environment. With PyTorch 2.6 or later,
the checkpoint can be read in weights-only mode by allowlisting the NumPy
scalar types stored in its training metadata:
```python
from pathlib import Path
import numpy as np
import torch
checkpoint_path = Path(
"checkpoints/dragmesh2/"
"object_45661/handle_3/policy.pth"
)
safe_globals = [
np.core.multiarray.scalar,
np.dtype,
np.dtypes.Float32DType,
]
with torch.serialization.safe_globals(safe_globals):
checkpoint = torch.load(
checkpoint_path,
map_location="cpu",
weights_only=True,
)
policy_state_dict = checkpoint["model"]
```
The policy architecture and observation configuration must match the
corresponding DragMesh-2 experiment when restoring a checkpoint.