jadechoghari HF Staff commited on
Commit
02e974e
·
verified ·
1 Parent(s): ac99aae

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +77 -1
README.md CHANGED
@@ -3,4 +3,80 @@ license: gemma
3
  language:
4
  - en
5
  ---
6
- # π0 fast
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  language:
4
  - en
5
  ---
6
+ # π0 fast
7
+
8
+ π₀-FAST is a Vision-Language-Action model for general robot control that uses autoregressive next-token prediction to model continuous robot actions.
9
+
10
+ ## How to Get Started
11
+
12
+ ```python
13
+ import torch
14
+ from lerobot.policies.factory import make_pre_post_processors
15
+ import numpy as np
16
+ from lerobot.policies.pi0.modeling_pi0 import PI0FastPolicy
17
+
18
+ model_id = "lerobot/pi0fast-base"
19
+ model = PI0FastPolicy.from_pretrained(model_id)
20
+
21
+ # select your device here
22
+ device = torch.device("cuda")
23
+ preprocess, postprocess = make_pre_post_processors(
24
+ model.config,
25
+ model_id,
26
+ preprocessor_overrides={"device_processor": {"device": str(device)}},
27
+ )
28
+
29
+ IMAGE_HEIGHT = 224
30
+ IMAGE_WIDTH = 224
31
+ batch_size = 1
32
+ prompt = "Pick up the red block and place it in the bin"
33
+
34
+ # Create random RGB images in [0, 255] uint8 range (as PIL images would be)
35
+ # Then convert to [0, 1] float32 range for LeRobot
36
+ def fake_rgb(h, w):
37
+ arr = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8)
38
+ t = torch.from_numpy(arr).permute(2, 0, 1) # CHW
39
+ return t
40
+
41
+ DUMMY_STATE_DIM = 7
42
+ batch = {
43
+ f"observation.images.base_0_rgb": torch.stack(
44
+ [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
45
+ ).to(device),
46
+ f"observation.images.left_wrist_0_rgb": torch.stack(
47
+ [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
48
+ ).to(device),
49
+ f"observation.images.right_wrist_0_rgb": torch.stack(
50
+ [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)]
51
+ ).to(device),
52
+ "observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
53
+ "task": [prompt for _ in range(batch_size)],
54
+ }
55
+
56
+ batch = preprocess(batch)
57
+ action = model.select_action(batch)
58
+ # or if you're training, do:
59
+ # loss, output_dict = policy.forward(batch)
60
+ # loss.backward()
61
+ action = postprocess(action)
62
+ print(action)
63
+ ```
64
+
65
+ ## How to Train the Model
66
+
67
+ ```bash
68
+ python src/lerobot/scripts/lerobot_train.py \
69
+ --dataset.repo_id=your_dataset \
70
+ --policy.type=pi0_fast \
71
+ --output_dir=./outputs/pi0fast_training \
72
+ --job_name=pi0fast_training \
73
+ --policy.pretrained_path=lerobot/pi0fast-base \
74
+ --policy.dtype=bfloat16 \
75
+ --policy.gradient_checkpointing=true \
76
+ --policy.chunk_size=10 \
77
+ --policy.n_action_steps=10 \
78
+ --policy.max_action_tokens=256 \
79
+ --steps=100000 \
80
+ --batch_size=4 \
81
+ --policy.device=cuda
82
+ ```