attilczuk commited on
Commit
b8d1b9e
·
verified ·
1 Parent(s): 8e2213b

Upload model card for best_model.pt

Browse files
Files changed (1) hide show
  1. README.md +300 -0
README.md ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ library_name: pytorch
4
+ tags:
5
+ - robotics
6
+ - progress-estimation
7
+ - behavior-cloning
8
+ ---
9
+
10
+ # SARM Progress Prediction
11
+
12
+ Stage-aware progress prediction model for robot manipulation tasks
13
+
14
+ ## Model Description
15
+
16
+ SARM predicts:
17
+ - **Progress**: How far through the task (0.0 to 1.0)
18
+ - **Stage**: Which stage of the task is being executed
19
+
20
+ The model uses a transformer architecture to process sequences of RGB images and robot states.
21
+
22
+ **Task**: clearing_food_from_table_into_fridge
23
+ **Dataset**: IliaLarchenko/behavior_224_rgb
24
+
25
+ ## Model Details
26
+
27
+ ### Architecture
28
+ - **Type**: Transformer with dual prediction heads (stage classification + progress regression)
29
+ - **Model dimension**: 768
30
+ - **Attention heads**: 12
31
+ - **Transformer layers**: 8
32
+ - **MLP dimension**: 512
33
+ - **Number of stages**: 100
34
+ - **Number of tasks**: 50
35
+
36
+ ### Training Details
37
+ - **Checkpoint**: `best_model.pt`
38
+ - **Training step**: 4800
39
+ - **Epoch**: unknown
40
+ - **Training loss**: unknown
41
+ - **Validation loss**: 1.0865614609792829
42
+ - **Batch size**: 16
43
+ - **Learning rate**: 0.0001
44
+ - **Max sequence length**: 13
45
+
46
+ ## Usage
47
+
48
+ ### Download and Load Model
49
+
50
+ ```python
51
+ from hf_model_hub import download_model_from_hub
52
+ from model import SARM
53
+ import torch
54
+ import json
55
+
56
+ # Download model and config
57
+ files = download_model_from_hub(
58
+ repo_id="YOUR_USERNAME/YOUR_REPO",
59
+ checkpoint_name="best_model.pt",
60
+ output_dir="./downloaded_model"
61
+ )
62
+
63
+ # Load config
64
+ with open(files["config"], "r") as f:
65
+ config = json.load(f)
66
+
67
+ # Create model
68
+ model_config = config["model"]
69
+ model = SARM(
70
+ d_model=model_config["d_model"],
71
+ n_heads=model_config["n_heads"],
72
+ n_layers=model_config["n_layers"],
73
+ d_mlp=model_config["d_mlp"],
74
+ num_stages=model_config["num_stages"],
75
+ d_state=model_config["d_state"],
76
+ num_tasks=model_config["num_tasks"],
77
+ )
78
+
79
+ # Load checkpoint
80
+ checkpoint = torch.load(files["checkpoint"])
81
+ model.load_state_dict(checkpoint["model_state_dict"])
82
+ model.eval()
83
+ ```
84
+
85
+ ### Run Inference
86
+
87
+ ```python
88
+ # Assuming you have images and states prepared
89
+ with torch.no_grad():
90
+ stage_logits, progress = model(images, states, tasks, padding_mask)
91
+
92
+ # Get predictions for the last frame
93
+ predicted_stage = stage_logits[:, -1].argmax(dim=-1)
94
+ predicted_progress = progress[:, -1]
95
+ ```
96
+
97
+ ## Training Data
98
+
99
+ This model was trained on the **IliaLarchenko/behavior_224_rgb** for robot manipulation tasks.
100
+
101
+ Training episodes: 90 episodes
102
+ Validation episodes: 15 episodes
103
+
104
+ ## Intended Use
105
+
106
+ - Progress estimation for robot manipulation tasks
107
+ - Stage classification for multi-step tasks
108
+ - Adaptive window sampling for VLA training
109
+ - Task monitoring and intervention detection
110
+
111
+ ## Limitations
112
+
113
+ - Trained on specific tasks from BEHAVIOR dataset
114
+ - Requires RGB images (224x224) and robot state information
115
+ - Fixed sequence length input
116
+
117
+ ## Citation
118
+
119
+ If you use this model, please cite:
120
+
121
+ ```bibtex
122
+ @misc{sarm-model,
123
+ author = {Your Name},
124
+ title = {SARM Progress Prediction},
125
+ year = {2025},
126
+ publisher = {HuggingFace},
127
+ url = {https://huggingface.co/YOUR_USERNAME/YOUR_REPO}
128
+ }
129
+ ```
130
+
131
+ ## Training Configuration
132
+
133
+ <details>
134
+ <summary>Click to expand full training configuration</summary>
135
+
136
+ ```json
137
+ {
138
+ "metadata": {
139
+ "model_name": "SARM Progress Prediction",
140
+ "description": "Stage-aware progress prediction model for robot manipulation tasks",
141
+ "task": "clearing_food_from_table_into_fridge",
142
+ "task_number": 25,
143
+ "dataset": "IliaLarchenko/behavior_224_rgb",
144
+ "version": "1.0",
145
+ "author": "Your Name",
146
+ "tags": [
147
+ "robotics",
148
+ "progress-estimation",
149
+ "behavior-cloning"
150
+ ]
151
+ },
152
+ "model": {
153
+ "d_model": 768,
154
+ "n_heads": 12,
155
+ "n_layers": 8,
156
+ "d_mlp": 512,
157
+ "num_stages": 100,
158
+ "d_state": 256,
159
+ "num_tasks": 50
160
+ },
161
+ "training": {
162
+ "max_steps": 10000,
163
+ "learning_rate": 0.0001,
164
+ "weight_decay": 0.0001,
165
+ "batch_size": 16,
166
+ "gradient_accumulation_steps": 4,
167
+ "max_grad_norm": 1.0,
168
+ "scheduler": "cosine",
169
+ "stage_loss_weight": 1.0,
170
+ "progress_loss_weight": 1.0,
171
+ "validation_steps": 100,
172
+ "save_steps": 200
173
+ },
174
+ "data": {
175
+ "max_sequence_length": 13,
176
+ "image_size": 224,
177
+ "num_workers": 10,
178
+ "val_workers": 10,
179
+ "val_samples": 500,
180
+ "train_episodes": [
181
+ 1,
182
+ 2,
183
+ 3,
184
+ 4,
185
+ 5,
186
+ 6,
187
+ 7,
188
+ 8,
189
+ 9,
190
+ 10,
191
+ 11,
192
+ 12,
193
+ 13,
194
+ 14,
195
+ 15,
196
+ 16,
197
+ 17,
198
+ 18,
199
+ 19,
200
+ 20,
201
+ 21,
202
+ 22,
203
+ 23,
204
+ 24,
205
+ 25,
206
+ 26,
207
+ 27,
208
+ 28,
209
+ 29,
210
+ 30,
211
+ 31,
212
+ 32,
213
+ 33,
214
+ 34,
215
+ 35,
216
+ 36,
217
+ 37,
218
+ 38,
219
+ 39,
220
+ 40,
221
+ 41,
222
+ 42,
223
+ 43,
224
+ 44,
225
+ 45,
226
+ 46,
227
+ 47,
228
+ 48,
229
+ 49,
230
+ 50,
231
+ 51,
232
+ 52,
233
+ 53,
234
+ 54,
235
+ 55,
236
+ 56,
237
+ 57,
238
+ 58,
239
+ 59,
240
+ 60,
241
+ 61,
242
+ 62,
243
+ 63,
244
+ 64,
245
+ 65,
246
+ 66,
247
+ 67,
248
+ 68,
249
+ 69,
250
+ 70,
251
+ 71,
252
+ 72,
253
+ 73,
254
+ 74,
255
+ 75,
256
+ 76,
257
+ 77,
258
+ 78,
259
+ 79,
260
+ 80,
261
+ 81,
262
+ 82,
263
+ 83,
264
+ 84,
265
+ 85,
266
+ 86,
267
+ 87,
268
+ 88,
269
+ 89,
270
+ 90
271
+ ],
272
+ "val_episodes": [
273
+ 91,
274
+ 92,
275
+ 93,
276
+ 94,
277
+ 95,
278
+ 96,
279
+ 97,
280
+ 98,
281
+ 99,
282
+ 100,
283
+ 101,
284
+ 102,
285
+ 103,
286
+ 104,
287
+ 105
288
+ ],
289
+ "seed": 42
290
+ },
291
+ "logging": {
292
+ "project_name": "sarm-training",
293
+ "run_name": null,
294
+ "log_freq": 10,
295
+ "checkpoint_dir": "checkpoints_sarm_25_2"
296
+ }
297
+ }
298
+ ```
299
+
300
+ </details>