|
|
--- |
|
|
license: mit |
|
|
library_name: pytorch |
|
|
tags: |
|
|
- robotics |
|
|
- progress-estimation |
|
|
- behavior-cloning |
|
|
--- |
|
|
|
|
|
# SARM Progress Prediction |
|
|
|
|
|
Stage-aware progress prediction model for robot manipulation tasks |
|
|
|
|
|
## Model Description |
|
|
|
|
|
SARM predicts: |
|
|
- **Progress**: How far through the task (0.0 to 1.0) |
|
|
- **Stage**: Which stage of the task is being executed |
|
|
|
|
|
The model uses a transformer architecture to process sequences of RGB images and robot states. |
|
|
|
|
|
**Task**: clearing_food_from_table_into_fridge |
|
|
**Dataset**: IliaLarchenko/behavior_224_rgb |
|
|
|
|
|
## Model Details |
|
|
|
|
|
### Architecture |
|
|
- **Type**: Transformer with dual prediction heads (stage classification + progress regression) |
|
|
- **Model dimension**: 768 |
|
|
- **Attention heads**: 12 |
|
|
- **Transformer layers**: 8 |
|
|
- **MLP dimension**: 512 |
|
|
- **Number of stages**: 100 |
|
|
- **Number of tasks**: 50 |
|
|
|
|
|
### Training Details |
|
|
- **Checkpoint**: `best_model.pt` |
|
|
- **Training step**: 4800 |
|
|
- **Epoch**: unknown |
|
|
- **Training loss**: unknown |
|
|
- **Validation loss**: 1.0865614609792829 |
|
|
- **Batch size**: 16 |
|
|
- **Learning rate**: 0.0001 |
|
|
- **Max sequence length**: 13 |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Download and Load Model |
|
|
|
|
|
```python |
|
|
from hf_model_hub import download_model_from_hub |
|
|
from model import SARM |
|
|
import torch |
|
|
import json |
|
|
|
|
|
# Download model and config |
|
|
files = download_model_from_hub( |
|
|
repo_id="YOUR_USERNAME/YOUR_REPO", |
|
|
checkpoint_name="best_model.pt", |
|
|
output_dir="./downloaded_model" |
|
|
) |
|
|
|
|
|
# Load config |
|
|
with open(files["config"], "r") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
# Create model |
|
|
model_config = config["model"] |
|
|
model = SARM( |
|
|
d_model=model_config["d_model"], |
|
|
n_heads=model_config["n_heads"], |
|
|
n_layers=model_config["n_layers"], |
|
|
d_mlp=model_config["d_mlp"], |
|
|
num_stages=model_config["num_stages"], |
|
|
d_state=model_config["d_state"], |
|
|
num_tasks=model_config["num_tasks"], |
|
|
) |
|
|
|
|
|
# Load checkpoint |
|
|
checkpoint = torch.load(files["checkpoint"]) |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
model.eval() |
|
|
``` |
|
|
|
|
|
### Run Inference |
|
|
|
|
|
```python |
|
|
# Assuming you have images and states prepared |
|
|
with torch.no_grad(): |
|
|
stage_logits, progress = model(images, states, tasks, padding_mask) |
|
|
|
|
|
# Get predictions for the last frame |
|
|
predicted_stage = stage_logits[:, -1].argmax(dim=-1) |
|
|
predicted_progress = progress[:, -1] |
|
|
``` |
|
|
|
|
|
## Training Data |
|
|
|
|
|
This model was trained on the **IliaLarchenko/behavior_224_rgb** for robot manipulation tasks. |
|
|
|
|
|
Training episodes: 90 episodes |
|
|
Validation episodes: 15 episodes |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
- Progress estimation for robot manipulation tasks |
|
|
- Stage classification for multi-step tasks |
|
|
- Adaptive window sampling for VLA training |
|
|
- Task monitoring and intervention detection |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Trained on specific tasks from BEHAVIOR dataset |
|
|
- Requires RGB images (224x224) and robot state information |
|
|
- Fixed sequence length input |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use this model, please cite: |
|
|
|
|
|
```bibtex |
|
|
@misc{sarm-model, |
|
|
author = {Your Name}, |
|
|
title = {SARM Progress Prediction}, |
|
|
year = {2025}, |
|
|
publisher = {HuggingFace}, |
|
|
url = {https://huggingface.co/YOUR_USERNAME/YOUR_REPO} |
|
|
} |
|
|
``` |
|
|
|
|
|
## Training Configuration |
|
|
|
|
|
<details> |
|
|
<summary>Click to expand full training configuration</summary> |
|
|
|
|
|
```json |
|
|
{ |
|
|
"metadata": { |
|
|
"model_name": "SARM Progress Prediction", |
|
|
"description": "Stage-aware progress prediction model for robot manipulation tasks", |
|
|
"task": "clearing_food_from_table_into_fridge", |
|
|
"task_number": 25, |
|
|
"dataset": "IliaLarchenko/behavior_224_rgb", |
|
|
"version": "1.0", |
|
|
"author": "Your Name", |
|
|
"tags": [ |
|
|
"robotics", |
|
|
"progress-estimation", |
|
|
"behavior-cloning" |
|
|
] |
|
|
}, |
|
|
"model": { |
|
|
"d_model": 768, |
|
|
"n_heads": 12, |
|
|
"n_layers": 8, |
|
|
"d_mlp": 512, |
|
|
"num_stages": 100, |
|
|
"d_state": 256, |
|
|
"num_tasks": 50 |
|
|
}, |
|
|
"training": { |
|
|
"max_steps": 10000, |
|
|
"learning_rate": 0.0001, |
|
|
"weight_decay": 0.0001, |
|
|
"batch_size": 16, |
|
|
"gradient_accumulation_steps": 4, |
|
|
"max_grad_norm": 1.0, |
|
|
"scheduler": "cosine", |
|
|
"stage_loss_weight": 1.0, |
|
|
"progress_loss_weight": 1.0, |
|
|
"validation_steps": 100, |
|
|
"save_steps": 200 |
|
|
}, |
|
|
"data": { |
|
|
"max_sequence_length": 13, |
|
|
"image_size": 224, |
|
|
"num_workers": 10, |
|
|
"val_workers": 10, |
|
|
"val_samples": 500, |
|
|
"train_episodes": [ |
|
|
1, |
|
|
2, |
|
|
3, |
|
|
4, |
|
|
5, |
|
|
6, |
|
|
7, |
|
|
8, |
|
|
9, |
|
|
10, |
|
|
11, |
|
|
12, |
|
|
13, |
|
|
14, |
|
|
15, |
|
|
16, |
|
|
17, |
|
|
18, |
|
|
19, |
|
|
20, |
|
|
21, |
|
|
22, |
|
|
23, |
|
|
24, |
|
|
25, |
|
|
26, |
|
|
27, |
|
|
28, |
|
|
29, |
|
|
30, |
|
|
31, |
|
|
32, |
|
|
33, |
|
|
34, |
|
|
35, |
|
|
36, |
|
|
37, |
|
|
38, |
|
|
39, |
|
|
40, |
|
|
41, |
|
|
42, |
|
|
43, |
|
|
44, |
|
|
45, |
|
|
46, |
|
|
47, |
|
|
48, |
|
|
49, |
|
|
50, |
|
|
51, |
|
|
52, |
|
|
53, |
|
|
54, |
|
|
55, |
|
|
56, |
|
|
57, |
|
|
58, |
|
|
59, |
|
|
60, |
|
|
61, |
|
|
62, |
|
|
63, |
|
|
64, |
|
|
65, |
|
|
66, |
|
|
67, |
|
|
68, |
|
|
69, |
|
|
70, |
|
|
71, |
|
|
72, |
|
|
73, |
|
|
74, |
|
|
75, |
|
|
76, |
|
|
77, |
|
|
78, |
|
|
79, |
|
|
80, |
|
|
81, |
|
|
82, |
|
|
83, |
|
|
84, |
|
|
85, |
|
|
86, |
|
|
87, |
|
|
88, |
|
|
89, |
|
|
90 |
|
|
], |
|
|
"val_episodes": [ |
|
|
91, |
|
|
92, |
|
|
93, |
|
|
94, |
|
|
95, |
|
|
96, |
|
|
97, |
|
|
98, |
|
|
99, |
|
|
100, |
|
|
101, |
|
|
102, |
|
|
103, |
|
|
104, |
|
|
105 |
|
|
], |
|
|
"seed": 42 |
|
|
}, |
|
|
"logging": { |
|
|
"project_name": "sarm-training", |
|
|
"run_name": null, |
|
|
"log_freq": 10, |
|
|
"checkpoint_dir": "checkpoints_sarm_25_2" |
|
|
} |
|
|
} |
|
|
``` |
|
|
|
|
|
</details> |
|
|
|