update lfs
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +5 -0
- LICENSE.md +14 -0
- README.md +201 -0
- algorithms/README.md +21 -0
- algorithms/__init__.py +0 -0
- algorithms/common/README.md +5 -0
- algorithms/common/__init__.py +0 -0
- algorithms/common/base_algo.py +21 -0
- algorithms/common/base_pytorch_algo.py +252 -0
- algorithms/common/metrics/__init__.py +3 -0
- algorithms/common/metrics/fid.py +1 -0
- algorithms/common/metrics/fvd.py +158 -0
- algorithms/common/metrics/lpips.py +1 -0
- algorithms/common/models/__init__.py +0 -0
- algorithms/common/models/cnn.py +141 -0
- algorithms/common/models/mlp.py +22 -0
- algorithms/worldmem/__init__.py +2 -0
- algorithms/worldmem/df_base.py +307 -0
- algorithms/worldmem/df_video.py +926 -0
- algorithms/worldmem/models/attention.py +342 -0
- algorithms/worldmem/models/cameractrl_module.py +12 -0
- algorithms/worldmem/models/diffusion.py +520 -0
- algorithms/worldmem/models/dit.py +572 -0
- algorithms/worldmem/models/pose_prediction.py +42 -0
- algorithms/worldmem/models/rotary_embedding_torch.py +302 -0
- algorithms/worldmem/models/utils.py +163 -0
- algorithms/worldmem/models/vae.py +359 -0
- algorithms/worldmem/pose_prediction.py +374 -0
- app.py +576 -0
- assets/desert.png +3 -0
- assets/ice_plains.png +3 -0
- assets/place.png +3 -0
- assets/plains.png +3 -0
- assets/rain_sunflower_plains.png +3 -0
- assets/savanna.png +3 -0
- assets/sunflower_plains.png +3 -0
- assets/worldmem_logo.png +3 -0
- calculate_fid.py +277 -0
- configurations/algorithm/base_algo.yaml +3 -0
- configurations/algorithm/base_pytorch_algo.yaml +4 -0
- configurations/algorithm/df_base.yaml +42 -0
- configurations/algorithm/df_video_worldmemminecraft.yaml +38 -0
- configurations/dataset/base_dataset.yaml +3 -0
- configurations/dataset/base_video.yaml +14 -0
- configurations/dataset/video_minecraft.yaml +14 -0
- configurations/experiment/base_experiment.yaml +2 -0
- configurations/experiment/base_pytorch.yaml +50 -0
- configurations/experiment/exp_video.yaml +31 -0
- configurations/huggingface.yaml +60 -0
- configurations/training.yaml +16 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
<<<<<<< HEAD
|
| 37 |
+
=======
|
| 38 |
+
assets/* filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
>>>>>>> def529c (Baseline WorldMem)
|
LICENSE.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# S-Lab License 1.0
|
| 2 |
+
|
| 3 |
+
Copyright 2025 S-Lab
|
| 4 |
+
|
| 5 |
+
Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 6 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 7 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 8 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\
|
| 9 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 10 |
+
4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg)
|
README.md
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
<br>
|
| 3 |
+
<p align="center">
|
| 4 |
+
|
| 5 |
+
<p align="center">
|
| 6 |
+
<img src="assets/worldmem_logo.png" alt="WORLDMEM Icon" width="80"/>
|
| 7 |
+
</p>
|
| 8 |
+
<h1 align="center"><strong>WorldMem: Long-term Consistent World Simulation <br> with Memory</strong></h1>
|
| 9 |
+
<p align="center"><span><a href=""></a></span>
|
| 10 |
+
<a href="https://xizaoqu.github.io">Zeqi Xiao<sup>1</sup></a>
|
| 11 |
+
<a href="https://nirvanalan.github.io/">Yushi Lan<sup>1</sup></a>
|
| 12 |
+
<a href="https://zhouyifan.net/about/">Yifan Zhou<sup>1</sup></a>
|
| 13 |
+
<a href="https://vicky0522.github.io/Wenqi-Ouyang/">Wenqi Ouyang<sup>1</sup></a>
|
| 14 |
+
<a href="https://williamyang1991.github.io/">Shuai Yang<sup>2</sup></a>
|
| 15 |
+
<a href="https://zengyh1900.github.io/">Yanhong Zeng<sup>3</sup></a>
|
| 16 |
+
<a href="https://xingangpan.github.io/">Xingang Pan<sup>1</sup></a> <br>
|
| 17 |
+
<sup>1</sup>S-Lab, Nanyang Technological University, <br> <sup>2</sup>Wangxuan Institute of Computer Technology, Peking University,<br> <sup>3</sup>Shanghai AI Laboratory
|
| 18 |
+
</p>
|
| 19 |
+
</p>
|
| 20 |
+
|
| 21 |
+
<p align="center">
|
| 22 |
+
<a href="https://arxiv.org/abs/2504.12369" target='_blank'>
|
| 23 |
+
<img src="https://img.shields.io/badge/arXiv-2504.12369-blue?">
|
| 24 |
+
</a>
|
| 25 |
+
<a href="https://xizaoqu.github.io/worldmem/" target='_blank'>
|
| 26 |
+
<img src="https://img.shields.io/badge/Project-🚀-blue">
|
| 27 |
+
</a>
|
| 28 |
+
<a href="https://huggingface.co/spaces/yslan/worldmem" target="_blank">
|
| 29 |
+
<img src="https://img.shields.io/badge/🤗 HuggingFace-Demo-orange" />
|
| 30 |
+
</a>
|
| 31 |
+
</p>
|
| 32 |
+
|
| 33 |
+
https://github.com/user-attachments/assets/fb8a32e2-9470-4819-a93d-c38caf76d72c
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
## Installation
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
conda create python=3.10 -n worldmem
|
| 40 |
+
conda activate worldmem
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
conda install -c conda-forge ffmpeg=4.3.2
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
## Quick start
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
python app.py
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## Run
|
| 53 |
+
|
| 54 |
+
To enable cloud logging with [Weights & Biases (wandb)](https://wandb.ai/site), follow these steps:
|
| 55 |
+
|
| 56 |
+
1. Sign up for a wandb account.
|
| 57 |
+
2. Run the following command to log in:
|
| 58 |
+
|
| 59 |
+
```bash
|
| 60 |
+
wandb login
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
3. Open `configurations/training.yaml` and set the `entity` and `project` field to your wandb username.
|
| 64 |
+
|
| 65 |
+
---
|
| 66 |
+
|
| 67 |
+
### Training
|
| 68 |
+
|
| 69 |
+
Download pretrained weights from [Oasis](https://github.com/etched-ai/open-oasis).
|
| 70 |
+
|
| 71 |
+
Training the model on 4 H100 GPUs, it converges after approximately 500K steps.
|
| 72 |
+
We observe that gradually increasing task difficulty improves performance. Thus, we adopt a multi-stage training strategy:
|
| 73 |
+
,
|
| 74 |
+
```bash
|
| 75 |
+
sh train_stage_1.sh # Small range, no vertical turning
|
| 76 |
+
sh train_stage_2.sh # Large range, no vertical turning
|
| 77 |
+
sh train_stage_3.sh # Large range, with vertical turning
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
To resume training from a previous checkpoint, configure the `resume` and `output_dir` variables in the corresponding `.sh` script.
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
### Inference
|
| 85 |
+
|
| 86 |
+
To run inference:
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
sh infer.sh
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
You can either **load the diffusion model and VAE separately**:
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
+diffusion_model_path=zeqixiao/worldmem_checkpoints/diffusion_only.ckpt \
|
| 96 |
+
+vae_path=zeqixiao/worldmem_checkpoints/vae_only.ckpt \
|
| 97 |
+
+customized_load=true \
|
| 98 |
+
+seperate_load=true \
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
Or **load a combined checkpoint**:
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
+load=your_model_path \
|
| 105 |
+
+customized_load=true \
|
| 106 |
+
+seperate_load=false \
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
### Evaluation
|
| 110 |
+
|
| 111 |
+
To run evaluation:
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
sh evaluate.sh
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
This script reproduces the results in Table 1 (beyond context window). It will generate PSNR and Lpips. Evaluating 1 case on 1 A100 GPU takes approximately 6 minutes. You can adjust `experiment.test.limit_batch` to specify the number of cases to evaluate.
|
| 118 |
+
|
| 119 |
+
Visual results will be saved by default to a timestamped directory (e.g., `outputs/2025-11-30/00-02-42`).
|
| 120 |
+
|
| 121 |
+
To calculate the FID score, run:
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
python calculate_fid.py --videos_dir <path_to_videos>
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
For example:
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
python calculate_fid.py --videos_dir outputs/2025-11-30/00-02-42/videos/test_vis
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
**Expected Results:**
|
| 134 |
+
|
| 135 |
+
| Metric | Value |
|
| 136 |
+
|--------|--------|
|
| 137 |
+
| PSNR | 24.01 |
|
| 138 |
+
| LPIPS | 0.1667 |
|
| 139 |
+
| FID | 15.13 |
|
| 140 |
+
|
| 141 |
+
*Note: FID is computed over 5000 frames.*
|
| 142 |
+
|
| 143 |
+
---
|
| 144 |
+
|
| 145 |
+
## Dataset
|
| 146 |
+
|
| 147 |
+
Download the Minecraft dataset from [Hugging Face](https://huggingface.co/datasets/zeqixiao/worldmem_minecraft_dataset)
|
| 148 |
+
|
| 149 |
+
Place the dataset in the following directory structure:
|
| 150 |
+
|
| 151 |
+
```
|
| 152 |
+
data/
|
| 153 |
+
└── minecraft/
|
| 154 |
+
├── training/
|
| 155 |
+
└── validation/
|
| 156 |
+
└── test/
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
## Data Generation
|
| 160 |
+
|
| 161 |
+
After setting up the environment as described in [MineDojo's GitHub repository](https://github.com/MineDojo/MineDojo), you can generate data using the following command:
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
xvfb-run -a python data_generator.py -o data/test -z 4 --env_type plains
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
**Parameters:**
|
| 168 |
+
- `-o`: Output directory for generated data
|
| 169 |
+
- `-z`: Number of parallel workers
|
| 170 |
+
- `--env_type`: Environment type (e.g., `plains`)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
## TODO
|
| 174 |
+
|
| 175 |
+
- [x] Release inference models and weights;
|
| 176 |
+
- [x] Release training pipeline on Minecraft;
|
| 177 |
+
- [x] Release training data on Minecraft;
|
| 178 |
+
- [x] Release evaluation scripts and data generator.
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
## 🔗 Citation
|
| 183 |
+
|
| 184 |
+
If you find our work helpful, please cite:
|
| 185 |
+
|
| 186 |
+
```
|
| 187 |
+
@misc{xiao2025worldmemlongtermconsistentworld,
|
| 188 |
+
title={WORLDMEM: Long-term Consistent World Simulation with Memory},
|
| 189 |
+
author={Zeqi Xiao and Yushi Lan and Yifan Zhou and Wenqi Ouyang and Shuai Yang and Yanhong Zeng and Xingang Pan},
|
| 190 |
+
year={2025},
|
| 191 |
+
eprint={2504.12369},
|
| 192 |
+
archivePrefix={arXiv},
|
| 193 |
+
primaryClass={cs.CV},
|
| 194 |
+
url={https://arxiv.org/abs/2504.12369},
|
| 195 |
+
}
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
## 👏 Acknowledgements
|
| 199 |
+
- [Diffusion Forcing](https://github.com/buoyancy99/diffusion-forcing): Diffusion Forcing provides flexible training and inference strategies for our methods.
|
| 200 |
+
- [Minedojo](https://github.com/MineDojo/MineDojo): We collect our Minecraft dataset from Minedojo.
|
| 201 |
+
- [Open-oasis](https://github.com/etched-ai/open-oasis): Our model architecture is based on Open-oasis. We also use pretrained VAE and DiT weight from it.
|
algorithms/README.md
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# algorithms
|
| 2 |
+
|
| 3 |
+
`algorithms` folder is designed to contain implementation of algorithms or models.
|
| 4 |
+
Content in `algorithms` can be loosely grouped components (e.g. models) or an algorithm has already has all
|
| 5 |
+
components chained together (e.g. Lightning Module, RL algo).
|
| 6 |
+
You should create a folder name after your own algorithm or baselines in it.
|
| 7 |
+
|
| 8 |
+
Two example can be found in `examples` subfolder.
|
| 9 |
+
|
| 10 |
+
The `common` subfolder is designed to contain general purpose classes that's useful for many projects, e.g MLP.
|
| 11 |
+
|
| 12 |
+
You should not run any `.py` file from algorithms folder.
|
| 13 |
+
Instead, you write unit tests / debug python files in `debug` and launch script in `experiments`.
|
| 14 |
+
|
| 15 |
+
You are discouraged from putting visualization utilities in algorithms, as those should go to `utils` in project root.
|
| 16 |
+
|
| 17 |
+
Each algorithm class takes in a DictConfig file `cfg` in its `__init__`, which allows you to pass in arguments via configuration file in `configurations/algorithm` or [command line override](https://hydra.cc/docs/tutorials/basic/your_first_app/simple_cli/).
|
| 18 |
+
|
| 19 |
+
---
|
| 20 |
+
|
| 21 |
+
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
|
algorithms/__init__.py
ADDED
|
File without changes
|
algorithms/common/README.md
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
THis folder contains models / algorithms that are considered general for many algorithms.
|
| 2 |
+
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep the above sentence in `README.md` and the `LICENSE` file to credit the author.
|
algorithms/common/__init__.py
ADDED
|
File without changes
|
algorithms/common/base_algo.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseAlgo(ABC):
|
| 8 |
+
"""
|
| 9 |
+
A base class for generic algorithms.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, cfg: DictConfig):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.cfg = cfg
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def run(*args: Any, **kwargs: Any) -> Any:
|
| 18 |
+
"""
|
| 19 |
+
Run the algorithm.
|
| 20 |
+
"""
|
| 21 |
+
raise NotImplementedError
|
algorithms/common/base_pytorch_algo.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import warnings
|
| 3 |
+
from typing import Any, Union, Sequence, Optional
|
| 4 |
+
|
| 5 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
| 6 |
+
from omegaconf import DictConfig
|
| 7 |
+
import lightning.pytorch as pl
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import wandb
|
| 12 |
+
import einops
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class BasePytorchAlgo(pl.LightningModule, ABC):
|
| 16 |
+
"""
|
| 17 |
+
A base class for Pytorch algorithms using Pytorch Lightning.
|
| 18 |
+
See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, cfg: DictConfig):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.cfg = cfg
|
| 24 |
+
self._build_model()
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def _build_model(self):
|
| 28 |
+
"""
|
| 29 |
+
Create all pytorch nn.Modules here.
|
| 30 |
+
"""
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
| 35 |
+
r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
|
| 36 |
+
logger.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
|
| 40 |
+
batch_idx: The index of this batch.
|
| 41 |
+
dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch.
|
| 42 |
+
|
| 43 |
+
Return:
|
| 44 |
+
Any of these options:
|
| 45 |
+
- :class:`~torch.Tensor` - The loss tensor
|
| 46 |
+
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
|
| 47 |
+
- ``None`` - Skip to the next batch. This is only supported for automatic optimization.
|
| 48 |
+
This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
|
| 49 |
+
|
| 50 |
+
In this step you'd normally do the forward pass and calculate the loss for a batch.
|
| 51 |
+
You can also do fancier things like multiple forward passes or something model specific.
|
| 52 |
+
|
| 53 |
+
Example::
|
| 54 |
+
|
| 55 |
+
def training_step(self, batch, batch_idx):
|
| 56 |
+
x, y, z = batch
|
| 57 |
+
out = self.encoder(x)
|
| 58 |
+
loss = self.loss(out, x)
|
| 59 |
+
return loss
|
| 60 |
+
|
| 61 |
+
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
|
| 62 |
+
|
| 63 |
+
.. code-block:: python
|
| 64 |
+
|
| 65 |
+
def __init__(self):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.automatic_optimization = False
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Multiple optimizers (e.g.: GANs)
|
| 71 |
+
def training_step(self, batch, batch_idx):
|
| 72 |
+
opt1, opt2 = self.optimizers()
|
| 73 |
+
|
| 74 |
+
# do training_step with encoder
|
| 75 |
+
...
|
| 76 |
+
opt1.step()
|
| 77 |
+
# do training_step with decoder
|
| 78 |
+
...
|
| 79 |
+
opt2.step()
|
| 80 |
+
|
| 81 |
+
Note:
|
| 82 |
+
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
|
| 83 |
+
normalized by ``accumulate_grad_batches`` internally.
|
| 84 |
+
|
| 85 |
+
"""
|
| 86 |
+
return super().training_step(*args, **kwargs)
|
| 87 |
+
|
| 88 |
+
def configure_optimizers(self):
|
| 89 |
+
"""
|
| 90 |
+
Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation:
|
| 91 |
+
https://lightning.ai/docs/pytorch/stable/common/optimization.html
|
| 92 |
+
"""
|
| 93 |
+
parameters = self.parameters()
|
| 94 |
+
return torch.optim.Adam(parameters, lr=self.cfg.lr)
|
| 95 |
+
|
| 96 |
+
def log_video(
|
| 97 |
+
self,
|
| 98 |
+
key: str,
|
| 99 |
+
video: Union[np.ndarray, torch.Tensor],
|
| 100 |
+
mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
|
| 101 |
+
std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
|
| 102 |
+
fps: int = 5,
|
| 103 |
+
format: str = "mp4",
|
| 104 |
+
):
|
| 105 |
+
"""
|
| 106 |
+
Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
video: a numpy array or tensor, either in form (time, channel, height, width) or in the form
|
| 110 |
+
(batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8
|
| 111 |
+
or [0, 1] otherwise.
|
| 112 |
+
mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1].
|
| 113 |
+
std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1].
|
| 114 |
+
key: the name of the video.
|
| 115 |
+
fps: the frame rate of the video.
|
| 116 |
+
format: the format of the video. Can be either "mp4" or "gif".
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
if isinstance(video, torch.Tensor):
|
| 120 |
+
video = video.detach().cpu().numpy()
|
| 121 |
+
|
| 122 |
+
expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1]
|
| 123 |
+
if std is not None:
|
| 124 |
+
if isinstance(std, (float, int)):
|
| 125 |
+
std = [std] * 3
|
| 126 |
+
if isinstance(std, torch.Tensor):
|
| 127 |
+
std = std.detach().cpu().numpy()
|
| 128 |
+
std = np.array(std).reshape(*expand_shape)
|
| 129 |
+
video = video * std
|
| 130 |
+
if mean is not None:
|
| 131 |
+
if isinstance(mean, (float, int)):
|
| 132 |
+
mean = [mean] * 3
|
| 133 |
+
if isinstance(mean, torch.Tensor):
|
| 134 |
+
mean = mean.detach().cpu().numpy()
|
| 135 |
+
mean = np.array(mean).reshape(*expand_shape)
|
| 136 |
+
video = video + mean
|
| 137 |
+
|
| 138 |
+
if video.dtype != np.uint8:
|
| 139 |
+
video = np.clip(video, a_min=0, a_max=1) * 255
|
| 140 |
+
video = video.astype(np.uint8)
|
| 141 |
+
|
| 142 |
+
self.logger.experiment.log(
|
| 143 |
+
{
|
| 144 |
+
key: wandb.Video(video, fps=fps, format=format),
|
| 145 |
+
},
|
| 146 |
+
step=self.global_step,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def log_image(
|
| 150 |
+
self,
|
| 151 |
+
key: str,
|
| 152 |
+
image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]],
|
| 153 |
+
mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
|
| 154 |
+
std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
|
| 155 |
+
**kwargs: Any,
|
| 156 |
+
):
|
| 157 |
+
"""
|
| 158 |
+
Log image(s) using WandbLogger.
|
| 159 |
+
Args:
|
| 160 |
+
key: the name of the video.
|
| 161 |
+
image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width).
|
| 162 |
+
mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1].
|
| 163 |
+
std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1].
|
| 164 |
+
kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx.
|
| 165 |
+
"""
|
| 166 |
+
if isinstance(image, Image.Image):
|
| 167 |
+
image = [image]
|
| 168 |
+
elif len(image) and not isinstance(image[0], Image.Image):
|
| 169 |
+
if isinstance(image, torch.Tensor):
|
| 170 |
+
image = image.detach().cpu().numpy()
|
| 171 |
+
|
| 172 |
+
if len(image.shape) == 3:
|
| 173 |
+
image = image[None]
|
| 174 |
+
|
| 175 |
+
if image.shape[1] == 3:
|
| 176 |
+
if image.shape[-1] == 3:
|
| 177 |
+
warnings.warn(f"Two channels in shape {image.shape} have size 3, assuming channel first.")
|
| 178 |
+
image = einops.rearrange(image, "b c h w -> b h w c")
|
| 179 |
+
|
| 180 |
+
if std is not None:
|
| 181 |
+
if isinstance(std, (float, int)):
|
| 182 |
+
std = [std] * 3
|
| 183 |
+
if isinstance(std, torch.Tensor):
|
| 184 |
+
std = std.detach().cpu().numpy()
|
| 185 |
+
std = np.array(std)[None, None, None]
|
| 186 |
+
image = image * std
|
| 187 |
+
if mean is not None:
|
| 188 |
+
if isinstance(mean, (float, int)):
|
| 189 |
+
mean = [mean] * 3
|
| 190 |
+
if isinstance(mean, torch.Tensor):
|
| 191 |
+
mean = mean.detach().cpu().numpy()
|
| 192 |
+
mean = np.array(mean)[None, None, None]
|
| 193 |
+
image = image + mean
|
| 194 |
+
|
| 195 |
+
if image.dtype != np.uint8:
|
| 196 |
+
image = np.clip(image, a_min=0.0, a_max=1.0) * 255
|
| 197 |
+
image = image.astype(np.uint8)
|
| 198 |
+
image = [img for img in image]
|
| 199 |
+
|
| 200 |
+
self.logger.log_image(key=key, images=image, **kwargs)
|
| 201 |
+
|
| 202 |
+
def log_gradient_stats(self):
|
| 203 |
+
"""Log gradient statistics such as the mean or std of norm."""
|
| 204 |
+
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
grad_norms = []
|
| 207 |
+
gpr = [] # gradient-to-parameter ratio
|
| 208 |
+
for param in self.parameters():
|
| 209 |
+
if param.grad is not None:
|
| 210 |
+
grad_norms.append(torch.norm(param.grad).item())
|
| 211 |
+
gpr.append(torch.norm(param.grad) / torch.norm(param))
|
| 212 |
+
if len(grad_norms) == 0:
|
| 213 |
+
return
|
| 214 |
+
grad_norms = torch.tensor(grad_norms)
|
| 215 |
+
gpr = torch.tensor(gpr)
|
| 216 |
+
self.log_dict(
|
| 217 |
+
{
|
| 218 |
+
"train/grad_norm/min": grad_norms.min(),
|
| 219 |
+
"train/grad_norm/max": grad_norms.max(),
|
| 220 |
+
"train/grad_norm/std": grad_norms.std(),
|
| 221 |
+
"train/grad_norm/mean": grad_norms.mean(),
|
| 222 |
+
"train/grad_norm/median": torch.median(grad_norms),
|
| 223 |
+
"train/gpr/min": gpr.min(),
|
| 224 |
+
"train/gpr/max": gpr.max(),
|
| 225 |
+
"train/gpr/std": gpr.std(),
|
| 226 |
+
"train/gpr/mean": gpr.mean(),
|
| 227 |
+
"train/gpr/median": torch.median(gpr),
|
| 228 |
+
}
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def register_data_mean_std(
|
| 232 |
+
self, mean: Union[str, float, Sequence], std: Union[str, float, Sequence], namespace: str = "data"
|
| 233 |
+
):
|
| 234 |
+
"""
|
| 235 |
+
Register mean and std of data as tensor buffer.
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
mean: the mean of data.
|
| 239 |
+
std: the std of data.
|
| 240 |
+
namespace: the namespace of the registered buffer.
|
| 241 |
+
"""
|
| 242 |
+
for k, v in [("mean", mean), ("std", std)]:
|
| 243 |
+
if isinstance(v, str):
|
| 244 |
+
if v.endswith(".npy"):
|
| 245 |
+
v = torch.from_numpy(np.load(v))
|
| 246 |
+
elif v.endswith(".pt"):
|
| 247 |
+
v = torch.load(v)
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError(f"Unsupported file type {v.split('.')[-1]}.")
|
| 250 |
+
else:
|
| 251 |
+
v = torch.tensor(v)
|
| 252 |
+
self.register_buffer(f"{namespace}_{k}", v.float().to(self.device))
|
algorithms/common/metrics/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .fid import FrechetInceptionDistance
|
| 2 |
+
from .lpips import LearnedPerceptualImagePatchSimilarity
|
| 3 |
+
from .fvd import FrechetVideoDistance
|
algorithms/common/metrics/fid.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
algorithms/common/metrics/fvd.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adopted from https://github.com/cvpr2022-stylegan-v/stylegan-v
|
| 3 |
+
Verified to be the same as tf version by https://github.com/universome/fvd-comparison
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import io
|
| 7 |
+
import re
|
| 8 |
+
import requests
|
| 9 |
+
import html
|
| 10 |
+
import hashlib
|
| 11 |
+
import urllib
|
| 12 |
+
import urllib.request
|
| 13 |
+
from typing import Any, List, Tuple, Union, Dict
|
| 14 |
+
import scipy
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def open_url(
|
| 22 |
+
url: str,
|
| 23 |
+
num_attempts: int = 10,
|
| 24 |
+
verbose: bool = True,
|
| 25 |
+
return_filename: bool = False,
|
| 26 |
+
) -> Any:
|
| 27 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 28 |
+
assert num_attempts >= 1
|
| 29 |
+
|
| 30 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 31 |
+
if not re.match("^[a-z]+://", url):
|
| 32 |
+
return url if return_filename else open(url, "rb")
|
| 33 |
+
|
| 34 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 35 |
+
# arise on Windows:
|
| 36 |
+
#
|
| 37 |
+
# file:///c:/foo.txt
|
| 38 |
+
#
|
| 39 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 40 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 41 |
+
#
|
| 42 |
+
# If you touch this code path, you should test it on both Linux and
|
| 43 |
+
# Windows.
|
| 44 |
+
#
|
| 45 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 46 |
+
# but that converts forward slashes to backslashes and this causes
|
| 47 |
+
# its own set of problems.
|
| 48 |
+
if url.startswith("file://"):
|
| 49 |
+
filename = urllib.parse.urlparse(url).path
|
| 50 |
+
if re.match(r"^/[a-zA-Z]:", filename):
|
| 51 |
+
filename = filename[1:]
|
| 52 |
+
return filename if return_filename else open(filename, "rb")
|
| 53 |
+
|
| 54 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 55 |
+
|
| 56 |
+
# Download.
|
| 57 |
+
url_name = None
|
| 58 |
+
url_data = None
|
| 59 |
+
with requests.Session() as session:
|
| 60 |
+
if verbose:
|
| 61 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 62 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 63 |
+
try:
|
| 64 |
+
with session.get(url) as res:
|
| 65 |
+
res.raise_for_status()
|
| 66 |
+
if len(res.content) == 0:
|
| 67 |
+
raise IOError("No data received")
|
| 68 |
+
|
| 69 |
+
if len(res.content) < 8192:
|
| 70 |
+
content_str = res.content.decode("utf-8")
|
| 71 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 72 |
+
links = [
|
| 73 |
+
html.unescape(link)
|
| 74 |
+
for link in content_str.split('"')
|
| 75 |
+
if "export=download" in link
|
| 76 |
+
]
|
| 77 |
+
if len(links) == 1:
|
| 78 |
+
url = requests.compat.urljoin(url, links[0])
|
| 79 |
+
raise IOError("Google Drive virus checker nag")
|
| 80 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 81 |
+
raise IOError(
|
| 82 |
+
"Google Drive download quota exceeded -- please try again later"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
match = re.search(
|
| 86 |
+
r'filename="([^"]*)"',
|
| 87 |
+
res.headers.get("Content-Disposition", ""),
|
| 88 |
+
)
|
| 89 |
+
url_name = match[1] if match else url
|
| 90 |
+
url_data = res.content
|
| 91 |
+
if verbose:
|
| 92 |
+
print(" done")
|
| 93 |
+
break
|
| 94 |
+
except KeyboardInterrupt:
|
| 95 |
+
raise
|
| 96 |
+
except:
|
| 97 |
+
if not attempts_left:
|
| 98 |
+
if verbose:
|
| 99 |
+
print(" failed")
|
| 100 |
+
raise
|
| 101 |
+
if verbose:
|
| 102 |
+
print(".", end="", flush=True)
|
| 103 |
+
|
| 104 |
+
# Return data as file object.
|
| 105 |
+
assert not return_filename
|
| 106 |
+
return io.BytesIO(url_data)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
|
| 110 |
+
mu_gen, sigma_gen = compute_stats(feats_fake)
|
| 111 |
+
mu_real, sigma_real = compute_stats(feats_real)
|
| 112 |
+
|
| 113 |
+
m = np.square(mu_gen - mu_real).sum()
|
| 114 |
+
s, _ = scipy.linalg.sqrtm(
|
| 115 |
+
np.dot(sigma_gen, sigma_real), disp=False
|
| 116 |
+
) # pylint: disable=no-member
|
| 117 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
| 118 |
+
|
| 119 |
+
return float(fid)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 123 |
+
mu = feats.mean(axis=0) # [d]
|
| 124 |
+
sigma = np.cov(feats, rowvar=False) # [d, d]
|
| 125 |
+
|
| 126 |
+
return mu, sigma
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class FrechetVideoDistance(nn.Module):
|
| 130 |
+
def __init__(self):
|
| 131 |
+
super().__init__()
|
| 132 |
+
detector_url = (
|
| 133 |
+
"https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1"
|
| 134 |
+
)
|
| 135 |
+
# Return raw features before the softmax layer.
|
| 136 |
+
self.detector_kwargs = dict(rescale=False, resize=True, return_features=True)
|
| 137 |
+
with open_url(detector_url, verbose=False) as f:
|
| 138 |
+
self.detector = torch.jit.load(f).eval()
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def compute(self, videos_fake: torch.Tensor, videos_real: torch.Tensor):
|
| 142 |
+
"""
|
| 143 |
+
:param videos_fake: predicted video tensor of shape (frame, batch, channel, height, width)
|
| 144 |
+
:param videos_real: ground-truth observation tensor of shape (frame, batch, channel, height, width)
|
| 145 |
+
:return:
|
| 146 |
+
"""
|
| 147 |
+
n_frames, batch_size, c, h, w = videos_fake.shape
|
| 148 |
+
if n_frames < 2:
|
| 149 |
+
raise ValueError("Video must have more than 1 frame for FVD")
|
| 150 |
+
|
| 151 |
+
videos_fake = videos_fake.permute(1, 2, 0, 3, 4).contiguous()
|
| 152 |
+
videos_real = videos_real.permute(1, 2, 0, 3, 4).contiguous()
|
| 153 |
+
|
| 154 |
+
# detector takes in tensors of shape [batch_size, c, video_len, h, w] with range -1 to 1
|
| 155 |
+
feats_fake = self.detector(videos_fake, **self.detector_kwargs).cpu().numpy()
|
| 156 |
+
feats_real = self.detector(videos_real, **self.detector_kwargs).cpu().numpy()
|
| 157 |
+
|
| 158 |
+
return compute_fvd(feats_fake, feats_real)
|
algorithms/common/metrics/lpips.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
algorithms/common/models/__init__.py
ADDED
|
File without changes
|
algorithms/common/models/cnn.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def is_square_of_two(num):
|
| 7 |
+
if num <= 0:
|
| 8 |
+
return False
|
| 9 |
+
return num & (num - 1) == 0
|
| 10 |
+
|
| 11 |
+
class CnnEncoder(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
Simple cnn encoder that encodes a 64x64 image to embeddings
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, embedding_size, activation_function='relu'):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.act_fn = getattr(F, activation_function)
|
| 18 |
+
self.embedding_size = embedding_size
|
| 19 |
+
self.fc = nn.Linear(1024, self.embedding_size)
|
| 20 |
+
self.conv1 = nn.Conv2d(3, 32, 4, stride=2)
|
| 21 |
+
self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
|
| 22 |
+
self.conv3 = nn.Conv2d(64, 128, 4, stride=2)
|
| 23 |
+
self.conv4 = nn.Conv2d(128, 256, 4, stride=2)
|
| 24 |
+
self.modules = [self.conv1, self.conv2, self.conv3, self.conv4]
|
| 25 |
+
|
| 26 |
+
def forward(self, observation):
|
| 27 |
+
batch_size = observation.shape[0]
|
| 28 |
+
hidden = self.act_fn(self.conv1(observation))
|
| 29 |
+
hidden = self.act_fn(self.conv2(hidden))
|
| 30 |
+
hidden = self.act_fn(self.conv3(hidden))
|
| 31 |
+
hidden = self.act_fn(self.conv4(hidden))
|
| 32 |
+
hidden = self.fc(hidden.view(batch_size, 1024))
|
| 33 |
+
return hidden
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CnnDecoder(nn.Module):
|
| 37 |
+
"""
|
| 38 |
+
Simple Cnn decoder that decodes an embedding to 64x64 images
|
| 39 |
+
"""
|
| 40 |
+
def __init__(self, embedding_size, activation_function='relu'):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.act_fn = getattr(F, activation_function)
|
| 43 |
+
self.embedding_size = embedding_size
|
| 44 |
+
self.fc = nn.Linear(embedding_size, 128)
|
| 45 |
+
self.conv1 = nn.ConvTranspose2d(128, 128, 5, stride=2)
|
| 46 |
+
self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2)
|
| 47 |
+
self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2)
|
| 48 |
+
self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2)
|
| 49 |
+
self.modules = [self.conv1, self.conv2, self.conv3, self.conv4]
|
| 50 |
+
|
| 51 |
+
def forward(self, embedding):
|
| 52 |
+
batch_size = embedding.shape[0]
|
| 53 |
+
hidden = self.fc(embedding)
|
| 54 |
+
hidden = hidden.view(batch_size, 128, 1, 1)
|
| 55 |
+
hidden = self.act_fn(self.conv1(hidden))
|
| 56 |
+
hidden = self.act_fn(self.conv2(hidden))
|
| 57 |
+
hidden = self.act_fn(self.conv3(hidden))
|
| 58 |
+
observation = self.conv4(hidden)
|
| 59 |
+
return observation
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class FullyConvEncoder(nn.Module):
|
| 63 |
+
"""
|
| 64 |
+
Simple fully convolutional encoder, with 2D input and 2D output
|
| 65 |
+
"""
|
| 66 |
+
def __init__(self,
|
| 67 |
+
input_shape=(3, 64, 64),
|
| 68 |
+
embedding_shape=(8, 16, 16),
|
| 69 |
+
activation_function='relu',
|
| 70 |
+
init_channels=16,
|
| 71 |
+
):
|
| 72 |
+
super().__init__()
|
| 73 |
+
|
| 74 |
+
assert len(input_shape) == 3, "input_shape must be a tuple of length 3"
|
| 75 |
+
assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3"
|
| 76 |
+
assert input_shape[1] == input_shape[2] and is_square_of_two(input_shape[1]), "input_shape must be square"
|
| 77 |
+
assert embedding_shape[1] == embedding_shape[2], "embedding_shape must be square"
|
| 78 |
+
assert input_shape[1] % embedding_shape[1] == 0, "input_shape must be divisible by embedding_shape"
|
| 79 |
+
assert is_square_of_two(init_channels), "init_channels must be a square of 2"
|
| 80 |
+
|
| 81 |
+
depth = int(math.sqrt(input_shape[1] / embedding_shape[1])) + 1
|
| 82 |
+
channels_per_layer = [init_channels * (2 ** i) for i in range(depth)]
|
| 83 |
+
self.act_fn = getattr(F, activation_function)
|
| 84 |
+
|
| 85 |
+
self.downs = nn.ModuleList([])
|
| 86 |
+
self.downs.append(nn.Conv2d(input_shape[0], channels_per_layer[0], kernel_size=3, stride=1, padding=1))
|
| 87 |
+
|
| 88 |
+
for i in range(1, depth):
|
| 89 |
+
self.downs.append(nn.Conv2d(channels_per_layer[i-1], channels_per_layer[i],
|
| 90 |
+
kernel_size=3, stride=2, padding=1))
|
| 91 |
+
|
| 92 |
+
# Bottleneck layer
|
| 93 |
+
self.downs.append(nn.Conv2d(channels_per_layer[-1], embedding_shape[0], kernel_size=1, stride=1, padding=0))
|
| 94 |
+
|
| 95 |
+
def forward(self, observation):
|
| 96 |
+
hidden = observation
|
| 97 |
+
for layer in self.downs:
|
| 98 |
+
hidden = self.act_fn(layer(hidden))
|
| 99 |
+
return hidden
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class FullyConvDecoder(nn.Module):
|
| 103 |
+
"""
|
| 104 |
+
Simple fully convolutional decoder, with 2D input and 2D output
|
| 105 |
+
"""
|
| 106 |
+
def __init__(self,
|
| 107 |
+
embedding_shape=(8, 16, 16),
|
| 108 |
+
output_shape=(3, 64, 64),
|
| 109 |
+
activation_function='relu',
|
| 110 |
+
init_channels=16,
|
| 111 |
+
):
|
| 112 |
+
super().__init__()
|
| 113 |
+
|
| 114 |
+
assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3"
|
| 115 |
+
assert len(output_shape) == 3, "output_shape must be a tuple of length 3"
|
| 116 |
+
assert output_shape[1] == output_shape[2] and is_square_of_two(output_shape[1]), "output_shape must be square"
|
| 117 |
+
assert embedding_shape[1] == embedding_shape[2], "input_shape must be square"
|
| 118 |
+
assert output_shape[1] % embedding_shape[1] == 0, "output_shape must be divisible by input_shape"
|
| 119 |
+
assert is_square_of_two(init_channels), "init_channels must be a square of 2"
|
| 120 |
+
|
| 121 |
+
depth = int(math.sqrt(output_shape[1] / embedding_shape[1])) + 1
|
| 122 |
+
channels_per_layer = [init_channels * (2 ** i) for i in range(depth)]
|
| 123 |
+
self.act_fn = getattr(F, activation_function)
|
| 124 |
+
|
| 125 |
+
self.ups = nn.ModuleList([])
|
| 126 |
+
self.ups.append(nn.ConvTranspose2d(embedding_shape[0], channels_per_layer[-1],
|
| 127 |
+
kernel_size=1, stride=1, padding=0))
|
| 128 |
+
|
| 129 |
+
for i in range(1, depth):
|
| 130 |
+
self.ups.append(nn.ConvTranspose2d(channels_per_layer[-i], channels_per_layer[-i-1],
|
| 131 |
+
kernel_size=3, stride=2, padding=1, output_padding=1))
|
| 132 |
+
|
| 133 |
+
self.output_layer = nn.ConvTranspose2d(channels_per_layer[0], output_shape[0],
|
| 134 |
+
kernel_size=3, stride=1, padding=1)
|
| 135 |
+
|
| 136 |
+
def forward(self, embedding):
|
| 137 |
+
hidden = embedding
|
| 138 |
+
for layer in self.ups:
|
| 139 |
+
hidden = self.act_fn(layer(hidden))
|
| 140 |
+
|
| 141 |
+
return self.output_layer(hidden)
|
algorithms/common/models/mlp.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Type, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SimpleMlp(nn.Module):
|
| 8 |
+
"""
|
| 9 |
+
A class for very simple multi layer perceptron
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, in_dim=2, out_dim=1, hidden_dim=64, n_layers=2,
|
| 12 |
+
activation: Type[nn.Module] = nn.ReLU, output_activation: Optional[Type[nn.Module]] = None):
|
| 13 |
+
super(SimpleMlp, self).__init__()
|
| 14 |
+
layers = [nn.Linear(in_dim, hidden_dim), activation()]
|
| 15 |
+
layers.extend([nn.Linear(hidden_dim, hidden_dim), activation()] * (n_layers - 2))
|
| 16 |
+
layers.append(nn.Linear(hidden_dim, out_dim))
|
| 17 |
+
if output_activation:
|
| 18 |
+
layers.append(output_activation())
|
| 19 |
+
self.net = nn.Sequential(*layers)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
return self.net(x)
|
algorithms/worldmem/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .df_video import WorldMemMinecraft
|
| 2 |
+
from .pose_prediction import PosePrediction
|
algorithms/worldmem/df_base.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
|
| 3 |
+
template [repo](https://github.com/buoyancy99/research-template).
|
| 4 |
+
By its MIT license, you must keep the above sentence in `README.md`
|
| 5 |
+
and the `LICENSE` file to credit the author.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from omegaconf import DictConfig
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from typing import Any
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
|
| 17 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
| 18 |
+
|
| 19 |
+
from algorithms.common.base_pytorch_algo import BasePytorchAlgo
|
| 20 |
+
from .models.diffusion import Diffusion
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DiffusionForcingBase(BasePytorchAlgo):
|
| 24 |
+
def __init__(self, cfg: DictConfig):
|
| 25 |
+
self.cfg = cfg
|
| 26 |
+
self.x_shape = cfg.x_shape
|
| 27 |
+
self.frame_stack = cfg.frame_stack
|
| 28 |
+
self.x_stacked_shape = list(self.x_shape)
|
| 29 |
+
self.x_stacked_shape[0] *= cfg.frame_stack
|
| 30 |
+
self.guidance_scale = cfg.guidance_scale
|
| 31 |
+
self.context_frames = cfg.context_frames
|
| 32 |
+
self.chunk_size = cfg.chunk_size
|
| 33 |
+
self.action_cond_dim = cfg.action_cond_dim
|
| 34 |
+
self.causal = cfg.causal
|
| 35 |
+
|
| 36 |
+
self.uncertainty_scale = cfg.uncertainty_scale
|
| 37 |
+
self.timesteps = cfg.diffusion.timesteps
|
| 38 |
+
self.sampling_timesteps = cfg.diffusion.sampling_timesteps
|
| 39 |
+
self.clip_noise = cfg.diffusion.clip_noise
|
| 40 |
+
|
| 41 |
+
self.cfg.diffusion.cum_snr_decay = self.cfg.diffusion.cum_snr_decay ** (self.frame_stack * cfg.frame_skip)
|
| 42 |
+
|
| 43 |
+
self.validation_step_outputs = []
|
| 44 |
+
super().__init__(cfg)
|
| 45 |
+
|
| 46 |
+
def _build_model(self):
|
| 47 |
+
self.diffusion_model = Diffusion(
|
| 48 |
+
x_shape=self.x_stacked_shape,
|
| 49 |
+
action_cond_dim=self.action_cond_dim,
|
| 50 |
+
is_causal=self.causal,
|
| 51 |
+
cfg=self.cfg.diffusion,
|
| 52 |
+
)
|
| 53 |
+
self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std)
|
| 54 |
+
|
| 55 |
+
def configure_optimizers(self):
|
| 56 |
+
params = tuple(self.diffusion_model.parameters())
|
| 57 |
+
optimizer_dynamics = torch.optim.AdamW(
|
| 58 |
+
params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay, betas=self.cfg.optimizer_beta
|
| 59 |
+
)
|
| 60 |
+
return optimizer_dynamics
|
| 61 |
+
|
| 62 |
+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
|
| 63 |
+
# update params
|
| 64 |
+
optimizer.step(closure=optimizer_closure)
|
| 65 |
+
|
| 66 |
+
# manually warm up lr without a scheduler
|
| 67 |
+
if self.trainer.global_step < self.cfg.warmup_steps:
|
| 68 |
+
lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.cfg.warmup_steps)
|
| 69 |
+
for pg in optimizer.param_groups:
|
| 70 |
+
pg["lr"] = lr_scale * self.cfg.lr
|
| 71 |
+
|
| 72 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
| 73 |
+
xs, conditions, masks = self._preprocess_batch(batch)
|
| 74 |
+
|
| 75 |
+
rand_length = torch.randint(3,xs.shape[0]-2, (1,))[0].item()
|
| 76 |
+
xs = torch.cat([xs[:rand_length], xs[rand_length-3:rand_length-1]])
|
| 77 |
+
conditions = torch.cat([conditions[:rand_length], conditions[rand_length-3:rand_length-1]])
|
| 78 |
+
masks = torch.cat([masks[:rand_length], masks[rand_length-3:rand_length-1]])
|
| 79 |
+
noise_levels=self._generate_noise_levels(xs)
|
| 80 |
+
noise_levels[:rand_length] = 15 # stable_noise_levels
|
| 81 |
+
noise_levels[rand_length+1:] = 15 # stable_noise_levels
|
| 82 |
+
|
| 83 |
+
xs_pred, loss = self.diffusion_model(xs, conditions, noise_levels=noise_levels)
|
| 84 |
+
loss = self.reweight_loss(loss, masks)
|
| 85 |
+
|
| 86 |
+
# log the loss
|
| 87 |
+
if batch_idx % 20 == 0:
|
| 88 |
+
self.log("training/loss", loss)
|
| 89 |
+
|
| 90 |
+
xs = self._unstack_and_unnormalize(xs)
|
| 91 |
+
xs_pred = self._unstack_and_unnormalize(xs_pred)
|
| 92 |
+
|
| 93 |
+
output_dict = {
|
| 94 |
+
"loss": loss,
|
| 95 |
+
"xs_pred": xs_pred,
|
| 96 |
+
"xs": xs,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
return output_dict
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
|
| 103 |
+
xs, conditions, masks = self._preprocess_batch(batch)
|
| 104 |
+
n_frames, batch_size, *_ = xs.shape
|
| 105 |
+
xs_pred = []
|
| 106 |
+
curr_frame = 0
|
| 107 |
+
|
| 108 |
+
# context
|
| 109 |
+
n_context_frames = self.context_frames // self.frame_stack
|
| 110 |
+
xs_pred = xs[:n_context_frames].clone()
|
| 111 |
+
curr_frame += n_context_frames
|
| 112 |
+
|
| 113 |
+
if self.condtion_similar_length:
|
| 114 |
+
n_frames -= self.condtion_similar_length
|
| 115 |
+
|
| 116 |
+
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
| 117 |
+
while curr_frame < n_frames:
|
| 118 |
+
if self.chunk_size > 0:
|
| 119 |
+
horizon = min(n_frames - curr_frame, self.chunk_size)
|
| 120 |
+
else:
|
| 121 |
+
horizon = n_frames - curr_frame
|
| 122 |
+
assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
|
| 123 |
+
scheduling_matrix = self._generate_scheduling_matrix(horizon)
|
| 124 |
+
|
| 125 |
+
chunk = torch.randn((horizon, batch_size, *self.x_stacked_shape), device=self.device)
|
| 126 |
+
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
|
| 127 |
+
xs_pred = torch.cat([xs_pred, chunk], 0)
|
| 128 |
+
|
| 129 |
+
# sliding window: only input the last n_tokens frames
|
| 130 |
+
start_frame = max(0, curr_frame + horizon - self.n_tokens)
|
| 131 |
+
|
| 132 |
+
pbar.set_postfix(
|
| 133 |
+
{
|
| 134 |
+
"start": start_frame,
|
| 135 |
+
"end": curr_frame + horizon,
|
| 136 |
+
}
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if self.condtion_similar_length:
|
| 140 |
+
xs_pred = torch.cat([xs_pred, xs[curr_frame-self.condtion_similar_length:curr_frame].clone()], 0)
|
| 141 |
+
|
| 142 |
+
for m in range(scheduling_matrix.shape[0] - 1):
|
| 143 |
+
|
| 144 |
+
from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[
|
| 145 |
+
:, None
|
| 146 |
+
].repeat(batch_size, axis=1)
|
| 147 |
+
to_noise_levels = np.concatenate(
|
| 148 |
+
(
|
| 149 |
+
np.zeros((curr_frame,), dtype=np.int64),
|
| 150 |
+
scheduling_matrix[m + 1],
|
| 151 |
+
)
|
| 152 |
+
)[
|
| 153 |
+
:, None
|
| 154 |
+
].repeat(batch_size, axis=1)
|
| 155 |
+
|
| 156 |
+
if self.condtion_similar_length:
|
| 157 |
+
from_noise_levels = np.concatenate([from_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
|
| 158 |
+
to_noise_levels = np.concatenate([to_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
|
| 159 |
+
|
| 160 |
+
from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
|
| 161 |
+
to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
|
| 162 |
+
|
| 163 |
+
# update xs_pred by DDIM or DDPM sampling
|
| 164 |
+
# input frames within the sliding window
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
input_condition = conditions[start_frame : curr_frame + horizon].clone()
|
| 168 |
+
except:
|
| 169 |
+
import pdb;pdb.set_trace()
|
| 170 |
+
if self.condtion_similar_length:
|
| 171 |
+
input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], conditions[-self.condtion_similar_length:]], dim=0)
|
| 172 |
+
xs_pred[start_frame:] = self.diffusion_model.sample_step(
|
| 173 |
+
xs_pred[start_frame:],
|
| 174 |
+
input_condition,
|
| 175 |
+
from_noise_levels[start_frame:],
|
| 176 |
+
to_noise_levels[start_frame:],
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if self.condtion_similar_length:
|
| 180 |
+
xs_pred = xs_pred[:-self.condtion_similar_length]
|
| 181 |
+
|
| 182 |
+
curr_frame += horizon
|
| 183 |
+
pbar.update(horizon)
|
| 184 |
+
|
| 185 |
+
if self.condtion_similar_length:
|
| 186 |
+
xs = xs[:-self.condtion_similar_length]
|
| 187 |
+
# FIXME: loss
|
| 188 |
+
loss = F.mse_loss(xs_pred, xs, reduction="none")
|
| 189 |
+
loss = self.reweight_loss(loss, masks)
|
| 190 |
+
self.validation_step_outputs.append((xs_pred.detach().cpu(), xs.detach().cpu()))
|
| 191 |
+
|
| 192 |
+
return loss
|
| 193 |
+
|
| 194 |
+
def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
| 195 |
+
return self.validation_step(*args, **kwargs, namespace="test")
|
| 196 |
+
|
| 197 |
+
def on_test_epoch_end(self) -> None:
|
| 198 |
+
self.on_validation_epoch_end(namespace="test")
|
| 199 |
+
|
| 200 |
+
def _generate_noise_levels(self, xs: torch.Tensor, masks: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 201 |
+
"""
|
| 202 |
+
Generate noise levels for training.
|
| 203 |
+
"""
|
| 204 |
+
num_frames, batch_size, *_ = xs.shape
|
| 205 |
+
match self.cfg.noise_level:
|
| 206 |
+
case "random_all": # entirely random noise levels
|
| 207 |
+
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
|
| 208 |
+
case "same":
|
| 209 |
+
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
|
| 210 |
+
noise_levels[1:] = noise_levels[0]
|
| 211 |
+
|
| 212 |
+
if masks is not None:
|
| 213 |
+
# for frames that are not available, treat as full noise
|
| 214 |
+
discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
|
| 215 |
+
noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
|
| 216 |
+
|
| 217 |
+
return noise_levels
|
| 218 |
+
|
| 219 |
+
def _generate_scheduling_matrix(self, horizon: int):
|
| 220 |
+
match self.cfg.scheduling_matrix:
|
| 221 |
+
case "pyramid":
|
| 222 |
+
return self._generate_pyramid_scheduling_matrix(horizon, self.uncertainty_scale)
|
| 223 |
+
case "full_sequence":
|
| 224 |
+
return np.arange(self.sampling_timesteps, -1, -1)[:, None].repeat(horizon, axis=1)
|
| 225 |
+
case "autoregressive":
|
| 226 |
+
return self._generate_pyramid_scheduling_matrix(horizon, self.sampling_timesteps)
|
| 227 |
+
case "trapezoid":
|
| 228 |
+
return self._generate_trapezoid_scheduling_matrix(horizon, self.uncertainty_scale)
|
| 229 |
+
|
| 230 |
+
def _generate_pyramid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
|
| 231 |
+
height = self.sampling_timesteps + int((horizon - 1) * uncertainty_scale) + 1
|
| 232 |
+
scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
|
| 233 |
+
for m in range(height):
|
| 234 |
+
for t in range(horizon):
|
| 235 |
+
scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
|
| 236 |
+
|
| 237 |
+
return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
|
| 238 |
+
|
| 239 |
+
def _generate_trapezoid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
|
| 240 |
+
height = self.sampling_timesteps + int((horizon + 1) // 2 * uncertainty_scale)
|
| 241 |
+
scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
|
| 242 |
+
for m in range(height):
|
| 243 |
+
for t in range((horizon + 1) // 2):
|
| 244 |
+
scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
|
| 245 |
+
scheduling_matrix[m, -t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
|
| 246 |
+
|
| 247 |
+
return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
|
| 248 |
+
|
| 249 |
+
def reweight_loss(self, loss, weight=None):
|
| 250 |
+
# Note there is another part of loss reweighting (fused_snr) inside the Diffusion class!
|
| 251 |
+
loss = rearrange(loss, "t b (fs c) ... -> t b fs c ...", fs=self.frame_stack)
|
| 252 |
+
if weight is not None:
|
| 253 |
+
expand_dim = len(loss.shape) - len(weight.shape) - 1
|
| 254 |
+
weight = rearrange(
|
| 255 |
+
weight,
|
| 256 |
+
"(t fs) b ... -> t b fs ..." + " 1" * expand_dim,
|
| 257 |
+
fs=self.frame_stack,
|
| 258 |
+
)
|
| 259 |
+
loss = loss * weight
|
| 260 |
+
|
| 261 |
+
return loss.mean()
|
| 262 |
+
|
| 263 |
+
def _preprocess_batch(self, batch):
|
| 264 |
+
xs = batch[0]
|
| 265 |
+
batch_size, n_frames = xs.shape[:2]
|
| 266 |
+
|
| 267 |
+
if n_frames % self.frame_stack != 0:
|
| 268 |
+
raise ValueError("Number of frames must be divisible by frame stack size")
|
| 269 |
+
if self.context_frames % self.frame_stack != 0:
|
| 270 |
+
raise ValueError("Number of context frames must be divisible by frame stack size")
|
| 271 |
+
|
| 272 |
+
masks = torch.ones(n_frames, batch_size).to(xs.device)
|
| 273 |
+
n_frames = n_frames // self.frame_stack
|
| 274 |
+
|
| 275 |
+
if self.action_cond_dim:
|
| 276 |
+
conditions = batch[1]
|
| 277 |
+
conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
|
| 278 |
+
conditions = rearrange(conditions, "b (t fs) d -> t b (fs d)", fs=self.frame_stack).contiguous()
|
| 279 |
+
|
| 280 |
+
# f, _, _ = conditions.shape
|
| 281 |
+
# predefined_1 = torch.tensor([0,0,0,1]).to(conditions.device)
|
| 282 |
+
# predefined_2 = torch.tensor([0,0,1,0]).to(conditions.device)
|
| 283 |
+
# conditions[:f//2] = predefined_1
|
| 284 |
+
# conditions[f//2:] = predefined_2
|
| 285 |
+
else:
|
| 286 |
+
conditions = [None for _ in range(n_frames)]
|
| 287 |
+
|
| 288 |
+
xs = self._normalize_x(xs)
|
| 289 |
+
xs = rearrange(xs, "b (t fs) c ... -> t b (fs c) ...", fs=self.frame_stack).contiguous()
|
| 290 |
+
|
| 291 |
+
return xs, conditions, masks
|
| 292 |
+
|
| 293 |
+
def _normalize_x(self, xs):
|
| 294 |
+
shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
|
| 295 |
+
mean = self.data_mean.reshape(shape)
|
| 296 |
+
std = self.data_std.reshape(shape)
|
| 297 |
+
return (xs - mean) / std
|
| 298 |
+
|
| 299 |
+
def _unnormalize_x(self, xs):
|
| 300 |
+
shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
|
| 301 |
+
mean = self.data_mean.reshape(shape)
|
| 302 |
+
std = self.data_std.reshape(shape)
|
| 303 |
+
return xs * std + mean
|
| 304 |
+
|
| 305 |
+
def _unstack_and_unnormalize(self, xs):
|
| 306 |
+
xs = rearrange(xs, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack)
|
| 307 |
+
return self._unnormalize_x(xs)
|
algorithms/worldmem/df_video.py
ADDED
|
@@ -0,0 +1,926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torchvision.transforms.functional as TF
|
| 8 |
+
from torchvision.transforms import InterpolationMode
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from packaging import version as pver
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from omegaconf import DictConfig
|
| 14 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
| 15 |
+
from algorithms.common.metrics import (
|
| 16 |
+
LearnedPerceptualImagePatchSimilarity,
|
| 17 |
+
)
|
| 18 |
+
from utils.logging_utils import log_video, get_validation_metrics_for_videos
|
| 19 |
+
from .df_base import DiffusionForcingBase
|
| 20 |
+
from .models.vae import VAE_models
|
| 21 |
+
from .models.diffusion import Diffusion
|
| 22 |
+
from .models.pose_prediction import PosePredictionNet
|
| 23 |
+
import glob
|
| 24 |
+
|
| 25 |
+
# Utility Functions
|
| 26 |
+
def euler_to_rotation_matrix(pitch, yaw):
|
| 27 |
+
"""
|
| 28 |
+
Convert pitch and yaw angles (in radians) to a 3x3 rotation matrix.
|
| 29 |
+
Supports batch input.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
pitch (torch.Tensor): Pitch angles in radians.
|
| 33 |
+
yaw (torch.Tensor): Yaw angles in radians.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
torch.Tensor: Rotation matrix of shape (batch_size, 3, 3).
|
| 37 |
+
"""
|
| 38 |
+
cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch)
|
| 39 |
+
cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw)
|
| 40 |
+
|
| 41 |
+
R_pitch = torch.stack([
|
| 42 |
+
torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
|
| 43 |
+
torch.zeros_like(pitch), cos_pitch, -sin_pitch,
|
| 44 |
+
torch.zeros_like(pitch), sin_pitch, cos_pitch
|
| 45 |
+
], dim=-1).reshape(-1, 3, 3)
|
| 46 |
+
|
| 47 |
+
R_yaw = torch.stack([
|
| 48 |
+
cos_yaw, torch.zeros_like(yaw), sin_yaw,
|
| 49 |
+
torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
|
| 50 |
+
-sin_yaw, torch.zeros_like(yaw), cos_yaw
|
| 51 |
+
], dim=-1).reshape(-1, 3, 3)
|
| 52 |
+
|
| 53 |
+
return torch.matmul(R_yaw, R_pitch)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def euler_to_camera_to_world_matrix(pose):
|
| 57 |
+
"""
|
| 58 |
+
Convert (x, y, z, pitch, yaw) to a 4x4 camera-to-world transformation matrix using torch.
|
| 59 |
+
Supports both (5,) and (f, b, 5) shaped inputs.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
pose (torch.Tensor): Pose tensor of shape (5,) or (f, b, 5).
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
torch.Tensor: Camera-to-world transformation matrix of shape (4, 4).
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
origin_dim = pose.ndim
|
| 69 |
+
if origin_dim == 1:
|
| 70 |
+
pose = pose.unsqueeze(0).unsqueeze(0) # Convert (5,) -> (1, 1, 5)
|
| 71 |
+
elif origin_dim == 2:
|
| 72 |
+
pose = pose.unsqueeze(0)
|
| 73 |
+
|
| 74 |
+
x, y, z, pitch, yaw = pose[..., 0], pose[..., 1], pose[..., 2], pose[..., 3], pose[..., 4]
|
| 75 |
+
pitch, yaw = torch.deg2rad(pitch), torch.deg2rad(yaw)
|
| 76 |
+
|
| 77 |
+
# Compute rotation matrix (batch mode)
|
| 78 |
+
R = euler_to_rotation_matrix(pitch, yaw) # Shape (f*b, 3, 3)
|
| 79 |
+
|
| 80 |
+
# Create the 4x4 transformation matrix
|
| 81 |
+
eye = torch.eye(4, dtype=torch.float32, device=pose.device)
|
| 82 |
+
camera_to_world = eye.repeat(R.shape[0], 1, 1) # Shape (f*b, 4, 4)
|
| 83 |
+
|
| 84 |
+
# Assign rotation
|
| 85 |
+
camera_to_world[:, :3, :3] = R
|
| 86 |
+
|
| 87 |
+
# Assign translation
|
| 88 |
+
camera_to_world[:, :3, 3] = torch.stack([x.reshape(-1), y.reshape(-1), z.reshape(-1)], dim=-1)
|
| 89 |
+
|
| 90 |
+
# Reshape back to (f, b, 4, 4) if needed
|
| 91 |
+
if origin_dim == 3:
|
| 92 |
+
return camera_to_world.view(pose.shape[0], pose.shape[1], 4, 4)
|
| 93 |
+
elif origin_dim == 2:
|
| 94 |
+
return camera_to_world.view(pose.shape[0], 4, 4)
|
| 95 |
+
else:
|
| 96 |
+
return camera_to_world.squeeze(0).squeeze(0) # Convert (1,1,4,4) -> (4,4)
|
| 97 |
+
|
| 98 |
+
def is_inside_fov_3d_hv(points, center, center_pitch, center_yaw, fov_half_h, fov_half_v):
|
| 99 |
+
"""
|
| 100 |
+
Check whether points are within a given 3D field of view (FOV)
|
| 101 |
+
with separately defined horizontal and vertical ranges.
|
| 102 |
+
|
| 103 |
+
The center view direction is specified by pitch and yaw (in degrees).
|
| 104 |
+
|
| 105 |
+
:param points: (N, B, 3) Sample point coordinates
|
| 106 |
+
:param center: (3,) Center coordinates of the FOV
|
| 107 |
+
:param center_pitch: Pitch angle of the center view (in degrees)
|
| 108 |
+
:param center_yaw: Yaw angle of the center view (in degrees)
|
| 109 |
+
:param fov_half_h: Horizontal half-FOV angle (in degrees)
|
| 110 |
+
:param fov_half_v: Vertical half-FOV angle (in degrees)
|
| 111 |
+
:return: Boolean tensor (N, B), indicating whether each point is inside the FOV
|
| 112 |
+
"""
|
| 113 |
+
# Compute vectors relative to the center
|
| 114 |
+
vectors = points - center # shape (N, B, 3)
|
| 115 |
+
x = vectors[..., 0]
|
| 116 |
+
y = vectors[..., 1]
|
| 117 |
+
z = vectors[..., 2]
|
| 118 |
+
|
| 119 |
+
# Compute horizontal angle (yaw): measured with respect to the z-axis as the forward direction,
|
| 120 |
+
# and the x-axis as left-right, resulting in a range of -180 to 180 degrees.
|
| 121 |
+
azimuth = torch.atan2(x, z) * (180 / math.pi)
|
| 122 |
+
|
| 123 |
+
# Compute vertical angle (pitch): measured with respect to the horizontal plane,
|
| 124 |
+
# resulting in a range of -90 to 90 degrees.
|
| 125 |
+
elevation = torch.atan2(y, torch.sqrt(x**2 + z**2)) * (180 / math.pi)
|
| 126 |
+
|
| 127 |
+
# Compute the angular difference from the center view (handling circular angle wrap-around)
|
| 128 |
+
diff_azimuth = (azimuth - center_yaw).abs() % 360
|
| 129 |
+
diff_elevation = (elevation - center_pitch).abs() % 360
|
| 130 |
+
|
| 131 |
+
# Adjust values greater than 180 degrees to the shorter angular difference
|
| 132 |
+
diff_azimuth = torch.where(diff_azimuth > 180, 360 - diff_azimuth, diff_azimuth)
|
| 133 |
+
diff_elevation = torch.where(diff_elevation > 180, 360 - diff_elevation, diff_elevation)
|
| 134 |
+
|
| 135 |
+
# Check if both horizontal and vertical angles are within their respective FOV limits
|
| 136 |
+
return (diff_azimuth < fov_half_h) & (diff_elevation < fov_half_v)
|
| 137 |
+
|
| 138 |
+
def generate_points_in_sphere(n_points, radius):
|
| 139 |
+
# Sample three independent uniform distributions
|
| 140 |
+
samples_r = torch.rand(n_points) # For radius distribution
|
| 141 |
+
samples_phi = torch.rand(n_points) # For azimuthal angle phi
|
| 142 |
+
samples_u = torch.rand(n_points) # For polar angle theta
|
| 143 |
+
|
| 144 |
+
# Apply cube root to ensure uniform volumetric distribution
|
| 145 |
+
r = radius * torch.pow(samples_r, 1/3)
|
| 146 |
+
# Azimuthal angle phi uniformly distributed in [0, 2π]
|
| 147 |
+
phi = 2 * math.pi * samples_phi
|
| 148 |
+
# Convert u to theta to ensure cos(theta) is uniformly distributed
|
| 149 |
+
theta = torch.acos(1 - 2 * samples_u)
|
| 150 |
+
|
| 151 |
+
# Convert spherical coordinates to Cartesian coordinates
|
| 152 |
+
x = r * torch.sin(theta) * torch.cos(phi)
|
| 153 |
+
y = r * torch.sin(theta) * torch.sin(phi)
|
| 154 |
+
z = r * torch.cos(theta)
|
| 155 |
+
|
| 156 |
+
points = torch.stack((x, y, z), dim=1)
|
| 157 |
+
return points
|
| 158 |
+
|
| 159 |
+
def tensor_max_with_number(tensor, number):
|
| 160 |
+
number_tensor = torch.tensor(number, dtype=tensor.dtype, device=tensor.device)
|
| 161 |
+
result = torch.max(tensor, number_tensor)
|
| 162 |
+
return result
|
| 163 |
+
|
| 164 |
+
def custom_meshgrid(*args):
|
| 165 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
| 166 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
| 167 |
+
return torch.meshgrid(*args)
|
| 168 |
+
else:
|
| 169 |
+
return torch.meshgrid(*args, indexing='ij')
|
| 170 |
+
|
| 171 |
+
def camera_to_world_to_world_to_camera(camera_to_world: torch.Tensor) -> torch.Tensor:
|
| 172 |
+
"""
|
| 173 |
+
Convert Camera-to-World matrices to World-to-Camera matrices for a tensor with shape (f, b, 4, 4).
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
camera_to_world (torch.Tensor): A tensor of shape (f, b, 4, 4), where:
|
| 177 |
+
f = number of frames,
|
| 178 |
+
b = batch size.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
torch.Tensor: A tensor of shape (f, b, 4, 4) representing the World-to-Camera matrices.
|
| 182 |
+
"""
|
| 183 |
+
# Ensure input is a 4D tensor
|
| 184 |
+
assert camera_to_world.ndim == 4 and camera_to_world.shape[2:] == (4, 4), \
|
| 185 |
+
"Input must be of shape (f, b, 4, 4)"
|
| 186 |
+
|
| 187 |
+
# Extract the rotation (R) and translation (T) parts
|
| 188 |
+
R = camera_to_world[:, :, :3, :3] # Shape: (f, b, 3, 3)
|
| 189 |
+
T = camera_to_world[:, :, :3, 3] # Shape: (f, b, 3)
|
| 190 |
+
|
| 191 |
+
# Initialize an identity matrix for the output
|
| 192 |
+
world_to_camera = torch.eye(4, device=camera_to_world.device).unsqueeze(0).unsqueeze(0)
|
| 193 |
+
world_to_camera = world_to_camera.repeat(camera_to_world.size(0), camera_to_world.size(1), 1, 1) # Shape: (f, b, 4, 4)
|
| 194 |
+
|
| 195 |
+
# Compute the rotation (transpose of R)
|
| 196 |
+
world_to_camera[:, :, :3, :3] = R.transpose(2, 3)
|
| 197 |
+
|
| 198 |
+
# Compute the translation (-R^T * T)
|
| 199 |
+
world_to_camera[:, :, :3, 3] = -torch.matmul(R.transpose(2, 3), T.unsqueeze(-1)).squeeze(-1)
|
| 200 |
+
|
| 201 |
+
return world_to_camera.to(camera_to_world.dtype)
|
| 202 |
+
|
| 203 |
+
def convert_to_plucker(poses, curr_frame, focal_length, image_width, image_height):
|
| 204 |
+
|
| 205 |
+
intrinsic = np.asarray([focal_length * image_width,
|
| 206 |
+
focal_length * image_height,
|
| 207 |
+
0.5 * image_width,
|
| 208 |
+
0.5 * image_height], dtype=np.float32)
|
| 209 |
+
|
| 210 |
+
c2ws = get_relative_pose(poses, zero_first_frame_scale=curr_frame)
|
| 211 |
+
c2ws = rearrange(c2ws, "t b m n -> b t m n")
|
| 212 |
+
|
| 213 |
+
K = torch.as_tensor(intrinsic, device=poses.device, dtype=poses.dtype).repeat(c2ws.shape[0],c2ws.shape[1],1) # [B, F, 4]
|
| 214 |
+
plucker_embedding = ray_condition(K, c2ws, image_height, image_width, device=c2ws.device)
|
| 215 |
+
plucker_embedding = rearrange(plucker_embedding, "b t h w d -> t b h w d").contiguous()
|
| 216 |
+
|
| 217 |
+
return plucker_embedding
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def get_relative_pose(abs_c2ws, zero_first_frame_scale):
|
| 221 |
+
abs_w2cs = camera_to_world_to_world_to_camera(abs_c2ws)
|
| 222 |
+
target_cam_c2w = torch.tensor([
|
| 223 |
+
[1, 0, 0, 0],
|
| 224 |
+
[0, 1, 0, 0],
|
| 225 |
+
[0, 0, 1, 0],
|
| 226 |
+
[0, 0, 0, 1]
|
| 227 |
+
]).to(abs_c2ws.device).to(abs_c2ws.dtype)
|
| 228 |
+
abs2rel = target_cam_c2w @ abs_w2cs[zero_first_frame_scale]
|
| 229 |
+
ret_poses = [abs2rel @ abs_c2w for abs_c2w in abs_c2ws]
|
| 230 |
+
ret_poses = torch.stack(ret_poses)
|
| 231 |
+
return ret_poses
|
| 232 |
+
|
| 233 |
+
def ray_condition(K, c2w, H, W, device):
|
| 234 |
+
# c2w: B, V, 4, 4
|
| 235 |
+
# K: B, V, 4
|
| 236 |
+
|
| 237 |
+
B = K.shape[0]
|
| 238 |
+
|
| 239 |
+
j, i = custom_meshgrid(
|
| 240 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
| 241 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
| 242 |
+
)
|
| 243 |
+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 244 |
+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 245 |
+
|
| 246 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
| 247 |
+
|
| 248 |
+
zs = torch.ones_like(i, device=device, dtype=c2w.dtype) # [B, HxW]
|
| 249 |
+
xs = -(i - cx) / fx * zs
|
| 250 |
+
ys = -(j - cy) / fy * zs
|
| 251 |
+
|
| 252 |
+
zs = zs.expand_as(ys)
|
| 253 |
+
|
| 254 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
| 255 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
| 256 |
+
|
| 257 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
| 258 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
| 259 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
| 260 |
+
# c2w @ dirctions
|
| 261 |
+
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
| 262 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
| 263 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
| 264 |
+
|
| 265 |
+
return plucker
|
| 266 |
+
|
| 267 |
+
def random_transform(tensor):
|
| 268 |
+
"""
|
| 269 |
+
Apply the same random translation, rotation, and scaling to all frames in the batch.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
tensor (torch.Tensor): Input tensor of shape (F, B, 3, H, W).
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
torch.Tensor: Transformed tensor of shape (F, B, 3, H, W).
|
| 276 |
+
"""
|
| 277 |
+
if tensor.ndim != 5:
|
| 278 |
+
raise ValueError("Input tensor must have shape (F, B, 3, H, W)")
|
| 279 |
+
|
| 280 |
+
F, B, C, H, W = tensor.shape
|
| 281 |
+
|
| 282 |
+
# Generate random transformation parameters
|
| 283 |
+
max_translate = 0.2 # Translate up to 20% of width/height
|
| 284 |
+
max_rotate = 30 # Rotate up to 30 degrees
|
| 285 |
+
max_scale = 0.2 # Scale change by up to +/- 20%
|
| 286 |
+
|
| 287 |
+
translate_x = random.uniform(-max_translate, max_translate) * W
|
| 288 |
+
translate_y = random.uniform(-max_translate, max_translate) * H
|
| 289 |
+
rotate_angle = random.uniform(-max_rotate, max_rotate)
|
| 290 |
+
scale_factor = 1 + random.uniform(-max_scale, max_scale)
|
| 291 |
+
|
| 292 |
+
# Apply the same transformation to all frames and batches
|
| 293 |
+
|
| 294 |
+
tensor = tensor.reshape(F*B, C, H, W)
|
| 295 |
+
transformed_tensor = TF.affine(
|
| 296 |
+
tensor,
|
| 297 |
+
angle=rotate_angle,
|
| 298 |
+
translate=(translate_x, translate_y),
|
| 299 |
+
scale=scale_factor,
|
| 300 |
+
shear=(0, 0),
|
| 301 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 302 |
+
fill=0
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
transformed_tensor = transformed_tensor.reshape(F, B, C, H, W)
|
| 306 |
+
return transformed_tensor
|
| 307 |
+
|
| 308 |
+
def save_tensor_as_png(tensor, file_path):
|
| 309 |
+
"""
|
| 310 |
+
Save a 3*H*W tensor as a PNG image.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
tensor (torch.Tensor): Input tensor of shape (3, H, W).
|
| 314 |
+
file_path (str): Path to save the PNG file.
|
| 315 |
+
"""
|
| 316 |
+
if tensor.ndim != 3 or tensor.shape[0] != 3:
|
| 317 |
+
raise ValueError("Input tensor must have shape (3, H, W)")
|
| 318 |
+
|
| 319 |
+
# Convert tensor to PIL Image
|
| 320 |
+
image = TF.to_pil_image(tensor)
|
| 321 |
+
|
| 322 |
+
# Save image
|
| 323 |
+
image.save(file_path)
|
| 324 |
+
|
| 325 |
+
class WorldMemMinecraft(DiffusionForcingBase):
|
| 326 |
+
"""
|
| 327 |
+
Video generation for MineCraft with memory.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
def __init__(self, cfg: DictConfig):
|
| 331 |
+
"""
|
| 332 |
+
Initialize the WorldMemMinecraft class with the given configuration.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
cfg (DictConfig): Configuration object.
|
| 336 |
+
"""
|
| 337 |
+
self.n_tokens = cfg.n_frames // cfg.frame_stack # number of max tokens for the model
|
| 338 |
+
self.n_frames = cfg.n_frames
|
| 339 |
+
if hasattr(cfg, "n_tokens"):
|
| 340 |
+
self.n_tokens = cfg.n_tokens // cfg.frame_stack
|
| 341 |
+
self.memory_condition_length = cfg.memory_condition_length
|
| 342 |
+
self.pose_cond_dim = getattr(cfg, "pose_cond_dim", 5)
|
| 343 |
+
|
| 344 |
+
self.use_plucker = getattr(cfg, "use_plucker", True)
|
| 345 |
+
self.relative_embedding = getattr(cfg, "relative_embedding", True)
|
| 346 |
+
self.state_embed_only_on_qk = getattr(cfg, "state_embed_only_on_qk", True)
|
| 347 |
+
self.use_memory_attention = getattr(cfg, "use_memory_attention", True)
|
| 348 |
+
self.add_timestamp_embedding = getattr(cfg, "add_timestamp_embedding", True)
|
| 349 |
+
self.ref_mode = getattr(cfg, "ref_mode", 'sequential')
|
| 350 |
+
self.log_curve = getattr(cfg, "log_curve", False)
|
| 351 |
+
self.focal_length = getattr(cfg, "focal_length", 0.35)
|
| 352 |
+
self.log_video = cfg.log_video
|
| 353 |
+
self.save_local = getattr(cfg, "save_local", True)
|
| 354 |
+
self.local_save_dir = getattr(cfg, "local_save_dir", None)
|
| 355 |
+
self.lpips_batch_size = getattr(cfg, "lpips_batch_size", 16)
|
| 356 |
+
self.next_frame_length = getattr(cfg, "next_frame_length", 1)
|
| 357 |
+
self.require_pose_prediction = getattr(cfg, "require_pose_prediction", False)
|
| 358 |
+
|
| 359 |
+
super().__init__(cfg)
|
| 360 |
+
|
| 361 |
+
def _build_model(self):
|
| 362 |
+
|
| 363 |
+
self.diffusion_model = Diffusion(
|
| 364 |
+
reference_length=self.memory_condition_length,
|
| 365 |
+
x_shape=self.x_stacked_shape,
|
| 366 |
+
action_cond_dim=self.action_cond_dim,
|
| 367 |
+
pose_cond_dim=self.pose_cond_dim,
|
| 368 |
+
is_causal=self.causal,
|
| 369 |
+
cfg=self.cfg.diffusion,
|
| 370 |
+
is_dit=True,
|
| 371 |
+
use_plucker=self.use_plucker,
|
| 372 |
+
relative_embedding=self.relative_embedding,
|
| 373 |
+
state_embed_only_on_qk=self.state_embed_only_on_qk,
|
| 374 |
+
use_memory_attention=self.use_memory_attention,
|
| 375 |
+
add_timestamp_embedding=self.add_timestamp_embedding,
|
| 376 |
+
ref_mode=self.ref_mode
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity()
|
| 380 |
+
vae = VAE_models["vit-l-20-shallow-encoder"]()
|
| 381 |
+
self.vae = vae.eval()
|
| 382 |
+
|
| 383 |
+
if self.require_pose_prediction:
|
| 384 |
+
self.pose_prediction_model = PosePredictionNet()
|
| 385 |
+
|
| 386 |
+
def _generate_noise_levels(self, xs: torch.Tensor, masks = None) -> torch.Tensor:
|
| 387 |
+
"""
|
| 388 |
+
Generate noise levels for training.
|
| 389 |
+
"""
|
| 390 |
+
num_frames, batch_size, *_ = xs.shape
|
| 391 |
+
match self.cfg.noise_level:
|
| 392 |
+
case "random_all": # entirely random noise levels
|
| 393 |
+
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
|
| 394 |
+
case "same":
|
| 395 |
+
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
|
| 396 |
+
noise_levels[1:] = noise_levels[0]
|
| 397 |
+
|
| 398 |
+
if masks is not None:
|
| 399 |
+
# for frames that are not available, treat as full noise
|
| 400 |
+
discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
|
| 401 |
+
noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
|
| 402 |
+
|
| 403 |
+
return noise_levels
|
| 404 |
+
|
| 405 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
| 406 |
+
"""
|
| 407 |
+
Perform a single training step.
|
| 408 |
+
|
| 409 |
+
This function processes the input batch,
|
| 410 |
+
encodes the input frames, generates noise levels, and computes the loss using the diffusion model.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
batch: Input batch of data containing frames, conditions, poses, etc.
|
| 414 |
+
batch_idx: Index of the current batch.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
dict: A dictionary containing the training loss.
|
| 418 |
+
"""
|
| 419 |
+
xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
|
| 420 |
+
|
| 421 |
+
if self.use_plucker:
|
| 422 |
+
if self.relative_embedding:
|
| 423 |
+
input_pose_condition = []
|
| 424 |
+
frame_idx_list = []
|
| 425 |
+
for i in range(self.n_frames):
|
| 426 |
+
input_pose_condition.append(
|
| 427 |
+
convert_to_plucker(
|
| 428 |
+
torch.cat([c2w_mat[i:i + 1], c2w_mat[-self.memory_condition_length:]]).clone(),
|
| 429 |
+
0,
|
| 430 |
+
focal_length=self.focal_length,
|
| 431 |
+
image_height=xs.shape[-2],image_width=xs.shape[-1]
|
| 432 |
+
).to(xs.dtype)
|
| 433 |
+
)
|
| 434 |
+
frame_idx_list.append(
|
| 435 |
+
torch.cat([
|
| 436 |
+
frame_idx[i:i + 1] - frame_idx[i:i + 1],
|
| 437 |
+
frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1]
|
| 438 |
+
]).clone()
|
| 439 |
+
)
|
| 440 |
+
input_pose_condition = torch.cat(input_pose_condition)
|
| 441 |
+
frame_idx_list = torch.cat(frame_idx_list)
|
| 442 |
+
else:
|
| 443 |
+
input_pose_condition = convert_to_plucker(
|
| 444 |
+
c2w_mat, 0, focal_length=self.focal_length
|
| 445 |
+
).to(xs.dtype)
|
| 446 |
+
frame_idx_list = frame_idx
|
| 447 |
+
else:
|
| 448 |
+
input_pose_condition = pose_conditions.to(xs.dtype)
|
| 449 |
+
frame_idx_list = None
|
| 450 |
+
|
| 451 |
+
xs = self.encode(xs)
|
| 452 |
+
|
| 453 |
+
noise_levels = self._generate_noise_levels(xs)
|
| 454 |
+
|
| 455 |
+
if self.memory_condition_length:
|
| 456 |
+
noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
|
| 457 |
+
conditions[-self.memory_condition_length:] *= 0
|
| 458 |
+
|
| 459 |
+
_, loss = self.diffusion_model(
|
| 460 |
+
xs,
|
| 461 |
+
conditions,
|
| 462 |
+
input_pose_condition,
|
| 463 |
+
noise_levels=noise_levels,
|
| 464 |
+
reference_length=self.memory_condition_length,
|
| 465 |
+
frame_idx=frame_idx_list
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
if self.memory_condition_length:
|
| 469 |
+
loss = loss[:-self.memory_condition_length]
|
| 470 |
+
|
| 471 |
+
loss = self.reweight_loss(loss, None)
|
| 472 |
+
|
| 473 |
+
if batch_idx % 20 == 0:
|
| 474 |
+
self.log("training/loss", loss.cpu())
|
| 475 |
+
|
| 476 |
+
return {"loss": loss}
|
| 477 |
+
|
| 478 |
+
def on_validation_epoch_end(self, namespace="validation") -> None:
|
| 479 |
+
if not self.validation_step_outputs:
|
| 480 |
+
return
|
| 481 |
+
|
| 482 |
+
xs_pred = []
|
| 483 |
+
xs = []
|
| 484 |
+
for pred, gt in self.validation_step_outputs:
|
| 485 |
+
xs_pred.append(pred)
|
| 486 |
+
xs.append(gt)
|
| 487 |
+
|
| 488 |
+
xs_pred = torch.cat(xs_pred, 1)
|
| 489 |
+
if gt is not None:
|
| 490 |
+
xs = torch.cat(xs, 1)
|
| 491 |
+
else:
|
| 492 |
+
xs = None
|
| 493 |
+
|
| 494 |
+
if self.logger and self.log_video:
|
| 495 |
+
log_video(
|
| 496 |
+
xs_pred,
|
| 497 |
+
xs,
|
| 498 |
+
step=None if namespace == "test" else self.global_step,
|
| 499 |
+
namespace=namespace + "_vis",
|
| 500 |
+
context_frames=self.context_frames,
|
| 501 |
+
logger=self.logger.experiment,
|
| 502 |
+
save_local=self.save_local,
|
| 503 |
+
local_save_dir=self.local_save_dir,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if xs is not None:
|
| 507 |
+
# Move data to the same device as LPIPS model for metric calculation
|
| 508 |
+
device = next(self.validation_lpips_model.parameters()).device
|
| 509 |
+
xs_pred_device = xs_pred.to(device)
|
| 510 |
+
xs_device = xs.to(device)
|
| 511 |
+
|
| 512 |
+
metric_dict = get_validation_metrics_for_videos(
|
| 513 |
+
xs_pred_device, xs_device,
|
| 514 |
+
lpips_model=self.validation_lpips_model,
|
| 515 |
+
lpips_batch_size=self.lpips_batch_size)
|
| 516 |
+
|
| 517 |
+
self.log_dict(
|
| 518 |
+
{"mse": metric_dict['mse'],
|
| 519 |
+
"psnr": metric_dict['psnr'],
|
| 520 |
+
"lpips": metric_dict['lpips']},
|
| 521 |
+
sync_dist=True
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if self.log_curve:
|
| 525 |
+
psnr_values = metric_dict['frame_wise_psnr'].cpu().tolist()
|
| 526 |
+
frames = list(range(len(psnr_values)))
|
| 527 |
+
line_plot = wandb.plot.line_series(
|
| 528 |
+
xs = frames,
|
| 529 |
+
ys = [psnr_values],
|
| 530 |
+
keys = ["PSNR"],
|
| 531 |
+
title = "Frame-wise PSNR",
|
| 532 |
+
xname = "Frame index"
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
self.logger.experiment.log({"frame_wise_psnr_plot": line_plot})
|
| 536 |
+
|
| 537 |
+
self.validation_step_outputs.clear()
|
| 538 |
+
|
| 539 |
+
def _preprocess_batch(self, batch):
|
| 540 |
+
|
| 541 |
+
xs, conditions, pose_conditions, frame_index = batch
|
| 542 |
+
|
| 543 |
+
if self.action_cond_dim:
|
| 544 |
+
conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
|
| 545 |
+
conditions = rearrange(conditions, "b t d -> t b d").contiguous()
|
| 546 |
+
else:
|
| 547 |
+
raise NotImplementedError("Only support external cond.")
|
| 548 |
+
|
| 549 |
+
pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous()
|
| 550 |
+
c2w_mat = euler_to_camera_to_world_matrix(pose_conditions)
|
| 551 |
+
xs = rearrange(xs, "b t c ... -> t b c ...").contiguous()
|
| 552 |
+
frame_index = rearrange(frame_index, "b t -> t b").contiguous()
|
| 553 |
+
|
| 554 |
+
return xs, conditions, pose_conditions, c2w_mat, frame_index
|
| 555 |
+
|
| 556 |
+
def encode(self, x):
|
| 557 |
+
# vae encoding
|
| 558 |
+
T = x.shape[0]
|
| 559 |
+
H, W = x.shape[-2:]
|
| 560 |
+
scaling_factor = 0.07843137255
|
| 561 |
+
|
| 562 |
+
x = rearrange(x, "t b c h w -> (t b) c h w")
|
| 563 |
+
with torch.no_grad():
|
| 564 |
+
x = self.vae.encode(x * 2 - 1).mean * scaling_factor
|
| 565 |
+
x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size)
|
| 566 |
+
return x
|
| 567 |
+
|
| 568 |
+
def decode(self, x):
|
| 569 |
+
total_frames = x.shape[0]
|
| 570 |
+
scaling_factor = 0.07843137255
|
| 571 |
+
x = rearrange(x, "t b c h w -> (t b) (h w) c")
|
| 572 |
+
with torch.no_grad():
|
| 573 |
+
x = (self.vae.decode(x / scaling_factor) + 1) / 2
|
| 574 |
+
x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames)
|
| 575 |
+
return x
|
| 576 |
+
|
| 577 |
+
def _generate_condition_indices(self, curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon):
|
| 578 |
+
"""
|
| 579 |
+
Generate indices for condition similarity based on the current frame and pose conditions.
|
| 580 |
+
"""
|
| 581 |
+
if curr_frame < memory_condition_length:
|
| 582 |
+
random_idx = [i for i in range(curr_frame)] + [0] * (memory_condition_length - curr_frame)
|
| 583 |
+
random_idx = np.repeat(np.array(random_idx)[:, None], xs_pred.shape[1], -1)
|
| 584 |
+
else:
|
| 585 |
+
# Generate points in a sphere and filter based on field of view
|
| 586 |
+
num_samples = 10000
|
| 587 |
+
radius = 30
|
| 588 |
+
points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device)
|
| 589 |
+
points = points[:, None].repeat(1, pose_conditions.shape[1], 1)
|
| 590 |
+
points += pose_conditions[curr_frame, :, :3][None]
|
| 591 |
+
fov_half_h = torch.tensor(105 / 2, device=pose_conditions.device)
|
| 592 |
+
fov_half_v = torch.tensor(75 / 2, device=pose_conditions.device)
|
| 593 |
+
|
| 594 |
+
# in_fov1 = is_inside_fov_3d_hv(
|
| 595 |
+
# points, pose_conditions[curr_frame, :, :3],
|
| 596 |
+
# pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1],
|
| 597 |
+
# fov_half_h, fov_half_v
|
| 598 |
+
# )
|
| 599 |
+
|
| 600 |
+
in_fov1 = torch.stack([
|
| 601 |
+
is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
|
| 602 |
+
for pc in pose_conditions[curr_frame:curr_frame+horizon]
|
| 603 |
+
])
|
| 604 |
+
|
| 605 |
+
in_fov1 = torch.sum(in_fov1, 0) > 0
|
| 606 |
+
|
| 607 |
+
# Compute overlap ratios and select indices
|
| 608 |
+
in_fov_list = torch.stack([
|
| 609 |
+
is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
|
| 610 |
+
for pc in pose_conditions[:curr_frame]
|
| 611 |
+
])
|
| 612 |
+
|
| 613 |
+
random_idx = []
|
| 614 |
+
for _ in range(memory_condition_length):
|
| 615 |
+
overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
|
| 616 |
+
|
| 617 |
+
confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
|
| 618 |
+
|
| 619 |
+
if len(random_idx) > 0:
|
| 620 |
+
confidence[torch.cat(random_idx)] = -1e10
|
| 621 |
+
_, r_idx = torch.topk(confidence, k=1, dim=0)
|
| 622 |
+
random_idx.append(r_idx[0])
|
| 623 |
+
|
| 624 |
+
# choice 1: directly remove overlapping region
|
| 625 |
+
occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
|
| 626 |
+
in_fov1 = in_fov1 & ~occupied_mask
|
| 627 |
+
|
| 628 |
+
# choice 2: apply similarity filter
|
| 629 |
+
# cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
|
| 630 |
+
# range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
|
| 631 |
+
# cos_sim = cos_sim.mean((-2,-1))
|
| 632 |
+
|
| 633 |
+
# mask_sim = cos_sim>0.9
|
| 634 |
+
# in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
|
| 635 |
+
|
| 636 |
+
random_idx = torch.stack(random_idx).cpu()
|
| 637 |
+
|
| 638 |
+
return random_idx
|
| 639 |
+
|
| 640 |
+
def _prepare_conditions(self,
|
| 641 |
+
start_frame, curr_frame, horizon, conditions,
|
| 642 |
+
pose_conditions, c2w_mat, frame_idx, random_idx,
|
| 643 |
+
image_width, image_height):
|
| 644 |
+
"""
|
| 645 |
+
Prepare input conditions and pose conditions for sampling.
|
| 646 |
+
"""
|
| 647 |
+
|
| 648 |
+
padding = torch.zeros((len(random_idx),) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype)
|
| 649 |
+
input_condition = torch.cat([conditions[start_frame:curr_frame + horizon], padding], dim=0)
|
| 650 |
+
|
| 651 |
+
batch_size = conditions.shape[1]
|
| 652 |
+
|
| 653 |
+
if self.use_plucker:
|
| 654 |
+
if self.relative_embedding:
|
| 655 |
+
frame_idx_list = []
|
| 656 |
+
input_pose_condition = []
|
| 657 |
+
for i in range(start_frame, curr_frame + horizon):
|
| 658 |
+
input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]]).clone(), 0, focal_length=self.focal_length,
|
| 659 |
+
image_width=image_width, image_height=image_height).to(conditions.dtype))
|
| 660 |
+
frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(batch_size)], range(batch_size)]-frame_idx[i:i+1]]))
|
| 661 |
+
input_pose_condition = torch.cat(input_pose_condition)
|
| 662 |
+
frame_idx_list = torch.cat(frame_idx_list)
|
| 663 |
+
|
| 664 |
+
else:
|
| 665 |
+
input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
|
| 666 |
+
input_pose_condition = convert_to_plucker(input_pose_condition, 0, focal_length=self.focal_length)
|
| 667 |
+
frame_idx_list = None
|
| 668 |
+
else:
|
| 669 |
+
input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
|
| 670 |
+
frame_idx_list = None
|
| 671 |
+
|
| 672 |
+
return input_condition, input_pose_condition, frame_idx_list
|
| 673 |
+
|
| 674 |
+
def _prepare_noise_levels(self, scheduling_matrix, m, curr_frame, batch_size, memory_condition_length):
|
| 675 |
+
"""
|
| 676 |
+
Prepare noise levels for the current sampling step.
|
| 677 |
+
"""
|
| 678 |
+
from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[:, None].repeat(batch_size, axis=1)
|
| 679 |
+
to_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m + 1]))[:, None].repeat(batch_size, axis=1)
|
| 680 |
+
if memory_condition_length:
|
| 681 |
+
from_noise_levels = np.concatenate([from_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
|
| 682 |
+
to_noise_levels = np.concatenate([to_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
|
| 683 |
+
from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
|
| 684 |
+
to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
|
| 685 |
+
return from_noise_levels, to_noise_levels
|
| 686 |
+
|
| 687 |
+
def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
|
| 688 |
+
"""
|
| 689 |
+
Perform a single validation step.
|
| 690 |
+
|
| 691 |
+
This function processes the input batch, encodes frames, generates predictions using a sliding window approach,
|
| 692 |
+
and handles condition similarity logic for sampling. The results are decoded and stored for evaluation.
|
| 693 |
+
|
| 694 |
+
Args:
|
| 695 |
+
batch: Input batch of data containing frames, conditions, poses, etc.
|
| 696 |
+
batch_idx: Index of the current batch.
|
| 697 |
+
namespace: Namespace for logging (default: "validation").
|
| 698 |
+
|
| 699 |
+
Returns:
|
| 700 |
+
None: Appends the predicted and ground truth frames to `self.validation_step_outputs`.
|
| 701 |
+
"""
|
| 702 |
+
# Preprocess the input batch
|
| 703 |
+
memory_condition_length = self.memory_condition_length
|
| 704 |
+
xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
# Encode frames in chunks if necessary
|
| 708 |
+
total_frame = xs_raw.shape[0]
|
| 709 |
+
if total_frame > 10:
|
| 710 |
+
xs = torch.cat([
|
| 711 |
+
self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu()
|
| 712 |
+
for i in range(10)
|
| 713 |
+
])
|
| 714 |
+
else:
|
| 715 |
+
xs = self.encode(xs_raw).cpu()
|
| 716 |
+
|
| 717 |
+
n_frames, batch_size, *_ = xs.shape
|
| 718 |
+
curr_frame = 0
|
| 719 |
+
|
| 720 |
+
# Initialize context frames
|
| 721 |
+
n_context_frames = self.context_frames // self.frame_stack
|
| 722 |
+
xs_pred = xs[:n_context_frames].clone()
|
| 723 |
+
curr_frame += n_context_frames
|
| 724 |
+
|
| 725 |
+
# Progress bar for sampling
|
| 726 |
+
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
| 727 |
+
|
| 728 |
+
while curr_frame < n_frames:
|
| 729 |
+
# Determine the horizon for the current chunk
|
| 730 |
+
horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame
|
| 731 |
+
assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens."
|
| 732 |
+
|
| 733 |
+
# Generate scheduling matrix and initialize noise
|
| 734 |
+
scheduling_matrix = self._generate_scheduling_matrix(horizon)
|
| 735 |
+
chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:]))
|
| 736 |
+
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device)
|
| 737 |
+
xs_pred = torch.cat([xs_pred, chunk], 0)
|
| 738 |
+
|
| 739 |
+
# Sliding window: only input the last `n_tokens` frames
|
| 740 |
+
start_frame = max(0, curr_frame + horizon - self.n_tokens)
|
| 741 |
+
pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon})
|
| 742 |
+
|
| 743 |
+
# Handle condition similarity logic
|
| 744 |
+
if memory_condition_length:
|
| 745 |
+
random_idx = self._generate_condition_indices(
|
| 746 |
+
curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
|
| 750 |
+
|
| 751 |
+
# Prepare input conditions and pose conditions
|
| 752 |
+
input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
|
| 753 |
+
start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
|
| 754 |
+
image_width=xs_raw.shape[-1], image_height=xs_raw.shape[-2]
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# Perform sampling for each step in the scheduling matrix
|
| 758 |
+
for m in range(scheduling_matrix.shape[0] - 1):
|
| 759 |
+
from_noise_levels, to_noise_levels = self._prepare_noise_levels(
|
| 760 |
+
scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
xs_pred[start_frame:] = self.diffusion_model.sample_step(
|
| 764 |
+
xs_pred[start_frame:].to(input_condition.device),
|
| 765 |
+
input_condition,
|
| 766 |
+
input_pose_condition,
|
| 767 |
+
from_noise_levels[start_frame:],
|
| 768 |
+
to_noise_levels[start_frame:],
|
| 769 |
+
current_frame=curr_frame,
|
| 770 |
+
mode="validation",
|
| 771 |
+
reference_length=memory_condition_length,
|
| 772 |
+
frame_idx=frame_idx_list
|
| 773 |
+
).cpu()
|
| 774 |
+
|
| 775 |
+
# Remove condition similarity frames if applicable
|
| 776 |
+
if memory_condition_length:
|
| 777 |
+
xs_pred = xs_pred[:-memory_condition_length]
|
| 778 |
+
|
| 779 |
+
curr_frame += horizon
|
| 780 |
+
pbar.update(horizon)
|
| 781 |
+
|
| 782 |
+
# Decode predictions and ground truth
|
| 783 |
+
xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
|
| 784 |
+
xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
|
| 785 |
+
|
| 786 |
+
# Store results for evaluation (move to CPU to save GPU memory)
|
| 787 |
+
self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
|
| 788 |
+
return
|
| 789 |
+
|
| 790 |
+
@torch.no_grad()
|
| 791 |
+
def interactive(self, first_frame, new_actions, first_pose, device,
|
| 792 |
+
memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
|
| 793 |
+
|
| 794 |
+
memory_condition_length = self.memory_condition_length
|
| 795 |
+
|
| 796 |
+
if memory_latent_frames is None:
|
| 797 |
+
first_frame = torch.from_numpy(first_frame)
|
| 798 |
+
new_actions = torch.from_numpy(new_actions)
|
| 799 |
+
first_pose = torch.from_numpy(first_pose)
|
| 800 |
+
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
| 801 |
+
memory_latent_frames = first_frame_encode.cpu()
|
| 802 |
+
memory_actions = new_actions[None, None].to(device)
|
| 803 |
+
memory_poses = first_pose[None, None].to(device)
|
| 804 |
+
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
| 805 |
+
memory_c2w = new_c2w_mat[None, None].to(device)
|
| 806 |
+
memory_frame_idx = torch.tensor([[0]]).to(device)
|
| 807 |
+
return first_frame.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()
|
| 808 |
+
else:
|
| 809 |
+
memory_latent_frames = torch.from_numpy(memory_latent_frames)
|
| 810 |
+
memory_actions = torch.from_numpy(memory_actions).to(device)
|
| 811 |
+
memory_poses = torch.from_numpy(memory_poses).to(device)
|
| 812 |
+
memory_c2w = torch.from_numpy(memory_c2w).to(device)
|
| 813 |
+
memory_frame_idx = torch.from_numpy(memory_frame_idx).to(device)
|
| 814 |
+
new_actions = new_actions.to(device)
|
| 815 |
+
|
| 816 |
+
curr_frame = 0
|
| 817 |
+
batch_size = 1
|
| 818 |
+
horizon = self.next_frame_length
|
| 819 |
+
n_frames = curr_frame + horizon
|
| 820 |
+
# context
|
| 821 |
+
n_context_frames = len(memory_latent_frames)
|
| 822 |
+
xs_pred = memory_latent_frames[:n_context_frames].clone()
|
| 823 |
+
curr_frame += n_context_frames
|
| 824 |
+
|
| 825 |
+
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
| 826 |
+
|
| 827 |
+
new_pose_condition_list = []
|
| 828 |
+
last_frame = xs_pred[-1].clone()
|
| 829 |
+
last_pose_condition = memory_poses[-1].clone()
|
| 830 |
+
curr_actions = new_actions.clone()
|
| 831 |
+
for hi in range(len(new_actions)):
|
| 832 |
+
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
| 833 |
+
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None, hi], last_pose_condition)
|
| 834 |
+
new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
|
| 835 |
+
new_pose_condition = last_pose_condition + new_pose_condition_offset
|
| 836 |
+
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
| 837 |
+
new_pose_condition[:,3:] %= 360
|
| 838 |
+
last_pose_condition = new_pose_condition.clone()
|
| 839 |
+
new_pose_condition_list.append(new_pose_condition[None])
|
| 840 |
+
new_pose_condition_list = torch.cat(new_pose_condition_list, 0)
|
| 841 |
+
|
| 842 |
+
ai = 0
|
| 843 |
+
while ai < len(new_actions):
|
| 844 |
+
next_horizon = min(horizon, len(new_actions) - ai)
|
| 845 |
+
last_frame = xs_pred[-1].clone()
|
| 846 |
+
curr_actions = new_actions[ai:ai+next_horizon].clone()
|
| 847 |
+
|
| 848 |
+
new_pose_condition = new_pose_condition_list[ai:ai+next_horizon].clone()
|
| 849 |
+
|
| 850 |
+
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
| 851 |
+
memory_poses = torch.cat([memory_poses, new_pose_condition])
|
| 852 |
+
memory_actions = torch.cat([memory_actions, curr_actions[:, None]])
|
| 853 |
+
memory_c2w = torch.cat([memory_c2w, new_c2w_mat])
|
| 854 |
+
new_indices = memory_frame_idx[-1,0] + torch.arange(next_horizon, device=memory_frame_idx.device) + 1
|
| 855 |
+
|
| 856 |
+
memory_frame_idx = torch.cat([memory_frame_idx, new_indices[:, None]])
|
| 857 |
+
|
| 858 |
+
conditions = memory_actions.clone()
|
| 859 |
+
pose_conditions = memory_poses.clone()
|
| 860 |
+
c2w_mat = memory_c2w .clone()
|
| 861 |
+
frame_idx = memory_frame_idx.clone()
|
| 862 |
+
|
| 863 |
+
# generation on frame
|
| 864 |
+
scheduling_matrix = self._generate_scheduling_matrix(next_horizon)
|
| 865 |
+
chunk = torch.randn((next_horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
|
| 866 |
+
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
|
| 867 |
+
|
| 868 |
+
xs_pred = torch.cat([xs_pred, chunk], 0)
|
| 869 |
+
|
| 870 |
+
# sliding window: only input the last n_tokens frames
|
| 871 |
+
start_frame = max(0, curr_frame - self.n_tokens)
|
| 872 |
+
|
| 873 |
+
pbar.set_postfix(
|
| 874 |
+
{
|
| 875 |
+
"start": start_frame,
|
| 876 |
+
"end": curr_frame + next_horizon,
|
| 877 |
+
}
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
# Handle condition similarity logic
|
| 881 |
+
if memory_condition_length:
|
| 882 |
+
random_idx = self._generate_condition_indices(
|
| 883 |
+
curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, next_horizon
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
# random_idx = np.unique(random_idx)[:, None]
|
| 887 |
+
# memory_condition_length = len(random_idx)
|
| 888 |
+
xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
|
| 889 |
+
|
| 890 |
+
# Prepare input conditions and pose conditions
|
| 891 |
+
input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
|
| 892 |
+
start_frame, curr_frame, next_horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
|
| 893 |
+
image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
# Perform sampling for each step in the scheduling matrix
|
| 897 |
+
for m in range(scheduling_matrix.shape[0] - 1):
|
| 898 |
+
from_noise_levels, to_noise_levels = self._prepare_noise_levels(
|
| 899 |
+
scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
xs_pred[start_frame:] = self.diffusion_model.sample_step(
|
| 903 |
+
xs_pred[start_frame:].to(input_condition.device),
|
| 904 |
+
input_condition,
|
| 905 |
+
input_pose_condition,
|
| 906 |
+
from_noise_levels[start_frame:],
|
| 907 |
+
to_noise_levels[start_frame:],
|
| 908 |
+
current_frame=curr_frame,
|
| 909 |
+
mode="validation",
|
| 910 |
+
reference_length=memory_condition_length,
|
| 911 |
+
frame_idx=frame_idx_list
|
| 912 |
+
).cpu()
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
if memory_condition_length:
|
| 916 |
+
xs_pred = xs_pred[:-memory_condition_length]
|
| 917 |
+
|
| 918 |
+
curr_frame += next_horizon
|
| 919 |
+
pbar.update(next_horizon)
|
| 920 |
+
ai += next_horizon
|
| 921 |
+
|
| 922 |
+
memory_latent_frames = torch.cat([memory_latent_frames, xs_pred[n_context_frames:]])
|
| 923 |
+
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
| 924 |
+
|
| 925 |
+
return xs_pred.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), \
|
| 926 |
+
memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()
|
algorithms/worldmem/models/attention.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from collections import namedtuple
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
class TemporalAxialAttention(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dim: int,
|
| 18 |
+
heads: int,
|
| 19 |
+
dim_head: int,
|
| 20 |
+
reference_length: int,
|
| 21 |
+
rotary_emb: RotaryEmbedding,
|
| 22 |
+
is_causal: bool = True,
|
| 23 |
+
is_temporal_independent: bool = False,
|
| 24 |
+
use_domain_adapter = False
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.inner_dim = dim_head * heads
|
| 28 |
+
self.heads = heads
|
| 29 |
+
self.head_dim = dim_head
|
| 30 |
+
self.inner_dim = dim_head * heads
|
| 31 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 32 |
+
|
| 33 |
+
self.use_domain_adapter = use_domain_adapter
|
| 34 |
+
if self.use_domain_adapter:
|
| 35 |
+
lora_rank = 8
|
| 36 |
+
self.lora_A = nn.Linear(dim, lora_rank, bias=False)
|
| 37 |
+
self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
|
| 38 |
+
|
| 39 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 40 |
+
|
| 41 |
+
self.rotary_emb = rotary_emb
|
| 42 |
+
self.is_causal = is_causal
|
| 43 |
+
self.is_temporal_independent = is_temporal_independent
|
| 44 |
+
|
| 45 |
+
self.reference_length = reference_length
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor):
|
| 48 |
+
B, T, H, W, D = x.shape
|
| 49 |
+
|
| 50 |
+
# if T>=9:
|
| 51 |
+
# try:
|
| 52 |
+
# # x = torch.cat([x[:,:-1],x[:,16-T:17-T],x[:,-1:]], dim=1)
|
| 53 |
+
# x = torch.cat([x[:,16-T:17-T],x], dim=1)
|
| 54 |
+
# except:
|
| 55 |
+
# import pdb;pdb.set_trace()
|
| 56 |
+
# print("="*50)
|
| 57 |
+
# print(x.shape)
|
| 58 |
+
|
| 59 |
+
B, T, H, W, D = x.shape
|
| 60 |
+
|
| 61 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 62 |
+
|
| 63 |
+
if self.use_domain_adapter:
|
| 64 |
+
q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
|
| 65 |
+
q = q+q_lora
|
| 66 |
+
k = k+k_lora
|
| 67 |
+
v = v+v_lora
|
| 68 |
+
|
| 69 |
+
q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 70 |
+
k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 71 |
+
v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 72 |
+
|
| 73 |
+
q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
|
| 74 |
+
k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
|
| 75 |
+
|
| 76 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 77 |
+
|
| 78 |
+
if self.is_temporal_independent:
|
| 79 |
+
attn_bias = torch.ones((T, T), dtype=q.dtype, device=q.device)
|
| 80 |
+
attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
|
| 81 |
+
attn_bias[range(T), range(T)] = 0
|
| 82 |
+
elif self.is_causal:
|
| 83 |
+
attn_bias = torch.triu(torch.ones((T, T), dtype=q.dtype, device=q.device), diagonal=1)
|
| 84 |
+
attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
|
| 85 |
+
attn_bias[(T-self.reference_length):] = float('-inf')
|
| 86 |
+
attn_bias[range(T), range(T)] = 0
|
| 87 |
+
else:
|
| 88 |
+
attn_bias = None
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
|
| 92 |
+
except:
|
| 93 |
+
import pdb;pdb.set_trace()
|
| 94 |
+
|
| 95 |
+
x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
|
| 96 |
+
x = x.to(q.dtype)
|
| 97 |
+
|
| 98 |
+
# linear proj
|
| 99 |
+
x = self.to_out(x)
|
| 100 |
+
|
| 101 |
+
# if T>=10:
|
| 102 |
+
# try:
|
| 103 |
+
# # x = torch.cat([x[:,:-2],x[:,-1:]], dim=1)
|
| 104 |
+
# x = x[:,1:]
|
| 105 |
+
# except:
|
| 106 |
+
# import pdb;pdb.set_trace()
|
| 107 |
+
# print(x.shape)
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
class SpatialAxialAttention(nn.Module):
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
dim: int,
|
| 114 |
+
heads: int,
|
| 115 |
+
dim_head: int,
|
| 116 |
+
rotary_emb: RotaryEmbedding,
|
| 117 |
+
use_domain_adapter = False
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.inner_dim = dim_head * heads
|
| 121 |
+
self.heads = heads
|
| 122 |
+
self.head_dim = dim_head
|
| 123 |
+
self.inner_dim = dim_head * heads
|
| 124 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 125 |
+
self.use_domain_adapter = use_domain_adapter
|
| 126 |
+
if self.use_domain_adapter:
|
| 127 |
+
lora_rank = 8
|
| 128 |
+
self.lora_A = nn.Linear(dim, lora_rank, bias=False)
|
| 129 |
+
self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
|
| 130 |
+
|
| 131 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 132 |
+
|
| 133 |
+
self.rotary_emb = rotary_emb
|
| 134 |
+
|
| 135 |
+
def forward(self, x: torch.Tensor):
|
| 136 |
+
B, T, H, W, D = x.shape
|
| 137 |
+
|
| 138 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 139 |
+
|
| 140 |
+
if self.use_domain_adapter:
|
| 141 |
+
q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
|
| 142 |
+
q = q+q_lora
|
| 143 |
+
k = k+k_lora
|
| 144 |
+
v = v+v_lora
|
| 145 |
+
|
| 146 |
+
q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
|
| 147 |
+
k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
|
| 148 |
+
v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
|
| 149 |
+
|
| 150 |
+
freqs = self.rotary_emb.get_axial_freqs(H, W)
|
| 151 |
+
q = apply_rotary_emb(freqs, q)
|
| 152 |
+
k = apply_rotary_emb(freqs, k)
|
| 153 |
+
|
| 154 |
+
# prepare for attn
|
| 155 |
+
q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
|
| 156 |
+
k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
|
| 157 |
+
v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
|
| 158 |
+
|
| 159 |
+
x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False)
|
| 160 |
+
|
| 161 |
+
x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
|
| 162 |
+
x = x.to(q.dtype)
|
| 163 |
+
|
| 164 |
+
# linear proj
|
| 165 |
+
x = self.to_out(x)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
class MemTemporalAxialAttention(nn.Module):
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
dim: int,
|
| 172 |
+
heads: int,
|
| 173 |
+
dim_head: int,
|
| 174 |
+
rotary_emb: RotaryEmbedding,
|
| 175 |
+
is_causal: bool = True,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.inner_dim = dim_head * heads
|
| 179 |
+
self.heads = heads
|
| 180 |
+
self.head_dim = dim_head
|
| 181 |
+
self.inner_dim = dim_head * heads
|
| 182 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 183 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 184 |
+
|
| 185 |
+
self.rotary_emb = rotary_emb
|
| 186 |
+
self.is_causal = is_causal
|
| 187 |
+
|
| 188 |
+
self.reference_length = 3
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor):
|
| 191 |
+
B, T, H, W, D = x.shape
|
| 192 |
+
|
| 193 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 197 |
+
k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 198 |
+
v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
|
| 203 |
+
# k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
|
| 204 |
+
|
| 205 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 206 |
+
|
| 207 |
+
# if T == 21000:
|
| 208 |
+
# # 手动计算缩放点积分数
|
| 209 |
+
# _, _, _, d_k = q.shape
|
| 210 |
+
# scores = torch.einsum("b h n d, b h m d -> b h n m", q, k) / (d_k ** 0.5) # Shape: (B, T_q, T_k)
|
| 211 |
+
|
| 212 |
+
# # 计算注意力图 (Attention Map)
|
| 213 |
+
# attention_map = F.softmax(scores, dim=-1) # Shape: (B, T_q, T_k)
|
| 214 |
+
# b_, h_, n_, m_ = attention_map.shape
|
| 215 |
+
# attention_map = attention_map.reshape(1, int(np.sqrt(b_/1)), int(np.sqrt(b_/1)), h_, n_, m_)
|
| 216 |
+
# attention_map = attention_map.mean(3)
|
| 217 |
+
|
| 218 |
+
# attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
|
| 219 |
+
# T_origin = T - self.reference_length
|
| 220 |
+
# attn_bias[:T_origin, T_origin:] = 1
|
| 221 |
+
# attn_bias[range(T), range(T)] = 1
|
| 222 |
+
|
| 223 |
+
# attention_map = attention_map * attn_bias
|
| 224 |
+
|
| 225 |
+
# # print 注意力图
|
| 226 |
+
# import matplotlib.pyplot as plt
|
| 227 |
+
# fig, axes = plt.subplots(21000, 21000, figsize=(9, 9)) # 调整figsize以适配图像大小
|
| 228 |
+
|
| 229 |
+
# # 遍历3*3维度
|
| 230 |
+
# for i in range(21000):
|
| 231 |
+
# for j in range(21000):
|
| 232 |
+
# # 取出第(i, j)个子图像
|
| 233 |
+
# img = attention_map[0, :, :, i, j].cpu().numpy()
|
| 234 |
+
# axes[i, j].imshow(img, cmap='viridis') # 可以自定义cmap
|
| 235 |
+
# axes[i, j].axis('off') # 隐藏坐标轴
|
| 236 |
+
|
| 237 |
+
# # 调整子图间距
|
| 238 |
+
# plt.tight_layout()
|
| 239 |
+
# plt.savefig('attention_map.png')
|
| 240 |
+
# import pdb; pdb.set_trace()
|
| 241 |
+
# plt.close()
|
| 242 |
+
|
| 243 |
+
attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
|
| 244 |
+
attn_bias = attn_bias.masked_fill(attn_bias == 0, float('-inf'))
|
| 245 |
+
T_origin = T - self.reference_length
|
| 246 |
+
attn_bias[:T_origin, T_origin:] = 0
|
| 247 |
+
attn_bias[range(T), range(T)] = 0
|
| 248 |
+
|
| 249 |
+
# if T==121000:
|
| 250 |
+
# import pdb;pdb.set_trace()
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
|
| 254 |
+
except:
|
| 255 |
+
import pdb;pdb.set_trace()
|
| 256 |
+
|
| 257 |
+
x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
|
| 258 |
+
x = x.to(q.dtype)
|
| 259 |
+
|
| 260 |
+
# linear proj
|
| 261 |
+
x = self.to_out(x)
|
| 262 |
+
return x
|
| 263 |
+
|
| 264 |
+
class MemFullAttention(nn.Module):
|
| 265 |
+
def __init__(
|
| 266 |
+
self,
|
| 267 |
+
dim: int,
|
| 268 |
+
heads: int,
|
| 269 |
+
dim_head: int,
|
| 270 |
+
reference_length: int,
|
| 271 |
+
rotary_emb: RotaryEmbedding,
|
| 272 |
+
is_causal: bool = True
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.inner_dim = dim_head * heads
|
| 276 |
+
self.heads = heads
|
| 277 |
+
self.head_dim = dim_head
|
| 278 |
+
self.inner_dim = dim_head * heads
|
| 279 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 280 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 281 |
+
|
| 282 |
+
self.rotary_emb = rotary_emb
|
| 283 |
+
self.is_causal = is_causal
|
| 284 |
+
|
| 285 |
+
self.reference_length = reference_length
|
| 286 |
+
|
| 287 |
+
self.store = None
|
| 288 |
+
|
| 289 |
+
def forward(self, x: torch.Tensor, relative_embedding=False,
|
| 290 |
+
extra_condition=None,
|
| 291 |
+
state_embed_only_on_qk=False,
|
| 292 |
+
reference_length=None):
|
| 293 |
+
|
| 294 |
+
B, T, H, W, D = x.shape
|
| 295 |
+
|
| 296 |
+
if state_embed_only_on_qk:
|
| 297 |
+
q, k, _ = self.to_qkv(x+extra_condition).chunk(3, dim=-1)
|
| 298 |
+
_, _, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 299 |
+
else:
|
| 300 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 301 |
+
|
| 302 |
+
if relative_embedding:
|
| 303 |
+
length = reference_length+1
|
| 304 |
+
n_frames = T // length
|
| 305 |
+
x = x.reshape(B, n_frames, length, H, W, D)
|
| 306 |
+
|
| 307 |
+
x_list = []
|
| 308 |
+
|
| 309 |
+
for i in range(n_frames):
|
| 310 |
+
if i == n_frames-1:
|
| 311 |
+
q_i = rearrange(q[:, i*length:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 312 |
+
k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 313 |
+
v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 314 |
+
else:
|
| 315 |
+
q_i = rearrange(q[:, i*length:i*length+1], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 316 |
+
k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 317 |
+
v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 318 |
+
|
| 319 |
+
q_i, k_i, v_i = map(lambda t: t.contiguous(), (q_i, k_i, v_i))
|
| 320 |
+
x_i = F.scaled_dot_product_attention(query=q_i, key=k_i, value=v_i)
|
| 321 |
+
x_i = rearrange(x_i, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
|
| 322 |
+
x_i = x_i.to(q.dtype)
|
| 323 |
+
x_list.append(x_i)
|
| 324 |
+
|
| 325 |
+
x = torch.cat(x_list, dim=1)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
else:
|
| 329 |
+
T_ = T - reference_length
|
| 330 |
+
q = rearrange(q, "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 331 |
+
k = rearrange(k[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 332 |
+
v = rearrange(v[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 333 |
+
|
| 334 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 335 |
+
x = F.scaled_dot_product_attention(query=q, key=k, value=v)
|
| 336 |
+
x = rearrange(x, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
|
| 337 |
+
x = x.to(q.dtype)
|
| 338 |
+
|
| 339 |
+
# linear proj
|
| 340 |
+
x = self.to_out(x)
|
| 341 |
+
|
| 342 |
+
return x
|
algorithms/worldmem/models/cameractrl_module.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
class SimpleCameraPoseEncoder(nn.Module):
|
| 3 |
+
def __init__(self, c_in, c_out, hidden_dim=128):
|
| 4 |
+
super(SimpleCameraPoseEncoder, self).__init__()
|
| 5 |
+
self.model = nn.Sequential(
|
| 6 |
+
nn.Linear(c_in, hidden_dim),
|
| 7 |
+
nn.ReLU(),
|
| 8 |
+
nn.Linear(hidden_dim, c_out)
|
| 9 |
+
)
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
return self.model(x)
|
| 12 |
+
|
algorithms/worldmem/models/diffusion.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Callable
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
from omegaconf import DictConfig
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from .utils import linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, extract
|
| 9 |
+
from .dit import DiT_models
|
| 10 |
+
|
| 11 |
+
ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start", "model_out"])
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Diffusion(nn.Module):
|
| 15 |
+
# Special thanks to lucidrains for the implementation of the base Diffusion model
|
| 16 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
x_shape: torch.Size,
|
| 21 |
+
reference_length: int,
|
| 22 |
+
action_cond_dim: int,
|
| 23 |
+
pose_cond_dim,
|
| 24 |
+
is_causal: bool,
|
| 25 |
+
cfg: DictConfig,
|
| 26 |
+
is_dit: bool=False,
|
| 27 |
+
use_plucker=False,
|
| 28 |
+
relative_embedding=False,
|
| 29 |
+
state_embed_only_on_qk=False,
|
| 30 |
+
use_memory_attention=False,
|
| 31 |
+
add_timestamp_embedding=False,
|
| 32 |
+
ref_mode='sequential'
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.cfg = cfg
|
| 36 |
+
|
| 37 |
+
self.x_shape = x_shape
|
| 38 |
+
self.action_cond_dim = action_cond_dim
|
| 39 |
+
self.timesteps = cfg.timesteps
|
| 40 |
+
self.sampling_timesteps = cfg.sampling_timesteps
|
| 41 |
+
self.beta_schedule = cfg.beta_schedule
|
| 42 |
+
self.schedule_fn_kwargs = cfg.schedule_fn_kwargs
|
| 43 |
+
self.objective = cfg.objective
|
| 44 |
+
self.use_fused_snr = cfg.use_fused_snr
|
| 45 |
+
self.snr_clip = cfg.snr_clip
|
| 46 |
+
self.cum_snr_decay = cfg.cum_snr_decay
|
| 47 |
+
self.ddim_sampling_eta = cfg.ddim_sampling_eta
|
| 48 |
+
self.clip_noise = cfg.clip_noise
|
| 49 |
+
self.arch = cfg.architecture
|
| 50 |
+
self.stabilization_level = cfg.stabilization_level
|
| 51 |
+
self.is_causal = is_causal
|
| 52 |
+
self.is_dit = is_dit
|
| 53 |
+
self.reference_length = reference_length
|
| 54 |
+
self.pose_cond_dim = pose_cond_dim
|
| 55 |
+
self.use_plucker = use_plucker
|
| 56 |
+
self.relative_embedding = relative_embedding
|
| 57 |
+
self.state_embed_only_on_qk = state_embed_only_on_qk
|
| 58 |
+
self.use_memory_attention = use_memory_attention
|
| 59 |
+
self.add_timestamp_embedding = add_timestamp_embedding
|
| 60 |
+
self.ref_mode = ref_mode
|
| 61 |
+
|
| 62 |
+
self._build_model()
|
| 63 |
+
self._build_buffer()
|
| 64 |
+
|
| 65 |
+
def _build_model(self):
|
| 66 |
+
x_channel = self.x_shape[0]
|
| 67 |
+
if self.is_dit:
|
| 68 |
+
self.model = DiT_models["DiT-S/2"](action_cond_dim=self.action_cond_dim,
|
| 69 |
+
pose_cond_dim=self.pose_cond_dim, reference_length=self.reference_length,
|
| 70 |
+
use_plucker=self.use_plucker,
|
| 71 |
+
relative_embedding=self.relative_embedding,
|
| 72 |
+
state_embed_only_on_qk=self.state_embed_only_on_qk,
|
| 73 |
+
use_memory_attention=self.use_memory_attention,
|
| 74 |
+
add_timestamp_embedding=self.add_timestamp_embedding,
|
| 75 |
+
ref_mode=self.ref_mode)
|
| 76 |
+
else:
|
| 77 |
+
raise NotImplementedError
|
| 78 |
+
|
| 79 |
+
def _build_buffer(self):
|
| 80 |
+
if self.beta_schedule == "linear":
|
| 81 |
+
beta_schedule_fn = linear_beta_schedule
|
| 82 |
+
elif self.beta_schedule == "cosine":
|
| 83 |
+
beta_schedule_fn = cosine_beta_schedule
|
| 84 |
+
elif self.beta_schedule == "sigmoid":
|
| 85 |
+
beta_schedule_fn = sigmoid_beta_schedule
|
| 86 |
+
else:
|
| 87 |
+
raise ValueError(f"unknown beta schedule {self.beta_schedule}")
|
| 88 |
+
|
| 89 |
+
betas = beta_schedule_fn(self.timesteps, **self.schedule_fn_kwargs)
|
| 90 |
+
|
| 91 |
+
alphas = 1.0 - betas
|
| 92 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 93 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
| 94 |
+
|
| 95 |
+
# sampling related parameters
|
| 96 |
+
assert self.sampling_timesteps <= self.timesteps
|
| 97 |
+
self.is_ddim_sampling = self.sampling_timesteps < self.timesteps
|
| 98 |
+
|
| 99 |
+
# helper function to register buffer from float64 to float32
|
| 100 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
| 101 |
+
|
| 102 |
+
register_buffer("betas", betas)
|
| 103 |
+
register_buffer("alphas_cumprod", alphas_cumprod)
|
| 104 |
+
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
|
| 105 |
+
|
| 106 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 107 |
+
|
| 108 |
+
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
| 109 |
+
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
|
| 110 |
+
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
|
| 111 |
+
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
|
| 112 |
+
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
|
| 113 |
+
|
| 114 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 115 |
+
|
| 116 |
+
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
| 117 |
+
|
| 118 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 119 |
+
|
| 120 |
+
register_buffer("posterior_variance", posterior_variance)
|
| 121 |
+
|
| 122 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 123 |
+
|
| 124 |
+
register_buffer(
|
| 125 |
+
"posterior_log_variance_clipped",
|
| 126 |
+
torch.log(posterior_variance.clamp(min=1e-20)),
|
| 127 |
+
)
|
| 128 |
+
register_buffer(
|
| 129 |
+
"posterior_mean_coef1",
|
| 130 |
+
betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
|
| 131 |
+
)
|
| 132 |
+
register_buffer(
|
| 133 |
+
"posterior_mean_coef2",
|
| 134 |
+
(1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# calculate p2 reweighting
|
| 138 |
+
|
| 139 |
+
# register_buffer(
|
| 140 |
+
# "p2_loss_weight",
|
| 141 |
+
# (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
|
| 142 |
+
# ** -self.p2_loss_weight_gamma,
|
| 143 |
+
# )
|
| 144 |
+
|
| 145 |
+
# derive loss weight
|
| 146 |
+
# https://arxiv.org/abs/2303.09556
|
| 147 |
+
# snr: signal noise ratio
|
| 148 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
| 149 |
+
clipped_snr = snr.clone()
|
| 150 |
+
clipped_snr.clamp_(max=self.snr_clip)
|
| 151 |
+
|
| 152 |
+
register_buffer("clipped_snr", clipped_snr)
|
| 153 |
+
register_buffer("snr", snr)
|
| 154 |
+
|
| 155 |
+
def add_shape_channels(self, x):
|
| 156 |
+
return rearrange(x, f"... -> ...{' 1' * len(self.x_shape)}")
|
| 157 |
+
|
| 158 |
+
def model_predictions(self, x, t, action_cond=None, current_frame=None,
|
| 159 |
+
pose_cond=None, mode="training", reference_length=None, frame_idx=None):
|
| 160 |
+
x = x.permute(1,0,2,3,4)
|
| 161 |
+
action_cond = action_cond.permute(1,0,2)
|
| 162 |
+
if pose_cond is not None and pose_cond[0] is not None:
|
| 163 |
+
try:
|
| 164 |
+
pose_cond = pose_cond.permute(1,0,2)
|
| 165 |
+
except:
|
| 166 |
+
pass
|
| 167 |
+
t = t.permute(1,0)
|
| 168 |
+
model_output = self.model(x, t, action_cond, current_frame=current_frame, pose_cond=pose_cond,
|
| 169 |
+
mode=mode, reference_length=reference_length, frame_idx=frame_idx)
|
| 170 |
+
model_output = model_output.permute(1,0,2,3,4)
|
| 171 |
+
x = x.permute(1,0,2,3,4)
|
| 172 |
+
t = t.permute(1,0)
|
| 173 |
+
|
| 174 |
+
if self.objective == "pred_noise":
|
| 175 |
+
pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
|
| 176 |
+
x_start = self.predict_start_from_noise(x, t, pred_noise)
|
| 177 |
+
|
| 178 |
+
elif self.objective == "pred_x0":
|
| 179 |
+
x_start = model_output
|
| 180 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 181 |
+
|
| 182 |
+
elif self.objective == "pred_v":
|
| 183 |
+
v = model_output
|
| 184 |
+
x_start = self.predict_start_from_v(x, t, v)
|
| 185 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
return ModelPrediction(pred_noise, x_start, model_output)
|
| 189 |
+
|
| 190 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 191 |
+
return (
|
| 192 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 193 |
+
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
| 197 |
+
return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract(
|
| 198 |
+
self.sqrt_recipm1_alphas_cumprod, t, x_t.shape
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def predict_v(self, x_start, t, noise):
|
| 202 |
+
return (
|
| 203 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
|
| 204 |
+
- extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def predict_start_from_v(self, x_t, t, v):
|
| 208 |
+
return (
|
| 209 |
+
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
|
| 210 |
+
- extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def q_mean_variance(self, x_start, t):
|
| 214 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 215 |
+
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
|
| 216 |
+
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
| 217 |
+
return mean, variance, log_variance
|
| 218 |
+
|
| 219 |
+
def q_posterior(self, x_start, x_t, t):
|
| 220 |
+
posterior_mean = (
|
| 221 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 222 |
+
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 223 |
+
)
|
| 224 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 225 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 226 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 227 |
+
|
| 228 |
+
def q_sample(self, x_start, t, noise=None):
|
| 229 |
+
if noise is None:
|
| 230 |
+
noise = torch.randn_like(x_start)
|
| 231 |
+
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
|
| 232 |
+
return (
|
| 233 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 234 |
+
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def p_mean_variance(self, x, t, action_cond=None, pose_cond=None, reference_length=None):
|
| 238 |
+
model_pred = self.model_predictions(x=x, t=t, action_cond=action_cond,
|
| 239 |
+
pose_cond=pose_cond, reference_length=reference_length,
|
| 240 |
+
frame_idx=frame_idx)
|
| 241 |
+
x_start = model_pred.pred_x_start
|
| 242 |
+
return self.q_posterior(x_start=x_start, x_t=x, t=t)
|
| 243 |
+
|
| 244 |
+
def compute_loss_weights(self, noise_levels: torch.Tensor):
|
| 245 |
+
|
| 246 |
+
snr = self.snr[noise_levels]
|
| 247 |
+
clipped_snr = self.clipped_snr[noise_levels]
|
| 248 |
+
normalized_clipped_snr = clipped_snr / self.snr_clip
|
| 249 |
+
normalized_snr = snr / self.snr_clip
|
| 250 |
+
|
| 251 |
+
if not self.use_fused_snr:
|
| 252 |
+
# min SNR reweighting
|
| 253 |
+
match self.objective:
|
| 254 |
+
case "pred_noise":
|
| 255 |
+
return clipped_snr / snr
|
| 256 |
+
case "pred_x0":
|
| 257 |
+
return clipped_snr
|
| 258 |
+
case "pred_v":
|
| 259 |
+
return clipped_snr / (snr + 1)
|
| 260 |
+
|
| 261 |
+
cum_snr = torch.zeros_like(normalized_snr)
|
| 262 |
+
for t in range(0, noise_levels.shape[0]):
|
| 263 |
+
if t == 0:
|
| 264 |
+
cum_snr[t] = normalized_clipped_snr[t]
|
| 265 |
+
else:
|
| 266 |
+
cum_snr[t] = self.cum_snr_decay * cum_snr[t - 1] + (1 - self.cum_snr_decay) * normalized_clipped_snr[t]
|
| 267 |
+
|
| 268 |
+
cum_snr = F.pad(cum_snr[:-1], (0, 0, 1, 0), value=0.0)
|
| 269 |
+
clipped_fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_clipped_snr)
|
| 270 |
+
fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_snr)
|
| 271 |
+
|
| 272 |
+
match self.objective:
|
| 273 |
+
case "pred_noise":
|
| 274 |
+
return clipped_fused_snr / fused_snr
|
| 275 |
+
case "pred_x0":
|
| 276 |
+
return clipped_fused_snr * self.snr_clip
|
| 277 |
+
case "pred_v":
|
| 278 |
+
return clipped_fused_snr * self.snr_clip / (fused_snr * self.snr_clip + 1)
|
| 279 |
+
case _:
|
| 280 |
+
raise ValueError(f"unknown objective {self.objective}")
|
| 281 |
+
|
| 282 |
+
def forward(
|
| 283 |
+
self,
|
| 284 |
+
x: torch.Tensor,
|
| 285 |
+
action_cond: Optional[torch.Tensor],
|
| 286 |
+
pose_cond,
|
| 287 |
+
noise_levels: torch.Tensor,
|
| 288 |
+
reference_length,
|
| 289 |
+
frame_idx=None
|
| 290 |
+
):
|
| 291 |
+
noise = torch.randn_like(x)
|
| 292 |
+
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
|
| 293 |
+
|
| 294 |
+
noised_x = self.q_sample(x_start=x, t=noise_levels, noise=noise)
|
| 295 |
+
|
| 296 |
+
model_pred = self.model_predictions(x=noised_x, t=noise_levels, action_cond=action_cond,
|
| 297 |
+
pose_cond=pose_cond,reference_length=reference_length, frame_idx=frame_idx)
|
| 298 |
+
|
| 299 |
+
pred = model_pred.model_out
|
| 300 |
+
x_pred = model_pred.pred_x_start
|
| 301 |
+
|
| 302 |
+
if self.objective == "pred_noise":
|
| 303 |
+
target = noise
|
| 304 |
+
elif self.objective == "pred_x0":
|
| 305 |
+
target = x
|
| 306 |
+
elif self.objective == "pred_v":
|
| 307 |
+
target = self.predict_v(x, noise_levels, noise)
|
| 308 |
+
else:
|
| 309 |
+
raise ValueError(f"unknown objective {self.objective}")
|
| 310 |
+
|
| 311 |
+
# 训练的时候每个frame随便给噪声
|
| 312 |
+
loss = F.mse_loss(pred, target.detach(), reduction="none")
|
| 313 |
+
loss_weight = self.compute_loss_weights(noise_levels)
|
| 314 |
+
|
| 315 |
+
loss_weight = loss_weight.view(*loss_weight.shape, *((1,) * (loss.ndim - 2)))
|
| 316 |
+
|
| 317 |
+
loss = loss * loss_weight
|
| 318 |
+
|
| 319 |
+
return x_pred, loss
|
| 320 |
+
|
| 321 |
+
def sample_step(
|
| 322 |
+
self,
|
| 323 |
+
x: torch.Tensor,
|
| 324 |
+
action_cond: Optional[torch.Tensor],
|
| 325 |
+
pose_cond,
|
| 326 |
+
curr_noise_level: torch.Tensor,
|
| 327 |
+
next_noise_level: torch.Tensor,
|
| 328 |
+
guidance_fn: Optional[Callable] = None,
|
| 329 |
+
current_frame=None,
|
| 330 |
+
mode="training",
|
| 331 |
+
reference_length=None,
|
| 332 |
+
frame_idx=None
|
| 333 |
+
):
|
| 334 |
+
real_steps = torch.linspace(-1, self.timesteps - 1, steps=self.sampling_timesteps + 1, device=x.device).long()
|
| 335 |
+
|
| 336 |
+
# convert noise levels (0 ~ sampling_timesteps) to real noise levels (-1 ~ timesteps - 1)
|
| 337 |
+
curr_noise_level = real_steps[curr_noise_level]
|
| 338 |
+
next_noise_level = real_steps[next_noise_level]
|
| 339 |
+
|
| 340 |
+
if self.is_ddim_sampling:
|
| 341 |
+
return self.ddim_sample_step(
|
| 342 |
+
x=x,
|
| 343 |
+
action_cond=action_cond,
|
| 344 |
+
pose_cond=pose_cond,
|
| 345 |
+
curr_noise_level=curr_noise_level,
|
| 346 |
+
next_noise_level=next_noise_level,
|
| 347 |
+
guidance_fn=guidance_fn,
|
| 348 |
+
current_frame=current_frame,
|
| 349 |
+
mode=mode,
|
| 350 |
+
reference_length=reference_length,
|
| 351 |
+
frame_idx=frame_idx
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
# FIXME: temporary code for checking ddpm sampling
|
| 355 |
+
assert torch.all(
|
| 356 |
+
(curr_noise_level - 1 == next_noise_level) | ((curr_noise_level == -1) & (next_noise_level == -1))
|
| 357 |
+
), "Wrong noise level given for ddpm sampling."
|
| 358 |
+
|
| 359 |
+
assert (
|
| 360 |
+
self.sampling_timesteps == self.timesteps
|
| 361 |
+
), "sampling_timesteps should be equal to timesteps for ddpm sampling."
|
| 362 |
+
|
| 363 |
+
return self.ddpm_sample_step(
|
| 364 |
+
x=x,
|
| 365 |
+
action_cond=action_cond,
|
| 366 |
+
pose_cond=pose_cond,
|
| 367 |
+
curr_noise_level=curr_noise_level,
|
| 368 |
+
guidance_fn=guidance_fn,
|
| 369 |
+
reference_length=reference_length,
|
| 370 |
+
frame_idx=frame_idx
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
def ddpm_sample_step(
|
| 374 |
+
self,
|
| 375 |
+
x: torch.Tensor,
|
| 376 |
+
action_cond: Optional[torch.Tensor],
|
| 377 |
+
pose_cond,
|
| 378 |
+
curr_noise_level: torch.Tensor,
|
| 379 |
+
guidance_fn: Optional[Callable] = None,
|
| 380 |
+
reference_length=None,
|
| 381 |
+
frame_idx=None,
|
| 382 |
+
):
|
| 383 |
+
clipped_curr_noise_level = torch.where(
|
| 384 |
+
curr_noise_level < 0,
|
| 385 |
+
torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
|
| 386 |
+
curr_noise_level,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# treating as stabilization would require us to scale with sqrt of alpha_cum
|
| 390 |
+
orig_x = x.clone().detach()
|
| 391 |
+
scaled_context = self.q_sample(
|
| 392 |
+
x,
|
| 393 |
+
clipped_curr_noise_level,
|
| 394 |
+
noise=torch.zeros_like(x),
|
| 395 |
+
)
|
| 396 |
+
x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
|
| 397 |
+
|
| 398 |
+
if guidance_fn is not None:
|
| 399 |
+
raise NotImplementedError("Guidance function is not implemented for ddpm sampling yet.")
|
| 400 |
+
|
| 401 |
+
else:
|
| 402 |
+
model_mean, _, model_log_variance = self.p_mean_variance(
|
| 403 |
+
x=x,
|
| 404 |
+
t=clipped_curr_noise_level,
|
| 405 |
+
action_cond=action_cond,
|
| 406 |
+
pose_cond=pose_cond,
|
| 407 |
+
reference_length=reference_length,
|
| 408 |
+
frame_idx=frame_idx
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
noise = torch.where(
|
| 412 |
+
self.add_shape_channels(clipped_curr_noise_level > 0),
|
| 413 |
+
torch.randn_like(x),
|
| 414 |
+
0,
|
| 415 |
+
)
|
| 416 |
+
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
|
| 417 |
+
x_pred = model_mean + torch.exp(0.5 * model_log_variance) * noise
|
| 418 |
+
|
| 419 |
+
# only update frames where the noise level decreases
|
| 420 |
+
return torch.where(self.add_shape_channels(curr_noise_level == -1), orig_x, x_pred)
|
| 421 |
+
|
| 422 |
+
def ddim_sample_step(
|
| 423 |
+
self,
|
| 424 |
+
x: torch.Tensor,
|
| 425 |
+
action_cond: Optional[torch.Tensor],
|
| 426 |
+
pose_cond,
|
| 427 |
+
curr_noise_level: torch.Tensor,
|
| 428 |
+
next_noise_level: torch.Tensor,
|
| 429 |
+
guidance_fn: Optional[Callable] = None,
|
| 430 |
+
current_frame=None,
|
| 431 |
+
mode="training",
|
| 432 |
+
reference_length=None,
|
| 433 |
+
frame_idx=None
|
| 434 |
+
):
|
| 435 |
+
# convert noise level -1 to self.stabilization_level - 1
|
| 436 |
+
clipped_curr_noise_level = torch.where(
|
| 437 |
+
curr_noise_level < 0,
|
| 438 |
+
torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
|
| 439 |
+
curr_noise_level,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# treating as stabilization would require us to scale with sqrt of alpha_cum
|
| 443 |
+
orig_x = x.clone().detach()
|
| 444 |
+
scaled_context = self.q_sample(
|
| 445 |
+
x,
|
| 446 |
+
clipped_curr_noise_level,
|
| 447 |
+
noise=torch.zeros_like(x),
|
| 448 |
+
)
|
| 449 |
+
x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
|
| 450 |
+
|
| 451 |
+
alpha = self.alphas_cumprod[clipped_curr_noise_level]
|
| 452 |
+
alpha_next = torch.where(
|
| 453 |
+
next_noise_level < 0,
|
| 454 |
+
torch.ones_like(next_noise_level),
|
| 455 |
+
self.alphas_cumprod[next_noise_level],
|
| 456 |
+
)
|
| 457 |
+
sigma = torch.where(
|
| 458 |
+
next_noise_level < 0,
|
| 459 |
+
torch.zeros_like(next_noise_level),
|
| 460 |
+
self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt(),
|
| 461 |
+
)
|
| 462 |
+
c = (1 - alpha_next - sigma**2).sqrt()
|
| 463 |
+
|
| 464 |
+
alpha_next = self.add_shape_channels(alpha_next)
|
| 465 |
+
c = self.add_shape_channels(c)
|
| 466 |
+
sigma = self.add_shape_channels(sigma)
|
| 467 |
+
|
| 468 |
+
if guidance_fn is not None:
|
| 469 |
+
with torch.enable_grad():
|
| 470 |
+
x = x.detach().requires_grad_()
|
| 471 |
+
|
| 472 |
+
model_pred = self.model_predictions(
|
| 473 |
+
x=x,
|
| 474 |
+
t=clipped_curr_noise_level,
|
| 475 |
+
action_cond=action_cond,
|
| 476 |
+
pose_cond=pose_cond,
|
| 477 |
+
current_frame=current_frame,
|
| 478 |
+
mode=mode,
|
| 479 |
+
reference_length=reference_length,
|
| 480 |
+
frame_idx=frame_idx
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
guidance_loss = guidance_fn(model_pred.pred_x_start)
|
| 484 |
+
grad = -torch.autograd.grad(
|
| 485 |
+
guidance_loss,
|
| 486 |
+
x,
|
| 487 |
+
)[0]
|
| 488 |
+
|
| 489 |
+
pred_noise = model_pred.pred_noise + (1 - alpha_next).sqrt() * grad
|
| 490 |
+
x_start = self.predict_start_from_noise(x, clipped_curr_noise_level, pred_noise)
|
| 491 |
+
|
| 492 |
+
else:
|
| 493 |
+
# print(clipped_curr_noise_level)
|
| 494 |
+
model_pred = self.model_predictions(
|
| 495 |
+
x=x,
|
| 496 |
+
t=clipped_curr_noise_level,
|
| 497 |
+
action_cond=action_cond,
|
| 498 |
+
pose_cond=pose_cond,
|
| 499 |
+
current_frame=current_frame,
|
| 500 |
+
mode=mode,
|
| 501 |
+
reference_length=reference_length,
|
| 502 |
+
frame_idx=frame_idx
|
| 503 |
+
)
|
| 504 |
+
x_start = model_pred.pred_x_start
|
| 505 |
+
pred_noise = model_pred.pred_noise
|
| 506 |
+
|
| 507 |
+
noise = torch.randn_like(x)
|
| 508 |
+
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
|
| 509 |
+
|
| 510 |
+
x_pred = x_start * alpha_next.sqrt() + pred_noise * c + sigma * noise
|
| 511 |
+
|
| 512 |
+
# only update frames where the noise level decreases
|
| 513 |
+
mask = curr_noise_level == next_noise_level
|
| 514 |
+
x_pred = torch.where(
|
| 515 |
+
self.add_shape_channels(mask),
|
| 516 |
+
orig_x,
|
| 517 |
+
x_pred,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
return x_pred
|
algorithms/worldmem/models/dit.py
ADDED
|
@@ -0,0 +1,572 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
References:
|
| 3 |
+
- DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
| 4 |
+
- Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py
|
| 5 |
+
- Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional, Literal
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from .rotary_embedding_torch import RotaryEmbedding
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from .attention import SpatialAxialAttention, TemporalAxialAttention, MemTemporalAxialAttention, MemFullAttention
|
| 14 |
+
from timm.models.vision_transformer import Mlp
|
| 15 |
+
from timm.layers.helpers import to_2tuple
|
| 16 |
+
import math
|
| 17 |
+
from collections import namedtuple
|
| 18 |
+
from typing import Optional, Callable
|
| 19 |
+
from .cameractrl_module import SimpleCameraPoseEncoder
|
| 20 |
+
|
| 21 |
+
def modulate(x, shift, scale):
|
| 22 |
+
fixed_dims = [1] * len(shift.shape[1:])
|
| 23 |
+
shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims)
|
| 24 |
+
scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims)
|
| 25 |
+
while shift.dim() < x.dim():
|
| 26 |
+
shift = shift.unsqueeze(-2)
|
| 27 |
+
scale = scale.unsqueeze(-2)
|
| 28 |
+
return x * (1 + scale) + shift
|
| 29 |
+
|
| 30 |
+
def gate(x, g):
|
| 31 |
+
fixed_dims = [1] * len(g.shape[1:])
|
| 32 |
+
g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims)
|
| 33 |
+
while g.dim() < x.dim():
|
| 34 |
+
g = g.unsqueeze(-2)
|
| 35 |
+
return g * x
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class PatchEmbed(nn.Module):
|
| 39 |
+
"""2D Image to Patch Embedding"""
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
img_height=256,
|
| 44 |
+
img_width=256,
|
| 45 |
+
patch_size=16,
|
| 46 |
+
in_chans=3,
|
| 47 |
+
embed_dim=768,
|
| 48 |
+
norm_layer=None,
|
| 49 |
+
flatten=True,
|
| 50 |
+
):
|
| 51 |
+
super().__init__()
|
| 52 |
+
img_size = (img_height, img_width)
|
| 53 |
+
patch_size = to_2tuple(patch_size)
|
| 54 |
+
self.img_size = img_size
|
| 55 |
+
self.patch_size = patch_size
|
| 56 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 57 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 58 |
+
self.flatten = flatten
|
| 59 |
+
|
| 60 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 61 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 62 |
+
|
| 63 |
+
def forward(self, x, random_sample=False):
|
| 64 |
+
B, C, H, W = x.shape
|
| 65 |
+
assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 66 |
+
|
| 67 |
+
x = self.proj(x)
|
| 68 |
+
if self.flatten:
|
| 69 |
+
x = rearrange(x, "B C H W -> B (H W) C")
|
| 70 |
+
else:
|
| 71 |
+
x = rearrange(x, "B C H W -> B H W C")
|
| 72 |
+
x = self.norm(x)
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TimestepEmbedder(nn.Module):
|
| 77 |
+
"""
|
| 78 |
+
Embeds scalar timesteps into vector representations.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(self, hidden_size, frequency_embedding_size=256, freq_type='time_step'):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.mlp = nn.Sequential(
|
| 84 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size
|
| 85 |
+
nn.SiLU(),
|
| 86 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 87 |
+
)
|
| 88 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 89 |
+
self.freq_type = freq_type
|
| 90 |
+
|
| 91 |
+
@staticmethod
|
| 92 |
+
def timestep_embedding(t, dim, max_period=10000, freq_type='time_step'):
|
| 93 |
+
"""
|
| 94 |
+
Create sinusoidal timestep embeddings.
|
| 95 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 96 |
+
These may be fractional.
|
| 97 |
+
:param dim: the dimension of the output.
|
| 98 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 99 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 100 |
+
"""
|
| 101 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 102 |
+
half = dim // 2
|
| 103 |
+
|
| 104 |
+
if freq_type == 'time_step':
|
| 105 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
| 106 |
+
elif freq_type == 'spatial': # ~(-5 5)
|
| 107 |
+
freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi
|
| 108 |
+
elif freq_type == 'angle': # 0-360
|
| 109 |
+
freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi / 180
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
args = t[:, None].float() * freqs[None]
|
| 113 |
+
|
| 114 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 115 |
+
if dim % 2:
|
| 116 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 117 |
+
return embedding
|
| 118 |
+
|
| 119 |
+
def forward(self, t):
|
| 120 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size, freq_type=self.freq_type)
|
| 121 |
+
t_emb = self.mlp(t_freq)
|
| 122 |
+
return t_emb
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class FinalLayer(nn.Module):
|
| 126 |
+
"""
|
| 127 |
+
The final layer of DiT.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 133 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 134 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 135 |
+
|
| 136 |
+
def forward(self, x, c):
|
| 137 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 138 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 139 |
+
x = self.linear(x)
|
| 140 |
+
return x
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class SpatioTemporalDiTBlock(nn.Module):
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
hidden_size,
|
| 147 |
+
num_heads,
|
| 148 |
+
reference_length,
|
| 149 |
+
mlp_ratio=4.0,
|
| 150 |
+
is_causal=True,
|
| 151 |
+
spatial_rotary_emb: Optional[RotaryEmbedding] = None,
|
| 152 |
+
temporal_rotary_emb: Optional[RotaryEmbedding] = None,
|
| 153 |
+
reference_rotary_emb=None,
|
| 154 |
+
use_plucker=False,
|
| 155 |
+
relative_embedding=False,
|
| 156 |
+
state_embed_only_on_qk=False,
|
| 157 |
+
use_memory_attention=False,
|
| 158 |
+
ref_mode='sequential'
|
| 159 |
+
):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.is_causal = is_causal
|
| 162 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 163 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 164 |
+
|
| 165 |
+
self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 166 |
+
self.s_attn = SpatialAxialAttention(
|
| 167 |
+
hidden_size,
|
| 168 |
+
heads=num_heads,
|
| 169 |
+
dim_head=hidden_size // num_heads,
|
| 170 |
+
rotary_emb=spatial_rotary_emb
|
| 171 |
+
)
|
| 172 |
+
self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 173 |
+
self.s_mlp = Mlp(
|
| 174 |
+
in_features=hidden_size,
|
| 175 |
+
hidden_features=mlp_hidden_dim,
|
| 176 |
+
act_layer=approx_gelu,
|
| 177 |
+
drop=0,
|
| 178 |
+
)
|
| 179 |
+
self.s_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 180 |
+
|
| 181 |
+
self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 182 |
+
self.t_attn = TemporalAxialAttention(
|
| 183 |
+
hidden_size,
|
| 184 |
+
heads=num_heads,
|
| 185 |
+
dim_head=hidden_size // num_heads,
|
| 186 |
+
is_causal=is_causal,
|
| 187 |
+
rotary_emb=temporal_rotary_emb,
|
| 188 |
+
reference_length=reference_length
|
| 189 |
+
)
|
| 190 |
+
self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 191 |
+
self.t_mlp = Mlp(
|
| 192 |
+
in_features=hidden_size,
|
| 193 |
+
hidden_features=mlp_hidden_dim,
|
| 194 |
+
act_layer=approx_gelu,
|
| 195 |
+
drop=0,
|
| 196 |
+
)
|
| 197 |
+
self.t_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 198 |
+
|
| 199 |
+
self.use_memory_attention = use_memory_attention
|
| 200 |
+
if self.use_memory_attention:
|
| 201 |
+
self.r_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 202 |
+
self.ref_type = "full_ref"
|
| 203 |
+
if self.ref_type == "temporal_ref":
|
| 204 |
+
self.r_attn = MemTemporalAxialAttention(
|
| 205 |
+
hidden_size,
|
| 206 |
+
heads=num_heads,
|
| 207 |
+
dim_head=hidden_size // num_heads,
|
| 208 |
+
is_causal=is_causal,
|
| 209 |
+
rotary_emb=None
|
| 210 |
+
)
|
| 211 |
+
elif self.ref_type == "full_ref":
|
| 212 |
+
self.r_attn = MemFullAttention(
|
| 213 |
+
hidden_size,
|
| 214 |
+
heads=num_heads,
|
| 215 |
+
dim_head=hidden_size // num_heads,
|
| 216 |
+
is_causal=is_causal,
|
| 217 |
+
rotary_emb=reference_rotary_emb,
|
| 218 |
+
reference_length=reference_length
|
| 219 |
+
)
|
| 220 |
+
self.r_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 221 |
+
self.r_mlp = Mlp(
|
| 222 |
+
in_features=hidden_size,
|
| 223 |
+
hidden_features=mlp_hidden_dim,
|
| 224 |
+
act_layer=approx_gelu,
|
| 225 |
+
drop=0,
|
| 226 |
+
)
|
| 227 |
+
self.r_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 228 |
+
|
| 229 |
+
self.use_plucker = use_plucker
|
| 230 |
+
if use_plucker:
|
| 231 |
+
self.pose_cond_mlp = nn.Linear(hidden_size, hidden_size)
|
| 232 |
+
self.temporal_pose_cond_mlp = nn.Linear(hidden_size, hidden_size)
|
| 233 |
+
|
| 234 |
+
self.reference_length = reference_length
|
| 235 |
+
self.relative_embedding = relative_embedding
|
| 236 |
+
self.state_embed_only_on_qk = state_embed_only_on_qk
|
| 237 |
+
|
| 238 |
+
self.ref_mode = ref_mode
|
| 239 |
+
|
| 240 |
+
if self.ref_mode == 'parallel':
|
| 241 |
+
self.parallel_map = nn.Linear(hidden_size, hidden_size)
|
| 242 |
+
|
| 243 |
+
def forward(self, x, c, current_frame=None, timestep=None, is_last_block=False,
|
| 244 |
+
pose_cond=None, mode="training", c_action_cond=None, reference_length=None):
|
| 245 |
+
B, T, H, W, D = x.shape
|
| 246 |
+
|
| 247 |
+
# spatial block
|
| 248 |
+
|
| 249 |
+
s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1)
|
| 250 |
+
x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa)
|
| 251 |
+
x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp)
|
| 252 |
+
|
| 253 |
+
# temporal block
|
| 254 |
+
if c_action_cond is not None:
|
| 255 |
+
t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c_action_cond).chunk(6, dim=-1)
|
| 256 |
+
else:
|
| 257 |
+
t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1)
|
| 258 |
+
|
| 259 |
+
x_t = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa)
|
| 260 |
+
x_t = x_t + gate(self.t_mlp(modulate(self.t_norm2(x_t), t_shift_mlp, t_scale_mlp)), t_gate_mlp)
|
| 261 |
+
|
| 262 |
+
if self.ref_mode == 'sequential':
|
| 263 |
+
x = x_t
|
| 264 |
+
|
| 265 |
+
# memory block
|
| 266 |
+
relative_embedding = self.relative_embedding # and mode == "training"
|
| 267 |
+
|
| 268 |
+
if self.use_memory_attention:
|
| 269 |
+
r_shift_msa, r_scale_msa, r_gate_msa, r_shift_mlp, r_scale_mlp, r_gate_mlp = self.r_adaLN_modulation(c).chunk(6, dim=-1)
|
| 270 |
+
|
| 271 |
+
if pose_cond is not None:
|
| 272 |
+
if self.use_plucker:
|
| 273 |
+
input_cond = self.pose_cond_mlp(pose_cond)
|
| 274 |
+
|
| 275 |
+
if relative_embedding:
|
| 276 |
+
n_frames = x.shape[1] - reference_length
|
| 277 |
+
x1_relative_embedding = []
|
| 278 |
+
r_shift_msa_relative_embedding = []
|
| 279 |
+
r_scale_msa_relative_embedding = []
|
| 280 |
+
for i in range(n_frames):
|
| 281 |
+
x1_relative_embedding.append(torch.cat([x[:,i:i+1], x[:, -reference_length:]], dim=1).clone())
|
| 282 |
+
r_shift_msa_relative_embedding.append(torch.cat([r_shift_msa[:,i:i+1], r_shift_msa[:, -reference_length:]], dim=1).clone())
|
| 283 |
+
r_scale_msa_relative_embedding.append(torch.cat([r_scale_msa[:,i:i+1], r_scale_msa[:, -reference_length:]], dim=1).clone())
|
| 284 |
+
x1_zero_frame = torch.cat(x1_relative_embedding, dim=1)
|
| 285 |
+
r_shift_msa = torch.cat(r_shift_msa_relative_embedding, dim=1)
|
| 286 |
+
r_scale_msa = torch.cat(r_scale_msa_relative_embedding, dim=1)
|
| 287 |
+
|
| 288 |
+
# if current_frame == 18:
|
| 289 |
+
# import pdb;pdb.set_trace()
|
| 290 |
+
|
| 291 |
+
if self.state_embed_only_on_qk:
|
| 292 |
+
attn_input = x1_zero_frame
|
| 293 |
+
extra_condition = input_cond
|
| 294 |
+
else:
|
| 295 |
+
attn_input = input_cond + x1_zero_frame
|
| 296 |
+
extra_condition = None
|
| 297 |
+
else:
|
| 298 |
+
attn_input = input_cond + x
|
| 299 |
+
extra_condition = None
|
| 300 |
+
# print("input_cond2:", input_cond.abs().mean())
|
| 301 |
+
# print("c:", c.abs().mean())
|
| 302 |
+
# input_cond = x1
|
| 303 |
+
|
| 304 |
+
x = x + gate(self.r_attn(modulate(self.r_norm1(attn_input), r_shift_msa, r_scale_msa),
|
| 305 |
+
relative_embedding=relative_embedding,
|
| 306 |
+
extra_condition=extra_condition,
|
| 307 |
+
state_embed_only_on_qk=self.state_embed_only_on_qk,
|
| 308 |
+
reference_length=reference_length), r_gate_msa)
|
| 309 |
+
else:
|
| 310 |
+
# pose_cond *= 0
|
| 311 |
+
x = x + gate(self.r_attn(modulate(self.r_norm1(x+pose_cond[:,:,None, None]), r_shift_msa, r_scale_msa),
|
| 312 |
+
current_frame=current_frame, timestep=timestep,
|
| 313 |
+
is_last_block=is_last_block,
|
| 314 |
+
reference_length=reference_length), r_gate_msa)
|
| 315 |
+
else:
|
| 316 |
+
x = x + gate(self.r_attn(modulate(self.r_norm1(x), r_shift_msa, r_scale_msa), current_frame=current_frame, timestep=timestep,
|
| 317 |
+
is_last_block=is_last_block), r_gate_msa)
|
| 318 |
+
|
| 319 |
+
x = x + gate(self.r_mlp(modulate(self.r_norm2(x), r_shift_mlp, r_scale_mlp)), r_gate_mlp)
|
| 320 |
+
|
| 321 |
+
if self.ref_mode == 'parallel':
|
| 322 |
+
x = x_t + self.parallel_map(x)
|
| 323 |
+
|
| 324 |
+
return x
|
| 325 |
+
|
| 326 |
+
# print((x1-x2).abs().sum())
|
| 327 |
+
# r_shift_msa, r_scale_msa, r_gate_msa, r_shift_mlp, r_scale_mlp, r_gate_mlp = self.r_adaLN_modulation(c).chunk(6, dim=-1)
|
| 328 |
+
# x2 = x1 + gate(self.r_attn(modulate(self.r_norm1(x_), r_shift_msa, r_scale_msa)), r_gate_msa)
|
| 329 |
+
# x2 = gate(self.r_mlp(modulate(self.r_norm2(x2), r_shift_mlp, r_scale_mlp)), r_gate_mlp)
|
| 330 |
+
# x = x1 + x2
|
| 331 |
+
|
| 332 |
+
# print(x.mean())
|
| 333 |
+
# return x
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class DiT(nn.Module):
|
| 337 |
+
"""
|
| 338 |
+
Diffusion model with a Transformer backbone.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
def __init__(
|
| 342 |
+
self,
|
| 343 |
+
input_h=18,
|
| 344 |
+
input_w=32,
|
| 345 |
+
patch_size=2,
|
| 346 |
+
in_channels=16,
|
| 347 |
+
hidden_size=1024,
|
| 348 |
+
depth=12,
|
| 349 |
+
num_heads=16,
|
| 350 |
+
mlp_ratio=4.0,
|
| 351 |
+
action_cond_dim=25,
|
| 352 |
+
pose_cond_dim=4,
|
| 353 |
+
max_frames=32,
|
| 354 |
+
reference_length=8,
|
| 355 |
+
use_plucker=False,
|
| 356 |
+
relative_embedding=False,
|
| 357 |
+
state_embed_only_on_qk=False,
|
| 358 |
+
use_memory_attention=False,
|
| 359 |
+
add_timestamp_embedding=False,
|
| 360 |
+
ref_mode='sequential'
|
| 361 |
+
):
|
| 362 |
+
super().__init__()
|
| 363 |
+
self.in_channels = in_channels
|
| 364 |
+
self.out_channels = in_channels
|
| 365 |
+
self.patch_size = patch_size
|
| 366 |
+
self.num_heads = num_heads
|
| 367 |
+
self.max_frames = max_frames
|
| 368 |
+
|
| 369 |
+
self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False)
|
| 370 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 371 |
+
|
| 372 |
+
self.add_timestamp_embedding = add_timestamp_embedding
|
| 373 |
+
if self.add_timestamp_embedding:
|
| 374 |
+
self.timestamp_embedding = TimestepEmbedder(hidden_size)
|
| 375 |
+
|
| 376 |
+
frame_h, frame_w = self.x_embedder.grid_size
|
| 377 |
+
|
| 378 |
+
self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
|
| 379 |
+
self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads)
|
| 380 |
+
# self.reference_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
|
| 381 |
+
self.reference_rotary_emb = None
|
| 382 |
+
|
| 383 |
+
self.external_cond = nn.Linear(action_cond_dim, hidden_size) if action_cond_dim > 0 else nn.Identity()
|
| 384 |
+
|
| 385 |
+
# self.pose_cond = nn.Linear(pose_cond_dim, hidden_size) if pose_cond_dim > 0 else nn.Identity()
|
| 386 |
+
|
| 387 |
+
self.use_plucker = use_plucker
|
| 388 |
+
if not self.use_plucker:
|
| 389 |
+
self.position_embedder = TimestepEmbedder(hidden_size, freq_type='spatial')
|
| 390 |
+
self.angle_embedder = TimestepEmbedder(hidden_size, freq_type='angle')
|
| 391 |
+
else:
|
| 392 |
+
self.pose_embedder = SimpleCameraPoseEncoder(c_in=6, c_out=hidden_size)
|
| 393 |
+
|
| 394 |
+
self.blocks = nn.ModuleList(
|
| 395 |
+
[
|
| 396 |
+
SpatioTemporalDiTBlock(
|
| 397 |
+
hidden_size,
|
| 398 |
+
num_heads,
|
| 399 |
+
mlp_ratio=mlp_ratio,
|
| 400 |
+
is_causal=True,
|
| 401 |
+
reference_length=reference_length,
|
| 402 |
+
spatial_rotary_emb=self.spatial_rotary_emb,
|
| 403 |
+
temporal_rotary_emb=self.temporal_rotary_emb,
|
| 404 |
+
reference_rotary_emb=self.reference_rotary_emb,
|
| 405 |
+
use_plucker=self.use_plucker,
|
| 406 |
+
relative_embedding=relative_embedding,
|
| 407 |
+
state_embed_only_on_qk=state_embed_only_on_qk,
|
| 408 |
+
use_memory_attention=use_memory_attention,
|
| 409 |
+
ref_mode=ref_mode
|
| 410 |
+
)
|
| 411 |
+
for _ in range(depth)
|
| 412 |
+
]
|
| 413 |
+
)
|
| 414 |
+
self.use_memory_attention = use_memory_attention
|
| 415 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 416 |
+
self.initialize_weights()
|
| 417 |
+
|
| 418 |
+
def initialize_weights(self):
|
| 419 |
+
# Initialize transformer layers:
|
| 420 |
+
def _basic_init(module):
|
| 421 |
+
if isinstance(module, nn.Linear):
|
| 422 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 423 |
+
if module.bias is not None:
|
| 424 |
+
nn.init.constant_(module.bias, 0)
|
| 425 |
+
|
| 426 |
+
self.apply(_basic_init)
|
| 427 |
+
|
| 428 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 429 |
+
w = self.x_embedder.proj.weight.data
|
| 430 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 431 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 432 |
+
|
| 433 |
+
# Initialize timestep embedding MLP:
|
| 434 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 435 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 436 |
+
|
| 437 |
+
if self.use_memory_attention:
|
| 438 |
+
if not self.use_plucker:
|
| 439 |
+
nn.init.normal_(self.position_embedder.mlp[0].weight, std=0.02)
|
| 440 |
+
nn.init.normal_(self.position_embedder.mlp[2].weight, std=0.02)
|
| 441 |
+
|
| 442 |
+
nn.init.normal_(self.angle_embedder.mlp[0].weight, std=0.02)
|
| 443 |
+
nn.init.normal_(self.angle_embedder.mlp[2].weight, std=0.02)
|
| 444 |
+
|
| 445 |
+
if self.add_timestamp_embedding:
|
| 446 |
+
nn.init.normal_(self.timestamp_embedding.mlp[0].weight, std=0.02)
|
| 447 |
+
nn.init.normal_(self.timestamp_embedding.mlp[2].weight, std=0.02)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 451 |
+
for block in self.blocks:
|
| 452 |
+
nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0)
|
| 453 |
+
nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0)
|
| 454 |
+
nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0)
|
| 455 |
+
nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0)
|
| 456 |
+
|
| 457 |
+
if self.use_plucker and self.use_memory_attention:
|
| 458 |
+
nn.init.constant_(block.pose_cond_mlp.weight, 0)
|
| 459 |
+
nn.init.constant_(block.pose_cond_mlp.bias, 0)
|
| 460 |
+
|
| 461 |
+
# Zero-out output layers:
|
| 462 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 463 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 464 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 465 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 466 |
+
|
| 467 |
+
def unpatchify(self, x):
|
| 468 |
+
"""
|
| 469 |
+
x: (N, H, W, patch_size**2 * C)
|
| 470 |
+
imgs: (N, H, W, C)
|
| 471 |
+
"""
|
| 472 |
+
c = self.out_channels
|
| 473 |
+
p = self.x_embedder.patch_size[0]
|
| 474 |
+
h = x.shape[1]
|
| 475 |
+
w = x.shape[2]
|
| 476 |
+
|
| 477 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 478 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 479 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 480 |
+
return imgs
|
| 481 |
+
|
| 482 |
+
def forward(self, x, t, action_cond=None, pose_cond=None, current_frame=None, mode=None,
|
| 483 |
+
reference_length=None, frame_idx=None):
|
| 484 |
+
"""
|
| 485 |
+
Forward pass of DiT.
|
| 486 |
+
x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 487 |
+
t: (B, T,) tensor of diffusion timesteps
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
B, T, C, H, W = x.shape
|
| 491 |
+
|
| 492 |
+
# add spatial embeddings
|
| 493 |
+
x = rearrange(x, "b t c h w -> (b t) c h w")
|
| 494 |
+
|
| 495 |
+
x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
|
| 496 |
+
# restore shape
|
| 497 |
+
x = rearrange(x, "(b t) h w d -> b t h w d", t=T)
|
| 498 |
+
# embed noise steps
|
| 499 |
+
t = rearrange(t, "b t -> (b t)")
|
| 500 |
+
|
| 501 |
+
c_t = self.t_embedder(t) # (N, D)
|
| 502 |
+
c = c_t.clone()
|
| 503 |
+
c = rearrange(c, "(b t) d -> b t d", t=T)
|
| 504 |
+
|
| 505 |
+
if torch.is_tensor(action_cond):
|
| 506 |
+
try:
|
| 507 |
+
c_action_cond = c + self.external_cond(action_cond)
|
| 508 |
+
except:
|
| 509 |
+
import pdb;pdb.set_trace()
|
| 510 |
+
else:
|
| 511 |
+
c_action_cond = None
|
| 512 |
+
|
| 513 |
+
if torch.is_tensor(pose_cond):
|
| 514 |
+
if not self.use_plucker:
|
| 515 |
+
pose_cond = pose_cond.to(action_cond.dtype)
|
| 516 |
+
b_, t_, d_ = pose_cond.shape
|
| 517 |
+
pos_emb = self.position_embedder(rearrange(pose_cond[...,:3], "b t d -> (b t d)"))
|
| 518 |
+
angle_emb = self.angle_embedder(rearrange(pose_cond[...,3:], "b t d -> (b t d)"))
|
| 519 |
+
pos_emb = rearrange(pos_emb, "(b t d) c -> b t d c", b=b_, t=t_, d=3).sum(-2)
|
| 520 |
+
angle_emb = rearrange(angle_emb, "(b t d) c -> b t d c", b=b_, t=t_, d=2).sum(-2)
|
| 521 |
+
pc = pos_emb + angle_emb
|
| 522 |
+
else:
|
| 523 |
+
pose_cond = pose_cond[:, :, ::40, ::40]
|
| 524 |
+
# pc = self.pose_embedder(pose_cond)[0]
|
| 525 |
+
# pc = pc.permute(0,2,3,4,1)
|
| 526 |
+
pc = self.pose_embedder(pose_cond)
|
| 527 |
+
pc = pc.permute(1,0,2,3,4)
|
| 528 |
+
|
| 529 |
+
if torch.is_tensor(frame_idx) and self.add_timestamp_embedding:
|
| 530 |
+
bb = frame_idx.shape[1]
|
| 531 |
+
frame_idx = rearrange(frame_idx, "t b -> (b t)")
|
| 532 |
+
frame_idx = self.timestamp_embedding(frame_idx)
|
| 533 |
+
frame_idx = rearrange(frame_idx, "(b t) d -> b t d", b=bb)
|
| 534 |
+
pc = pc + frame_idx[:, :, None, None]
|
| 535 |
+
|
| 536 |
+
# pc = pc + rearrange(c_t.clone(), "(b t) d -> b t d", t=T)[:,:,None,None] # add time condition for different timestep scaling
|
| 537 |
+
else:
|
| 538 |
+
pc = None
|
| 539 |
+
|
| 540 |
+
for i, block in enumerate(self.blocks):
|
| 541 |
+
x = block(x, c, current_frame=current_frame, timestep=t, is_last_block= (i+1 == len(self.blocks)),
|
| 542 |
+
pose_cond=pc, mode=mode, c_action_cond=c_action_cond, reference_length=reference_length) # (N, T, H, W, D)
|
| 543 |
+
x = self.final_layer(x, c) # (N, T, H, W, patch_size ** 2 * out_channels)
|
| 544 |
+
# unpatchify
|
| 545 |
+
x = rearrange(x, "b t h w d -> (b t) h w d")
|
| 546 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 547 |
+
x = rearrange(x, "(b t) c h w -> b t c h w", t=T)
|
| 548 |
+
return x
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def DiT_S_2(action_cond_dim, pose_cond_dim, reference_length,
|
| 552 |
+
use_plucker, relative_embedding,
|
| 553 |
+
state_embed_only_on_qk, use_memory_attention, add_timestamp_embedding,
|
| 554 |
+
ref_mode):
|
| 555 |
+
return DiT(
|
| 556 |
+
patch_size=2,
|
| 557 |
+
hidden_size=1024,
|
| 558 |
+
depth=16,
|
| 559 |
+
num_heads=16,
|
| 560 |
+
action_cond_dim=action_cond_dim,
|
| 561 |
+
pose_cond_dim=pose_cond_dim,
|
| 562 |
+
reference_length=reference_length,
|
| 563 |
+
use_plucker=use_plucker,
|
| 564 |
+
relative_embedding=relative_embedding,
|
| 565 |
+
state_embed_only_on_qk=state_embed_only_on_qk,
|
| 566 |
+
use_memory_attention=use_memory_attention,
|
| 567 |
+
add_timestamp_embedding=add_timestamp_embedding,
|
| 568 |
+
ref_mode=ref_mode
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
DiT_models = {"DiT-S/2": DiT_S_2}
|
algorithms/worldmem/models/pose_prediction.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class PosePredictionNet(nn.Module):
|
| 6 |
+
def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128):
|
| 7 |
+
super(PosePredictionNet, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.cnn = nn.Sequential(
|
| 10 |
+
nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1),
|
| 11 |
+
nn.ReLU(),
|
| 12 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
| 13 |
+
nn.ReLU(),
|
| 14 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
| 15 |
+
nn.ReLU(),
|
| 16 |
+
nn.AdaptiveAvgPool2d((1, 1))
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
self.fc_img = nn.Linear(128, img_feat_dim)
|
| 20 |
+
|
| 21 |
+
self.mlp_motion = nn.Sequential(
|
| 22 |
+
nn.Linear(pose_dim + action_dim, hidden_dim),
|
| 23 |
+
nn.ReLU(),
|
| 24 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 25 |
+
nn.ReLU()
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.fc_out = nn.Sequential(
|
| 29 |
+
nn.Linear(img_feat_dim + hidden_dim, hidden_dim),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
nn.Linear(hidden_dim, pose_dim)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, img, action, pose):
|
| 35 |
+
img_feat = self.cnn(img).view(img.size(0), -1)
|
| 36 |
+
img_feat = self.fc_img(img_feat)
|
| 37 |
+
|
| 38 |
+
motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1))
|
| 39 |
+
fused_feat = torch.cat([img_feat, motion_feat], dim=1)
|
| 40 |
+
pose_next_pred = self.fc_out(fused_feat)
|
| 41 |
+
|
| 42 |
+
return pose_next_pred
|
algorithms/worldmem/models/rotary_embedding_torch.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
from math import pi, log
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn import Module, ModuleList
|
| 10 |
+
from torch.amp import autocast
|
| 11 |
+
from torch import nn, einsum, broadcast_tensors, Tensor
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
|
| 15 |
+
from typing import Literal
|
| 16 |
+
|
| 17 |
+
# helper functions
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def exists(val):
|
| 21 |
+
return val is not None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def default(val, d):
|
| 25 |
+
return val if exists(val) else d
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# broadcat, as tortoise-tts was using it
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def broadcat(tensors, dim=-1):
|
| 32 |
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
| 33 |
+
return torch.cat(broadcasted_tensors, dim=dim)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# rotary embedding helper functions
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def rotate_half(x):
|
| 40 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 41 |
+
x1, x2 = x.unbind(dim=-1)
|
| 42 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 43 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@autocast("cuda", enabled=False)
|
| 47 |
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
| 48 |
+
dtype = t.dtype
|
| 49 |
+
|
| 50 |
+
if t.ndim == 3:
|
| 51 |
+
seq_len = t.shape[seq_dim]
|
| 52 |
+
freqs = freqs[-seq_len:]
|
| 53 |
+
|
| 54 |
+
rot_dim = freqs.shape[-1]
|
| 55 |
+
end_index = start_index + rot_dim
|
| 56 |
+
|
| 57 |
+
assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
| 58 |
+
|
| 59 |
+
# Split t into three parts: left, middle (to be transformed), and right
|
| 60 |
+
t_left = t[..., :start_index]
|
| 61 |
+
t_middle = t[..., start_index:end_index]
|
| 62 |
+
t_right = t[..., end_index:]
|
| 63 |
+
|
| 64 |
+
# Apply rotary embeddings without modifying t in place
|
| 65 |
+
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
| 66 |
+
|
| 67 |
+
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
| 68 |
+
|
| 69 |
+
return out.type(dtype)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# learned rotation helpers
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
| 76 |
+
if exists(freq_ranges):
|
| 77 |
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
| 78 |
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
| 79 |
+
|
| 80 |
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
| 81 |
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# classes
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class RotaryEmbedding(Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
dim,
|
| 91 |
+
custom_freqs: Tensor | None = None,
|
| 92 |
+
freqs_for: Literal["lang", "pixel", "constant"] = "lang",
|
| 93 |
+
theta=10000,
|
| 94 |
+
max_freq=10,
|
| 95 |
+
num_freqs=1,
|
| 96 |
+
learned_freq=False,
|
| 97 |
+
use_xpos=False,
|
| 98 |
+
xpos_scale_base=512,
|
| 99 |
+
interpolate_factor=1.0,
|
| 100 |
+
theta_rescale_factor=1.0,
|
| 101 |
+
seq_before_head_dim=False,
|
| 102 |
+
cache_if_possible=True,
|
| 103 |
+
cache_max_seq_len=8192,
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 107 |
+
# has some connection to NTK literature
|
| 108 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 109 |
+
|
| 110 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 111 |
+
|
| 112 |
+
self.freqs_for = freqs_for
|
| 113 |
+
|
| 114 |
+
if exists(custom_freqs):
|
| 115 |
+
freqs = custom_freqs
|
| 116 |
+
elif freqs_for == "lang":
|
| 117 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 118 |
+
elif freqs_for == "pixel":
|
| 119 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 120 |
+
elif freqs_for == "spacetime":
|
| 121 |
+
time_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 122 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 123 |
+
elif freqs_for == "constant":
|
| 124 |
+
freqs = torch.ones(num_freqs).float()
|
| 125 |
+
|
| 126 |
+
if freqs_for == "spacetime":
|
| 127 |
+
self.time_freqs = nn.Parameter(time_freqs, requires_grad=learned_freq)
|
| 128 |
+
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
| 129 |
+
|
| 130 |
+
self.cache_if_possible = cache_if_possible
|
| 131 |
+
self.cache_max_seq_len = cache_max_seq_len
|
| 132 |
+
|
| 133 |
+
self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
| 134 |
+
self.register_buffer("cached_freqs_seq_len", torch.tensor(0), persistent=False)
|
| 135 |
+
|
| 136 |
+
self.learned_freq = learned_freq
|
| 137 |
+
|
| 138 |
+
# dummy for device
|
| 139 |
+
|
| 140 |
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
| 141 |
+
|
| 142 |
+
# default sequence dimension
|
| 143 |
+
|
| 144 |
+
self.seq_before_head_dim = seq_before_head_dim
|
| 145 |
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
| 146 |
+
|
| 147 |
+
# interpolation factors
|
| 148 |
+
|
| 149 |
+
assert interpolate_factor >= 1.0
|
| 150 |
+
self.interpolate_factor = interpolate_factor
|
| 151 |
+
|
| 152 |
+
# xpos
|
| 153 |
+
|
| 154 |
+
self.use_xpos = use_xpos
|
| 155 |
+
|
| 156 |
+
if not use_xpos:
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 160 |
+
self.scale_base = xpos_scale_base
|
| 161 |
+
|
| 162 |
+
self.register_buffer("scale", scale, persistent=False)
|
| 163 |
+
self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
| 164 |
+
self.register_buffer("cached_scales_seq_len", torch.tensor(0), persistent=False)
|
| 165 |
+
|
| 166 |
+
# add apply_rotary_emb as static method
|
| 167 |
+
|
| 168 |
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def device(self):
|
| 172 |
+
return self.dummy.device
|
| 173 |
+
|
| 174 |
+
def get_seq_pos(self, seq_len, device, dtype, offset=0):
|
| 175 |
+
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
|
| 176 |
+
|
| 177 |
+
def rotate_queries_or_keys(self, t, freqs, seq_dim=None, offset=0, scale=None):
|
| 178 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 179 |
+
|
| 180 |
+
assert not self.use_xpos or exists(scale), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
|
| 181 |
+
|
| 182 |
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
| 183 |
+
|
| 184 |
+
seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
|
| 185 |
+
|
| 186 |
+
seq_freqs = self.forward(seq, freqs, seq_len=seq_len, offset=offset)
|
| 187 |
+
|
| 188 |
+
if seq_dim == -3:
|
| 189 |
+
seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
|
| 190 |
+
|
| 191 |
+
return apply_rotary_emb(seq_freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
|
| 192 |
+
|
| 193 |
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
| 194 |
+
dtype, device, seq_dim = (
|
| 195 |
+
q.dtype,
|
| 196 |
+
q.device,
|
| 197 |
+
default(seq_dim, self.default_seq_dim),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
| 201 |
+
assert q_len <= k_len
|
| 202 |
+
|
| 203 |
+
q_scale = k_scale = 1.0
|
| 204 |
+
|
| 205 |
+
if self.use_xpos:
|
| 206 |
+
seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
|
| 207 |
+
|
| 208 |
+
q_scale = self.get_scale(seq[-q_len:]).type(dtype)
|
| 209 |
+
k_scale = self.get_scale(seq).type(dtype)
|
| 210 |
+
|
| 211 |
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
|
| 212 |
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
|
| 213 |
+
|
| 214 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 215 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 216 |
+
|
| 217 |
+
return rotated_q, rotated_k
|
| 218 |
+
|
| 219 |
+
def rotate_queries_and_keys(self, q, k, freqs, seq_dim=None):
|
| 220 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 221 |
+
|
| 222 |
+
assert self.use_xpos
|
| 223 |
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
| 224 |
+
|
| 225 |
+
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
| 226 |
+
|
| 227 |
+
seq_freqs = self.forward(seq, freqs, seq_len=seq_len)
|
| 228 |
+
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
| 229 |
+
|
| 230 |
+
if seq_dim == -3:
|
| 231 |
+
seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
|
| 232 |
+
scale = rearrange(scale, "n d -> n 1 d")
|
| 233 |
+
|
| 234 |
+
rotated_q = apply_rotary_emb(seq_freqs, q, scale=scale, seq_dim=seq_dim)
|
| 235 |
+
rotated_k = apply_rotary_emb(seq_freqs, k, scale=scale**-1, seq_dim=seq_dim)
|
| 236 |
+
|
| 237 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 238 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 239 |
+
|
| 240 |
+
return rotated_q, rotated_k
|
| 241 |
+
|
| 242 |
+
def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
|
| 243 |
+
assert self.use_xpos
|
| 244 |
+
|
| 245 |
+
should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
|
| 246 |
+
|
| 247 |
+
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len.item():
|
| 248 |
+
return self.cached_scales[offset : (offset + seq_len)]
|
| 249 |
+
|
| 250 |
+
scale = 1.0
|
| 251 |
+
if self.use_xpos:
|
| 252 |
+
power = (t - len(t) // 2) / self.scale_base
|
| 253 |
+
scale = self.scale ** rearrange(power, "n -> n 1")
|
| 254 |
+
scale = repeat(scale, "n d -> n (d r)", r=2)
|
| 255 |
+
|
| 256 |
+
if should_cache and offset == 0:
|
| 257 |
+
self.cached_scales[:seq_len] = scale.detach()
|
| 258 |
+
self.cached_scales_seq_len.copy_(seq_len)
|
| 259 |
+
|
| 260 |
+
return scale
|
| 261 |
+
|
| 262 |
+
def get_axial_freqs(self, *dims):
|
| 263 |
+
Colon = slice(None)
|
| 264 |
+
all_freqs = []
|
| 265 |
+
|
| 266 |
+
for ind, dim in enumerate(dims):
|
| 267 |
+
# only allow pixel freqs for last two dimensions
|
| 268 |
+
use_pixel = (self.freqs_for == "pixel" or self.freqs_for == "spacetime") and ind >= len(dims) - 2
|
| 269 |
+
if use_pixel:
|
| 270 |
+
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
| 271 |
+
else:
|
| 272 |
+
pos = torch.arange(dim, device=self.device)
|
| 273 |
+
|
| 274 |
+
if self.freqs_for == "spacetime" and not use_pixel:
|
| 275 |
+
seq_freqs = self.forward(pos, self.time_freqs, seq_len=dim)
|
| 276 |
+
else:
|
| 277 |
+
seq_freqs = self.forward(pos, self.freqs, seq_len=dim)
|
| 278 |
+
|
| 279 |
+
all_axis = [None] * len(dims)
|
| 280 |
+
all_axis[ind] = Colon
|
| 281 |
+
|
| 282 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
| 283 |
+
all_freqs.append(seq_freqs[new_axis_slice])
|
| 284 |
+
|
| 285 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
| 286 |
+
return torch.cat(all_freqs, dim=-1)
|
| 287 |
+
|
| 288 |
+
@autocast("cuda", enabled=False)
|
| 289 |
+
def forward(self, t: Tensor, freqs: Tensor, seq_len=None, offset=0):
|
| 290 |
+
should_cache = self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" and (offset + seq_len) <= self.cache_max_seq_len
|
| 291 |
+
|
| 292 |
+
if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len.item():
|
| 293 |
+
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
| 294 |
+
|
| 295 |
+
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
| 296 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
| 297 |
+
|
| 298 |
+
if should_cache and offset == 0:
|
| 299 |
+
self.cached_freqs[:seq_len] = freqs.detach()
|
| 300 |
+
self.cached_freqs_seq_len.copy_(seq_len)
|
| 301 |
+
|
| 302 |
+
return freqs
|
algorithms/worldmem/models/utils.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
|
| 3 |
+
Action format derived from VPT https://github.com/openai/Video-Pre-Training
|
| 4 |
+
Adapted from https://github.com/etched-ai/open-oasis/blob/master/utils.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torchvision.io import read_image, read_video
|
| 11 |
+
from torchvision.transforms.functional import resize
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from typing import Mapping, Sequence
|
| 14 |
+
from einops import rearrange, parse_shape
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def exists(val):
|
| 18 |
+
return val is not None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def default(val, d):
|
| 22 |
+
if exists(val):
|
| 23 |
+
return val
|
| 24 |
+
return d() if callable(d) else d
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def extract(a, t, x_shape):
|
| 28 |
+
f, b = t.shape
|
| 29 |
+
out = a[t]
|
| 30 |
+
return out.reshape(f, b, *((1,) * (len(x_shape) - 2)))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def linear_beta_schedule(timesteps):
|
| 34 |
+
"""
|
| 35 |
+
linear schedule, proposed in original ddpm paper
|
| 36 |
+
"""
|
| 37 |
+
scale = 1000 / timesteps
|
| 38 |
+
beta_start = scale * 0.0001
|
| 39 |
+
beta_end = scale * 0.02
|
| 40 |
+
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
| 44 |
+
"""
|
| 45 |
+
cosine schedule
|
| 46 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
| 47 |
+
"""
|
| 48 |
+
steps = timesteps + 1
|
| 49 |
+
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
|
| 50 |
+
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
|
| 51 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 52 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 53 |
+
return torch.clip(betas, 0, 0.999)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
|
| 58 |
+
"""
|
| 59 |
+
sigmoid schedule
|
| 60 |
+
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
|
| 61 |
+
better for images > 64x64, when used during training
|
| 62 |
+
"""
|
| 63 |
+
steps = timesteps + 1
|
| 64 |
+
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
|
| 65 |
+
v_start = torch.tensor(start / tau).sigmoid()
|
| 66 |
+
v_end = torch.tensor(end / tau).sigmoid()
|
| 67 |
+
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
|
| 68 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 69 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 70 |
+
return torch.clip(betas, 0, 0.999)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
ACTION_KEYS = [
|
| 74 |
+
"inventory",
|
| 75 |
+
"ESC",
|
| 76 |
+
"hotbar.1",
|
| 77 |
+
"hotbar.2",
|
| 78 |
+
"hotbar.3",
|
| 79 |
+
"hotbar.4",
|
| 80 |
+
"hotbar.5",
|
| 81 |
+
"hotbar.6",
|
| 82 |
+
"hotbar.7",
|
| 83 |
+
"hotbar.8",
|
| 84 |
+
"hotbar.9",
|
| 85 |
+
"forward",
|
| 86 |
+
"back",
|
| 87 |
+
"left",
|
| 88 |
+
"right",
|
| 89 |
+
"cameraX",
|
| 90 |
+
"cameraY",
|
| 91 |
+
"jump",
|
| 92 |
+
"sneak",
|
| 93 |
+
"sprint",
|
| 94 |
+
"swapHands",
|
| 95 |
+
"attack",
|
| 96 |
+
"use",
|
| 97 |
+
"pickItem",
|
| 98 |
+
"drop",
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
|
| 103 |
+
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
|
| 104 |
+
for i, current_actions in enumerate(actions):
|
| 105 |
+
for j, action_key in enumerate(ACTION_KEYS):
|
| 106 |
+
if action_key.startswith("camera"):
|
| 107 |
+
if action_key == "cameraX":
|
| 108 |
+
value = current_actions["camera"][0]
|
| 109 |
+
elif action_key == "cameraY":
|
| 110 |
+
value = current_actions["camera"][1]
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Unknown camera action key: {action_key}")
|
| 113 |
+
max_val = 20
|
| 114 |
+
bin_size = 0.5
|
| 115 |
+
num_buckets = int(max_val / bin_size)
|
| 116 |
+
value = (value - num_buckets) / num_buckets
|
| 117 |
+
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
|
| 118 |
+
else:
|
| 119 |
+
value = current_actions[action_key]
|
| 120 |
+
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
|
| 121 |
+
actions_one_hot[i, j] = value
|
| 122 |
+
|
| 123 |
+
return actions_one_hot
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
IMAGE_EXTENSIONS = {"png", "jpg", "jpeg"}
|
| 127 |
+
VIDEO_EXTENSIONS = {"mp4"}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load_prompt(path, video_offset=None, n_prompt_frames=1):
|
| 131 |
+
if path.lower().split(".")[-1] in IMAGE_EXTENSIONS:
|
| 132 |
+
print("prompt is image; ignoring video_offset and n_prompt_frames")
|
| 133 |
+
prompt = read_image(path)
|
| 134 |
+
# add frame dimension
|
| 135 |
+
prompt = rearrange(prompt, "c h w -> 1 c h w")
|
| 136 |
+
elif path.lower().split(".")[-1] in VIDEO_EXTENSIONS:
|
| 137 |
+
prompt = read_video(path, pts_unit="sec")[0]
|
| 138 |
+
if video_offset is not None:
|
| 139 |
+
prompt = prompt[video_offset:]
|
| 140 |
+
prompt = prompt[:n_prompt_frames]
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(f"unrecognized prompt file extension; expected one in {IMAGE_EXTENSIONS} or {VIDEO_EXTENSIONS}")
|
| 143 |
+
assert prompt.shape[0] == n_prompt_frames, f"input prompt {path} had less than n_prompt_frames={n_prompt_frames} frames"
|
| 144 |
+
prompt = resize(prompt, (360, 640))
|
| 145 |
+
# add batch dimension
|
| 146 |
+
prompt = rearrange(prompt, "t c h w -> 1 t c h w")
|
| 147 |
+
prompt = prompt.float() / 255.0
|
| 148 |
+
return prompt
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def load_actions(path, action_offset=None):
|
| 152 |
+
if path.endswith(".actions.pt"):
|
| 153 |
+
actions = one_hot_actions(torch.load(path))
|
| 154 |
+
elif path.endswith(".one_hot_actions.pt"):
|
| 155 |
+
actions = torch.load(path, weights_only=True)
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError("unrecognized action file extension; expected '*.actions.pt' or '*.one_hot_actions.pt'")
|
| 158 |
+
if action_offset is not None:
|
| 159 |
+
actions = actions[action_offset:]
|
| 160 |
+
actions = torch.cat([torch.zeros_like(actions[:1]), actions], dim=0)
|
| 161 |
+
# add batch dimension
|
| 162 |
+
actions = rearrange(actions, "t d -> 1 t d")
|
| 163 |
+
return actions
|
algorithms/worldmem/models/vae.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
References:
|
| 3 |
+
- VQGAN: https://github.com/CompVis/taming-transformers
|
| 4 |
+
- MAE: https://github.com/facebookresearch/mae
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import math
|
| 9 |
+
import functools
|
| 10 |
+
from collections import namedtuple
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from timm.models.vision_transformer import Mlp
|
| 16 |
+
from timm.layers.helpers import to_2tuple
|
| 17 |
+
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
|
| 18 |
+
from .dit import PatchEmbed
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DiagonalGaussianDistribution(object):
|
| 22 |
+
def __init__(self, parameters, deterministic=False, dim=1):
|
| 23 |
+
self.parameters = parameters
|
| 24 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
| 25 |
+
if dim == 1:
|
| 26 |
+
self.dims = [1, 2, 3]
|
| 27 |
+
elif dim == 2:
|
| 28 |
+
self.dims = [1, 2]
|
| 29 |
+
else:
|
| 30 |
+
raise NotImplementedError
|
| 31 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 32 |
+
self.deterministic = deterministic
|
| 33 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 34 |
+
self.var = torch.exp(self.logvar)
|
| 35 |
+
if self.deterministic:
|
| 36 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
| 37 |
+
|
| 38 |
+
def sample(self):
|
| 39 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
def mode(self):
|
| 43 |
+
return self.mean
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Attention(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
dim,
|
| 50 |
+
num_heads,
|
| 51 |
+
frame_height,
|
| 52 |
+
frame_width,
|
| 53 |
+
qkv_bias=False,
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.num_heads = num_heads
|
| 57 |
+
head_dim = dim // num_heads
|
| 58 |
+
self.frame_height = frame_height
|
| 59 |
+
self.frame_width = frame_width
|
| 60 |
+
|
| 61 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 62 |
+
self.proj = nn.Linear(dim, dim)
|
| 63 |
+
|
| 64 |
+
rotary_freqs = RotaryEmbedding(
|
| 65 |
+
dim=head_dim // 4,
|
| 66 |
+
freqs_for="pixel",
|
| 67 |
+
max_freq=frame_height * frame_width,
|
| 68 |
+
).get_axial_freqs(frame_height, frame_width)
|
| 69 |
+
self.register_buffer("rotary_freqs", rotary_freqs, persistent=False)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
B, N, C = x.shape
|
| 73 |
+
assert N == self.frame_height * self.frame_width
|
| 74 |
+
|
| 75 |
+
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
| 76 |
+
|
| 77 |
+
q = rearrange(
|
| 78 |
+
q,
|
| 79 |
+
"b (H W) (h d) -> b h H W d",
|
| 80 |
+
H=self.frame_height,
|
| 81 |
+
W=self.frame_width,
|
| 82 |
+
h=self.num_heads,
|
| 83 |
+
)
|
| 84 |
+
k = rearrange(
|
| 85 |
+
k,
|
| 86 |
+
"b (H W) (h d) -> b h H W d",
|
| 87 |
+
H=self.frame_height,
|
| 88 |
+
W=self.frame_width,
|
| 89 |
+
h=self.num_heads,
|
| 90 |
+
)
|
| 91 |
+
v = rearrange(
|
| 92 |
+
v,
|
| 93 |
+
"b (H W) (h d) -> b h H W d",
|
| 94 |
+
H=self.frame_height,
|
| 95 |
+
W=self.frame_width,
|
| 96 |
+
h=self.num_heads,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
q = apply_rotary_emb(self.rotary_freqs, q)
|
| 100 |
+
k = apply_rotary_emb(self.rotary_freqs, k)
|
| 101 |
+
|
| 102 |
+
q = rearrange(q, "b h H W d -> b h (H W) d")
|
| 103 |
+
k = rearrange(k, "b h H W d -> b h (H W) d")
|
| 104 |
+
v = rearrange(v, "b h H W d -> b h (H W) d")
|
| 105 |
+
|
| 106 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 107 |
+
x = rearrange(x, "b h N d -> b N (h d)")
|
| 108 |
+
|
| 109 |
+
x = self.proj(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class AttentionBlock(nn.Module):
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
dim,
|
| 117 |
+
num_heads,
|
| 118 |
+
frame_height,
|
| 119 |
+
frame_width,
|
| 120 |
+
mlp_ratio=4.0,
|
| 121 |
+
qkv_bias=False,
|
| 122 |
+
attn_causal=False,
|
| 123 |
+
act_layer=nn.GELU,
|
| 124 |
+
norm_layer=nn.LayerNorm,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.norm1 = norm_layer(dim)
|
| 128 |
+
self.attn = Attention(
|
| 129 |
+
dim,
|
| 130 |
+
num_heads,
|
| 131 |
+
frame_height,
|
| 132 |
+
frame_width,
|
| 133 |
+
qkv_bias=qkv_bias,
|
| 134 |
+
)
|
| 135 |
+
self.norm2 = norm_layer(dim)
|
| 136 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 137 |
+
self.mlp = Mlp(
|
| 138 |
+
in_features=dim,
|
| 139 |
+
hidden_features=mlp_hidden_dim,
|
| 140 |
+
act_layer=act_layer,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
x = x + self.attn(self.norm1(x))
|
| 145 |
+
x = x + self.mlp(self.norm2(x))
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class AutoencoderKL(nn.Module):
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
latent_dim,
|
| 153 |
+
input_height=256,
|
| 154 |
+
input_width=256,
|
| 155 |
+
patch_size=16,
|
| 156 |
+
enc_dim=768,
|
| 157 |
+
enc_depth=6,
|
| 158 |
+
enc_heads=12,
|
| 159 |
+
dec_dim=768,
|
| 160 |
+
dec_depth=6,
|
| 161 |
+
dec_heads=12,
|
| 162 |
+
mlp_ratio=4.0,
|
| 163 |
+
norm_layer=functools.partial(nn.LayerNorm, eps=1e-6),
|
| 164 |
+
use_variational=True,
|
| 165 |
+
**kwargs,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.input_height = input_height
|
| 169 |
+
self.input_width = input_width
|
| 170 |
+
self.patch_size = patch_size
|
| 171 |
+
self.seq_h = input_height // patch_size
|
| 172 |
+
self.seq_w = input_width // patch_size
|
| 173 |
+
self.seq_len = self.seq_h * self.seq_w
|
| 174 |
+
self.patch_dim = 3 * patch_size**2
|
| 175 |
+
|
| 176 |
+
self.latent_dim = latent_dim
|
| 177 |
+
self.enc_dim = enc_dim
|
| 178 |
+
self.dec_dim = dec_dim
|
| 179 |
+
|
| 180 |
+
# patch
|
| 181 |
+
self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim)
|
| 182 |
+
|
| 183 |
+
# encoder
|
| 184 |
+
self.encoder = nn.ModuleList(
|
| 185 |
+
[
|
| 186 |
+
AttentionBlock(
|
| 187 |
+
enc_dim,
|
| 188 |
+
enc_heads,
|
| 189 |
+
self.seq_h,
|
| 190 |
+
self.seq_w,
|
| 191 |
+
mlp_ratio,
|
| 192 |
+
qkv_bias=True,
|
| 193 |
+
norm_layer=norm_layer,
|
| 194 |
+
)
|
| 195 |
+
for i in range(enc_depth)
|
| 196 |
+
]
|
| 197 |
+
)
|
| 198 |
+
self.enc_norm = norm_layer(enc_dim)
|
| 199 |
+
|
| 200 |
+
# bottleneck
|
| 201 |
+
self.use_variational = use_variational
|
| 202 |
+
mult = 2 if self.use_variational else 1
|
| 203 |
+
self.quant_conv = nn.Linear(enc_dim, mult * latent_dim)
|
| 204 |
+
self.post_quant_conv = nn.Linear(latent_dim, dec_dim)
|
| 205 |
+
|
| 206 |
+
# decoder
|
| 207 |
+
self.decoder = nn.ModuleList(
|
| 208 |
+
[
|
| 209 |
+
AttentionBlock(
|
| 210 |
+
dec_dim,
|
| 211 |
+
dec_heads,
|
| 212 |
+
self.seq_h,
|
| 213 |
+
self.seq_w,
|
| 214 |
+
mlp_ratio,
|
| 215 |
+
qkv_bias=True,
|
| 216 |
+
norm_layer=norm_layer,
|
| 217 |
+
)
|
| 218 |
+
for i in range(dec_depth)
|
| 219 |
+
]
|
| 220 |
+
)
|
| 221 |
+
self.dec_norm = norm_layer(dec_dim)
|
| 222 |
+
self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch
|
| 223 |
+
|
| 224 |
+
# initialize this weight first
|
| 225 |
+
self.initialize_weights()
|
| 226 |
+
|
| 227 |
+
def initialize_weights(self):
|
| 228 |
+
# initialization
|
| 229 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 230 |
+
self.apply(self._init_weights)
|
| 231 |
+
|
| 232 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 233 |
+
w = self.patch_embed.proj.weight.data
|
| 234 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 235 |
+
|
| 236 |
+
def _init_weights(self, m):
|
| 237 |
+
if isinstance(m, nn.Linear):
|
| 238 |
+
# we use xavier_uniform following official JAX ViT:
|
| 239 |
+
nn.init.xavier_uniform_(m.weight)
|
| 240 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 241 |
+
nn.init.constant_(m.bias, 0.0)
|
| 242 |
+
elif isinstance(m, nn.LayerNorm):
|
| 243 |
+
nn.init.constant_(m.bias, 0.0)
|
| 244 |
+
nn.init.constant_(m.weight, 1.0)
|
| 245 |
+
|
| 246 |
+
def patchify(self, x):
|
| 247 |
+
# patchify
|
| 248 |
+
bsz, _, h, w = x.shape
|
| 249 |
+
x = x.reshape(
|
| 250 |
+
bsz,
|
| 251 |
+
3,
|
| 252 |
+
self.seq_h,
|
| 253 |
+
self.patch_size,
|
| 254 |
+
self.seq_w,
|
| 255 |
+
self.patch_size,
|
| 256 |
+
).permute([0, 1, 3, 5, 2, 4]) # [b, c, h, p, w, p] --> [b, c, p, p, h, w]
|
| 257 |
+
x = x.reshape(bsz, self.patch_dim, self.seq_h, self.seq_w) # --> [b, cxpxp, h, w]
|
| 258 |
+
x = x.permute([0, 2, 3, 1]).reshape(bsz, self.seq_len, self.patch_dim) # --> [b, hxw, cxpxp]
|
| 259 |
+
return x
|
| 260 |
+
|
| 261 |
+
def unpatchify(self, x):
|
| 262 |
+
bsz = x.shape[0]
|
| 263 |
+
# unpatchify
|
| 264 |
+
x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute([0, 3, 1, 2]) # [b, h, w, cxpxp] --> [b, cxpxp, h, w]
|
| 265 |
+
x = x.reshape(
|
| 266 |
+
bsz,
|
| 267 |
+
3,
|
| 268 |
+
self.patch_size,
|
| 269 |
+
self.patch_size,
|
| 270 |
+
self.seq_h,
|
| 271 |
+
self.seq_w,
|
| 272 |
+
).permute([0, 1, 4, 2, 5, 3]) # [b, c, p, p, h, w] --> [b, c, h, p, w, p]
|
| 273 |
+
x = x.reshape(
|
| 274 |
+
bsz,
|
| 275 |
+
3,
|
| 276 |
+
self.input_height,
|
| 277 |
+
self.input_width,
|
| 278 |
+
) # [b, c, hxp, wxp]
|
| 279 |
+
return x
|
| 280 |
+
|
| 281 |
+
def encode(self, x):
|
| 282 |
+
# patchify
|
| 283 |
+
x = self.patch_embed(x)
|
| 284 |
+
|
| 285 |
+
# encoder
|
| 286 |
+
for blk in self.encoder:
|
| 287 |
+
x = blk(x)
|
| 288 |
+
x = self.enc_norm(x)
|
| 289 |
+
|
| 290 |
+
# bottleneck
|
| 291 |
+
moments = self.quant_conv(x)
|
| 292 |
+
if not self.use_variational:
|
| 293 |
+
moments = torch.cat((moments, torch.zeros_like(moments)), 2)
|
| 294 |
+
posterior = DiagonalGaussianDistribution(moments, deterministic=(not self.use_variational), dim=2)
|
| 295 |
+
return posterior
|
| 296 |
+
|
| 297 |
+
def decode(self, z):
|
| 298 |
+
# bottleneck
|
| 299 |
+
z = self.post_quant_conv(z)
|
| 300 |
+
|
| 301 |
+
# decoder
|
| 302 |
+
for blk in self.decoder:
|
| 303 |
+
z = blk(z)
|
| 304 |
+
z = self.dec_norm(z)
|
| 305 |
+
|
| 306 |
+
# predictor
|
| 307 |
+
z = self.predictor(z)
|
| 308 |
+
|
| 309 |
+
# unpatchify
|
| 310 |
+
dec = self.unpatchify(z)
|
| 311 |
+
return dec
|
| 312 |
+
|
| 313 |
+
def autoencode(self, input, sample_posterior=True):
|
| 314 |
+
posterior = self.encode(input)
|
| 315 |
+
if self.use_variational and sample_posterior:
|
| 316 |
+
z = posterior.sample()
|
| 317 |
+
else:
|
| 318 |
+
z = posterior.mode()
|
| 319 |
+
dec = self.decode(z)
|
| 320 |
+
return dec, posterior, z
|
| 321 |
+
|
| 322 |
+
def get_input(self, batch, k):
|
| 323 |
+
x = batch[k]
|
| 324 |
+
if len(x.shape) == 3:
|
| 325 |
+
x = x[..., None]
|
| 326 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
def forward(self, inputs, labels, split="train"):
|
| 330 |
+
rec, post, latent = self.autoencode(inputs)
|
| 331 |
+
return rec, post, latent
|
| 332 |
+
|
| 333 |
+
def get_last_layer(self):
|
| 334 |
+
return self.predictor.weight
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def ViT_L_20_Shallow_Encoder(**kwargs):
|
| 338 |
+
if "latent_dim" in kwargs:
|
| 339 |
+
latent_dim = kwargs.pop("latent_dim")
|
| 340 |
+
else:
|
| 341 |
+
latent_dim = 16
|
| 342 |
+
return AutoencoderKL(
|
| 343 |
+
latent_dim=latent_dim,
|
| 344 |
+
patch_size=20,
|
| 345 |
+
enc_dim=1024,
|
| 346 |
+
enc_depth=6,
|
| 347 |
+
enc_heads=16,
|
| 348 |
+
dec_dim=1024,
|
| 349 |
+
dec_depth=12,
|
| 350 |
+
dec_heads=16,
|
| 351 |
+
input_height=360,
|
| 352 |
+
input_width=640,
|
| 353 |
+
**kwargs,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
VAE_models = {
|
| 358 |
+
"vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder,
|
| 359 |
+
}
|
algorithms/worldmem/pose_prediction.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from omegaconf import DictConfig
|
| 2 |
+
import torch
|
| 3 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
| 4 |
+
from algorithms.common.metrics import (
|
| 5 |
+
FrechetInceptionDistance,
|
| 6 |
+
LearnedPerceptualImagePatchSimilarity,
|
| 7 |
+
FrechetVideoDistance,
|
| 8 |
+
)
|
| 9 |
+
from .df_base import DiffusionForcingBase
|
| 10 |
+
from utils.logging_utils import log_video, get_validation_metrics_for_videos
|
| 11 |
+
from .models.vae import VAE_models
|
| 12 |
+
from .models.dit import DiT_models
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from torch import autocast
|
| 15 |
+
import numpy as np
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from .models.pose_prediction import PosePredictionNet
|
| 19 |
+
import torchvision.transforms.functional as TF
|
| 20 |
+
import random
|
| 21 |
+
from torchvision.transforms import InterpolationMode
|
| 22 |
+
from PIL import Image
|
| 23 |
+
import math
|
| 24 |
+
from packaging import version as pver
|
| 25 |
+
import torch.distributed as dist
|
| 26 |
+
import matplotlib.pyplot as plt
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import math
|
| 30 |
+
import wandb
|
| 31 |
+
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
from algorithms.common.base_pytorch_algo import BasePytorchAlgo
|
| 34 |
+
|
| 35 |
+
class PosePrediction(BasePytorchAlgo):
|
| 36 |
+
|
| 37 |
+
def __init__(self, cfg: DictConfig):
|
| 38 |
+
|
| 39 |
+
super().__init__(cfg)
|
| 40 |
+
|
| 41 |
+
def _build_model(self):
|
| 42 |
+
self.pose_prediction_model = PosePredictionNet()
|
| 43 |
+
vae = VAE_models["vit-l-20-shallow-encoder"]()
|
| 44 |
+
self.vae = vae.eval()
|
| 45 |
+
|
| 46 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
| 47 |
+
xs, conditions, pose_conditions= batch
|
| 48 |
+
pose_conditions[:,:,3:] = pose_conditions[:,:,3:] // 15
|
| 49 |
+
xs = self.encode(xs)
|
| 50 |
+
|
| 51 |
+
b,f,c,h,w = xs.shape
|
| 52 |
+
xs = xs[:,:-1].reshape(-1, c, h, w)
|
| 53 |
+
conditions = conditions[:,1:].reshape(-1, 25)
|
| 54 |
+
offset_gt = pose_conditions[:,1:] - pose_conditions[:,:-1]
|
| 55 |
+
pose_conditions = pose_conditions[:,:-1].reshape(-1, 5)
|
| 56 |
+
offset_gt = offset_gt.reshape(-1, 5)
|
| 57 |
+
offset_gt[:, 3][offset_gt[:, 3]==23] = -1
|
| 58 |
+
offset_gt[:, 3][offset_gt[:, 3]==-23] = 1
|
| 59 |
+
offset_gt[:, 4][offset_gt[:, 4]==23] = -1
|
| 60 |
+
offset_gt[:, 4][offset_gt[:, 4]==-23] = 1
|
| 61 |
+
|
| 62 |
+
offset_pred = self.pose_prediction_model(xs, conditions, pose_conditions)
|
| 63 |
+
criterion = nn.MSELoss()
|
| 64 |
+
loss = criterion(offset_pred, offset_gt)
|
| 65 |
+
if batch_idx % 200 == 0:
|
| 66 |
+
self.log("training/loss", loss.cpu())
|
| 67 |
+
output_dict = {
|
| 68 |
+
"loss": loss}
|
| 69 |
+
return output_dict
|
| 70 |
+
|
| 71 |
+
def encode(self, x):
|
| 72 |
+
# vae encoding
|
| 73 |
+
B = x.shape[1]
|
| 74 |
+
T = x.shape[0]
|
| 75 |
+
H, W = x.shape[-2:]
|
| 76 |
+
scaling_factor = 0.07843137255
|
| 77 |
+
|
| 78 |
+
x = rearrange(x, "t b c h w -> (t b) c h w")
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
with autocast("cuda", dtype=torch.half):
|
| 81 |
+
x = self.vae.encode(x * 2 - 1).mean * scaling_factor
|
| 82 |
+
x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size)
|
| 83 |
+
# x = x[:, :n_prompt_frames]
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
def decode(self, x):
|
| 87 |
+
total_frames = x.shape[0]
|
| 88 |
+
scaling_factor = 0.07843137255
|
| 89 |
+
x = rearrange(x, "t b c h w -> (t b) (h w) c")
|
| 90 |
+
with torch.no_grad():
|
| 91 |
+
with autocast("cuda", dtype=torch.half):
|
| 92 |
+
x = (self.vae.decode(x / scaling_factor) + 1) / 2
|
| 93 |
+
|
| 94 |
+
x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
|
| 98 |
+
xs, conditions, pose_conditions= batch
|
| 99 |
+
pose_conditions[:,:,3:] = pose_conditions[:,:,3:] // 15
|
| 100 |
+
xs = self.encode(xs)
|
| 101 |
+
|
| 102 |
+
b,f,c,h,w = xs.shape
|
| 103 |
+
xs = xs[:,:-1].reshape(-1, c, h, w)
|
| 104 |
+
conditions = conditions[:,1:].reshape(-1, 25)
|
| 105 |
+
offset_gt = pose_conditions[:,1:] - pose_conditions[:,:-1]
|
| 106 |
+
pose_conditions = pose_conditions[:,:-1].reshape(-1, 5)
|
| 107 |
+
offset_gt = offset_gt.reshape(-1, 5)
|
| 108 |
+
offset_gt[:, 3][offset_gt[:, 3]==23] = -1
|
| 109 |
+
offset_gt[:, 3][offset_gt[:, 3]==-23] = 1
|
| 110 |
+
offset_gt[:, 4][offset_gt[:, 4]==23] = -1
|
| 111 |
+
offset_gt[:, 4][offset_gt[:, 4]==-23] = 1
|
| 112 |
+
|
| 113 |
+
offset_pred = self.pose_prediction_model(xs, conditions, pose_conditions)
|
| 114 |
+
|
| 115 |
+
criterion = nn.MSELoss()
|
| 116 |
+
loss = criterion(offset_pred, offset_gt)
|
| 117 |
+
|
| 118 |
+
if batch_idx % 200 == 0:
|
| 119 |
+
self.log("validation/loss", loss.cpu())
|
| 120 |
+
output_dict = {
|
| 121 |
+
"loss": loss}
|
| 122 |
+
return
|
| 123 |
+
|
| 124 |
+
@torch.no_grad()
|
| 125 |
+
def interactive(self, batch, context_frames, device):
|
| 126 |
+
with torch.cuda.amp.autocast():
|
| 127 |
+
condition_similar_length = self.condition_similar_length
|
| 128 |
+
# xs_raw, conditions, pose_conditions, c2w_mat, masks, frame_idx = self._preprocess_batch(batch)
|
| 129 |
+
|
| 130 |
+
first_frame, new_conditions, new_pose_conditions, new_c2w_mat, new_frame_idx = batch
|
| 131 |
+
|
| 132 |
+
if self.frames is None:
|
| 133 |
+
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
| 134 |
+
self.frames = first_frame_encode.to(device)
|
| 135 |
+
self.actions = new_conditions[None, None].to(device)
|
| 136 |
+
self.poses = new_pose_conditions[None, None].to(device)
|
| 137 |
+
self.memory_c2w = new_c2w_mat[None, None].to(device)
|
| 138 |
+
self.frame_idx = torch.tensor([[new_frame_idx]]).to(device)
|
| 139 |
+
return first_frame
|
| 140 |
+
else:
|
| 141 |
+
self.actions = torch.cat([self.actions, new_conditions[None, None].to(device)])
|
| 142 |
+
self.poses = torch.cat([self.poses, new_pose_conditions[None, None].to(device)])
|
| 143 |
+
self.memory_c2w = torch.cat([self.memory_c2w, new_c2w_mat[None, None].to(device)])
|
| 144 |
+
self.frame_idx = torch.cat([self.frame_idx, torch.tensor([[new_frame_idx]]).to(device)])
|
| 145 |
+
|
| 146 |
+
conditions = self.actions.clone()
|
| 147 |
+
pose_conditions = self.poses.clone()
|
| 148 |
+
c2w_mat = self.memory_c2w .clone()
|
| 149 |
+
frame_idx = self.frame_idx.clone()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
curr_frame = 0
|
| 153 |
+
horizon = 1
|
| 154 |
+
batch_size = 1
|
| 155 |
+
n_frames = curr_frame + horizon
|
| 156 |
+
# context
|
| 157 |
+
n_context_frames = context_frames // self.frame_stack
|
| 158 |
+
xs_pred = self.frames[:n_context_frames].clone()
|
| 159 |
+
curr_frame += n_context_frames
|
| 160 |
+
|
| 161 |
+
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
| 162 |
+
|
| 163 |
+
# generation on frame
|
| 164 |
+
scheduling_matrix = self._generate_scheduling_matrix(horizon)
|
| 165 |
+
chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
|
| 166 |
+
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
|
| 167 |
+
|
| 168 |
+
xs_pred = torch.cat([xs_pred, chunk], 0)
|
| 169 |
+
|
| 170 |
+
# sliding window: only input the last n_tokens frames
|
| 171 |
+
start_frame = max(0, curr_frame + horizon - self.n_tokens)
|
| 172 |
+
|
| 173 |
+
pbar.set_postfix(
|
| 174 |
+
{
|
| 175 |
+
"start": start_frame,
|
| 176 |
+
"end": curr_frame + horizon,
|
| 177 |
+
}
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
if condition_similar_length:
|
| 181 |
+
|
| 182 |
+
if curr_frame < condition_similar_length:
|
| 183 |
+
random_idx = [i for i in range(curr_frame)] + [0] * (condition_similar_length-curr_frame)
|
| 184 |
+
random_idx = np.repeat(np.array(random_idx)[:,None], xs_pred.shape[1], -1)
|
| 185 |
+
else:
|
| 186 |
+
num_samples = 10000
|
| 187 |
+
radius = 30
|
| 188 |
+
samples = torch.rand((num_samples, 1), device=pose_conditions.device)
|
| 189 |
+
angles = 2 * np.pi * torch.rand((num_samples,), device=pose_conditions.device)
|
| 190 |
+
# points = radius * torch.sqrt(samples) * torch.stack((torch.cos(angles), torch.sin(angles)), dim=1)
|
| 191 |
+
|
| 192 |
+
points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device)
|
| 193 |
+
points = points[:, None].repeat(1, pose_conditions.shape[1], 1)
|
| 194 |
+
points += pose_conditions[curr_frame, :, :3][None]
|
| 195 |
+
fov_half_h = torch.tensor(105/2, device=pose_conditions.device)
|
| 196 |
+
fov_half_v = torch.tensor(75/2, device=pose_conditions.device)
|
| 197 |
+
# in_fov1 = is_inside_fov(points, pose_conditions[curr_frame, :, [0, 2]], pose_conditions[curr_frame, :, -1], fov_half)
|
| 198 |
+
|
| 199 |
+
in_fov1 = is_inside_fov_3d_hv(points, pose_conditions[curr_frame, :, :3],
|
| 200 |
+
pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1],
|
| 201 |
+
fov_half_h, fov_half_v)
|
| 202 |
+
|
| 203 |
+
in_fov_list = []
|
| 204 |
+
for pc in pose_conditions[:curr_frame]:
|
| 205 |
+
in_fov_list.append(is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1],
|
| 206 |
+
fov_half_h, fov_half_v))
|
| 207 |
+
|
| 208 |
+
in_fov_list = torch.stack(in_fov_list)
|
| 209 |
+
# v3
|
| 210 |
+
random_idx = []
|
| 211 |
+
|
| 212 |
+
for csl in range(self.condition_similar_length // 2):
|
| 213 |
+
overlap_ratio = ((in_fov1[None].bool() & in_fov_list).sum(1))/in_fov1.sum()
|
| 214 |
+
# mask = distance > (in_fov1.bool().sum(0) / 4)
|
| 215 |
+
#_, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
|
| 216 |
+
|
| 217 |
+
# if csl > self.condition_similar_length:
|
| 218 |
+
# _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
|
| 219 |
+
# else:
|
| 220 |
+
# _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
|
| 221 |
+
|
| 222 |
+
_, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
|
| 223 |
+
# _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
|
| 224 |
+
|
| 225 |
+
# if curr_frame >=93:
|
| 226 |
+
# import pdb;pdb.set_trace()
|
| 227 |
+
|
| 228 |
+
# start_time = time.time()
|
| 229 |
+
cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
|
| 230 |
+
range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
|
| 231 |
+
cos_sim = cos_sim.mean((-2,-1))
|
| 232 |
+
|
| 233 |
+
mask_sim = cos_sim>0.9
|
| 234 |
+
in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
|
| 235 |
+
|
| 236 |
+
random_idx.append(r_idx)
|
| 237 |
+
|
| 238 |
+
for bi in range(conditions.shape[1]):
|
| 239 |
+
if len(torch.nonzero(conditions[:,bi,24] == 1))==0:
|
| 240 |
+
pass
|
| 241 |
+
else:
|
| 242 |
+
last_idx = torch.nonzero(conditions[:,bi,24] == 1)[-1]
|
| 243 |
+
in_fov_list[:last_idx,:,bi] = False
|
| 244 |
+
|
| 245 |
+
for csl in range(self.condition_similar_length // 2):
|
| 246 |
+
overlap_ratio = ((in_fov1[None].bool() & in_fov_list).sum(1))/in_fov1.sum()
|
| 247 |
+
# mask = distance > (in_fov1.bool().sum(0) / 4)
|
| 248 |
+
#_, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
|
| 249 |
+
|
| 250 |
+
# if csl > self.condition_similar_length:
|
| 251 |
+
# _, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
|
| 252 |
+
# else:
|
| 253 |
+
# _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
|
| 254 |
+
|
| 255 |
+
_, r_idx = torch.topk(overlap_ratio, k=1, dim=0)
|
| 256 |
+
# _, r_idx = torch.topk(overlap_ratio / tensor_max_with_number((frame_idx[curr_frame] - frame_idx[:curr_frame]), 10), k=1, dim=0)
|
| 257 |
+
|
| 258 |
+
# if curr_frame >=93:
|
| 259 |
+
# import pdb;pdb.set_trace()
|
| 260 |
+
|
| 261 |
+
# start_time = time.time()
|
| 262 |
+
cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
|
| 263 |
+
range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
|
| 264 |
+
cos_sim = cos_sim.mean((-2,-1))
|
| 265 |
+
|
| 266 |
+
mask_sim = cos_sim>0.9
|
| 267 |
+
in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
|
| 268 |
+
|
| 269 |
+
random_idx.append(r_idx)
|
| 270 |
+
|
| 271 |
+
random_idx = torch.cat(random_idx).cpu()
|
| 272 |
+
condition_similar_length = len(random_idx)
|
| 273 |
+
|
| 274 |
+
xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
|
| 275 |
+
|
| 276 |
+
if condition_similar_length:
|
| 277 |
+
# import pdb;pdb.set_trace()
|
| 278 |
+
padding = torch.zeros((condition_similar_length,) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype)
|
| 279 |
+
input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], padding], dim=0)
|
| 280 |
+
if self.pose_cond_dim:
|
| 281 |
+
# if not self.use_plucker:
|
| 282 |
+
input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
|
| 283 |
+
|
| 284 |
+
if self.use_plucker:
|
| 285 |
+
if self.all_zero_frame:
|
| 286 |
+
frame_idx_list = []
|
| 287 |
+
input_pose_condition = []
|
| 288 |
+
for i in range(start_frame, curr_frame + horizon):
|
| 289 |
+
input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]]).clone(), 0, focal_length=self.focal_length, is_old_setting=self.old_setting).to(xs_pred.dtype))
|
| 290 |
+
frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]-frame_idx[i:i+1]]))
|
| 291 |
+
input_pose_condition = torch.cat(input_pose_condition)
|
| 292 |
+
frame_idx_list = torch.cat(frame_idx_list)
|
| 293 |
+
|
| 294 |
+
# print(frame_idx_list[:,0])
|
| 295 |
+
else:
|
| 296 |
+
# print(curr_frame-start_frame)
|
| 297 |
+
# input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
|
| 298 |
+
# import pdb;pdb.set_trace()
|
| 299 |
+
if self.last_frame_refer:
|
| 300 |
+
input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[-1:]], dim=0).clone()
|
| 301 |
+
else:
|
| 302 |
+
input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
|
| 303 |
+
|
| 304 |
+
if self.zero_curr:
|
| 305 |
+
# print("="*50)
|
| 306 |
+
input_pose_condition = convert_to_plucker(input_pose_condition, curr_frame-start_frame, focal_length=self.focal_length, is_old_setting=self.old_setting)
|
| 307 |
+
# input_pose_condition[:curr_frame-start_frame] = input_pose_condition[curr_frame-start_frame:curr_frame-start_frame+1]
|
| 308 |
+
# input_pose_condition = convert_to_plucker(input_pose_condition, -self.condition_similar_length-1, focal_length=self.focal_length)
|
| 309 |
+
else:
|
| 310 |
+
input_pose_condition = convert_to_plucker(input_pose_condition, -condition_similar_length, focal_length=self.focal_length, is_old_setting=self.old_setting)
|
| 311 |
+
frame_idx_list = None
|
| 312 |
+
else:
|
| 313 |
+
input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(xs_pred.shape[1])], range(xs_pred.shape[1])]], dim=0).clone()
|
| 314 |
+
frame_idx_list = None
|
| 315 |
+
else:
|
| 316 |
+
input_condition = conditions[start_frame : curr_frame + horizon]
|
| 317 |
+
input_pose_condition = None
|
| 318 |
+
frame_idx_list = None
|
| 319 |
+
|
| 320 |
+
for m in range(scheduling_matrix.shape[0] - 1):
|
| 321 |
+
from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[
|
| 322 |
+
:, None
|
| 323 |
+
].repeat(batch_size, axis=1)
|
| 324 |
+
to_noise_levels = np.concatenate(
|
| 325 |
+
(
|
| 326 |
+
np.zeros((curr_frame,), dtype=np.int64),
|
| 327 |
+
scheduling_matrix[m + 1],
|
| 328 |
+
)
|
| 329 |
+
)[
|
| 330 |
+
:, None
|
| 331 |
+
].repeat(batch_size, axis=1)
|
| 332 |
+
|
| 333 |
+
if condition_similar_length:
|
| 334 |
+
from_noise_levels = np.concatenate([from_noise_levels, np.zeros((condition_similar_length,from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
|
| 335 |
+
to_noise_levels = np.concatenate([to_noise_levels, np.zeros((condition_similar_length,from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
|
| 336 |
+
|
| 337 |
+
from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
|
| 338 |
+
to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
if input_pose_condition is not None:
|
| 342 |
+
input_pose_condition = input_pose_condition.to(xs_pred.dtype)
|
| 343 |
+
|
| 344 |
+
xs_pred[start_frame:] = self.diffusion_model.sample_step(
|
| 345 |
+
xs_pred[start_frame:],
|
| 346 |
+
input_condition,
|
| 347 |
+
input_pose_condition,
|
| 348 |
+
from_noise_levels[start_frame:],
|
| 349 |
+
to_noise_levels[start_frame:],
|
| 350 |
+
current_frame=curr_frame,
|
| 351 |
+
mode="validation",
|
| 352 |
+
reference_length=condition_similar_length,
|
| 353 |
+
frame_idx=frame_idx_list
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# if curr_frame > 14:
|
| 357 |
+
# import pdb;pdb.set_trace()
|
| 358 |
+
|
| 359 |
+
# if xs_pred_back is not None:
|
| 360 |
+
# xs_pred = torch.cat([xs_pred[:6], xs_pred_back[6:12], xs_pred[6:]], dim=0)
|
| 361 |
+
|
| 362 |
+
# import pdb;pdb.set_trace()
|
| 363 |
+
if condition_similar_length: # and curr_frame+1!=n_frames:
|
| 364 |
+
xs_pred = xs_pred[:-condition_similar_length]
|
| 365 |
+
|
| 366 |
+
curr_frame += horizon
|
| 367 |
+
pbar.update(horizon)
|
| 368 |
+
|
| 369 |
+
self.frames = torch.cat([self.frames, xs_pred[n_context_frames:]])
|
| 370 |
+
|
| 371 |
+
xs_pred = self.decode(xs_pred[n_context_frames:])
|
| 372 |
+
|
| 373 |
+
return xs_pred[-1,0].cpu()
|
| 374 |
+
|
app.py
ADDED
|
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import sys
|
| 5 |
+
import subprocess
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
import hydra
|
| 10 |
+
from omegaconf import DictConfig, OmegaConf
|
| 11 |
+
from omegaconf.omegaconf import open_dict
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
import torch
|
| 15 |
+
import torchvision.transforms as transforms
|
| 16 |
+
import cv2
|
| 17 |
+
import subprocess
|
| 18 |
+
from PIL import Image
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
import spaces
|
| 21 |
+
from algorithms.worldmem import WorldMemMinecraft
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
import tempfile
|
| 24 |
+
import os
|
| 25 |
+
import requests
|
| 26 |
+
from huggingface_hub import model_info
|
| 27 |
+
|
| 28 |
+
from experiments.exp_base import load_custom_checkpoint
|
| 29 |
+
|
| 30 |
+
torch.set_float32_matmul_precision("high")
|
| 31 |
+
|
| 32 |
+
def download_assets_if_needed():
|
| 33 |
+
ASSETS_URL_BASE = "https://huggingface.co/spaces/yslan/worldmem/resolve/main/assets/examples"
|
| 34 |
+
ASSETS_DIR = "assets/examples"
|
| 35 |
+
ASSETS = ['case1.npz', 'case2.npz', 'case3.npz', 'case4.npz']
|
| 36 |
+
|
| 37 |
+
if not os.path.exists(ASSETS_DIR):
|
| 38 |
+
os.makedirs(ASSETS_DIR)
|
| 39 |
+
|
| 40 |
+
# Download assets if they don't exist (total 4 files)
|
| 41 |
+
for filename in ASSETS:
|
| 42 |
+
filepath = os.path.join(ASSETS_DIR, filename)
|
| 43 |
+
if not os.path.exists(filepath):
|
| 44 |
+
print(f"Downloading {filename}...")
|
| 45 |
+
url = f"{ASSETS_URL_BASE}/{filename}"
|
| 46 |
+
response = requests.get(url)
|
| 47 |
+
if response.status_code == 200:
|
| 48 |
+
with open(filepath, "wb") as f:
|
| 49 |
+
f.write(response.content)
|
| 50 |
+
else:
|
| 51 |
+
print(f"Failed to download {filename}: {response.status_code}")
|
| 52 |
+
|
| 53 |
+
def parse_input_to_tensor(input_str):
|
| 54 |
+
"""
|
| 55 |
+
Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation
|
| 56 |
+
of the corresponding action key.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS").
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action.
|
| 63 |
+
"""
|
| 64 |
+
# Get the length of the input sequence
|
| 65 |
+
seq_len = len(input_str)
|
| 66 |
+
|
| 67 |
+
# Initialize a zero tensor of shape (seq_len, 25)
|
| 68 |
+
action_tensor = torch.zeros((seq_len, 25))
|
| 69 |
+
|
| 70 |
+
# Iterate through the input string and update the corresponding positions
|
| 71 |
+
for i, char in enumerate(input_str):
|
| 72 |
+
action, value = KEY_TO_ACTION.get(char.upper()) # Convert to uppercase to handle case insensitivity
|
| 73 |
+
if action and action in ACTION_KEYS:
|
| 74 |
+
index = ACTION_KEYS.index(action)
|
| 75 |
+
action_tensor[i, index] = value # Set the corresponding action index to 1
|
| 76 |
+
|
| 77 |
+
return action_tensor
|
| 78 |
+
|
| 79 |
+
def load_image_as_tensor(image_path: str) -> torch.Tensor:
|
| 80 |
+
"""
|
| 81 |
+
Load an image and convert it to a 0-1 normalized tensor.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
image_path (str): Path to the image file.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1].
|
| 88 |
+
"""
|
| 89 |
+
if isinstance(image_path, str):
|
| 90 |
+
image = Image.open(image_path).convert("RGB") # Ensure it's RGB
|
| 91 |
+
else:
|
| 92 |
+
image = image_path
|
| 93 |
+
transform = transforms.Compose([
|
| 94 |
+
transforms.ToTensor(), # Converts to tensor and normalizes to [0,1]
|
| 95 |
+
])
|
| 96 |
+
return transform(image)
|
| 97 |
+
|
| 98 |
+
def enable_amp(model, precision="16-mixed"):
|
| 99 |
+
original_forward = model.forward
|
| 100 |
+
|
| 101 |
+
def amp_forward(*args, **kwargs):
|
| 102 |
+
with torch.autocast("cuda", dtype=torch.float16 if precision == "16-mixed" else torch.bfloat16):
|
| 103 |
+
return original_forward(*args, **kwargs)
|
| 104 |
+
|
| 105 |
+
model.forward = amp_forward
|
| 106 |
+
return model
|
| 107 |
+
|
| 108 |
+
download_assets_if_needed()
|
| 109 |
+
|
| 110 |
+
ACTION_KEYS = [
|
| 111 |
+
"inventory",
|
| 112 |
+
"ESC",
|
| 113 |
+
"hotbar.1",
|
| 114 |
+
"hotbar.2",
|
| 115 |
+
"hotbar.3",
|
| 116 |
+
"hotbar.4",
|
| 117 |
+
"hotbar.5",
|
| 118 |
+
"hotbar.6",
|
| 119 |
+
"hotbar.7",
|
| 120 |
+
"hotbar.8",
|
| 121 |
+
"hotbar.9",
|
| 122 |
+
"forward",
|
| 123 |
+
"back",
|
| 124 |
+
"left",
|
| 125 |
+
"right",
|
| 126 |
+
"cameraY",
|
| 127 |
+
"cameraX",
|
| 128 |
+
"jump",
|
| 129 |
+
"sneak",
|
| 130 |
+
"sprint",
|
| 131 |
+
"swapHands",
|
| 132 |
+
"attack",
|
| 133 |
+
"use",
|
| 134 |
+
"pickItem",
|
| 135 |
+
"drop",
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
# Mapping of input keys to action names
|
| 139 |
+
KEY_TO_ACTION = {
|
| 140 |
+
"Q": ("forward", 1),
|
| 141 |
+
"E": ("back", 1),
|
| 142 |
+
"W": ("cameraY", -1),
|
| 143 |
+
"S": ("cameraY", 1),
|
| 144 |
+
"A": ("cameraX", -1),
|
| 145 |
+
"D": ("cameraX", 1),
|
| 146 |
+
"U": ("drop", 1),
|
| 147 |
+
"N": ("noop", 1),
|
| 148 |
+
"1": ("hotbar.1", 1),
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
example_images = [
|
| 152 |
+
["1", "assets/ice_plains.png", "turn rightgo backward→look up→turn left→look down→turn right→go forward→turn left", 20, 3, 8],
|
| 153 |
+
["2", "assets/place.png", "put item→go backward→put item→go backward→go around", 20, 3, 8],
|
| 154 |
+
["3", "assets/rain_sunflower_plains.png", "turn right→look up→turn right→look down→turn left→go backward→turn left", 20, 3, 8],
|
| 155 |
+
["4", "assets/desert.png", "turn 360 degree→turn right→go forward→turn left", 20, 3, 8],
|
| 156 |
+
]
|
| 157 |
+
|
| 158 |
+
video_frames = []
|
| 159 |
+
input_history = ""
|
| 160 |
+
ICE_PLAINS_IMAGE = "assets/ice_plains.png"
|
| 161 |
+
DESERT_IMAGE = "assets/desert.png"
|
| 162 |
+
SAVANNA_IMAGE = "assets/savanna.png"
|
| 163 |
+
PLAINS_IMAGE = "assets/plans.png"
|
| 164 |
+
PLACE_IMAGE = "assets/place.png"
|
| 165 |
+
SUNFLOWERS_IMAGE = "assets/sunflower_plains.png"
|
| 166 |
+
SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png"
|
| 167 |
+
|
| 168 |
+
device = torch.device('cuda')
|
| 169 |
+
|
| 170 |
+
def save_video(frames, path="output.mp4", fps=10):
|
| 171 |
+
temp_path = path[:-4] + "_temp.mp4"
|
| 172 |
+
h, w, _ = frames[0].shape
|
| 173 |
+
|
| 174 |
+
out = cv2.VideoWriter(temp_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
| 175 |
+
for frame in frames:
|
| 176 |
+
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
|
| 177 |
+
out.release()
|
| 178 |
+
|
| 179 |
+
ffmpeg_cmd = [
|
| 180 |
+
"ffmpeg", "-y", "-i", temp_path,
|
| 181 |
+
"-c:v", "libx264", "-crf", "23", "-preset", "medium",
|
| 182 |
+
path
|
| 183 |
+
]
|
| 184 |
+
subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
| 185 |
+
os.remove(temp_path)
|
| 186 |
+
|
| 187 |
+
cfg = OmegaConf.load("configurations/huggingface.yaml")
|
| 188 |
+
worldmem = WorldMemMinecraft(cfg)
|
| 189 |
+
load_custom_checkpoint(algo=worldmem.diffusion_model, checkpoint_path=cfg.diffusion_path)
|
| 190 |
+
load_custom_checkpoint(algo=worldmem.vae, checkpoint_path=cfg.vae_path)
|
| 191 |
+
load_custom_checkpoint(algo=worldmem.pose_prediction_model, checkpoint_path=cfg.pose_predictor_path)
|
| 192 |
+
worldmem.to("cuda").eval()
|
| 193 |
+
# worldmem = enable_amp(worldmem, precision="16-mixed")
|
| 194 |
+
|
| 195 |
+
actions = np.zeros((1, 25), dtype=np.float32)
|
| 196 |
+
poses = np.zeros((1, 5), dtype=np.float32)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_duration_single_image_to_long_video(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
|
| 201 |
+
memory_poses, memory_c2w, memory_frame_idx):
|
| 202 |
+
return 5 * len(action) if memory_actions is not None else 5
|
| 203 |
+
|
| 204 |
+
@spaces.GPU(duration=get_duration_single_image_to_long_video)
|
| 205 |
+
def run_interactive(first_frame, action, first_pose, device, memory_latent_frames, memory_actions,
|
| 206 |
+
memory_poses, memory_c2w, memory_frame_idx):
|
| 207 |
+
new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = worldmem.interactive(first_frame,
|
| 208 |
+
action,
|
| 209 |
+
first_pose,
|
| 210 |
+
device=device,
|
| 211 |
+
memory_latent_frames=memory_latent_frames,
|
| 212 |
+
memory_actions=memory_actions,
|
| 213 |
+
memory_poses=memory_poses,
|
| 214 |
+
memory_c2w=memory_c2w,
|
| 215 |
+
memory_frame_idx=memory_frame_idx)
|
| 216 |
+
|
| 217 |
+
return new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
| 218 |
+
|
| 219 |
+
def set_denoising_steps(denoising_steps, sampling_timesteps_state):
|
| 220 |
+
worldmem.sampling_timesteps = denoising_steps
|
| 221 |
+
worldmem.diffusion_model.sampling_timesteps = denoising_steps
|
| 222 |
+
sampling_timesteps_state = denoising_steps
|
| 223 |
+
print("set denoising steps to", worldmem.sampling_timesteps)
|
| 224 |
+
return sampling_timesteps_state
|
| 225 |
+
|
| 226 |
+
def set_context_length(context_length, sampling_context_length_state):
|
| 227 |
+
worldmem.n_tokens = context_length
|
| 228 |
+
sampling_context_length_state = context_length
|
| 229 |
+
print("set context length to", worldmem.n_tokens)
|
| 230 |
+
return sampling_context_length_state
|
| 231 |
+
|
| 232 |
+
def set_memory_condition_length(memory_condition_length, sampling_memory_condition_length_state):
|
| 233 |
+
worldmem.memory_condition_length = memory_condition_length
|
| 234 |
+
sampling_memory_condition_length_state = memory_condition_length
|
| 235 |
+
print("set memory length to", worldmem.memory_condition_length)
|
| 236 |
+
return sampling_memory_condition_length_state
|
| 237 |
+
|
| 238 |
+
def set_next_frame_length(next_frame_length, sampling_next_frame_length_state):
|
| 239 |
+
worldmem.next_frame_length = next_frame_length
|
| 240 |
+
sampling_next_frame_length_state = next_frame_length
|
| 241 |
+
print("set next frame length to", worldmem.next_frame_length)
|
| 242 |
+
return sampling_next_frame_length_state
|
| 243 |
+
|
| 244 |
+
def generate(keys, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
|
| 245 |
+
input_actions = parse_input_to_tensor(keys)
|
| 246 |
+
|
| 247 |
+
if memory_latent_frames is None:
|
| 248 |
+
new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
|
| 249 |
+
actions[0],
|
| 250 |
+
poses[0],
|
| 251 |
+
device=device,
|
| 252 |
+
memory_latent_frames=memory_latent_frames,
|
| 253 |
+
memory_actions=memory_actions,
|
| 254 |
+
memory_poses=memory_poses,
|
| 255 |
+
memory_c2w=memory_c2w,
|
| 256 |
+
memory_frame_idx=memory_frame_idx)
|
| 257 |
+
|
| 258 |
+
new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
|
| 259 |
+
input_actions,
|
| 260 |
+
None,
|
| 261 |
+
device=device,
|
| 262 |
+
memory_latent_frames=memory_latent_frames,
|
| 263 |
+
memory_actions=memory_actions,
|
| 264 |
+
memory_poses=memory_poses,
|
| 265 |
+
memory_c2w=memory_c2w,
|
| 266 |
+
memory_frame_idx=memory_frame_idx)
|
| 267 |
+
|
| 268 |
+
video_frames = np.concatenate([video_frames, new_frame[:,0]])
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
out_video = video_frames.transpose(0,2,3,1).copy()
|
| 272 |
+
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
| 273 |
+
out_video = (out_video * 255).astype(np.uint8)
|
| 274 |
+
|
| 275 |
+
last_frame = out_video[-1].copy()
|
| 276 |
+
border_thickness = 2
|
| 277 |
+
out_video[-len(new_frame):, :border_thickness, :, :] = [255, 0, 0]
|
| 278 |
+
out_video[-len(new_frame):, -border_thickness:, :, :] = [255, 0, 0]
|
| 279 |
+
out_video[-len(new_frame):, :, :border_thickness, :] = [255, 0, 0]
|
| 280 |
+
out_video[-len(new_frame):, :, -border_thickness:, :] = [255, 0, 0]
|
| 281 |
+
|
| 282 |
+
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
| 283 |
+
save_video(out_video, temporal_video_path)
|
| 284 |
+
input_history += keys
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
# now = datetime.now()
|
| 288 |
+
# folder_name = now.strftime("%Y-%m-%d_%H-%M-%S")
|
| 289 |
+
# folder_path = os.path.join("/mnt/xiaozeqi/worldmem/output_material", folder_name)
|
| 290 |
+
# os.makedirs(folder_path, exist_ok=True)
|
| 291 |
+
# data_dict = {
|
| 292 |
+
# "input_history": input_history,
|
| 293 |
+
# "video_frames": video_frames,
|
| 294 |
+
# "memory_latent_frames": memory_latent_frames,
|
| 295 |
+
# "memory_actions": memory_actions,
|
| 296 |
+
# "memory_poses": memory_poses,
|
| 297 |
+
# "memory_c2w": memory_c2w,
|
| 298 |
+
# "memory_frame_idx": memory_frame_idx,
|
| 299 |
+
# }
|
| 300 |
+
|
| 301 |
+
# np.savez(os.path.join(folder_path, "data_bundle.npz"), **data_dict)
|
| 302 |
+
|
| 303 |
+
return last_frame, temporal_video_path, input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
| 304 |
+
|
| 305 |
+
def reset(selected_image):
|
| 306 |
+
memory_latent_frames = None
|
| 307 |
+
memory_poses = None
|
| 308 |
+
memory_actions = None
|
| 309 |
+
memory_c2w = None
|
| 310 |
+
memory_frame_idx = None
|
| 311 |
+
video_frames = load_image_as_tensor(selected_image).numpy()[None]
|
| 312 |
+
input_history = ""
|
| 313 |
+
|
| 314 |
+
new_frame, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = run_interactive(video_frames[0],
|
| 315 |
+
actions[0],
|
| 316 |
+
poses[0],
|
| 317 |
+
device=device,
|
| 318 |
+
memory_latent_frames=memory_latent_frames,
|
| 319 |
+
memory_actions=memory_actions,
|
| 320 |
+
memory_poses=memory_poses,
|
| 321 |
+
memory_c2w=memory_c2w,
|
| 322 |
+
memory_frame_idx=memory_frame_idx,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
return input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
| 326 |
+
|
| 327 |
+
def on_image_click(selected_image):
|
| 328 |
+
input_history, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx = reset(selected_image)
|
| 329 |
+
return input_history, selected_image, selected_image, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
| 330 |
+
|
| 331 |
+
def set_memory(examples_case):
|
| 332 |
+
if examples_case == '1':
|
| 333 |
+
data_bundle = np.load("assets/examples/case1.npz")
|
| 334 |
+
input_history = data_bundle['input_history'].item()
|
| 335 |
+
video_frames = data_bundle['memory_frames']
|
| 336 |
+
memory_latent_frames = data_bundle['self_frames']
|
| 337 |
+
memory_actions = data_bundle['self_actions']
|
| 338 |
+
memory_poses = data_bundle['self_poses']
|
| 339 |
+
memory_c2w = data_bundle['self_memory_c2w']
|
| 340 |
+
memory_frame_idx = data_bundle['self_frame_idx']
|
| 341 |
+
elif examples_case == '2':
|
| 342 |
+
data_bundle = np.load("assets/examples/case2.npz")
|
| 343 |
+
input_history = data_bundle['input_history'].item()
|
| 344 |
+
video_frames = data_bundle['memory_frames']
|
| 345 |
+
memory_latent_frames = data_bundle['self_frames']
|
| 346 |
+
memory_actions = data_bundle['self_actions']
|
| 347 |
+
memory_poses = data_bundle['self_poses']
|
| 348 |
+
memory_c2w = data_bundle['self_memory_c2w']
|
| 349 |
+
memory_frame_idx = data_bundle['self_frame_idx']
|
| 350 |
+
elif examples_case == '3':
|
| 351 |
+
data_bundle = np.load("assets/examples/case3.npz")
|
| 352 |
+
input_history = data_bundle['input_history'].item()
|
| 353 |
+
video_frames = data_bundle['memory_frames']
|
| 354 |
+
memory_latent_frames = data_bundle['self_frames']
|
| 355 |
+
memory_actions = data_bundle['self_actions']
|
| 356 |
+
memory_poses = data_bundle['self_poses']
|
| 357 |
+
memory_c2w = data_bundle['self_memory_c2w']
|
| 358 |
+
memory_frame_idx = data_bundle['self_frame_idx']
|
| 359 |
+
elif examples_case == '4':
|
| 360 |
+
data_bundle = np.load("assets/examples/case4.npz")
|
| 361 |
+
input_history = data_bundle['input_history'].item()
|
| 362 |
+
video_frames = data_bundle['memory_frames']
|
| 363 |
+
memory_latent_frames = data_bundle['self_frames']
|
| 364 |
+
memory_actions = data_bundle['self_actions']
|
| 365 |
+
memory_poses = data_bundle['self_poses']
|
| 366 |
+
memory_c2w = data_bundle['self_memory_c2w']
|
| 367 |
+
memory_frame_idx = data_bundle['self_frame_idx']
|
| 368 |
+
|
| 369 |
+
out_video = video_frames.transpose(0,2,3,1)
|
| 370 |
+
out_video = np.clip(out_video, a_min=0.0, a_max=1.0)
|
| 371 |
+
out_video = (out_video * 255).astype(np.uint8)
|
| 372 |
+
|
| 373 |
+
temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
| 374 |
+
save_video(out_video, temporal_video_path)
|
| 375 |
+
|
| 376 |
+
return input_history, out_video[-1], temporal_video_path, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx
|
| 377 |
+
|
| 378 |
+
css = """
|
| 379 |
+
h1 {
|
| 380 |
+
text-align: center;
|
| 381 |
+
display:block;
|
| 382 |
+
}
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
with gr.Blocks(css=css) as demo:
|
| 386 |
+
gr.Markdown(
|
| 387 |
+
"""
|
| 388 |
+
# WORLDMEM: Long-term Consistent World Simulation with Memory
|
| 389 |
+
"""
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
gr.Markdown(
|
| 393 |
+
"""
|
| 394 |
+
## 🚀 How to Explore WorldMem
|
| 395 |
+
|
| 396 |
+
Follow these simple steps to get started:
|
| 397 |
+
|
| 398 |
+
1. **Choose a scene**.
|
| 399 |
+
2. **Input your action sequence**.
|
| 400 |
+
3. **Click "Generate"**.
|
| 401 |
+
|
| 402 |
+
- You can continuously click **"Generate"** to **extend the video** and observe how well the world maintains consistency over time.
|
| 403 |
+
- For best performance, we recommend **running locally** (1s/frame on H100) instead of Spaces (5s/frame).
|
| 404 |
+
- ⭐️ If you like this project, please [give it a star on GitHub]()!
|
| 405 |
+
- 💬 For questions or feedback, feel free to open an issue or email me at **zeqixiao1@gmail.com**.
|
| 406 |
+
|
| 407 |
+
Happy exploring! 🌍
|
| 408 |
+
"""
|
| 409 |
+
)
|
| 410 |
+
# <div style="text-align: center;">
|
| 411 |
+
# <!-- Public Website -->
|
| 412 |
+
# <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/">
|
| 413 |
+
# <img src="https://img.shields.io/badge/public_website-8A2BE2">
|
| 414 |
+
# </a>
|
| 415 |
+
|
| 416 |
+
# <!-- GitHub Stars -->
|
| 417 |
+
# <a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything">
|
| 418 |
+
# <img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social">
|
| 419 |
+
# </a>
|
| 420 |
+
|
| 421 |
+
# <!-- Project Page -->
|
| 422 |
+
# <a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/">
|
| 423 |
+
# <img src="https://img.shields.io/badge/project_page-blue">
|
| 424 |
+
# </a>
|
| 425 |
+
|
| 426 |
+
# <!-- arXiv Paper -->
|
| 427 |
+
# <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX">
|
| 428 |
+
# <img src="https://img.shields.io/badge/arXiv-paper-red">
|
| 429 |
+
# </a>
|
| 430 |
+
# </div>
|
| 431 |
+
|
| 432 |
+
example_actions = {"turn left→turn right": "AAAAAAAAAAAADDDDDDDDDDDD",
|
| 433 |
+
"turn 360 degree": "AAAAAAAAAAAAAAAAAAAAAAAA",
|
| 434 |
+
"turn right→go backward→look up→turn left→look down": "DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW",
|
| 435 |
+
"turn right→go forward→turn right": "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD",
|
| 436 |
+
"turn right→look up→turn right→look down": "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSS",
|
| 437 |
+
"put item→go backward→put item→go backward":"SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEE"}
|
| 438 |
+
|
| 439 |
+
selected_image = gr.State(ICE_PLAINS_IMAGE)
|
| 440 |
+
|
| 441 |
+
with gr.Row(variant="panel"):
|
| 442 |
+
with gr.Column():
|
| 443 |
+
gr.Markdown("🖼️ Start from this frame.")
|
| 444 |
+
image_display = gr.Image(value=selected_image.value, interactive=False, label="Current Frame")
|
| 445 |
+
with gr.Column():
|
| 446 |
+
gr.Markdown("🎞️ Generated videos. New contents are marked in red box.")
|
| 447 |
+
video_display = gr.Video(autoplay=True, loop=True)
|
| 448 |
+
|
| 449 |
+
gr.Markdown("### 🏞️ Choose a scene and start generation.")
|
| 450 |
+
|
| 451 |
+
with gr.Row():
|
| 452 |
+
image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains")
|
| 453 |
+
image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert")
|
| 454 |
+
image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna")
|
| 455 |
+
image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains")
|
| 456 |
+
image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains")
|
| 457 |
+
image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place")
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
with gr.Row(variant="panel"):
|
| 461 |
+
with gr.Column(scale=2):
|
| 462 |
+
gr.Markdown("### 🕹️ Input action sequences for interaction.")
|
| 463 |
+
input_box = gr.Textbox(label="Action Sequences", placeholder="Enter action sequences here, e.g. (AAAAAAAAAAAADDDDDDDDDDDD)", lines=1, max_lines=1)
|
| 464 |
+
log_output = gr.Textbox(label="History Sequences", interactive=False)
|
| 465 |
+
gr.Markdown(
|
| 466 |
+
"""
|
| 467 |
+
### 💡 Action Key Guide
|
| 468 |
+
|
| 469 |
+
<pre style="font-family: monospace; font-size: 14px; line-height: 1.6;">
|
| 470 |
+
W: Turn up S: Turn down A: Turn left D: Turn right
|
| 471 |
+
Q: Go forward E: Go backward N: No-op U: Use item
|
| 472 |
+
</pre>
|
| 473 |
+
"""
|
| 474 |
+
)
|
| 475 |
+
gr.Markdown("### 👇 Click to quickly set action sequence examples.")
|
| 476 |
+
with gr.Row():
|
| 477 |
+
buttons = []
|
| 478 |
+
for action_key in list(example_actions.keys())[:2]:
|
| 479 |
+
with gr.Column(scale=len(action_key)):
|
| 480 |
+
buttons.append(gr.Button(action_key))
|
| 481 |
+
with gr.Row():
|
| 482 |
+
for action_key in list(example_actions.keys())[2:4]:
|
| 483 |
+
with gr.Column(scale=len(action_key)):
|
| 484 |
+
buttons.append(gr.Button(action_key))
|
| 485 |
+
with gr.Row():
|
| 486 |
+
for action_key in list(example_actions.keys())[4:6]:
|
| 487 |
+
with gr.Column(scale=len(action_key)):
|
| 488 |
+
buttons.append(gr.Button(action_key))
|
| 489 |
+
|
| 490 |
+
with gr.Column(scale=1):
|
| 491 |
+
submit_button = gr.Button("🎬 Generate!", variant="primary")
|
| 492 |
+
reset_btn = gr.Button("🔄 Reset")
|
| 493 |
+
|
| 494 |
+
# gr.Markdown("<div style='flex-grow:1; height: 100px'></div>")
|
| 495 |
+
|
| 496 |
+
gr.Markdown("### ⚙️ Advanced Settings")
|
| 497 |
+
|
| 498 |
+
slider_denoising_step = gr.Slider(
|
| 499 |
+
minimum=10, maximum=50, value=worldmem.sampling_timesteps, step=1,
|
| 500 |
+
label="Denoising Steps",
|
| 501 |
+
info="Higher values yield better quality but slower speed"
|
| 502 |
+
)
|
| 503 |
+
slider_context_length = gr.Slider(
|
| 504 |
+
minimum=2, maximum=10, value=worldmem.n_tokens, step=1,
|
| 505 |
+
label="Context Length",
|
| 506 |
+
info="How many previous frames in temporal context window."
|
| 507 |
+
)
|
| 508 |
+
slider_memory_condition_length = gr.Slider(
|
| 509 |
+
minimum=4, maximum=16, value=worldmem.memory_condition_length, step=1,
|
| 510 |
+
label="Memory Length",
|
| 511 |
+
info="How many previous frames in memory window. (Recommended: 1, multi-frame generation is not stable yet)"
|
| 512 |
+
)
|
| 513 |
+
slider_next_frame_length = gr.Slider(
|
| 514 |
+
minimum=1, maximum=5, value=worldmem.next_frame_length, step=1,
|
| 515 |
+
label="Next Frame Length",
|
| 516 |
+
info="How many next frames to generate at once."
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
sampling_timesteps_state = gr.State(worldmem.sampling_timesteps)
|
| 520 |
+
sampling_context_length_state = gr.State(worldmem.n_tokens)
|
| 521 |
+
sampling_memory_condition_length_state = gr.State(worldmem.memory_condition_length)
|
| 522 |
+
sampling_next_frame_length_state = gr.State(worldmem.next_frame_length)
|
| 523 |
+
|
| 524 |
+
video_frames = gr.State(load_image_as_tensor(selected_image.value)[None].numpy())
|
| 525 |
+
memory_latent_frames = gr.State()
|
| 526 |
+
memory_actions = gr.State()
|
| 527 |
+
memory_poses = gr.State()
|
| 528 |
+
memory_c2w = gr.State()
|
| 529 |
+
memory_frame_idx = gr.State()
|
| 530 |
+
|
| 531 |
+
def set_action(action):
|
| 532 |
+
return action
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
for button, action_key in zip(buttons, list(example_actions.keys())):
|
| 537 |
+
button.click(set_action, inputs=[gr.State(value=example_actions[action_key])], outputs=input_box)
|
| 538 |
+
|
| 539 |
+
gr.Markdown("### 👇 Click to review generated examples, and continue generation based on them.")
|
| 540 |
+
|
| 541 |
+
example_case = gr.Textbox(label="Case", visible=False)
|
| 542 |
+
image_output = gr.Image(visible=False)
|
| 543 |
+
|
| 544 |
+
examples = gr.Examples(
|
| 545 |
+
examples=example_images,
|
| 546 |
+
inputs=[example_case, image_output, log_output, slider_denoising_step, slider_context_length, slider_memory_condition_length],
|
| 547 |
+
cache_examples=False
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
example_case.change(
|
| 551 |
+
fn=set_memory,
|
| 552 |
+
inputs=[example_case],
|
| 553 |
+
outputs=[log_output, image_display, video_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx]
|
| 554 |
+
)
|
| 555 |
+
|
| 556 |
+
submit_button.click(generate, inputs=[input_box, log_output, video_frames,
|
| 557 |
+
memory_latent_frames, memory_actions, memory_poses,
|
| 558 |
+
memory_c2w, memory_frame_idx],
|
| 559 |
+
outputs=[image_display, video_display, log_output,
|
| 560 |
+
video_frames, memory_latent_frames, memory_actions, memory_poses,
|
| 561 |
+
memory_c2w, memory_frame_idx])
|
| 562 |
+
|
| 563 |
+
reset_btn.click(reset, inputs=[selected_image], outputs=[log_output, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
| 564 |
+
image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
| 565 |
+
image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
| 566 |
+
image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
| 567 |
+
image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
| 568 |
+
image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=[log_output, selected_image, image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
| 569 |
+
image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=[log_output, selected_image,image_display, video_frames, memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx])
|
| 570 |
+
|
| 571 |
+
slider_denoising_step.change(fn=set_denoising_steps, inputs=[slider_denoising_step, sampling_timesteps_state], outputs=sampling_timesteps_state)
|
| 572 |
+
slider_context_length.change(fn=set_context_length, inputs=[slider_context_length, sampling_context_length_state], outputs=sampling_context_length_state)
|
| 573 |
+
slider_memory_condition_length.change(fn=set_memory_condition_length, inputs=[slider_memory_condition_length, sampling_memory_condition_length_state], outputs=sampling_memory_condition_length_state)
|
| 574 |
+
slider_next_frame_length.change(fn=set_next_frame_length, inputs=[slider_next_frame_length, sampling_next_frame_length_state], outputs=sampling_next_frame_length_state)
|
| 575 |
+
|
| 576 |
+
demo.launch(share=True)
|
assets/desert.png
ADDED
|
Git LFS Details
|
assets/ice_plains.png
ADDED
|
Git LFS Details
|
assets/place.png
ADDED
|
Git LFS Details
|
assets/plains.png
ADDED
|
Git LFS Details
|
assets/rain_sunflower_plains.png
ADDED
|
Git LFS Details
|
assets/savanna.png
ADDED
|
Git LFS Details
|
assets/sunflower_plains.png
ADDED
|
Git LFS Details
|
assets/worldmem_logo.png
ADDED
|
Git LFS Details
|
calculate_fid.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Calculate FID (Fréchet Inception Distance) between predicted and ground truth videos.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python calculate_fid.py --videos_dir /path/to/videos
|
| 7 |
+
python calculate_fid.py --videos_dir /path/to/videos --batch_size 32
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import argparse
|
| 15 |
+
import cv2
|
| 16 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_video_frames(video_path, max_frames=None):
|
| 20 |
+
"""
|
| 21 |
+
Load frames from a video file.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
video_path: Path to the video file
|
| 25 |
+
max_frames: Maximum number of frames to load (None = all frames)
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
torch.Tensor: Video frames with shape (T, C, H, W) in range [0, 255]
|
| 29 |
+
"""
|
| 30 |
+
cap = cv2.VideoCapture(str(video_path))
|
| 31 |
+
frames = []
|
| 32 |
+
frame_count = 0
|
| 33 |
+
|
| 34 |
+
while True:
|
| 35 |
+
ret, frame = cap.read()
|
| 36 |
+
if not ret:
|
| 37 |
+
break
|
| 38 |
+
|
| 39 |
+
# Convert BGR to RGB
|
| 40 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 41 |
+
frames.append(frame)
|
| 42 |
+
frame_count += 1
|
| 43 |
+
|
| 44 |
+
if max_frames and frame_count >= max_frames:
|
| 45 |
+
break
|
| 46 |
+
|
| 47 |
+
cap.release()
|
| 48 |
+
|
| 49 |
+
if len(frames) == 0:
|
| 50 |
+
raise ValueError(f"No frames loaded from {video_path}")
|
| 51 |
+
|
| 52 |
+
# Convert to tensor: (T, H, W, C) -> (T, C, H, W)
|
| 53 |
+
frames = np.stack(frames, axis=0)
|
| 54 |
+
frames = torch.from_numpy(frames).permute(0, 3, 1, 2)
|
| 55 |
+
|
| 56 |
+
return frames
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_videos_from_directory(video_dir, max_frames_per_video=None, max_videos=None):
|
| 60 |
+
"""
|
| 61 |
+
Load all videos from a directory.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
video_dir: Directory containing .mp4 files
|
| 65 |
+
max_frames_per_video: Maximum frames to load per video
|
| 66 |
+
max_videos: Maximum number of videos to load
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
torch.Tensor: All frames concatenated with shape (N, C, H, W)
|
| 70 |
+
"""
|
| 71 |
+
video_dir = Path(video_dir)
|
| 72 |
+
video_paths = sorted(list(video_dir.glob("**/*.mp4")))
|
| 73 |
+
|
| 74 |
+
if max_videos:
|
| 75 |
+
video_paths = video_paths[:max_videos]
|
| 76 |
+
|
| 77 |
+
all_frames = []
|
| 78 |
+
|
| 79 |
+
print(f"Loading videos from {video_dir}")
|
| 80 |
+
print(f"Found {len(video_paths)} videos")
|
| 81 |
+
|
| 82 |
+
for video_path in tqdm(video_paths, desc="Loading videos"):
|
| 83 |
+
try:
|
| 84 |
+
frames = load_video_frames(video_path, max_frames=max_frames_per_video)
|
| 85 |
+
all_frames.append(frames)
|
| 86 |
+
except Exception as e:
|
| 87 |
+
print(f"\nWarning: Failed to load {video_path.name}: {e}")
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
if len(all_frames) == 0:
|
| 91 |
+
raise ValueError(f"No videos loaded from {video_dir}")
|
| 92 |
+
|
| 93 |
+
# Concatenate all frames: (N_videos, T, C, H, W) -> (N_total_frames, C, H, W)
|
| 94 |
+
all_frames = torch.cat(all_frames, dim=0)
|
| 95 |
+
|
| 96 |
+
print(f"Loaded {all_frames.shape[0]} frames total")
|
| 97 |
+
print(f"Frame shape: {all_frames.shape[1:]}")
|
| 98 |
+
|
| 99 |
+
return all_frames
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def calculate_fid(pred_dir, gt_dir, batch_size=32, device='cuda',
|
| 103 |
+
max_frames_per_video=None, max_videos=None):
|
| 104 |
+
"""
|
| 105 |
+
Calculate FID between predicted and ground truth videos.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
pred_dir: Directory containing predicted videos
|
| 109 |
+
gt_dir: Directory containing ground truth videos
|
| 110 |
+
batch_size: Batch size for FID calculation
|
| 111 |
+
device: Device to use ('cuda' or 'cpu')
|
| 112 |
+
max_frames_per_video: Maximum frames to load per video
|
| 113 |
+
max_videos: Maximum number of videos to load from each directory
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
float: FID score
|
| 117 |
+
"""
|
| 118 |
+
print("="*60)
|
| 119 |
+
print("FID Calculation")
|
| 120 |
+
print("="*60)
|
| 121 |
+
print(f"Pred directory: {pred_dir}")
|
| 122 |
+
print(f"GT directory: {gt_dir}")
|
| 123 |
+
print(f"Device: {device}")
|
| 124 |
+
print(f"Batch size: {batch_size}")
|
| 125 |
+
print("="*60 + "\n")
|
| 126 |
+
|
| 127 |
+
# Check if directories exist
|
| 128 |
+
pred_dir = Path(pred_dir)
|
| 129 |
+
gt_dir = Path(gt_dir)
|
| 130 |
+
|
| 131 |
+
if not pred_dir.exists():
|
| 132 |
+
raise ValueError(f"Pred directory does not exist: {pred_dir}")
|
| 133 |
+
if not gt_dir.exists():
|
| 134 |
+
raise ValueError(f"GT directory does not exist: {gt_dir}")
|
| 135 |
+
|
| 136 |
+
# Load videos
|
| 137 |
+
print("\n[1/3] Loading predicted videos...")
|
| 138 |
+
pred_frames = load_videos_from_directory(
|
| 139 |
+
pred_dir,
|
| 140 |
+
max_frames_per_video=max_frames_per_video,
|
| 141 |
+
max_videos=max_videos
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
print("\n[2/3] Loading ground truth videos...")
|
| 145 |
+
gt_frames = load_videos_from_directory(
|
| 146 |
+
gt_dir,
|
| 147 |
+
max_frames_per_video=max_frames_per_video,
|
| 148 |
+
max_videos=max_videos
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Initialize FID model
|
| 152 |
+
print("\n[3/3] Calculating FID...")
|
| 153 |
+
fid_model = FrechetInceptionDistance(normalize=True).to(device)
|
| 154 |
+
|
| 155 |
+
# Process pred frames in batches
|
| 156 |
+
print("Processing predicted frames...")
|
| 157 |
+
num_pred_frames = pred_frames.shape[0]
|
| 158 |
+
for i in tqdm(range(0, num_pred_frames, batch_size)):
|
| 159 |
+
batch = pred_frames[i:i+batch_size]
|
| 160 |
+
batch = batch.to(device)
|
| 161 |
+
fid_model.update(batch, real=False)
|
| 162 |
+
|
| 163 |
+
# Process gt frames in batches
|
| 164 |
+
print("Processing ground truth frames...")
|
| 165 |
+
num_gt_frames = gt_frames.shape[0]
|
| 166 |
+
for i in tqdm(range(0, num_gt_frames, batch_size)):
|
| 167 |
+
batch = gt_frames[i:i+batch_size]
|
| 168 |
+
batch = batch.to(device)
|
| 169 |
+
fid_model.update(batch, real=True)
|
| 170 |
+
|
| 171 |
+
# Compute FID
|
| 172 |
+
fid_score = fid_model.compute().item()
|
| 173 |
+
|
| 174 |
+
return fid_score
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def main():
|
| 178 |
+
parser = argparse.ArgumentParser(
|
| 179 |
+
description="Calculate FID between predicted and ground truth videos"
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--videos_dir",
|
| 183 |
+
type=str,
|
| 184 |
+
default="/mnt/worldmem_valid/outputs/2025-12-01/08-09-46/videos/test_vis",
|
| 185 |
+
help="Base directory containing 'pred' and 'gt' subdirectories"
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--pred_dir",
|
| 189 |
+
type=str,
|
| 190 |
+
default=None,
|
| 191 |
+
help="Override pred directory (default: {videos_dir}/pred)"
|
| 192 |
+
)
|
| 193 |
+
parser.add_argument(
|
| 194 |
+
"--gt_dir",
|
| 195 |
+
type=str,
|
| 196 |
+
default=None,
|
| 197 |
+
help="Override gt directory (default: {videos_dir}/gt)"
|
| 198 |
+
)
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--batch_size",
|
| 201 |
+
type=int,
|
| 202 |
+
default=32,
|
| 203 |
+
help="Batch size for FID calculation (default: 32)"
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument(
|
| 206 |
+
"--device",
|
| 207 |
+
type=str,
|
| 208 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 209 |
+
help="Device to use (default: cuda if available)"
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--max_frames_per_video",
|
| 213 |
+
type=int,
|
| 214 |
+
default=None,
|
| 215 |
+
help="Maximum frames to load per video (default: None, load all)"
|
| 216 |
+
)
|
| 217 |
+
parser.add_argument(
|
| 218 |
+
"--max_videos",
|
| 219 |
+
type=int,
|
| 220 |
+
default=50,
|
| 221 |
+
help="Maximum number of videos to load (default: None, load all)"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
args = parser.parse_args()
|
| 225 |
+
|
| 226 |
+
# Determine pred and gt directories
|
| 227 |
+
videos_dir = Path(args.videos_dir)
|
| 228 |
+
|
| 229 |
+
if args.pred_dir:
|
| 230 |
+
pred_dir = Path(args.pred_dir)
|
| 231 |
+
else:
|
| 232 |
+
pred_dir = videos_dir / "pred"
|
| 233 |
+
|
| 234 |
+
if args.gt_dir:
|
| 235 |
+
gt_dir = Path(args.gt_dir)
|
| 236 |
+
else:
|
| 237 |
+
gt_dir = videos_dir / "gt"
|
| 238 |
+
|
| 239 |
+
# Calculate FID
|
| 240 |
+
try:
|
| 241 |
+
fid_score = calculate_fid(
|
| 242 |
+
pred_dir=pred_dir,
|
| 243 |
+
gt_dir=gt_dir,
|
| 244 |
+
batch_size=args.batch_size,
|
| 245 |
+
device=args.device,
|
| 246 |
+
max_frames_per_video=args.max_frames_per_video,
|
| 247 |
+
max_videos=args.max_videos
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
# Print results
|
| 251 |
+
print("\n" + "="*60)
|
| 252 |
+
print("RESULTS")
|
| 253 |
+
print("="*60)
|
| 254 |
+
print(f"FID Score: {fid_score:.4f}")
|
| 255 |
+
print("="*60)
|
| 256 |
+
|
| 257 |
+
# Save results to file
|
| 258 |
+
output_file = videos_dir / "fid_results.txt"
|
| 259 |
+
with open(output_file, 'w') as f:
|
| 260 |
+
f.write(f"FID Score: {fid_score:.4f}\n")
|
| 261 |
+
f.write(f"Pred directory: {pred_dir}\n")
|
| 262 |
+
f.write(f"GT directory: {gt_dir}\n")
|
| 263 |
+
|
| 264 |
+
print(f"\nResults saved to: {output_file}")
|
| 265 |
+
|
| 266 |
+
except Exception as e:
|
| 267 |
+
print(f"\n✗ Error: {e}")
|
| 268 |
+
import traceback
|
| 269 |
+
traceback.print_exc()
|
| 270 |
+
return 1
|
| 271 |
+
|
| 272 |
+
return 0
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
if __name__ == "__main__":
|
| 276 |
+
exit(main())
|
| 277 |
+
|
configurations/algorithm/base_algo.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class
|
| 2 |
+
|
| 3 |
+
debug: ${debug} # inherited from configurations/config.yaml
|
configurations/algorithm/base_pytorch_algo.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_algo # inherits from configurations/algorithm/base_algo.yaml
|
| 3 |
+
|
| 4 |
+
lr: ${experiment.training.lr}
|
configurations/algorithm/df_base.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_pytorch_algo
|
| 3 |
+
|
| 4 |
+
# dataset-dependent configurations
|
| 5 |
+
x_shape: ${dataset.observation_shape}
|
| 6 |
+
frame_stack: 1
|
| 7 |
+
frame_skip: 1
|
| 8 |
+
data_mean: ${dataset.data_mean}
|
| 9 |
+
data_std: ${dataset.data_std}
|
| 10 |
+
external_cond_dim: 0 #${dataset.action_dim}
|
| 11 |
+
context_frames: ${dataset.context_length}
|
| 12 |
+
# training hyperparameters
|
| 13 |
+
weight_decay: 1e-4
|
| 14 |
+
warmup_steps: 10000
|
| 15 |
+
optimizer_beta: [0.9, 0.999]
|
| 16 |
+
# diffusion-related
|
| 17 |
+
uncertainty_scale: 1
|
| 18 |
+
guidance_scale: 0.0
|
| 19 |
+
chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
|
| 20 |
+
scheduling_matrix: autoregressive
|
| 21 |
+
noise_level: random_all
|
| 22 |
+
causal: True
|
| 23 |
+
|
| 24 |
+
diffusion:
|
| 25 |
+
# training
|
| 26 |
+
objective: pred_x0
|
| 27 |
+
beta_schedule: cosine
|
| 28 |
+
schedule_fn_kwargs: {}
|
| 29 |
+
clip_noise: 20.0
|
| 30 |
+
use_snr: False
|
| 31 |
+
use_cum_snr: False
|
| 32 |
+
use_fused_snr: False
|
| 33 |
+
snr_clip: 5.0
|
| 34 |
+
cum_snr_decay: 0.98
|
| 35 |
+
timesteps: 1000
|
| 36 |
+
# sampling
|
| 37 |
+
sampling_timesteps: 50 # fixme, numer of diffusion steps, should be increased
|
| 38 |
+
ddim_sampling_eta: 1.0
|
| 39 |
+
stabilization_level: 10
|
| 40 |
+
# architecture
|
| 41 |
+
architecture:
|
| 42 |
+
network_size: 64
|
configurations/algorithm/df_video_worldmemminecraft.yaml
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- df_base
|
| 3 |
+
|
| 4 |
+
n_frames: ${dataset.n_frames}
|
| 5 |
+
frame_skip: ${dataset.frame_skip}
|
| 6 |
+
metadata: ${dataset.metadata}
|
| 7 |
+
|
| 8 |
+
# training hyperparameters
|
| 9 |
+
weight_decay: 2e-3
|
| 10 |
+
warmup_steps: 1000
|
| 11 |
+
optimizer_beta: [0.9, 0.99]
|
| 12 |
+
action_cond_dim: 25
|
| 13 |
+
use_plucker: true
|
| 14 |
+
|
| 15 |
+
diffusion:
|
| 16 |
+
# training
|
| 17 |
+
beta_schedule: sigmoid
|
| 18 |
+
objective: pred_v
|
| 19 |
+
use_fused_snr: True
|
| 20 |
+
cum_snr_decay: 0.96
|
| 21 |
+
clip_noise: 20.
|
| 22 |
+
# sampling
|
| 23 |
+
sampling_timesteps: 20
|
| 24 |
+
ddim_sampling_eta: 0.0
|
| 25 |
+
stabilization_level: 15
|
| 26 |
+
# architecture
|
| 27 |
+
architecture:
|
| 28 |
+
network_size: 64
|
| 29 |
+
attn_heads: 4
|
| 30 |
+
attn_dim_head: 64
|
| 31 |
+
dim_mults: [1, 2, 4, 8]
|
| 32 |
+
resolution: ${dataset.resolution}
|
| 33 |
+
attn_resolutions: [16, 32, 64, 128]
|
| 34 |
+
use_init_temporal_attn: True
|
| 35 |
+
use_linear_attn: True
|
| 36 |
+
time_emb_type: rotary
|
| 37 |
+
|
| 38 |
+
_name: df_video_worldmemminecraft
|
configurations/dataset/base_dataset.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This will be passed as the cfg to Dataset.__init__(cfg) of your dataset class
|
| 2 |
+
|
| 3 |
+
debug: ${debug} # inherited from configurations/config.yaml
|
configurations/dataset/base_video.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_dataset
|
| 3 |
+
|
| 4 |
+
metadata: "data/${dataset.name}/metadata.json"
|
| 5 |
+
data_mean: "data/${dataset.name}/data_mean.npy"
|
| 6 |
+
data_std: "data/${dataset.name}/data_std.npy"
|
| 7 |
+
save_dir: ???
|
| 8 |
+
n_frames: 32
|
| 9 |
+
context_length: 4
|
| 10 |
+
resolution: 128
|
| 11 |
+
observation_shape: [3, "${dataset.resolution}", "${dataset.resolution}"]
|
| 12 |
+
external_cond_dim: 0
|
| 13 |
+
validation_multiplier: 1
|
| 14 |
+
frame_skip: 1
|
configurations/dataset/video_minecraft.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_video
|
| 3 |
+
|
| 4 |
+
save_dir: data/minecraft_simple_backforward
|
| 5 |
+
n_frames: 16 # TODO: increase later
|
| 6 |
+
resolution: 128
|
| 7 |
+
data_mean: 0.5
|
| 8 |
+
data_std: 0.5
|
| 9 |
+
action_cond_dim: 25
|
| 10 |
+
context_length: 1
|
| 11 |
+
frame_skip: 1
|
| 12 |
+
validation_multiplier: 1
|
| 13 |
+
|
| 14 |
+
_name: video_minecraft_oasis
|
configurations/experiment/base_experiment.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
debug: ${debug} # inherited from configurations/config.yaml
|
| 2 |
+
tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a subset of them.
|
configurations/experiment/base_pytorch.yaml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inherites from base_experiment.yaml
|
| 2 |
+
# most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html
|
| 3 |
+
|
| 4 |
+
defaults:
|
| 5 |
+
- base_experiment
|
| 6 |
+
|
| 7 |
+
tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a subset of them.
|
| 8 |
+
num_nodes: 1 # number of gpu servers used in large scale distributed training
|
| 9 |
+
|
| 10 |
+
training:
|
| 11 |
+
precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable
|
| 12 |
+
compile: False # whether to compile the model with torch.compile
|
| 13 |
+
lr: 0.001 # learning rate
|
| 14 |
+
batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training
|
| 15 |
+
max_epochs: 1000 # set to -1 to train forever
|
| 16 |
+
max_steps: -1 # set to -1 to train forever, will override max_epochs
|
| 17 |
+
max_time: null # set to something like "00:12:00:00" to enable
|
| 18 |
+
data:
|
| 19 |
+
num_workers: 4 # number of CPU threads for data preprocessing.
|
| 20 |
+
shuffle: True # whether training data will be shuffled
|
| 21 |
+
optim:
|
| 22 |
+
accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop
|
| 23 |
+
gradient_clip_val: 0 # clip gradients with norm above this value, set to 0 to disable
|
| 24 |
+
checkpointing:
|
| 25 |
+
# these are arguments to pytorch lightning's callback, `ModelCheckpoint` class
|
| 26 |
+
every_n_train_steps: 5000 # save a checkpoint every n train steps
|
| 27 |
+
every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``
|
| 28 |
+
train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
| 29 |
+
enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones.
|
| 30 |
+
|
| 31 |
+
validation:
|
| 32 |
+
precision: 16-mixed
|
| 33 |
+
compile: False # whether to compile the model with torch.compile
|
| 34 |
+
batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
| 35 |
+
val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
|
| 36 |
+
val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
|
| 37 |
+
limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
|
| 38 |
+
inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
|
| 39 |
+
data:
|
| 40 |
+
num_workers: 4 # number of CPU threads for data preprocessing, for validation.
|
| 41 |
+
shuffle: False # whether validation data will be shuffled
|
| 42 |
+
|
| 43 |
+
test:
|
| 44 |
+
precision: 16-mixed
|
| 45 |
+
compile: False # whether to compile the model with torch.compile
|
| 46 |
+
batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
| 47 |
+
limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
|
| 48 |
+
data:
|
| 49 |
+
num_workers: 4 # number of CPU threads for data preprocessing, for test.
|
| 50 |
+
shuffle: False # whether test data will be shuffled
|
configurations/experiment/exp_video.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_pytorch
|
| 3 |
+
|
| 4 |
+
tasks: [training]
|
| 5 |
+
|
| 6 |
+
training:
|
| 7 |
+
lr: 2e-5
|
| 8 |
+
precision: 16-mixed
|
| 9 |
+
batch_size: 4
|
| 10 |
+
max_epochs: -1
|
| 11 |
+
max_steps: 2000005
|
| 12 |
+
checkpointing:
|
| 13 |
+
every_n_train_steps: 2500
|
| 14 |
+
optim:
|
| 15 |
+
gradient_clip_val: 1.0
|
| 16 |
+
|
| 17 |
+
validation:
|
| 18 |
+
val_every_n_step: 2500
|
| 19 |
+
val_every_n_epoch: null
|
| 20 |
+
batch_size: 4
|
| 21 |
+
limit_batch: 1
|
| 22 |
+
|
| 23 |
+
test:
|
| 24 |
+
limit_batch: 1
|
| 25 |
+
batch_size: 1
|
| 26 |
+
|
| 27 |
+
logging:
|
| 28 |
+
metrics:
|
| 29 |
+
# - fvd
|
| 30 |
+
# - fid
|
| 31 |
+
# - lpips
|
configurations/huggingface.yaml
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
n_tokens: 3
|
| 2 |
+
pose_cond_dim: 5
|
| 3 |
+
use_plucker: true
|
| 4 |
+
focal_length: 0.35
|
| 5 |
+
customized_validation: true
|
| 6 |
+
memory_condition_length: 8
|
| 7 |
+
log_video: true
|
| 8 |
+
relative_embedding: true
|
| 9 |
+
state_embed_only_on_qk: true
|
| 10 |
+
use_domain_adapter: false
|
| 11 |
+
use_memory_attention: true
|
| 12 |
+
add_timestamp_embedding: true
|
| 13 |
+
use_pose_prediction: true
|
| 14 |
+
require_pose_prediction: true
|
| 15 |
+
is_interactive: true
|
| 16 |
+
diffusion:
|
| 17 |
+
sampling_timesteps: 20
|
| 18 |
+
beta_schedule: sigmoid
|
| 19 |
+
objective: pred_v
|
| 20 |
+
use_fused_snr: True
|
| 21 |
+
cum_snr_decay: 0.96
|
| 22 |
+
clip_noise: 20.
|
| 23 |
+
ddim_sampling_eta: 0.0
|
| 24 |
+
stabilization_level: 15
|
| 25 |
+
schedule_fn_kwargs: {}
|
| 26 |
+
use_snr: False
|
| 27 |
+
use_cum_snr: False
|
| 28 |
+
snr_clip: 5.0
|
| 29 |
+
timesteps: 1000
|
| 30 |
+
# architecture
|
| 31 |
+
architecture:
|
| 32 |
+
network_size: 64
|
| 33 |
+
attn_heads: 4
|
| 34 |
+
attn_dim_head: 64
|
| 35 |
+
dim_mults: [1, 2, 4, 8]
|
| 36 |
+
resolution: ${dataset.resolution}
|
| 37 |
+
attn_resolutions: [16, 32, 64, 128]
|
| 38 |
+
use_init_temporal_attn: True
|
| 39 |
+
use_linear_attn: True
|
| 40 |
+
time_emb_type: rotary
|
| 41 |
+
|
| 42 |
+
weight_decay: 2e-3
|
| 43 |
+
warmup_steps: 10000
|
| 44 |
+
optimizer_beta: [0.9, 0.99]
|
| 45 |
+
action_cond_dim: 25
|
| 46 |
+
n_frames: 8
|
| 47 |
+
frame_skip: 1
|
| 48 |
+
frame_stack: 1
|
| 49 |
+
uncertainty_scale: 1
|
| 50 |
+
guidance_scale: 0.0
|
| 51 |
+
chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
|
| 52 |
+
scheduling_matrix: full_sequence
|
| 53 |
+
noise_level: random_all
|
| 54 |
+
causal: True
|
| 55 |
+
x_shape: [3, 360, 640]
|
| 56 |
+
context_frames: 1
|
| 57 |
+
diffusion_path: zeqixiao/worldmem_checkpoints/diffusion_only.ckpt
|
| 58 |
+
vae_path: zeqixiao/worldmem_checkpoints/vae_only.ckpt
|
| 59 |
+
pose_predictor_path: zeqixiao/worldmem_checkpoints/pose_prediction_model_only.ckpt
|
| 60 |
+
next_frame_length: 1
|
configurations/training.yaml
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# configuration parsing starts here
|
| 2 |
+
defaults:
|
| 3 |
+
- experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme]
|
| 4 |
+
- dataset: video_minecraft # dataset yaml file name in configurations/dataset folder [fixme]
|
| 5 |
+
- algorithm: df_video_worldmemminecraft # algorithm yaml file name in configurations/algorithm folder [fixme]
|
| 6 |
+
- cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute
|
| 7 |
+
|
| 8 |
+
debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
|
| 9 |
+
|
| 10 |
+
wandb:
|
| 11 |
+
entity: xizaoqu # wandb account name / organization name [fixme]
|
| 12 |
+
project: worldmem # wandb project name; if not provided, defaults to root folder name [fixme]
|
| 13 |
+
mode: online # set wandb logging to online, offline or dryrun
|
| 14 |
+
|
| 15 |
+
resume: null # wandb run id to resume logging and loading checkpoint from
|
| 16 |
+
load: null # wandb run id containing checkpoint or a path to a checkpoint file
|