BonanDing commited on
Commit
681f346
·
1 Parent(s): 8652b14

Reproduce Training & Fix distributed eval

Browse files
README.md CHANGED
@@ -1,201 +1,66 @@
 
1
 
2
- <br>
3
- <p align="center">
4
-
5
- <p align="center">
6
- <img src="assets/worldmem_logo.png" alt="WORLDMEM Icon" width="80"/>
7
- </p>
8
- <h1 align="center"><strong>WorldMem: Long-term Consistent World Simulation <br> with Memory</strong></h1>
9
- <p align="center"><span><a href=""></a></span>
10
- <a href="https://xizaoqu.github.io">Zeqi Xiao<sup>1</sup></a>
11
- <a href="https://nirvanalan.github.io/">Yushi Lan<sup>1</sup></a>
12
- <a href="https://zhouyifan.net/about/">Yifan Zhou<sup>1</sup></a>
13
- <a href="https://vicky0522.github.io/Wenqi-Ouyang/">Wenqi Ouyang<sup>1</sup></a>
14
- <a href="https://williamyang1991.github.io/">Shuai Yang<sup>2</sup></a>
15
- <a href="https://zengyh1900.github.io/">Yanhong Zeng<sup>3</sup></a>
16
- <a href="https://xingangpan.github.io/">Xingang Pan<sup>1</sup></a> <br>
17
- <sup>1</sup>S-Lab, Nanyang Technological University, <br> <sup>2</sup>Wangxuan Institute of Computer Technology, Peking University,<br> <sup>3</sup>Shanghai AI Laboratory
18
- </p>
19
- </p>
20
-
21
- <p align="center">
22
- <a href="https://arxiv.org/abs/2504.12369" target='_blank'>
23
- <img src="https://img.shields.io/badge/arXiv-2504.12369-blue?">
24
- </a>
25
- <a href="https://xizaoqu.github.io/worldmem/" target='_blank'>
26
- <img src="https://img.shields.io/badge/Project-&#x1F680-blue">
27
- </a>
28
- <a href="https://huggingface.co/spaces/yslan/worldmem" target="_blank">
29
- <img src="https://img.shields.io/badge/🤗 HuggingFace-Demo-orange" />
30
- </a>
31
- </p>
32
-
33
- https://github.com/user-attachments/assets/fb8a32e2-9470-4819-a93d-c38caf76d72c
34
-
35
-
36
- ## Installation
37
 
38
- ```
39
- conda create python=3.10 -n worldmem
 
 
40
  conda activate worldmem
41
  pip install -r requirements.txt
42
  conda install -c conda-forge ffmpeg=4.3.2
43
  ```
44
 
 
45
 
