--- license: mit library_name: pytorch tags: - robotics - progress-estimation - behavior-cloning --- # SARM Progress Prediction Stage-aware progress prediction model for robot manipulation tasks ## Model Description SARM predicts: - **Progress**: How far through the task (0.0 to 1.0) - **Stage**: Which stage of the task is being executed The model uses a transformer architecture to process sequences of RGB images and robot states. **Task**: clearing_food_from_table_into_fridge **Dataset**: IliaLarchenko/behavior_224_rgb ## Model Details ### Architecture - **Type**: Transformer with dual prediction heads (stage classification + progress regression) - **Model dimension**: 768 - **Attention heads**: 12 - **Transformer layers**: 8 - **MLP dimension**: 512 - **Number of stages**: 100 - **Number of tasks**: 50 ### Training Details - **Checkpoint**: `best_model.pt` - **Training step**: 4800 - **Epoch**: unknown - **Training loss**: unknown - **Validation loss**: 1.0865614609792829 - **Batch size**: 16 - **Learning rate**: 0.0001 - **Max sequence length**: 13 ## Usage ### Download and Load Model ```python from hf_model_hub import download_model_from_hub from model import SARM import torch import json # Download model and config files = download_model_from_hub( repo_id="YOUR_USERNAME/YOUR_REPO", checkpoint_name="best_model.pt", output_dir="./downloaded_model" ) # Load config with open(files["config"], "r") as f: config = json.load(f) # Create model model_config = config["model"] model = SARM( d_model=model_config["d_model"], n_heads=model_config["n_heads"], n_layers=model_config["n_layers"], d_mlp=model_config["d_mlp"], num_stages=model_config["num_stages"], d_state=model_config["d_state"], num_tasks=model_config["num_tasks"], ) # Load checkpoint checkpoint = torch.load(files["checkpoint"]) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() ``` ### Run Inference ```python # Assuming you have images and states prepared with torch.no_grad(): stage_logits, progress = model(images, states, tasks, padding_mask) # Get predictions for the last frame predicted_stage = stage_logits[:, -1].argmax(dim=-1) predicted_progress = progress[:, -1] ``` ## Training Data This model was trained on the **IliaLarchenko/behavior_224_rgb** for robot manipulation tasks. Training episodes: 90 episodes Validation episodes: 15 episodes ## Intended Use - Progress estimation for robot manipulation tasks - Stage classification for multi-step tasks - Adaptive window sampling for VLA training - Task monitoring and intervention detection ## Limitations - Trained on specific tasks from BEHAVIOR dataset - Requires RGB images (224x224) and robot state information - Fixed sequence length input ## Citation If you use this model, please cite: ```bibtex @misc{sarm-model, author = {Your Name}, title = {SARM Progress Prediction}, year = {2025}, publisher = {HuggingFace}, url = {https://huggingface.co/YOUR_USERNAME/YOUR_REPO} } ``` ## Training Configuration
Click to expand full training configuration ```json { "metadata": { "model_name": "SARM Progress Prediction", "description": "Stage-aware progress prediction model for robot manipulation tasks", "task": "clearing_food_from_table_into_fridge", "task_number": 25, "dataset": "IliaLarchenko/behavior_224_rgb", "version": "1.0", "author": "Your Name", "tags": [ "robotics", "progress-estimation", "behavior-cloning" ] }, "model": { "d_model": 768, "n_heads": 12, "n_layers": 8, "d_mlp": 512, "num_stages": 100, "d_state": 256, "num_tasks": 50 }, "training": { "max_steps": 10000, "learning_rate": 0.0001, "weight_decay": 0.0001, "batch_size": 16, "gradient_accumulation_steps": 4, "max_grad_norm": 1.0, "scheduler": "cosine", "stage_loss_weight": 1.0, "progress_loss_weight": 1.0, "validation_steps": 100, "save_steps": 200 }, "data": { "max_sequence_length": 13, "image_size": 224, "num_workers": 10, "val_workers": 10, "val_samples": 500, "train_episodes": [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90 ], "val_episodes": [ 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105 ], "seed": 42 }, "logging": { "project_name": "sarm-training", "run_name": null, "log_freq": 10, "checkpoint_dir": "checkpoints_sarm_25_2" } } ```