arulloomba commited on
Commit
eb2a22a
·
verified ·
1 Parent(s): 3473eb4

Upload 3 files

Browse files
Files changed (3) hide show
  1. config.json +63 -0
  2. model.safetensors +3 -0
  3. train.py +231 -0
config.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_obs_steps": 1,
3
+ "normalization_mapping": {
4
+ "VISUAL": "MEAN_STD",
5
+ "STATE": "MEAN_STD",
6
+ "ACTION": "MEAN_STD"
7
+ },
8
+ "input_features": {
9
+ "observation.state": {
10
+ "type": "STATE",
11
+ "shape": [
12
+ 6
13
+ ]
14
+ },
15
+ "observation.images.laptop": {
16
+ "type": "VISUAL",
17
+ "shape": [
18
+ 3,
19
+ 480,
20
+ 640
21
+ ]
22
+ },
23
+ "observation.images.phone": {
24
+ "type": "VISUAL",
25
+ "shape": [
26
+ 3,
27
+ 480,
28
+ 640
29
+ ]
30
+ }
31
+ },
32
+ "output_features": {
33
+ "action": {
34
+ "type": "ACTION",
35
+ "shape": [
36
+ 6
37
+ ]
38
+ }
39
+ },
40
+ "device": "cuda",
41
+ "use_amp": false,
42
+ "chunk_size": 100,
43
+ "n_action_steps": 100,
44
+ "vision_backbone": "resnet18",
45
+ "pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
46
+ "replace_final_stride_with_dilation": false,
47
+ "pre_norm": false,
48
+ "dim_model": 512,
49
+ "n_heads": 8,
50
+ "dim_feedforward": 3200,
51
+ "feedforward_activation": "relu",
52
+ "n_encoder_layers": 4,
53
+ "n_decoder_layers": 1,
54
+ "use_vae": true,
55
+ "latent_dim": 32,
56
+ "n_vae_encoder_layers": 4,
57
+ "temporal_ensemble_coeff": null,
58
+ "dropout": 0.1,
59
+ "kl_weight": 10.0,
60
+ "optimizer_lr": 8e-05,
61
+ "optimizer_weight_decay": 0.0001,
62
+ "optimizer_lr_backbone": 1e-05
63
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c00f423fe0617078f0555f3e117822d9943c128ace49ec91d7ab0d813839ea4
3
+ size 206701072
train.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ from contextlib import nullcontext
7
+ from pprint import pformat
8
+ from typing import Any
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+ from termcolor import colored
13
+ from torch.amp import GradScaler
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+ from torch.optim import Optimizer
16
+ from torch.utils.data.distributed import DistributedSampler
17
+
18
+ from lerobot.common.datasets.factory import make_dataset
19
+ from lerobot.common.datasets.utils import cycle
20
+ from lerobot.common.envs.factory import make_env
21
+ from lerobot.common.optim.factory import make_optimizer_and_scheduler
22
+ from lerobot.common.policies.factory import make_policy
23
+ from lerobot.common.policies.pretrained import PreTrainedPolicy
24
+ from lerobot.common.policies.utils import get_device_from_parameters
25
+ from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
26
+ from lerobot.common.utils.random_utils import set_seed
27
+ from lerobot.common.utils.train_utils import (
28
+ get_step_checkpoint_dir,
29
+ get_step_identifier,
30
+ load_training_state,
31
+ save_checkpoint,
32
+ update_last_checkpoint,
33
+ )
34
+ from lerobot.common.utils.utils import (
35
+ format_big_number,
36
+ get_safe_torch_device,
37
+ has_method,
38
+ init_logging,
39
+ )
40
+ from lerobot.common.utils.wandb_utils import WandBLogger
41
+ from lerobot.configs import parser
42
+ from lerobot.configs.train import TrainPipelineConfig
43
+ from lerobot.scripts.eval import eval_policy
44
+
45
+ def update_policy(
46
+ train_metrics: MetricsTracker,
47
+ policy: PreTrainedPolicy,
48
+ batch: Any,
49
+ optimizer: Optimizer,
50
+ grad_clip_norm: float,
51
+ grad_scaler: GradScaler,
52
+ lr_scheduler=None,
53
+ use_amp: bool = False,
54
+ lock=None,
55
+ ) -> tuple[MetricsTracker, dict]:
56
+ start_time = time.perf_counter()
57
+ device = get_device_from_parameters(policy)
58
+ policy.train()
59
+ with torch.autocast(device_type=device.type) if use_amp else nullcontext():
60
+ loss, output_dict = policy.forward(batch)
61
+
62
+ grad_scaler.scale(loss).backward()
63
+ grad_scaler.unscale_(optimizer)
64
+
65
+ grad_norm = torch.nn.utils.clip_grad_norm_(
66
+ policy.parameters(), grad_clip_norm, error_if_nonfinite=False
67
+ )
68
+
69
+ with lock if lock is not None else nullcontext():
70
+ grad_scaler.step(optimizer)
71
+ grad_scaler.update()
72
+ optimizer.zero_grad()
73
+
74
+ if lr_scheduler is not None:
75
+ lr_scheduler.step()
76
+
77
+ if has_method(policy, "update"):
78
+ policy.update()
79
+
80
+ train_metrics.loss = loss.item()
81
+ train_metrics.grad_norm = grad_norm.item()
82
+ train_metrics.lr = optimizer.param_groups[0]["lr"]
83
+ train_metrics.update_s = time.perf_counter() - start_time
84
+ return train_metrics, output_dict
85
+
86
+ @parser.wrap()
87
+ def train(cfg: TrainPipelineConfig):
88
+ cfg.validate()
89
+ logging.info(pformat(cfg.to_dict()))
90
+
91
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
92
+ dist.init_process_group(backend="nccl")
93
+ local_rank = int(os.environ["LOCAL_RANK"])
94
+ torch.cuda.set_device(local_rank)
95
+ device = torch.device("cuda", local_rank)
96
+ is_main_process = (local_rank == 0)
97
+ else:
98
+ device = get_safe_torch_device(cfg.policy.device, log=True)
99
+ is_main_process = True
100
+ local_rank = 0
101
+
102
+ if cfg.seed is not None:
103
+ set_seed(cfg.seed + local_rank)
104
+
105
+ torch.backends.cudnn.benchmark = True
106
+ torch.backends.cuda.matmul.allow_tf32 = True
107
+
108
+ if cfg.wandb.enable and cfg.wandb.project and is_main_process:
109
+ wandb_logger = WandBLogger(cfg)
110
+ else:
111
+ wandb_logger = None
112
+ if is_main_process:
113
+ logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
114
+
115
+ logging.info("Creating dataset")
116
+ if is_main_process:
117
+ dataset = make_dataset(cfg)
118
+ if dist.is_initialized():
119
+ dist.barrier()
120
+ else:
121
+ if dist.is_initialized():
122
+ dist.barrier()
123
+ dataset = make_dataset(cfg)
124
+
125
+ logging.info("Creating policy")
126
+ policy = make_policy(cfg=cfg.policy, ds_meta=dataset.meta).to(device)
127
+
128
+ if dist.is_initialized():
129
+ policy = DDP(policy, device_ids=[device], output_device=device, find_unused_parameters=False)
130
+
131
+ raw_policy = policy.module if isinstance(policy, DDP) else policy
132
+
133
+ logging.info("Creating optimizer and scheduler")
134
+ optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, raw_policy)
135
+ grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
136
+
137
+ step = 0
138
+ if cfg.resume:
139
+ step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
140
+
141
+ num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
142
+ num_total_params = sum(p.numel() for p in policy.parameters())
143
+
144
+ if is_main_process:
145
+ logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
146
+ if cfg.env is not None:
147
+ logging.info(f"{cfg.env.task=}")
148
+ logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
149
+ logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
150
+ logging.info(f"{dataset.num_episodes=}")
151
+ logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
152
+ logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
153
+
154
+ sampler = DistributedSampler(dataset, shuffle=True) if dist.is_initialized() else None
155
+ dataloader = torch.utils.data.DataLoader(
156
+ dataset,
157
+ sampler=sampler,
158
+ batch_size=cfg.batch_size,
159
+ shuffle=(sampler is None),
160
+ num_workers=cfg.num_workers,
161
+ pin_memory=device.type != "cpu",
162
+ drop_last=True,
163
+ )
164
+ dl_iter = cycle(dataloader)
165
+
166
+ policy.train()
167
+
168
+ train_metrics = {
169
+ "loss": AverageMeter("loss", ":.3f"),
170
+ "grad_norm": AverageMeter("grdn", ":.3f"),
171
+ "lr": AverageMeter("lr", ":0.1e"),
172
+ "update_s": AverageMeter("updt_s", ":.3f"),
173
+ "dataloading_s": AverageMeter("data_s", ":.3f"),
174
+ }
175
+
176
+ train_tracker = MetricsTracker(
177
+ cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
178
+ )
179
+
180
+ if is_main_process:
181
+ logging.info("Start offline training on a fixed dataset")
182
+
183
+ for _ in range(step, cfg.steps):
184
+ if dist.is_initialized():
185
+ sampler.set_epoch(_)
186
+
187
+ start_time = time.perf_counter()
188
+ batch = next(dl_iter)
189
+ train_tracker.dataloading_s = time.perf_counter() - start_time
190
+
191
+ for key in batch:
192
+ if isinstance(batch[key], torch.Tensor):
193
+ batch[key] = batch[key].to(device, non_blocking=True)
194
+
195
+ train_tracker, output_dict = update_policy(
196
+ train_tracker, policy, batch, optimizer,
197
+ cfg.optimizer.grad_clip_norm, grad_scaler,
198
+ lr_scheduler=lr_scheduler, use_amp=cfg.policy.use_amp
199
+ )
200
+
201
+ step += 1
202
+ train_tracker.step()
203
+
204
+ is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
205
+ is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
206
+
207
+ if is_log_step and is_main_process:
208
+ logging.info(train_tracker)
209
+ if wandb_logger:
210
+ wandb_log_dict = train_tracker.to_dict()
211
+ if output_dict:
212
+ wandb_log_dict.update(output_dict)
213
+ wandb_logger.log_dict(wandb_log_dict, step)
214
+ train_tracker.reset_averages()
215
+
216
+ if cfg.save_checkpoint and is_saving_step and is_main_process:
217
+ logging.info(f"Checkpoint policy after step {step}")
218
+ checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
219
+ save_checkpoint(checkpoint_dir, step, cfg, policy.module if dist.is_initialized() else policy, optimizer, lr_scheduler)
220
+ update_last_checkpoint(checkpoint_dir)
221
+ if wandb_logger:
222
+ wandb_logger.log_policy(checkpoint_dir)
223
+
224
+ if dist.is_initialized():
225
+ dist.destroy_process_group()
226
+
227
+ logging.info("End of training")
228
+
229
+ if __name__ == "__main__":
230
+ init_logging()
231
+ train()