46
- ## Quick start
47
-
48
- ```
49
- python app.py
50
- ```
51
-
52
- ## Run
53
-
54
- To enable cloud logging with [Weights & Biases (wandb)](https://wandb.ai/site), follow these steps:
55
-
56
- 1. Sign up for a wandb account.
57
- 2. Run the following command to log in:
58
-
59
- ```bash
60
- wandb login
61
- ```
62
-
63
- 3. Open `configurations/training.yaml` and set the `entity` and `project` field to your wandb username.
64
-
65
- ---
66
-
67
- ### Training
68
-
69
- Download pretrained weights from [Oasis](https://github.com/etched-ai/open-oasis).
70
-
71
- Training the model on 4 H100 GPUs, it converges after approximately 500K steps.
72
- We observe that gradually increasing task difficulty improves performance. Thus, we adopt a multi-stage training strategy:
73
- ,
74
- ```bash
75
- sh train_stage_1.sh # Small range, no vertical turning
76
- sh train_stage_2.sh # Large range, no vertical turning
77
- sh train_stage_3.sh # Large range, with vertical turning
78
- ```
79
-
80
- To resume training from a previous checkpoint, configure the `resume` and `output_dir` variables in the corresponding `.sh` script.
81
-
82
- ---
83
-
84
- ### Inference
85
-
86
- To run inference:
87
-
88
- ```bash
89
- sh infer.sh
90
- ```
91
-
92
- You can either **load the diffusion model and VAE separately**:
93
-
94
- ```bash
95
- +diffusion_model_path=zeqixiao/worldmem_checkpoints/diffusion_only.ckpt \
96
- +vae_path=zeqixiao/worldmem_checkpoints/vae_only.ckpt \
97
- +customized_load=true \
98
- +seperate_load=true \
99
- ```
100
-
101
- Or **load a combined checkpoint**:
102
-
103
- ```bash
104
- +load=your_model_path \
105
- +customized_load=true \
106
- +seperate_load=false \
107
- ```
108
-
109
- ### Evaluation
110
-
111
- To run evaluation:
112
-
113
- ```bash
114
- sh evaluate.sh
115
- ```
116
-
117
- This script reproduces the results in Table 1 (beyond context window). It will generate PSNR and Lpips. Evaluating 1 case on 1 A100 GPU takes approximately 6 minutes. You can adjust `experiment.test.limit_batch` to specify the number of cases to evaluate.
118
-
119
- Visual results will be saved by default to a timestamped directory (e.g., `outputs/2025-11-30/00-02-42`).
120
-
121
- To calculate the FID score, run:
122
-
123
- ```bash
124
- python calculate_fid.py --videos_dir <path_to_videos>
125
- ```
126
-
127
- For example:
128
-
129
- ```bash
130
- python calculate_fid.py --videos_dir outputs/2025-11-30/00-02-42/videos/test_vis
131
- ```
132
 
133
- **Expected Results:**
134
-
135
- | Metric | Value |
136
- |--------|--------|
137
- | PSNR | 24.01 |
138
- | LPIPS | 0.1667 |
139
- | FID | 15.13 |
140
-
141
- *Note: FID is computed over 5000 frames.*
142
-
143
- ---
144
-
145
- ## Dataset
146
-
147
- Download the Minecraft dataset from [Hugging Face](https://huggingface.co/datasets/zeqixiao/worldmem_minecraft_dataset)
148
-
149
- Place the dataset in the following directory structure:
150
-
151
- ```
152
  data/
153
  └── minecraft/
154
  ├── training/
155
- ── validation/
156
  └── test/
157
  ```
158
 
159
- ## Data Generation
160
 
161
- After setting up the environment as described in [MineDojo's GitHub repository](https://github.com/MineDojo/MineDojo), you can generate data using the following command:
 
 
162
 
163
  ```bash
164
- xvfb-run -a python data_generator.py -o data/test -z 4 --env_type plains
 
 
165
  ```
166
 
167
- **Parameters:**
168
- - `-o`: Output directory for generated data
169
- - `-z`: Number of parallel workers
170
- - `--env_type`: Environment type (e.g., `plains`)
171
-
172
 
173
- ## TODO
 
 
174
 
175
- - [x] Release inference models and weights;
176
- - [x] Release training pipeline on Minecraft;
177
- - [x] Release training data on Minecraft;
178
- - [x] Release evaluation scripts and data generator.
179
 
 
180
 
 
181
 
182
- ## 🔗 Citation
 
 
 
 
 
 
 
 
183
 
184
- If you find our work helpful, please cite:
185
 
 
 
186
  ```
187
- @misc{xiao2025worldmemlongtermconsistentworld,
188
- title={WORLDMEM: Long-term Consistent World Simulation with Memory},
189
- author={Zeqi Xiao and Yushi Lan and Yifan Zhou and Wenqi Ouyang and Shuai Yang and Yanhong Zeng and Xingang Pan},
190
- year={2025},
191
- eprint={2504.12369},
192
- archivePrefix={arXiv},
193
- primaryClass={cs.CV},
194
- url={https://arxiv.org/abs/2504.12369},
195
- }
196
- ```
197
-
198
- ## 👏 Acknowledgements
199
- - [Diffusion Forcing](https://github.com/buoyancy99/diffusion-forcing): Diffusion Forcing provides flexible training and inference strategies for our methods.
200
- - [Minedojo](https://github.com/MineDojo/MineDojo): We collect our Minecraft dataset from Minedojo.
201
- - [Open-oasis](https://github.com/etched-ai/open-oasis): Our model architecture is based on Open-oasis. We also use pretrained VAE and DiT weight from it.
 
1
+ # WorldMem
2
 
3
+ Long-term consistent world simulation with memory.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ ## Environment (conda)
6
+
7
+ ```bash
8
+ conda create -n worldmem python=3.10
9
  conda activate worldmem
10
  pip install -r requirements.txt
11
  conda install -c conda-forge ffmpeg=4.3.2
12
  ```
13
 
14
+ ## Data preparation (data folder)
15
 
16
+ 1. Download the Minecraft dataset:
17
+ https://huggingface.co/datasets/zeqixiao/worldmem_minecraft_dataset
18
+ 2. Place it under `data/` with this structure:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ ```text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  data/
22
  └── minecraft/
23
  ├── training/
24
+ ── validation/
25
  └── test/
26
  ```
27
 
28
+ The training and evaluation scripts expect the dataset to live at `data/minecraft` by default.
29
 
30
+ ## Training
31
+
32
+ Run a single stage:
33
 
34
  ```bash
35
+ sh train_stage_1.sh
36
+ sh train_stage_2.sh
37
+ sh train_stage_3.sh
38
  ```
39
 
40
+ Run all stages:
 
 
 
 
41
 
42
+ ```bash
43
+ sh train_3stages.sh
44
+ ```
45
 
46
+ The stage scripts include dataset and checkpoint paths. Update those paths or override them on the CLI to match your local setup.
 
 
 
47
 
48
+ ## Training config (exp_video.yaml)
49
 
50
+ Defaults live in `configurations/experiment/exp_video.yaml`.
51
 
52
+ Common fields to edit:
53
+ - `training.lr`
54
+ - `training.precision`
55
+ - `training.batch_size`
56
+ - `training.max_steps`
57
+ - `training.checkpointing.every_n_train_steps`
58
+ - `validation.val_every_n_step`
59
+ - `validation.batch_size`
60
+ - `test.batch_size`
61
 
62
+ You can also override values from the CLI used in the scripts:
63
 
64
+ ```bash
65
+ python -m main +name=train experiment.training.batch_size=8 experiment.training.max_steps=100000
66
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
algorithms/worldmem/df_base.py CHANGED
@@ -33,6 +33,7 @@ class DiffusionForcingBase(BasePytorchAlgo):
33
  self.action_cond_dim = cfg.action_cond_dim
34
  self.causal = cfg.causal
35
 
 
36
  self.uncertainty_scale = cfg.uncertainty_scale
37
  self.timesteps = cfg.diffusion.timesteps
38
  self.sampling_timesteps = cfg.diffusion.sampling_timesteps
 
33
  self.action_cond_dim = cfg.action_cond_dim
34
  self.causal = cfg.causal
35
 
36
+
37
  self.uncertainty_scale = cfg.uncertainty_scale
38
  self.timesteps = cfg.diffusion.timesteps
39
  self.sampling_timesteps = cfg.diffusion.sampling_timesteps
algorithms/worldmem/df_video.py CHANGED
@@ -3,6 +3,7 @@ import random
3
  import math
4
  import numpy as np
5
  import torch
 
6
  import torch.nn.functional as F
7
  import torchvision.transforms.functional as TF
8
  from torchvision.transforms import InterpolationMode
@@ -21,6 +22,7 @@ from .models.vae import VAE_models
21
  from .models.diffusion import Diffusion
22
  from .models.pose_prediction import PosePredictionNet
23
  import glob
 
24
 
25
  # Utility Functions
26
  def euler_to_rotation_matrix(pitch, yaw):
@@ -376,7 +378,8 @@ class WorldMemMinecraft(DiffusionForcingBase):
376
  ref_mode=self.ref_mode
377
  )
378
 
379
- self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity()
 
380
  vae = VAE_models["vit-l-20-shallow-encoder"]()
381
  self.vae = vae.eval()
382
 
@@ -430,13 +433,13 @@ class WorldMemMinecraft(DiffusionForcingBase):
430
  focal_length=self.focal_length,
431
  image_height=xs.shape[-2],image_width=xs.shape[-1]
432
  ).to(xs.dtype)
433
- )
434
  frame_idx_list.append(
435
  torch.cat([
436
  frame_idx[i:i + 1] - frame_idx[i:i + 1],
437
  frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1]
438
  ]).clone()
439
- )
440
  input_pose_condition = torch.cat(input_pose_condition)
441
  frame_idx_list = torch.cat(frame_idx_list)
442
  else:
@@ -476,66 +479,78 @@ class WorldMemMinecraft(DiffusionForcingBase):
476
  return {"loss": loss}
477
 
478
  def on_validation_epoch_end(self, namespace="validation") -> None:
479
- if not self.validation_step_outputs:
480
  return
481
-
482
- xs_pred = []
483
- xs = []
484
- for pred, gt in self.validation_step_outputs:
485
- xs_pred.append(pred)
486
- xs.append(gt)
487
-
488
- xs_pred = torch.cat(xs_pred, 1)
489
- if gt is not None:
490
- xs = torch.cat(xs, 1)
491
- else:
492
- xs = None
493
 
494
- if self.logger and self.log_video:
495
- log_video(
496
- xs_pred,
497
- xs,
498
- step=None if namespace == "test" else self.global_step,
499
- namespace=namespace + "_vis",
500
- context_frames=self.context_frames,
501
- logger=self.logger.experiment,
502
- save_local=self.save_local,
503
- local_save_dir=self.local_save_dir,
504
- )
505
-
506
- if xs is not None:
507
- # Move data to the same device as LPIPS model for metric calculation
508
- device = next(self.validation_lpips_model.parameters()).device
509
- xs_pred_device = xs_pred.to(device)
510
- xs_device = xs.to(device)
511
-
512
- metric_dict = get_validation_metrics_for_videos(
513
- xs_pred_device, xs_device,
514
- lpips_model=self.validation_lpips_model,
515
- lpips_batch_size=self.lpips_batch_size)
516
-
517
- self.log_dict(
518
- {"mse": metric_dict['mse'],
519
- "psnr": metric_dict['psnr'],
520
- "lpips": metric_dict['lpips']},
521
- sync_dist=True
522
- )
523
-
524
- if self.log_curve:
525
- psnr_values = metric_dict['frame_wise_psnr'].cpu().tolist()
526
- frames = list(range(len(psnr_values)))
527
- line_plot = wandb.plot.line_series(
528
- xs = frames,
529
- ys = [psnr_values],
530
- keys = ["PSNR"],
531
- title = "Frame-wise PSNR",
532
- xname = "Frame index"
533
  )
534
 
535
- self.logger.experiment.log({"frame_wise_psnr_plot": line_plot})
536
-
537
  self.validation_step_outputs.clear()
538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  def _preprocess_batch(self, batch):
540
 
541
  xs, conditions, pose_conditions, frame_index = batch
@@ -554,7 +569,7 @@ class WorldMemMinecraft(DiffusionForcingBase):
554
  return xs, conditions, pose_conditions, c2w_mat, frame_index
555
 
556
  def encode(self, x):
557
- # vae encoding
558
  T = x.shape[0]
559
  H, W = x.shape[-2:]
560
  scaling_factor = 0.07843137255
@@ -783,8 +798,21 @@ class WorldMemMinecraft(DiffusionForcingBase):
783
  xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
784
  xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
785
 
786
- # Store results for evaluation (move to CPU to save GPU memory)
787
- self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  return
789
 
790
  @torch.no_grad()
 
3
  import math
4
  import numpy as np
5
  import torch
6
+ import torch.distributed as dist
7
  import torch.nn.functional as F
8
  import torchvision.transforms.functional as TF
9
  from torchvision.transforms import InterpolationMode
 
22
  from .models.diffusion import Diffusion
23
  from .models.pose_prediction import PosePredictionNet
24
  import glob
25
+ import wandb
26
 
27
  # Utility Functions
28
  def euler_to_rotation_matrix(pitch, yaw):
 
378
  ref_mode=self.ref_mode
379
  )
380
 
381
+ # Avoid distributed sync inside torchmetrics; reduce metrics manually across ranks.
382
+ self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity(sync_on_compute=False)
383
  vae = VAE_models["vit-l-20-shallow-encoder"]()
384
  self.vae = vae.eval()
385
 
 
433
  focal_length=self.focal_length,
434
  image_height=xs.shape[-2],image_width=xs.shape[-1]
435
  ).to(xs.dtype)
436
+ ) # [V(1 + memory_condition_length),B ,H, W, 6]
437
  frame_idx_list.append(
438
  torch.cat([
439
  frame_idx[i:i + 1] - frame_idx[i:i + 1],
440
  frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1]
441
  ]).clone()
442
+ ) # [V(1 + memory_condition_length),B] (0 for current frame, others for memory frames with relative index to current frame)
443
  input_pose_condition = torch.cat(input_pose_condition)
444
  frame_idx_list = torch.cat(frame_idx_list)
445
  else:
 
479
  return {"loss": loss}
480
 
481
  def on_validation_epoch_end(self, namespace="validation") -> None:
482
+ if not hasattr(self, "_metric_device"):
483
  return
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
+ if dist.is_available() and dist.is_initialized():
486
+ for tensor in (
487
+ self._mse_sum,
488
+ self._mse_count,
489
+ self._psnr_sum,
490
+ self._psnr_count,
491
+ self._lpips_sum,
492
+ self._lpips_count,
493
+ ):
494
+ dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
495
+
496
+ mse = self._mse_sum / self._mse_count.clamp_min(1.0)
497
+ psnr = self._psnr_sum / self._psnr_count.clamp_min(1.0)
498
+ lpips = self._lpips_sum / self._lpips_count.clamp_min(1.0)
499
+
500
+ if self.trainer is None or self.trainer.is_global_zero:
501
+ if self._mse_count.item() > 0:
502
+ self.log_dict(
503
+ {"mse": mse, "psnr": psnr, "lpips": lpips},
504
+ sync_dist=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  )
506
 
 
 
507
  self.validation_step_outputs.clear()
508
 
509
+ def on_validation_epoch_start(self) -> None:
510
+ self._reset_metric_accumulators()
511
+
512
+ def on_test_epoch_start(self) -> None:
513
+ self._reset_metric_accumulators()
514
+
515
+ def _reset_metric_accumulators(self) -> None:
516
+ self._metric_device = next(self.validation_lpips_model.parameters()).device
517
+ self._mse_sum = torch.tensor(0.0, device=self._metric_device)
518
+ self._mse_count = torch.tensor(0.0, device=self._metric_device)
519
+ self._psnr_sum = torch.tensor(0.0, device=self._metric_device)
520
+ self._psnr_count = torch.tensor(0.0, device=self._metric_device)
521
+ self._lpips_sum = torch.tensor(0.0, device=self._metric_device)
522
+ self._lpips_count = torch.tensor(0.0, device=self._metric_device)
523
+
524
+ def _update_metric_accumulators(self, xs_pred: torch.Tensor, xs_gt: torch.Tensor) -> None:
525
+ xs_pred_device = xs_pred.to(self._metric_device)
526
+ xs_device = xs_gt.to(self._metric_device)
527
+
528
+ metric_dict = get_validation_metrics_for_videos(
529
+ xs_pred_device,
530
+ xs_device,
531
+ lpips_model=self.validation_lpips_model,
532
+ lpips_batch_size=self.lpips_batch_size,
533
+ )
534
+
535
+ mse_val = metric_dict["mse"].detach()
536
+ psnr_val = metric_dict["psnr"].detach()
537
+ lpips_val = torch.tensor(metric_dict["lpips"], device=self._metric_device)
538
+
539
+ mse_count_batch = torch.tensor(float(xs_pred_device.numel()), device=self._metric_device)
540
+ psnr_count_batch = torch.tensor(float(xs_pred_device.shape[1]), device=self._metric_device)
541
+ lpips_count_batch = torch.tensor(
542
+ float(xs_pred_device.shape[0] * xs_pred_device.shape[1]), device=self._metric_device
543
+ )
544
+
545
+ self._mse_sum += mse_val * mse_count_batch
546
+ self._psnr_sum += psnr_val * psnr_count_batch
547
+ self._lpips_sum += lpips_val * lpips_count_batch
548
+ self._mse_count += mse_count_batch
549
+ self._psnr_count += psnr_count_batch
550
+ self._lpips_count += lpips_count_batch
551
+
552
+ del xs_pred_device, xs_device
553
+
554
  def _preprocess_batch(self, batch):
555
 
556
  xs, conditions, pose_conditions, frame_index = batch
 
569
  return xs, conditions, pose_conditions, c2w_mat, frame_index
570
 
571
  def encode(self, x):
572
+ # vae encoding x with shape (t b c h w)
573
  T = x.shape[0]
574
  H, W = x.shape[-2:]
575
  scaling_factor = 0.07843137255
 
798
  xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
799
  xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
800
 
801
+ # Save videos for every batch (rank is encoded in filenames).
802
+ if self.logger and self.log_video:
803
+ log_video(
804
+ xs_pred,
805
+ xs_decode,
806
+ step=batch_idx,
807
+ namespace=namespace + "_vis",
808
+ context_frames=self.context_frames,
809
+ logger=self.logger.experiment,
810
+ save_local=self.save_local,
811
+ local_save_dir=self.local_save_dir,
812
+ )
813
+
814
+ # Stream metrics to avoid holding all outputs in memory.
815
+ self._update_metric_accumulators(xs_pred, xs_decode)
816
  return
817
 
818
  @torch.no_grad()
algorithms/worldmem/models/diffusion.py CHANGED
@@ -169,7 +169,7 @@ class Diffusion(nn.Module):
169
  mode=mode, reference_length=reference_length, frame_idx=frame_idx)
170
  model_output = model_output.permute(1,0,2,3,4)
171
  x = x.permute(1,0,2,3,4)
172
- t = t.permute(1,0)
173
 
174
  if self.objective == "pred_noise":
175
  pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
 
169
  mode=mode, reference_length=reference_length, frame_idx=frame_idx)
170
  model_output = model_output.permute(1,0,2,3,4)
171
  x = x.permute(1,0,2,3,4)
172
+ t = t.permute(1,0)
173
 
174
  if self.objective == "pred_noise":
175
  pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
configurations/experiment/base_pytorch.yaml CHANGED
@@ -35,7 +35,8 @@ validation:
35
  val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
36
  val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
37
  limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
38
- inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
 
39
  data:
40
  num_workers: 4 # number of CPU threads for data preprocessing, for validation.
41
  shuffle: False # whether validation data will be shuffled
@@ -45,6 +46,7 @@ test:
45
  compile: False # whether to compile the model with torch.compile
46
  batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
47
  limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
 
48
  data:
49
  num_workers: 4 # number of CPU threads for data preprocessing, for test.
50
  shuffle: False # whether test data will be shuffled
 
35
  val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
36
  val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
37
  limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
38
+ # inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
39
+ inference_mode: False # whether to run validation in inference mode (enable_grad won't work!)
40
  data:
41
  num_workers: 4 # number of CPU threads for data preprocessing, for validation.
42
  shuffle: False # whether validation data will be shuffled
 
46
  compile: False # whether to compile the model with torch.compile
47
  batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
48
  limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
49
+ inference_mode: False # whether to run test in inference mode (enable_grad won't work!)
50
  data:
51
  num_workers: 4 # number of CPU threads for data preprocessing, for test.
52
  shuffle: False # whether test data will be shuffled
configurations/experiment/exp_video.yaml CHANGED
@@ -7,6 +7,7 @@ training:
7
  lr: 2e-5
8
  precision: 16-mixed
9
  batch_size: 4
 
10
  max_epochs: -1
11
  max_steps: 2000005
12
  checkpointing:
 
7
  lr: 2e-5
8
  precision: 16-mixed
9
  batch_size: 4
10
+ # batch_size: 8
11
  max_epochs: -1
12
  max_steps: 2000005
13
  checkpointing:
configurations/training.yaml CHANGED
@@ -8,7 +8,7 @@ defaults:
8
  debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
9
 
10
  wandb:
11
- entity: xizaoqu # wandb account name / organization name [fixme]
12
  project: worldmem # wandb project name; if not provided, defaults to root folder name [fixme]
13
  mode: online # set wandb logging to online, offline or dryrun
14
 
 
8
  debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
9
 
10
  wandb:
11
+ entity: turlin # wandb account name / organization name [fixme]
12
  project: worldmem # wandb project name; if not provided, defaults to root folder name [fixme]
13
  mode: online # set wandb logging to online, offline or dryrun
14
 
datasets/video/base_video_dataset.py CHANGED
@@ -47,6 +47,7 @@ class BaseVideoDataset(torch.utils.data.Dataset, ABC):
47
  self.clips_per_video = np.clip(np.array(self.metadata) - self.n_frames + 1, a_min=1, a_max=None).astype(
48
  np.int32
49
  )
 
50
  self.cum_clips_per_video = np.cumsum(self.clips_per_video)
51
  self.transform = transforms.Resize((self.resolution, self.resolution), antialias=True)
52
 
 
47
  self.clips_per_video = np.clip(np.array(self.metadata) - self.n_frames + 1, a_min=1, a_max=None).astype(
48
  np.int32
49
  )
50
+
51
  self.cum_clips_per_video = np.cumsum(self.clips_per_video)
52
  self.transform = transforms.Resize((self.resolution, self.resolution), antialias=True)
53
 
datasets/video/minecraft_video_dataset.py CHANGED
@@ -126,7 +126,7 @@ class MinecraftVideoDataset(BaseVideoDataset):
126
  try:
127
  return self.load_data(idx)
128
  except Exception as e:
129
- # print(f"Retrying due to error: {e}")
130
  idx = (idx + 1) % len(self)
131
 
132
  def load_data(self, idx):
@@ -184,9 +184,9 @@ class MinecraftVideoDataset(BaseVideoDataset):
184
  dis = np.abs(poses[:, None] - poses_pool[None, :])
185
  dis[..., 3:][dis[..., 3:] > 180] = 360 - dis[..., 3:][dis[..., 3:] > 180]
186
 
187
- spatial_match = (dis[..., :3] <= self.pos_range).sum(-1) >= 3
188
- angular_match = (dis[..., 3:] <= self.angle_range).sum(-1) >= 2
189
- not_exact_match = ((dis[..., :3] > 0).sum(-1) >= 1) | ((dis[..., 3:] > 0).sum(-1) >= 1)
190
 
191
  valid_index = (spatial_match & angular_match & not_exact_match).sum(0)
192
  valid_index[:100] = 0 # skip unstable early frames
@@ -237,7 +237,7 @@ class MinecraftVideoDataset(BaseVideoDataset):
237
  timestamp = np.arange(self.n_frames)
238
 
239
  # === 7. Convert video to torch format ===
240
- video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
241
 
242
  # === 9. Return all items ===
243
  return (
 
126
  try:
127
  return self.load_data(idx)
128
  except Exception as e:
129
+ print(f"Retrying due to error: {e}")
130
  idx = (idx + 1) % len(self)
131
 
132
  def load_data(self, idx):
 
184
  dis = np.abs(poses[:, None] - poses_pool[None, :])
185
  dis[..., 3:][dis[..., 3:] > 180] = 360 - dis[..., 3:][dis[..., 3:] > 180]
186
 
187
+ spatial_match = (dis[..., :3] <= self.pos_range).sum(-1) >= 3 # X, Y, Z axis all within range
188
+ angular_match = (dis[..., 3:] <= self.angle_range).sum(-1) >= 2 # Pitch, Yaw all within range
189
+ not_exact_match = ((dis[..., :3] > 0).sum(-1) >= 1) | ((dis[..., 3:] > 0).sum(-1) >= 1) # At least one axis is in range
190
 
191
  valid_index = (spatial_match & angular_match & not_exact_match).sum(0)
192
  valid_index[:100] = 0 # skip unstable early frames
 
237
  timestamp = np.arange(self.n_frames)
238
 
239
  # === 7. Convert video to torch format ===
240
+ video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous() # (T, H, W, C) -> (T, C, H, W)
241
 
242
  # === 9. Return all items ===
243
  return (
evaluate.sh CHANGED
@@ -1,15 +1,27 @@
1
  export PYTHONWARNINGS="ignore"
 
 
 
 
 
 
 
 
 
 
 
 
2
  wandb offline
3
  python -m main +name=infer \
4
  experiment.tasks=[test] \
5
  dataset.validation_multiplier=1 \
6
  +dataset.seed=42 \
7
- +diffusion_model_path=zeqixiao/worldmem_checkpoints/diffusion_only.ckpt \
8
- +vae_path=zeqixiao/worldmem_checkpoints/vae_only.ckpt \
9
  +customized_load=true \
10
  +seperate_load=true \
11
  dataset.n_frames=8 \
12
- dataset.save_dir=data/minecraft \
13
  +dataset.n_frames_valid=700 \
14
  algorithm.diffusion.sampling_timesteps=20 \
15
  +algorithm.memory_condition_length=8 \
@@ -20,4 +32,4 @@ python -m main +name=infer \
20
  +algorithm.n_tokens=8 \
21
  algorithm.context_frames=600 \
22
  experiment.test.batch_size=1 \
23
- experiment.test.limit_batch=10 \
 
1
  export PYTHONWARNINGS="ignore"
2
+ export CUDA_VISIBLE_DEVICES=4,5,6,7
3
+
4
+ # export NCCL_DEBUG=INFO
5
+ # export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
6
+ # export TORCH_DISTRIBUTED_DEBUG=DETAIL
7
+ # export NCCL_DEBUG_SUBSYS=COLL
8
+ # # Optional but very helpful while debugging (slower):
9
+ # export TORCH_NCCL_BLOCKING_WAIT=1
10
+ export NCCL_TIMEOUT=7200
11
+ export NCCL_P2P_DISABLE=1
12
+ export HYDRA_FULL_ERROR=1
13
+
14
  wandb offline
15
  python -m main +name=infer \
16
  experiment.tasks=[test] \
17
  dataset.validation_multiplier=1 \
18
  +dataset.seed=42 \
19
+ +diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
20
+ +vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
21
  +customized_load=true \
22
  +seperate_load=true \
23
  dataset.n_frames=8 \
24
+ dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
25
  +dataset.n_frames_valid=700 \
26
  algorithm.diffusion.sampling_timesteps=20 \
27
  +algorithm.memory_condition_length=8 \
 
32
  +algorithm.n_tokens=8 \
33
  algorithm.context_frames=600 \
34
  experiment.test.batch_size=1 \
35
+ experiment.test.limit_batch=160 \
experiments/exp_base.py CHANGED
@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
9
  from typing import Optional, Union, Literal, List, Dict
10
  import pathlib
11
  import os
12
-
13
  import hydra
14
  import torch
15
  from lightning.pytorch.strategies.ddp import DDPStrategy
@@ -415,10 +415,11 @@ class BaseLightningExperiment(BaseExperiment):
415
  logger=self.logger,
416
  devices="auto",
417
  num_nodes=self.cfg.num_nodes,
418
- strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
419
  callbacks=callbacks,
420
  limit_test_batches=self.cfg.test.limit_batch,
421
  precision=self.cfg.test.precision,
 
422
  detect_anomaly=False, # self.cfg.debug,
423
  )
424
 
 
9
  from typing import Optional, Union, Literal, List, Dict
10
  import pathlib
11
  import os
12
+ from datetime import timedelta
13
  import hydra
14
  import torch
15
  from lightning.pytorch.strategies.ddp import DDPStrategy
 
415
  logger=self.logger,
416
  devices="auto",
417
  num_nodes=self.cfg.num_nodes,
418
+ strategy=DDPStrategy(find_unused_parameters=False, timeout=timedelta(hours=1)) if torch.cuda.device_count() > 1 else "auto",
419
  callbacks=callbacks,
420
  limit_test_batches=self.cfg.test.limit_batch,
421
  precision=self.cfg.test.precision,
422
+ inference_mode=self.cfg.test.inference_mode,
423
  detect_anomaly=False, # self.cfg.debug,
424
  )
425
 
infer.sh CHANGED
@@ -1,14 +1,21 @@
1
  export PYTHONWARNINGS="ignore"
 
 
 
 
 
 
 
2
  wandb offline
3
  python -m main +name=infer \
4
  experiment.tasks=[validation] \
5
  dataset.validation_multiplier=1 \
6
- +diffusion_model_path=zeqixiao/worldmem_checkpoints/diffusion_only.ckpt \
7
- +vae_path=zeqixiao/worldmem_checkpoints/vae_only.ckpt \
8
  +customized_load=true \
9
  +seperate_load=true \
10
  dataset.n_frames=8 \
11
- dataset.save_dir=data/minecraft \
12
  +dataset.n_frames_valid=700 \
13
  +dataset.memory_condition_length=8 \
14
  +dataset.customized_validation=true \
 
1
  export PYTHONWARNINGS="ignore"
2
+ export NCCL_DEBUG=INFO
3
+ export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
4
+ export TORCH_DISTRIBUTED_DEBUG=DETAIL
5
+ export NCCL_DEBUG_SUBSYS=COLL
6
+ # Optional but very helpful while debugging (slower):
7
+ export TORCH_NCCL_BLOCKING_WAIT=1
8
+ export NCCL_P2P_DISABLE=1
9
  wandb offline
10
  python -m main +name=infer \
11
  experiment.tasks=[validation] \
12
  dataset.validation_multiplier=1 \
13
+ +diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
14
+ +vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
15
  +customized_load=true \
16
  +seperate_load=true \
17
  dataset.n_frames=8 \
18
+ dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
19
  +dataset.n_frames_valid=700 \
20
  +dataset.memory_condition_length=8 \
21
  +dataset.customized_validation=true \
main.py CHANGED
@@ -59,6 +59,10 @@ def run_local(cfg: DictConfig):
59
  OmegaConf.set_readonly(hydra_cfg, True)
60
 
61
  output_dir = Path(hydra_cfg.runtime.output_dir)
 
 
 
 
62
 
63
  if is_rank_zero:
64
  print(cyan(f"Outputs will be saved to:"), output_dir)
 
59
  OmegaConf.set_readonly(hydra_cfg, True)
60
 
61
  output_dir = Path(hydra_cfg.runtime.output_dir)
62
+ if not output_dir.exists():
63
+ output_dir.mkdir(parents=True, exist_ok=True)
64
+ if is_rank_zero:
65
+ print(cyan(f"Created output directory: {output_dir}"))
66
 
67
  if is_rank_zero:
68
  print(cyan(f"Outputs will be saved to:"), output_dir)
requirements.txt CHANGED
@@ -1,28 +1,139 @@
1
- torch~=2.4.0
2
- torchvision~=0.19.1
3
- lightning~=2.1.2
4
- wandb~=0.17.0
5
- hydra-core~=1.3.2
6
- omegaconf~=2.3.0
7
- torchmetrics[image]==0.11.4
8
- wandb-osh==1.2.1
9
- gluonts[torch]==0.13.1
10
- pytorchvideo~=0.1.5
11
- colorama
12
- tqdm
13
- opencv-python
14
- matplotlib
15
- click
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  moviepy==1.0.3
17
- imageio
18
- einops
19
- pandas
20
- pyzmq
21
- pyrealsense2
22
- internetarchive
23
- h5py
24
- rotary_embedding_torch
25
- diffusers
26
- timm
27
- gradio
28
- spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.13.3
4
+ aiosignal==1.4.0
5
+ altair==5.5.0
6
+ annotated-doc==0.0.4
7
+ antlr4-python3-runtime==4.9.3
8
+ anyio==4.12.1
9
+ async-timeout==5.0.1
10
+ attrs==25.4.0
11
+ av==16.1.0
12
+ certifi==2026.1.4
13
+ charset-normalizer==3.4.4
14
+ click==8.3.1
15
+ colorama==0.4.6
16
+ colorlog==6.10.1
17
+ contourpy==1.3.2
18
+ cycler==0.12.1
19
+ decorator==4.4.2
20
+ diffusers==0.36.0
21
+ docker-pycreds==0.4.0
22
+ einops==0.8.1
23
+ exceptiongroup==1.3.1
24
+ fastapi==0.125.0
25
+ ffmpy==1.0.0
26
+ filelock==3.20.3
27
+ fonttools==4.61.1
28
+ frozenlist==1.8.0
29
+ fsspec==2024.12.0
30
+ fvcore==0.1.5.post20221221
31
+ gitdb==4.0.12
32
+ GitPython==3.1.46
33
+ gluonts==0.13.1
34
+ gradio==3.50.2
35
+ gradio_client==0.6.1
36
+ h11==0.16.0
37
+ h5py==3.15.1
38
+ hf-xet==1.2.0
39
+ httpcore==1.0.9
40
+ httpx==0.28.1
41
+ huggingface_hub==1.3.2
42
+ hydra-core==1.3.2
43
+ idna==3.11
44
+ ImageIO==2.37.2
45
+ imageio-ffmpeg==0.6.0
46
+ importlib_metadata==8.7.1
47
+ importlib_resources==6.5.2
48
+ internetarchive==5.7.1
49
+ iopath==0.1.10
50
+ Jinja2==3.1.6
51
+ jsonpatch==1.33
52
+ jsonpointer==3.0.0
53
+ jsonschema==4.26.0
54
+ jsonschema-specifications==2025.9.1
55
+ kiwisolver==1.4.9
56
+ lightning==2.1.4
57
+ lightning-utilities==0.15.2
58
+ lpips==0.1.4
59
+ MarkupSafe==2.1.5
60
+ matplotlib==3.10.8
61
  moviepy==1.0.3
62
+ mpmath==1.3.0
63
+ multidict==6.7.0
64
+ narwhals==2.15.0
65
+ networkx==3.4.2
66
+ numpy==1.26.4
67
+ nvidia-cublas-cu12==12.1.3.1
68
+ nvidia-cuda-cupti-cu12==12.1.105
69
+ nvidia-cuda-nvrtc-cu12==12.1.105
70
+ nvidia-cuda-runtime-cu12==12.1.105
71
+ nvidia-cudnn-cu12==9.1.0.70
72
+ nvidia-cufft-cu12==11.0.2.54
73
+ nvidia-curand-cu12==10.3.2.106
74
+ nvidia-cusolver-cu12==11.4.5.107
75
+ nvidia-cusparse-cu12==12.1.0.106
76
+ nvidia-nccl-cu12==2.20.5
77
+ nvidia-nvjitlink-cu12==12.9.86
78
+ nvidia-nvtx-cu12==12.1.105
79
+ omegaconf==2.3.0
80
+ opencv-python==4.11.0.86
81
+ orjson==3.11.5
82
+ packaging==24.2
83
+ pandas==2.3.3
84
+ parameterized==0.9.0
85
+ pillow==10.4.0
86
+ platformdirs==4.5.1
87
+ portalocker==3.2.0
88
+ proglog==0.1.12
89
+ propcache==0.4.1
90
+ protobuf==3.19.6
91
+ psutil==5.9.8
92
+ pydantic==1.10.26
93
+ pydub==0.25.1
94
+ pyparsing==3.3.1
95
+ pyrealsense2==2.56.5.9235
96
+ python-dateutil==2.9.0.post0
97
+ python-multipart==0.0.21
98
+ pytorch-lightning==2.6.0
99
+ pytorchvideo==0.1.5
100
+ pytz==2025.2
101
+ PyYAML==6.0.3
102
+ pyzmq==27.1.0
103
+ referencing==0.37.0
104
+ regex==2026.1.15
105
+ requests==2.32.5
106
+ rotary-embedding-torch==0.8.9
107
+ rpds-py==0.30.0
108
+ safetensors==0.7.0
109
+ scipy==1.15.3
110
+ semantic-version==2.10.0
111
+ sentry-sdk==2.49.0
112
+ setproctitle==1.3.7
113
+ shellingham==1.5.4
114
+ six==1.17.0
115
+ smmap==5.0.2
116
+ spaces==0.46.0
117
+ starlette==0.50.0
118
+ sympy==1.14.0
119
+ tabulate==0.9.0
120
+ termcolor==3.3.0
121
+ timm==1.0.24
122
+ toolz==0.12.1
123
+ torch==2.4.1
124
+ torch-fidelity==0.3.0
125
+ torchmetrics==0.11.4
126
+ torchvision==0.19.1
127
+ tqdm==4.67.1
128
+ triton==3.0.0
129
+ typer-slim==0.21.1
130
+ typing_extensions==4.15.0
131
+ tzdata==2025.3
132
+ urllib3==2.6.3
133
+ uvicorn==0.40.0
134
+ wandb==0.17.9
135
+ wandb_osh==1.2.1
136
+ websockets==11.0.3
137
+ yacs==0.1.8
138
+ yarl==1.22.0
139
+ zipp==3.23.0
train_3stages.sh ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb enabled
2
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
3
+ export NCCL_P2P_DISABLE=1
4
+ # export HYDRA_FULL_ERROR=1
5
+
6
+ set -e # Exit on any error
7
+ set -o pipefail # Exit on pipe failures
8
+
9
+ #Stage 1
10
+ python -m main +name=train \
11
+ +diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
12
+ +vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
13
+ +customized_load=true \
14
+ +seperate_load=true \
15
+ +zero_init_gate=true \
16
+ dataset.n_frames=8 \
17
+ dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
18
+ +dataset.n_frames_valid=700 \
19
+ +dataset.angle_range=110 \
20
+ +dataset.pos_range=2 \
21
+ +dataset.memory_condition_length=8 \
22
+ +dataset.customized_validation=true \
23
+ +dataset.add_timestamp_embedding=true \
24
+ +dataset.wo_updown=true \
25
+ +algorithm.n_tokens=8 \
26
+ +algorithm.memory_condition_length=8 \
27
+ algorithm.context_frames=600 \
28
+ +algorithm.relative_embedding=true \
29
+ +algorithm.log_video=true \
30
+ +algorithm.add_timestamp_embedding=true \
31
+ +algorithm.metrics=[lpips,psnr] \
32
+ experiment.training.checkpointing.every_n_train_steps=2500 \
33
+ experiment.training.max_steps=120000 \
34
+ +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
35
+
36
+ #Stage 2
37
+ python -m main +name=train \
38
+ dataset.n_frames=8 \
39
+ dataset.save_dir=data/minecraft \
40
+ +dataset.n_frames_valid=700 \
41
+ +dataset.angle_range=110 \
42
+ +dataset.pos_range=8 \
43
+ +dataset.memory_condition_length=8 \
44
+ +dataset.customized_validation=true \
45
+ +dataset.add_timestamp_embedding=true \
46
+ +dataset.wo_updown=true \
47
+ +algorithm.n_tokens=8 \
48
+ +algorithm.memory_condition_length=8 \
49
+ algorithm.context_frames=600 \
50
+ +algorithm.relative_embedding=true \
51
+ +algorithm.log_video=true \
52
+ +algorithm.add_timestamp_embedding=true \
53
+ +algorithm.metrics=[lpips,psnr] \
54
+ experiment.training.checkpointing.every_n_train_steps=2500 \
55
+ resume=ot7jqmgn \
56
+ +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
57
+ experiment.training.max_steps=240000
58
+
59
+ #Stage 3
60
+ python -m main +name=train \
61
+ dataset.n_frames=8 \
62
+ dataset.save_dir=data/minecraft \
63
+ +dataset.n_frames_valid=700 \
64
+ +dataset.angle_range=110 \
65
+ +dataset.pos_range=8 \
66
+ +dataset.memory_condition_length=8 \
67
+ +dataset.customized_validation=true \
68
+ +dataset.add_timestamp_embedding=true \
69
+ +dataset.wo_updown=false \
70
+ +algorithm.n_tokens=8 \
71
+ +algorithm.memory_condition_length=8 \
72
+ algorithm.context_frames=600 \
73
+ +algorithm.relative_embedding=true \
74
+ +algorithm.log_video=true \
75
+ +algorithm.add_timestamp_embedding=true \
76
+ +algorithm.metrics=[lpips,psnr] \
77
+ experiment.training.checkpointing.every_n_train_steps=2500 \
78
+ resume=ot7jqmgn \
79
+ +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
80
+ experiment.training.max_steps=700000
train_stage_1.sh CHANGED
@@ -1,14 +1,16 @@
1
  wandb enabled
2
-
 
 
3
  # set -e
4
  python -m main +name=train \
5
- +diffusion_model_path=your_diffusion_model_path \
6
- +vae_path=your_vae_path \
7
  +customized_load=true \
8
  +seperate_load=true \
9
  +zero_init_gate=true \
10
  dataset.n_frames=8 \
11
- dataset.save_dir=data/minecraft \
12
  +dataset.n_frames_valid=700 \
13
  +dataset.angle_range=110 \
14
  +dataset.pos_range=2 \
@@ -22,8 +24,7 @@ python -m main +name=train \
22
  +algorithm.relative_embedding=true \
23
  +algorithm.log_video=true \
24
  +algorithm.add_timestamp_embedding=true \
25
- algorithm.metrics=[lpips,psnr] \
26
  experiment.training.checkpointing.every_n_train_steps=2500 \
27
- experiment.training.max_steps=120000
28
-
29
-
 
1
  wandb enabled
2
+ export CUDA_VISIBLE_DEVICES=0,1,2,3
3
+ export NCCL_P2P_DISABLE=1
4
+ # export HYDRA_FULL_ERROR=1
5
  # set -e
6
  python -m main +name=train \
7
+ +diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
8
+ +vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
9
  +customized_load=true \
10
  +seperate_load=true \
11
  +zero_init_gate=true \
12
  dataset.n_frames=8 \
13
+ dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
14
  +dataset.n_frames_valid=700 \
15
  +dataset.angle_range=110 \
16
  +dataset.pos_range=2 \
 
24
  +algorithm.relative_embedding=true \
25
  +algorithm.log_video=true \
26
  +algorithm.add_timestamp_embedding=true \
27
+ +algorithm.metrics=[lpips,psnr] \
28
  experiment.training.checkpointing.every_n_train_steps=2500 \
29
+ experiment.training.max_steps=120000 \
30
+ +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
 
train_stage_2.sh CHANGED
@@ -1,9 +1,11 @@
1
  wandb enabled
2
-
3
- # set -e
 
 
4
  python -m main +name=train \
5
  dataset.n_frames=8 \
6
- dataset.save_dir=data/minecraft \
7
  +dataset.n_frames_valid=700 \
8
  +dataset.angle_range=110 \
9
  +dataset.pos_range=8 \
@@ -17,9 +19,31 @@ python -m main +name=train \
17
  +algorithm.relative_embedding=true \
18
  +algorithm.log_video=true \
19
  +algorithm.add_timestamp_embedding=true \
20
- algorithm.metrics=[lpips,psnr] \
21
  experiment.training.checkpointing.every_n_train_steps=2500 \
22
- resume=your_wandb_job_id e.g.yhht29bz \
23
- +output_dir=your_saving_path e.g. outputs/2025-05-18/15-16-32 \
24
  experiment.training.max_steps=240000
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  wandb enabled
2
+ export CUDA_VISIBLE_DEVICES=4,5,6,7
3
+ export NCCL_P2P_DISABLE=1
4
+ # export HYDRA_FULL_ERROR=1
5
+ set -e
6
  python -m main +name=train \
7
  dataset.n_frames=8 \
8
+ dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
9
  +dataset.n_frames_valid=700 \
10
  +dataset.angle_range=110 \
11
  +dataset.pos_range=8 \
 
19
  +algorithm.relative_embedding=true \
20
  +algorithm.log_video=true \
21
  +algorithm.add_timestamp_embedding=true \
22
+ +algorithm.metrics=[lpips,psnr] \
23
  experiment.training.checkpointing.every_n_train_steps=2500 \
24
+ resume=ot7jqmgn \
25
+ +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
26
  experiment.training.max_steps=240000
27
 
28
+ #Stage 3
29
+ python -m main +name=train \
30
+ dataset.n_frames=8 \
31
+ dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
32
+ +dataset.n_frames_valid=700 \
33
+ +dataset.angle_range=110 \
34
+ +dataset.pos_range=8 \
35
+ +dataset.memory_condition_length=8 \
36
+ +dataset.customized_validation=true \
37
+ +dataset.add_timestamp_embedding=true \
38
+ +dataset.wo_updown=false \
39
+ +algorithm.n_tokens=8 \
40
+ +algorithm.memory_condition_length=8 \
41
+ algorithm.context_frames=600 \
42
+ +algorithm.relative_embedding=true \
43
+ +algorithm.log_video=true \
44
+ +algorithm.add_timestamp_embedding=true \
45
+ +algorithm.metrics=[lpips,psnr] \
46
+ experiment.training.checkpointing.every_n_train_steps=2500 \
47
+ resume=ot7jqmgn \
48
+ +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
49
+ experiment.training.max_steps=700000
train_stage_3.sh CHANGED
@@ -1,5 +1,7 @@
1
  wandb enabled
2
-
 
 
3
  # set -e
4
  python -m main +name=train \
5
  dataset.n_frames=8 \
@@ -17,8 +19,9 @@ python -m main +name=train \
17
  +algorithm.relative_embedding=true \
18
  +algorithm.log_video=true \
19
  +algorithm.add_timestamp_embedding=true \
20
- algorithm.metrics=[lpips,psnr] \
21
  experiment.training.checkpointing.every_n_train_steps=2500 \
22
- resume=your_wandb_job_id e.g.yhht29bz \
23
- +output_dir=your_saving_path e.g. outputs/2025-05-18/15-16-32 \
24
- experiment.training.max_steps=700000
 
 
1
  wandb enabled
2
+ export CUDA_VISIBLE_DEVICES=4,5
3
+ export NCCL_P2P_DISABLE=1
4
+ # export HYDRA_FULL_ERROR=1
5
  # set -e
6
  python -m main +name=train \
7
  dataset.n_frames=8 \
 
19
  +algorithm.relative_embedding=true \
20
  +algorithm.log_video=true \
21
  +algorithm.add_timestamp_embedding=true \
22
+ +algorithm.metrics=[lpips,psnr] \
23
  experiment.training.checkpointing.every_n_train_steps=2500 \
24
+ resume=qyyk38nw \
25
+ +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_1 \
26
+ # experiment.training.max_steps=700000
27
+ experiment.training.max_steps=350000
utils/distributed_utils.py CHANGED
@@ -1,3 +1,10 @@
1
- import wandb
2
 
3
- is_rank_zero = wandb.run is not None
 
 
 
 
 
 
 
 
1
+ import os
2
 
3
+ # Check standard environment variables for distributed training
4
+ # Default to True (rank 0) if not in a distributed environment
5
+ _rank = int(os.environ.get("RANK", 0))
6
+ _local_rank = int(os.environ.get("LOCAL_RANK", 0))
7
+
8
+ # We consider it rank zero if global rank is 0.
9
+ # Local rank check is usually redundant if rank is 0, but good for sanity.
10
+ is_rank_zero = _rank == 0