Reproduce Training & Fix distributed eval
Browse files- README.md +38 -173
- algorithms/worldmem/df_base.py +1 -0
- algorithms/worldmem/df_video.py +88 -60
- algorithms/worldmem/models/diffusion.py +1 -1
- configurations/experiment/base_pytorch.yaml +3 -1
- configurations/experiment/exp_video.yaml +1 -0
- configurations/training.yaml +1 -1
- datasets/video/base_video_dataset.py +1 -0
- datasets/video/minecraft_video_dataset.py +5 -5
- evaluate.sh +16 -4
- experiments/exp_base.py +3 -2
- infer.sh +10 -3
- main.py +4 -0
- requirements.txt +138 -27
- train_3stages.sh +80 -0
- train_stage_1.sh +9 -8
- train_stage_2.sh +30 -6
- train_stage_3.sh +8 -5
- utils/distributed_utils.py +9 -2
README.md
CHANGED
|
@@ -1,201 +1,66 @@
|
|
|
|
|
| 1 |
|
| 2 |
-
|
| 3 |
-
<p align="center">
|
| 4 |
-
|
| 5 |
-
<p align="center">
|
| 6 |
-
<img src="assets/worldmem_logo.png" alt="WORLDMEM Icon" width="80"/>
|
| 7 |
-
</p>
|
| 8 |
-
<h1 align="center"><strong>WorldMem: Long-term Consistent World Simulation <br> with Memory</strong></h1>
|
| 9 |
-
<p align="center"><span><a href=""></a></span>
|
| 10 |
-
<a href="https://xizaoqu.github.io">Zeqi Xiao<sup>1</sup></a>
|
| 11 |
-
<a href="https://nirvanalan.github.io/">Yushi Lan<sup>1</sup></a>
|
| 12 |
-
<a href="https://zhouyifan.net/about/">Yifan Zhou<sup>1</sup></a>
|
| 13 |
-
<a href="https://vicky0522.github.io/Wenqi-Ouyang/">Wenqi Ouyang<sup>1</sup></a>
|
| 14 |
-
<a href="https://williamyang1991.github.io/">Shuai Yang<sup>2</sup></a>
|
| 15 |
-
<a href="https://zengyh1900.github.io/">Yanhong Zeng<sup>3</sup></a>
|
| 16 |
-
<a href="https://xingangpan.github.io/">Xingang Pan<sup>1</sup></a> <br>
|
| 17 |
-
<sup>1</sup>S-Lab, Nanyang Technological University, <br> <sup>2</sup>Wangxuan Institute of Computer Technology, Peking University,<br> <sup>3</sup>Shanghai AI Laboratory
|
| 18 |
-
</p>
|
| 19 |
-
</p>
|
| 20 |
-
|
| 21 |
-
<p align="center">
|
| 22 |
-
<a href="https://arxiv.org/abs/2504.12369" target='_blank'>
|
| 23 |
-
<img src="https://img.shields.io/badge/arXiv-2504.12369-blue?">
|
| 24 |
-
</a>
|
| 25 |
-
<a href="https://xizaoqu.github.io/worldmem/" target='_blank'>
|
| 26 |
-
<img src="https://img.shields.io/badge/Project-🚀-blue">
|
| 27 |
-
</a>
|
| 28 |
-
<a href="https://huggingface.co/spaces/yslan/worldmem" target="_blank">
|
| 29 |
-
<img src="https://img.shields.io/badge/🤗 HuggingFace-Demo-orange" />
|
| 30 |
-
</a>
|
| 31 |
-
</p>
|
| 32 |
-
|
| 33 |
-
https://github.com/user-attachments/assets/fb8a32e2-9470-4819-a93d-c38caf76d72c
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
## Installation
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
conda activate worldmem
|
| 41 |
pip install -r requirements.txt
|
| 42 |
conda install -c conda-forge ffmpeg=4.3.2
|
| 43 |
```
|
| 44 |
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
``
|
| 49 |
-
python app.py
|
| 50 |
-
```
|
| 51 |
-
|
| 52 |
-
## Run
|
| 53 |
-
|
| 54 |
-
To enable cloud logging with [Weights & Biases (wandb)](https://wandb.ai/site), follow these steps:
|
| 55 |
-
|
| 56 |
-
1. Sign up for a wandb account.
|
| 57 |
-
2. Run the following command to log in:
|
| 58 |
-
|
| 59 |
-
```bash
|
| 60 |
-
wandb login
|
| 61 |
-
```
|
| 62 |
-
|
| 63 |
-
3. Open `configurations/training.yaml` and set the `entity` and `project` field to your wandb username.
|
| 64 |
-
|
| 65 |
-
---
|
| 66 |
-
|
| 67 |
-
### Training
|
| 68 |
-
|
| 69 |
-
Download pretrained weights from [Oasis](https://github.com/etched-ai/open-oasis).
|
| 70 |
-
|
| 71 |
-
Training the model on 4 H100 GPUs, it converges after approximately 500K steps.
|
| 72 |
-
We observe that gradually increasing task difficulty improves performance. Thus, we adopt a multi-stage training strategy:
|
| 73 |
-
,
|
| 74 |
-
```bash
|
| 75 |
-
sh train_stage_1.sh # Small range, no vertical turning
|
| 76 |
-
sh train_stage_2.sh # Large range, no vertical turning
|
| 77 |
-
sh train_stage_3.sh # Large range, with vertical turning
|
| 78 |
-
```
|
| 79 |
-
|
| 80 |
-
To resume training from a previous checkpoint, configure the `resume` and `output_dir` variables in the corresponding `.sh` script.
|
| 81 |
-
|
| 82 |
-
---
|
| 83 |
-
|
| 84 |
-
### Inference
|
| 85 |
-
|
| 86 |
-
To run inference:
|
| 87 |
-
|
| 88 |
-
```bash
|
| 89 |
-
sh infer.sh
|
| 90 |
-
```
|
| 91 |
-
|
| 92 |
-
You can either **load the diffusion model and VAE separately**:
|
| 93 |
-
|
| 94 |
-
```bash
|
| 95 |
-
+diffusion_model_path=zeqixiao/worldmem_checkpoints/diffusion_only.ckpt \
|
| 96 |
-
+vae_path=zeqixiao/worldmem_checkpoints/vae_only.ckpt \
|
| 97 |
-
+customized_load=true \
|
| 98 |
-
+seperate_load=true \
|
| 99 |
-
```
|
| 100 |
-
|
| 101 |
-
Or **load a combined checkpoint**:
|
| 102 |
-
|
| 103 |
-
```bash
|
| 104 |
-
+load=your_model_path \
|
| 105 |
-
+customized_load=true \
|
| 106 |
-
+seperate_load=false \
|
| 107 |
-
```
|
| 108 |
-
|
| 109 |
-
### Evaluation
|
| 110 |
-
|
| 111 |
-
To run evaluation:
|
| 112 |
-
|
| 113 |
-
```bash
|
| 114 |
-
sh evaluate.sh
|
| 115 |
-
```
|
| 116 |
-
|
| 117 |
-
This script reproduces the results in Table 1 (beyond context window). It will generate PSNR and Lpips. Evaluating 1 case on 1 A100 GPU takes approximately 6 minutes. You can adjust `experiment.test.limit_batch` to specify the number of cases to evaluate.
|
| 118 |
-
|
| 119 |
-
Visual results will be saved by default to a timestamped directory (e.g., `outputs/2025-11-30/00-02-42`).
|
| 120 |
-
|
| 121 |
-
To calculate the FID score, run:
|
| 122 |
-
|
| 123 |
-
```bash
|
| 124 |
-
python calculate_fid.py --videos_dir <path_to_videos>
|
| 125 |
-
```
|
| 126 |
-
|
| 127 |
-
For example:
|
| 128 |
-
|
| 129 |
-
```bash
|
| 130 |
-
python calculate_fid.py --videos_dir outputs/2025-11-30/00-02-42/videos/test_vis
|
| 131 |
-
```
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
| Metric | Value |
|
| 136 |
-
|--------|--------|
|
| 137 |
-
| PSNR | 24.01 |
|
| 138 |
-
| LPIPS | 0.1667 |
|
| 139 |
-
| FID | 15.13 |
|
| 140 |
-
|
| 141 |
-
*Note: FID is computed over 5000 frames.*
|
| 142 |
-
|
| 143 |
-
---
|
| 144 |
-
|
| 145 |
-
## Dataset
|
| 146 |
-
|
| 147 |
-
Download the Minecraft dataset from [Hugging Face](https://huggingface.co/datasets/zeqixiao/worldmem_minecraft_dataset)
|
| 148 |
-
|
| 149 |
-
Place the dataset in the following directory structure:
|
| 150 |
-
|
| 151 |
-
```
|
| 152 |
data/
|
| 153 |
└── minecraft/
|
| 154 |
├── training/
|
| 155 |
-
|
| 156 |
└── test/
|
| 157 |
```
|
| 158 |
|
| 159 |
-
|
| 160 |
|
| 161 |
-
|
|
|
|
|
|
|
| 162 |
|
| 163 |
```bash
|
| 164 |
-
|
|
|
|
|
|
|
| 165 |
```
|
| 166 |
|
| 167 |
-
|
| 168 |
-
- `-o`: Output directory for generated data
|
| 169 |
-
- `-z`: Number of parallel workers
|
| 170 |
-
- `--env_type`: Environment type (e.g., `plains`)
|
| 171 |
-
|
| 172 |
|
| 173 |
-
|
|
|
|
|
|
|
| 174 |
|
| 175 |
-
|
| 176 |
-
- [x] Release training pipeline on Minecraft;
|
| 177 |
-
- [x] Release training data on Minecraft;
|
| 178 |
-
- [x] Release evaluation scripts and data generator.
|
| 179 |
|
|
|
|
| 180 |
|
|
|
|
| 181 |
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
-
|
| 185 |
|
|
|
|
|
|
|
| 186 |
```
|
| 187 |
-
@misc{xiao2025worldmemlongtermconsistentworld,
|
| 188 |
-
title={WORLDMEM: Long-term Consistent World Simulation with Memory},
|
| 189 |
-
author={Zeqi Xiao and Yushi Lan and Yifan Zhou and Wenqi Ouyang and Shuai Yang and Yanhong Zeng and Xingang Pan},
|
| 190 |
-
year={2025},
|
| 191 |
-
eprint={2504.12369},
|
| 192 |
-
archivePrefix={arXiv},
|
| 193 |
-
primaryClass={cs.CV},
|
| 194 |
-
url={https://arxiv.org/abs/2504.12369},
|
| 195 |
-
}
|
| 196 |
-
```
|
| 197 |
-
|
| 198 |
-
## 👏 Acknowledgements
|
| 199 |
-
- [Diffusion Forcing](https://github.com/buoyancy99/diffusion-forcing): Diffusion Forcing provides flexible training and inference strategies for our methods.
|
| 200 |
-
- [Minedojo](https://github.com/MineDojo/MineDojo): We collect our Minecraft dataset from Minedojo.
|
| 201 |
-
- [Open-oasis](https://github.com/etched-ai/open-oasis): Our model architecture is based on Open-oasis. We also use pretrained VAE and DiT weight from it.
|
|
|
|
| 1 |
+
# WorldMem
|
| 2 |
|
| 3 |
+
Long-term consistent world simulation with memory.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
## Environment (conda)
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
conda create -n worldmem python=3.10
|
| 9 |
conda activate worldmem
|
| 10 |
pip install -r requirements.txt
|
| 11 |
conda install -c conda-forge ffmpeg=4.3.2
|
| 12 |
```
|
| 13 |
|
| 14 |
+
## Data preparation (data folder)
|
| 15 |
|
| 16 |
+
1. Download the Minecraft dataset:
|
| 17 |
+
https://huggingface.co/datasets/zeqixiao/worldmem_minecraft_dataset
|
| 18 |
+
2. Place it under `data/` with this structure:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
```text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
data/
|
| 22 |
└── minecraft/
|
| 23 |
├── training/
|
| 24 |
+
├── validation/
|
| 25 |
└── test/
|
| 26 |
```
|
| 27 |
|
| 28 |
+
The training and evaluation scripts expect the dataset to live at `data/minecraft` by default.
|
| 29 |
|
| 30 |
+
## Training
|
| 31 |
+
|
| 32 |
+
Run a single stage:
|
| 33 |
|
| 34 |
```bash
|
| 35 |
+
sh train_stage_1.sh
|
| 36 |
+
sh train_stage_2.sh
|
| 37 |
+
sh train_stage_3.sh
|
| 38 |
```
|
| 39 |
|
| 40 |
+
Run all stages:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
```bash
|
| 43 |
+
sh train_3stages.sh
|
| 44 |
+
```
|
| 45 |
|
| 46 |
+
The stage scripts include dataset and checkpoint paths. Update those paths or override them on the CLI to match your local setup.
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
## Training config (exp_video.yaml)
|
| 49 |
|
| 50 |
+
Defaults live in `configurations/experiment/exp_video.yaml`.
|
| 51 |
|
| 52 |
+
Common fields to edit:
|
| 53 |
+
- `training.lr`
|
| 54 |
+
- `training.precision`
|
| 55 |
+
- `training.batch_size`
|
| 56 |
+
- `training.max_steps`
|
| 57 |
+
- `training.checkpointing.every_n_train_steps`
|
| 58 |
+
- `validation.val_every_n_step`
|
| 59 |
+
- `validation.batch_size`
|
| 60 |
+
- `test.batch_size`
|
| 61 |
|
| 62 |
+
You can also override values from the CLI used in the scripts:
|
| 63 |
|
| 64 |
+
```bash
|
| 65 |
+
python -m main +name=train experiment.training.batch_size=8 experiment.training.max_steps=100000
|
| 66 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
algorithms/worldmem/df_base.py
CHANGED
|
@@ -33,6 +33,7 @@ class DiffusionForcingBase(BasePytorchAlgo):
|
|
| 33 |
self.action_cond_dim = cfg.action_cond_dim
|
| 34 |
self.causal = cfg.causal
|
| 35 |
|
|
|
|
| 36 |
self.uncertainty_scale = cfg.uncertainty_scale
|
| 37 |
self.timesteps = cfg.diffusion.timesteps
|
| 38 |
self.sampling_timesteps = cfg.diffusion.sampling_timesteps
|
|
|
|
| 33 |
self.action_cond_dim = cfg.action_cond_dim
|
| 34 |
self.causal = cfg.causal
|
| 35 |
|
| 36 |
+
|
| 37 |
self.uncertainty_scale = cfg.uncertainty_scale
|
| 38 |
self.timesteps = cfg.diffusion.timesteps
|
| 39 |
self.sampling_timesteps = cfg.diffusion.sampling_timesteps
|
algorithms/worldmem/df_video.py
CHANGED
|
@@ -3,6 +3,7 @@ import random
|
|
| 3 |
import math
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
|
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import torchvision.transforms.functional as TF
|
| 8 |
from torchvision.transforms import InterpolationMode
|
|
@@ -21,6 +22,7 @@ from .models.vae import VAE_models
|
|
| 21 |
from .models.diffusion import Diffusion
|
| 22 |
from .models.pose_prediction import PosePredictionNet
|
| 23 |
import glob
|
|
|
|
| 24 |
|
| 25 |
# Utility Functions
|
| 26 |
def euler_to_rotation_matrix(pitch, yaw):
|
|
@@ -376,7 +378,8 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 376 |
ref_mode=self.ref_mode
|
| 377 |
)
|
| 378 |
|
| 379 |
-
|
|
|
|
| 380 |
vae = VAE_models["vit-l-20-shallow-encoder"]()
|
| 381 |
self.vae = vae.eval()
|
| 382 |
|
|
@@ -430,13 +433,13 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 430 |
focal_length=self.focal_length,
|
| 431 |
image_height=xs.shape[-2],image_width=xs.shape[-1]
|
| 432 |
).to(xs.dtype)
|
| 433 |
-
)
|
| 434 |
frame_idx_list.append(
|
| 435 |
torch.cat([
|
| 436 |
frame_idx[i:i + 1] - frame_idx[i:i + 1],
|
| 437 |
frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1]
|
| 438 |
]).clone()
|
| 439 |
-
)
|
| 440 |
input_pose_condition = torch.cat(input_pose_condition)
|
| 441 |
frame_idx_list = torch.cat(frame_idx_list)
|
| 442 |
else:
|
|
@@ -476,66 +479,78 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 476 |
return {"loss": loss}
|
| 477 |
|
| 478 |
def on_validation_epoch_end(self, namespace="validation") -> None:
|
| 479 |
-
if not self
|
| 480 |
return
|
| 481 |
-
|
| 482 |
-
xs_pred = []
|
| 483 |
-
xs = []
|
| 484 |
-
for pred, gt in self.validation_step_outputs:
|
| 485 |
-
xs_pred.append(pred)
|
| 486 |
-
xs.append(gt)
|
| 487 |
-
|
| 488 |
-
xs_pred = torch.cat(xs_pred, 1)
|
| 489 |
-
if gt is not None:
|
| 490 |
-
xs = torch.cat(xs, 1)
|
| 491 |
-
else:
|
| 492 |
-
xs = None
|
| 493 |
|
| 494 |
-
if
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
lpips_model=self.validation_lpips_model,
|
| 515 |
-
lpips_batch_size=self.lpips_batch_size)
|
| 516 |
-
|
| 517 |
-
self.log_dict(
|
| 518 |
-
{"mse": metric_dict['mse'],
|
| 519 |
-
"psnr": metric_dict['psnr'],
|
| 520 |
-
"lpips": metric_dict['lpips']},
|
| 521 |
-
sync_dist=True
|
| 522 |
-
)
|
| 523 |
-
|
| 524 |
-
if self.log_curve:
|
| 525 |
-
psnr_values = metric_dict['frame_wise_psnr'].cpu().tolist()
|
| 526 |
-
frames = list(range(len(psnr_values)))
|
| 527 |
-
line_plot = wandb.plot.line_series(
|
| 528 |
-
xs = frames,
|
| 529 |
-
ys = [psnr_values],
|
| 530 |
-
keys = ["PSNR"],
|
| 531 |
-
title = "Frame-wise PSNR",
|
| 532 |
-
xname = "Frame index"
|
| 533 |
)
|
| 534 |
|
| 535 |
-
self.logger.experiment.log({"frame_wise_psnr_plot": line_plot})
|
| 536 |
-
|
| 537 |
self.validation_step_outputs.clear()
|
| 538 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
def _preprocess_batch(self, batch):
|
| 540 |
|
| 541 |
xs, conditions, pose_conditions, frame_index = batch
|
|
@@ -554,7 +569,7 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 554 |
return xs, conditions, pose_conditions, c2w_mat, frame_index
|
| 555 |
|
| 556 |
def encode(self, x):
|
| 557 |
-
# vae encoding
|
| 558 |
T = x.shape[0]
|
| 559 |
H, W = x.shape[-2:]
|
| 560 |
scaling_factor = 0.07843137255
|
|
@@ -783,8 +798,21 @@ class WorldMemMinecraft(DiffusionForcingBase):
|
|
| 783 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
|
| 784 |
xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
|
| 785 |
|
| 786 |
-
#
|
| 787 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 788 |
return
|
| 789 |
|
| 790 |
@torch.no_grad()
|
|
|
|
| 3 |
import math
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
import torch.nn.functional as F
|
| 8 |
import torchvision.transforms.functional as TF
|
| 9 |
from torchvision.transforms import InterpolationMode
|
|
|
|
| 22 |
from .models.diffusion import Diffusion
|
| 23 |
from .models.pose_prediction import PosePredictionNet
|
| 24 |
import glob
|
| 25 |
+
import wandb
|
| 26 |
|
| 27 |
# Utility Functions
|
| 28 |
def euler_to_rotation_matrix(pitch, yaw):
|
|
|
|
| 378 |
ref_mode=self.ref_mode
|
| 379 |
)
|
| 380 |
|
| 381 |
+
# Avoid distributed sync inside torchmetrics; reduce metrics manually across ranks.
|
| 382 |
+
self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity(sync_on_compute=False)
|
| 383 |
vae = VAE_models["vit-l-20-shallow-encoder"]()
|
| 384 |
self.vae = vae.eval()
|
| 385 |
|
|
|
|
| 433 |
focal_length=self.focal_length,
|
| 434 |
image_height=xs.shape[-2],image_width=xs.shape[-1]
|
| 435 |
).to(xs.dtype)
|
| 436 |
+
) # [V(1 + memory_condition_length),B ,H, W, 6]
|
| 437 |
frame_idx_list.append(
|
| 438 |
torch.cat([
|
| 439 |
frame_idx[i:i + 1] - frame_idx[i:i + 1],
|
| 440 |
frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1]
|
| 441 |
]).clone()
|
| 442 |
+
) # [V(1 + memory_condition_length),B] (0 for current frame, others for memory frames with relative index to current frame)
|
| 443 |
input_pose_condition = torch.cat(input_pose_condition)
|
| 444 |
frame_idx_list = torch.cat(frame_idx_list)
|
| 445 |
else:
|
|
|
|
| 479 |
return {"loss": loss}
|
| 480 |
|
| 481 |
def on_validation_epoch_end(self, namespace="validation") -> None:
|
| 482 |
+
if not hasattr(self, "_metric_device"):
|
| 483 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
+
if dist.is_available() and dist.is_initialized():
|
| 486 |
+
for tensor in (
|
| 487 |
+
self._mse_sum,
|
| 488 |
+
self._mse_count,
|
| 489 |
+
self._psnr_sum,
|
| 490 |
+
self._psnr_count,
|
| 491 |
+
self._lpips_sum,
|
| 492 |
+
self._lpips_count,
|
| 493 |
+
):
|
| 494 |
+
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
|
| 495 |
+
|
| 496 |
+
mse = self._mse_sum / self._mse_count.clamp_min(1.0)
|
| 497 |
+
psnr = self._psnr_sum / self._psnr_count.clamp_min(1.0)
|
| 498 |
+
lpips = self._lpips_sum / self._lpips_count.clamp_min(1.0)
|
| 499 |
+
|
| 500 |
+
if self.trainer is None or self.trainer.is_global_zero:
|
| 501 |
+
if self._mse_count.item() > 0:
|
| 502 |
+
self.log_dict(
|
| 503 |
+
{"mse": mse, "psnr": psnr, "lpips": lpips},
|
| 504 |
+
sync_dist=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
)
|
| 506 |
|
|
|
|
|
|
|
| 507 |
self.validation_step_outputs.clear()
|
| 508 |
|
| 509 |
+
def on_validation_epoch_start(self) -> None:
|
| 510 |
+
self._reset_metric_accumulators()
|
| 511 |
+
|
| 512 |
+
def on_test_epoch_start(self) -> None:
|
| 513 |
+
self._reset_metric_accumulators()
|
| 514 |
+
|
| 515 |
+
def _reset_metric_accumulators(self) -> None:
|
| 516 |
+
self._metric_device = next(self.validation_lpips_model.parameters()).device
|
| 517 |
+
self._mse_sum = torch.tensor(0.0, device=self._metric_device)
|
| 518 |
+
self._mse_count = torch.tensor(0.0, device=self._metric_device)
|
| 519 |
+
self._psnr_sum = torch.tensor(0.0, device=self._metric_device)
|
| 520 |
+
self._psnr_count = torch.tensor(0.0, device=self._metric_device)
|
| 521 |
+
self._lpips_sum = torch.tensor(0.0, device=self._metric_device)
|
| 522 |
+
self._lpips_count = torch.tensor(0.0, device=self._metric_device)
|
| 523 |
+
|
| 524 |
+
def _update_metric_accumulators(self, xs_pred: torch.Tensor, xs_gt: torch.Tensor) -> None:
|
| 525 |
+
xs_pred_device = xs_pred.to(self._metric_device)
|
| 526 |
+
xs_device = xs_gt.to(self._metric_device)
|
| 527 |
+
|
| 528 |
+
metric_dict = get_validation_metrics_for_videos(
|
| 529 |
+
xs_pred_device,
|
| 530 |
+
xs_device,
|
| 531 |
+
lpips_model=self.validation_lpips_model,
|
| 532 |
+
lpips_batch_size=self.lpips_batch_size,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
mse_val = metric_dict["mse"].detach()
|
| 536 |
+
psnr_val = metric_dict["psnr"].detach()
|
| 537 |
+
lpips_val = torch.tensor(metric_dict["lpips"], device=self._metric_device)
|
| 538 |
+
|
| 539 |
+
mse_count_batch = torch.tensor(float(xs_pred_device.numel()), device=self._metric_device)
|
| 540 |
+
psnr_count_batch = torch.tensor(float(xs_pred_device.shape[1]), device=self._metric_device)
|
| 541 |
+
lpips_count_batch = torch.tensor(
|
| 542 |
+
float(xs_pred_device.shape[0] * xs_pred_device.shape[1]), device=self._metric_device
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
self._mse_sum += mse_val * mse_count_batch
|
| 546 |
+
self._psnr_sum += psnr_val * psnr_count_batch
|
| 547 |
+
self._lpips_sum += lpips_val * lpips_count_batch
|
| 548 |
+
self._mse_count += mse_count_batch
|
| 549 |
+
self._psnr_count += psnr_count_batch
|
| 550 |
+
self._lpips_count += lpips_count_batch
|
| 551 |
+
|
| 552 |
+
del xs_pred_device, xs_device
|
| 553 |
+
|
| 554 |
def _preprocess_batch(self, batch):
|
| 555 |
|
| 556 |
xs, conditions, pose_conditions, frame_index = batch
|
|
|
|
| 569 |
return xs, conditions, pose_conditions, c2w_mat, frame_index
|
| 570 |
|
| 571 |
def encode(self, x):
|
| 572 |
+
# vae encoding x with shape (t b c h w)
|
| 573 |
T = x.shape[0]
|
| 574 |
H, W = x.shape[-2:]
|
| 575 |
scaling_factor = 0.07843137255
|
|
|
|
| 798 |
xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
|
| 799 |
xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
|
| 800 |
|
| 801 |
+
# Save videos for every batch (rank is encoded in filenames).
|
| 802 |
+
if self.logger and self.log_video:
|
| 803 |
+
log_video(
|
| 804 |
+
xs_pred,
|
| 805 |
+
xs_decode,
|
| 806 |
+
step=batch_idx,
|
| 807 |
+
namespace=namespace + "_vis",
|
| 808 |
+
context_frames=self.context_frames,
|
| 809 |
+
logger=self.logger.experiment,
|
| 810 |
+
save_local=self.save_local,
|
| 811 |
+
local_save_dir=self.local_save_dir,
|
| 812 |
+
)
|
| 813 |
+
|
| 814 |
+
# Stream metrics to avoid holding all outputs in memory.
|
| 815 |
+
self._update_metric_accumulators(xs_pred, xs_decode)
|
| 816 |
return
|
| 817 |
|
| 818 |
@torch.no_grad()
|
algorithms/worldmem/models/diffusion.py
CHANGED
|
@@ -169,7 +169,7 @@ class Diffusion(nn.Module):
|
|
| 169 |
mode=mode, reference_length=reference_length, frame_idx=frame_idx)
|
| 170 |
model_output = model_output.permute(1,0,2,3,4)
|
| 171 |
x = x.permute(1,0,2,3,4)
|
| 172 |
-
t = t.permute(1,0)
|
| 173 |
|
| 174 |
if self.objective == "pred_noise":
|
| 175 |
pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
|
|
|
|
| 169 |
mode=mode, reference_length=reference_length, frame_idx=frame_idx)
|
| 170 |
model_output = model_output.permute(1,0,2,3,4)
|
| 171 |
x = x.permute(1,0,2,3,4)
|
| 172 |
+
t = t.permute(1,0)
|
| 173 |
|
| 174 |
if self.objective == "pred_noise":
|
| 175 |
pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
|
configurations/experiment/base_pytorch.yaml
CHANGED
|
@@ -35,7 +35,8 @@ validation:
|
|
| 35 |
val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
|
| 36 |
val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
|
| 37 |
limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
|
| 38 |
-
inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
|
|
|
|
| 39 |
data:
|
| 40 |
num_workers: 4 # number of CPU threads for data preprocessing, for validation.
|
| 41 |
shuffle: False # whether validation data will be shuffled
|
|
@@ -45,6 +46,7 @@ test:
|
|
| 45 |
compile: False # whether to compile the model with torch.compile
|
| 46 |
batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
| 47 |
limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
|
|
|
|
| 48 |
data:
|
| 49 |
num_workers: 4 # number of CPU threads for data preprocessing, for test.
|
| 50 |
shuffle: False # whether test data will be shuffled
|
|
|
|
| 35 |
val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
|
| 36 |
val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
|
| 37 |
limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
|
| 38 |
+
# inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
|
| 39 |
+
inference_mode: False # whether to run validation in inference mode (enable_grad won't work!)
|
| 40 |
data:
|
| 41 |
num_workers: 4 # number of CPU threads for data preprocessing, for validation.
|
| 42 |
shuffle: False # whether validation data will be shuffled
|
|
|
|
| 46 |
compile: False # whether to compile the model with torch.compile
|
| 47 |
batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
| 48 |
limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
|
| 49 |
+
inference_mode: False # whether to run test in inference mode (enable_grad won't work!)
|
| 50 |
data:
|
| 51 |
num_workers: 4 # number of CPU threads for data preprocessing, for test.
|
| 52 |
shuffle: False # whether test data will be shuffled
|
configurations/experiment/exp_video.yaml
CHANGED
|
@@ -7,6 +7,7 @@ training:
|
|
| 7 |
lr: 2e-5
|
| 8 |
precision: 16-mixed
|
| 9 |
batch_size: 4
|
|
|
|
| 10 |
max_epochs: -1
|
| 11 |
max_steps: 2000005
|
| 12 |
checkpointing:
|
|
|
|
| 7 |
lr: 2e-5
|
| 8 |
precision: 16-mixed
|
| 9 |
batch_size: 4
|
| 10 |
+
# batch_size: 8
|
| 11 |
max_epochs: -1
|
| 12 |
max_steps: 2000005
|
| 13 |
checkpointing:
|
configurations/training.yaml
CHANGED
|
@@ -8,7 +8,7 @@ defaults:
|
|
| 8 |
debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
|
| 9 |
|
| 10 |
wandb:
|
| 11 |
-
entity:
|
| 12 |
project: worldmem # wandb project name; if not provided, defaults to root folder name [fixme]
|
| 13 |
mode: online # set wandb logging to online, offline or dryrun
|
| 14 |
|
|
|
|
| 8 |
debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
|
| 9 |
|
| 10 |
wandb:
|
| 11 |
+
entity: turlin # wandb account name / organization name [fixme]
|
| 12 |
project: worldmem # wandb project name; if not provided, defaults to root folder name [fixme]
|
| 13 |
mode: online # set wandb logging to online, offline or dryrun
|
| 14 |
|
datasets/video/base_video_dataset.py
CHANGED
|
@@ -47,6 +47,7 @@ class BaseVideoDataset(torch.utils.data.Dataset, ABC):
|
|
| 47 |
self.clips_per_video = np.clip(np.array(self.metadata) - self.n_frames + 1, a_min=1, a_max=None).astype(
|
| 48 |
np.int32
|
| 49 |
)
|
|
|
|
| 50 |
self.cum_clips_per_video = np.cumsum(self.clips_per_video)
|
| 51 |
self.transform = transforms.Resize((self.resolution, self.resolution), antialias=True)
|
| 52 |
|
|
|
|
| 47 |
self.clips_per_video = np.clip(np.array(self.metadata) - self.n_frames + 1, a_min=1, a_max=None).astype(
|
| 48 |
np.int32
|
| 49 |
)
|
| 50 |
+
|
| 51 |
self.cum_clips_per_video = np.cumsum(self.clips_per_video)
|
| 52 |
self.transform = transforms.Resize((self.resolution, self.resolution), antialias=True)
|
| 53 |
|
datasets/video/minecraft_video_dataset.py
CHANGED
|
@@ -126,7 +126,7 @@ class MinecraftVideoDataset(BaseVideoDataset):
|
|
| 126 |
try:
|
| 127 |
return self.load_data(idx)
|
| 128 |
except Exception as e:
|
| 129 |
-
|
| 130 |
idx = (idx + 1) % len(self)
|
| 131 |
|
| 132 |
def load_data(self, idx):
|
|
@@ -184,9 +184,9 @@ class MinecraftVideoDataset(BaseVideoDataset):
|
|
| 184 |
dis = np.abs(poses[:, None] - poses_pool[None, :])
|
| 185 |
dis[..., 3:][dis[..., 3:] > 180] = 360 - dis[..., 3:][dis[..., 3:] > 180]
|
| 186 |
|
| 187 |
-
spatial_match = (dis[..., :3] <= self.pos_range).sum(-1) >= 3
|
| 188 |
-
angular_match = (dis[..., 3:] <= self.angle_range).sum(-1) >= 2
|
| 189 |
-
not_exact_match = ((dis[..., :3] > 0).sum(-1) >= 1) | ((dis[..., 3:] > 0).sum(-1) >= 1)
|
| 190 |
|
| 191 |
valid_index = (spatial_match & angular_match & not_exact_match).sum(0)
|
| 192 |
valid_index[:100] = 0 # skip unstable early frames
|
|
@@ -237,7 +237,7 @@ class MinecraftVideoDataset(BaseVideoDataset):
|
|
| 237 |
timestamp = np.arange(self.n_frames)
|
| 238 |
|
| 239 |
# === 7. Convert video to torch format ===
|
| 240 |
-
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous()
|
| 241 |
|
| 242 |
# === 9. Return all items ===
|
| 243 |
return (
|
|
|
|
| 126 |
try:
|
| 127 |
return self.load_data(idx)
|
| 128 |
except Exception as e:
|
| 129 |
+
print(f"Retrying due to error: {e}")
|
| 130 |
idx = (idx + 1) % len(self)
|
| 131 |
|
| 132 |
def load_data(self, idx):
|
|
|
|
| 184 |
dis = np.abs(poses[:, None] - poses_pool[None, :])
|
| 185 |
dis[..., 3:][dis[..., 3:] > 180] = 360 - dis[..., 3:][dis[..., 3:] > 180]
|
| 186 |
|
| 187 |
+
spatial_match = (dis[..., :3] <= self.pos_range).sum(-1) >= 3 # X, Y, Z axis all within range
|
| 188 |
+
angular_match = (dis[..., 3:] <= self.angle_range).sum(-1) >= 2 # Pitch, Yaw all within range
|
| 189 |
+
not_exact_match = ((dis[..., :3] > 0).sum(-1) >= 1) | ((dis[..., 3:] > 0).sum(-1) >= 1) # At least one axis is in range
|
| 190 |
|
| 191 |
valid_index = (spatial_match & angular_match & not_exact_match).sum(0)
|
| 192 |
valid_index[:100] = 0 # skip unstable early frames
|
|
|
|
| 237 |
timestamp = np.arange(self.n_frames)
|
| 238 |
|
| 239 |
# === 7. Convert video to torch format ===
|
| 240 |
+
video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous() # (T, H, W, C) -> (T, C, H, W)
|
| 241 |
|
| 242 |
# === 9. Return all items ===
|
| 243 |
return (
|
evaluate.sh
CHANGED
|
@@ -1,15 +1,27 @@
|
|
| 1 |
export PYTHONWARNINGS="ignore"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
wandb offline
|
| 3 |
python -m main +name=infer \
|
| 4 |
experiment.tasks=[test] \
|
| 5 |
dataset.validation_multiplier=1 \
|
| 6 |
+dataset.seed=42 \
|
| 7 |
-
+diffusion_model_path=
|
| 8 |
-
+vae_path=
|
| 9 |
+customized_load=true \
|
| 10 |
+seperate_load=true \
|
| 11 |
dataset.n_frames=8 \
|
| 12 |
-
dataset.save_dir=
|
| 13 |
+dataset.n_frames_valid=700 \
|
| 14 |
algorithm.diffusion.sampling_timesteps=20 \
|
| 15 |
+algorithm.memory_condition_length=8 \
|
|
@@ -20,4 +32,4 @@ python -m main +name=infer \
|
|
| 20 |
+algorithm.n_tokens=8 \
|
| 21 |
algorithm.context_frames=600 \
|
| 22 |
experiment.test.batch_size=1 \
|
| 23 |
-
experiment.test.limit_batch=
|
|
|
|
| 1 |
export PYTHONWARNINGS="ignore"
|
| 2 |
+
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 3 |
+
|
| 4 |
+
# export NCCL_DEBUG=INFO
|
| 5 |
+
# export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
|
| 6 |
+
# export TORCH_DISTRIBUTED_DEBUG=DETAIL
|
| 7 |
+
# export NCCL_DEBUG_SUBSYS=COLL
|
| 8 |
+
# # Optional but very helpful while debugging (slower):
|
| 9 |
+
# export TORCH_NCCL_BLOCKING_WAIT=1
|
| 10 |
+
export NCCL_TIMEOUT=7200
|
| 11 |
+
export NCCL_P2P_DISABLE=1
|
| 12 |
+
export HYDRA_FULL_ERROR=1
|
| 13 |
+
|
| 14 |
wandb offline
|
| 15 |
python -m main +name=infer \
|
| 16 |
experiment.tasks=[test] \
|
| 17 |
dataset.validation_multiplier=1 \
|
| 18 |
+dataset.seed=42 \
|
| 19 |
+
+diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
|
| 20 |
+
+vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
|
| 21 |
+customized_load=true \
|
| 22 |
+seperate_load=true \
|
| 23 |
dataset.n_frames=8 \
|
| 24 |
+
dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
|
| 25 |
+dataset.n_frames_valid=700 \
|
| 26 |
algorithm.diffusion.sampling_timesteps=20 \
|
| 27 |
+algorithm.memory_condition_length=8 \
|
|
|
|
| 32 |
+algorithm.n_tokens=8 \
|
| 33 |
algorithm.context_frames=600 \
|
| 34 |
experiment.test.batch_size=1 \
|
| 35 |
+
experiment.test.limit_batch=160 \
|
experiments/exp_base.py
CHANGED
|
@@ -9,7 +9,7 @@ from abc import ABC, abstractmethod
|
|
| 9 |
from typing import Optional, Union, Literal, List, Dict
|
| 10 |
import pathlib
|
| 11 |
import os
|
| 12 |
-
|
| 13 |
import hydra
|
| 14 |
import torch
|
| 15 |
from lightning.pytorch.strategies.ddp import DDPStrategy
|
|
@@ -415,10 +415,11 @@ class BaseLightningExperiment(BaseExperiment):
|
|
| 415 |
logger=self.logger,
|
| 416 |
devices="auto",
|
| 417 |
num_nodes=self.cfg.num_nodes,
|
| 418 |
-
strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto",
|
| 419 |
callbacks=callbacks,
|
| 420 |
limit_test_batches=self.cfg.test.limit_batch,
|
| 421 |
precision=self.cfg.test.precision,
|
|
|
|
| 422 |
detect_anomaly=False, # self.cfg.debug,
|
| 423 |
)
|
| 424 |
|
|
|
|
| 9 |
from typing import Optional, Union, Literal, List, Dict
|
| 10 |
import pathlib
|
| 11 |
import os
|
| 12 |
+
from datetime import timedelta
|
| 13 |
import hydra
|
| 14 |
import torch
|
| 15 |
from lightning.pytorch.strategies.ddp import DDPStrategy
|
|
|
|
| 415 |
logger=self.logger,
|
| 416 |
devices="auto",
|
| 417 |
num_nodes=self.cfg.num_nodes,
|
| 418 |
+
strategy=DDPStrategy(find_unused_parameters=False, timeout=timedelta(hours=1)) if torch.cuda.device_count() > 1 else "auto",
|
| 419 |
callbacks=callbacks,
|
| 420 |
limit_test_batches=self.cfg.test.limit_batch,
|
| 421 |
precision=self.cfg.test.precision,
|
| 422 |
+
inference_mode=self.cfg.test.inference_mode,
|
| 423 |
detect_anomaly=False, # self.cfg.debug,
|
| 424 |
)
|
| 425 |
|
infer.sh
CHANGED
|
@@ -1,14 +1,21 @@
|
|
| 1 |
export PYTHONWARNINGS="ignore"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
wandb offline
|
| 3 |
python -m main +name=infer \
|
| 4 |
experiment.tasks=[validation] \
|
| 5 |
dataset.validation_multiplier=1 \
|
| 6 |
-
+diffusion_model_path=
|
| 7 |
-
+vae_path=
|
| 8 |
+customized_load=true \
|
| 9 |
+seperate_load=true \
|
| 10 |
dataset.n_frames=8 \
|
| 11 |
-
dataset.save_dir=
|
| 12 |
+dataset.n_frames_valid=700 \
|
| 13 |
+dataset.memory_condition_length=8 \
|
| 14 |
+dataset.customized_validation=true \
|
|
|
|
| 1 |
export PYTHONWARNINGS="ignore"
|
| 2 |
+
export NCCL_DEBUG=INFO
|
| 3 |
+
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
|
| 4 |
+
export TORCH_DISTRIBUTED_DEBUG=DETAIL
|
| 5 |
+
export NCCL_DEBUG_SUBSYS=COLL
|
| 6 |
+
# Optional but very helpful while debugging (slower):
|
| 7 |
+
export TORCH_NCCL_BLOCKING_WAIT=1
|
| 8 |
+
export NCCL_P2P_DISABLE=1
|
| 9 |
wandb offline
|
| 10 |
python -m main +name=infer \
|
| 11 |
experiment.tasks=[validation] \
|
| 12 |
dataset.validation_multiplier=1 \
|
| 13 |
+
+diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
|
| 14 |
+
+vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
|
| 15 |
+customized_load=true \
|
| 16 |
+seperate_load=true \
|
| 17 |
dataset.n_frames=8 \
|
| 18 |
+
dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
|
| 19 |
+dataset.n_frames_valid=700 \
|
| 20 |
+dataset.memory_condition_length=8 \
|
| 21 |
+dataset.customized_validation=true \
|
main.py
CHANGED
|
@@ -59,6 +59,10 @@ def run_local(cfg: DictConfig):
|
|
| 59 |
OmegaConf.set_readonly(hydra_cfg, True)
|
| 60 |
|
| 61 |
output_dir = Path(hydra_cfg.runtime.output_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
if is_rank_zero:
|
| 64 |
print(cyan(f"Outputs will be saved to:"), output_dir)
|
|
|
|
| 59 |
OmegaConf.set_readonly(hydra_cfg, True)
|
| 60 |
|
| 61 |
output_dir = Path(hydra_cfg.runtime.output_dir)
|
| 62 |
+
if not output_dir.exists():
|
| 63 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 64 |
+
if is_rank_zero:
|
| 65 |
+
print(cyan(f"Created output directory: {output_dir}"))
|
| 66 |
|
| 67 |
if is_rank_zero:
|
| 68 |
print(cyan(f"Outputs will be saved to:"), output_dir)
|
requirements.txt
CHANGED
|
@@ -1,28 +1,139 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
moviepy==1.0.3
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiofiles==23.2.1
|
| 2 |
+
aiohappyeyeballs==2.6.1
|
| 3 |
+
aiohttp==3.13.3
|
| 4 |
+
aiosignal==1.4.0
|
| 5 |
+
altair==5.5.0
|
| 6 |
+
annotated-doc==0.0.4
|
| 7 |
+
antlr4-python3-runtime==4.9.3
|
| 8 |
+
anyio==4.12.1
|
| 9 |
+
async-timeout==5.0.1
|
| 10 |
+
attrs==25.4.0
|
| 11 |
+
av==16.1.0
|
| 12 |
+
certifi==2026.1.4
|
| 13 |
+
charset-normalizer==3.4.4
|
| 14 |
+
click==8.3.1
|
| 15 |
+
colorama==0.4.6
|
| 16 |
+
colorlog==6.10.1
|
| 17 |
+
contourpy==1.3.2
|
| 18 |
+
cycler==0.12.1
|
| 19 |
+
decorator==4.4.2
|
| 20 |
+
diffusers==0.36.0
|
| 21 |
+
docker-pycreds==0.4.0
|
| 22 |
+
einops==0.8.1
|
| 23 |
+
exceptiongroup==1.3.1
|
| 24 |
+
fastapi==0.125.0
|
| 25 |
+
ffmpy==1.0.0
|
| 26 |
+
filelock==3.20.3
|
| 27 |
+
fonttools==4.61.1
|
| 28 |
+
frozenlist==1.8.0
|
| 29 |
+
fsspec==2024.12.0
|
| 30 |
+
fvcore==0.1.5.post20221221
|
| 31 |
+
gitdb==4.0.12
|
| 32 |
+
GitPython==3.1.46
|
| 33 |
+
gluonts==0.13.1
|
| 34 |
+
gradio==3.50.2
|
| 35 |
+
gradio_client==0.6.1
|
| 36 |
+
h11==0.16.0
|
| 37 |
+
h5py==3.15.1
|
| 38 |
+
hf-xet==1.2.0
|
| 39 |
+
httpcore==1.0.9
|
| 40 |
+
httpx==0.28.1
|
| 41 |
+
huggingface_hub==1.3.2
|
| 42 |
+
hydra-core==1.3.2
|
| 43 |
+
idna==3.11
|
| 44 |
+
ImageIO==2.37.2
|
| 45 |
+
imageio-ffmpeg==0.6.0
|
| 46 |
+
importlib_metadata==8.7.1
|
| 47 |
+
importlib_resources==6.5.2
|
| 48 |
+
internetarchive==5.7.1
|
| 49 |
+
iopath==0.1.10
|
| 50 |
+
Jinja2==3.1.6
|
| 51 |
+
jsonpatch==1.33
|
| 52 |
+
jsonpointer==3.0.0
|
| 53 |
+
jsonschema==4.26.0
|
| 54 |
+
jsonschema-specifications==2025.9.1
|
| 55 |
+
kiwisolver==1.4.9
|
| 56 |
+
lightning==2.1.4
|
| 57 |
+
lightning-utilities==0.15.2
|
| 58 |
+
lpips==0.1.4
|
| 59 |
+
MarkupSafe==2.1.5
|
| 60 |
+
matplotlib==3.10.8
|
| 61 |
moviepy==1.0.3
|
| 62 |
+
mpmath==1.3.0
|
| 63 |
+
multidict==6.7.0
|
| 64 |
+
narwhals==2.15.0
|
| 65 |
+
networkx==3.4.2
|
| 66 |
+
numpy==1.26.4
|
| 67 |
+
nvidia-cublas-cu12==12.1.3.1
|
| 68 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
| 69 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
| 70 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
| 71 |
+
nvidia-cudnn-cu12==9.1.0.70
|
| 72 |
+
nvidia-cufft-cu12==11.0.2.54
|
| 73 |
+
nvidia-curand-cu12==10.3.2.106
|
| 74 |
+
nvidia-cusolver-cu12==11.4.5.107
|
| 75 |
+
nvidia-cusparse-cu12==12.1.0.106
|
| 76 |
+
nvidia-nccl-cu12==2.20.5
|
| 77 |
+
nvidia-nvjitlink-cu12==12.9.86
|
| 78 |
+
nvidia-nvtx-cu12==12.1.105
|
| 79 |
+
omegaconf==2.3.0
|
| 80 |
+
opencv-python==4.11.0.86
|
| 81 |
+
orjson==3.11.5
|
| 82 |
+
packaging==24.2
|
| 83 |
+
pandas==2.3.3
|
| 84 |
+
parameterized==0.9.0
|
| 85 |
+
pillow==10.4.0
|
| 86 |
+
platformdirs==4.5.1
|
| 87 |
+
portalocker==3.2.0
|
| 88 |
+
proglog==0.1.12
|
| 89 |
+
propcache==0.4.1
|
| 90 |
+
protobuf==3.19.6
|
| 91 |
+
psutil==5.9.8
|
| 92 |
+
pydantic==1.10.26
|
| 93 |
+
pydub==0.25.1
|
| 94 |
+
pyparsing==3.3.1
|
| 95 |
+
pyrealsense2==2.56.5.9235
|
| 96 |
+
python-dateutil==2.9.0.post0
|
| 97 |
+
python-multipart==0.0.21
|
| 98 |
+
pytorch-lightning==2.6.0
|
| 99 |
+
pytorchvideo==0.1.5
|
| 100 |
+
pytz==2025.2
|
| 101 |
+
PyYAML==6.0.3
|
| 102 |
+
pyzmq==27.1.0
|
| 103 |
+
referencing==0.37.0
|
| 104 |
+
regex==2026.1.15
|
| 105 |
+
requests==2.32.5
|
| 106 |
+
rotary-embedding-torch==0.8.9
|
| 107 |
+
rpds-py==0.30.0
|
| 108 |
+
safetensors==0.7.0
|
| 109 |
+
scipy==1.15.3
|
| 110 |
+
semantic-version==2.10.0
|
| 111 |
+
sentry-sdk==2.49.0
|
| 112 |
+
setproctitle==1.3.7
|
| 113 |
+
shellingham==1.5.4
|
| 114 |
+
six==1.17.0
|
| 115 |
+
smmap==5.0.2
|
| 116 |
+
spaces==0.46.0
|
| 117 |
+
starlette==0.50.0
|
| 118 |
+
sympy==1.14.0
|
| 119 |
+
tabulate==0.9.0
|
| 120 |
+
termcolor==3.3.0
|
| 121 |
+
timm==1.0.24
|
| 122 |
+
toolz==0.12.1
|
| 123 |
+
torch==2.4.1
|
| 124 |
+
torch-fidelity==0.3.0
|
| 125 |
+
torchmetrics==0.11.4
|
| 126 |
+
torchvision==0.19.1
|
| 127 |
+
tqdm==4.67.1
|
| 128 |
+
triton==3.0.0
|
| 129 |
+
typer-slim==0.21.1
|
| 130 |
+
typing_extensions==4.15.0
|
| 131 |
+
tzdata==2025.3
|
| 132 |
+
urllib3==2.6.3
|
| 133 |
+
uvicorn==0.40.0
|
| 134 |
+
wandb==0.17.9
|
| 135 |
+
wandb_osh==1.2.1
|
| 136 |
+
websockets==11.0.3
|
| 137 |
+
yacs==0.1.8
|
| 138 |
+
yarl==1.22.0
|
| 139 |
+
zipp==3.23.0
|
train_3stages.sh
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
wandb enabled
|
| 2 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 3 |
+
export NCCL_P2P_DISABLE=1
|
| 4 |
+
# export HYDRA_FULL_ERROR=1
|
| 5 |
+
|
| 6 |
+
set -e # Exit on any error
|
| 7 |
+
set -o pipefail # Exit on pipe failures
|
| 8 |
+
|
| 9 |
+
#Stage 1
|
| 10 |
+
python -m main +name=train \
|
| 11 |
+
+diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
|
| 12 |
+
+vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
|
| 13 |
+
+customized_load=true \
|
| 14 |
+
+seperate_load=true \
|
| 15 |
+
+zero_init_gate=true \
|
| 16 |
+
dataset.n_frames=8 \
|
| 17 |
+
dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
|
| 18 |
+
+dataset.n_frames_valid=700 \
|
| 19 |
+
+dataset.angle_range=110 \
|
| 20 |
+
+dataset.pos_range=2 \
|
| 21 |
+
+dataset.memory_condition_length=8 \
|
| 22 |
+
+dataset.customized_validation=true \
|
| 23 |
+
+dataset.add_timestamp_embedding=true \
|
| 24 |
+
+dataset.wo_updown=true \
|
| 25 |
+
+algorithm.n_tokens=8 \
|
| 26 |
+
+algorithm.memory_condition_length=8 \
|
| 27 |
+
algorithm.context_frames=600 \
|
| 28 |
+
+algorithm.relative_embedding=true \
|
| 29 |
+
+algorithm.log_video=true \
|
| 30 |
+
+algorithm.add_timestamp_embedding=true \
|
| 31 |
+
+algorithm.metrics=[lpips,psnr] \
|
| 32 |
+
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 33 |
+
experiment.training.max_steps=120000 \
|
| 34 |
+
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
|
| 35 |
+
|
| 36 |
+
#Stage 2
|
| 37 |
+
python -m main +name=train \
|
| 38 |
+
dataset.n_frames=8 \
|
| 39 |
+
dataset.save_dir=data/minecraft \
|
| 40 |
+
+dataset.n_frames_valid=700 \
|
| 41 |
+
+dataset.angle_range=110 \
|
| 42 |
+
+dataset.pos_range=8 \
|
| 43 |
+
+dataset.memory_condition_length=8 \
|
| 44 |
+
+dataset.customized_validation=true \
|
| 45 |
+
+dataset.add_timestamp_embedding=true \
|
| 46 |
+
+dataset.wo_updown=true \
|
| 47 |
+
+algorithm.n_tokens=8 \
|
| 48 |
+
+algorithm.memory_condition_length=8 \
|
| 49 |
+
algorithm.context_frames=600 \
|
| 50 |
+
+algorithm.relative_embedding=true \
|
| 51 |
+
+algorithm.log_video=true \
|
| 52 |
+
+algorithm.add_timestamp_embedding=true \
|
| 53 |
+
+algorithm.metrics=[lpips,psnr] \
|
| 54 |
+
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 55 |
+
resume=ot7jqmgn \
|
| 56 |
+
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
|
| 57 |
+
experiment.training.max_steps=240000
|
| 58 |
+
|
| 59 |
+
#Stage 3
|
| 60 |
+
python -m main +name=train \
|
| 61 |
+
dataset.n_frames=8 \
|
| 62 |
+
dataset.save_dir=data/minecraft \
|
| 63 |
+
+dataset.n_frames_valid=700 \
|
| 64 |
+
+dataset.angle_range=110 \
|
| 65 |
+
+dataset.pos_range=8 \
|
| 66 |
+
+dataset.memory_condition_length=8 \
|
| 67 |
+
+dataset.customized_validation=true \
|
| 68 |
+
+dataset.add_timestamp_embedding=true \
|
| 69 |
+
+dataset.wo_updown=false \
|
| 70 |
+
+algorithm.n_tokens=8 \
|
| 71 |
+
+algorithm.memory_condition_length=8 \
|
| 72 |
+
algorithm.context_frames=600 \
|
| 73 |
+
+algorithm.relative_embedding=true \
|
| 74 |
+
+algorithm.log_video=true \
|
| 75 |
+
+algorithm.add_timestamp_embedding=true \
|
| 76 |
+
+algorithm.metrics=[lpips,psnr] \
|
| 77 |
+
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 78 |
+
resume=ot7jqmgn \
|
| 79 |
+
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
|
| 80 |
+
experiment.training.max_steps=700000
|
train_stage_1.sh
CHANGED
|
@@ -1,14 +1,16 @@
|
|
| 1 |
wandb enabled
|
| 2 |
-
|
|
|
|
|
|
|
| 3 |
# set -e
|
| 4 |
python -m main +name=train \
|
| 5 |
-
+diffusion_model_path=
|
| 6 |
-
+vae_path=
|
| 7 |
+customized_load=true \
|
| 8 |
+seperate_load=true \
|
| 9 |
+zero_init_gate=true \
|
| 10 |
dataset.n_frames=8 \
|
| 11 |
-
dataset.save_dir=
|
| 12 |
+dataset.n_frames_valid=700 \
|
| 13 |
+dataset.angle_range=110 \
|
| 14 |
+dataset.pos_range=2 \
|
|
@@ -22,8 +24,7 @@ python -m main +name=train \
|
|
| 22 |
+algorithm.relative_embedding=true \
|
| 23 |
+algorithm.log_video=true \
|
| 24 |
+algorithm.add_timestamp_embedding=true \
|
| 25 |
-
algorithm.metrics=[lpips,psnr] \
|
| 26 |
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 27 |
-
experiment.training.max_steps=120000
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 1 |
wandb enabled
|
| 2 |
+
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
| 3 |
+
export NCCL_P2P_DISABLE=1
|
| 4 |
+
# export HYDRA_FULL_ERROR=1
|
| 5 |
# set -e
|
| 6 |
python -m main +name=train \
|
| 7 |
+
+diffusion_model_path=/share_1/users/bonan_ding/worldmem_ckpt/diffusion_only.ckpt \
|
| 8 |
+
+vae_path=/share_1/users/bonan_ding/worldmem_ckpt/vae_only.ckpt \
|
| 9 |
+customized_load=true \
|
| 10 |
+seperate_load=true \
|
| 11 |
+zero_init_gate=true \
|
| 12 |
dataset.n_frames=8 \
|
| 13 |
+
dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
|
| 14 |
+dataset.n_frames_valid=700 \
|
| 15 |
+dataset.angle_range=110 \
|
| 16 |
+dataset.pos_range=2 \
|
|
|
|
| 24 |
+algorithm.relative_embedding=true \
|
| 25 |
+algorithm.log_video=true \
|
| 26 |
+algorithm.add_timestamp_embedding=true \
|
| 27 |
+
+algorithm.metrics=[lpips,psnr] \
|
| 28 |
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 29 |
+
experiment.training.max_steps=120000 \
|
| 30 |
+
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
|
|
|
train_stage_2.sh
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
wandb enabled
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
| 4 |
python -m main +name=train \
|
| 5 |
dataset.n_frames=8 \
|
| 6 |
-
dataset.save_dir=
|
| 7 |
+dataset.n_frames_valid=700 \
|
| 8 |
+dataset.angle_range=110 \
|
| 9 |
+dataset.pos_range=8 \
|
|
@@ -17,9 +19,31 @@ python -m main +name=train \
|
|
| 17 |
+algorithm.relative_embedding=true \
|
| 18 |
+algorithm.log_video=true \
|
| 19 |
+algorithm.add_timestamp_embedding=true \
|
| 20 |
-
algorithm.metrics=[lpips,psnr] \
|
| 21 |
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 22 |
-
resume=
|
| 23 |
-
+output_dir=
|
| 24 |
experiment.training.max_steps=240000
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
wandb enabled
|
| 2 |
+
export CUDA_VISIBLE_DEVICES=4,5,6,7
|
| 3 |
+
export NCCL_P2P_DISABLE=1
|
| 4 |
+
# export HYDRA_FULL_ERROR=1
|
| 5 |
+
set -e
|
| 6 |
python -m main +name=train \
|
| 7 |
dataset.n_frames=8 \
|
| 8 |
+
dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
|
| 9 |
+dataset.n_frames_valid=700 \
|
| 10 |
+dataset.angle_range=110 \
|
| 11 |
+dataset.pos_range=8 \
|
|
|
|
| 19 |
+algorithm.relative_embedding=true \
|
| 20 |
+algorithm.log_video=true \
|
| 21 |
+algorithm.add_timestamp_embedding=true \
|
| 22 |
+
+algorithm.metrics=[lpips,psnr] \
|
| 23 |
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 24 |
+
resume=ot7jqmgn \
|
| 25 |
+
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
|
| 26 |
experiment.training.max_steps=240000
|
| 27 |
|
| 28 |
+
#Stage 3
|
| 29 |
+
python -m main +name=train \
|
| 30 |
+
dataset.n_frames=8 \
|
| 31 |
+
dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \
|
| 32 |
+
+dataset.n_frames_valid=700 \
|
| 33 |
+
+dataset.angle_range=110 \
|
| 34 |
+
+dataset.pos_range=8 \
|
| 35 |
+
+dataset.memory_condition_length=8 \
|
| 36 |
+
+dataset.customized_validation=true \
|
| 37 |
+
+dataset.add_timestamp_embedding=true \
|
| 38 |
+
+dataset.wo_updown=false \
|
| 39 |
+
+algorithm.n_tokens=8 \
|
| 40 |
+
+algorithm.memory_condition_length=8 \
|
| 41 |
+
algorithm.context_frames=600 \
|
| 42 |
+
+algorithm.relative_embedding=true \
|
| 43 |
+
+algorithm.log_video=true \
|
| 44 |
+
+algorithm.add_timestamp_embedding=true \
|
| 45 |
+
+algorithm.metrics=[lpips,psnr] \
|
| 46 |
+
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 47 |
+
resume=ot7jqmgn \
|
| 48 |
+
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_official_set \
|
| 49 |
+
experiment.training.max_steps=700000
|
train_stage_3.sh
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
wandb enabled
|
| 2 |
-
|
|
|
|
|
|
|
| 3 |
# set -e
|
| 4 |
python -m main +name=train \
|
| 5 |
dataset.n_frames=8 \
|
|
@@ -17,8 +19,9 @@ python -m main +name=train \
|
|
| 17 |
+algorithm.relative_embedding=true \
|
| 18 |
+algorithm.log_video=true \
|
| 19 |
+algorithm.add_timestamp_embedding=true \
|
| 20 |
-
algorithm.metrics=[lpips,psnr] \
|
| 21 |
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 22 |
-
resume=
|
| 23 |
-
+output_dir=
|
| 24 |
-
experiment.training.max_steps=700000
|
|
|
|
|
|
| 1 |
wandb enabled
|
| 2 |
+
export CUDA_VISIBLE_DEVICES=4,5
|
| 3 |
+
export NCCL_P2P_DISABLE=1
|
| 4 |
+
# export HYDRA_FULL_ERROR=1
|
| 5 |
# set -e
|
| 6 |
python -m main +name=train \
|
| 7 |
dataset.n_frames=8 \
|
|
|
|
| 19 |
+algorithm.relative_embedding=true \
|
| 20 |
+algorithm.log_video=true \
|
| 21 |
+algorithm.add_timestamp_embedding=true \
|
| 22 |
+
+algorithm.metrics=[lpips,psnr] \
|
| 23 |
experiment.training.checkpointing.every_n_train_steps=2500 \
|
| 24 |
+
resume=qyyk38nw \
|
| 25 |
+
+output_dir=/share_1/users/bonan_ding/worldmem_ckpt/reproduce_1 \
|
| 26 |
+
# experiment.training.max_steps=700000
|
| 27 |
+
experiment.training.max_steps=350000
|
utils/distributed_utils.py
CHANGED
|
@@ -1,3 +1,10 @@
|
|
| 1 |
-
import
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
|
| 3 |
+
# Check standard environment variables for distributed training
|
| 4 |
+
# Default to True (rank 0) if not in a distributed environment
|
| 5 |
+
_rank = int(os.environ.get("RANK", 0))
|
| 6 |
+
_local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| 7 |
+
|
| 8 |
+
# We consider it rank zero if global rank is 0.
|
| 9 |
+
# Local rank check is usually redundant if rank is 0, but good for sanity.
|
| 10 |
+
is_rank_zero = _rank == 0
|