File size: 7,961 Bytes
a8eb6e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import importlib
import os
from unittest.mock import MagicMock, patch

import pytest
from safetensors.torch import load_file

from .utils import require_package

# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
    os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
    reason="This test requires peft and is very slow, not meant for CI",
)


def run_command(cmd, module, args):
    module = importlib.import_module(f"lerobot.scripts.{module}")
    with patch("sys.argv", [cmd] + args):
        module.main()


def lerobot_train(args):
    return run_command(cmd="lerobot-train", module="lerobot_train", args=args)


def lerobot_record(args):
    return run_command(cmd="lerobot-record", module="lerobot_record", args=args)


def resolve_model_id_for_peft_training(policy_type):
    """PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training."""
    if policy_type == "smolvla":
        return "lerobot/smolvla_base"

    raise ValueError(f"No pretrained model known for {policy_type}. PEFT training will not work.")


@pytest.mark.parametrize("policy_type", ["smolvla"])
@require_package("peft")
def test_peft_training_push_to_hub_works(policy_type, tmp_path):
    """Ensure that push to hub stores PEFT only the adapter, not the full model weights."""
    output_dir = tmp_path / f"output_{policy_type}"
    upload_folder_contents = set()

    model_id = resolve_model_id_for_peft_training(policy_type)

    def mock_upload_folder(*args, **kwargs):
        folder_path = kwargs["folder_path"]
        # we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders()
        upload_folder_contents.update(os.listdir(folder_path))
        return MagicMock()

    with (
        patch("huggingface_hub.HfApi.create_repo"),
        patch("huggingface_hub.HfApi.upload_folder", mock_upload_folder),
    ):
        lerobot_train(
            [
                f"--policy.path={model_id}",
                "--policy.push_to_hub=true",
                "--policy.repo_id=foo/bar",
                "--policy.input_features=null",
                "--policy.output_features=null",
                "--peft.method=LORA",
                "--dataset.repo_id=lerobot/pusht",
                "--dataset.episodes=[0, 1]",
                "--steps=1",
                f"--output_dir={output_dir}",
            ]
        )

        assert "adapter_model.safetensors" in upload_folder_contents
        assert "config.json" in upload_folder_contents
        assert "adapter_config.json" in upload_folder_contents


@pytest.mark.parametrize("policy_type", ["smolvla"])
@require_package("peft")
def test_peft_training_works(policy_type, tmp_path):
    """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
    output_dir = tmp_path / f"output_{policy_type}"
    model_id = resolve_model_id_for_peft_training(policy_type)

    lerobot_train(
        [
            f"--policy.path={model_id}",
            "--policy.push_to_hub=false",
            "--policy.input_features=null",
            "--policy.output_features=null",
            "--peft.method=LORA",
            "--dataset.repo_id=lerobot/pusht",
            "--dataset.episodes=[0, 1]",
            "--steps=1",
            f"--output_dir={output_dir}",
        ]
    )

    policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model"

    for file in ["adapter_config.json", "adapter_model.safetensors", "config.json"]:
        assert (policy_dir / file).exists()

    # This is the default case where we train a pre-trained policy from scratch with new data.
    # We assume that we target policy-specific modules but fully fine-tune action and state projections
    # so these must be part of the trained state dict.
    state_dict = load_file(policy_dir / "adapter_model.safetensors")

    adapted_keys = [
        "state_proj",
        "action_in_proj",
        "action_out_proj",
        "action_time_mlp_in",
        "action_time_mlp_out",
    ]

    found_keys = [
        module_key
        for module_key in adapted_keys
        for state_dict_key in state_dict
        if f".{module_key}." in state_dict_key
    ]

    assert set(found_keys) == set(adapted_keys)


@pytest.mark.parametrize("policy_type", ["smolvla"])
@require_package("peft")
def test_peft_training_params_are_fewer(policy_type, tmp_path):
    """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works."""
    output_dir = tmp_path / f"output_{policy_type}"
    model_id = resolve_model_id_for_peft_training(policy_type)

    def dummy_update_policy(
        train_metrics, policy, batch, optimizer, grad_clip_norm: float, accelerator, **kwargs
    ):
        params_total = sum(p.numel() for p in policy.parameters())
        params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad)

        assert params_total > params_trainable

        return train_metrics, {}

    with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy):
        lerobot_train(
            [
                f"--policy.path={model_id}",
                "--policy.push_to_hub=false",
                "--policy.input_features=null",
                "--policy.output_features=null",
                "--peft.method=LORA",
                "--dataset.repo_id=lerobot/pusht",
                "--dataset.episodes=[0, 1]",
                "--steps=1",
                f"--output_dir={output_dir}",
            ]
        )


class DummyRobot:
    name = "dummy"
    cameras = []
    action_features = {"foo": 1.0, "bar": 2.0}
    observation_features = {"obs1": 1.0, "obs2": 2.0}
    is_connected = True

    def connect(self, *args):
        pass

    def disconnect(self):
        pass


def dummy_make_robot_from_config(*args, **kwargs):
    return DummyRobot()


@pytest.mark.parametrize("policy_type", ["smolvla"])
@require_package("peft")
def test_peft_record_loads_policy(policy_type, tmp_path):
    """Train a policy with PEFT and attempt to load it with `lerobot-record`."""
    from peft import PeftModel

    output_dir = tmp_path / f"output_{policy_type}"
    model_id = resolve_model_id_for_peft_training(policy_type)

    lerobot_train(
        [
            f"--policy.path={model_id}",
            "--policy.push_to_hub=false",
            "--policy.input_features=null",
            "--policy.output_features=null",
            "--peft.method=LORA",
            "--dataset.repo_id=lerobot/pusht",
            "--dataset.episodes=[0, 1]",
            "--steps=1",
            f"--output_dir={output_dir}",
        ]
    )

    policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model"
    dataset_dir = tmp_path / "eval_pusht"
    single_task = "move the table"
    loaded_policy = None

    def dummy_record_loop(*args, **kwargs):
        nonlocal loaded_policy

        if "dataset" not in kwargs:
            return

        dataset = kwargs["dataset"]
        dataset.add_frame({"task": single_task})
        loaded_policy = kwargs["policy"]

    with (
        patch("lerobot.scripts.lerobot_record.make_robot_from_config", dummy_make_robot_from_config),
        # disable record loop since we're only interested in successful loading of the policy.
        patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop),
        # disable speech output
        patch("lerobot.utils.utils.say"),
    ):
        lerobot_record(
            [
                f"--policy.path={policy_dir}",
                "--robot.type=so101_follower",
                "--robot.port=/dev/null",
                "--dataset.repo_id=lerobot/eval_pusht",
                f'--dataset.single_task="{single_task}"',
                f"--dataset.root={dataset_dir}",
                "--dataset.push_to_hub=false",
            ]
        )

        assert isinstance(loaded_policy, PeftModel)