ar27111994 commited on
Commit
c1ff22f
·
verified ·
1 Parent(s): 76a149d

Upload lewm_train.py

Browse files
Files changed (1) hide show
  1. lewm_train.py +404 -0
lewm_train.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LeWorldModel (LeWM) Training Script
3
+ Reference: Maes et al., 2026 — Stable End-to-End JEPA from Pixels
4
+ arXiv: 2603.19312
5
+
6
+ This script trains LeWM on trajectory data (observations + actions).
7
+ Supports both real HDF5 datasets and a synthetic PushT-like benchmark
8
+ for rapid smoke-testing.
9
+ """
10
+
11
+ import os
12
+ import argparse
13
+ import math
14
+ import numpy as np
15
+ import h5py
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from einops import rearrange
21
+ from transformers import get_cosine_schedule_with_warmup
22
+
23
+ from lewm_model import build_lewm, SIGReg
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Dataset: HDF5 trajectory loader
28
+ # ---------------------------------------------------------------------------
29
+ class TrajectoryDataset(Dataset):
30
+ """
31
+ Loads offline trajectories from an HDF5 file.
32
+ Expected keys (standard from DINO-WM / LeWM datasets):
33
+ observations/pixels (N_episodes, T_max, H, W, C) uint8
34
+ actions (N_episodes, T_max, A) float32
35
+ We extract sub-trajectories of length `seq_len` with frame_skip.
36
+ """
37
+ def __init__(self, h5_path, seq_len=4, frameskip=5, img_size=224,
38
+ train=True, train_split=0.95):
39
+ self.seq_len = seq_len
40
+ self.frameskip = frameskip
41
+ self.img_size = img_size
42
+ self.train = train
43
+
44
+ with h5py.File(h5_path, 'r') as f:
45
+ pixels = f['observations']['pixels'][:] # (N, T, H, W, C)
46
+ actions = f['actions'][:] # (N, T, A)
47
+
48
+ # Convert to torch tensors
49
+ self.pixels = torch.from_numpy(pixels).permute(0, 1, 4, 2, 3).float() / 255.0 # (N,T,C,H,W)
50
+ self.actions = torch.from_numpy(actions).float()
51
+
52
+ # Pre-compute episode boundaries
53
+ N, T_max = self.pixels.shape[:2]
54
+ n_train = int(N * train_split)
55
+ if train:
56
+ self.pixels = self.pixels[:n_train]
57
+ self.actions = self.actions[:n_train]
58
+ else:
59
+ self.pixels = self.pixels[n_train:]
60
+ self.actions = self.actions[n_train:]
61
+
62
+ N, T_max = self.pixels.shape[:2]
63
+ self.indices = []
64
+ for ep in range(N):
65
+ valid = T_max - (seq_len * frameskip) - 1
66
+ if valid > 0:
67
+ for start in range(0, valid, frameskip):
68
+ self.indices.append((ep, start))
69
+
70
+ def __len__(self):
71
+ return len(self.indices)
72
+
73
+ def __getitem__(self, idx):
74
+ ep, start = self.indices[idx]
75
+ fs = self.frameskip
76
+ # Sample every frameskip-th frame
77
+ frame_indices = [start + t * fs for t in range(self.seq_len)]
78
+ obs = self.pixels[ep, frame_indices] # (T, C, H, W)
79
+ # Actions: group `frameskip` consecutive actions into a block (mean or sum)
80
+ acts = []
81
+ for t in range(self.seq_len):
82
+ act_block = self.actions[ep, start + t * fs: start + (t + 1) * fs]
83
+ acts.append(act_block.mean(dim=0))
84
+ acts = torch.stack(acts, dim=0) # (T, A)
85
+ return obs, acts
86
+
87
+
88
+ # ---------------------------------------------------------------------------
89
+ # Synthetic PushT-like dataset (for smoke-testing without 12 GB download)
90
+ # ---------------------------------------------------------------------------
91
+ class SyntheticPushTDataset(Dataset):
92
+ """
93
+ Generates synthetic 2D manipulation trajectories.
94
+ Agent (blue dot) pushes a T-shaped block toward a target.
95
+ Observations are rendered as 224×224 RGB images.
96
+ """
97
+ def __init__(self, n_episodes=2000, max_steps=196, img_size=224, seq_len=4, frameskip=5):
98
+ self.seq_len = seq_len
99
+ self.frameskip = frameskip
100
+ self.img_size = img_size
101
+ self.data = []
102
+ rng = np.random.RandomState(42)
103
+ min_steps = max(60, seq_len * frameskip + 10)
104
+ for _ in range(n_episodes):
105
+ length = rng.randint(min_steps, max( min_steps + 1, max_steps))
106
+ traj = self._generate_trajectory(length, rng)
107
+ self.data.append(traj)
108
+
109
+ def _generate_trajectory(self, length, rng):
110
+ img_size = self.img_size
111
+ # Agent pos, block pos, block angle
112
+ agent = rng.uniform(0.2, 0.8, size=(length, 2)).astype(np.float32)
113
+ block = rng.uniform(0.3, 0.7, size=(length, 2)).astype(np.float32)
114
+ angle = np.cumsum(rng.randn(length).astype(np.float32) * 0.1)
115
+ # Actions: dx, dy for agent (2D continuous)
116
+ actions = np.diff(agent, prepend=agent[:1], axis=0).astype(np.float32)
117
+ # Pad to uniform length by repeating last frame
118
+ pixels = np.zeros((length, 3, img_size, img_size), dtype=np.float32)
119
+ for t in range(length):
120
+ pixels[t] = self._render(agent[t], block[t], angle[t], img_size)
121
+ return {"pixels": pixels, "actions": actions}
122
+
123
+ @staticmethod
124
+ def _render(agent, block, angle, size):
125
+ canvas = np.ones((3, size, size), dtype=np.float32) * 0.9
126
+ # Draw agent (blue circle)
127
+ y, x = np.ogrid[:size, :size]
128
+ ax, ay = int(agent[0] * size), int(agent[1] * size)
129
+ mask = ((x - ax) ** 2 + (y - ay) ** 2) < (size * 0.03) ** 2
130
+ canvas[2][mask] = 0.3
131
+ canvas[0][mask] = 0.3
132
+ # Draw block (red T)
133
+ bx, by = int(block[0] * size), int(block[1] * size)
134
+ block_mask = ((x - bx) ** 2 + (y - by) ** 2) < (size * 0.05) ** 2
135
+ canvas[0][block_mask] = 0.9
136
+ canvas[1][block_mask] = 0.2
137
+ canvas[2][block_mask] = 0.2
138
+ return canvas
139
+
140
+ def __len__(self):
141
+ return len(self.data) * 50 # many sub-trajectories per episode
142
+
143
+ def __getitem__(self, idx):
144
+ ep = idx % len(self.data)
145
+ traj = self.data[ep]
146
+ max_start = len(traj["pixels"]) - self.seq_len * self.frameskip - 1
147
+ if max_start <= 0:
148
+ max_start = 1
149
+ start = np.random.randint(0, max_start)
150
+ fs = self.frameskip
151
+ frame_idx = [start + t * fs for t in range(self.seq_len)]
152
+ obs = torch.from_numpy(traj["pixels"][frame_idx])
153
+ acts = []
154
+ for t in range(self.seq_len):
155
+ a = traj["actions"][start + t * fs: start + (t + 1) * fs].mean(axis=0)
156
+ acts.append(a)
157
+ acts = torch.from_numpy(np.stack(acts, axis=0))
158
+ # Pad actions to effective dim (frameskip * action_dim)
159
+ A = acts.shape[-1]
160
+ pad = fs * A - A
161
+ if pad > 0:
162
+ acts = F.pad(acts, (0, pad))
163
+ return obs, acts
164
+
165
+
166
+ # ---------------------------------------------------------------------------
167
+ # Training loop
168
+ # ---------------------------------------------------------------------------
169
+ def train(args):
170
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
171
+ print(f"Device: {device}")
172
+
173
+ # Build model
174
+ model = build_lewm(
175
+ image_size=args.img_size,
176
+ patch_size=14,
177
+ embed_dim=args.embed_dim,
178
+ action_dim=args.action_dim,
179
+ history_size=args.history_size,
180
+ frameskip=args.frameskip,
181
+ predictor_depth=6,
182
+ predictor_heads=16,
183
+ predictor_mlp_dim=2048,
184
+ predictor_dropout=0.1,
185
+ ).to(device)
186
+
187
+ print(f"Model params: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
188
+
189
+ # Dataset
190
+ if args.use_synthetic:
191
+ dataset = SyntheticPushTDataset(
192
+ n_episodes=args.n_episodes,
193
+ seq_len=args.seq_len,
194
+ frameskip=args.frameskip,
195
+ img_size=args.img_size,
196
+ )
197
+ val_dataset = SyntheticPushTDataset(
198
+ n_episodes=max(1, args.n_episodes // 10),
199
+ seq_len=args.seq_len,
200
+ frameskip=args.frameskip,
201
+ img_size=args.img_size,
202
+ )
203
+ else:
204
+ dataset = TrajectoryDataset(
205
+ args.h5_path, seq_len=args.seq_len, frameskip=args.frameskip,
206
+ img_size=args.img_size, train=True,
207
+ )
208
+ val_dataset = TrajectoryDataset(
209
+ args.h5_path, seq_len=args.seq_len, frameskip=args.frameskip,
210
+ img_size=args.img_size, train=False,
211
+ )
212
+
213
+ loader = DataLoader(
214
+ dataset, batch_size=args.batch_size, shuffle=True,
215
+ num_workers=args.num_workers, drop_last=True, pin_memory=True,
216
+ )
217
+ val_loader = DataLoader(
218
+ val_dataset, batch_size=args.batch_size, shuffle=False,
219
+ num_workers=0, drop_last=False, pin_memory=True,
220
+ )
221
+
222
+ # Optimizer + scheduler
223
+ optimizer = torch.optim.AdamW(
224
+ model.parameters(), lr=args.lr, weight_decay=args.weight_decay,
225
+ betas=(0.9, 0.95),
226
+ )
227
+ total_steps = len(loader) * args.epochs
228
+ scheduler = get_cosine_schedule_with_warmup(
229
+ optimizer, num_warmup_steps=int(0.05 * total_steps),
230
+ num_training_steps=total_steps,
231
+ )
232
+
233
+ # SIGReg
234
+ sigreg = SIGReg(knots=17, num_proj=1024).to(device)
235
+
236
+ # Training
237
+ best_val_loss = float('inf')
238
+ for epoch in range(args.epochs):
239
+ model.train()
240
+ epoch_loss = 0.0
241
+ epoch_pred = 0.0
242
+ epoch_sig = 0.0
243
+
244
+ for step, (obs, acts) in enumerate(loader):
245
+ obs = obs.to(device)
246
+ acts = acts.to(device)
247
+ b, t = obs.shape[:2]
248
+
249
+ # Encode
250
+ emb = model.encode(obs) # (B, T, D)
251
+ act_emb = model.action_encoder(acts)
252
+
253
+ # Predictor (history_size)
254
+ ctx_emb = emb[:, :args.history_size]
255
+ ctx_act = act_emb[:, :args.history_size]
256
+ pred_emb = model.predict(ctx_emb, ctx_act)
257
+
258
+ # Prediction loss
259
+ pred_loss = (pred_emb[:, :-1] - emb[:, 1:args.history_size]).pow(2).mean()
260
+
261
+ # SIGReg
262
+ sigreg_loss = sigreg(emb.transpose(0, 1))
263
+
264
+ loss = pred_loss + args.lambd * sigreg_loss
265
+
266
+ optimizer.zero_grad()
267
+ loss.backward()
268
+ if args.grad_clip > 0:
269
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
270
+ optimizer.step()
271
+ scheduler.step()
272
+
273
+ epoch_loss += loss.item()
274
+ epoch_pred += pred_loss.item()
275
+ epoch_sig += sigreg_loss.item()
276
+
277
+ if step % args.log_interval == 0:
278
+ print(f" [E{epoch}|S{step}] loss={loss.item():.4f} "
279
+ f"pred={pred_loss.item():.4f} sigreg={sigreg_loss.item():.4f}")
280
+
281
+ n = len(loader)
282
+ print(f"Epoch {epoch} | avg_loss={epoch_loss/n:.4f} "
283
+ f"avg_pred={epoch_pred/n:.4f} avg_sigreg={epoch_sig/n:.4f}")
284
+
285
+ # Validation
286
+ model.eval()
287
+ val_loss = 0.0
288
+ with torch.no_grad():
289
+ for obs, acts in val_loader:
290
+ obs = obs.to(device)
291
+ acts = acts.to(device)
292
+ emb = model.encode(obs)
293
+ act_emb = model.action_encoder(acts)
294
+ ctx_emb = emb[:, :args.history_size]
295
+ ctx_act = act_emb[:, :args.history_size]
296
+ pred_emb = model.predict(ctx_emb, ctx_act)
297
+ pred_loss = (pred_emb[:, :-1] - emb[:, 1:args.history_size]).pow(2).mean()
298
+ sigreg_loss = sigreg(emb.transpose(0, 1))
299
+ val_loss += (pred_loss + args.lambd * sigreg_loss).item()
300
+ val_loss /= max(1, len(val_loader))
301
+ print(f" Val loss: {val_loss:.4f}")
302
+
303
+ # Save best
304
+ if val_loss < best_val_loss:
305
+ best_val_loss = val_loss
306
+ ckpt = {
307
+ "model": model.state_dict(),
308
+ "optimizer": optimizer.state_dict(),
309
+ "scheduler": scheduler.state_dict(),
310
+ "epoch": epoch,
311
+ "args": vars(args),
312
+ }
313
+ out_path = os.path.join(args.output_dir, "best_model.pt")
314
+ os.makedirs(args.output_dir, exist_ok=True)
315
+ torch.save(ckpt, out_path)
316
+ print(f" Saved best model -> {out_path}")
317
+
318
+ # Final save
319
+ final_path = os.path.join(args.output_dir, "final_model.pt")
320
+ torch.save({"model": model.state_dict(), "args": vars(args)}, final_path)
321
+ print(f"Training complete. Saved to {final_path}")
322
+
323
+ # Push to hub
324
+ if args.push_to_hub:
325
+ from huggingface_hub import HfApi
326
+ api = HfApi()
327
+ repo_id = f"{args.hf_username}/{args.hub_model_id}"
328
+ api.create_repo(repo_id, repo_type="model", exist_ok=True)
329
+ api.upload_file(
330
+ path_or_fileobj=final_path,
331
+ path_in_repo="model.pt",
332
+ repo_id=repo_id,
333
+ repo_type="model",
334
+ )
335
+ # Save config
336
+ import json
337
+ config = {
338
+ "_target_": "lewm_model.LeWorldModel",
339
+ "encoder": {
340
+ "image_size": args.img_size,
341
+ "patch_size": 14,
342
+ "embed_dim": args.embed_dim,
343
+ "num_layers": 12,
344
+ "num_heads": 3,
345
+ },
346
+ "predictor": {
347
+ "num_frames": args.history_size,
348
+ "depth": 6,
349
+ "heads": 16,
350
+ "mlp_dim": 2048,
351
+ "dropout": 0.1,
352
+ },
353
+ "action_dim": args.action_dim,
354
+ "frameskip": args.frameskip,
355
+ "lambd": args.lambd,
356
+ }
357
+ config_path = os.path.join(args.output_dir, "config.json")
358
+ with open(config_path, "w") as f:
359
+ json.dump(config, f, indent=2)
360
+ api.upload_file(
361
+ path_or_fileobj=config_path,
362
+ path_in_repo="config.json",
363
+ repo_id=repo_id,
364
+ repo_type="model",
365
+ )
366
+ print(f"Pushed model to https://huggingface.co/{repo_id}")
367
+
368
+
369
+ # ---------------------------------------------------------------------------
370
+ # CLI
371
+ # ---------------------------------------------------------------------------
372
+ def get_args():
373
+ parser = argparse.ArgumentParser(description="Train LeWorldModel")
374
+ # Data
375
+ parser.add_argument("--h5_path", type=str, default="/tmp/pusht_expert_train.h5")
376
+ parser.add_argument("--use_synthetic", action="store_true", help="Use synthetic data for smoke testing")
377
+ parser.add_argument("--n_episodes", type=int, default=2000, help="Synthetic dataset size")
378
+ parser.add_argument("--seq_len", type=int, default=4)
379
+ parser.add_argument("--frameskip", type=int, default=5)
380
+ parser.add_argument("--img_size", type=int, default=224)
381
+ parser.add_argument("--action_dim", type=int, default=2)
382
+ parser.add_argument("--history_size", type=int, default=3)
383
+ # Model
384
+ parser.add_argument("--embed_dim", type=int, default=192)
385
+ parser.add_argument("--lambd", type=float, default=0.1, help="SIGReg weight")
386
+ # Training
387
+ parser.add_argument("--epochs", type=int, default=10)
388
+ parser.add_argument("--batch_size", type=int, default=128)
389
+ parser.add_argument("--lr", type=float, default=1e-3)
390
+ parser.add_argument("--weight_decay", type=float, default=0.05)
391
+ parser.add_argument("--grad_clip", type=float, default=1.0)
392
+ parser.add_argument("--num_workers", type=int, default=4)
393
+ parser.add_argument("--log_interval", type=int, default=50)
394
+ parser.add_argument("--output_dir", type=str, default="/tmp/lewm_output")
395
+ # Hub
396
+ parser.add_argument("--push_to_hub", action="store_true")
397
+ parser.add_argument("--hf_username", type=str, default="ar27111994")
398
+ parser.add_argument("--hub_model_id", type=str, default="lewm-synthetic-pusht")
399
+ return parser.parse_args()
400
+
401
+
402
+ if __name__ == "__main__":
403
+ args = get_args()
404
+ train(args)