Upload 3 files
Browse files- config.json +63 -0
- model.safetensors +3 -0
- train.py +231 -0
config.json
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"n_obs_steps": 1,
|
| 3 |
+
"normalization_mapping": {
|
| 4 |
+
"VISUAL": "MEAN_STD",
|
| 5 |
+
"STATE": "MEAN_STD",
|
| 6 |
+
"ACTION": "MEAN_STD"
|
| 7 |
+
},
|
| 8 |
+
"input_features": {
|
| 9 |
+
"observation.state": {
|
| 10 |
+
"type": "STATE",
|
| 11 |
+
"shape": [
|
| 12 |
+
6
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
"observation.images.laptop": {
|
| 16 |
+
"type": "VISUAL",
|
| 17 |
+
"shape": [
|
| 18 |
+
3,
|
| 19 |
+
480,
|
| 20 |
+
640
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
"observation.images.phone": {
|
| 24 |
+
"type": "VISUAL",
|
| 25 |
+
"shape": [
|
| 26 |
+
3,
|
| 27 |
+
480,
|
| 28 |
+
640
|
| 29 |
+
]
|
| 30 |
+
}
|
| 31 |
+
},
|
| 32 |
+
"output_features": {
|
| 33 |
+
"action": {
|
| 34 |
+
"type": "ACTION",
|
| 35 |
+
"shape": [
|
| 36 |
+
6
|
| 37 |
+
]
|
| 38 |
+
}
|
| 39 |
+
},
|
| 40 |
+
"device": "cuda",
|
| 41 |
+
"use_amp": false,
|
| 42 |
+
"chunk_size": 100,
|
| 43 |
+
"n_action_steps": 100,
|
| 44 |
+
"vision_backbone": "resnet18",
|
| 45 |
+
"pretrained_backbone_weights": "ResNet18_Weights.IMAGENET1K_V1",
|
| 46 |
+
"replace_final_stride_with_dilation": false,
|
| 47 |
+
"pre_norm": false,
|
| 48 |
+
"dim_model": 512,
|
| 49 |
+
"n_heads": 8,
|
| 50 |
+
"dim_feedforward": 3200,
|
| 51 |
+
"feedforward_activation": "relu",
|
| 52 |
+
"n_encoder_layers": 4,
|
| 53 |
+
"n_decoder_layers": 1,
|
| 54 |
+
"use_vae": true,
|
| 55 |
+
"latent_dim": 32,
|
| 56 |
+
"n_vae_encoder_layers": 4,
|
| 57 |
+
"temporal_ensemble_coeff": null,
|
| 58 |
+
"dropout": 0.1,
|
| 59 |
+
"kl_weight": 10.0,
|
| 60 |
+
"optimizer_lr": 8e-05,
|
| 61 |
+
"optimizer_weight_decay": 0.0001,
|
| 62 |
+
"optimizer_lr_backbone": 1e-05
|
| 63 |
+
}
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c00f423fe0617078f0555f3e117822d9943c128ace49ec91d7ab0d813839ea4
|
| 3 |
+
size 206701072
|
train.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
from contextlib import nullcontext
|
| 7 |
+
from pprint import pformat
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.distributed as dist
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
from torch.amp import GradScaler
|
| 14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 15 |
+
from torch.optim import Optimizer
|
| 16 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 17 |
+
|
| 18 |
+
from lerobot.common.datasets.factory import make_dataset
|
| 19 |
+
from lerobot.common.datasets.utils import cycle
|
| 20 |
+
from lerobot.common.envs.factory import make_env
|
| 21 |
+
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
| 22 |
+
from lerobot.common.policies.factory import make_policy
|
| 23 |
+
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
| 24 |
+
from lerobot.common.policies.utils import get_device_from_parameters
|
| 25 |
+
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
| 26 |
+
from lerobot.common.utils.random_utils import set_seed
|
| 27 |
+
from lerobot.common.utils.train_utils import (
|
| 28 |
+
get_step_checkpoint_dir,
|
| 29 |
+
get_step_identifier,
|
| 30 |
+
load_training_state,
|
| 31 |
+
save_checkpoint,
|
| 32 |
+
update_last_checkpoint,
|
| 33 |
+
)
|
| 34 |
+
from lerobot.common.utils.utils import (
|
| 35 |
+
format_big_number,
|
| 36 |
+
get_safe_torch_device,
|
| 37 |
+
has_method,
|
| 38 |
+
init_logging,
|
| 39 |
+
)
|
| 40 |
+
from lerobot.common.utils.wandb_utils import WandBLogger
|
| 41 |
+
from lerobot.configs import parser
|
| 42 |
+
from lerobot.configs.train import TrainPipelineConfig
|
| 43 |
+
from lerobot.scripts.eval import eval_policy
|
| 44 |
+
|
| 45 |
+
def update_policy(
|
| 46 |
+
train_metrics: MetricsTracker,
|
| 47 |
+
policy: PreTrainedPolicy,
|
| 48 |
+
batch: Any,
|
| 49 |
+
optimizer: Optimizer,
|
| 50 |
+
grad_clip_norm: float,
|
| 51 |
+
grad_scaler: GradScaler,
|
| 52 |
+
lr_scheduler=None,
|
| 53 |
+
use_amp: bool = False,
|
| 54 |
+
lock=None,
|
| 55 |
+
) -> tuple[MetricsTracker, dict]:
|
| 56 |
+
start_time = time.perf_counter()
|
| 57 |
+
device = get_device_from_parameters(policy)
|
| 58 |
+
policy.train()
|
| 59 |
+
with torch.autocast(device_type=device.type) if use_amp else nullcontext():
|
| 60 |
+
loss, output_dict = policy.forward(batch)
|
| 61 |
+
|
| 62 |
+
grad_scaler.scale(loss).backward()
|
| 63 |
+
grad_scaler.unscale_(optimizer)
|
| 64 |
+
|
| 65 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 66 |
+
policy.parameters(), grad_clip_norm, error_if_nonfinite=False
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
with lock if lock is not None else nullcontext():
|
| 70 |
+
grad_scaler.step(optimizer)
|
| 71 |
+
grad_scaler.update()
|
| 72 |
+
optimizer.zero_grad()
|
| 73 |
+
|
| 74 |
+
if lr_scheduler is not None:
|
| 75 |
+
lr_scheduler.step()
|
| 76 |
+
|
| 77 |
+
if has_method(policy, "update"):
|
| 78 |
+
policy.update()
|
| 79 |
+
|
| 80 |
+
train_metrics.loss = loss.item()
|
| 81 |
+
train_metrics.grad_norm = grad_norm.item()
|
| 82 |
+
train_metrics.lr = optimizer.param_groups[0]["lr"]
|
| 83 |
+
train_metrics.update_s = time.perf_counter() - start_time
|
| 84 |
+
return train_metrics, output_dict
|
| 85 |
+
|
| 86 |
+
@parser.wrap()
|
| 87 |
+
def train(cfg: TrainPipelineConfig):
|
| 88 |
+
cfg.validate()
|
| 89 |
+
logging.info(pformat(cfg.to_dict()))
|
| 90 |
+
|
| 91 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
| 92 |
+
dist.init_process_group(backend="nccl")
|
| 93 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 94 |
+
torch.cuda.set_device(local_rank)
|
| 95 |
+
device = torch.device("cuda", local_rank)
|
| 96 |
+
is_main_process = (local_rank == 0)
|
| 97 |
+
else:
|
| 98 |
+
device = get_safe_torch_device(cfg.policy.device, log=True)
|
| 99 |
+
is_main_process = True
|
| 100 |
+
local_rank = 0
|
| 101 |
+
|
| 102 |
+
if cfg.seed is not None:
|
| 103 |
+
set_seed(cfg.seed + local_rank)
|
| 104 |
+
|
| 105 |
+
torch.backends.cudnn.benchmark = True
|
| 106 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 107 |
+
|
| 108 |
+
if cfg.wandb.enable and cfg.wandb.project and is_main_process:
|
| 109 |
+
wandb_logger = WandBLogger(cfg)
|
| 110 |
+
else:
|
| 111 |
+
wandb_logger = None
|
| 112 |
+
if is_main_process:
|
| 113 |
+
logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"]))
|
| 114 |
+
|
| 115 |
+
logging.info("Creating dataset")
|
| 116 |
+
if is_main_process:
|
| 117 |
+
dataset = make_dataset(cfg)
|
| 118 |
+
if dist.is_initialized():
|
| 119 |
+
dist.barrier()
|
| 120 |
+
else:
|
| 121 |
+
if dist.is_initialized():
|
| 122 |
+
dist.barrier()
|
| 123 |
+
dataset = make_dataset(cfg)
|
| 124 |
+
|
| 125 |
+
logging.info("Creating policy")
|
| 126 |
+
policy = make_policy(cfg=cfg.policy, ds_meta=dataset.meta).to(device)
|
| 127 |
+
|
| 128 |
+
if dist.is_initialized():
|
| 129 |
+
policy = DDP(policy, device_ids=[device], output_device=device, find_unused_parameters=False)
|
| 130 |
+
|
| 131 |
+
raw_policy = policy.module if isinstance(policy, DDP) else policy
|
| 132 |
+
|
| 133 |
+
logging.info("Creating optimizer and scheduler")
|
| 134 |
+
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, raw_policy)
|
| 135 |
+
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)
|
| 136 |
+
|
| 137 |
+
step = 0
|
| 138 |
+
if cfg.resume:
|
| 139 |
+
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
| 140 |
+
|
| 141 |
+
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
| 142 |
+
num_total_params = sum(p.numel() for p in policy.parameters())
|
| 143 |
+
|
| 144 |
+
if is_main_process:
|
| 145 |
+
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
| 146 |
+
if cfg.env is not None:
|
| 147 |
+
logging.info(f"{cfg.env.task=}")
|
| 148 |
+
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
| 149 |
+
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
| 150 |
+
logging.info(f"{dataset.num_episodes=}")
|
| 151 |
+
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
| 152 |
+
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
| 153 |
+
|
| 154 |
+
sampler = DistributedSampler(dataset, shuffle=True) if dist.is_initialized() else None
|
| 155 |
+
dataloader = torch.utils.data.DataLoader(
|
| 156 |
+
dataset,
|
| 157 |
+
sampler=sampler,
|
| 158 |
+
batch_size=cfg.batch_size,
|
| 159 |
+
shuffle=(sampler is None),
|
| 160 |
+
num_workers=cfg.num_workers,
|
| 161 |
+
pin_memory=device.type != "cpu",
|
| 162 |
+
drop_last=True,
|
| 163 |
+
)
|
| 164 |
+
dl_iter = cycle(dataloader)
|
| 165 |
+
|
| 166 |
+
policy.train()
|
| 167 |
+
|
| 168 |
+
train_metrics = {
|
| 169 |
+
"loss": AverageMeter("loss", ":.3f"),
|
| 170 |
+
"grad_norm": AverageMeter("grdn", ":.3f"),
|
| 171 |
+
"lr": AverageMeter("lr", ":0.1e"),
|
| 172 |
+
"update_s": AverageMeter("updt_s", ":.3f"),
|
| 173 |
+
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
train_tracker = MetricsTracker(
|
| 177 |
+
cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if is_main_process:
|
| 181 |
+
logging.info("Start offline training on a fixed dataset")
|
| 182 |
+
|
| 183 |
+
for _ in range(step, cfg.steps):
|
| 184 |
+
if dist.is_initialized():
|
| 185 |
+
sampler.set_epoch(_)
|
| 186 |
+
|
| 187 |
+
start_time = time.perf_counter()
|
| 188 |
+
batch = next(dl_iter)
|
| 189 |
+
train_tracker.dataloading_s = time.perf_counter() - start_time
|
| 190 |
+
|
| 191 |
+
for key in batch:
|
| 192 |
+
if isinstance(batch[key], torch.Tensor):
|
| 193 |
+
batch[key] = batch[key].to(device, non_blocking=True)
|
| 194 |
+
|
| 195 |
+
train_tracker, output_dict = update_policy(
|
| 196 |
+
train_tracker, policy, batch, optimizer,
|
| 197 |
+
cfg.optimizer.grad_clip_norm, grad_scaler,
|
| 198 |
+
lr_scheduler=lr_scheduler, use_amp=cfg.policy.use_amp
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
step += 1
|
| 202 |
+
train_tracker.step()
|
| 203 |
+
|
| 204 |
+
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
| 205 |
+
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
| 206 |
+
|
| 207 |
+
if is_log_step and is_main_process:
|
| 208 |
+
logging.info(train_tracker)
|
| 209 |
+
if wandb_logger:
|
| 210 |
+
wandb_log_dict = train_tracker.to_dict()
|
| 211 |
+
if output_dict:
|
| 212 |
+
wandb_log_dict.update(output_dict)
|
| 213 |
+
wandb_logger.log_dict(wandb_log_dict, step)
|
| 214 |
+
train_tracker.reset_averages()
|
| 215 |
+
|
| 216 |
+
if cfg.save_checkpoint and is_saving_step and is_main_process:
|
| 217 |
+
logging.info(f"Checkpoint policy after step {step}")
|
| 218 |
+
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
| 219 |
+
save_checkpoint(checkpoint_dir, step, cfg, policy.module if dist.is_initialized() else policy, optimizer, lr_scheduler)
|
| 220 |
+
update_last_checkpoint(checkpoint_dir)
|
| 221 |
+
if wandb_logger:
|
| 222 |
+
wandb_logger.log_policy(checkpoint_dir)
|
| 223 |
+
|
| 224 |
+
if dist.is_initialized():
|
| 225 |
+
dist.destroy_process_group()
|
| 226 |
+
|
| 227 |
+
logging.info("End of training")
|
| 228 |
+
|
| 229 |
+
if __name__ == "__main__":
|
| 230 |
+
init_logging()
|
| 231 |
+
train()
|