Commit ·
b47a1ce
0
Parent(s):
Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +27 -0
- LICENSE.md +14 -0
- README.md +37 -0
- algorithms/__init__.py +0 -0
- algorithms/common/__init__.py +0 -0
- algorithms/common/base_algo.py +21 -0
- algorithms/common/base_pytorch_algo.py +277 -0
- algorithms/common/metrics/__init__.py +3 -0
- algorithms/common/metrics/fid.py +1 -0
- algorithms/common/metrics/fvd.py +158 -0
- algorithms/common/metrics/lpips.py +1 -0
- algorithms/worldmem/__init__.py +1 -0
- algorithms/worldmem/dememwm/__init__.py +18 -0
- algorithms/worldmem/dememwm/algorithm.py +0 -0
- algorithms/worldmem/dememwm/cache.py +513 -0
- algorithms/worldmem/dememwm/compression.py +260 -0
- algorithms/worldmem/dememwm/diagnostics.py +174 -0
- algorithms/worldmem/dememwm/gates.py +46 -0
- algorithms/worldmem/dememwm/injection.py +83 -0
- algorithms/worldmem/dememwm/labels.py +479 -0
- algorithms/worldmem/dememwm/memory.py +208 -0
- algorithms/worldmem/dememwm/negatives.py +41 -0
- algorithms/worldmem/dememwm/retrieval.py +476 -0
- algorithms/worldmem/dememwm/schedules.py +223 -0
- algorithms/worldmem/dememwm/types.py +98 -0
- algorithms/worldmem/dememwm_memory_dit.py +18 -0
- algorithms/worldmem/df_base.py +307 -0
- algorithms/worldmem/df_video.py +926 -0
- algorithms/worldmem/models/attention.py +342 -0
- algorithms/worldmem/models/diffusion.py +594 -0
- algorithms/worldmem/models/dit.py +899 -0
- algorithms/worldmem/models/pose_prediction.py +42 -0
- algorithms/worldmem/models/rotary_embedding_torch.py +302 -0
- algorithms/worldmem/models/utils.py +163 -0
- algorithms/worldmem/models/vae.py +359 -0
- configurations/algorithm/base_algo.yaml +3 -0
- configurations/algorithm/base_pytorch_algo.yaml +4 -0
- configurations/algorithm/base_video_dit.yaml +36 -0
- configurations/algorithm/dememwm_memory_dit.yaml +103 -0
- configurations/algorithm/df_base.yaml +42 -0
- configurations/dataset/base_dataset.yaml +3 -0
- configurations/dataset/base_video.yaml +14 -0
- configurations/dataset/video_minecraft.yaml +14 -0
- configurations/dataset/video_minecraft_latent.yaml +6 -0
- configurations/experiment/base_experiment.yaml +2 -0
- configurations/experiment/base_pytorch.yaml +51 -0
- configurations/experiment/exp_video.yaml +31 -0
- configurations/training.yaml +18 -0
- datasets/__init__.py +1 -0
- datasets/video/__init__.py +2 -0
.gitignore
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
|
| 5 |
+
.pytest_cache/
|
| 6 |
+
.mypy_cache/
|
| 7 |
+
.ruff_cache/
|
| 8 |
+
.ipynb_checkpoints/
|
| 9 |
+
|
| 10 |
+
.env
|
| 11 |
+
.venv/
|
| 12 |
+
venv/
|
| 13 |
+
|
| 14 |
+
outputs/
|
| 15 |
+
slurm_logs/
|
| 16 |
+
latest-run
|
| 17 |
+
.wandb_run_id
|
| 18 |
+
.wandb_osh_command_dir/
|
| 19 |
+
wandb/
|
| 20 |
+
|
| 21 |
+
*.log
|
| 22 |
+
*.ckpt
|
| 23 |
+
*.pt
|
| 24 |
+
*.pth
|
| 25 |
+
*.safetensors
|
| 26 |
+
|
| 27 |
+
data/
|
LICENSE.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# S-Lab License 1.0
|
| 2 |
+
|
| 3 |
+
Copyright 2025 S-Lab
|
| 4 |
+
|
| 5 |
+
Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
| 6 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
| 7 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
| 8 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\
|
| 9 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 10 |
+
4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg)
|
README.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DeMemWM
|
| 2 |
+
|
| 3 |
+
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep this sentence in `README.md` and the `LICENSE` file to credit the author.
|
| 4 |
+
|
| 5 |
+
DeMemWM is a Memory-DiT video prediction project built on the local research template. The primary algorithm entry point is `DeMemWMMinecraft`, registered through the Hydra algorithm config `dememwm_memory_dit`.
|
| 6 |
+
|
| 7 |
+
## Quick Start
|
| 8 |
+
|
| 9 |
+
```bash
|
| 10 |
+
python -m venv .venv
|
| 11 |
+
source .venv/bin/activate
|
| 12 |
+
pip install -r requirements.txt
|
| 13 |
+
python -m pytest tests
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Run a local offline experiment after setting the dataset path in `configurations/dataset/video_minecraft.yaml`:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
python main.py +name=dememwm_debug algorithm=dememwm_memory_dit wandb.mode=offline
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
Use `resume_ckpt_path=/path/to/checkpoint.ckpt` for deterministic checkpoint resume, or keep `auto_resume=true` to resume from `output_dir/checkpoints` when available.
|
| 23 |
+
|
| 24 |
+
## Layout
|
| 25 |
+
|
| 26 |
+
- `algorithms/worldmem/dememwm/`: DeMemWM memory construction, retrieval, scheduling, diagnostics, and injection code.
|
| 27 |
+
- `algorithms/worldmem/dememwm_memory_dit.py`: primary DeMemWM algorithm class.
|
| 28 |
+
- `configurations/algorithm/dememwm_memory_dit.yaml`: consumed DeMemWM training and evaluation contract.
|
| 29 |
+
- `scripts/`: Slurm and inspection scripts using the DeMemWM naming.
|
| 30 |
+
- `tests/`: static and unit coverage for DeMemWM config, retrieval, compression, schedules, and training behavior.
|
| 31 |
+
|
| 32 |
+
## Reproducibility Notes
|
| 33 |
+
|
| 34 |
+
- Keep `wandb.mode=offline` for local reproducible runs that do not depend on network access.
|
| 35 |
+
- Set `seed=<int>` on the command line to seed Lightning and dataloader workers.
|
| 36 |
+
- Runtime artifacts such as `outputs/`, `slurm_logs/`, Python caches, checkpoints, and local datasets are ignored by git.
|
| 37 |
+
- The default Hydra training config selects `algorithm: dememwm_memory_dit`.
|
algorithms/__init__.py
ADDED
|
File without changes
|
algorithms/common/__init__.py
ADDED
|
File without changes
|
algorithms/common/base_algo.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseAlgo(ABC):
|
| 8 |
+
"""
|
| 9 |
+
A base class for generic algorithms.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, cfg: DictConfig):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.cfg = cfg
|
| 15 |
+
|
| 16 |
+
@abstractmethod
|
| 17 |
+
def run(*args: Any, **kwargs: Any) -> Any:
|
| 18 |
+
"""
|
| 19 |
+
Run the algorithm.
|
| 20 |
+
"""
|
| 21 |
+
raise NotImplementedError
|
algorithms/common/base_pytorch_algo.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import warnings
|
| 3 |
+
import random
|
| 4 |
+
from typing import Any, Union, Sequence, Optional
|
| 5 |
+
|
| 6 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
| 7 |
+
from omegaconf import DictConfig
|
| 8 |
+
import lightning.pytorch as pl
|
| 9 |
+
import torch
|
| 10 |
+
import numpy as np
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import wandb
|
| 13 |
+
import einops
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class BasePytorchAlgo(pl.LightningModule, ABC):
|
| 17 |
+
"""
|
| 18 |
+
A base class for Pytorch algorithms using Pytorch Lightning.
|
| 19 |
+
See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, cfg: DictConfig):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.cfg = cfg
|
| 25 |
+
self._build_model()
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def _build_model(self):
|
| 29 |
+
"""
|
| 30 |
+
Create all pytorch nn.Modules here.
|
| 31 |
+
"""
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
| 36 |
+
r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
|
| 37 |
+
logger.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
|
| 41 |
+
batch_idx: The index of this batch.
|
| 42 |
+
dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch.
|
| 43 |
+
|
| 44 |
+
Return:
|
| 45 |
+
Any of these options:
|
| 46 |
+
- :class:`~torch.Tensor` - The loss tensor
|
| 47 |
+
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
|
| 48 |
+
- ``None`` - Skip to the next batch. This is only supported for automatic optimization.
|
| 49 |
+
This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
|
| 50 |
+
|
| 51 |
+
In this step you'd normally do the forward pass and calculate the loss for a batch.
|
| 52 |
+
You can also do fancier things like multiple forward passes or something model specific.
|
| 53 |
+
|
| 54 |
+
Example::
|
| 55 |
+
|
| 56 |
+
def training_step(self, batch, batch_idx):
|
| 57 |
+
x, y, z = batch
|
| 58 |
+
out = self.encoder(x)
|
| 59 |
+
loss = self.loss(out, x)
|
| 60 |
+
return loss
|
| 61 |
+
|
| 62 |
+
To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
|
| 63 |
+
|
| 64 |
+
.. code-block:: python
|
| 65 |
+
|
| 66 |
+
def __init__(self):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.automatic_optimization = False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Multiple optimizers (e.g.: GANs)
|
| 72 |
+
def training_step(self, batch, batch_idx):
|
| 73 |
+
opt1, opt2 = self.optimizers()
|
| 74 |
+
|
| 75 |
+
# do training_step with encoder
|
| 76 |
+
...
|
| 77 |
+
opt1.step()
|
| 78 |
+
# do training_step with decoder
|
| 79 |
+
...
|
| 80 |
+
opt2.step()
|
| 81 |
+
|
| 82 |
+
Note:
|
| 83 |
+
When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
|
| 84 |
+
normalized by ``accumulate_grad_batches`` internally.
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
return super().training_step(*args, **kwargs)
|
| 88 |
+
|
| 89 |
+
def configure_optimizers(self):
|
| 90 |
+
"""
|
| 91 |
+
Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation:
|
| 92 |
+
https://lightning.ai/docs/pytorch/stable/common/optimization.html
|
| 93 |
+
"""
|
| 94 |
+
parameters = self.parameters()
|
| 95 |
+
return torch.optim.Adam(parameters, lr=self.cfg.lr)
|
| 96 |
+
|
| 97 |
+
def on_save_checkpoint(self, checkpoint):
|
| 98 |
+
checkpoint["rng_states"] = {
|
| 99 |
+
"python": random.getstate(),
|
| 100 |
+
"numpy": np.random.get_state(),
|
| 101 |
+
"torch": torch.get_rng_state(),
|
| 102 |
+
"cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
def on_load_checkpoint(self, checkpoint):
|
| 106 |
+
rng_states = checkpoint.get("rng_states")
|
| 107 |
+
if rng_states is None:
|
| 108 |
+
if getattr(self, "_strict_resume_state", False):
|
| 109 |
+
raise RuntimeError(
|
| 110 |
+
"Cannot deterministically resume because this checkpoint has no rng_states entry. "
|
| 111 |
+
"Use a checkpoint created after automatic resume support was added, or start a fresh run."
|
| 112 |
+
)
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
random.setstate(rng_states["python"])
|
| 116 |
+
np.random.set_state(rng_states["numpy"])
|
| 117 |
+
torch.set_rng_state(rng_states["torch"])
|
| 118 |
+
if torch.cuda.is_available() and rng_states["cuda"] is not None:
|
| 119 |
+
torch.cuda.set_rng_state_all(rng_states["cuda"])
|
| 120 |
+
|
| 121 |
+
def log_video(
|
| 122 |
+
self,
|
| 123 |
+
key: str,
|
| 124 |
+
video: Union[np.ndarray, torch.Tensor],
|
| 125 |
+
mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
|
| 126 |
+
std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
|
| 127 |
+
fps: int = 5,
|
| 128 |
+
format: str = "mp4",
|
| 129 |
+
):
|
| 130 |
+
"""
|
| 131 |
+
Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly.
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
video: a numpy array or tensor, either in form (time, channel, height, width) or in the form
|
| 135 |
+
(batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8
|
| 136 |
+
or [0, 1] otherwise.
|
| 137 |
+
mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1].
|
| 138 |
+
std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1].
|
| 139 |
+
key: the name of the video.
|
| 140 |
+
fps: the frame rate of the video.
|
| 141 |
+
format: the format of the video. Can be either "mp4" or "gif".
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
if isinstance(video, torch.Tensor):
|
| 145 |
+
video = video.detach().cpu().numpy()
|
| 146 |
+
|
| 147 |
+
expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1]
|
| 148 |
+
if std is not None:
|
| 149 |
+
if isinstance(std, (float, int)):
|
| 150 |
+
std = [std] * 3
|
| 151 |
+
if isinstance(std, torch.Tensor):
|
| 152 |
+
std = std.detach().cpu().numpy()
|
| 153 |
+
std = np.array(std).reshape(*expand_shape)
|
| 154 |
+
video = video * std
|
| 155 |
+
if mean is not None:
|
| 156 |
+
if isinstance(mean, (float, int)):
|
| 157 |
+
mean = [mean] * 3
|
| 158 |
+
if isinstance(mean, torch.Tensor):
|
| 159 |
+
mean = mean.detach().cpu().numpy()
|
| 160 |
+
mean = np.array(mean).reshape(*expand_shape)
|
| 161 |
+
video = video + mean
|
| 162 |
+
|
| 163 |
+
if video.dtype != np.uint8:
|
| 164 |
+
video = np.clip(video, a_min=0, a_max=1) * 255
|
| 165 |
+
video = video.astype(np.uint8)
|
| 166 |
+
|
| 167 |
+
self.logger.experiment.log(
|
| 168 |
+
{
|
| 169 |
+
key: wandb.Video(video, fps=fps, format=format),
|
| 170 |
+
},
|
| 171 |
+
step=self.global_step,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def log_image(
|
| 175 |
+
self,
|
| 176 |
+
key: str,
|
| 177 |
+
image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]],
|
| 178 |
+
mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
|
| 179 |
+
std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
|
| 180 |
+
**kwargs: Any,
|
| 181 |
+
):
|
| 182 |
+
"""
|
| 183 |
+
Log image(s) using WandbLogger.
|
| 184 |
+
Args:
|
| 185 |
+
key: the name of the video.
|
| 186 |
+
image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width).
|
| 187 |
+
mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1].
|
| 188 |
+
std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1].
|
| 189 |
+
kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx.
|
| 190 |
+
"""
|
| 191 |
+
if isinstance(image, Image.Image):
|
| 192 |
+
image = [image]
|
| 193 |
+
elif len(image) and not isinstance(image[0], Image.Image):
|
| 194 |
+
if isinstance(image, torch.Tensor):
|
| 195 |
+
image = image.detach().cpu().numpy()
|
| 196 |
+
|
| 197 |
+
if len(image.shape) == 3:
|
| 198 |
+
image = image[None]
|
| 199 |
+
|
| 200 |
+
if image.shape[1] == 3:
|
| 201 |
+
if image.shape[-1] == 3:
|
| 202 |
+
warnings.warn(f"Two channels in shape {image.shape} have size 3, assuming channel first.")
|
| 203 |
+
image = einops.rearrange(image, "b c h w -> b h w c")
|
| 204 |
+
|
| 205 |
+
if std is not None:
|
| 206 |
+
if isinstance(std, (float, int)):
|
| 207 |
+
std = [std] * 3
|
| 208 |
+
if isinstance(std, torch.Tensor):
|
| 209 |
+
std = std.detach().cpu().numpy()
|
| 210 |
+
std = np.array(std)[None, None, None]
|
| 211 |
+
image = image * std
|
| 212 |
+
if mean is not None:
|
| 213 |
+
if isinstance(mean, (float, int)):
|
| 214 |
+
mean = [mean] * 3
|
| 215 |
+
if isinstance(mean, torch.Tensor):
|
| 216 |
+
mean = mean.detach().cpu().numpy()
|
| 217 |
+
mean = np.array(mean)[None, None, None]
|
| 218 |
+
image = image + mean
|
| 219 |
+
|
| 220 |
+
if image.dtype != np.uint8:
|
| 221 |
+
image = np.clip(image, a_min=0.0, a_max=1.0) * 255
|
| 222 |
+
image = image.astype(np.uint8)
|
| 223 |
+
image = [img for img in image]
|
| 224 |
+
|
| 225 |
+
self.logger.log_image(key=key, images=image, **kwargs)
|
| 226 |
+
|
| 227 |
+
def log_gradient_stats(self):
|
| 228 |
+
"""Log gradient statistics such as the mean or std of norm."""
|
| 229 |
+
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
grad_norms = []
|
| 232 |
+
gpr = [] # gradient-to-parameter ratio
|
| 233 |
+
for param in self.parameters():
|
| 234 |
+
if param.grad is not None:
|
| 235 |
+
grad_norms.append(torch.norm(param.grad).item())
|
| 236 |
+
gpr.append(torch.norm(param.grad) / torch.norm(param))
|
| 237 |
+
if len(grad_norms) == 0:
|
| 238 |
+
return
|
| 239 |
+
grad_norms = torch.tensor(grad_norms)
|
| 240 |
+
gpr = torch.tensor(gpr)
|
| 241 |
+
self.log_dict(
|
| 242 |
+
{
|
| 243 |
+
"train/grad_norm/min": grad_norms.min(),
|
| 244 |
+
"train/grad_norm/max": grad_norms.max(),
|
| 245 |
+
"train/grad_norm/std": grad_norms.std(),
|
| 246 |
+
"train/grad_norm/mean": grad_norms.mean(),
|
| 247 |
+
"train/grad_norm/median": torch.median(grad_norms),
|
| 248 |
+
"train/gpr/min": gpr.min(),
|
| 249 |
+
"train/gpr/max": gpr.max(),
|
| 250 |
+
"train/gpr/std": gpr.std(),
|
| 251 |
+
"train/gpr/mean": gpr.mean(),
|
| 252 |
+
"train/gpr/median": torch.median(gpr),
|
| 253 |
+
}
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def register_data_mean_std(
|
| 257 |
+
self, mean: Union[str, float, Sequence], std: Union[str, float, Sequence], namespace: str = "data"
|
| 258 |
+
):
|
| 259 |
+
"""
|
| 260 |
+
Register mean and std of data as tensor buffer.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
mean: the mean of data.
|
| 264 |
+
std: the std of data.
|
| 265 |
+
namespace: the namespace of the registered buffer.
|
| 266 |
+
"""
|
| 267 |
+
for k, v in [("mean", mean), ("std", std)]:
|
| 268 |
+
if isinstance(v, str):
|
| 269 |
+
if v.endswith(".npy"):
|
| 270 |
+
v = torch.from_numpy(np.load(v))
|
| 271 |
+
elif v.endswith(".pt"):
|
| 272 |
+
v = torch.load(v)
|
| 273 |
+
else:
|
| 274 |
+
raise ValueError(f"Unsupported file type {v.split('.')[-1]}.")
|
| 275 |
+
else:
|
| 276 |
+
v = torch.tensor(v)
|
| 277 |
+
self.register_buffer(f"{namespace}_{k}", v.float().to(self.device))
|
algorithms/common/metrics/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .fid import FrechetInceptionDistance
|
| 2 |
+
from .lpips import LearnedPerceptualImagePatchSimilarity
|
| 3 |
+
from .fvd import FrechetVideoDistance
|
algorithms/common/metrics/fid.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
algorithms/common/metrics/fvd.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adopted from https://github.com/cvpr2022-stylegan-v/stylegan-v
|
| 3 |
+
Verified to be the same as tf version by https://github.com/universome/fvd-comparison
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import io
|
| 7 |
+
import re
|
| 8 |
+
import requests
|
| 9 |
+
import html
|
| 10 |
+
import hashlib
|
| 11 |
+
import urllib
|
| 12 |
+
import urllib.request
|
| 13 |
+
from typing import Any, List, Tuple, Union, Dict
|
| 14 |
+
import scipy
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def open_url(
|
| 22 |
+
url: str,
|
| 23 |
+
num_attempts: int = 10,
|
| 24 |
+
verbose: bool = True,
|
| 25 |
+
return_filename: bool = False,
|
| 26 |
+
) -> Any:
|
| 27 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
| 28 |
+
assert num_attempts >= 1
|
| 29 |
+
|
| 30 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
| 31 |
+
if not re.match("^[a-z]+://", url):
|
| 32 |
+
return url if return_filename else open(url, "rb")
|
| 33 |
+
|
| 34 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
| 35 |
+
# arise on Windows:
|
| 36 |
+
#
|
| 37 |
+
# file:///c:/foo.txt
|
| 38 |
+
#
|
| 39 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
| 40 |
+
# invalid. Drop the forward slash for such pathnames.
|
| 41 |
+
#
|
| 42 |
+
# If you touch this code path, you should test it on both Linux and
|
| 43 |
+
# Windows.
|
| 44 |
+
#
|
| 45 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
| 46 |
+
# but that converts forward slashes to backslashes and this causes
|
| 47 |
+
# its own set of problems.
|
| 48 |
+
if url.startswith("file://"):
|
| 49 |
+
filename = urllib.parse.urlparse(url).path
|
| 50 |
+
if re.match(r"^/[a-zA-Z]:", filename):
|
| 51 |
+
filename = filename[1:]
|
| 52 |
+
return filename if return_filename else open(filename, "rb")
|
| 53 |
+
|
| 54 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
| 55 |
+
|
| 56 |
+
# Download.
|
| 57 |
+
url_name = None
|
| 58 |
+
url_data = None
|
| 59 |
+
with requests.Session() as session:
|
| 60 |
+
if verbose:
|
| 61 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
| 62 |
+
for attempts_left in reversed(range(num_attempts)):
|
| 63 |
+
try:
|
| 64 |
+
with session.get(url) as res:
|
| 65 |
+
res.raise_for_status()
|
| 66 |
+
if len(res.content) == 0:
|
| 67 |
+
raise IOError("No data received")
|
| 68 |
+
|
| 69 |
+
if len(res.content) < 8192:
|
| 70 |
+
content_str = res.content.decode("utf-8")
|
| 71 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
| 72 |
+
links = [
|
| 73 |
+
html.unescape(link)
|
| 74 |
+
for link in content_str.split('"')
|
| 75 |
+
if "export=download" in link
|
| 76 |
+
]
|
| 77 |
+
if len(links) == 1:
|
| 78 |
+
url = requests.compat.urljoin(url, links[0])
|
| 79 |
+
raise IOError("Google Drive virus checker nag")
|
| 80 |
+
if "Google Drive - Quota exceeded" in content_str:
|
| 81 |
+
raise IOError(
|
| 82 |
+
"Google Drive download quota exceeded -- please try again later"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
match = re.search(
|
| 86 |
+
r'filename="([^"]*)"',
|
| 87 |
+
res.headers.get("Content-Disposition", ""),
|
| 88 |
+
)
|
| 89 |
+
url_name = match[1] if match else url
|
| 90 |
+
url_data = res.content
|
| 91 |
+
if verbose:
|
| 92 |
+
print(" done")
|
| 93 |
+
break
|
| 94 |
+
except KeyboardInterrupt:
|
| 95 |
+
raise
|
| 96 |
+
except:
|
| 97 |
+
if not attempts_left:
|
| 98 |
+
if verbose:
|
| 99 |
+
print(" failed")
|
| 100 |
+
raise
|
| 101 |
+
if verbose:
|
| 102 |
+
print(".", end="", flush=True)
|
| 103 |
+
|
| 104 |
+
# Return data as file object.
|
| 105 |
+
assert not return_filename
|
| 106 |
+
return io.BytesIO(url_data)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
|
| 110 |
+
mu_gen, sigma_gen = compute_stats(feats_fake)
|
| 111 |
+
mu_real, sigma_real = compute_stats(feats_real)
|
| 112 |
+
|
| 113 |
+
m = np.square(mu_gen - mu_real).sum()
|
| 114 |
+
s, _ = scipy.linalg.sqrtm(
|
| 115 |
+
np.dot(sigma_gen, sigma_real), disp=False
|
| 116 |
+
) # pylint: disable=no-member
|
| 117 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
| 118 |
+
|
| 119 |
+
return float(fid)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 123 |
+
mu = feats.mean(axis=0) # [d]
|
| 124 |
+
sigma = np.cov(feats, rowvar=False) # [d, d]
|
| 125 |
+
|
| 126 |
+
return mu, sigma
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class FrechetVideoDistance(nn.Module):
|
| 130 |
+
def __init__(self):
|
| 131 |
+
super().__init__()
|
| 132 |
+
detector_url = (
|
| 133 |
+
"https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1"
|
| 134 |
+
)
|
| 135 |
+
# Return raw features before the softmax layer.
|
| 136 |
+
self.detector_kwargs = dict(rescale=False, resize=True, return_features=True)
|
| 137 |
+
with open_url(detector_url, verbose=False) as f:
|
| 138 |
+
self.detector = torch.jit.load(f).eval()
|
| 139 |
+
|
| 140 |
+
@torch.no_grad()
|
| 141 |
+
def compute(self, videos_fake: torch.Tensor, videos_real: torch.Tensor):
|
| 142 |
+
"""
|
| 143 |
+
:param videos_fake: predicted video tensor of shape (frame, batch, channel, height, width)
|
| 144 |
+
:param videos_real: ground-truth observation tensor of shape (frame, batch, channel, height, width)
|
| 145 |
+
:return:
|
| 146 |
+
"""
|
| 147 |
+
n_frames, batch_size, c, h, w = videos_fake.shape
|
| 148 |
+
if n_frames < 2:
|
| 149 |
+
raise ValueError("Video must have more than 1 frame for FVD")
|
| 150 |
+
|
| 151 |
+
videos_fake = videos_fake.permute(1, 2, 0, 3, 4).contiguous()
|
| 152 |
+
videos_real = videos_real.permute(1, 2, 0, 3, 4).contiguous()
|
| 153 |
+
|
| 154 |
+
# detector takes in tensors of shape [batch_size, c, video_len, h, w] with range -1 to 1
|
| 155 |
+
feats_fake = self.detector(videos_fake, **self.detector_kwargs).cpu().numpy()
|
| 156 |
+
feats_real = self.detector(videos_real, **self.detector_kwargs).cpu().numpy()
|
| 157 |
+
|
| 158 |
+
return compute_fvd(feats_fake, feats_real)
|
algorithms/common/metrics/lpips.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
algorithms/worldmem/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .dememwm_memory_dit import DeMemWMMinecraft, DeMemWMMemoryDiTMinecraft
|
algorithms/worldmem/dememwm/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors, RevisitRetrievalResult, StreamGateState
|
| 3 |
+
from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
|
| 4 |
+
from .compression import CausalConv3DDynamicCompressor, latent_patch_tokens, spatial_pool_tokens
|
| 5 |
+
from .retrieval import deterministic_revisit_retrieval
|
| 6 |
+
from .schedules import compute_stream_gates, CurriculumState, resolve_curriculum, DeMemWMCurriculumState, resolve_dememwm_curriculum
|
| 7 |
+
from .gates import RevisitRawGate
|
| 8 |
+
from .cache import StreamingCache, DeMemWMStreamingCache
|
| 9 |
+
from .injection import InjectionAdapter, DeMemWMInjectionAdapter
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"MemoryRecord", "MemorySourceType", "MemoryStreamTensors", "RevisitRetrievalResult", "StreamGateState",
|
| 13 |
+
"CausalMemoryBank", "MemoryBankQuery", "stack_record_tokens",
|
| 14 |
+
"CausalConv3DDynamicCompressor", "latent_patch_tokens", "spatial_pool_tokens",
|
| 15 |
+
"deterministic_revisit_retrieval", "compute_stream_gates", "CurriculumState", "resolve_curriculum",
|
| 16 |
+
"DeMemWMCurriculumState", "resolve_dememwm_curriculum", "RevisitRawGate",
|
| 17 |
+
"StreamingCache", "DeMemWMStreamingCache", "InjectionAdapter", "DeMemWMInjectionAdapter",
|
| 18 |
+
]
|
algorithms/worldmem/dememwm/algorithm.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
algorithms/worldmem/dememwm/cache.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Iterable, Optional
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from .memory import CausalMemoryBank
|
| 11 |
+
from .types import MemoryRecord
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class _RawLatentSegment:
|
| 16 |
+
latents: torch.Tensor
|
| 17 |
+
frame_indices: torch.Tensor
|
| 18 |
+
source_is_generated: torch.Tensor
|
| 19 |
+
pose: Optional[torch.Tensor]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class StreamingCache:
|
| 23 |
+
"""Per-video DeMemWM streaming cache with strict no-eviction semantics.
|
| 24 |
+
|
| 25 |
+
The cache is intentionally allowed to grow for the current video. It stores
|
| 26 |
+
detached CPU (or pinned CPU) raw latents plus compressed MemoryRecord objects,
|
| 27 |
+
while DiT readout tensors remain bounded by the caller's manual budgets.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
*,
|
| 33 |
+
enabled: bool = True,
|
| 34 |
+
device: str = "cpu",
|
| 35 |
+
keep_raw_latents: str = "all",
|
| 36 |
+
keep_compressed_records: bool = True,
|
| 37 |
+
keep_prefix_anchors: bool = True,
|
| 38 |
+
eviction_policy: str = "none",
|
| 39 |
+
no_evict: bool = True,
|
| 40 |
+
clear_between_videos: bool = True,
|
| 41 |
+
max_records: Optional[int] = None,
|
| 42 |
+
max_slots: Optional[int] = None,
|
| 43 |
+
on_capacity_exceeded: str = "warn",
|
| 44 |
+
) -> None:
|
| 45 |
+
self.enabled = bool(enabled)
|
| 46 |
+
self.device = str(device or "cpu")
|
| 47 |
+
self.keep_raw_latents = keep_raw_latents
|
| 48 |
+
self.keep_compressed_records = bool(keep_compressed_records)
|
| 49 |
+
self.keep_prefix_anchors = bool(keep_prefix_anchors)
|
| 50 |
+
self.eviction_policy = str(eviction_policy or "none")
|
| 51 |
+
self.no_evict = bool(no_evict)
|
| 52 |
+
self.clear_between_videos = bool(clear_between_videos)
|
| 53 |
+
self.max_records = max_records
|
| 54 |
+
self.max_slots = max_slots
|
| 55 |
+
self.on_capacity_exceeded = str(on_capacity_exceeded or "warn")
|
| 56 |
+
if self.eviction_policy != "none" or not self.no_evict:
|
| 57 |
+
raise ValueError("DeMemWMStreamingCache only supports eviction_policy='none' with no_evict=true")
|
| 58 |
+
if self.device not in {"cpu", "pinned_cpu", "cuda"}:
|
| 59 |
+
raise ValueError("cache.device must be one of: cpu, pinned_cpu, cuda")
|
| 60 |
+
self.reset_count = 0
|
| 61 |
+
self.evictions = 0
|
| 62 |
+
self.capacity_exceeded_count = 0
|
| 63 |
+
self.current_video_id: Any = None
|
| 64 |
+
self._raw_segments: list[_RawLatentSegment] = []
|
| 65 |
+
self._records: dict[str, dict[int, list[MemoryRecord]]] = {"anchor": {}, "revisit": {}}
|
| 66 |
+
self._raw_keys: set[tuple[int, int]] = set()
|
| 67 |
+
self._raw_index: dict[tuple[int, int], tuple[int, int]] = {}
|
| 68 |
+
self._record_keys: set[tuple[str, int, str, int, int, bool]] = set()
|
| 69 |
+
self._batch_size: Optional[int] = None
|
| 70 |
+
# Concat cache: avoids repeated torch.cat across DDIM steps within one chunk.
|
| 71 |
+
# Invalidated whenever new raw segments are added.
|
| 72 |
+
self._raw_concat_version: int = 0
|
| 73 |
+
self._raw_concat_built: int = -1
|
| 74 |
+
self._raw_concat_cache: Optional[tuple] = None # (latents, frame_indices, generated, pose)
|
| 75 |
+
# GPU memory-bank cache: avoids repeated CPU→GPU record transfers across DDIM steps.
|
| 76 |
+
# Invalidated whenever new records are added.
|
| 77 |
+
self._banks_version: int = 0
|
| 78 |
+
self._banks_built_cache: dict[tuple, tuple[int, list[CausalMemoryBank]]] = {}
|
| 79 |
+
|
| 80 |
+
@classmethod
|
| 81 |
+
def from_config(cls, cfg: Any, *, enabled_default: bool = True) -> "StreamingCache":
|
| 82 |
+
def get(name: str, default: Any) -> Any:
|
| 83 |
+
return getattr(cfg, name, default) if cfg is not None else default
|
| 84 |
+
|
| 85 |
+
return cls(
|
| 86 |
+
enabled=bool(get("enabled", enabled_default)),
|
| 87 |
+
device=str(get("device", "cpu")),
|
| 88 |
+
keep_raw_latents=str(get("keep_raw_latents", "all")),
|
| 89 |
+
keep_compressed_records=bool(get("keep_compressed_records", True)),
|
| 90 |
+
keep_prefix_anchors=bool(get("keep_prefix_anchors", True)),
|
| 91 |
+
eviction_policy=str(get("eviction_policy", "none")),
|
| 92 |
+
no_evict=bool(get("no_evict", True)),
|
| 93 |
+
clear_between_videos=bool(get("clear_between_videos", True)),
|
| 94 |
+
max_records=get("max_records", None),
|
| 95 |
+
max_slots=get("max_slots", None),
|
| 96 |
+
on_capacity_exceeded=str(get("on_capacity_exceeded", "warn")),
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def batch_size(self) -> int:
|
| 101 |
+
return int(self._batch_size or 0)
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def raw_segment_count(self) -> int:
|
| 105 |
+
return len(self._raw_segments)
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def raw_frame_slots(self) -> int:
|
| 109 |
+
return sum(int(seg.latents.shape[0] * seg.latents.shape[1]) for seg in self._raw_segments)
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def record_count(self) -> int:
|
| 113 |
+
return sum(len(records) for by_batch in self._records.values() for records in by_batch.values())
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def slot_count(self) -> int:
|
| 117 |
+
return sum(record.valid_slots for by_batch in self._records.values() for records in by_batch.values() for record in records)
|
| 118 |
+
|
| 119 |
+
def records_count(self, kind: str | None = None) -> int:
|
| 120 |
+
if kind is None:
|
| 121 |
+
return self.record_count
|
| 122 |
+
return sum(len(records) for records in self._records.get(kind, {}).values())
|
| 123 |
+
|
| 124 |
+
def reset(self, video_id: Any = None) -> None:
|
| 125 |
+
self.current_video_id = video_id
|
| 126 |
+
self._raw_segments.clear()
|
| 127 |
+
self._records = {"anchor": {}, "revisit": {}}
|
| 128 |
+
self._raw_keys.clear()
|
| 129 |
+
self._raw_index.clear()
|
| 130 |
+
self._record_keys.clear()
|
| 131 |
+
self._batch_size = None
|
| 132 |
+
self.evictions = 0
|
| 133 |
+
self.capacity_exceeded_count = 0
|
| 134 |
+
self.reset_count += 1
|
| 135 |
+
self._raw_concat_version += 1
|
| 136 |
+
self._raw_concat_built = -1
|
| 137 |
+
self._raw_concat_cache = None
|
| 138 |
+
self._banks_version += 1
|
| 139 |
+
self._banks_built_cache.clear()
|
| 140 |
+
|
| 141 |
+
def _store_tensor(self, tensor: Optional[torch.Tensor], *, dtype: torch.dtype | None = None) -> Optional[torch.Tensor]:
|
| 142 |
+
if tensor is None:
|
| 143 |
+
return None
|
| 144 |
+
out = tensor.detach()
|
| 145 |
+
if dtype is not None and out.is_floating_point():
|
| 146 |
+
out = out.to(dtype=dtype)
|
| 147 |
+
if self.device in {"cpu", "pinned_cpu"}:
|
| 148 |
+
out = out.to(device="cpu", copy=True)
|
| 149 |
+
if self.device == "pinned_cpu":
|
| 150 |
+
try:
|
| 151 |
+
out = out.pin_memory()
|
| 152 |
+
except RuntimeError:
|
| 153 |
+
# Keep stable CPU behavior if pinning is unavailable in a worker/process.
|
| 154 |
+
pass
|
| 155 |
+
elif self.device == "cuda":
|
| 156 |
+
out = out.clone()
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
def _metadata_to_storage(self, metadata: dict) -> dict:
|
| 160 |
+
out = {}
|
| 161 |
+
for key, value in dict(metadata or {}).items():
|
| 162 |
+
if torch.is_tensor(value):
|
| 163 |
+
out[key] = self._store_tensor(value)
|
| 164 |
+
elif isinstance(value, dict):
|
| 165 |
+
out[key] = self._metadata_to_storage(value)
|
| 166 |
+
else:
|
| 167 |
+
out[key] = value
|
| 168 |
+
return out
|
| 169 |
+
|
| 170 |
+
def _metadata_to_device(self, metadata: dict, *, device: torch.device, dtype: torch.dtype) -> dict:
|
| 171 |
+
out = {}
|
| 172 |
+
for key, value in dict(metadata or {}).items():
|
| 173 |
+
if torch.is_tensor(value):
|
| 174 |
+
tensor = value.to(device=device)
|
| 175 |
+
out[key] = tensor.to(dtype=dtype) if tensor.is_floating_point() else tensor
|
| 176 |
+
elif isinstance(value, dict):
|
| 177 |
+
out[key] = self._metadata_to_device(value, device=device, dtype=dtype)
|
| 178 |
+
else:
|
| 179 |
+
out[key] = value
|
| 180 |
+
return out
|
| 181 |
+
|
| 182 |
+
def _record_to_storage(self, record: MemoryRecord) -> MemoryRecord:
|
| 183 |
+
return MemoryRecord(
|
| 184 |
+
tokens=self._store_tensor(record.tokens),
|
| 185 |
+
mask=self._store_tensor(record.mask),
|
| 186 |
+
source_start=int(record.source_start),
|
| 187 |
+
source_end=int(record.source_end),
|
| 188 |
+
frame_indices=self._store_tensor(record.frame_indices),
|
| 189 |
+
pose=self._store_tensor(record.pose),
|
| 190 |
+
source_type=record.source_type,
|
| 191 |
+
is_generated=bool(record.is_generated),
|
| 192 |
+
score=None if record.score is None or not torch.is_tensor(record.score) else self._store_tensor(record.score),
|
| 193 |
+
chunk_id=record.chunk_id,
|
| 194 |
+
metadata=self._metadata_to_storage(record.metadata),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def _record_to_device(self, record: MemoryRecord, *, device: torch.device, dtype: torch.dtype) -> MemoryRecord:
|
| 198 |
+
return MemoryRecord(
|
| 199 |
+
tokens=record.tokens.to(device=device, dtype=dtype),
|
| 200 |
+
mask=record.mask.to(device=device, dtype=torch.bool),
|
| 201 |
+
source_start=int(record.source_start),
|
| 202 |
+
source_end=int(record.source_end),
|
| 203 |
+
frame_indices=record.frame_indices.to(device=device),
|
| 204 |
+
pose=None if record.pose is None else record.pose.to(device=device),
|
| 205 |
+
source_type=record.source_type,
|
| 206 |
+
is_generated=bool(record.is_generated),
|
| 207 |
+
score=record.score,
|
| 208 |
+
chunk_id=record.chunk_id,
|
| 209 |
+
metadata=self._metadata_to_device(record.metadata, device=device, dtype=dtype),
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def _check_capacity(self) -> None:
|
| 213 |
+
exceeded = False
|
| 214 |
+
if self.max_records is not None and self.record_count > int(self.max_records):
|
| 215 |
+
exceeded = True
|
| 216 |
+
if self.max_slots is not None and self.slot_count > int(self.max_slots):
|
| 217 |
+
exceeded = True
|
| 218 |
+
if not exceeded:
|
| 219 |
+
return
|
| 220 |
+
self.capacity_exceeded_count += 1
|
| 221 |
+
msg = (
|
| 222 |
+
"DeMemWMStreamingCache capacity exceeded "
|
| 223 |
+
f"records={self.record_count}/{self.max_records}, slots={self.slot_count}/{self.max_slots}; "
|
| 224 |
+
"no eviction performed because no_evict=true"
|
| 225 |
+
)
|
| 226 |
+
if self.on_capacity_exceeded == "error":
|
| 227 |
+
raise RuntimeError(msg)
|
| 228 |
+
if self.on_capacity_exceeded == "warn":
|
| 229 |
+
warnings.warn(msg, RuntimeWarning, stacklevel=2)
|
| 230 |
+
|
| 231 |
+
def add_raw_latents(
|
| 232 |
+
self,
|
| 233 |
+
latents: torch.Tensor,
|
| 234 |
+
frame_indices: torch.Tensor,
|
| 235 |
+
source_is_generated: Optional[torch.Tensor] = None,
|
| 236 |
+
pose: Optional[torch.Tensor] = None,
|
| 237 |
+
) -> None:
|
| 238 |
+
if not self.enabled or self.keep_raw_latents != "all":
|
| 239 |
+
return
|
| 240 |
+
if latents.ndim != 5:
|
| 241 |
+
raise ValueError("cached raw latents must have shape (T,B,C,H,W)")
|
| 242 |
+
T, B = int(latents.shape[0]), int(latents.shape[1])
|
| 243 |
+
if frame_indices.shape != (T, B):
|
| 244 |
+
raise ValueError("cached frame_indices must have shape (T,B)")
|
| 245 |
+
if self._batch_size is None:
|
| 246 |
+
self._batch_size = B
|
| 247 |
+
elif self._batch_size != B:
|
| 248 |
+
raise ValueError("streaming cache batch size changed within a video")
|
| 249 |
+
keep_positions: list[int] = []
|
| 250 |
+
frame_cpu = frame_indices.detach().cpu()
|
| 251 |
+
for t in range(T):
|
| 252 |
+
keys = [(b, int(frame_cpu[t, b].item())) for b in range(B)]
|
| 253 |
+
if any(key not in self._raw_keys for key in keys):
|
| 254 |
+
keep_positions.append(t)
|
| 255 |
+
self._raw_keys.update(keys)
|
| 256 |
+
if not keep_positions:
|
| 257 |
+
return
|
| 258 |
+
pos = torch.as_tensor(keep_positions, dtype=torch.long)
|
| 259 |
+
seg_latents = latents.index_select(0, pos.to(device=latents.device))
|
| 260 |
+
seg_frames = frame_indices.index_select(0, pos.to(device=frame_indices.device))
|
| 261 |
+
if source_is_generated is None:
|
| 262 |
+
seg_generated = torch.zeros(seg_frames.shape, device=seg_frames.device, dtype=torch.bool)
|
| 263 |
+
else:
|
| 264 |
+
seg_generated = source_is_generated.index_select(0, pos.to(device=source_is_generated.device)).bool()
|
| 265 |
+
seg_pose = None if pose is None else pose.index_select(0, pos.to(device=pose.device))
|
| 266 |
+
segment_idx = len(self._raw_segments)
|
| 267 |
+
self._raw_segments.append(
|
| 268 |
+
_RawLatentSegment(
|
| 269 |
+
latents=self._store_tensor(seg_latents),
|
| 270 |
+
frame_indices=self._store_tensor(seg_frames),
|
| 271 |
+
source_is_generated=self._store_tensor(seg_generated),
|
| 272 |
+
pose=self._store_tensor(seg_pose),
|
| 273 |
+
)
|
| 274 |
+
)
|
| 275 |
+
for local_pos, source_pos in enumerate(keep_positions):
|
| 276 |
+
for b in range(B):
|
| 277 |
+
key = (b, int(frame_cpu[source_pos, b].item()))
|
| 278 |
+
self._raw_index.setdefault(key, (segment_idx, local_pos))
|
| 279 |
+
# Invalidate the concat cache — new segment was added.
|
| 280 |
+
self._raw_concat_version += 1
|
| 281 |
+
self._raw_concat_cache = None
|
| 282 |
+
|
| 283 |
+
def add_records(self, kind: str, batch_idx: int, records: Iterable[MemoryRecord]) -> None:
|
| 284 |
+
if not self.enabled or not self.keep_compressed_records:
|
| 285 |
+
return
|
| 286 |
+
if kind not in self._records:
|
| 287 |
+
raise ValueError(f"unsupported cache record kind: {kind}")
|
| 288 |
+
batch_idx = int(batch_idx)
|
| 289 |
+
bucket = self._records[kind].setdefault(batch_idx, [])
|
| 290 |
+
added_any = False
|
| 291 |
+
for record in records:
|
| 292 |
+
if kind == "anchor" and not self.keep_prefix_anchors:
|
| 293 |
+
continue
|
| 294 |
+
key = (
|
| 295 |
+
kind,
|
| 296 |
+
batch_idx,
|
| 297 |
+
str(record.chunk_id or ""),
|
| 298 |
+
int(record.source_start),
|
| 299 |
+
int(record.source_end),
|
| 300 |
+
bool(record.is_generated),
|
| 301 |
+
)
|
| 302 |
+
if key in self._record_keys:
|
| 303 |
+
continue
|
| 304 |
+
self._record_keys.add(key)
|
| 305 |
+
bucket.append(self._record_to_storage(record))
|
| 306 |
+
added_any = True
|
| 307 |
+
if added_any:
|
| 308 |
+
# Invalidate the GPU banks cache — new records were added.
|
| 309 |
+
self._banks_version += 1
|
| 310 |
+
self._banks_built_cache.clear()
|
| 311 |
+
self._check_capacity()
|
| 312 |
+
|
| 313 |
+
def add_memory_banks(self, anchor_banks: list[CausalMemoryBank], revisit_banks: list[CausalMemoryBank]) -> None:
|
| 314 |
+
for batch_idx, bank in enumerate(anchor_banks):
|
| 315 |
+
self.add_records("anchor", batch_idx, bank.records)
|
| 316 |
+
for batch_idx, bank in enumerate(revisit_banks):
|
| 317 |
+
self.add_records("revisit", batch_idx, bank.records)
|
| 318 |
+
|
| 319 |
+
def memory_banks(self, kind: str, *, device: torch.device, dtype: torch.dtype, batch_size: int | None = None) -> list[CausalMemoryBank]:
|
| 320 |
+
if kind not in self._records:
|
| 321 |
+
raise ValueError(f"unsupported cache record kind: {kind}")
|
| 322 |
+
B = int(batch_size or self.batch_size or (max(self._records[kind].keys()) + 1 if self._records[kind] else 0))
|
| 323 |
+
cache_key = (kind, device, dtype, B)
|
| 324 |
+
cached = self._banks_built_cache.get(cache_key)
|
| 325 |
+
if cached is not None and cached[0] == self._banks_version:
|
| 326 |
+
return cached[1]
|
| 327 |
+
banks: list[CausalMemoryBank] = []
|
| 328 |
+
for batch_idx in range(B):
|
| 329 |
+
bank = CausalMemoryBank()
|
| 330 |
+
for record in self._records[kind].get(batch_idx, []):
|
| 331 |
+
bank.add_record(self._record_to_device(record, device=device, dtype=dtype))
|
| 332 |
+
banks.append(bank)
|
| 333 |
+
self._banks_built_cache[cache_key] = (self._banks_version, banks)
|
| 334 |
+
return banks
|
| 335 |
+
|
| 336 |
+
def records_for_batch(self, kind: str, batch_idx: int) -> tuple[MemoryRecord, ...]:
|
| 337 |
+
if kind not in self._records:
|
| 338 |
+
raise ValueError(f"unsupported cache record kind: {kind}")
|
| 339 |
+
return tuple(self._records[kind].get(int(batch_idx), ()))
|
| 340 |
+
|
| 341 |
+
def raw_latents_for_frames(
|
| 342 |
+
self,
|
| 343 |
+
*,
|
| 344 |
+
batch_idx: int,
|
| 345 |
+
frame_indices: torch.Tensor,
|
| 346 |
+
device: torch.device,
|
| 347 |
+
dtype: torch.dtype,
|
| 348 |
+
) -> torch.Tensor:
|
| 349 |
+
frames = frame_indices.detach().cpu().reshape(-1)
|
| 350 |
+
rows = []
|
| 351 |
+
batch_idx = int(batch_idx)
|
| 352 |
+
for frame in frames.tolist():
|
| 353 |
+
key = (batch_idx, int(frame))
|
| 354 |
+
location = self._raw_index.get(key)
|
| 355 |
+
if location is None:
|
| 356 |
+
raise KeyError(f"raw latent for batch={batch_idx}, frame={int(frame)} is not cached")
|
| 357 |
+
segment_idx, local_pos = location
|
| 358 |
+
rows.append(self._raw_segments[segment_idx].latents[local_pos, batch_idx])
|
| 359 |
+
if not rows:
|
| 360 |
+
template = self._raw_segments[0].latents
|
| 361 |
+
return template[:0, batch_idx:batch_idx + 1].to(device=device, dtype=dtype)
|
| 362 |
+
return torch.stack(rows, dim=0).unsqueeze(1).to(device=device, dtype=dtype)
|
| 363 |
+
|
| 364 |
+
def _select_time_positions(
|
| 365 |
+
self,
|
| 366 |
+
frame_indices: torch.Tensor,
|
| 367 |
+
target_frame_indices: Optional[torch.Tensor],
|
| 368 |
+
max_recent_frames: Optional[int],
|
| 369 |
+
exclude_latest_local_frames: int = 0,
|
| 370 |
+
) -> torch.Tensor:
|
| 371 |
+
T, B = frame_indices.shape
|
| 372 |
+
if target_frame_indices is None or max_recent_frames is None or int(max_recent_frames) <= 0:
|
| 373 |
+
return torch.arange(T, dtype=torch.long)
|
| 374 |
+
targets = target_frame_indices.detach().cpu()
|
| 375 |
+
if targets.ndim == 1:
|
| 376 |
+
targets = targets[:, None].expand(-1, B)
|
| 377 |
+
frames = frame_indices.detach().cpu() # (T, B)
|
| 378 |
+
recent = int(max_recent_frames)
|
| 379 |
+
exclude = max(0, int(exclude_latest_local_frames))
|
| 380 |
+
# Vectorized: valid[t_tgt, t_src, b] = True if source position t_src is
|
| 381 |
+
# causally valid for target t_tgt in batch b.
|
| 382 |
+
# frames (T, B) → (1, T, B); targets (T_tgt, B) → (T_tgt, 1, B)
|
| 383 |
+
valid = frames.unsqueeze(0) < (targets.unsqueeze(1) - exclude) # (T_tgt, T, B)
|
| 384 |
+
# For each (t_tgt, b), retain only the last `recent` valid positions.
|
| 385 |
+
# Flip T, cumsum along T (counting from the end), keep where ≤ recent.
|
| 386 |
+
valid_f = valid.flip(1)
|
| 387 |
+
keep_f = (valid_f.long().cumsum(1) <= recent) & valid_f
|
| 388 |
+
# Any position needed by any (t_tgt, b) pair.
|
| 389 |
+
keep_any = keep_f.flip(1).any(dim=0).any(dim=1) # (T,)
|
| 390 |
+
return keep_any.nonzero(as_tuple=False).flatten()
|
| 391 |
+
|
| 392 |
+
def materialize_raw_latents(
|
| 393 |
+
self,
|
| 394 |
+
*,
|
| 395 |
+
device: torch.device,
|
| 396 |
+
dtype: torch.dtype,
|
| 397 |
+
max_recent_frames: Optional[int] = None,
|
| 398 |
+
target_frame_indices: Optional[torch.Tensor] = None,
|
| 399 |
+
exclude_latest_local_frames: int = 0,
|
| 400 |
+
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 401 |
+
if not self._raw_segments:
|
| 402 |
+
return None, None, None, None
|
| 403 |
+
if target_frame_indices is not None and max_recent_frames is not None and int(max_recent_frames) > 0:
|
| 404 |
+
return self._materialize_recent_raw_latents(
|
| 405 |
+
device=device,
|
| 406 |
+
dtype=dtype,
|
| 407 |
+
max_recent_frames=int(max_recent_frames),
|
| 408 |
+
target_frame_indices=target_frame_indices,
|
| 409 |
+
exclude_latest_local_frames=exclude_latest_local_frames,
|
| 410 |
+
)
|
| 411 |
+
# Rebuild the concatenated CPU tensors only when new segments were added.
|
| 412 |
+
if self._raw_concat_cache is None or self._raw_concat_built != self._raw_concat_version:
|
| 413 |
+
latents = torch.cat([seg.latents for seg in self._raw_segments], dim=0)
|
| 414 |
+
frame_indices = torch.cat([seg.frame_indices for seg in self._raw_segments], dim=0)
|
| 415 |
+
generated = torch.cat([seg.source_is_generated for seg in self._raw_segments], dim=0)
|
| 416 |
+
pose: Optional[torch.Tensor] = None
|
| 417 |
+
if all(seg.pose is not None for seg in self._raw_segments):
|
| 418 |
+
pose = torch.cat([seg.pose for seg in self._raw_segments if seg.pose is not None], dim=0)
|
| 419 |
+
self._raw_concat_cache = (latents, frame_indices, generated, pose)
|
| 420 |
+
self._raw_concat_built = self._raw_concat_version
|
| 421 |
+
else:
|
| 422 |
+
latents, frame_indices, generated, pose = self._raw_concat_cache
|
| 423 |
+
pos = self._select_time_positions(frame_indices, target_frame_indices, max_recent_frames, exclude_latest_local_frames)
|
| 424 |
+
if pos.numel() == 0:
|
| 425 |
+
empty_latents = latents[:0].to(device=device, dtype=dtype)
|
| 426 |
+
empty_frames = frame_indices[:0].to(device=device)
|
| 427 |
+
empty_generated = generated[:0].to(device=device, dtype=torch.bool)
|
| 428 |
+
empty_pose = None if pose is None else pose[:0].to(device=device)
|
| 429 |
+
return empty_latents, empty_frames, empty_generated, empty_pose
|
| 430 |
+
latents = latents.index_select(0, pos.to(device=latents.device)).to(device=device, dtype=dtype)
|
| 431 |
+
frame_indices = frame_indices.index_select(0, pos.to(device=frame_indices.device)).to(device=device)
|
| 432 |
+
generated = generated.index_select(0, pos.to(device=generated.device)).to(device=device, dtype=torch.bool)
|
| 433 |
+
if pose is not None:
|
| 434 |
+
pose = pose.index_select(0, pos.to(device=pose.device)).to(device=device)
|
| 435 |
+
return latents, frame_indices, generated, pose
|
| 436 |
+
|
| 437 |
+
def _materialize_recent_raw_latents(
|
| 438 |
+
self,
|
| 439 |
+
*,
|
| 440 |
+
device: torch.device,
|
| 441 |
+
dtype: torch.dtype,
|
| 442 |
+
max_recent_frames: int,
|
| 443 |
+
target_frame_indices: torch.Tensor,
|
| 444 |
+
exclude_latest_local_frames: int = 0,
|
| 445 |
+
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 446 |
+
B = self.batch_size
|
| 447 |
+
targets = target_frame_indices.detach().cpu()
|
| 448 |
+
if targets.ndim == 1:
|
| 449 |
+
targets = targets[:, None].expand(-1, B)
|
| 450 |
+
elif targets.shape[1] == 1 and B > 1:
|
| 451 |
+
targets = targets.expand(-1, B)
|
| 452 |
+
if targets.shape[1] != B:
|
| 453 |
+
raise ValueError("target_frame_indices batch dimension does not match streaming cache")
|
| 454 |
+
|
| 455 |
+
recent = max(0, int(max_recent_frames))
|
| 456 |
+
exclude = max(0, int(exclude_latest_local_frames))
|
| 457 |
+
counts = torch.zeros(targets.shape, dtype=torch.long)
|
| 458 |
+
selected: list[tuple[_RawLatentSegment, int]] = []
|
| 459 |
+
|
| 460 |
+
for segment in reversed(self._raw_segments):
|
| 461 |
+
frames = segment.frame_indices.detach().cpu()
|
| 462 |
+
for local_pos in range(frames.shape[0] - 1, -1, -1):
|
| 463 |
+
valid = frames[local_pos].unsqueeze(0) < (targets - exclude)
|
| 464 |
+
needed = valid & (counts < recent)
|
| 465 |
+
if not needed.any():
|
| 466 |
+
continue
|
| 467 |
+
selected.append((segment, local_pos))
|
| 468 |
+
counts += needed.long()
|
| 469 |
+
if bool((counts >= recent).all().item()):
|
| 470 |
+
break
|
| 471 |
+
if bool((counts >= recent).all().item()):
|
| 472 |
+
break
|
| 473 |
+
|
| 474 |
+
if not selected:
|
| 475 |
+
template = self._raw_segments[0]
|
| 476 |
+
empty_latents = template.latents[:0].to(device=device, dtype=dtype)
|
| 477 |
+
empty_frames = template.frame_indices[:0].to(device=device)
|
| 478 |
+
empty_generated = template.source_is_generated[:0].to(device=device, dtype=torch.bool)
|
| 479 |
+
empty_pose = None if template.pose is None else template.pose[:0].to(device=device)
|
| 480 |
+
return empty_latents, empty_frames, empty_generated, empty_pose
|
| 481 |
+
|
| 482 |
+
selected.reverse()
|
| 483 |
+
latents = torch.stack([segment.latents[local_pos] for segment, local_pos in selected], dim=0).to(device=device, dtype=dtype)
|
| 484 |
+
frame_indices = torch.stack([segment.frame_indices[local_pos] for segment, local_pos in selected], dim=0).to(device=device)
|
| 485 |
+
generated = torch.stack([segment.source_is_generated[local_pos] for segment, local_pos in selected], dim=0).to(device=device, dtype=torch.bool)
|
| 486 |
+
pose = None
|
| 487 |
+
if all(segment.pose is not None for segment, _ in selected):
|
| 488 |
+
pose = torch.stack(
|
| 489 |
+
[segment.pose[local_pos] for segment, local_pos in selected if segment.pose is not None],
|
| 490 |
+
dim=0,
|
| 491 |
+
).to(device=device)
|
| 492 |
+
return latents, frame_indices, generated, pose
|
| 493 |
+
|
| 494 |
+
def diagnostics(self, prefix: str = "cache") -> dict[str, Any]:
|
| 495 |
+
return {
|
| 496 |
+
f"{prefix}_enabled": bool(self.enabled),
|
| 497 |
+
f"{prefix}_records": int(self.record_count),
|
| 498 |
+
f"{prefix}_anchor_records": int(self.records_count("anchor")),
|
| 499 |
+
f"{prefix}_revisit_records": int(self.records_count("revisit")),
|
| 500 |
+
f"{prefix}_slots": int(self.slot_count),
|
| 501 |
+
f"{prefix}_raw_frame_slots": int(self.raw_frame_slots),
|
| 502 |
+
f"{prefix}_raw_segments": int(self.raw_segment_count),
|
| 503 |
+
f"{prefix}_evictions": int(self.evictions),
|
| 504 |
+
f"{prefix}_resets": int(self.reset_count),
|
| 505 |
+
f"{prefix}_capacity_exceeded": int(self.capacity_exceeded_count),
|
| 506 |
+
f"{prefix}_device": self.device,
|
| 507 |
+
f"{prefix}_current_video_id": self.current_video_id,
|
| 508 |
+
f"{prefix}_clear_between_videos": bool(self.clear_between_videos),
|
| 509 |
+
f"{prefix}_no_evict": bool(self.no_evict),
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
DeMemWMStreamingCache = StreamingCache
|
algorithms/worldmem/dememwm/compression.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def latent_patch_tokens(latents: torch.Tensor, patch_size: int) -> torch.Tensor:
|
| 12 |
+
if latents.ndim != 5:
|
| 13 |
+
raise ValueError("latents must have shape (T,B,C,H,W)")
|
| 14 |
+
if patch_size <= 0:
|
| 15 |
+
raise ValueError("patch_size must be positive")
|
| 16 |
+
T, B, C, H, W = latents.shape
|
| 17 |
+
if H % patch_size != 0 or W % patch_size != 0:
|
| 18 |
+
raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={patch_size}")
|
| 19 |
+
flat = latents.reshape(T * B, C, H, W)
|
| 20 |
+
patches = F.unfold(flat, kernel_size=patch_size, stride=patch_size).transpose(1, 2).contiguous()
|
| 21 |
+
return patches.reshape(T, B, patches.shape[1], C * patch_size * patch_size)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def spatial_pool_tokens(
|
| 25 |
+
tokens: torch.Tensor,
|
| 26 |
+
pool_h: int,
|
| 27 |
+
pool_w: int,
|
| 28 |
+
src_h: int,
|
| 29 |
+
src_w: int,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
"""2D adaptive average pool on a flattened (src_h*src_w, D) token grid.
|
| 32 |
+
Preserves 2D spatial layout. Returns (pool_h*pool_w, D)."""
|
| 33 |
+
if tokens.ndim != 2:
|
| 34 |
+
raise ValueError("tokens must have shape (N, D)")
|
| 35 |
+
D = tokens.shape[-1]
|
| 36 |
+
spatial = tokens.reshape(src_h, src_w, D).permute(2, 0, 1).unsqueeze(0)
|
| 37 |
+
pooled = F.adaptive_avg_pool2d(spatial, (pool_h, pool_w))
|
| 38 |
+
return pooled.squeeze(0).permute(1, 2, 0).reshape(-1, D)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SpatialConv2DMemoryProjector(nn.Module):
|
| 42 |
+
"""Project latent maps to DiT hidden tokens while preserving the HxW grid."""
|
| 43 |
+
|
| 44 |
+
projects_spatial_latents = True
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
latent_channels: int,
|
| 49 |
+
dit_hidden_size: int,
|
| 50 |
+
mid_channels: int,
|
| 51 |
+
kernel_size: int = 3,
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
kernel_size = int(kernel_size)
|
| 55 |
+
if kernel_size <= 0 or kernel_size % 2 == 0:
|
| 56 |
+
raise ValueError("kernel_size must be a positive odd integer")
|
| 57 |
+
self.latent_channels = int(latent_channels)
|
| 58 |
+
self.dit_hidden_size = int(dit_hidden_size)
|
| 59 |
+
self.mid_channels = int(mid_channels)
|
| 60 |
+
self.kernel_size = kernel_size
|
| 61 |
+
self.out_features = self.dit_hidden_size
|
| 62 |
+
self.proj_in = nn.Conv2d(self.latent_channels, self.mid_channels, kernel_size=1)
|
| 63 |
+
self.proj_spatial = nn.Conv2d(
|
| 64 |
+
self.mid_channels,
|
| 65 |
+
self.dit_hidden_size,
|
| 66 |
+
kernel_size=kernel_size,
|
| 67 |
+
padding=kernel_size // 2,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def forward(self, latents: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
if latents.ndim != 5:
|
| 72 |
+
raise ValueError("latents must have shape (T,B,C,H,W)")
|
| 73 |
+
T, B, C, H, W = latents.shape
|
| 74 |
+
if C != self.latent_channels:
|
| 75 |
+
raise ValueError(f"expected {self.latent_channels} latent channels, got {C}")
|
| 76 |
+
x = latents.reshape(T * B, C, H, W)
|
| 77 |
+
x = self.proj_spatial(self.proj_in(x))
|
| 78 |
+
x = x.reshape(T, B, self.dit_hidden_size, H, W)
|
| 79 |
+
return x.permute(1, 0, 3, 4, 2).reshape(B, T, H * W, self.dit_hidden_size).contiguous()
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class CausalConv3DDynamicCompressor(nn.Module):
|
| 83 |
+
"""Dynamic memory compressor: delta preprocessing + causal Conv3D on raw latents.
|
| 84 |
+
|
| 85 |
+
Replaces ShortTermLatentCompressor (slot cross-attention).
|
| 86 |
+
- Operates directly on (T, C, H, W) raw latents
|
| 87 |
+
- Delta: inp[0]=latent[0], inp[t]=latent[t]-latent[t-1]
|
| 88 |
+
- Causal padding prepends temporal zeros and right-aligns fixed outputs
|
| 89 |
+
- Zero-padded to max_source_frames for fixed output shape
|
| 90 |
+
- No slot cross-attention, no chunking
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
latent_channels: int,
|
| 96 |
+
dit_hidden_size: int,
|
| 97 |
+
patch_size: int = 2,
|
| 98 |
+
conv_kernel_t: int = 3,
|
| 99 |
+
conv_stride_t: int = 2,
|
| 100 |
+
max_source_frames: int = 8,
|
| 101 |
+
exclude_latest_local_frames: int = 4,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.latent_channels = latent_channels
|
| 105 |
+
self.dit_hidden_size = dit_hidden_size
|
| 106 |
+
self.patch_size = patch_size
|
| 107 |
+
self.conv_kernel_t = conv_kernel_t
|
| 108 |
+
self.conv_stride_t = conv_stride_t
|
| 109 |
+
self.max_source_frames = max_source_frames
|
| 110 |
+
self.exclude_latest_local_frames = int(exclude_latest_local_frames)
|
| 111 |
+
self.causal_pad = self._temporal_left_pad()
|
| 112 |
+
self.conv3d = nn.Conv3d(
|
| 113 |
+
latent_channels, dit_hidden_size,
|
| 114 |
+
kernel_size=(conv_kernel_t, patch_size, patch_size),
|
| 115 |
+
stride=(conv_stride_t, patch_size, patch_size),
|
| 116 |
+
padding=0,
|
| 117 |
+
)
|
| 118 |
+
self.out_norm = nn.LayerNorm(dit_hidden_size)
|
| 119 |
+
self._init_temporal_as_delta()
|
| 120 |
+
|
| 121 |
+
def _init_temporal_as_delta(self) -> None:
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
self.conv3d.weight.zero_()
|
| 124 |
+
k_t, p = self.conv_kernel_t, self.patch_size
|
| 125 |
+
D_out, D_in = self.conv3d.weight.shape[:2]
|
| 126 |
+
scale = 1.0 / (p * p)
|
| 127 |
+
# Delta preprocessing happens in forward. Initialize every output
|
| 128 |
+
# channel to read a patch-averaged current delta, repeating latent
|
| 129 |
+
# channels across the wider DiT hidden dimension.
|
| 130 |
+
for d in range(D_out):
|
| 131 |
+
self.conv3d.weight[d, d % D_in, k_t - 1, :, :] = scale
|
| 132 |
+
if self.conv3d.bias is not None:
|
| 133 |
+
nn.init.zeros_(self.conv3d.bias)
|
| 134 |
+
|
| 135 |
+
def _temporal_output_count(self) -> int:
|
| 136 |
+
return math.ceil(self.max_source_frames / self.conv_stride_t)
|
| 137 |
+
|
| 138 |
+
def _temporal_left_pad(self) -> int:
|
| 139 |
+
t_out = self._temporal_output_count()
|
| 140 |
+
latest_output_end = (t_out - 1) * self.conv_stride_t + self.conv_kernel_t - 1
|
| 141 |
+
latest_source = self.max_source_frames - 1
|
| 142 |
+
return max(0, latest_output_end - latest_source)
|
| 143 |
+
|
| 144 |
+
def _output_time_indices(self, device: torch.device) -> torch.Tensor:
|
| 145 |
+
t_out = self._temporal_output_count()
|
| 146 |
+
return (
|
| 147 |
+
torch.arange(t_out, device=device, dtype=torch.long) * self.conv_stride_t
|
| 148 |
+
+ self.conv_kernel_t
|
| 149 |
+
- 1
|
| 150 |
+
- self.causal_pad
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
def tokens_per_target(self, H: int, W: int) -> int:
|
| 154 |
+
p = self.patch_size
|
| 155 |
+
T_out = self._temporal_output_count()
|
| 156 |
+
return T_out * (H // p) * (W // p)
|
| 157 |
+
|
| 158 |
+
def forward(
|
| 159 |
+
self,
|
| 160 |
+
latents: torch.Tensor,
|
| 161 |
+
frame_indices: torch.Tensor,
|
| 162 |
+
pose: Optional[torch.Tensor],
|
| 163 |
+
target_frame_indices: torch.Tensor,
|
| 164 |
+
source_is_generated: Optional[torch.Tensor] = None,
|
| 165 |
+
exclude_latest_local_frames: Optional[int] = None,
|
| 166 |
+
) -> tuple[torch.Tensor, torch.Tensor, dict]:
|
| 167 |
+
if latents.ndim != 5:
|
| 168 |
+
raise ValueError("latents must have shape (T_src,B,C,H,W)")
|
| 169 |
+
exclude_latest_local_frames = (
|
| 170 |
+
self.exclude_latest_local_frames
|
| 171 |
+
if exclude_latest_local_frames is None
|
| 172 |
+
else int(exclude_latest_local_frames)
|
| 173 |
+
)
|
| 174 |
+
T_src, B, C, H, W = latents.shape
|
| 175 |
+
p = self.patch_size
|
| 176 |
+
if H % p != 0 or W % p != 0:
|
| 177 |
+
raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}")
|
| 178 |
+
if target_frame_indices.ndim == 1:
|
| 179 |
+
target_frame_indices = target_frame_indices[:, None].expand(-1, B)
|
| 180 |
+
T_tgt = target_frame_indices.shape[0]
|
| 181 |
+
device = latents.device
|
| 182 |
+
generated_flags = None if source_is_generated is None else source_is_generated.to(device=device, dtype=torch.bool)
|
| 183 |
+
n_spatial = (H // p) * (W // p)
|
| 184 |
+
T_out = self._temporal_output_count()
|
| 185 |
+
num_slots = T_out * n_spatial
|
| 186 |
+
output_time_idx = self._output_time_indices(device)
|
| 187 |
+
selected_source_count = torch.zeros((B, T_tgt), dtype=torch.long, device=device)
|
| 188 |
+
max_source_frame = torch.full((B, T_tgt), -1, dtype=torch.long, device=device)
|
| 189 |
+
generated_source_fraction = torch.zeros((B, T_tgt), dtype=torch.float32, device=device)
|
| 190 |
+
min_gap = torch.full((B, T_tgt), -1, dtype=torch.long, device=device)
|
| 191 |
+
max_gap = torch.full((B, T_tgt), -1, dtype=torch.long, device=device)
|
| 192 |
+
output_rows, mask_rows = [], []
|
| 193 |
+
for b in range(B):
|
| 194 |
+
src_frames_b = frame_indices[:, b]
|
| 195 |
+
tgt_outputs, tgt_masks = [], []
|
| 196 |
+
for j in range(T_tgt):
|
| 197 |
+
target = int(target_frame_indices[j, b].item())
|
| 198 |
+
valid_idx = (
|
| 199 |
+
src_frames_b < target - exclude_latest_local_frames
|
| 200 |
+
).nonzero(as_tuple=False).flatten()
|
| 201 |
+
if valid_idx.numel() == 0:
|
| 202 |
+
tgt_outputs.append(latents.new_zeros(num_slots, self.dit_hidden_size))
|
| 203 |
+
tgt_masks.append(torch.zeros(num_slots, device=device, dtype=torch.bool))
|
| 204 |
+
continue
|
| 205 |
+
selected_frames = src_frames_b.index_select(0, valid_idx)
|
| 206 |
+
order = torch.argsort(selected_frames)
|
| 207 |
+
valid_idx = valid_idx.index_select(0, order)[-self.max_source_frames:]
|
| 208 |
+
selected_frames = src_frames_b.index_select(0, valid_idx)
|
| 209 |
+
selected_source_count[b, j] = int(selected_frames.numel())
|
| 210 |
+
max_source_frame[b, j] = selected_frames.max()
|
| 211 |
+
gaps = target - selected_frames
|
| 212 |
+
min_gap[b, j] = gaps.min()
|
| 213 |
+
max_gap[b, j] = gaps.max()
|
| 214 |
+
if generated_flags is not None:
|
| 215 |
+
generated = generated_flags.index_select(0, valid_idx)[:, b]
|
| 216 |
+
generated_source_fraction[b, j] = generated.float().mean()
|
| 217 |
+
chunk = latents[valid_idx, b]
|
| 218 |
+
real_mask = torch.ones((chunk.shape[0],), device=device, dtype=torch.bool)
|
| 219 |
+
if chunk.shape[0] < self.max_source_frames:
|
| 220 |
+
pad = chunk.new_zeros(self.max_source_frames - chunk.shape[0], C, H, W)
|
| 221 |
+
chunk = torch.cat([pad, chunk], dim=0)
|
| 222 |
+
real_mask = torch.cat([
|
| 223 |
+
torch.zeros((pad.shape[0],), device=device, dtype=torch.bool),
|
| 224 |
+
real_mask,
|
| 225 |
+
])
|
| 226 |
+
inp = chunk.clone()
|
| 227 |
+
inp[1:] = chunk[1:] - chunk[:-1]
|
| 228 |
+
x = inp.permute(1, 0, 2, 3).unsqueeze(0) # (1,C,T,H,W)
|
| 229 |
+
x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0)) # left-pad time
|
| 230 |
+
x = self.conv3d(x) # (1,D,T_out,H//p,W//p)
|
| 231 |
+
x = x.squeeze(0).permute(1, 2, 3, 0) # (T_out,H//p,W//p,D)
|
| 232 |
+
x = self.out_norm(x)
|
| 233 |
+
tokens = x.reshape(num_slots, self.dit_hidden_size)
|
| 234 |
+
clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1)
|
| 235 |
+
temporal_mask = (
|
| 236 |
+
(output_time_idx >= 0)
|
| 237 |
+
& (output_time_idx < self.max_source_frames)
|
| 238 |
+
& real_mask.index_select(0, clamped_time_idx)
|
| 239 |
+
)
|
| 240 |
+
mask = temporal_mask[:, None].expand(T_out, n_spatial).reshape(num_slots)
|
| 241 |
+
tgt_outputs.append(tokens)
|
| 242 |
+
tgt_masks.append(mask)
|
| 243 |
+
output_rows.append(torch.stack(tgt_outputs))
|
| 244 |
+
mask_rows.append(torch.stack(tgt_masks))
|
| 245 |
+
out_tokens = torch.stack(output_rows)
|
| 246 |
+
out_mask = torch.stack(mask_rows)
|
| 247 |
+
diagnostics = {
|
| 248 |
+
"num_dynamic_slots": num_slots,
|
| 249 |
+
"dynamic_T_out": T_out,
|
| 250 |
+
"dynamic_n_spatial": n_spatial,
|
| 251 |
+
"dynamic_temporal_left_pad": self.causal_pad,
|
| 252 |
+
"dynamic_output_time_indices": output_time_idx,
|
| 253 |
+
"selected_source_count": selected_source_count,
|
| 254 |
+
"max_source_frame": max_source_frame,
|
| 255 |
+
"generated_source_fraction": generated_source_fraction,
|
| 256 |
+
"dynamic_min_gap_to_target_per_target": min_gap,
|
| 257 |
+
"dynamic_max_gap_to_target_per_target": max_gap,
|
| 258 |
+
"dynamic_exclude_latest_local_frames": exclude_latest_local_frames,
|
| 259 |
+
}
|
| 260 |
+
return out_tokens, out_mask, diagnostics
|
algorithms/worldmem/dememwm/diagnostics.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from .schedules import EVAL_ABLATION_BRANCH_TO_ID, NOISE_BUCKETS, NOISE_BUCKET_TO_ID, normalize_eval_ablation_branch, normalize_noise_bucket
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_REVISIT_LABEL_SOURCE = "deterministic_fov_coverage_plucker"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def tensor_valid_fraction(mask: torch.Tensor | None) -> float:
|
| 14 |
+
if mask is None or mask.numel() == 0:
|
| 15 |
+
return 0.0
|
| 16 |
+
return float(mask.detach().bool().float().mean().item())
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def gate_stats(gate: torch.Tensor | float | int | None) -> dict[str, float]:
|
| 20 |
+
if gate is None:
|
| 21 |
+
return {"mean": 0.0, "min": 0.0, "max": 0.0}
|
| 22 |
+
if not torch.is_tensor(gate):
|
| 23 |
+
value = float(gate)
|
| 24 |
+
return {"mean": value, "min": value, "max": value}
|
| 25 |
+
g = gate.detach().float()
|
| 26 |
+
return {"mean": float(g.mean().item()), "min": float(g.min().item()), "max": float(g.max().item())}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def summarize_stream(name: str, tokens: torch.Tensor | None, mask: torch.Tensor | None, gate: torch.Tensor | float | None) -> dict[str, Any]:
|
| 30 |
+
return {f"{name}_tokens_shape": None if tokens is None else tuple(tokens.shape), f"{name}_valid_fraction": tensor_valid_fraction(mask), f"{name}_valid_tokens": 0 if mask is None else int(mask.detach().bool().sum().item()), f"{name}_gate": gate_stats(gate)}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def assert_no_future_sources(target_frame: int, max_source_frame: int | torch.Tensor) -> None:
|
| 34 |
+
max_src = int(max_source_frame.detach().max().item()) if torch.is_tensor(max_source_frame) else int(max_source_frame)
|
| 35 |
+
if max_src >= int(target_frame):
|
| 36 |
+
raise AssertionError(f"DeMemWM memory source {max_src} is not causal for target {target_frame}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _collect_values(result_diagnostics: list[dict[str, Any]], key: str) -> list[float]:
|
| 40 |
+
values: list[float] = []
|
| 41 |
+
for diag in result_diagnostics:
|
| 42 |
+
for value in diag.get(key, []) or []:
|
| 43 |
+
values.append(float(value))
|
| 44 |
+
return values
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _value_stats(values: list[float], prefix: str) -> dict[str, float]:
|
| 48 |
+
if not values:
|
| 49 |
+
return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0}
|
| 50 |
+
return {
|
| 51 |
+
f"{prefix}_mean": float(sum(values) / len(values)),
|
| 52 |
+
f"{prefix}_min": float(min(values)),
|
| 53 |
+
f"{prefix}_max": float(max(values)),
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], valid_revisit_mask: torch.Tensor | None) -> dict[str, Any]:
|
| 58 |
+
target_count = len(result_diagnostics)
|
| 59 |
+
candidate_count = sum(int(diag.get("revisit_candidate_frame_count", diag.get("revisit_candidate_count", diag.get("candidate_count", 0)))) for diag in result_diagnostics)
|
| 60 |
+
candidate_count_mean = float(candidate_count / target_count) if target_count else 0.0
|
| 61 |
+
valid_candidate_label_count = sum(int(diag.get("valid_candidate_label_count", diag.get("valid_candidate_count", 0))) for diag in result_diagnostics)
|
| 62 |
+
pose_preselect_input_count = sum(int(diag.get("revisit_pose_preselect_input_count", 0)) for diag in result_diagnostics)
|
| 63 |
+
pose_preselect_selected_count = sum(int(diag.get("revisit_pose_preselect_selected_count", 0)) for diag in result_diagnostics)
|
| 64 |
+
exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics)
|
| 65 |
+
valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics)
|
| 66 |
+
valid_count_mean = float(valid_count / target_count) if target_count else 0.0
|
| 67 |
+
valid_target_count = sum(int(diag.get("valid_revisit_target_count", diag.get("high_quality_selected_revisit", 0))) for diag in result_diagnostics)
|
| 68 |
+
selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics)
|
| 69 |
+
no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics)
|
| 70 |
+
abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics)
|
| 71 |
+
selected_gaps = [int(diag["revisit_min_gap_to_target"]) for diag in result_diagnostics if int(diag.get("revisit_min_gap_to_target", -1)) >= 0]
|
| 72 |
+
diagnostics: dict[str, Any] = {
|
| 73 |
+
"revisit_candidate_frame_count": candidate_count_mean,
|
| 74 |
+
"revisit_candidate_count": candidate_count_mean,
|
| 75 |
+
"valid_candidate_label_count": int(valid_candidate_label_count),
|
| 76 |
+
"revisit_pose_preselect_input_count": float(pose_preselect_input_count / target_count) if target_count else 0.0,
|
| 77 |
+
"revisit_pose_preselect_selected_count": float(pose_preselect_selected_count / target_count) if target_count else 0.0,
|
| 78 |
+
"revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0,
|
| 79 |
+
"valid_revisit_frame_count": valid_count_mean,
|
| 80 |
+
"valid_revisit_count": valid_count_mean,
|
| 81 |
+
"valid_revisit_target_count": int(valid_target_count),
|
| 82 |
+
"no_valid_revisit_count": int(no_valid_count),
|
| 83 |
+
"valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask),
|
| 84 |
+
"revisit_selected_frame_count": int(selected_count),
|
| 85 |
+
"revisit_selected_count": int(selected_count),
|
| 86 |
+
"revisit_abstained_count": int(abstained_count),
|
| 87 |
+
"revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1,
|
| 88 |
+
"revisit_label_source": _REVISIT_LABEL_SOURCE,
|
| 89 |
+
}
|
| 90 |
+
frame_fov_values = _collect_values(result_diagnostics, "frame_fov_overlap_values")
|
| 91 |
+
if not frame_fov_values:
|
| 92 |
+
frame_fov_values = _collect_values(result_diagnostics, "fov_overlap_values")
|
| 93 |
+
diagnostics.update(_value_stats(frame_fov_values, "revisit_frame_fov_overlap"))
|
| 94 |
+
diagnostics.update(_value_stats(frame_fov_values, "revisit_fov_overlap"))
|
| 95 |
+
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "plucker_overlap_values"), "revisit_plucker_overlap"))
|
| 96 |
+
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_fov_overlap_values"), "revisit_best_selected_fov_overlap"))
|
| 97 |
+
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_plucker_overlap_values"), "revisit_best_selected_plucker_overlap"))
|
| 98 |
+
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_gap_frame_values"), "revisit_best_selected_gap_frames"))
|
| 99 |
+
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_frame_fov_overlap_values"), "revisit_best_selected_frame_fov_overlap"))
|
| 100 |
+
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_frame_fov_overlap_values"), "revisit_selected_frame_fov_overlap"))
|
| 101 |
+
diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_incremental_fov_overlap_values"), "revisit_incremental_fov_overlap"))
|
| 102 |
+
return diagnostics
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def summarize_noise_bucket_diagnostics(
|
| 106 |
+
*,
|
| 107 |
+
noise_bucket: str | None,
|
| 108 |
+
valid_revisit_mask: torch.Tensor | None,
|
| 109 |
+
no_valid_revisit_mask: torch.Tensor | None,
|
| 110 |
+
noise_bucket_ids: torch.Tensor | None = None,
|
| 111 |
+
) -> dict[str, Any]:
|
| 112 |
+
bucket = normalize_noise_bucket(noise_bucket)
|
| 113 |
+
diagnostics: dict[str, Any] = {
|
| 114 |
+
"noise_bucket": bucket,
|
| 115 |
+
"noise_bucket_id": int(NOISE_BUCKET_TO_ID[bucket]),
|
| 116 |
+
}
|
| 117 |
+
for candidate in NOISE_BUCKETS:
|
| 118 |
+
diagnostics[f"noise_bucket_is_{candidate}"] = int(bucket == candidate)
|
| 119 |
+
|
| 120 |
+
valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 121 |
+
no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 122 |
+
target_count = int(valid.numel())
|
| 123 |
+
diagnostics["noise_bucket_target_count"] = target_count
|
| 124 |
+
if noise_bucket_ids is None:
|
| 125 |
+
target_bucket_ids = torch.full((target_count,), int(NOISE_BUCKET_TO_ID[bucket]), dtype=torch.long)
|
| 126 |
+
else:
|
| 127 |
+
target_bucket_ids = noise_bucket_ids.detach().long().reshape(-1).cpu()
|
| 128 |
+
if int(target_bucket_ids.numel()) != target_count:
|
| 129 |
+
raise ValueError(
|
| 130 |
+
f"noise_bucket_ids has {int(target_bucket_ids.numel())} targets, expected {target_count}"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
for bucket_name in NOISE_BUCKETS:
|
| 134 |
+
bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
|
| 135 |
+
diagnostics[f"noise_bucket_{bucket_name}_target_count"] = int(bucket_mask.long().sum().item())
|
| 136 |
+
|
| 137 |
+
mask_specs = (
|
| 138 |
+
("valid_revisit", valid),
|
| 139 |
+
("no_valid_revisit", no_valid),
|
| 140 |
+
)
|
| 141 |
+
for mask_name, mask in mask_specs:
|
| 142 |
+
for bucket_name in NOISE_BUCKETS:
|
| 143 |
+
bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
|
| 144 |
+
count = int((mask & bucket_mask).long().sum().item()) if mask.numel() else 0
|
| 145 |
+
diagnostics[f"{mask_name}_noise_bucket_{bucket_name}_count"] = count
|
| 146 |
+
return diagnostics
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def summarize_eval_ablation_diagnostics(
|
| 150 |
+
*,
|
| 151 |
+
enabled: bool,
|
| 152 |
+
branch: str | None,
|
| 153 |
+
valid_revisit_mask: torch.Tensor | None,
|
| 154 |
+
no_valid_revisit_mask: torch.Tensor | None,
|
| 155 |
+
eval_corrupted_revisit_mask: torch.Tensor | None,
|
| 156 |
+
) -> dict[str, Any]:
|
| 157 |
+
branch = normalize_eval_ablation_branch(branch)
|
| 158 |
+
valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 159 |
+
no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 160 |
+
corrupted = torch.zeros_like(valid) if eval_corrupted_revisit_mask is None else eval_corrupted_revisit_mask.detach().bool().reshape(-1).cpu()
|
| 161 |
+
true_revisit = valid & (~corrupted)
|
| 162 |
+
diagnostics: dict[str, Any] = {
|
| 163 |
+
"eval_ablation_enabled": bool(enabled),
|
| 164 |
+
"eval_ablation_branch": branch,
|
| 165 |
+
"eval_ablation_branch_id": int(EVAL_ABLATION_BRANCH_TO_ID[branch]),
|
| 166 |
+
"eval_bucket_true_revisit_count": int(true_revisit.long().sum().item()),
|
| 167 |
+
"eval_bucket_no_valid_revisit_count": int(no_valid.long().sum().item()),
|
| 168 |
+
"eval_bucket_corrupted_memory_count": int(corrupted.long().sum().item()),
|
| 169 |
+
}
|
| 170 |
+
total = max(int(valid.numel()), 1)
|
| 171 |
+
diagnostics["eval_bucket_true_revisit_fraction"] = float(diagnostics["eval_bucket_true_revisit_count"] / total)
|
| 172 |
+
diagnostics["eval_bucket_no_valid_revisit_fraction"] = float(diagnostics["eval_bucket_no_valid_revisit_count"] / total)
|
| 173 |
+
diagnostics["eval_bucket_corrupted_memory_fraction"] = float(diagnostics["eval_bucket_corrupted_memory_count"] / total)
|
| 174 |
+
return diagnostics
|
algorithms/worldmem/dememwm/gates.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RevisitRawGate(nn.Module):
|
| 8 |
+
"""Learned per-target revisit gate from selected-revisit quality features.
|
| 9 |
+
|
| 10 |
+
The caller applies validity masking and stage/denoise scaling after this
|
| 11 |
+
module. This module never turns selected revisit validity into a target.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, init_logit: float = 1.0):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.net = nn.Linear(3, 1)
|
| 17 |
+
nn.init.zeros_(self.net.weight)
|
| 18 |
+
nn.init.constant_(self.net.bias, float(init_logit))
|
| 19 |
+
|
| 20 |
+
def forward(
|
| 21 |
+
self,
|
| 22 |
+
*,
|
| 23 |
+
valid_revisit_mask: torch.Tensor,
|
| 24 |
+
best_selected_fov_overlap: torch.Tensor | None = None,
|
| 25 |
+
best_selected_plucker_overlap: torch.Tensor | None = None,
|
| 26 |
+
selected_gap_frames: torch.Tensor | None = None,
|
| 27 |
+
) -> torch.Tensor:
|
| 28 |
+
if valid_revisit_mask.ndim != 2:
|
| 29 |
+
raise ValueError("valid_revisit_mask must have shape (B,T)")
|
| 30 |
+
device = valid_revisit_mask.device
|
| 31 |
+
dtype = torch.float32
|
| 32 |
+
shape = valid_revisit_mask.shape
|
| 33 |
+
|
| 34 |
+
def _feature(value: torch.Tensor | None) -> torch.Tensor:
|
| 35 |
+
if value is None:
|
| 36 |
+
return torch.zeros(shape, device=device, dtype=dtype)
|
| 37 |
+
tensor = value.to(device=device, dtype=dtype)
|
| 38 |
+
if tensor.ndim == 0:
|
| 39 |
+
return torch.full(shape, float(tensor.item()), device=device, dtype=dtype)
|
| 40 |
+
return tensor.expand(shape)
|
| 41 |
+
|
| 42 |
+
fov = _feature(best_selected_fov_overlap).clamp(min=0.0, max=1.0)
|
| 43 |
+
plucker = _feature(best_selected_plucker_overlap).clamp(min=0.0, max=1.0)
|
| 44 |
+
log_age = torch.log1p(_feature(selected_gap_frames).clamp_min(0.0)).clamp(max=8.0) / 8.0
|
| 45 |
+
features = torch.stack([fov, plucker, log_age], dim=-1)
|
| 46 |
+
return torch.sigmoid(self.net(features).squeeze(-1))
|
algorithms/worldmem/dememwm/injection.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .diagnostics import summarize_stream
|
| 10 |
+
from .types import MemoryStreamTensors
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class InjectionAdapter:
|
| 15 |
+
"""Convert DeMemWM stream tensors to Diffusion/DiT Option-C kwargs."""
|
| 16 |
+
|
| 17 |
+
omit_disabled: bool = False
|
| 18 |
+
|
| 19 |
+
def _tokens(self, name: str, tokens: torch.Tensor, device, dtype) -> torch.Tensor:
|
| 20 |
+
if tokens.ndim != 4:
|
| 21 |
+
raise ValueError(f"{name} tokens must have shape (B,T,M,D), got {tuple(tokens.shape)}")
|
| 22 |
+
return tokens.to(device=device, dtype=dtype)
|
| 23 |
+
|
| 24 |
+
def _mask(self, name: str, mask: torch.Tensor, tokens: torch.Tensor, device) -> torch.Tensor:
|
| 25 |
+
if mask.shape != tokens.shape[:3]:
|
| 26 |
+
raise ValueError(f"{name} mask must have shape {tuple(tokens.shape[:3])}, got {tuple(mask.shape)}")
|
| 27 |
+
return mask.to(device=device, dtype=torch.bool)
|
| 28 |
+
|
| 29 |
+
def _gate(self, gate: torch.Tensor | float | int, tokens: torch.Tensor, device, dtype):
|
| 30 |
+
if torch.is_tensor(gate):
|
| 31 |
+
return gate.to(device=device, dtype=dtype)
|
| 32 |
+
return torch.tensor(float(gate), device=device, dtype=dtype)
|
| 33 |
+
|
| 34 |
+
def __call__(self, streams: MemoryStreamTensors, device=None, dtype=None) -> tuple[dict[str, Any], dict[str, Any]]:
|
| 35 |
+
ref = streams.anchor_tokens
|
| 36 |
+
device = device or ref.device
|
| 37 |
+
dtype = dtype or ref.dtype
|
| 38 |
+
anchor_tokens = self._tokens("anchor", streams.anchor_tokens, device, dtype)
|
| 39 |
+
dynamic_tokens = self._tokens("dynamic", streams.dynamic_tokens, device, dtype)
|
| 40 |
+
revisit_tokens = self._tokens("revisit", streams.revisit_tokens, device, dtype)
|
| 41 |
+
anchor_mask = self._mask("anchor", streams.anchor_mask, anchor_tokens, device)
|
| 42 |
+
dynamic_mask = self._mask("dynamic", streams.dynamic_mask, dynamic_tokens, device)
|
| 43 |
+
revisit_mask = self._mask("revisit", streams.revisit_mask, revisit_tokens, device)
|
| 44 |
+
kwargs = {
|
| 45 |
+
"memory_tokens": anchor_tokens,
|
| 46 |
+
"memory_token_mask": anchor_mask,
|
| 47 |
+
"memory_dynamic_tokens": dynamic_tokens,
|
| 48 |
+
"memory_dynamic_mask": dynamic_mask,
|
| 49 |
+
"memory_retrieval_tokens": revisit_tokens,
|
| 50 |
+
"memory_retrieval_mask": revisit_mask,
|
| 51 |
+
"memory_anchor_gate": self._gate(streams.anchor_gate, anchor_tokens, device, dtype),
|
| 52 |
+
"memory_dynamic_gate": self._gate(streams.dynamic_gate, dynamic_tokens, device, dtype),
|
| 53 |
+
"memory_retrieval_gate": self._gate(streams.revisit_gate, revisit_tokens, device, dtype),
|
| 54 |
+
}
|
| 55 |
+
if self.omit_disabled:
|
| 56 |
+
if not anchor_mask.any():
|
| 57 |
+
kwargs["memory_tokens"] = None
|
| 58 |
+
kwargs["memory_token_mask"] = None
|
| 59 |
+
if not dynamic_mask.any():
|
| 60 |
+
kwargs["memory_dynamic_tokens"] = None
|
| 61 |
+
kwargs["memory_dynamic_mask"] = None
|
| 62 |
+
if not revisit_mask.any():
|
| 63 |
+
kwargs["memory_retrieval_tokens"] = None
|
| 64 |
+
kwargs["memory_retrieval_mask"] = None
|
| 65 |
+
diagnostics = dict(streams.diagnostics)
|
| 66 |
+
diagnostics.update(summarize_stream("anchor", anchor_tokens, anchor_mask, kwargs["memory_anchor_gate"]))
|
| 67 |
+
diagnostics.update(summarize_stream("dynamic", dynamic_tokens, dynamic_mask, kwargs["memory_dynamic_gate"]))
|
| 68 |
+
diagnostics.update(summarize_stream("revisit", revisit_tokens, revisit_mask, kwargs["memory_retrieval_gate"]))
|
| 69 |
+
if streams.revisit_gate_raw is not None:
|
| 70 |
+
raw_gate = streams.revisit_gate_raw.to(device=device, dtype=dtype)
|
| 71 |
+
diagnostics["revisit_gate_raw"] = raw_gate
|
| 72 |
+
diagnostics["revisit_gate_raw_mean"] = float(raw_gate.detach().float().mean().item()) if raw_gate.numel() else 0.0
|
| 73 |
+
diagnostics["revisit_gate_raw_min"] = float(raw_gate.detach().float().min().item()) if raw_gate.numel() else 0.0
|
| 74 |
+
diagnostics["revisit_gate_raw_max"] = float(raw_gate.detach().float().max().item()) if raw_gate.numel() else 0.0
|
| 75 |
+
if streams.no_valid_revisit_mask is not None:
|
| 76 |
+
diagnostics["no_valid_revisit_mask"] = streams.no_valid_revisit_mask.to(device=device, dtype=torch.bool)
|
| 77 |
+
max_sources = [v for k, v in streams.diagnostics.items() if k.endswith("max_source_frame")]
|
| 78 |
+
if max_sources:
|
| 79 |
+
diagnostics["max_source_frame"] = max(int(torch.as_tensor(v).max().item()) for v in max_sources)
|
| 80 |
+
return kwargs, diagnostics
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
DeMemWMInjectionAdapter = InjectionAdapter
|
algorithms/worldmem/dememwm/labels.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .types import MemoryRecord
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
LABEL_SOURCE = "deterministic_fov_coverage_plucker"
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass(frozen=True)
|
| 16 |
+
class RevisitCandidateLabel:
|
| 17 |
+
record: MemoryRecord
|
| 18 |
+
valid: bool
|
| 19 |
+
gap_valid: bool
|
| 20 |
+
gap_to_target: int
|
| 21 |
+
fov_overlap: Optional[float]
|
| 22 |
+
plucker_overlap: Optional[float]
|
| 23 |
+
primary_overlap: float
|
| 24 |
+
coverage_mask: Optional[torch.Tensor]
|
| 25 |
+
reject_reasons: tuple[str, ...]
|
| 26 |
+
best_frame_index: Optional[int] = None
|
| 27 |
+
best_frame_fov_overlap: Optional[float] = None
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def sort_key(self) -> tuple[float, float, float, int, int, str]:
|
| 31 |
+
fov = 0.0 if self.fov_overlap is None else float(self.fov_overlap)
|
| 32 |
+
plucker = 0.0 if self.plucker_overlap is None else float(self.plucker_overlap)
|
| 33 |
+
return (
|
| 34 |
+
-self.primary_overlap,
|
| 35 |
+
-fov,
|
| 36 |
+
-plucker,
|
| 37 |
+
self.gap_to_target,
|
| 38 |
+
int(self.record.source_start),
|
| 39 |
+
str(self.record.chunk_id or ""),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _as_float_tensor(value) -> Optional[torch.Tensor]:
|
| 44 |
+
if value is None:
|
| 45 |
+
return None
|
| 46 |
+
if torch.is_tensor(value):
|
| 47 |
+
if value.numel() == 0:
|
| 48 |
+
return None
|
| 49 |
+
return value.detach().float()
|
| 50 |
+
tensor = torch.as_tensor(value, dtype=torch.float32)
|
| 51 |
+
if tensor.numel() == 0:
|
| 52 |
+
return None
|
| 53 |
+
return tensor
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _pose_frames(value) -> Optional[torch.Tensor]:
|
| 57 |
+
tensor = _as_float_tensor(value)
|
| 58 |
+
if tensor is None or tensor.shape[-1] < 5:
|
| 59 |
+
return None
|
| 60 |
+
return tensor.reshape(-1, tensor.shape[-1])[:, :5]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _angle_diff_degrees(value: torch.Tensor) -> torch.Tensor:
|
| 64 |
+
diff = value.abs() % 360.0
|
| 65 |
+
return torch.where(diff > 180.0, 360.0 - diff, diff)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _target_fov_points(
|
| 69 |
+
target_pose: torch.Tensor,
|
| 70 |
+
*,
|
| 71 |
+
fov_half_h: float,
|
| 72 |
+
fov_half_v: float,
|
| 73 |
+
yaw_samples: int,
|
| 74 |
+
pitch_samples: int,
|
| 75 |
+
depth_samples: int,
|
| 76 |
+
radius: float,
|
| 77 |
+
) -> torch.Tensor:
|
| 78 |
+
yaw_samples = max(1, int(yaw_samples))
|
| 79 |
+
pitch_samples = max(1, int(pitch_samples))
|
| 80 |
+
depth_samples = max(1, int(depth_samples))
|
| 81 |
+
device = target_pose.device
|
| 82 |
+
dtype = target_pose.dtype
|
| 83 |
+
if yaw_samples == 1:
|
| 84 |
+
yaw_offsets = torch.zeros((1,), device=device, dtype=dtype)
|
| 85 |
+
else:
|
| 86 |
+
yaw_offsets = torch.linspace(-float(fov_half_h), float(fov_half_h), yaw_samples + 2, device=device, dtype=dtype)[1:-1]
|
| 87 |
+
if pitch_samples == 1:
|
| 88 |
+
pitch_offsets = torch.zeros((1,), device=device, dtype=dtype)
|
| 89 |
+
else:
|
| 90 |
+
pitch_offsets = torch.linspace(-float(fov_half_v), float(fov_half_v), pitch_samples + 2, device=device, dtype=dtype)[1:-1]
|
| 91 |
+
if depth_samples == 1:
|
| 92 |
+
depths = torch.full((1,), float(radius), device=device, dtype=dtype)
|
| 93 |
+
else:
|
| 94 |
+
depths = torch.linspace(float(radius) / float(depth_samples), float(radius), depth_samples, device=device, dtype=dtype)
|
| 95 |
+
depth_grid, pitch_grid, yaw_grid = torch.meshgrid(depths, pitch_offsets, yaw_offsets, indexing="ij")
|
| 96 |
+
pitch = torch.deg2rad(target_pose[3] + pitch_grid.reshape(-1))
|
| 97 |
+
yaw = torch.deg2rad(target_pose[4] + yaw_grid.reshape(-1))
|
| 98 |
+
depth = depth_grid.reshape(-1)
|
| 99 |
+
cos_pitch = torch.cos(pitch)
|
| 100 |
+
vectors = torch.stack(
|
| 101 |
+
[
|
| 102 |
+
depth * cos_pitch * torch.sin(yaw),
|
| 103 |
+
depth * torch.sin(pitch),
|
| 104 |
+
depth * cos_pitch * torch.cos(yaw),
|
| 105 |
+
],
|
| 106 |
+
dim=-1,
|
| 107 |
+
)
|
| 108 |
+
return target_pose[:3].reshape(1, 3) + vectors
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def _inside_fov_3d_hv(
|
| 112 |
+
points: torch.Tensor,
|
| 113 |
+
poses: torch.Tensor,
|
| 114 |
+
*,
|
| 115 |
+
fov_half_h: float,
|
| 116 |
+
fov_half_v: float,
|
| 117 |
+
) -> torch.Tensor:
|
| 118 |
+
vectors = points.unsqueeze(0) - poses[:, None, :3]
|
| 119 |
+
x = vectors[..., 0]
|
| 120 |
+
y = vectors[..., 1]
|
| 121 |
+
z = vectors[..., 2]
|
| 122 |
+
azimuth = torch.atan2(x, z) * (180.0 / math.pi)
|
| 123 |
+
elevation = torch.atan2(y, torch.sqrt(x.square() + z.square()).clamp_min(1e-8)) * (180.0 / math.pi)
|
| 124 |
+
diff_azimuth = _angle_diff_degrees(azimuth - poses[:, None, 4])
|
| 125 |
+
diff_elevation = _angle_diff_degrees(elevation - poses[:, None, 3])
|
| 126 |
+
return (diff_azimuth < float(fov_half_h)) & (diff_elevation < float(fov_half_v))
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def fov_coverage_overlap(
|
| 130 |
+
source_pose,
|
| 131 |
+
target_pose,
|
| 132 |
+
*,
|
| 133 |
+
fov_half_h: float = 105.0 / 2.0,
|
| 134 |
+
fov_half_v: float = 75.0 / 2.0,
|
| 135 |
+
yaw_samples: int = 25,
|
| 136 |
+
pitch_samples: int = 20,
|
| 137 |
+
depth_samples: int = 20,
|
| 138 |
+
radius: float = 30.0,
|
| 139 |
+
) -> tuple[Optional[float], Optional[torch.Tensor]]:
|
| 140 |
+
source_poses = _pose_frames(source_pose)
|
| 141 |
+
target_poses = _pose_frames(target_pose)
|
| 142 |
+
if source_poses is None or target_poses is None:
|
| 143 |
+
return None, None
|
| 144 |
+
target = target_poses[-1].to(device=source_poses.device, dtype=source_poses.dtype)
|
| 145 |
+
points = _target_fov_points(
|
| 146 |
+
target,
|
| 147 |
+
fov_half_h=fov_half_h,
|
| 148 |
+
fov_half_v=fov_half_v,
|
| 149 |
+
yaw_samples=yaw_samples,
|
| 150 |
+
pitch_samples=pitch_samples,
|
| 151 |
+
depth_samples=depth_samples,
|
| 152 |
+
radius=radius,
|
| 153 |
+
)
|
| 154 |
+
coverage_mask = _inside_fov_3d_hv(points, source_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v).any(dim=0)
|
| 155 |
+
return float(coverage_mask.float().mean().item()), coverage_mask.detach()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _rotation_from_pose(poses: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
pitch = torch.deg2rad(poses[:, 3])
|
| 160 |
+
yaw = torch.deg2rad(poses[:, 4])
|
| 161 |
+
cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch)
|
| 162 |
+
cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw)
|
| 163 |
+
zeros = torch.zeros_like(pitch)
|
| 164 |
+
ones = torch.ones_like(pitch)
|
| 165 |
+
r_pitch = torch.stack(
|
| 166 |
+
[
|
| 167 |
+
ones, zeros, zeros,
|
| 168 |
+
zeros, cos_pitch, -sin_pitch,
|
| 169 |
+
zeros, sin_pitch, cos_pitch,
|
| 170 |
+
],
|
| 171 |
+
dim=-1,
|
| 172 |
+
).reshape(-1, 3, 3)
|
| 173 |
+
r_yaw = torch.stack(
|
| 174 |
+
[
|
| 175 |
+
cos_yaw, zeros, sin_yaw,
|
| 176 |
+
zeros, ones, zeros,
|
| 177 |
+
-sin_yaw, zeros, cos_yaw,
|
| 178 |
+
],
|
| 179 |
+
dim=-1,
|
| 180 |
+
).reshape(-1, 3, 3)
|
| 181 |
+
return torch.matmul(r_yaw, r_pitch)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _plucker_descriptor(
|
| 185 |
+
poses: torch.Tensor,
|
| 186 |
+
*,
|
| 187 |
+
grid_h: int,
|
| 188 |
+
grid_w: int,
|
| 189 |
+
focal_length: float,
|
| 190 |
+
) -> torch.Tensor:
|
| 191 |
+
grid_h = max(1, int(grid_h))
|
| 192 |
+
grid_w = max(1, int(grid_w))
|
| 193 |
+
poses = poses.float()
|
| 194 |
+
device = poses.device
|
| 195 |
+
dtype = torch.float32
|
| 196 |
+
ys, xs = torch.meshgrid(
|
| 197 |
+
torch.linspace(0, grid_h - 1, grid_h, device=device, dtype=dtype),
|
| 198 |
+
torch.linspace(0, grid_w - 1, grid_w, device=device, dtype=dtype),
|
| 199 |
+
indexing="ij",
|
| 200 |
+
)
|
| 201 |
+
fx = float(focal_length) * float(grid_w)
|
| 202 |
+
fy = float(focal_length) * float(grid_h)
|
| 203 |
+
cx = 0.5 * float(grid_w)
|
| 204 |
+
cy = 0.5 * float(grid_h)
|
| 205 |
+
zs = torch.ones_like(xs)
|
| 206 |
+
dirs = torch.stack([-(xs + 0.5 - cx) / fx, -(ys + 0.5 - cy) / fy, zs], dim=-1)
|
| 207 |
+
dirs = dirs.reshape(-1, 3)
|
| 208 |
+
dirs = dirs / dirs.norm(dim=-1, keepdim=True).clamp_min(1e-8)
|
| 209 |
+
rotation = _rotation_from_pose(poses)
|
| 210 |
+
rays_d = torch.matmul(dirs.unsqueeze(0), rotation.transpose(1, 2)).float()
|
| 211 |
+
rays_o = poses[:, None, :3].expand_as(rays_d).float()
|
| 212 |
+
moments = torch.linalg.cross(rays_o, rays_d, dim=-1)
|
| 213 |
+
return torch.cat([moments, rays_d], dim=-1).reshape(poses.shape[0], -1)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def plucker_overlap(
|
| 217 |
+
source_pose,
|
| 218 |
+
target_pose,
|
| 219 |
+
*,
|
| 220 |
+
grid_h: int = 4,
|
| 221 |
+
grid_w: int = 4,
|
| 222 |
+
focal_length: float = 0.35,
|
| 223 |
+
) -> Optional[float]:
|
| 224 |
+
source_poses = _pose_frames(source_pose)
|
| 225 |
+
target_poses = _pose_frames(target_pose)
|
| 226 |
+
if source_poses is None or target_poses is None:
|
| 227 |
+
return None
|
| 228 |
+
target = target_poses[-1:].to(device=source_poses.device, dtype=source_poses.dtype)
|
| 229 |
+
source_desc = _plucker_descriptor(source_poses, grid_h=grid_h, grid_w=grid_w, focal_length=focal_length)
|
| 230 |
+
target_desc = _plucker_descriptor(target, grid_h=grid_h, grid_w=grid_w, focal_length=focal_length)
|
| 231 |
+
diff = source_desc - target_desc
|
| 232 |
+
distance = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1]))
|
| 233 |
+
best_distance = float(distance.min().item())
|
| 234 |
+
return float(1.0 / (1.0 + max(best_distance, 0.0)))
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _passes_threshold(overlap: Optional[float], threshold: Optional[float]) -> bool:
|
| 238 |
+
if threshold is None or overlap is None:
|
| 239 |
+
return True
|
| 240 |
+
return float(overlap) >= float(threshold)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _batched_pose_overlaps(
|
| 244 |
+
records: list[MemoryRecord],
|
| 245 |
+
*,
|
| 246 |
+
target_pose=None,
|
| 247 |
+
fov_half_h: float = 105.0 / 2.0,
|
| 248 |
+
fov_half_v: float = 75.0 / 2.0,
|
| 249 |
+
fov_yaw_samples: int = 25,
|
| 250 |
+
fov_pitch_samples: int = 20,
|
| 251 |
+
fov_depth_samples: int = 20,
|
| 252 |
+
fov_radius: float = 30.0,
|
| 253 |
+
plucker_grid_h: int = 4,
|
| 254 |
+
plucker_grid_w: int = 4,
|
| 255 |
+
plucker_focal_length: float = 0.35,
|
| 256 |
+
) -> tuple[
|
| 257 |
+
list[Optional[float]],
|
| 258 |
+
list[Optional[float]],
|
| 259 |
+
list[Optional[torch.Tensor]],
|
| 260 |
+
list[Optional[int]],
|
| 261 |
+
list[Optional[float]],
|
| 262 |
+
]:
|
| 263 |
+
fov_overlaps: list[Optional[float]] = [None] * len(records)
|
| 264 |
+
plucker_overlaps: list[Optional[float]] = [None] * len(records)
|
| 265 |
+
coverage_masks: list[Optional[torch.Tensor]] = [None] * len(records)
|
| 266 |
+
best_frame_indices: list[Optional[int]] = [None] * len(records)
|
| 267 |
+
best_frame_fov_overlaps: list[Optional[float]] = [None] * len(records)
|
| 268 |
+
target_poses = _pose_frames(target_pose)
|
| 269 |
+
if not records or target_poses is None:
|
| 270 |
+
return fov_overlaps, plucker_overlaps, coverage_masks, best_frame_indices, best_frame_fov_overlaps
|
| 271 |
+
|
| 272 |
+
pose_records: list[tuple[int, torch.Tensor, torch.Tensor | None]] = []
|
| 273 |
+
device = None
|
| 274 |
+
for record_idx, record in enumerate(records):
|
| 275 |
+
source_poses = _pose_frames(record.pose)
|
| 276 |
+
if source_poses is None:
|
| 277 |
+
continue
|
| 278 |
+
if device is None:
|
| 279 |
+
device = source_poses.device
|
| 280 |
+
frame_values = record.frame_indices.detach().reshape(-1)
|
| 281 |
+
frame_values = frame_values if int(frame_values.numel()) == int(source_poses.shape[0]) else None
|
| 282 |
+
pose_records.append((record_idx, source_poses, frame_values))
|
| 283 |
+
if not pose_records or device is None:
|
| 284 |
+
return fov_overlaps, plucker_overlaps, coverage_masks, best_frame_indices, best_frame_fov_overlaps
|
| 285 |
+
|
| 286 |
+
source_pose_blocks = [poses.to(device=device, dtype=torch.float32) for _, poses, _ in pose_records]
|
| 287 |
+
source_poses = torch.cat(source_pose_blocks, dim=0)
|
| 288 |
+
record_ids = torch.cat(
|
| 289 |
+
[
|
| 290 |
+
torch.full((poses.shape[0],), int(record_idx), device=device, dtype=torch.long)
|
| 291 |
+
for (record_idx, _, _), poses in zip(pose_records, source_pose_blocks)
|
| 292 |
+
],
|
| 293 |
+
dim=0,
|
| 294 |
+
)
|
| 295 |
+
source_frame_blocks = [
|
| 296 |
+
torch.full((poses.shape[0],), -1, device=device, dtype=torch.long)
|
| 297 |
+
if frame_values is None
|
| 298 |
+
else frame_values.to(device=device, dtype=torch.long)
|
| 299 |
+
for (_, _, frame_values), poses in zip(pose_records, source_pose_blocks)
|
| 300 |
+
]
|
| 301 |
+
source_frame_values = torch.cat(source_frame_blocks, dim=0)
|
| 302 |
+
pose_record_indices = [record_idx for record_idx, _, _ in pose_records]
|
| 303 |
+
target = target_poses[-1].to(device=device, dtype=source_poses.dtype)
|
| 304 |
+
|
| 305 |
+
points = _target_fov_points(
|
| 306 |
+
target,
|
| 307 |
+
fov_half_h=fov_half_h,
|
| 308 |
+
fov_half_v=fov_half_v,
|
| 309 |
+
yaw_samples=fov_yaw_samples,
|
| 310 |
+
pitch_samples=fov_pitch_samples,
|
| 311 |
+
depth_samples=fov_depth_samples,
|
| 312 |
+
radius=fov_radius,
|
| 313 |
+
)
|
| 314 |
+
inside = _inside_fov_3d_hv(points, source_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v)
|
| 315 |
+
per_frame_fov = inside.float().mean(dim=1)
|
| 316 |
+
|
| 317 |
+
source_desc = _plucker_descriptor(source_poses, grid_h=plucker_grid_h, grid_w=plucker_grid_w, focal_length=plucker_focal_length)
|
| 318 |
+
target_desc = _plucker_descriptor(
|
| 319 |
+
target.reshape(1, -1),
|
| 320 |
+
grid_h=plucker_grid_h,
|
| 321 |
+
grid_w=plucker_grid_w,
|
| 322 |
+
focal_length=plucker_focal_length,
|
| 323 |
+
)
|
| 324 |
+
diff = source_desc - target_desc
|
| 325 |
+
distance = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1]))
|
| 326 |
+
best_distance = torch.full((len(records),), float("inf"), device=device, dtype=distance.dtype)
|
| 327 |
+
best_distance.scatter_reduce_(0, record_ids, distance, reduce="amin", include_self=True)
|
| 328 |
+
plucker_values = (1.0 / (1.0 + best_distance.clamp_min(0.0))).detach().cpu().tolist()
|
| 329 |
+
|
| 330 |
+
for record_idx in pose_record_indices:
|
| 331 |
+
rows = record_ids == int(record_idx)
|
| 332 |
+
if not rows.any():
|
| 333 |
+
continue
|
| 334 |
+
record_fov = per_frame_fov[rows]
|
| 335 |
+
record_inside = inside[rows]
|
| 336 |
+
best_pose_row = int(torch.argmax(record_fov).item())
|
| 337 |
+
best_score_value = float(record_fov[best_pose_row].item())
|
| 338 |
+
fov_overlaps[record_idx] = best_score_value
|
| 339 |
+
plucker_overlaps[record_idx] = float(plucker_values[record_idx])
|
| 340 |
+
coverage_masks[record_idx] = record_inside[best_pose_row].detach()
|
| 341 |
+
valid_rows = rows & (source_frame_values >= 0)
|
| 342 |
+
if valid_rows.any():
|
| 343 |
+
frame_scores = per_frame_fov[valid_rows].detach().cpu().tolist()
|
| 344 |
+
frame_values = source_frame_values[valid_rows].detach().cpu().tolist()
|
| 345 |
+
best_score, best_frame = max(
|
| 346 |
+
((float(score), int(frame)) for score, frame in zip(frame_scores, frame_values)),
|
| 347 |
+
key=lambda item: (item[0], item[1]),
|
| 348 |
+
)
|
| 349 |
+
best_frame_indices[record_idx] = int(best_frame)
|
| 350 |
+
best_frame_fov_overlaps[record_idx] = float(best_score)
|
| 351 |
+
return fov_overlaps, plucker_overlaps, coverage_masks, best_frame_indices, best_frame_fov_overlaps
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def make_revisit_candidate_labels(
|
| 355 |
+
records: list[MemoryRecord],
|
| 356 |
+
*,
|
| 357 |
+
target_frame: int,
|
| 358 |
+
exclude_local_context_frames: int,
|
| 359 |
+
target_pose=None,
|
| 360 |
+
fov_overlap_threshold: Optional[float] = 0.0,
|
| 361 |
+
plucker_weight: float = 0.1,
|
| 362 |
+
target_video_id: Any = None,
|
| 363 |
+
fov_half_h: float = 105.0 / 2.0,
|
| 364 |
+
fov_half_v: float = 75.0 / 2.0,
|
| 365 |
+
fov_yaw_samples: int = 25,
|
| 366 |
+
fov_pitch_samples: int = 20,
|
| 367 |
+
fov_depth_samples: int = 20,
|
| 368 |
+
fov_radius: float = 30.0,
|
| 369 |
+
plucker_grid_h: int = 4,
|
| 370 |
+
plucker_grid_w: int = 4,
|
| 371 |
+
plucker_focal_length: float = 0.35,
|
| 372 |
+
) -> list[RevisitCandidateLabel]:
|
| 373 |
+
del plucker_weight, target_video_id
|
| 374 |
+
target_frame = int(target_frame)
|
| 375 |
+
exclude_local_context_frames = int(exclude_local_context_frames)
|
| 376 |
+
|
| 377 |
+
score_indices: list[int] = []
|
| 378 |
+
max_source_frames: list[int] = []
|
| 379 |
+
gap_flags: list[bool] = []
|
| 380 |
+
for record_idx, record in enumerate(records):
|
| 381 |
+
max_source_frame = int(record.source_end) - 1
|
| 382 |
+
gap_ok = max_source_frame < target_frame - exclude_local_context_frames
|
| 383 |
+
max_source_frames.append(max_source_frame)
|
| 384 |
+
gap_flags.append(gap_ok)
|
| 385 |
+
if gap_ok:
|
| 386 |
+
score_indices.append(record_idx)
|
| 387 |
+
|
| 388 |
+
fov_overlaps: list[Optional[float]] = [None] * len(records)
|
| 389 |
+
plucker_overlaps: list[Optional[float]] = [None] * len(records)
|
| 390 |
+
coverage_masks: list[Optional[torch.Tensor]] = [None] * len(records)
|
| 391 |
+
best_frame_indices: list[Optional[int]] = [None] * len(records)
|
| 392 |
+
best_frame_fov_overlaps: list[Optional[float]] = [None] * len(records)
|
| 393 |
+
scored_records = [records[record_idx] for record_idx in score_indices]
|
| 394 |
+
scored_fov, scored_plucker, scored_masks, scored_best_frames, scored_best_frame_fov = _batched_pose_overlaps(
|
| 395 |
+
scored_records,
|
| 396 |
+
target_pose=target_pose,
|
| 397 |
+
fov_half_h=fov_half_h,
|
| 398 |
+
fov_half_v=fov_half_v,
|
| 399 |
+
fov_yaw_samples=fov_yaw_samples,
|
| 400 |
+
fov_pitch_samples=fov_pitch_samples,
|
| 401 |
+
fov_depth_samples=fov_depth_samples,
|
| 402 |
+
fov_radius=fov_radius,
|
| 403 |
+
plucker_grid_h=plucker_grid_h,
|
| 404 |
+
plucker_grid_w=plucker_grid_w,
|
| 405 |
+
plucker_focal_length=plucker_focal_length,
|
| 406 |
+
)
|
| 407 |
+
for scored_idx, record_idx in enumerate(score_indices):
|
| 408 |
+
fov_overlaps[record_idx] = scored_fov[scored_idx]
|
| 409 |
+
plucker_overlaps[record_idx] = scored_plucker[scored_idx]
|
| 410 |
+
coverage_masks[record_idx] = scored_masks[scored_idx]
|
| 411 |
+
best_frame_indices[record_idx] = scored_best_frames[scored_idx]
|
| 412 |
+
best_frame_fov_overlaps[record_idx] = scored_best_frame_fov[scored_idx]
|
| 413 |
+
|
| 414 |
+
labels: list[RevisitCandidateLabel] = []
|
| 415 |
+
for record_idx, record in enumerate(records):
|
| 416 |
+
gap = target_frame - max_source_frames[record_idx]
|
| 417 |
+
fov_overlap = fov_overlaps[record_idx]
|
| 418 |
+
reasons: list[str] = []
|
| 419 |
+
if not gap_flags[record_idx]:
|
| 420 |
+
reasons.append("inside_c_short")
|
| 421 |
+
if not _passes_threshold(fov_overlap, fov_overlap_threshold):
|
| 422 |
+
reasons.append("fov_overlap_below_threshold")
|
| 423 |
+
|
| 424 |
+
fov_score = 0.0 if fov_overlap is None else float(fov_overlap)
|
| 425 |
+
labels.append(
|
| 426 |
+
RevisitCandidateLabel(
|
| 427 |
+
record=record,
|
| 428 |
+
valid=not reasons,
|
| 429 |
+
gap_valid=gap_flags[record_idx],
|
| 430 |
+
gap_to_target=gap,
|
| 431 |
+
fov_overlap=fov_overlap,
|
| 432 |
+
plucker_overlap=plucker_overlaps[record_idx],
|
| 433 |
+
primary_overlap=fov_score,
|
| 434 |
+
coverage_mask=coverage_masks[record_idx],
|
| 435 |
+
reject_reasons=tuple(reasons),
|
| 436 |
+
best_frame_index=best_frame_indices[record_idx],
|
| 437 |
+
best_frame_fov_overlap=best_frame_fov_overlaps[record_idx],
|
| 438 |
+
)
|
| 439 |
+
)
|
| 440 |
+
return labels
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def make_revisit_candidate_label(
|
| 444 |
+
record: MemoryRecord,
|
| 445 |
+
*,
|
| 446 |
+
target_frame: int,
|
| 447 |
+
exclude_local_context_frames: int,
|
| 448 |
+
target_pose=None,
|
| 449 |
+
fov_overlap_threshold: Optional[float] = 0.0,
|
| 450 |
+
plucker_weight: float = 0.1,
|
| 451 |
+
target_video_id: Any = None,
|
| 452 |
+
fov_half_h: float = 105.0 / 2.0,
|
| 453 |
+
fov_half_v: float = 75.0 / 2.0,
|
| 454 |
+
fov_yaw_samples: int = 25,
|
| 455 |
+
fov_pitch_samples: int = 20,
|
| 456 |
+
fov_depth_samples: int = 20,
|
| 457 |
+
fov_radius: float = 30.0,
|
| 458 |
+
plucker_grid_h: int = 4,
|
| 459 |
+
plucker_grid_w: int = 4,
|
| 460 |
+
plucker_focal_length: float = 0.35,
|
| 461 |
+
) -> RevisitCandidateLabel:
|
| 462 |
+
return make_revisit_candidate_labels(
|
| 463 |
+
[record],
|
| 464 |
+
target_frame=target_frame,
|
| 465 |
+
exclude_local_context_frames=exclude_local_context_frames,
|
| 466 |
+
target_pose=target_pose,
|
| 467 |
+
fov_overlap_threshold=fov_overlap_threshold,
|
| 468 |
+
plucker_weight=plucker_weight,
|
| 469 |
+
target_video_id=target_video_id,
|
| 470 |
+
fov_half_h=fov_half_h,
|
| 471 |
+
fov_half_v=fov_half_v,
|
| 472 |
+
fov_yaw_samples=fov_yaw_samples,
|
| 473 |
+
fov_pitch_samples=fov_pitch_samples,
|
| 474 |
+
fov_depth_samples=fov_depth_samples,
|
| 475 |
+
fov_radius=fov_radius,
|
| 476 |
+
plucker_grid_h=plucker_grid_h,
|
| 477 |
+
plucker_grid_w=plucker_grid_w,
|
| 478 |
+
plucker_focal_length=plucker_focal_length,
|
| 479 |
+
)[0]
|
algorithms/worldmem/dememwm/memory.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Iterable, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .types import MemoryRecord, MemorySourceType
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class MemoryBankQuery:
|
| 14 |
+
target_frame: int
|
| 15 |
+
source_type: Optional[MemorySourceType] = None
|
| 16 |
+
include_generated: bool = True
|
| 17 |
+
max_records: Optional[int] = None
|
| 18 |
+
max_slots: Optional[int] = None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CausalMemoryBank:
|
| 22 |
+
"""Small causal memory bank for DeMemWM records."""
|
| 23 |
+
|
| 24 |
+
def __init__(self, max_records: Optional[int] = None, max_slots: Optional[int] = None):
|
| 25 |
+
self.max_records = max_records
|
| 26 |
+
self.max_slots = max_slots
|
| 27 |
+
self._records: list[MemoryRecord] = []
|
| 28 |
+
|
| 29 |
+
def __len__(self) -> int:
|
| 30 |
+
return len(self._records)
|
| 31 |
+
|
| 32 |
+
@property
|
| 33 |
+
def records(self) -> tuple[MemoryRecord, ...]:
|
| 34 |
+
return tuple(self._records)
|
| 35 |
+
|
| 36 |
+
def add_record(self, record: MemoryRecord) -> None:
|
| 37 |
+
if record.source_type == MemorySourceType.PREFIX_GT and record.is_generated:
|
| 38 |
+
raise ValueError("generated records cannot be high-trust prefix anchors")
|
| 39 |
+
self._records.append(record)
|
| 40 |
+
if self.max_records is not None and len(self._records) > self.max_records:
|
| 41 |
+
self._records = self._records[-self.max_records:]
|
| 42 |
+
|
| 43 |
+
def add_prefix_anchors(
|
| 44 |
+
self,
|
| 45 |
+
tokens: torch.Tensor,
|
| 46 |
+
mask: torch.Tensor,
|
| 47 |
+
frame_indices: torch.Tensor,
|
| 48 |
+
pose: Optional[torch.Tensor] = None,
|
| 49 |
+
slots_per_anchor: Optional[int] = None,
|
| 50 |
+
) -> None:
|
| 51 |
+
if tokens.ndim == 2:
|
| 52 |
+
tokens = tokens.unsqueeze(0)
|
| 53 |
+
if mask.ndim == 1:
|
| 54 |
+
mask = mask.unsqueeze(0)
|
| 55 |
+
flat_frames = frame_indices.detach().reshape(-1)
|
| 56 |
+
if tokens.shape[0] != flat_frames.numel():
|
| 57 |
+
raise ValueError("tokens first dimension must match number of frame indices")
|
| 58 |
+
for i, frame in enumerate(flat_frames.tolist()):
|
| 59 |
+
rec_tokens = tokens[i]
|
| 60 |
+
rec_mask = mask[i].bool()
|
| 61 |
+
if slots_per_anchor is not None:
|
| 62 |
+
rec_tokens = rec_tokens[:slots_per_anchor]
|
| 63 |
+
rec_mask = rec_mask[:slots_per_anchor]
|
| 64 |
+
self.add_record(
|
| 65 |
+
MemoryRecord(
|
| 66 |
+
tokens=rec_tokens,
|
| 67 |
+
mask=rec_mask,
|
| 68 |
+
source_start=int(frame),
|
| 69 |
+
source_end=int(frame) + 1,
|
| 70 |
+
frame_indices=torch.as_tensor([frame], device=rec_tokens.device),
|
| 71 |
+
pose=None if pose is None else pose[i],
|
| 72 |
+
source_type=MemorySourceType.PREFIX_GT,
|
| 73 |
+
is_generated=False,
|
| 74 |
+
chunk_id=f"prefix_{int(frame)}",
|
| 75 |
+
)
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
def add_chunk_record(
|
| 79 |
+
self,
|
| 80 |
+
tokens: torch.Tensor,
|
| 81 |
+
mask: torch.Tensor,
|
| 82 |
+
frame_indices: torch.Tensor,
|
| 83 |
+
pose: Optional[torch.Tensor] = None,
|
| 84 |
+
source_type: MemorySourceType = MemorySourceType.PREFIX_GT,
|
| 85 |
+
is_generated: bool = False,
|
| 86 |
+
chunk_id: Optional[str] = None,
|
| 87 |
+
metadata: Optional[dict] = None,
|
| 88 |
+
) -> None:
|
| 89 |
+
flat_frames = frame_indices.detach().reshape(-1)
|
| 90 |
+
if flat_frames.numel() == 0:
|
| 91 |
+
raise ValueError("chunk frame_indices must be non-empty")
|
| 92 |
+
if tokens.ndim != 2:
|
| 93 |
+
raise ValueError("chunk tokens must have shape (M,D)")
|
| 94 |
+
if mask.ndim != 1 or mask.shape[0] != tokens.shape[0]:
|
| 95 |
+
raise ValueError("chunk mask must have shape (M,)")
|
| 96 |
+
start = int(flat_frames.min().item())
|
| 97 |
+
end = int(flat_frames.max().item()) + 1
|
| 98 |
+
self.add_record(
|
| 99 |
+
MemoryRecord(
|
| 100 |
+
tokens=tokens,
|
| 101 |
+
mask=mask.bool(),
|
| 102 |
+
source_start=start,
|
| 103 |
+
source_end=end,
|
| 104 |
+
frame_indices=flat_frames.to(device=tokens.device),
|
| 105 |
+
pose=pose,
|
| 106 |
+
source_type=source_type,
|
| 107 |
+
is_generated=bool(is_generated),
|
| 108 |
+
chunk_id=chunk_id or f"{source_type.value}_chunk_{start}_{end}",
|
| 109 |
+
metadata=dict(metadata or {}),
|
| 110 |
+
)
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def add_frame_record(
|
| 114 |
+
self,
|
| 115 |
+
tokens: torch.Tensor,
|
| 116 |
+
mask: torch.Tensor,
|
| 117 |
+
frame_index: torch.Tensor | int,
|
| 118 |
+
pose: Optional[torch.Tensor] = None,
|
| 119 |
+
source_type: MemorySourceType = MemorySourceType.REVISIT,
|
| 120 |
+
is_generated: bool = False,
|
| 121 |
+
record_id: Optional[str] = None,
|
| 122 |
+
metadata: Optional[dict] = None,
|
| 123 |
+
) -> None:
|
| 124 |
+
frame_tensor = torch.as_tensor([int(torch.as_tensor(frame_index).reshape(-1)[0].item())], device=tokens.device)
|
| 125 |
+
frame = int(frame_tensor.item())
|
| 126 |
+
self.add_record(
|
| 127 |
+
MemoryRecord(
|
| 128 |
+
tokens=tokens,
|
| 129 |
+
mask=mask.bool(),
|
| 130 |
+
source_start=frame,
|
| 131 |
+
source_end=frame + 1,
|
| 132 |
+
frame_indices=frame_tensor,
|
| 133 |
+
pose=pose,
|
| 134 |
+
source_type=source_type,
|
| 135 |
+
is_generated=bool(is_generated),
|
| 136 |
+
chunk_id=record_id or f"{source_type.value}_frame_{frame}",
|
| 137 |
+
metadata=dict(metadata or {}),
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def add_generated_records(
|
| 142 |
+
self,
|
| 143 |
+
tokens: torch.Tensor,
|
| 144 |
+
mask: torch.Tensor,
|
| 145 |
+
frame_indices: torch.Tensor,
|
| 146 |
+
pose: Optional[torch.Tensor] = None,
|
| 147 |
+
source_type: MemorySourceType = MemorySourceType.GENERATED,
|
| 148 |
+
) -> None:
|
| 149 |
+
if source_type == MemorySourceType.PREFIX_GT:
|
| 150 |
+
raise ValueError("generated frames cannot be added as PREFIX_GT anchors by default")
|
| 151 |
+
if tokens.ndim == 2:
|
| 152 |
+
tokens = tokens.unsqueeze(0)
|
| 153 |
+
if mask.ndim == 1:
|
| 154 |
+
mask = mask.unsqueeze(0)
|
| 155 |
+
flat_frames = frame_indices.detach().reshape(-1)
|
| 156 |
+
for i, frame in enumerate(flat_frames.tolist()):
|
| 157 |
+
self.add_record(
|
| 158 |
+
MemoryRecord(
|
| 159 |
+
tokens=tokens[i],
|
| 160 |
+
mask=mask[i].bool(),
|
| 161 |
+
source_start=int(frame),
|
| 162 |
+
source_end=int(frame) + 1,
|
| 163 |
+
frame_indices=torch.as_tensor([frame], device=tokens.device),
|
| 164 |
+
pose=None if pose is None else pose[i],
|
| 165 |
+
source_type=source_type,
|
| 166 |
+
is_generated=True,
|
| 167 |
+
chunk_id=f"generated_{int(frame)}",
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def query(self, query: MemoryBankQuery | int, **kwargs) -> list[MemoryRecord]:
|
| 172 |
+
if isinstance(query, int):
|
| 173 |
+
query = MemoryBankQuery(target_frame=query, **kwargs)
|
| 174 |
+
out: list[MemoryRecord] = []
|
| 175 |
+
used_slots = 0
|
| 176 |
+
for record in self._records:
|
| 177 |
+
if int(record.source_end) > int(query.target_frame):
|
| 178 |
+
continue
|
| 179 |
+
if query.source_type is not None and record.source_type != query.source_type:
|
| 180 |
+
continue
|
| 181 |
+
if not query.include_generated and record.is_generated:
|
| 182 |
+
continue
|
| 183 |
+
if query.max_slots is not None and used_slots >= query.max_slots:
|
| 184 |
+
break
|
| 185 |
+
out.append(record)
|
| 186 |
+
if query.max_slots is not None:
|
| 187 |
+
used_slots += record.valid_slots
|
| 188 |
+
if query.max_records is not None and len(out) >= query.max_records:
|
| 189 |
+
break
|
| 190 |
+
if query.max_slots is not None and used_slots >= query.max_slots:
|
| 191 |
+
break
|
| 192 |
+
return out
|
| 193 |
+
|
| 194 |
+
def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None:
|
| 195 |
+
offenders = [r.chunk_id or f"[{r.source_start},{r.source_end})" for r in records if int(r.source_end) > int(target_frame)]
|
| 196 |
+
if offenders:
|
| 197 |
+
raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def stack_record_tokens(records: list[MemoryRecord], max_slots: int | None = None):
|
| 201 |
+
if not records:
|
| 202 |
+
return None, None
|
| 203 |
+
tokens = torch.cat([r.tokens for r in records], dim=0)
|
| 204 |
+
mask = torch.cat([r.mask.bool() for r in records], dim=0)
|
| 205 |
+
if max_slots is not None:
|
| 206 |
+
tokens = tokens[:max_slots]
|
| 207 |
+
mask = mask[:max_slots]
|
| 208 |
+
return tokens, mask
|
algorithms/worldmem/dememwm/negatives.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from .schedules import EVAL_CORRUPTION_BRANCHES
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _deterministic_noise_like(tokens: torch.Tensor, seed: int) -> torch.Tensor:
|
| 9 |
+
flat = torch.arange(tokens.numel(), device=tokens.device, dtype=torch.float32).reshape(tokens.shape)
|
| 10 |
+
noise = torch.sin(flat + float(seed) * 0.137).to(dtype=tokens.dtype)
|
| 11 |
+
scale = tokens.detach().float().std().to(device=tokens.device).clamp_min(0.05).to(dtype=tokens.dtype)
|
| 12 |
+
return noise * scale
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def apply_revisit_eval_corruption(
|
| 16 |
+
*,
|
| 17 |
+
tokens: torch.Tensor,
|
| 18 |
+
mask: torch.Tensor,
|
| 19 |
+
branch: str,
|
| 20 |
+
target_frame: int,
|
| 21 |
+
) -> tuple[torch.Tensor, bool]:
|
| 22 |
+
if branch not in EVAL_CORRUPTION_BRANCHES or not mask.any():
|
| 23 |
+
return tokens, False
|
| 24 |
+
|
| 25 |
+
corrupted = tokens.clone()
|
| 26 |
+
if branch == "wrong_pose":
|
| 27 |
+
corrupted = -corrupted
|
| 28 |
+
elif branch == "time_shuffle":
|
| 29 |
+
corrupted = torch.flip(corrupted, dims=(0,))
|
| 30 |
+
elif branch == "source_matched_random":
|
| 31 |
+
corrupted = _deterministic_noise_like(corrupted, seed=int(target_frame))
|
| 32 |
+
elif branch == "local_context_overlap_fake_revisit":
|
| 33 |
+
corrupted = torch.roll(corrupted, shifts=1, dims=0)
|
| 34 |
+
elif branch == "pose_shuffle":
|
| 35 |
+
corrupted = torch.roll(corrupted, shifts=1, dims=-1)
|
| 36 |
+
elif branch == "wrong_video":
|
| 37 |
+
corrupted = corrupted.detach().mean().to(dtype=corrupted.dtype).expand_as(corrupted).clone()
|
| 38 |
+
else:
|
| 39 |
+
return tokens, False
|
| 40 |
+
|
| 41 |
+
return corrupted, True
|
algorithms/worldmem/dememwm/retrieval.py
ADDED
|
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
from dataclasses import replace
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .labels import (
|
| 10 |
+
LABEL_SOURCE,
|
| 11 |
+
RevisitCandidateLabel,
|
| 12 |
+
_inside_fov_3d_hv,
|
| 13 |
+
_plucker_descriptor,
|
| 14 |
+
_target_fov_points,
|
| 15 |
+
)
|
| 16 |
+
from .types import MemoryRecord, RevisitRetrievalResult
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _overlap_values(labels, name: str) -> list[float]:
|
| 20 |
+
values: list[float] = []
|
| 21 |
+
for label in labels:
|
| 22 |
+
value = getattr(label, name)
|
| 23 |
+
if value is not None:
|
| 24 |
+
values.append(float(value))
|
| 25 |
+
return values
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _overlap_stats(values: list[float], prefix: str) -> dict[str, float]:
|
| 29 |
+
if not values:
|
| 30 |
+
return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0}
|
| 31 |
+
return {
|
| 32 |
+
f"{prefix}_mean": float(sum(values) / len(values)),
|
| 33 |
+
f"{prefix}_min": float(min(values)),
|
| 34 |
+
f"{prefix}_max": float(max(values)),
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _pose_rows(pose) -> torch.Tensor | None:
|
| 39 |
+
if pose is None:
|
| 40 |
+
return None
|
| 41 |
+
pose_tensor = pose if torch.is_tensor(pose) else torch.as_tensor(pose, dtype=torch.float32)
|
| 42 |
+
if pose_tensor.ndim == 0 or pose_tensor.numel() == 0 or pose_tensor.shape[-1] < 5:
|
| 43 |
+
return None
|
| 44 |
+
return pose_tensor.detach().reshape(-1, pose_tensor.shape[-1])[:, :5].to(dtype=torch.float32)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _pose_forward(poses: torch.Tensor) -> torch.Tensor:
|
| 48 |
+
pitch = torch.deg2rad(poses[:, 3])
|
| 49 |
+
yaw = torch.deg2rad(poses[:, 4])
|
| 50 |
+
cos_pitch = torch.cos(pitch)
|
| 51 |
+
return torch.stack(
|
| 52 |
+
[
|
| 53 |
+
cos_pitch * torch.sin(yaw),
|
| 54 |
+
torch.sin(pitch),
|
| 55 |
+
cos_pitch * torch.cos(yaw),
|
| 56 |
+
],
|
| 57 |
+
dim=-1,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _single_frame_pose(record: MemoryRecord) -> torch.Tensor | None:
|
| 62 |
+
if int(record.frame_indices.numel()) != 1:
|
| 63 |
+
return None
|
| 64 |
+
pose_rows = _pose_rows(record.pose)
|
| 65 |
+
if pose_rows is None or int(pose_rows.shape[0]) != 1:
|
| 66 |
+
return None
|
| 67 |
+
return pose_rows[0]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _vectorized_frame_candidate_labels(
|
| 71 |
+
records: list[MemoryRecord],
|
| 72 |
+
*,
|
| 73 |
+
target_frame: int,
|
| 74 |
+
target_pose,
|
| 75 |
+
fov_overlap_threshold: Optional[float],
|
| 76 |
+
fov_half_h: float,
|
| 77 |
+
fov_half_v: float,
|
| 78 |
+
fov_yaw_samples: int,
|
| 79 |
+
fov_pitch_samples: int,
|
| 80 |
+
fov_depth_samples: int,
|
| 81 |
+
fov_radius: float,
|
| 82 |
+
plucker_grid_h: int,
|
| 83 |
+
plucker_grid_w: int,
|
| 84 |
+
plucker_focal_length: float,
|
| 85 |
+
pose_preselect_topk: Optional[int],
|
| 86 |
+
) -> tuple[list[RevisitCandidateLabel], dict[str, float | int]]:
|
| 87 |
+
diagnostics: dict[str, float | int] = {
|
| 88 |
+
"revisit_pose_preselect_input_count": len(records),
|
| 89 |
+
"revisit_pose_preselect_scored_count": len(records),
|
| 90 |
+
"revisit_pose_preselect_unscored_count": 0,
|
| 91 |
+
"revisit_pose_preselect_selected_count": len(records),
|
| 92 |
+
"revisit_pose_preselect_min_distance": 0.0,
|
| 93 |
+
"revisit_pose_preselect_max_distance": 0.0,
|
| 94 |
+
"revisit_exact_fov_candidate_count": len(records),
|
| 95 |
+
"revisit_vectorized_frame_scorer_used": 1,
|
| 96 |
+
}
|
| 97 |
+
if not records:
|
| 98 |
+
return [], diagnostics
|
| 99 |
+
|
| 100 |
+
target_poses = _pose_rows(target_pose)
|
| 101 |
+
if target_poses is None:
|
| 102 |
+
raise ValueError("DeMemWM revisit retrieval requires target_pose for frame-level FoV scoring")
|
| 103 |
+
|
| 104 |
+
pose_rows: list[torch.Tensor] = []
|
| 105 |
+
for record in records:
|
| 106 |
+
pose_row = _single_frame_pose(record)
|
| 107 |
+
if pose_row is None:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"DeMemWM revisit retrieval requires frame-level records with exactly one frame index and one pose row"
|
| 110 |
+
)
|
| 111 |
+
pose_rows.append(pose_row)
|
| 112 |
+
|
| 113 |
+
device = pose_rows[0].device
|
| 114 |
+
if target_poses.is_cuda:
|
| 115 |
+
device = target_poses.device
|
| 116 |
+
source_poses = torch.stack([row.to(device=device, dtype=torch.float32) for row in pose_rows], dim=0)
|
| 117 |
+
target = target_poses[-1].to(device=device, dtype=torch.float32)
|
| 118 |
+
|
| 119 |
+
selected_indices = list(range(len(records)))
|
| 120 |
+
topk = None if pose_preselect_topk is None else int(pose_preselect_topk)
|
| 121 |
+
if topk is not None and topk > 0 and len(records) > topk:
|
| 122 |
+
translation_norm = torch.linalg.vector_norm(source_poses[:, :3] - target[:3], dim=-1) / max(float(fov_radius), 1e-6)
|
| 123 |
+
source_forward = _pose_forward(source_poses)
|
| 124 |
+
target_forward = _pose_forward(target.reshape(1, -1)).squeeze(0)
|
| 125 |
+
dot = (source_forward * target_forward.reshape(1, 3)).sum(dim=-1).clamp(-1.0, 1.0)
|
| 126 |
+
distances = translation_norm + (torch.acos(dot) / math.pi)
|
| 127 |
+
distance_values = [float(value) for value in distances.detach().cpu().tolist()]
|
| 128 |
+
ranked = [
|
| 129 |
+
(
|
| 130 |
+
distance_values[idx],
|
| 131 |
+
-int(record.max_source_frame),
|
| 132 |
+
int(record.source_start),
|
| 133 |
+
str(record.chunk_id or ""),
|
| 134 |
+
idx,
|
| 135 |
+
)
|
| 136 |
+
for idx, record in enumerate(records)
|
| 137 |
+
]
|
| 138 |
+
ranked.sort()
|
| 139 |
+
selected_indices = [idx for *_, idx in ranked[:topk]]
|
| 140 |
+
diagnostics["revisit_pose_preselect_selected_count"] = len(selected_indices)
|
| 141 |
+
diagnostics["revisit_pose_preselect_min_distance"] = float(min(distance_values))
|
| 142 |
+
diagnostics["revisit_pose_preselect_max_distance"] = float(max(distance_values))
|
| 143 |
+
|
| 144 |
+
selected_tensor = torch.tensor(selected_indices, device=device, dtype=torch.long)
|
| 145 |
+
selected_records = [records[idx] for idx in selected_indices]
|
| 146 |
+
selected_poses = source_poses.index_select(0, selected_tensor)
|
| 147 |
+
points = _target_fov_points(
|
| 148 |
+
target,
|
| 149 |
+
fov_half_h=fov_half_h,
|
| 150 |
+
fov_half_v=fov_half_v,
|
| 151 |
+
yaw_samples=fov_yaw_samples,
|
| 152 |
+
pitch_samples=fov_pitch_samples,
|
| 153 |
+
depth_samples=fov_depth_samples,
|
| 154 |
+
radius=fov_radius,
|
| 155 |
+
)
|
| 156 |
+
inside = _inside_fov_3d_hv(points, selected_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v)
|
| 157 |
+
fov_values = inside.float().mean(dim=1)
|
| 158 |
+
|
| 159 |
+
source_desc = _plucker_descriptor(
|
| 160 |
+
selected_poses,
|
| 161 |
+
grid_h=plucker_grid_h,
|
| 162 |
+
grid_w=plucker_grid_w,
|
| 163 |
+
focal_length=plucker_focal_length,
|
| 164 |
+
)
|
| 165 |
+
target_desc = _plucker_descriptor(
|
| 166 |
+
target.reshape(1, -1),
|
| 167 |
+
grid_h=plucker_grid_h,
|
| 168 |
+
grid_w=plucker_grid_w,
|
| 169 |
+
focal_length=plucker_focal_length,
|
| 170 |
+
)
|
| 171 |
+
diff = source_desc - target_desc
|
| 172 |
+
distances = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1]))
|
| 173 |
+
plucker_values = 1.0 / (1.0 + distances.clamp_min(0.0))
|
| 174 |
+
valid_mask = torch.ones_like(fov_values, dtype=torch.bool)
|
| 175 |
+
if fov_overlap_threshold is not None:
|
| 176 |
+
valid_mask = fov_values >= float(fov_overlap_threshold)
|
| 177 |
+
|
| 178 |
+
diagnostics["revisit_exact_fov_candidate_count"] = len(selected_records)
|
| 179 |
+
fov_list = [float(value) for value in fov_values.detach().cpu().tolist()]
|
| 180 |
+
plucker_list = [float(value) for value in plucker_values.detach().cpu().tolist()]
|
| 181 |
+
valid_list = [bool(value) for value in valid_mask.detach().cpu().tolist()]
|
| 182 |
+
|
| 183 |
+
labels: list[RevisitCandidateLabel] = []
|
| 184 |
+
for row_idx, record in enumerate(selected_records):
|
| 185 |
+
fov_overlap = fov_list[row_idx]
|
| 186 |
+
reasons = () if valid_list[row_idx] else ("fov_overlap_below_threshold",)
|
| 187 |
+
gap_to_target = int(target_frame) - (int(record.source_end) - 1)
|
| 188 |
+
labels.append(
|
| 189 |
+
RevisitCandidateLabel(
|
| 190 |
+
record=record,
|
| 191 |
+
valid=valid_list[row_idx],
|
| 192 |
+
gap_valid=True,
|
| 193 |
+
gap_to_target=gap_to_target,
|
| 194 |
+
fov_overlap=fov_overlap,
|
| 195 |
+
plucker_overlap=plucker_list[row_idx],
|
| 196 |
+
primary_overlap=fov_overlap,
|
| 197 |
+
coverage_mask=inside[row_idx].detach(),
|
| 198 |
+
reject_reasons=reasons,
|
| 199 |
+
best_frame_index=int(record.max_source_frame),
|
| 200 |
+
best_frame_fov_overlap=fov_overlap,
|
| 201 |
+
)
|
| 202 |
+
)
|
| 203 |
+
return labels, diagnostics
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _coverage_gain(label: RevisitCandidateLabel, covered_mask: torch.Tensor | None) -> float:
|
| 207 |
+
mask = label.coverage_mask
|
| 208 |
+
if mask is None or mask.numel() == 0:
|
| 209 |
+
return 0.0 if label.fov_overlap is None else float(label.fov_overlap)
|
| 210 |
+
mask = mask.detach().bool()
|
| 211 |
+
if covered_mask is None or covered_mask.shape != mask.shape:
|
| 212 |
+
return float(mask.float().mean().item())
|
| 213 |
+
return float((mask & ~covered_mask.to(device=mask.device, dtype=torch.bool)).float().mean().item())
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _coverage_gains(labels: list[RevisitCandidateLabel], covered_mask: torch.Tensor | None) -> list[float]:
|
| 217 |
+
masks = [label.coverage_mask for label in labels]
|
| 218 |
+
valid_masks = [mask for mask in masks if mask is not None and mask.numel() > 0]
|
| 219 |
+
if not valid_masks:
|
| 220 |
+
return [0.0 if label.fov_overlap is None else float(label.fov_overlap) for label in labels]
|
| 221 |
+
|
| 222 |
+
shape = valid_masks[0].shape
|
| 223 |
+
device = valid_masks[0].device
|
| 224 |
+
if any(mask.shape != shape for mask in valid_masks):
|
| 225 |
+
return [_coverage_gain(label, covered_mask) for label in labels]
|
| 226 |
+
|
| 227 |
+
stacked = torch.stack([
|
| 228 |
+
torch.zeros(shape, device=device, dtype=torch.bool)
|
| 229 |
+
if mask is None or mask.numel() == 0
|
| 230 |
+
else mask.detach().to(device=device, dtype=torch.bool)
|
| 231 |
+
for mask in masks
|
| 232 |
+
])
|
| 233 |
+
if covered_mask is None or covered_mask.shape != shape:
|
| 234 |
+
gains = stacked.float().mean(dim=1)
|
| 235 |
+
else:
|
| 236 |
+
covered = covered_mask.to(device=device, dtype=torch.bool)
|
| 237 |
+
gains = (stacked & ~covered).float().mean(dim=1)
|
| 238 |
+
return [float(value) for value in gains.detach().cpu().tolist()]
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def _select_greedy_coverage(
|
| 242 |
+
labels: list[RevisitCandidateLabel],
|
| 243 |
+
*,
|
| 244 |
+
topk: int,
|
| 245 |
+
plucker_weight: float,
|
| 246 |
+
) -> tuple[list[RevisitCandidateLabel], list[float], list[float]]:
|
| 247 |
+
remaining = list(labels)
|
| 248 |
+
selected: list[RevisitCandidateLabel] = []
|
| 249 |
+
selected_scores: list[float] = []
|
| 250 |
+
selected_gains: list[float] = []
|
| 251 |
+
covered_mask: torch.Tensor | None = None
|
| 252 |
+
for _ in range(max(0, int(topk))):
|
| 253 |
+
if not remaining:
|
| 254 |
+
break
|
| 255 |
+
gains = _coverage_gains(remaining, covered_mask)
|
| 256 |
+
ranked = []
|
| 257 |
+
for idx, (label, gain) in enumerate(zip(remaining, gains)):
|
| 258 |
+
plucker = 0.0 if label.plucker_overlap is None else float(label.plucker_overlap)
|
| 259 |
+
fov = 0.0 if label.fov_overlap is None else float(label.fov_overlap)
|
| 260 |
+
plucker_secondary = float(plucker_weight) * plucker
|
| 261 |
+
ranked.append((
|
| 262 |
+
-gain,
|
| 263 |
+
-fov,
|
| 264 |
+
-plucker_secondary,
|
| 265 |
+
label.gap_to_target,
|
| 266 |
+
int(label.record.source_start),
|
| 267 |
+
str(label.record.chunk_id or ""),
|
| 268 |
+
idx,
|
| 269 |
+
gain,
|
| 270 |
+
))
|
| 271 |
+
ranked.sort()
|
| 272 |
+
_, _, _, _, _, _, best_idx, best_gain = ranked[0]
|
| 273 |
+
label = remaining.pop(best_idx)
|
| 274 |
+
selected.append(label)
|
| 275 |
+
selected_scores.append(float(best_gain))
|
| 276 |
+
selected_gains.append(float(best_gain))
|
| 277 |
+
if label.coverage_mask is not None and label.coverage_mask.numel() > 0:
|
| 278 |
+
mask = label.coverage_mask.detach().bool()
|
| 279 |
+
if covered_mask is None or covered_mask.shape != mask.shape:
|
| 280 |
+
covered_mask = torch.zeros_like(mask, dtype=torch.bool)
|
| 281 |
+
covered_mask = covered_mask.to(device=mask.device, dtype=torch.bool) | mask
|
| 282 |
+
return selected, selected_scores, selected_gains
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _best_selected_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidateLabel | None:
|
| 286 |
+
if not labels:
|
| 287 |
+
return None
|
| 288 |
+
return max(
|
| 289 |
+
labels,
|
| 290 |
+
key=lambda label: (
|
| 291 |
+
0.0 if label.fov_overlap is None else float(label.fov_overlap),
|
| 292 |
+
0.0 if label.plucker_overlap is None else float(label.plucker_overlap),
|
| 293 |
+
-int(label.gap_to_target),
|
| 294 |
+
-int(label.record.source_start),
|
| 295 |
+
str(label.record.chunk_id or ""),
|
| 296 |
+
),
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def _best_selected_frame_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidateLabel | None:
|
| 301 |
+
frame_labels = [label for label in labels if label.best_frame_fov_overlap is not None]
|
| 302 |
+
if not frame_labels:
|
| 303 |
+
return None
|
| 304 |
+
return max(
|
| 305 |
+
frame_labels,
|
| 306 |
+
key=lambda label: (
|
| 307 |
+
float(label.best_frame_fov_overlap),
|
| 308 |
+
0.0 if label.fov_overlap is None else float(label.fov_overlap),
|
| 309 |
+
0.0 if label.plucker_overlap is None else float(label.plucker_overlap),
|
| 310 |
+
-int(label.gap_to_target),
|
| 311 |
+
-int(label.record.source_start),
|
| 312 |
+
str(label.record.chunk_id or ""),
|
| 313 |
+
),
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def _record_with_selected_frame_metadata(
|
| 318 |
+
label: RevisitCandidateLabel,
|
| 319 |
+
*,
|
| 320 |
+
high_quality_fov_threshold: float,
|
| 321 |
+
) -> MemoryRecord:
|
| 322 |
+
metadata = dict(label.record.metadata or {})
|
| 323 |
+
if label.fov_overlap is not None:
|
| 324 |
+
metadata["dememwm_selected_revisit_fov_overlap"] = float(label.fov_overlap)
|
| 325 |
+
if label.plucker_overlap is not None:
|
| 326 |
+
metadata["dememwm_selected_revisit_plucker_overlap"] = float(label.plucker_overlap)
|
| 327 |
+
if label.best_frame_index is not None:
|
| 328 |
+
metadata["dememwm_selected_frame_index"] = int(label.best_frame_index)
|
| 329 |
+
if label.best_frame_fov_overlap is not None:
|
| 330 |
+
frame_fov = float(label.best_frame_fov_overlap)
|
| 331 |
+
metadata["dememwm_selected_frame_fov_overlap"] = frame_fov
|
| 332 |
+
metadata["dememwm_selected_frame_fov_threshold"] = float(high_quality_fov_threshold)
|
| 333 |
+
metadata["dememwm_selected_frame_passes_high_quality"] = bool(frame_fov >= float(high_quality_fov_threshold))
|
| 334 |
+
return replace(label.record, metadata=metadata)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def deterministic_revisit_retrieval(
|
| 338 |
+
records: list[MemoryRecord],
|
| 339 |
+
target_frame: int,
|
| 340 |
+
target_pose: Optional[torch.Tensor] = None,
|
| 341 |
+
target_summary: Optional[torch.Tensor] = None,
|
| 342 |
+
topk: int = 2,
|
| 343 |
+
exclude_local_context_frames: int = 0,
|
| 344 |
+
fov_overlap_threshold: Optional[float] = 0.30,
|
| 345 |
+
high_quality_fov_threshold: float = 0.70,
|
| 346 |
+
plucker_weight: float = 0.1,
|
| 347 |
+
target_video_id: Any = None,
|
| 348 |
+
fov_half_h: float = 105.0 / 2.0,
|
| 349 |
+
fov_half_v: float = 75.0 / 2.0,
|
| 350 |
+
fov_yaw_samples: int = 25,
|
| 351 |
+
fov_pitch_samples: int = 20,
|
| 352 |
+
fov_depth_samples: int = 20,
|
| 353 |
+
fov_radius: float = 30.0,
|
| 354 |
+
plucker_grid_h: int = 4,
|
| 355 |
+
plucker_grid_w: int = 4,
|
| 356 |
+
plucker_focal_length: float = 0.35,
|
| 357 |
+
pose_preselect_topk: Optional[int] = 64,
|
| 358 |
+
**_legacy_scoring_kwargs,
|
| 359 |
+
) -> RevisitRetrievalResult:
|
| 360 |
+
del target_summary, target_video_id
|
| 361 |
+
topk = max(0, int(topk))
|
| 362 |
+
target_frame = int(target_frame)
|
| 363 |
+
exclude_local_context_frames = int(exclude_local_context_frames)
|
| 364 |
+
causal_records = [record for record in records if int(record.source_end) <= target_frame]
|
| 365 |
+
score_records = [
|
| 366 |
+
record
|
| 367 |
+
for record in causal_records
|
| 368 |
+
if int(record.source_end) <= target_frame - exclude_local_context_frames
|
| 369 |
+
]
|
| 370 |
+
labels, pose_preselect_diagnostics = _vectorized_frame_candidate_labels(
|
| 371 |
+
score_records,
|
| 372 |
+
target_frame=target_frame,
|
| 373 |
+
target_pose=target_pose,
|
| 374 |
+
fov_overlap_threshold=fov_overlap_threshold,
|
| 375 |
+
fov_half_h=fov_half_h,
|
| 376 |
+
fov_half_v=fov_half_v,
|
| 377 |
+
fov_yaw_samples=fov_yaw_samples,
|
| 378 |
+
fov_pitch_samples=fov_pitch_samples,
|
| 379 |
+
fov_depth_samples=fov_depth_samples,
|
| 380 |
+
fov_radius=fov_radius,
|
| 381 |
+
plucker_grid_h=plucker_grid_h,
|
| 382 |
+
plucker_grid_w=plucker_grid_w,
|
| 383 |
+
plucker_focal_length=plucker_focal_length,
|
| 384 |
+
pose_preselect_topk=pose_preselect_topk,
|
| 385 |
+
)
|
| 386 |
+
exact_fov_candidate_count = int(pose_preselect_diagnostics["revisit_exact_fov_candidate_count"])
|
| 387 |
+
valid_labels = [label for label in labels if label.valid]
|
| 388 |
+
selected_labels, selected_scores, selected_gains = _select_greedy_coverage(
|
| 389 |
+
valid_labels,
|
| 390 |
+
topk=topk,
|
| 391 |
+
plucker_weight=float(plucker_weight),
|
| 392 |
+
)
|
| 393 |
+
best_selected = _best_selected_label(selected_labels)
|
| 394 |
+
best_selected_frame = _best_selected_frame_label(selected_labels)
|
| 395 |
+
best_selected_fov = 0.0 if best_selected is None or best_selected.fov_overlap is None else float(best_selected.fov_overlap)
|
| 396 |
+
best_selected_plucker = 0.0 if best_selected is None or best_selected.plucker_overlap is None else float(best_selected.plucker_overlap)
|
| 397 |
+
best_selected_gap = -1 if best_selected is None else int(best_selected.gap_to_target)
|
| 398 |
+
best_selected_frame_fov = 0.0 if best_selected_frame is None else float(best_selected_frame.best_frame_fov_overlap)
|
| 399 |
+
best_selected_frame_index = -1 if best_selected_frame is None or best_selected_frame.best_frame_index is None else int(best_selected_frame.best_frame_index)
|
| 400 |
+
high_quality_selected = int(best_selected_frame is not None and best_selected_frame_fov >= float(high_quality_fov_threshold))
|
| 401 |
+
selected_records = [
|
| 402 |
+
_record_with_selected_frame_metadata(label, high_quality_fov_threshold=float(high_quality_fov_threshold))
|
| 403 |
+
for label in selected_labels
|
| 404 |
+
]
|
| 405 |
+
score_device = selected_records[0].tokens.device if selected_records else torch.device("cpu")
|
| 406 |
+
scores = torch.tensor(selected_scores, dtype=torch.float32, device=score_device)
|
| 407 |
+
|
| 408 |
+
fov_values = _overlap_values(valid_labels, "fov_overlap")
|
| 409 |
+
plucker_values = _overlap_values(valid_labels, "plucker_overlap")
|
| 410 |
+
selected_gaps = [label.gap_to_target for label in selected_labels]
|
| 411 |
+
selected_frame_fov_values = [
|
| 412 |
+
float(label.best_frame_fov_overlap)
|
| 413 |
+
for label in selected_labels
|
| 414 |
+
if label.best_frame_fov_overlap is not None
|
| 415 |
+
]
|
| 416 |
+
diagnostics = {
|
| 417 |
+
"target_frame": int(target_frame),
|
| 418 |
+
"candidate_count": len(causal_records),
|
| 419 |
+
"candidate_frame_count": len(causal_records),
|
| 420 |
+
"valid_candidate_count": len(valid_labels),
|
| 421 |
+
"revisit_exact_fov_candidate_count": exact_fov_candidate_count,
|
| 422 |
+
"valid_candidate_frame_count": len(valid_labels),
|
| 423 |
+
"valid_candidate_label_count": len(valid_labels),
|
| 424 |
+
"selected_count": len(selected_records),
|
| 425 |
+
"selected_frame_count": len(selected_records),
|
| 426 |
+
"revisit_candidate_frame_count": len(causal_records),
|
| 427 |
+
"revisit_candidate_count": len(causal_records),
|
| 428 |
+
"valid_revisit_frame_count": len(valid_labels),
|
| 429 |
+
"valid_revisit_count": len(valid_labels),
|
| 430 |
+
"valid_revisit_target_count": high_quality_selected,
|
| 431 |
+
"no_valid_revisit_count": int(len(valid_labels) == 0),
|
| 432 |
+
"valid_revisit_mask": int(len(valid_labels) > 0),
|
| 433 |
+
"revisit_abstained_count": int(len(selected_records) == 0),
|
| 434 |
+
"abstained": bool(len(selected_records) == 0),
|
| 435 |
+
"revisit_selected_frame_count": len(selected_records),
|
| 436 |
+
"revisit_selected_count": len(selected_records),
|
| 437 |
+
"revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1,
|
| 438 |
+
"best_selected_fov_overlap": best_selected_fov,
|
| 439 |
+
"best_selected_plucker_overlap": best_selected_plucker,
|
| 440 |
+
"best_selected_gap_frames": best_selected_gap,
|
| 441 |
+
"best_selected_frame_index": best_selected_frame_index,
|
| 442 |
+
"best_selected_frame_fov_overlap": best_selected_frame_fov,
|
| 443 |
+
"best_selected_frame_passes_high_quality": high_quality_selected,
|
| 444 |
+
"high_quality_selected_revisit": high_quality_selected,
|
| 445 |
+
"high_quality_fov_threshold": float(high_quality_fov_threshold),
|
| 446 |
+
"revisit_label_source": LABEL_SOURCE,
|
| 447 |
+
"selected_frame_ids": [int(record.max_source_frame) for record in selected_records],
|
| 448 |
+
"selected_frame_record_ids": [record.chunk_id for record in selected_records],
|
| 449 |
+
"selected_ranges": [(record.source_start, record.source_end) for record in selected_records],
|
| 450 |
+
"frame_fov_overlap_values": fov_values,
|
| 451 |
+
"fov_overlap_values": fov_values,
|
| 452 |
+
"plucker_overlap_values": plucker_values,
|
| 453 |
+
"best_selected_fov_overlap_values": [] if best_selected is None else [best_selected_fov],
|
| 454 |
+
"best_selected_plucker_overlap_values": [] if best_selected is None else [best_selected_plucker],
|
| 455 |
+
"best_selected_gap_frame_values": [] if best_selected is None else [best_selected_gap],
|
| 456 |
+
"best_selected_frame_fov_overlap_values": [] if best_selected_frame is None else [best_selected_frame_fov],
|
| 457 |
+
"selected_frame_fov_overlap_values": selected_frame_fov_values,
|
| 458 |
+
"selected_incremental_fov_overlap_values": selected_gains,
|
| 459 |
+
"selected_revisit_scores": selected_scores,
|
| 460 |
+
**pose_preselect_diagnostics,
|
| 461 |
+
}
|
| 462 |
+
diagnostics.update(_overlap_stats(fov_values, "revisit_frame_fov_overlap"))
|
| 463 |
+
diagnostics.update(_overlap_stats(fov_values, "revisit_fov_overlap"))
|
| 464 |
+
diagnostics.update(_overlap_stats(plucker_values, "revisit_plucker_overlap"))
|
| 465 |
+
diagnostics.update(_overlap_stats(diagnostics["best_selected_fov_overlap_values"], "revisit_best_selected_fov_overlap"))
|
| 466 |
+
diagnostics.update(_overlap_stats(diagnostics["best_selected_plucker_overlap_values"], "revisit_best_selected_plucker_overlap"))
|
| 467 |
+
diagnostics.update(_overlap_stats(diagnostics["best_selected_gap_frame_values"], "revisit_best_selected_gap_frames"))
|
| 468 |
+
diagnostics.update(_overlap_stats(diagnostics["best_selected_frame_fov_overlap_values"], "revisit_best_selected_frame_fov_overlap"))
|
| 469 |
+
diagnostics.update(_overlap_stats(selected_frame_fov_values, "revisit_selected_frame_fov_overlap"))
|
| 470 |
+
diagnostics.update(_overlap_stats(selected_gains, "revisit_incremental_fov_overlap"))
|
| 471 |
+
return RevisitRetrievalResult(
|
| 472 |
+
records=selected_records,
|
| 473 |
+
scores=scores,
|
| 474 |
+
selected_frame_ids=[int(record.max_source_frame) for record in selected_records],
|
| 475 |
+
diagnostics=diagnostics,
|
| 476 |
+
)
|
algorithms/worldmem/dememwm/schedules.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from .types import StreamGateState
|
| 10 |
+
|
| 11 |
+
NOISE_BUCKETS = ("high", "mid", "low")
|
| 12 |
+
NOISE_BUCKET_TO_ID = {name: idx for idx, name in enumerate(NOISE_BUCKETS)}
|
| 13 |
+
EVAL_ABLATION_BRANCHES = (
|
| 14 |
+
"memory_off",
|
| 15 |
+
"A_only",
|
| 16 |
+
"D_only",
|
| 17 |
+
"A_plus_D",
|
| 18 |
+
"A_plus_D_plus_R_normal",
|
| 19 |
+
"R_forced_off",
|
| 20 |
+
"R_forced_on",
|
| 21 |
+
"wrong_pose",
|
| 22 |
+
"time_shuffle",
|
| 23 |
+
"source_matched_random",
|
| 24 |
+
"pose_shuffle",
|
| 25 |
+
"wrong_video",
|
| 26 |
+
"local_context_overlap_fake_revisit",
|
| 27 |
+
)
|
| 28 |
+
EVAL_ABLATION_BRANCH_TO_ID = {name: idx for idx, name in enumerate(EVAL_ABLATION_BRANCHES)}
|
| 29 |
+
EVAL_CORRUPTION_BRANCHES = (
|
| 30 |
+
"wrong_pose",
|
| 31 |
+
"time_shuffle",
|
| 32 |
+
"source_matched_random",
|
| 33 |
+
"pose_shuffle",
|
| 34 |
+
"wrong_video",
|
| 35 |
+
"local_context_overlap_fake_revisit",
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _clamp01(value: float) -> float:
|
| 41 |
+
return max(0.0, min(1.0, float(value)))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def noise_bucket_from_denoising_fraction(denoising_fraction: float | None) -> str:
|
| 45 |
+
if denoising_fraction is None:
|
| 46 |
+
return "mid"
|
| 47 |
+
frac = _clamp01(float(denoising_fraction))
|
| 48 |
+
if frac < (1.0 / 3.0):
|
| 49 |
+
return "high"
|
| 50 |
+
if frac < (2.0 / 3.0):
|
| 51 |
+
return "mid"
|
| 52 |
+
return "low"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def noise_bucket_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> str:
|
| 56 |
+
if noise_levels is None or timesteps is None or int(timesteps) <= 1:
|
| 57 |
+
return "mid"
|
| 58 |
+
noise_fraction = _clamp01(float(noise_levels.detach().float().mean().item()) / float(int(timesteps) - 1))
|
| 59 |
+
if noise_fraction >= (2.0 / 3.0):
|
| 60 |
+
return "high"
|
| 61 |
+
if noise_fraction >= (1.0 / 3.0):
|
| 62 |
+
return "mid"
|
| 63 |
+
return "low"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def noise_bucket_ids_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> torch.Tensor | None:
|
| 67 |
+
if noise_levels is None or timesteps is None or int(timesteps) <= 1:
|
| 68 |
+
return None
|
| 69 |
+
noise_fraction = noise_levels.detach().float() / float(int(timesteps) - 1)
|
| 70 |
+
bucket_ids = torch.full_like(noise_levels, NOISE_BUCKET_TO_ID["mid"], dtype=torch.long)
|
| 71 |
+
bucket_ids = torch.where(
|
| 72 |
+
noise_fraction >= (2.0 / 3.0),
|
| 73 |
+
torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["high"]),
|
| 74 |
+
bucket_ids,
|
| 75 |
+
)
|
| 76 |
+
bucket_ids = torch.where(
|
| 77 |
+
noise_fraction < (1.0 / 3.0),
|
| 78 |
+
torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["low"]),
|
| 79 |
+
bucket_ids,
|
| 80 |
+
)
|
| 81 |
+
return bucket_ids
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def denoising_fraction_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> float | None:
|
| 85 |
+
if noise_levels is None or timesteps is None or int(timesteps) <= 1:
|
| 86 |
+
return None
|
| 87 |
+
noise_fraction = _clamp01(float(noise_levels.detach().float().mean().item()) / float(int(timesteps) - 1))
|
| 88 |
+
return _clamp01(1.0 - noise_fraction)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def normalize_eval_ablation_branch(branch: str | None) -> str:
|
| 92 |
+
if branch is None:
|
| 93 |
+
return "A_plus_D_plus_R_normal"
|
| 94 |
+
branch = str(branch)
|
| 95 |
+
if branch not in EVAL_ABLATION_BRANCH_TO_ID:
|
| 96 |
+
raise ValueError(f"unknown DeMemWM eval ablation branch: {branch}")
|
| 97 |
+
return branch
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def normalize_noise_bucket(noise_bucket: str | None) -> str:
|
| 101 |
+
if noise_bucket in NOISE_BUCKET_TO_ID:
|
| 102 |
+
return str(noise_bucket)
|
| 103 |
+
return "mid"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
_STAGE_ENABLES = {
|
| 107 |
+
'stage_1': (True, True, True),
|
| 108 |
+
'stage_2': (True, True, True),
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@dataclass(frozen=True)
|
| 114 |
+
class CurriculumState:
|
| 115 |
+
"""Step-resolved DeMemWM curriculum/freezing state for one continuous run."""
|
| 116 |
+
|
| 117 |
+
global_step: int
|
| 118 |
+
enabled: bool
|
| 119 |
+
stage: str
|
| 120 |
+
anchor_enabled: bool
|
| 121 |
+
dynamic_enabled: bool
|
| 122 |
+
revisit_enabled: bool
|
| 123 |
+
dit_train_state: str
|
| 124 |
+
freeze_vae: bool
|
| 125 |
+
dememwm_lr: float
|
| 126 |
+
memory_adapter_lr: float
|
| 127 |
+
full_dit_lr: float
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@property
|
| 131 |
+
def dit_full_trainable(self) -> bool:
|
| 132 |
+
return self.dit_train_state == "full"
|
| 133 |
+
|
| 134 |
+
def diagnostics(self) -> dict[str, Any]:
|
| 135 |
+
return {
|
| 136 |
+
"dememwm_global_step": self.global_step,
|
| 137 |
+
"dememwm_curriculum_enabled": self.enabled,
|
| 138 |
+
"dememwm_stage": self.stage,
|
| 139 |
+
"curriculum_anchor_enabled": self.anchor_enabled,
|
| 140 |
+
"curriculum_dynamic_enabled": self.dynamic_enabled,
|
| 141 |
+
"curriculum_revisit_enabled": self.revisit_enabled,
|
| 142 |
+
"dit_train_state": self.dit_train_state,
|
| 143 |
+
"dit_full_trainable": self.dit_full_trainable,
|
| 144 |
+
"freeze_vae": self.freeze_vae,
|
| 145 |
+
"lr_dememwm_modules": self.dememwm_lr,
|
| 146 |
+
"lr_memory_adapters": self.memory_adapter_lr,
|
| 147 |
+
"lr_full_dit": self.full_dit_lr,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _cfg_get(obj: Any, name: str, default: Any) -> Any:
|
| 152 |
+
return getattr(obj, name, default) if obj is not None else default
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _stage_for_step(curriculum_cfg: Any, step: int) -> str:
|
| 156 |
+
full_start = int(_cfg_get(curriculum_cfg, 'full_stage_start_step', 60000))
|
| 157 |
+
return 'stage_2' if step >= full_start else 'stage_1'
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _dit_train_state(curriculum_cfg: Any, step: int) -> str:
|
| 161 |
+
freeze_cfg = _cfg_get(curriculum_cfg, 'dit_freeze', None)
|
| 162 |
+
freeze_enabled = bool(_cfg_get(freeze_cfg, 'enabled', True))
|
| 163 |
+
full_step = int(_cfg_get(curriculum_cfg, 'full_stage_start_step', 60000))
|
| 164 |
+
if freeze_enabled and step < full_step:
|
| 165 |
+
return 'frozen'
|
| 166 |
+
return 'full'
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def resolve_curriculum(memory_cfg: Any, global_step: int | None = None) -> CurriculumState:
|
| 170 |
+
"""Resolve internal DeMemWM curriculum phase from Lightning global_step.
|
| 171 |
+
|
| 172 |
+
This intentionally supports one continuous training run; stage names are internal
|
| 173 |
+
gates only and do not imply separate jobs/checkpoints.
|
| 174 |
+
"""
|
| 175 |
+
|
| 176 |
+
step = max(0, int(global_step or 0))
|
| 177 |
+
curriculum_cfg = _cfg_get(memory_cfg, "curriculum", None)
|
| 178 |
+
lr_cfg = _cfg_get(curriculum_cfg, "lr", None)
|
| 179 |
+
enabled = bool(_cfg_get(curriculum_cfg, "enabled", False))
|
| 180 |
+
|
| 181 |
+
if enabled:
|
| 182 |
+
stage = _stage_for_step(curriculum_cfg, step)
|
| 183 |
+
dit_state = _dit_train_state(curriculum_cfg, step)
|
| 184 |
+
freeze_vae = bool(_cfg_get(curriculum_cfg, "freeze_vae", True))
|
| 185 |
+
else:
|
| 186 |
+
stage = str(_cfg_get(memory_cfg, "training_stage", "stage_1"))
|
| 187 |
+
dit_state = "full"
|
| 188 |
+
freeze_vae = True
|
| 189 |
+
|
| 190 |
+
if stage not in _STAGE_ENABLES:
|
| 191 |
+
raise ValueError(f"unknown DeMemWM stage: {stage}")
|
| 192 |
+
anchor_on, dynamic_on, revisit_on = _STAGE_ENABLES[stage]
|
| 193 |
+
return CurriculumState(
|
| 194 |
+
global_step=step,
|
| 195 |
+
enabled=enabled,
|
| 196 |
+
stage=stage,
|
| 197 |
+
anchor_enabled=anchor_on,
|
| 198 |
+
dynamic_enabled=dynamic_on,
|
| 199 |
+
revisit_enabled=revisit_on,
|
| 200 |
+
dit_train_state=dit_state,
|
| 201 |
+
freeze_vae=freeze_vae,
|
| 202 |
+
dememwm_lr=float(_cfg_get(lr_cfg, "dememwm_modules", 1.0e-4)),
|
| 203 |
+
memory_adapter_lr=float(_cfg_get(lr_cfg, "memory_adapters", 1.0e-4)),
|
| 204 |
+
full_dit_lr=float(_cfg_get(lr_cfg, "full_dit", 1.0e-5)),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
DeMemWMCurriculumState = CurriculumState
|
| 209 |
+
resolve_dememwm_curriculum = resolve_curriculum
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def compute_stream_gates(stage: str, denoising_fraction: float | None = None, debug_force_all_streams: bool = False, anchor_gate: float = 1.0, dynamic_gate: float = 1.0, revisit_gate: float = 1.0) -> StreamGateState:
|
| 213 |
+
if debug_force_all_streams:
|
| 214 |
+
return StreamGateState(True, True, True, float(anchor_gate), float(dynamic_gate), float(revisit_gate), stage, "debug_force_all_streams")
|
| 215 |
+
if stage not in _STAGE_ENABLES:
|
| 216 |
+
raise ValueError(f"unknown DeMemWM stage: {stage}")
|
| 217 |
+
a_on, d_on, r_on = _STAGE_ENABLES[stage]
|
| 218 |
+
if denoising_fraction is not None:
|
| 219 |
+
denoising_fraction = max(0.0, min(1.0, float(denoising_fraction)))
|
| 220 |
+
stage_scale = 0.25 + 0.75 * denoising_fraction
|
| 221 |
+
else:
|
| 222 |
+
stage_scale = 1.0
|
| 223 |
+
return StreamGateState(a_on, d_on, r_on, float(anchor_gate) if a_on else 0.0, float(dynamic_gate) * stage_scale if d_on else 0.0, float(revisit_gate) * stage_scale if r_on else 0.0, stage, "stage_schedule")
|
algorithms/worldmem/dememwm/types.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from typing import Any, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MemorySourceType(str, Enum):
|
| 12 |
+
PREFIX_GT = "prefix_gt"
|
| 13 |
+
GENERATED = "generated"
|
| 14 |
+
DYNAMIC = "dynamic"
|
| 15 |
+
REVISIT = "revisit"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class MemoryRecord:
|
| 20 |
+
"""One causal DeMemWM memory item.
|
| 21 |
+
|
| 22 |
+
Frame ranges use an inclusive start and exclusive end: [source_start, source_end).
|
| 23 |
+
All source frame indices represented by this record must be strictly smaller
|
| 24 |
+
than a queried target frame unless the caller is explicitly querying an
|
| 25 |
+
already-committed prefix frame.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
tokens: torch.Tensor
|
| 29 |
+
mask: torch.Tensor
|
| 30 |
+
source_start: int
|
| 31 |
+
source_end: int
|
| 32 |
+
frame_indices: torch.Tensor
|
| 33 |
+
pose: Optional[torch.Tensor]
|
| 34 |
+
source_type: MemorySourceType
|
| 35 |
+
is_generated: bool
|
| 36 |
+
score: Optional[float | torch.Tensor] = None
|
| 37 |
+
chunk_id: Optional[str] = None
|
| 38 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 39 |
+
|
| 40 |
+
def __post_init__(self) -> None:
|
| 41 |
+
if self.source_end <= self.source_start:
|
| 42 |
+
raise ValueError("source_end must be greater than source_start")
|
| 43 |
+
if self.tokens.ndim < 2:
|
| 44 |
+
raise ValueError("tokens must include slot and channel dimensions")
|
| 45 |
+
self.mask = self.mask.bool()
|
| 46 |
+
if self.mask.ndim != 1:
|
| 47 |
+
raise ValueError("mask must have shape (M,)")
|
| 48 |
+
if self.mask.shape[0] != self.tokens.shape[0]:
|
| 49 |
+
raise ValueError("mask length must match token slots")
|
| 50 |
+
if self.source_type == MemorySourceType.PREFIX_GT and self.is_generated:
|
| 51 |
+
raise ValueError("generated records cannot be PREFIX_GT anchors")
|
| 52 |
+
if self.frame_indices.numel() == 0:
|
| 53 |
+
raise ValueError("frame_indices cannot be empty")
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def max_source_frame(self) -> int:
|
| 57 |
+
return int(self.frame_indices.detach().max().item())
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def valid_slots(self) -> int:
|
| 61 |
+
return int(self.mask.detach().sum().item())
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class MemoryStreamTensors:
|
| 66 |
+
anchor_tokens: torch.Tensor
|
| 67 |
+
anchor_mask: torch.Tensor
|
| 68 |
+
dynamic_tokens: torch.Tensor
|
| 69 |
+
dynamic_mask: torch.Tensor
|
| 70 |
+
revisit_tokens: torch.Tensor
|
| 71 |
+
revisit_mask: torch.Tensor
|
| 72 |
+
anchor_gate: torch.Tensor | float
|
| 73 |
+
dynamic_gate: torch.Tensor | float
|
| 74 |
+
revisit_gate: torch.Tensor | float
|
| 75 |
+
revisit_gate_raw: torch.Tensor | None = None
|
| 76 |
+
valid_revisit_mask: torch.Tensor | None = None
|
| 77 |
+
no_valid_revisit_mask: torch.Tensor | None = None
|
| 78 |
+
diagnostics: dict[str, Any] = field(default_factory=dict)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@dataclass(frozen=True)
|
| 82 |
+
class StreamGateState:
|
| 83 |
+
anchor_enabled: bool
|
| 84 |
+
dynamic_enabled: bool
|
| 85 |
+
revisit_enabled: bool
|
| 86 |
+
anchor_gate: float
|
| 87 |
+
dynamic_gate: float
|
| 88 |
+
revisit_gate: float
|
| 89 |
+
stage: str
|
| 90 |
+
reason: str = ""
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
@dataclass
|
| 94 |
+
class RevisitRetrievalResult:
|
| 95 |
+
records: list[MemoryRecord]
|
| 96 |
+
scores: torch.Tensor
|
| 97 |
+
selected_frame_ids: list[int]
|
| 98 |
+
diagnostics: dict[str, Any]
|
algorithms/worldmem/dememwm_memory_dit.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
from .dememwm.algorithm import MemoryDiTMixin
|
| 5 |
+
from .df_video import BaseVideoDiTMinecraft
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DeMemWMMinecraft(MemoryDiTMixin, BaseVideoDiTMinecraft):
|
| 9 |
+
"""Standalone DeMemWM / Memory-DiT algorithm.
|
| 10 |
+
|
| 11 |
+
Reuses the base video-DiT VAE/diffusion/training infrastructure,
|
| 12 |
+
but owns memory construction/injection. Does not route through the legacy memory method.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
DeMemWMMemoryDiTMinecraft = DeMemWMMinecraft
|
algorithms/worldmem/df_base.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
|
| 3 |
+
template [repo](https://github.com/buoyancy99/research-template).
|
| 4 |
+
By its MIT license, you must keep the above sentence in `README.md`
|
| 5 |
+
and the `LICENSE` file to credit the author.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from omegaconf import DictConfig
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from typing import Any
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
|
| 17 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
| 18 |
+
|
| 19 |
+
from algorithms.common.base_pytorch_algo import BasePytorchAlgo
|
| 20 |
+
from .models.diffusion import Diffusion
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DiffusionForcingBase(BasePytorchAlgo):
|
| 24 |
+
def __init__(self, cfg: DictConfig):
|
| 25 |
+
self.cfg = cfg
|
| 26 |
+
self.x_shape = cfg.x_shape
|
| 27 |
+
self.frame_stack = cfg.frame_stack
|
| 28 |
+
self.x_stacked_shape = list(self.x_shape)
|
| 29 |
+
self.x_stacked_shape[0] *= cfg.frame_stack
|
| 30 |
+
self.guidance_scale = cfg.guidance_scale
|
| 31 |
+
self.context_frames = cfg.context_frames
|
| 32 |
+
self.chunk_size = cfg.chunk_size
|
| 33 |
+
self.action_cond_dim = cfg.action_cond_dim
|
| 34 |
+
self.causal = cfg.causal
|
| 35 |
+
|
| 36 |
+
self.uncertainty_scale = cfg.uncertainty_scale
|
| 37 |
+
self.timesteps = cfg.diffusion.timesteps
|
| 38 |
+
self.sampling_timesteps = cfg.diffusion.sampling_timesteps
|
| 39 |
+
self.clip_noise = cfg.diffusion.clip_noise
|
| 40 |
+
|
| 41 |
+
self.cfg.diffusion.cum_snr_decay = self.cfg.diffusion.cum_snr_decay ** (self.frame_stack * cfg.frame_skip)
|
| 42 |
+
|
| 43 |
+
self.validation_step_outputs = []
|
| 44 |
+
super().__init__(cfg)
|
| 45 |
+
|
| 46 |
+
def _build_model(self):
|
| 47 |
+
self.diffusion_model = Diffusion(
|
| 48 |
+
x_shape=self.x_stacked_shape,
|
| 49 |
+
action_cond_dim=self.action_cond_dim,
|
| 50 |
+
is_causal=self.causal,
|
| 51 |
+
cfg=self.cfg.diffusion,
|
| 52 |
+
)
|
| 53 |
+
self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std)
|
| 54 |
+
|
| 55 |
+
def configure_optimizers(self):
|
| 56 |
+
params = tuple(self.diffusion_model.parameters())
|
| 57 |
+
optimizer_dynamics = torch.optim.AdamW(
|
| 58 |
+
params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay, betas=self.cfg.optimizer_beta
|
| 59 |
+
)
|
| 60 |
+
return optimizer_dynamics
|
| 61 |
+
|
| 62 |
+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
|
| 63 |
+
# update params
|
| 64 |
+
optimizer.step(closure=optimizer_closure)
|
| 65 |
+
|
| 66 |
+
# manually warm up lr without a scheduler
|
| 67 |
+
if self.trainer.global_step < self.cfg.warmup_steps:
|
| 68 |
+
lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.cfg.warmup_steps)
|
| 69 |
+
for pg in optimizer.param_groups:
|
| 70 |
+
pg["lr"] = lr_scale * self.cfg.lr
|
| 71 |
+
|
| 72 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
| 73 |
+
xs, conditions, masks = self._preprocess_batch(batch)
|
| 74 |
+
|
| 75 |
+
rand_length = torch.randint(3,xs.shape[0]-2, (1,))[0].item()
|
| 76 |
+
xs = torch.cat([xs[:rand_length], xs[rand_length-3:rand_length-1]])
|
| 77 |
+
conditions = torch.cat([conditions[:rand_length], conditions[rand_length-3:rand_length-1]])
|
| 78 |
+
masks = torch.cat([masks[:rand_length], masks[rand_length-3:rand_length-1]])
|
| 79 |
+
noise_levels=self._generate_noise_levels(xs)
|
| 80 |
+
noise_levels[:rand_length] = 15 # stable_noise_levels
|
| 81 |
+
noise_levels[rand_length+1:] = 15 # stable_noise_levels
|
| 82 |
+
|
| 83 |
+
xs_pred, loss = self.diffusion_model(xs, conditions, noise_levels=noise_levels)
|
| 84 |
+
loss = self.reweight_loss(loss, masks)
|
| 85 |
+
|
| 86 |
+
# log the loss
|
| 87 |
+
if batch_idx % 20 == 0:
|
| 88 |
+
self.log("training/loss", loss)
|
| 89 |
+
|
| 90 |
+
xs = self._unstack_and_unnormalize(xs)
|
| 91 |
+
xs_pred = self._unstack_and_unnormalize(xs_pred)
|
| 92 |
+
|
| 93 |
+
output_dict = {
|
| 94 |
+
"loss": loss,
|
| 95 |
+
"xs_pred": xs_pred,
|
| 96 |
+
"xs": xs,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
return output_dict
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
|
| 103 |
+
xs, conditions, masks = self._preprocess_batch(batch)
|
| 104 |
+
n_frames, batch_size, *_ = xs.shape
|
| 105 |
+
xs_pred = []
|
| 106 |
+
curr_frame = 0
|
| 107 |
+
|
| 108 |
+
# context
|
| 109 |
+
n_context_frames = self.context_frames // self.frame_stack
|
| 110 |
+
xs_pred = xs[:n_context_frames].clone()
|
| 111 |
+
curr_frame += n_context_frames
|
| 112 |
+
|
| 113 |
+
if self.condtion_similar_length:
|
| 114 |
+
n_frames -= self.condtion_similar_length
|
| 115 |
+
|
| 116 |
+
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
| 117 |
+
while curr_frame < n_frames:
|
| 118 |
+
if self.chunk_size > 0:
|
| 119 |
+
horizon = min(n_frames - curr_frame, self.chunk_size)
|
| 120 |
+
else:
|
| 121 |
+
horizon = n_frames - curr_frame
|
| 122 |
+
assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
|
| 123 |
+
scheduling_matrix = self._generate_scheduling_matrix(horizon)
|
| 124 |
+
|
| 125 |
+
chunk = torch.randn((horizon, batch_size, *self.x_stacked_shape), device=self.device)
|
| 126 |
+
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
|
| 127 |
+
xs_pred = torch.cat([xs_pred, chunk], 0)
|
| 128 |
+
|
| 129 |
+
# sliding window: only input the last n_tokens frames
|
| 130 |
+
start_frame = max(0, curr_frame + horizon - self.n_tokens)
|
| 131 |
+
|
| 132 |
+
pbar.set_postfix(
|
| 133 |
+
{
|
| 134 |
+
"start": start_frame,
|
| 135 |
+
"end": curr_frame + horizon,
|
| 136 |
+
}
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if self.condtion_similar_length:
|
| 140 |
+
xs_pred = torch.cat([xs_pred, xs[curr_frame-self.condtion_similar_length:curr_frame].clone()], 0)
|
| 141 |
+
|
| 142 |
+
for m in range(scheduling_matrix.shape[0] - 1):
|
| 143 |
+
|
| 144 |
+
from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[
|
| 145 |
+
:, None
|
| 146 |
+
].repeat(batch_size, axis=1)
|
| 147 |
+
to_noise_levels = np.concatenate(
|
| 148 |
+
(
|
| 149 |
+
np.zeros((curr_frame,), dtype=np.int64),
|
| 150 |
+
scheduling_matrix[m + 1],
|
| 151 |
+
)
|
| 152 |
+
)[
|
| 153 |
+
:, None
|
| 154 |
+
].repeat(batch_size, axis=1)
|
| 155 |
+
|
| 156 |
+
if self.condtion_similar_length:
|
| 157 |
+
from_noise_levels = np.concatenate([from_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
|
| 158 |
+
to_noise_levels = np.concatenate([to_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
|
| 159 |
+
|
| 160 |
+
from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
|
| 161 |
+
to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
|
| 162 |
+
|
| 163 |
+
# update xs_pred by DDIM or DDPM sampling
|
| 164 |
+
# input frames within the sliding window
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
input_condition = conditions[start_frame : curr_frame + horizon].clone()
|
| 168 |
+
except:
|
| 169 |
+
import pdb;pdb.set_trace()
|
| 170 |
+
if self.condtion_similar_length:
|
| 171 |
+
input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], conditions[-self.condtion_similar_length:]], dim=0)
|
| 172 |
+
xs_pred[start_frame:] = self.diffusion_model.sample_step(
|
| 173 |
+
xs_pred[start_frame:],
|
| 174 |
+
input_condition,
|
| 175 |
+
from_noise_levels[start_frame:],
|
| 176 |
+
to_noise_levels[start_frame:],
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
if self.condtion_similar_length:
|
| 180 |
+
xs_pred = xs_pred[:-self.condtion_similar_length]
|
| 181 |
+
|
| 182 |
+
curr_frame += horizon
|
| 183 |
+
pbar.update(horizon)
|
| 184 |
+
|
| 185 |
+
if self.condtion_similar_length:
|
| 186 |
+
xs = xs[:-self.condtion_similar_length]
|
| 187 |
+
# FIXME: loss
|
| 188 |
+
loss = F.mse_loss(xs_pred, xs, reduction="none")
|
| 189 |
+
loss = self.reweight_loss(loss, masks)
|
| 190 |
+
self.validation_step_outputs.append((xs_pred.detach().cpu(), xs.detach().cpu()))
|
| 191 |
+
|
| 192 |
+
return loss
|
| 193 |
+
|
| 194 |
+
def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
|
| 195 |
+
return self.validation_step(*args, **kwargs, namespace="test")
|
| 196 |
+
|
| 197 |
+
def on_test_epoch_end(self) -> None:
|
| 198 |
+
self.on_validation_epoch_end(namespace="test")
|
| 199 |
+
|
| 200 |
+
def _generate_noise_levels(self, xs: torch.Tensor, masks: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 201 |
+
"""
|
| 202 |
+
Generate noise levels for training.
|
| 203 |
+
"""
|
| 204 |
+
num_frames, batch_size, *_ = xs.shape
|
| 205 |
+
match self.cfg.noise_level:
|
| 206 |
+
case "random_all": # entirely random noise levels
|
| 207 |
+
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
|
| 208 |
+
case "same":
|
| 209 |
+
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
|
| 210 |
+
noise_levels[1:] = noise_levels[0]
|
| 211 |
+
|
| 212 |
+
if masks is not None:
|
| 213 |
+
# for frames that are not available, treat as full noise
|
| 214 |
+
discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
|
| 215 |
+
noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
|
| 216 |
+
|
| 217 |
+
return noise_levels
|
| 218 |
+
|
| 219 |
+
def _generate_scheduling_matrix(self, horizon: int):
|
| 220 |
+
match self.cfg.scheduling_matrix:
|
| 221 |
+
case "pyramid":
|
| 222 |
+
return self._generate_pyramid_scheduling_matrix(horizon, self.uncertainty_scale)
|
| 223 |
+
case "full_sequence":
|
| 224 |
+
return np.arange(self.sampling_timesteps, -1, -1)[:, None].repeat(horizon, axis=1)
|
| 225 |
+
case "autoregressive":
|
| 226 |
+
return self._generate_pyramid_scheduling_matrix(horizon, self.sampling_timesteps)
|
| 227 |
+
case "trapezoid":
|
| 228 |
+
return self._generate_trapezoid_scheduling_matrix(horizon, self.uncertainty_scale)
|
| 229 |
+
|
| 230 |
+
def _generate_pyramid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
|
| 231 |
+
height = self.sampling_timesteps + int((horizon - 1) * uncertainty_scale) + 1
|
| 232 |
+
scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
|
| 233 |
+
for m in range(height):
|
| 234 |
+
for t in range(horizon):
|
| 235 |
+
scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
|
| 236 |
+
|
| 237 |
+
return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
|
| 238 |
+
|
| 239 |
+
def _generate_trapezoid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
|
| 240 |
+
height = self.sampling_timesteps + int((horizon + 1) // 2 * uncertainty_scale)
|
| 241 |
+
scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
|
| 242 |
+
for m in range(height):
|
| 243 |
+
for t in range((horizon + 1) // 2):
|
| 244 |
+
scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
|
| 245 |
+
scheduling_matrix[m, -t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
|
| 246 |
+
|
| 247 |
+
return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
|
| 248 |
+
|
| 249 |
+
def reweight_loss(self, loss, weight=None):
|
| 250 |
+
# Note there is another part of loss reweighting (fused_snr) inside the Diffusion class!
|
| 251 |
+
loss = rearrange(loss, "t b (fs c) ... -> t b fs c ...", fs=self.frame_stack)
|
| 252 |
+
if weight is not None:
|
| 253 |
+
expand_dim = len(loss.shape) - len(weight.shape) - 1
|
| 254 |
+
weight = rearrange(
|
| 255 |
+
weight,
|
| 256 |
+
"(t fs) b ... -> t b fs ..." + " 1" * expand_dim,
|
| 257 |
+
fs=self.frame_stack,
|
| 258 |
+
)
|
| 259 |
+
loss = loss * weight
|
| 260 |
+
|
| 261 |
+
return loss.mean()
|
| 262 |
+
|
| 263 |
+
def _preprocess_batch(self, batch):
|
| 264 |
+
xs = batch[0]
|
| 265 |
+
batch_size, n_frames = xs.shape[:2]
|
| 266 |
+
|
| 267 |
+
if n_frames % self.frame_stack != 0:
|
| 268 |
+
raise ValueError("Number of frames must be divisible by frame stack size")
|
| 269 |
+
if self.context_frames % self.frame_stack != 0:
|
| 270 |
+
raise ValueError("Number of context frames must be divisible by frame stack size")
|
| 271 |
+
|
| 272 |
+
masks = torch.ones(n_frames, batch_size).to(xs.device)
|
| 273 |
+
n_frames = n_frames // self.frame_stack
|
| 274 |
+
|
| 275 |
+
if self.action_cond_dim:
|
| 276 |
+
conditions = batch[1]
|
| 277 |
+
conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
|
| 278 |
+
conditions = rearrange(conditions, "b (t fs) d -> t b (fs d)", fs=self.frame_stack).contiguous()
|
| 279 |
+
|
| 280 |
+
# f, _, _ = conditions.shape
|
| 281 |
+
# predefined_1 = torch.tensor([0,0,0,1]).to(conditions.device)
|
| 282 |
+
# predefined_2 = torch.tensor([0,0,1,0]).to(conditions.device)
|
| 283 |
+
# conditions[:f//2] = predefined_1
|
| 284 |
+
# conditions[f//2:] = predefined_2
|
| 285 |
+
else:
|
| 286 |
+
conditions = [None for _ in range(n_frames)]
|
| 287 |
+
|
| 288 |
+
xs = self._normalize_x(xs)
|
| 289 |
+
xs = rearrange(xs, "b (t fs) c ... -> t b (fs c) ...", fs=self.frame_stack).contiguous()
|
| 290 |
+
|
| 291 |
+
return xs, conditions, masks
|
| 292 |
+
|
| 293 |
+
def _normalize_x(self, xs):
|
| 294 |
+
shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
|
| 295 |
+
mean = self.data_mean.reshape(shape)
|
| 296 |
+
std = self.data_std.reshape(shape)
|
| 297 |
+
return (xs - mean) / std
|
| 298 |
+
|
| 299 |
+
def _unnormalize_x(self, xs):
|
| 300 |
+
shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
|
| 301 |
+
mean = self.data_mean.reshape(shape)
|
| 302 |
+
std = self.data_std.reshape(shape)
|
| 303 |
+
return xs * std + mean
|
| 304 |
+
|
| 305 |
+
def _unstack_and_unnormalize(self, xs):
|
| 306 |
+
xs = rearrange(xs, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack)
|
| 307 |
+
return self._unnormalize_x(xs)
|
algorithms/worldmem/df_video.py
ADDED
|
@@ -0,0 +1,926 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torchvision.transforms.functional as TF
|
| 8 |
+
from torchvision.transforms import InterpolationMode
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from packaging import version as pver
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from omegaconf import DictConfig
|
| 14 |
+
from lightning.pytorch.utilities.types import STEP_OUTPUT
|
| 15 |
+
from algorithms.common.metrics import (
|
| 16 |
+
LearnedPerceptualImagePatchSimilarity,
|
| 17 |
+
)
|
| 18 |
+
from utils.logging_utils import log_video, get_validation_metrics_for_videos
|
| 19 |
+
from .df_base import DiffusionForcingBase
|
| 20 |
+
from .models.vae import VAE_models
|
| 21 |
+
from .models.diffusion import Diffusion
|
| 22 |
+
from .models.pose_prediction import PosePredictionNet
|
| 23 |
+
import glob
|
| 24 |
+
|
| 25 |
+
# Utility Functions
|
| 26 |
+
def euler_to_rotation_matrix(pitch, yaw):
|
| 27 |
+
"""
|
| 28 |
+
Convert pitch and yaw angles (in radians) to a 3x3 rotation matrix.
|
| 29 |
+
Supports batch input.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
pitch (torch.Tensor): Pitch angles in radians.
|
| 33 |
+
yaw (torch.Tensor): Yaw angles in radians.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
torch.Tensor: Rotation matrix of shape (batch_size, 3, 3).
|
| 37 |
+
"""
|
| 38 |
+
cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch)
|
| 39 |
+
cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw)
|
| 40 |
+
|
| 41 |
+
R_pitch = torch.stack([
|
| 42 |
+
torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
|
| 43 |
+
torch.zeros_like(pitch), cos_pitch, -sin_pitch,
|
| 44 |
+
torch.zeros_like(pitch), sin_pitch, cos_pitch
|
| 45 |
+
], dim=-1).reshape(-1, 3, 3)
|
| 46 |
+
|
| 47 |
+
R_yaw = torch.stack([
|
| 48 |
+
cos_yaw, torch.zeros_like(yaw), sin_yaw,
|
| 49 |
+
torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
|
| 50 |
+
-sin_yaw, torch.zeros_like(yaw), cos_yaw
|
| 51 |
+
], dim=-1).reshape(-1, 3, 3)
|
| 52 |
+
|
| 53 |
+
return torch.matmul(R_yaw, R_pitch)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def euler_to_camera_to_world_matrix(pose):
|
| 57 |
+
"""
|
| 58 |
+
Convert (x, y, z, pitch, yaw) to a 4x4 camera-to-world transformation matrix using torch.
|
| 59 |
+
Supports both (5,) and (f, b, 5) shaped inputs.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
pose (torch.Tensor): Pose tensor of shape (5,) or (f, b, 5).
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
torch.Tensor: Camera-to-world transformation matrix of shape (4, 4).
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
origin_dim = pose.ndim
|
| 69 |
+
if origin_dim == 1:
|
| 70 |
+
pose = pose.unsqueeze(0).unsqueeze(0) # Convert (5,) -> (1, 1, 5)
|
| 71 |
+
elif origin_dim == 2:
|
| 72 |
+
pose = pose.unsqueeze(0)
|
| 73 |
+
|
| 74 |
+
x, y, z, pitch, yaw = pose[..., 0], pose[..., 1], pose[..., 2], pose[..., 3], pose[..., 4]
|
| 75 |
+
pitch, yaw = torch.deg2rad(pitch), torch.deg2rad(yaw)
|
| 76 |
+
|
| 77 |
+
# Compute rotation matrix (batch mode)
|
| 78 |
+
R = euler_to_rotation_matrix(pitch, yaw) # Shape (f*b, 3, 3)
|
| 79 |
+
|
| 80 |
+
# Create the 4x4 transformation matrix
|
| 81 |
+
eye = torch.eye(4, dtype=torch.float32, device=pose.device)
|
| 82 |
+
camera_to_world = eye.repeat(R.shape[0], 1, 1) # Shape (f*b, 4, 4)
|
| 83 |
+
|
| 84 |
+
# Assign rotation
|
| 85 |
+
camera_to_world[:, :3, :3] = R
|
| 86 |
+
|
| 87 |
+
# Assign translation
|
| 88 |
+
camera_to_world[:, :3, 3] = torch.stack([x.reshape(-1), y.reshape(-1), z.reshape(-1)], dim=-1)
|
| 89 |
+
|
| 90 |
+
# Reshape back to (f, b, 4, 4) if needed
|
| 91 |
+
if origin_dim == 3:
|
| 92 |
+
return camera_to_world.view(pose.shape[0], pose.shape[1], 4, 4)
|
| 93 |
+
elif origin_dim == 2:
|
| 94 |
+
return camera_to_world.view(pose.shape[0], 4, 4)
|
| 95 |
+
else:
|
| 96 |
+
return camera_to_world.squeeze(0).squeeze(0) # Convert (1,1,4,4) -> (4,4)
|
| 97 |
+
|
| 98 |
+
def is_inside_fov_3d_hv(points, center, center_pitch, center_yaw, fov_half_h, fov_half_v):
|
| 99 |
+
"""
|
| 100 |
+
Check whether points are within a given 3D field of view (FOV)
|
| 101 |
+
with separately defined horizontal and vertical ranges.
|
| 102 |
+
|
| 103 |
+
The center view direction is specified by pitch and yaw (in degrees).
|
| 104 |
+
|
| 105 |
+
:param points: (N, B, 3) Sample point coordinates
|
| 106 |
+
:param center: (3,) Center coordinates of the FOV
|
| 107 |
+
:param center_pitch: Pitch angle of the center view (in degrees)
|
| 108 |
+
:param center_yaw: Yaw angle of the center view (in degrees)
|
| 109 |
+
:param fov_half_h: Horizontal half-FOV angle (in degrees)
|
| 110 |
+
:param fov_half_v: Vertical half-FOV angle (in degrees)
|
| 111 |
+
:return: Boolean tensor (N, B), indicating whether each point is inside the FOV
|
| 112 |
+
"""
|
| 113 |
+
# Compute vectors relative to the center
|
| 114 |
+
vectors = points - center # shape (N, B, 3)
|
| 115 |
+
x = vectors[..., 0]
|
| 116 |
+
y = vectors[..., 1]
|
| 117 |
+
z = vectors[..., 2]
|
| 118 |
+
|
| 119 |
+
# Compute horizontal angle (yaw): measured with respect to the z-axis as the forward direction,
|
| 120 |
+
# and the x-axis as left-right, resulting in a range of -180 to 180 degrees.
|
| 121 |
+
azimuth = torch.atan2(x, z) * (180 / math.pi)
|
| 122 |
+
|
| 123 |
+
# Compute vertical angle (pitch): measured with respect to the horizontal plane,
|
| 124 |
+
# resulting in a range of -90 to 90 degrees.
|
| 125 |
+
elevation = torch.atan2(y, torch.sqrt(x**2 + z**2)) * (180 / math.pi)
|
| 126 |
+
|
| 127 |
+
# Compute the angular difference from the center view (handling circular angle wrap-around)
|
| 128 |
+
diff_azimuth = (azimuth - center_yaw).abs() % 360
|
| 129 |
+
diff_elevation = (elevation - center_pitch).abs() % 360
|
| 130 |
+
|
| 131 |
+
# Adjust values greater than 180 degrees to the shorter angular difference
|
| 132 |
+
diff_azimuth = torch.where(diff_azimuth > 180, 360 - diff_azimuth, diff_azimuth)
|
| 133 |
+
diff_elevation = torch.where(diff_elevation > 180, 360 - diff_elevation, diff_elevation)
|
| 134 |
+
|
| 135 |
+
# Check if both horizontal and vertical angles are within their respective FOV limits
|
| 136 |
+
return (diff_azimuth < fov_half_h) & (diff_elevation < fov_half_v)
|
| 137 |
+
|
| 138 |
+
def generate_points_in_sphere(n_points, radius):
|
| 139 |
+
# Sample three independent uniform distributions
|
| 140 |
+
samples_r = torch.rand(n_points) # For radius distribution
|
| 141 |
+
samples_phi = torch.rand(n_points) # For azimuthal angle phi
|
| 142 |
+
samples_u = torch.rand(n_points) # For polar angle theta
|
| 143 |
+
|
| 144 |
+
# Apply cube root to ensure uniform volumetric distribution
|
| 145 |
+
r = radius * torch.pow(samples_r, 1/3)
|
| 146 |
+
# Azimuthal angle phi uniformly distributed in [0, 2π]
|
| 147 |
+
phi = 2 * math.pi * samples_phi
|
| 148 |
+
# Convert u to theta to ensure cos(theta) is uniformly distributed
|
| 149 |
+
theta = torch.acos(1 - 2 * samples_u)
|
| 150 |
+
|
| 151 |
+
# Convert spherical coordinates to Cartesian coordinates
|
| 152 |
+
x = r * torch.sin(theta) * torch.cos(phi)
|
| 153 |
+
y = r * torch.sin(theta) * torch.sin(phi)
|
| 154 |
+
z = r * torch.cos(theta)
|
| 155 |
+
|
| 156 |
+
points = torch.stack((x, y, z), dim=1)
|
| 157 |
+
return points
|
| 158 |
+
|
| 159 |
+
def tensor_max_with_number(tensor, number):
|
| 160 |
+
number_tensor = torch.tensor(number, dtype=tensor.dtype, device=tensor.device)
|
| 161 |
+
result = torch.max(tensor, number_tensor)
|
| 162 |
+
return result
|
| 163 |
+
|
| 164 |
+
def custom_meshgrid(*args):
|
| 165 |
+
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
| 166 |
+
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
| 167 |
+
return torch.meshgrid(*args)
|
| 168 |
+
else:
|
| 169 |
+
return torch.meshgrid(*args, indexing='ij')
|
| 170 |
+
|
| 171 |
+
def camera_to_world_to_world_to_camera(camera_to_world: torch.Tensor) -> torch.Tensor:
|
| 172 |
+
"""
|
| 173 |
+
Convert Camera-to-World matrices to World-to-Camera matrices for a tensor with shape (f, b, 4, 4).
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
camera_to_world (torch.Tensor): A tensor of shape (f, b, 4, 4), where:
|
| 177 |
+
f = number of frames,
|
| 178 |
+
b = batch size.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
torch.Tensor: A tensor of shape (f, b, 4, 4) representing the World-to-Camera matrices.
|
| 182 |
+
"""
|
| 183 |
+
# Ensure input is a 4D tensor
|
| 184 |
+
assert camera_to_world.ndim == 4 and camera_to_world.shape[2:] == (4, 4), \
|
| 185 |
+
"Input must be of shape (f, b, 4, 4)"
|
| 186 |
+
|
| 187 |
+
# Extract the rotation (R) and translation (T) parts
|
| 188 |
+
R = camera_to_world[:, :, :3, :3] # Shape: (f, b, 3, 3)
|
| 189 |
+
T = camera_to_world[:, :, :3, 3] # Shape: (f, b, 3)
|
| 190 |
+
|
| 191 |
+
# Initialize an identity matrix for the output
|
| 192 |
+
world_to_camera = torch.eye(4, device=camera_to_world.device).unsqueeze(0).unsqueeze(0)
|
| 193 |
+
world_to_camera = world_to_camera.repeat(camera_to_world.size(0), camera_to_world.size(1), 1, 1) # Shape: (f, b, 4, 4)
|
| 194 |
+
|
| 195 |
+
# Compute the rotation (transpose of R)
|
| 196 |
+
world_to_camera[:, :, :3, :3] = R.transpose(2, 3)
|
| 197 |
+
|
| 198 |
+
# Compute the translation (-R^T * T)
|
| 199 |
+
world_to_camera[:, :, :3, 3] = -torch.matmul(R.transpose(2, 3), T.unsqueeze(-1)).squeeze(-1)
|
| 200 |
+
|
| 201 |
+
return world_to_camera.to(camera_to_world.dtype)
|
| 202 |
+
|
| 203 |
+
def convert_to_plucker(poses, curr_frame, focal_length, image_width, image_height):
|
| 204 |
+
|
| 205 |
+
intrinsic = np.asarray([focal_length * image_width,
|
| 206 |
+
focal_length * image_height,
|
| 207 |
+
0.5 * image_width,
|
| 208 |
+
0.5 * image_height], dtype=np.float32)
|
| 209 |
+
|
| 210 |
+
c2ws = get_relative_pose(poses, zero_first_frame_scale=curr_frame)
|
| 211 |
+
c2ws = rearrange(c2ws, "t b m n -> b t m n")
|
| 212 |
+
|
| 213 |
+
K = torch.as_tensor(intrinsic, device=poses.device, dtype=poses.dtype).repeat(c2ws.shape[0],c2ws.shape[1],1) # [B, F, 4]
|
| 214 |
+
plucker_embedding = ray_condition(K, c2ws, image_height, image_width, device=c2ws.device)
|
| 215 |
+
plucker_embedding = rearrange(plucker_embedding, "b t h w d -> t b h w d").contiguous()
|
| 216 |
+
|
| 217 |
+
return plucker_embedding
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def get_relative_pose(abs_c2ws, zero_first_frame_scale):
|
| 221 |
+
abs_w2cs = camera_to_world_to_world_to_camera(abs_c2ws)
|
| 222 |
+
target_cam_c2w = torch.tensor([
|
| 223 |
+
[1, 0, 0, 0],
|
| 224 |
+
[0, 1, 0, 0],
|
| 225 |
+
[0, 0, 1, 0],
|
| 226 |
+
[0, 0, 0, 1]
|
| 227 |
+
]).to(abs_c2ws.device).to(abs_c2ws.dtype)
|
| 228 |
+
abs2rel = target_cam_c2w @ abs_w2cs[zero_first_frame_scale]
|
| 229 |
+
ret_poses = [abs2rel @ abs_c2w for abs_c2w in abs_c2ws]
|
| 230 |
+
ret_poses = torch.stack(ret_poses)
|
| 231 |
+
return ret_poses
|
| 232 |
+
|
| 233 |
+
def ray_condition(K, c2w, H, W, device):
|
| 234 |
+
# c2w: B, V, 4, 4
|
| 235 |
+
# K: B, V, 4
|
| 236 |
+
|
| 237 |
+
B = K.shape[0]
|
| 238 |
+
|
| 239 |
+
j, i = custom_meshgrid(
|
| 240 |
+
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
|
| 241 |
+
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
|
| 242 |
+
)
|
| 243 |
+
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 244 |
+
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
|
| 245 |
+
|
| 246 |
+
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
|
| 247 |
+
|
| 248 |
+
zs = torch.ones_like(i, device=device, dtype=c2w.dtype) # [B, HxW]
|
| 249 |
+
xs = -(i - cx) / fx * zs
|
| 250 |
+
ys = -(j - cy) / fy * zs
|
| 251 |
+
|
| 252 |
+
zs = zs.expand_as(ys)
|
| 253 |
+
|
| 254 |
+
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
|
| 255 |
+
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
|
| 256 |
+
|
| 257 |
+
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
|
| 258 |
+
rays_o = c2w[..., :3, 3] # B, V, 3
|
| 259 |
+
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
| 260 |
+
# c2w @ dirctions
|
| 261 |
+
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
| 262 |
+
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
| 263 |
+
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
| 264 |
+
|
| 265 |
+
return plucker
|
| 266 |
+
|
| 267 |
+
def random_transform(tensor):
|
| 268 |
+
"""
|
| 269 |
+
Apply the same random translation, rotation, and scaling to all frames in the batch.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
tensor (torch.Tensor): Input tensor of shape (F, B, 3, H, W).
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
torch.Tensor: Transformed tensor of shape (F, B, 3, H, W).
|
| 276 |
+
"""
|
| 277 |
+
if tensor.ndim != 5:
|
| 278 |
+
raise ValueError("Input tensor must have shape (F, B, 3, H, W)")
|
| 279 |
+
|
| 280 |
+
F, B, C, H, W = tensor.shape
|
| 281 |
+
|
| 282 |
+
# Generate random transformation parameters
|
| 283 |
+
max_translate = 0.2 # Translate up to 20% of width/height
|
| 284 |
+
max_rotate = 30 # Rotate up to 30 degrees
|
| 285 |
+
max_scale = 0.2 # Scale change by up to +/- 20%
|
| 286 |
+
|
| 287 |
+
translate_x = random.uniform(-max_translate, max_translate) * W
|
| 288 |
+
translate_y = random.uniform(-max_translate, max_translate) * H
|
| 289 |
+
rotate_angle = random.uniform(-max_rotate, max_rotate)
|
| 290 |
+
scale_factor = 1 + random.uniform(-max_scale, max_scale)
|
| 291 |
+
|
| 292 |
+
# Apply the same transformation to all frames and batches
|
| 293 |
+
|
| 294 |
+
tensor = tensor.reshape(F*B, C, H, W)
|
| 295 |
+
transformed_tensor = TF.affine(
|
| 296 |
+
tensor,
|
| 297 |
+
angle=rotate_angle,
|
| 298 |
+
translate=(translate_x, translate_y),
|
| 299 |
+
scale=scale_factor,
|
| 300 |
+
shear=(0, 0),
|
| 301 |
+
interpolation=InterpolationMode.BILINEAR,
|
| 302 |
+
fill=0
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
transformed_tensor = transformed_tensor.reshape(F, B, C, H, W)
|
| 306 |
+
return transformed_tensor
|
| 307 |
+
|
| 308 |
+
def save_tensor_as_png(tensor, file_path):
|
| 309 |
+
"""
|
| 310 |
+
Save a 3*H*W tensor as a PNG image.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
tensor (torch.Tensor): Input tensor of shape (3, H, W).
|
| 314 |
+
file_path (str): Path to save the PNG file.
|
| 315 |
+
"""
|
| 316 |
+
if tensor.ndim != 3 or tensor.shape[0] != 3:
|
| 317 |
+
raise ValueError("Input tensor must have shape (3, H, W)")
|
| 318 |
+
|
| 319 |
+
# Convert tensor to PIL Image
|
| 320 |
+
image = TF.to_pil_image(tensor)
|
| 321 |
+
|
| 322 |
+
# Save image
|
| 323 |
+
image.save(file_path)
|
| 324 |
+
|
| 325 |
+
class BaseVideoDiTMinecraft(DiffusionForcingBase):
|
| 326 |
+
"""
|
| 327 |
+
Video generation for MineCraft with memory.
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
def __init__(self, cfg: DictConfig):
|
| 331 |
+
"""
|
| 332 |
+
Initialize the base video-DiT Minecraft class with the given configuration.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
cfg (DictConfig): Configuration object.
|
| 336 |
+
"""
|
| 337 |
+
self.n_tokens = cfg.n_frames // cfg.frame_stack # number of max tokens for the model
|
| 338 |
+
self.n_frames = cfg.n_frames
|
| 339 |
+
if hasattr(cfg, "n_tokens"):
|
| 340 |
+
self.n_tokens = cfg.n_tokens // cfg.frame_stack
|
| 341 |
+
self.memory_condition_length = cfg.memory_condition_length
|
| 342 |
+
self.pose_cond_dim = getattr(cfg, "pose_cond_dim", 5)
|
| 343 |
+
|
| 344 |
+
self.use_plucker = getattr(cfg, "use_plucker", True)
|
| 345 |
+
self.relative_embedding = getattr(cfg, "relative_embedding", True)
|
| 346 |
+
self.state_embed_only_on_qk = getattr(cfg, "state_embed_only_on_qk", True)
|
| 347 |
+
self.use_memory_attention = getattr(cfg, "use_memory_attention", True)
|
| 348 |
+
self.add_timestamp_embedding = getattr(cfg, "add_timestamp_embedding", True)
|
| 349 |
+
self.ref_mode = getattr(cfg, "ref_mode", 'sequential')
|
| 350 |
+
self.log_curve = getattr(cfg, "log_curve", False)
|
| 351 |
+
self.focal_length = getattr(cfg, "focal_length", 0.35)
|
| 352 |
+
self.log_video = cfg.log_video
|
| 353 |
+
self.save_local = getattr(cfg, "save_local", True)
|
| 354 |
+
self.local_save_dir = getattr(cfg, "local_save_dir", None)
|
| 355 |
+
self.lpips_batch_size = getattr(cfg, "lpips_batch_size", 16)
|
| 356 |
+
self.next_frame_length = getattr(cfg, "next_frame_length", 1)
|
| 357 |
+
self.require_pose_prediction = getattr(cfg, "require_pose_prediction", False)
|
| 358 |
+
|
| 359 |
+
super().__init__(cfg)
|
| 360 |
+
|
| 361 |
+
def _build_model(self):
|
| 362 |
+
|
| 363 |
+
self.diffusion_model = Diffusion(
|
| 364 |
+
reference_length=self.memory_condition_length,
|
| 365 |
+
x_shape=self.x_stacked_shape,
|
| 366 |
+
action_cond_dim=self.action_cond_dim,
|
| 367 |
+
pose_cond_dim=self.pose_cond_dim,
|
| 368 |
+
is_causal=self.causal,
|
| 369 |
+
cfg=self.cfg.diffusion,
|
| 370 |
+
is_dit=True,
|
| 371 |
+
use_plucker=self.use_plucker,
|
| 372 |
+
relative_embedding=self.relative_embedding,
|
| 373 |
+
state_embed_only_on_qk=self.state_embed_only_on_qk,
|
| 374 |
+
use_memory_attention=self.use_memory_attention,
|
| 375 |
+
add_timestamp_embedding=self.add_timestamp_embedding,
|
| 376 |
+
ref_mode=self.ref_mode
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity()
|
| 380 |
+
vae = VAE_models["vit-l-20-shallow-encoder"]()
|
| 381 |
+
self.vae = vae.eval()
|
| 382 |
+
|
| 383 |
+
if self.require_pose_prediction:
|
| 384 |
+
self.pose_prediction_model = PosePredictionNet()
|
| 385 |
+
|
| 386 |
+
def _generate_noise_levels(self, xs: torch.Tensor, masks = None) -> torch.Tensor:
|
| 387 |
+
"""
|
| 388 |
+
Generate noise levels for training.
|
| 389 |
+
"""
|
| 390 |
+
num_frames, batch_size, *_ = xs.shape
|
| 391 |
+
match self.cfg.noise_level:
|
| 392 |
+
case "random_all": # entirely random noise levels
|
| 393 |
+
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
|
| 394 |
+
case "same":
|
| 395 |
+
noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
|
| 396 |
+
noise_levels[1:] = noise_levels[0]
|
| 397 |
+
|
| 398 |
+
if masks is not None:
|
| 399 |
+
# for frames that are not available, treat as full noise
|
| 400 |
+
discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
|
| 401 |
+
noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
|
| 402 |
+
|
| 403 |
+
return noise_levels
|
| 404 |
+
|
| 405 |
+
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
|
| 406 |
+
"""
|
| 407 |
+
Perform a single training step.
|
| 408 |
+
|
| 409 |
+
This function processes the input batch,
|
| 410 |
+
encodes the input frames, generates noise levels, and computes the loss using the diffusion model.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
batch: Input batch of data containing frames, conditions, poses, etc.
|
| 414 |
+
batch_idx: Index of the current batch.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
dict: A dictionary containing the training loss.
|
| 418 |
+
"""
|
| 419 |
+
xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
|
| 420 |
+
|
| 421 |
+
if self.use_plucker:
|
| 422 |
+
if self.relative_embedding:
|
| 423 |
+
input_pose_condition = []
|
| 424 |
+
frame_idx_list = []
|
| 425 |
+
for i in range(self.n_frames):
|
| 426 |
+
input_pose_condition.append(
|
| 427 |
+
convert_to_plucker(
|
| 428 |
+
torch.cat([c2w_mat[i:i + 1], c2w_mat[-self.memory_condition_length:]]).clone(),
|
| 429 |
+
0,
|
| 430 |
+
focal_length=self.focal_length,
|
| 431 |
+
image_height=xs.shape[-2],image_width=xs.shape[-1]
|
| 432 |
+
).to(xs.dtype)
|
| 433 |
+
)
|
| 434 |
+
frame_idx_list.append(
|
| 435 |
+
torch.cat([
|
| 436 |
+
frame_idx[i:i + 1] - frame_idx[i:i + 1],
|
| 437 |
+
frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1]
|
| 438 |
+
]).clone()
|
| 439 |
+
)
|
| 440 |
+
input_pose_condition = torch.cat(input_pose_condition)
|
| 441 |
+
frame_idx_list = torch.cat(frame_idx_list)
|
| 442 |
+
else:
|
| 443 |
+
input_pose_condition = convert_to_plucker(
|
| 444 |
+
c2w_mat, 0, focal_length=self.focal_length
|
| 445 |
+
).to(xs.dtype)
|
| 446 |
+
frame_idx_list = frame_idx
|
| 447 |
+
else:
|
| 448 |
+
input_pose_condition = pose_conditions.to(xs.dtype)
|
| 449 |
+
frame_idx_list = None
|
| 450 |
+
|
| 451 |
+
xs = self.encode(xs)
|
| 452 |
+
|
| 453 |
+
noise_levels = self._generate_noise_levels(xs)
|
| 454 |
+
|
| 455 |
+
if self.memory_condition_length:
|
| 456 |
+
noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
|
| 457 |
+
conditions[-self.memory_condition_length:] *= 0
|
| 458 |
+
|
| 459 |
+
_, loss = self.diffusion_model(
|
| 460 |
+
xs,
|
| 461 |
+
conditions,
|
| 462 |
+
input_pose_condition,
|
| 463 |
+
noise_levels=noise_levels,
|
| 464 |
+
reference_length=self.memory_condition_length,
|
| 465 |
+
frame_idx=frame_idx_list
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
if self.memory_condition_length:
|
| 469 |
+
loss = loss[:-self.memory_condition_length]
|
| 470 |
+
|
| 471 |
+
loss = self.reweight_loss(loss, None)
|
| 472 |
+
|
| 473 |
+
if batch_idx % 20 == 0:
|
| 474 |
+
self.log("training/loss", loss.cpu())
|
| 475 |
+
|
| 476 |
+
return {"loss": loss}
|
| 477 |
+
|
| 478 |
+
def on_validation_epoch_end(self, namespace="validation") -> None:
|
| 479 |
+
if not self.validation_step_outputs:
|
| 480 |
+
return
|
| 481 |
+
|
| 482 |
+
xs_pred = []
|
| 483 |
+
xs = []
|
| 484 |
+
for pred, gt in self.validation_step_outputs:
|
| 485 |
+
xs_pred.append(pred)
|
| 486 |
+
xs.append(gt)
|
| 487 |
+
|
| 488 |
+
xs_pred = torch.cat(xs_pred, 1)
|
| 489 |
+
if gt is not None:
|
| 490 |
+
xs = torch.cat(xs, 1)
|
| 491 |
+
else:
|
| 492 |
+
xs = None
|
| 493 |
+
|
| 494 |
+
if self.logger and self.log_video:
|
| 495 |
+
log_video(
|
| 496 |
+
xs_pred,
|
| 497 |
+
xs,
|
| 498 |
+
step=None if namespace == "test" else self.global_step,
|
| 499 |
+
namespace=namespace + "_vis",
|
| 500 |
+
context_frames=self.context_frames,
|
| 501 |
+
logger=self.logger.experiment,
|
| 502 |
+
save_local=self.save_local,
|
| 503 |
+
local_save_dir=self.local_save_dir,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
if xs is not None:
|
| 507 |
+
# Move data to the same device as LPIPS model for metric calculation
|
| 508 |
+
device = next(self.validation_lpips_model.parameters()).device
|
| 509 |
+
xs_pred_device = xs_pred.to(device)
|
| 510 |
+
xs_device = xs.to(device)
|
| 511 |
+
|
| 512 |
+
metric_dict = get_validation_metrics_for_videos(
|
| 513 |
+
xs_pred_device, xs_device,
|
| 514 |
+
lpips_model=self.validation_lpips_model,
|
| 515 |
+
lpips_batch_size=self.lpips_batch_size)
|
| 516 |
+
|
| 517 |
+
self.log_dict(
|
| 518 |
+
{"mse": metric_dict['mse'],
|
| 519 |
+
"psnr": metric_dict['psnr'],
|
| 520 |
+
"lpips": metric_dict['lpips']},
|
| 521 |
+
sync_dist=True
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if self.log_curve:
|
| 525 |
+
psnr_values = metric_dict['frame_wise_psnr'].cpu().tolist()
|
| 526 |
+
frames = list(range(len(psnr_values)))
|
| 527 |
+
line_plot = wandb.plot.line_series(
|
| 528 |
+
xs = frames,
|
| 529 |
+
ys = [psnr_values],
|
| 530 |
+
keys = ["PSNR"],
|
| 531 |
+
title = "Frame-wise PSNR",
|
| 532 |
+
xname = "Frame index"
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
self.logger.experiment.log({"frame_wise_psnr_plot": line_plot})
|
| 536 |
+
|
| 537 |
+
self.validation_step_outputs.clear()
|
| 538 |
+
|
| 539 |
+
def _preprocess_batch(self, batch):
|
| 540 |
+
|
| 541 |
+
xs, conditions, pose_conditions, frame_index = batch
|
| 542 |
+
|
| 543 |
+
if self.action_cond_dim:
|
| 544 |
+
conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
|
| 545 |
+
conditions = rearrange(conditions, "b t d -> t b d").contiguous()
|
| 546 |
+
else:
|
| 547 |
+
raise NotImplementedError("Only support external cond.")
|
| 548 |
+
|
| 549 |
+
pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous()
|
| 550 |
+
c2w_mat = euler_to_camera_to_world_matrix(pose_conditions)
|
| 551 |
+
xs = rearrange(xs, "b t c ... -> t b c ...").contiguous()
|
| 552 |
+
frame_index = rearrange(frame_index, "b t -> t b").contiguous()
|
| 553 |
+
|
| 554 |
+
return xs, conditions, pose_conditions, c2w_mat, frame_index
|
| 555 |
+
|
| 556 |
+
def encode(self, x):
|
| 557 |
+
# vae encoding
|
| 558 |
+
T = x.shape[0]
|
| 559 |
+
H, W = x.shape[-2:]
|
| 560 |
+
scaling_factor = 0.07843137255
|
| 561 |
+
|
| 562 |
+
x = rearrange(x, "t b c h w -> (t b) c h w")
|
| 563 |
+
with torch.no_grad():
|
| 564 |
+
x = self.vae.encode(x * 2 - 1).mean * scaling_factor
|
| 565 |
+
x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size)
|
| 566 |
+
return x
|
| 567 |
+
|
| 568 |
+
def decode(self, x):
|
| 569 |
+
total_frames = x.shape[0]
|
| 570 |
+
scaling_factor = 0.07843137255
|
| 571 |
+
x = rearrange(x, "t b c h w -> (t b) (h w) c")
|
| 572 |
+
with torch.no_grad():
|
| 573 |
+
x = (self.vae.decode(x / scaling_factor) + 1) / 2
|
| 574 |
+
x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames)
|
| 575 |
+
return x
|
| 576 |
+
|
| 577 |
+
def _generate_condition_indices(self, curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon):
|
| 578 |
+
"""
|
| 579 |
+
Generate indices for condition similarity based on the current frame and pose conditions.
|
| 580 |
+
"""
|
| 581 |
+
if curr_frame < memory_condition_length:
|
| 582 |
+
random_idx = [i for i in range(curr_frame)] + [0] * (memory_condition_length - curr_frame)
|
| 583 |
+
random_idx = np.repeat(np.array(random_idx)[:, None], xs_pred.shape[1], -1)
|
| 584 |
+
else:
|
| 585 |
+
# Generate points in a sphere and filter based on field of view
|
| 586 |
+
num_samples = 10000
|
| 587 |
+
radius = 30
|
| 588 |
+
points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device)
|
| 589 |
+
points = points[:, None].repeat(1, pose_conditions.shape[1], 1)
|
| 590 |
+
points += pose_conditions[curr_frame, :, :3][None]
|
| 591 |
+
fov_half_h = torch.tensor(105 / 2, device=pose_conditions.device)
|
| 592 |
+
fov_half_v = torch.tensor(75 / 2, device=pose_conditions.device)
|
| 593 |
+
|
| 594 |
+
# in_fov1 = is_inside_fov_3d_hv(
|
| 595 |
+
# points, pose_conditions[curr_frame, :, :3],
|
| 596 |
+
# pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1],
|
| 597 |
+
# fov_half_h, fov_half_v
|
| 598 |
+
# )
|
| 599 |
+
|
| 600 |
+
in_fov1 = torch.stack([
|
| 601 |
+
is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
|
| 602 |
+
for pc in pose_conditions[curr_frame:curr_frame+horizon]
|
| 603 |
+
])
|
| 604 |
+
|
| 605 |
+
in_fov1 = torch.sum(in_fov1, 0) > 0
|
| 606 |
+
|
| 607 |
+
# Compute overlap ratios and select indices
|
| 608 |
+
in_fov_list = torch.stack([
|
| 609 |
+
is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
|
| 610 |
+
for pc in pose_conditions[:curr_frame]
|
| 611 |
+
])
|
| 612 |
+
|
| 613 |
+
random_idx = []
|
| 614 |
+
for _ in range(memory_condition_length):
|
| 615 |
+
overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
|
| 616 |
+
|
| 617 |
+
confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
|
| 618 |
+
|
| 619 |
+
if len(random_idx) > 0:
|
| 620 |
+
confidence[torch.cat(random_idx)] = -1e10
|
| 621 |
+
_, r_idx = torch.topk(confidence, k=1, dim=0)
|
| 622 |
+
random_idx.append(r_idx[0])
|
| 623 |
+
|
| 624 |
+
# choice 1: directly remove overlapping region
|
| 625 |
+
occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
|
| 626 |
+
in_fov1 = in_fov1 & ~occupied_mask
|
| 627 |
+
|
| 628 |
+
# choice 2: apply similarity filter
|
| 629 |
+
# cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
|
| 630 |
+
# range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
|
| 631 |
+
# cos_sim = cos_sim.mean((-2,-1))
|
| 632 |
+
|
| 633 |
+
# mask_sim = cos_sim>0.9
|
| 634 |
+
# in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
|
| 635 |
+
|
| 636 |
+
random_idx = torch.stack(random_idx).cpu()
|
| 637 |
+
|
| 638 |
+
return random_idx
|
| 639 |
+
|
| 640 |
+
def _prepare_conditions(self,
|
| 641 |
+
start_frame, curr_frame, horizon, conditions,
|
| 642 |
+
pose_conditions, c2w_mat, frame_idx, random_idx,
|
| 643 |
+
image_width, image_height):
|
| 644 |
+
"""
|
| 645 |
+
Prepare input conditions and pose conditions for sampling.
|
| 646 |
+
"""
|
| 647 |
+
|
| 648 |
+
padding = torch.zeros((len(random_idx),) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype)
|
| 649 |
+
input_condition = torch.cat([conditions[start_frame:curr_frame + horizon], padding], dim=0)
|
| 650 |
+
|
| 651 |
+
batch_size = conditions.shape[1]
|
| 652 |
+
|
| 653 |
+
if self.use_plucker:
|
| 654 |
+
if self.relative_embedding:
|
| 655 |
+
frame_idx_list = []
|
| 656 |
+
input_pose_condition = []
|
| 657 |
+
for i in range(start_frame, curr_frame + horizon):
|
| 658 |
+
input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]]).clone(), 0, focal_length=self.focal_length,
|
| 659 |
+
image_width=image_width, image_height=image_height).to(conditions.dtype))
|
| 660 |
+
frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(batch_size)], range(batch_size)]-frame_idx[i:i+1]]))
|
| 661 |
+
input_pose_condition = torch.cat(input_pose_condition)
|
| 662 |
+
frame_idx_list = torch.cat(frame_idx_list)
|
| 663 |
+
|
| 664 |
+
else:
|
| 665 |
+
input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
|
| 666 |
+
input_pose_condition = convert_to_plucker(input_pose_condition, 0, focal_length=self.focal_length)
|
| 667 |
+
frame_idx_list = None
|
| 668 |
+
else:
|
| 669 |
+
input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
|
| 670 |
+
frame_idx_list = None
|
| 671 |
+
|
| 672 |
+
return input_condition, input_pose_condition, frame_idx_list
|
| 673 |
+
|
| 674 |
+
def _prepare_noise_levels(self, scheduling_matrix, m, curr_frame, batch_size, memory_condition_length):
|
| 675 |
+
"""
|
| 676 |
+
Prepare noise levels for the current sampling step.
|
| 677 |
+
"""
|
| 678 |
+
from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[:, None].repeat(batch_size, axis=1)
|
| 679 |
+
to_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m + 1]))[:, None].repeat(batch_size, axis=1)
|
| 680 |
+
if memory_condition_length:
|
| 681 |
+
from_noise_levels = np.concatenate([from_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
|
| 682 |
+
to_noise_levels = np.concatenate([to_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
|
| 683 |
+
from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
|
| 684 |
+
to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
|
| 685 |
+
return from_noise_levels, to_noise_levels
|
| 686 |
+
|
| 687 |
+
def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
|
| 688 |
+
"""
|
| 689 |
+
Perform a single validation step.
|
| 690 |
+
|
| 691 |
+
This function processes the input batch, encodes frames, generates predictions using a sliding window approach,
|
| 692 |
+
and handles condition similarity logic for sampling. The results are decoded and stored for evaluation.
|
| 693 |
+
|
| 694 |
+
Args:
|
| 695 |
+
batch: Input batch of data containing frames, conditions, poses, etc.
|
| 696 |
+
batch_idx: Index of the current batch.
|
| 697 |
+
namespace: Namespace for logging (default: "validation").
|
| 698 |
+
|
| 699 |
+
Returns:
|
| 700 |
+
None: Appends the predicted and ground truth frames to `self.validation_step_outputs`.
|
| 701 |
+
"""
|
| 702 |
+
# Preprocess the input batch
|
| 703 |
+
memory_condition_length = self.memory_condition_length
|
| 704 |
+
xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
# Encode frames in chunks if necessary
|
| 708 |
+
total_frame = xs_raw.shape[0]
|
| 709 |
+
if total_frame > 10:
|
| 710 |
+
xs = torch.cat([
|
| 711 |
+
self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu()
|
| 712 |
+
for i in range(10)
|
| 713 |
+
])
|
| 714 |
+
else:
|
| 715 |
+
xs = self.encode(xs_raw).cpu()
|
| 716 |
+
|
| 717 |
+
n_frames, batch_size, *_ = xs.shape
|
| 718 |
+
curr_frame = 0
|
| 719 |
+
|
| 720 |
+
# Initialize context frames
|
| 721 |
+
n_context_frames = self.context_frames // self.frame_stack
|
| 722 |
+
xs_pred = xs[:n_context_frames].clone()
|
| 723 |
+
curr_frame += n_context_frames
|
| 724 |
+
|
| 725 |
+
# Progress bar for sampling
|
| 726 |
+
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
| 727 |
+
|
| 728 |
+
while curr_frame < n_frames:
|
| 729 |
+
# Determine the horizon for the current chunk
|
| 730 |
+
horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame
|
| 731 |
+
assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens."
|
| 732 |
+
|
| 733 |
+
# Generate scheduling matrix and initialize noise
|
| 734 |
+
scheduling_matrix = self._generate_scheduling_matrix(horizon)
|
| 735 |
+
chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:]))
|
| 736 |
+
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device)
|
| 737 |
+
xs_pred = torch.cat([xs_pred, chunk], 0)
|
| 738 |
+
|
| 739 |
+
# Sliding window: only input the last `n_tokens` frames
|
| 740 |
+
start_frame = max(0, curr_frame + horizon - self.n_tokens)
|
| 741 |
+
pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon})
|
| 742 |
+
|
| 743 |
+
# Handle condition similarity logic
|
| 744 |
+
if memory_condition_length:
|
| 745 |
+
random_idx = self._generate_condition_indices(
|
| 746 |
+
curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
|
| 750 |
+
|
| 751 |
+
# Prepare input conditions and pose conditions
|
| 752 |
+
input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
|
| 753 |
+
start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
|
| 754 |
+
image_width=xs_raw.shape[-1], image_height=xs_raw.shape[-2]
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
# Perform sampling for each step in the scheduling matrix
|
| 758 |
+
for m in range(scheduling_matrix.shape[0] - 1):
|
| 759 |
+
from_noise_levels, to_noise_levels = self._prepare_noise_levels(
|
| 760 |
+
scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
xs_pred[start_frame:] = self.diffusion_model.sample_step(
|
| 764 |
+
xs_pred[start_frame:].to(input_condition.device),
|
| 765 |
+
input_condition,
|
| 766 |
+
input_pose_condition,
|
| 767 |
+
from_noise_levels[start_frame:],
|
| 768 |
+
to_noise_levels[start_frame:],
|
| 769 |
+
current_frame=curr_frame,
|
| 770 |
+
mode="validation",
|
| 771 |
+
reference_length=memory_condition_length,
|
| 772 |
+
frame_idx=frame_idx_list
|
| 773 |
+
).cpu()
|
| 774 |
+
|
| 775 |
+
# Remove condition similarity frames if applicable
|
| 776 |
+
if memory_condition_length:
|
| 777 |
+
xs_pred = xs_pred[:-memory_condition_length]
|
| 778 |
+
|
| 779 |
+
curr_frame += horizon
|
| 780 |
+
pbar.update(horizon)
|
| 781 |
+
|
| 782 |
+
# Decode predictions and ground truth
|
| 783 |
+
xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
|
| 784 |
+
xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
|
| 785 |
+
|
| 786 |
+
# Store results for evaluation (move to CPU to save GPU memory)
|
| 787 |
+
self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
|
| 788 |
+
return
|
| 789 |
+
|
| 790 |
+
@torch.no_grad()
|
| 791 |
+
def interactive(self, first_frame, new_actions, first_pose, device,
|
| 792 |
+
memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
|
| 793 |
+
|
| 794 |
+
memory_condition_length = self.memory_condition_length
|
| 795 |
+
|
| 796 |
+
if memory_latent_frames is None:
|
| 797 |
+
first_frame = torch.from_numpy(first_frame)
|
| 798 |
+
new_actions = torch.from_numpy(new_actions)
|
| 799 |
+
first_pose = torch.from_numpy(first_pose)
|
| 800 |
+
first_frame_encode = self.encode(first_frame[None, None].to(device))
|
| 801 |
+
memory_latent_frames = first_frame_encode.cpu()
|
| 802 |
+
memory_actions = new_actions[None, None].to(device)
|
| 803 |
+
memory_poses = first_pose[None, None].to(device)
|
| 804 |
+
new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
|
| 805 |
+
memory_c2w = new_c2w_mat[None, None].to(device)
|
| 806 |
+
memory_frame_idx = torch.tensor([[0]]).to(device)
|
| 807 |
+
return first_frame.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()
|
| 808 |
+
else:
|
| 809 |
+
memory_latent_frames = torch.from_numpy(memory_latent_frames)
|
| 810 |
+
memory_actions = torch.from_numpy(memory_actions).to(device)
|
| 811 |
+
memory_poses = torch.from_numpy(memory_poses).to(device)
|
| 812 |
+
memory_c2w = torch.from_numpy(memory_c2w).to(device)
|
| 813 |
+
memory_frame_idx = torch.from_numpy(memory_frame_idx).to(device)
|
| 814 |
+
new_actions = new_actions.to(device)
|
| 815 |
+
|
| 816 |
+
curr_frame = 0
|
| 817 |
+
batch_size = 1
|
| 818 |
+
horizon = self.next_frame_length
|
| 819 |
+
n_frames = curr_frame + horizon
|
| 820 |
+
# context
|
| 821 |
+
n_context_frames = len(memory_latent_frames)
|
| 822 |
+
xs_pred = memory_latent_frames[:n_context_frames].clone()
|
| 823 |
+
curr_frame += n_context_frames
|
| 824 |
+
|
| 825 |
+
pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
|
| 826 |
+
|
| 827 |
+
new_pose_condition_list = []
|
| 828 |
+
last_frame = xs_pred[-1].clone()
|
| 829 |
+
last_pose_condition = memory_poses[-1].clone()
|
| 830 |
+
curr_actions = new_actions.clone()
|
| 831 |
+
for hi in range(len(new_actions)):
|
| 832 |
+
last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
|
| 833 |
+
new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None, hi], last_pose_condition)
|
| 834 |
+
new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
|
| 835 |
+
new_pose_condition = last_pose_condition + new_pose_condition_offset
|
| 836 |
+
new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
|
| 837 |
+
new_pose_condition[:,3:] %= 360
|
| 838 |
+
last_pose_condition = new_pose_condition.clone()
|
| 839 |
+
new_pose_condition_list.append(new_pose_condition[None])
|
| 840 |
+
new_pose_condition_list = torch.cat(new_pose_condition_list, 0)
|
| 841 |
+
|
| 842 |
+
ai = 0
|
| 843 |
+
while ai < len(new_actions):
|
| 844 |
+
next_horizon = min(horizon, len(new_actions) - ai)
|
| 845 |
+
last_frame = xs_pred[-1].clone()
|
| 846 |
+
curr_actions = new_actions[ai:ai+next_horizon].clone()
|
| 847 |
+
|
| 848 |
+
new_pose_condition = new_pose_condition_list[ai:ai+next_horizon].clone()
|
| 849 |
+
|
| 850 |
+
new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
|
| 851 |
+
memory_poses = torch.cat([memory_poses, new_pose_condition])
|
| 852 |
+
memory_actions = torch.cat([memory_actions, curr_actions[:, None]])
|
| 853 |
+
memory_c2w = torch.cat([memory_c2w, new_c2w_mat])
|
| 854 |
+
new_indices = memory_frame_idx[-1,0] + torch.arange(next_horizon, device=memory_frame_idx.device) + 1
|
| 855 |
+
|
| 856 |
+
memory_frame_idx = torch.cat([memory_frame_idx, new_indices[:, None]])
|
| 857 |
+
|
| 858 |
+
conditions = memory_actions.clone()
|
| 859 |
+
pose_conditions = memory_poses.clone()
|
| 860 |
+
c2w_mat = memory_c2w .clone()
|
| 861 |
+
frame_idx = memory_frame_idx.clone()
|
| 862 |
+
|
| 863 |
+
# generation on frame
|
| 864 |
+
scheduling_matrix = self._generate_scheduling_matrix(next_horizon)
|
| 865 |
+
chunk = torch.randn((next_horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
|
| 866 |
+
chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
|
| 867 |
+
|
| 868 |
+
xs_pred = torch.cat([xs_pred, chunk], 0)
|
| 869 |
+
|
| 870 |
+
# sliding window: only input the last n_tokens frames
|
| 871 |
+
start_frame = max(0, curr_frame - self.n_tokens)
|
| 872 |
+
|
| 873 |
+
pbar.set_postfix(
|
| 874 |
+
{
|
| 875 |
+
"start": start_frame,
|
| 876 |
+
"end": curr_frame + next_horizon,
|
| 877 |
+
}
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
# Handle condition similarity logic
|
| 881 |
+
if memory_condition_length:
|
| 882 |
+
random_idx = self._generate_condition_indices(
|
| 883 |
+
curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, next_horizon
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
# random_idx = np.unique(random_idx)[:, None]
|
| 887 |
+
# memory_condition_length = len(random_idx)
|
| 888 |
+
xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
|
| 889 |
+
|
| 890 |
+
# Prepare input conditions and pose conditions
|
| 891 |
+
input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
|
| 892 |
+
start_frame, curr_frame, next_horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
|
| 893 |
+
image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
|
| 894 |
+
)
|
| 895 |
+
|
| 896 |
+
# Perform sampling for each step in the scheduling matrix
|
| 897 |
+
for m in range(scheduling_matrix.shape[0] - 1):
|
| 898 |
+
from_noise_levels, to_noise_levels = self._prepare_noise_levels(
|
| 899 |
+
scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
xs_pred[start_frame:] = self.diffusion_model.sample_step(
|
| 903 |
+
xs_pred[start_frame:].to(input_condition.device),
|
| 904 |
+
input_condition,
|
| 905 |
+
input_pose_condition,
|
| 906 |
+
from_noise_levels[start_frame:],
|
| 907 |
+
to_noise_levels[start_frame:],
|
| 908 |
+
current_frame=curr_frame,
|
| 909 |
+
mode="validation",
|
| 910 |
+
reference_length=memory_condition_length,
|
| 911 |
+
frame_idx=frame_idx_list
|
| 912 |
+
).cpu()
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
if memory_condition_length:
|
| 916 |
+
xs_pred = xs_pred[:-memory_condition_length]
|
| 917 |
+
|
| 918 |
+
curr_frame += next_horizon
|
| 919 |
+
pbar.update(next_horizon)
|
| 920 |
+
ai += next_horizon
|
| 921 |
+
|
| 922 |
+
memory_latent_frames = torch.cat([memory_latent_frames, xs_pred[n_context_frames:]])
|
| 923 |
+
xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
|
| 924 |
+
|
| 925 |
+
return xs_pred.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), \
|
| 926 |
+
memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()
|
algorithms/worldmem/models/attention.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from collections import namedtuple
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
class TemporalAxialAttention(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dim: int,
|
| 18 |
+
heads: int,
|
| 19 |
+
dim_head: int,
|
| 20 |
+
reference_length: int,
|
| 21 |
+
rotary_emb: RotaryEmbedding,
|
| 22 |
+
is_causal: bool = True,
|
| 23 |
+
is_temporal_independent: bool = False,
|
| 24 |
+
use_domain_adapter = False
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.inner_dim = dim_head * heads
|
| 28 |
+
self.heads = heads
|
| 29 |
+
self.head_dim = dim_head
|
| 30 |
+
self.inner_dim = dim_head * heads
|
| 31 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 32 |
+
|
| 33 |
+
self.use_domain_adapter = use_domain_adapter
|
| 34 |
+
if self.use_domain_adapter:
|
| 35 |
+
lora_rank = 8
|
| 36 |
+
self.lora_A = nn.Linear(dim, lora_rank, bias=False)
|
| 37 |
+
self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
|
| 38 |
+
|
| 39 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 40 |
+
|
| 41 |
+
self.rotary_emb = rotary_emb
|
| 42 |
+
self.is_causal = is_causal
|
| 43 |
+
self.is_temporal_independent = is_temporal_independent
|
| 44 |
+
|
| 45 |
+
self.reference_length = reference_length
|
| 46 |
+
|
| 47 |
+
def forward(self, x: torch.Tensor):
|
| 48 |
+
B, T, H, W, D = x.shape
|
| 49 |
+
|
| 50 |
+
# if T>=9:
|
| 51 |
+
# try:
|
| 52 |
+
# # x = torch.cat([x[:,:-1],x[:,16-T:17-T],x[:,-1:]], dim=1)
|
| 53 |
+
# x = torch.cat([x[:,16-T:17-T],x], dim=1)
|
| 54 |
+
# except:
|
| 55 |
+
# import pdb;pdb.set_trace()
|
| 56 |
+
# print("="*50)
|
| 57 |
+
# print(x.shape)
|
| 58 |
+
|
| 59 |
+
B, T, H, W, D = x.shape
|
| 60 |
+
|
| 61 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 62 |
+
|
| 63 |
+
if self.use_domain_adapter:
|
| 64 |
+
q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
|
| 65 |
+
q = q+q_lora
|
| 66 |
+
k = k+k_lora
|
| 67 |
+
v = v+v_lora
|
| 68 |
+
|
| 69 |
+
q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 70 |
+
k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 71 |
+
v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 72 |
+
|
| 73 |
+
q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
|
| 74 |
+
k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
|
| 75 |
+
|
| 76 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 77 |
+
|
| 78 |
+
if self.is_temporal_independent:
|
| 79 |
+
attn_bias = torch.ones((T, T), dtype=q.dtype, device=q.device)
|
| 80 |
+
attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
|
| 81 |
+
attn_bias[range(T), range(T)] = 0
|
| 82 |
+
elif self.is_causal:
|
| 83 |
+
attn_bias = torch.triu(torch.ones((T, T), dtype=q.dtype, device=q.device), diagonal=1)
|
| 84 |
+
attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
|
| 85 |
+
attn_bias[(T-self.reference_length):] = float('-inf')
|
| 86 |
+
attn_bias[range(T), range(T)] = 0
|
| 87 |
+
else:
|
| 88 |
+
attn_bias = None
|
| 89 |
+
|
| 90 |
+
try:
|
| 91 |
+
x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
|
| 92 |
+
except:
|
| 93 |
+
import pdb;pdb.set_trace()
|
| 94 |
+
|
| 95 |
+
x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
|
| 96 |
+
x = x.to(q.dtype)
|
| 97 |
+
|
| 98 |
+
# linear proj
|
| 99 |
+
x = self.to_out(x)
|
| 100 |
+
|
| 101 |
+
# if T>=10:
|
| 102 |
+
# try:
|
| 103 |
+
# # x = torch.cat([x[:,:-2],x[:,-1:]], dim=1)
|
| 104 |
+
# x = x[:,1:]
|
| 105 |
+
# except:
|
| 106 |
+
# import pdb;pdb.set_trace()
|
| 107 |
+
# print(x.shape)
|
| 108 |
+
return x
|
| 109 |
+
|
| 110 |
+
class SpatialAxialAttention(nn.Module):
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
dim: int,
|
| 114 |
+
heads: int,
|
| 115 |
+
dim_head: int,
|
| 116 |
+
rotary_emb: RotaryEmbedding,
|
| 117 |
+
use_domain_adapter = False
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.inner_dim = dim_head * heads
|
| 121 |
+
self.heads = heads
|
| 122 |
+
self.head_dim = dim_head
|
| 123 |
+
self.inner_dim = dim_head * heads
|
| 124 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 125 |
+
self.use_domain_adapter = use_domain_adapter
|
| 126 |
+
if self.use_domain_adapter:
|
| 127 |
+
lora_rank = 8
|
| 128 |
+
self.lora_A = nn.Linear(dim, lora_rank, bias=False)
|
| 129 |
+
self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
|
| 130 |
+
|
| 131 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 132 |
+
|
| 133 |
+
self.rotary_emb = rotary_emb
|
| 134 |
+
|
| 135 |
+
def forward(self, x: torch.Tensor):
|
| 136 |
+
B, T, H, W, D = x.shape
|
| 137 |
+
|
| 138 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 139 |
+
|
| 140 |
+
if self.use_domain_adapter:
|
| 141 |
+
q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
|
| 142 |
+
q = q+q_lora
|
| 143 |
+
k = k+k_lora
|
| 144 |
+
v = v+v_lora
|
| 145 |
+
|
| 146 |
+
q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
|
| 147 |
+
k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
|
| 148 |
+
v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
|
| 149 |
+
|
| 150 |
+
freqs = self.rotary_emb.get_axial_freqs(H, W)
|
| 151 |
+
q = apply_rotary_emb(freqs, q)
|
| 152 |
+
k = apply_rotary_emb(freqs, k)
|
| 153 |
+
|
| 154 |
+
# prepare for attn
|
| 155 |
+
q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
|
| 156 |
+
k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
|
| 157 |
+
v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
|
| 158 |
+
|
| 159 |
+
x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False)
|
| 160 |
+
|
| 161 |
+
x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
|
| 162 |
+
x = x.to(q.dtype)
|
| 163 |
+
|
| 164 |
+
# linear proj
|
| 165 |
+
x = self.to_out(x)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
class MemTemporalAxialAttention(nn.Module):
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
dim: int,
|
| 172 |
+
heads: int,
|
| 173 |
+
dim_head: int,
|
| 174 |
+
rotary_emb: RotaryEmbedding,
|
| 175 |
+
is_causal: bool = True,
|
| 176 |
+
):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.inner_dim = dim_head * heads
|
| 179 |
+
self.heads = heads
|
| 180 |
+
self.head_dim = dim_head
|
| 181 |
+
self.inner_dim = dim_head * heads
|
| 182 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 183 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 184 |
+
|
| 185 |
+
self.rotary_emb = rotary_emb
|
| 186 |
+
self.is_causal = is_causal
|
| 187 |
+
|
| 188 |
+
self.reference_length = 3
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor):
|
| 191 |
+
B, T, H, W, D = x.shape
|
| 192 |
+
|
| 193 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 197 |
+
k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 198 |
+
v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
|
| 203 |
+
# k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
|
| 204 |
+
|
| 205 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 206 |
+
|
| 207 |
+
# if T == 21000:
|
| 208 |
+
# # 手动计算缩放点积分数
|
| 209 |
+
# _, _, _, d_k = q.shape
|
| 210 |
+
# scores = torch.einsum("b h n d, b h m d -> b h n m", q, k) / (d_k ** 0.5) # Shape: (B, T_q, T_k)
|
| 211 |
+
|
| 212 |
+
# # 计算注意力图 (Attention Map)
|
| 213 |
+
# attention_map = F.softmax(scores, dim=-1) # Shape: (B, T_q, T_k)
|
| 214 |
+
# b_, h_, n_, m_ = attention_map.shape
|
| 215 |
+
# attention_map = attention_map.reshape(1, int(np.sqrt(b_/1)), int(np.sqrt(b_/1)), h_, n_, m_)
|
| 216 |
+
# attention_map = attention_map.mean(3)
|
| 217 |
+
|
| 218 |
+
# attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
|
| 219 |
+
# T_origin = T - self.reference_length
|
| 220 |
+
# attn_bias[:T_origin, T_origin:] = 1
|
| 221 |
+
# attn_bias[range(T), range(T)] = 1
|
| 222 |
+
|
| 223 |
+
# attention_map = attention_map * attn_bias
|
| 224 |
+
|
| 225 |
+
# # print 注意力图
|
| 226 |
+
# import matplotlib.pyplot as plt
|
| 227 |
+
# fig, axes = plt.subplots(21000, 21000, figsize=(9, 9)) # 调整figsize以适配图像大小
|
| 228 |
+
|
| 229 |
+
# # 遍历3*3维度
|
| 230 |
+
# for i in range(21000):
|
| 231 |
+
# for j in range(21000):
|
| 232 |
+
# # 取出第(i, j)个子图像
|
| 233 |
+
# img = attention_map[0, :, :, i, j].cpu().numpy()
|
| 234 |
+
# axes[i, j].imshow(img, cmap='viridis') # 可以自定义cmap
|
| 235 |
+
# axes[i, j].axis('off') # 隐藏坐标轴
|
| 236 |
+
|
| 237 |
+
# # 调整子图间距
|
| 238 |
+
# plt.tight_layout()
|
| 239 |
+
# plt.savefig('attention_map.png')
|
| 240 |
+
# import pdb; pdb.set_trace()
|
| 241 |
+
# plt.close()
|
| 242 |
+
|
| 243 |
+
attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
|
| 244 |
+
attn_bias = attn_bias.masked_fill(attn_bias == 0, float('-inf'))
|
| 245 |
+
T_origin = T - self.reference_length
|
| 246 |
+
attn_bias[:T_origin, T_origin:] = 0
|
| 247 |
+
attn_bias[range(T), range(T)] = 0
|
| 248 |
+
|
| 249 |
+
# if T==121000:
|
| 250 |
+
# import pdb;pdb.set_trace()
|
| 251 |
+
|
| 252 |
+
try:
|
| 253 |
+
x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
|
| 254 |
+
except:
|
| 255 |
+
import pdb;pdb.set_trace()
|
| 256 |
+
|
| 257 |
+
x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
|
| 258 |
+
x = x.to(q.dtype)
|
| 259 |
+
|
| 260 |
+
# linear proj
|
| 261 |
+
x = self.to_out(x)
|
| 262 |
+
return x
|
| 263 |
+
|
| 264 |
+
class MemFullAttention(nn.Module):
|
| 265 |
+
def __init__(
|
| 266 |
+
self,
|
| 267 |
+
dim: int,
|
| 268 |
+
heads: int,
|
| 269 |
+
dim_head: int,
|
| 270 |
+
reference_length: int,
|
| 271 |
+
rotary_emb: RotaryEmbedding,
|
| 272 |
+
is_causal: bool = True
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.inner_dim = dim_head * heads
|
| 276 |
+
self.heads = heads
|
| 277 |
+
self.head_dim = dim_head
|
| 278 |
+
self.inner_dim = dim_head * heads
|
| 279 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 280 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 281 |
+
|
| 282 |
+
self.rotary_emb = rotary_emb
|
| 283 |
+
self.is_causal = is_causal
|
| 284 |
+
|
| 285 |
+
self.reference_length = reference_length
|
| 286 |
+
|
| 287 |
+
self.store = None
|
| 288 |
+
|
| 289 |
+
def forward(self, x: torch.Tensor, relative_embedding=False,
|
| 290 |
+
extra_condition=None,
|
| 291 |
+
state_embed_only_on_qk=False,
|
| 292 |
+
reference_length=None):
|
| 293 |
+
|
| 294 |
+
B, T, H, W, D = x.shape
|
| 295 |
+
|
| 296 |
+
if state_embed_only_on_qk:
|
| 297 |
+
q, k, _ = self.to_qkv(x+extra_condition).chunk(3, dim=-1)
|
| 298 |
+
_, _, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 299 |
+
else:
|
| 300 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 301 |
+
|
| 302 |
+
if relative_embedding:
|
| 303 |
+
length = reference_length+1
|
| 304 |
+
n_frames = T // length
|
| 305 |
+
x = x.reshape(B, n_frames, length, H, W, D)
|
| 306 |
+
|
| 307 |
+
x_list = []
|
| 308 |
+
|
| 309 |
+
for i in range(n_frames):
|
| 310 |
+
if i == n_frames-1:
|
| 311 |
+
q_i = rearrange(q[:, i*length:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 312 |
+
k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 313 |
+
v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 314 |
+
else:
|
| 315 |
+
q_i = rearrange(q[:, i*length:i*length+1], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 316 |
+
k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 317 |
+
v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 318 |
+
|
| 319 |
+
q_i, k_i, v_i = map(lambda t: t.contiguous(), (q_i, k_i, v_i))
|
| 320 |
+
x_i = F.scaled_dot_product_attention(query=q_i, key=k_i, value=v_i)
|
| 321 |
+
x_i = rearrange(x_i, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
|
| 322 |
+
x_i = x_i.to(q.dtype)
|
| 323 |
+
x_list.append(x_i)
|
| 324 |
+
|
| 325 |
+
x = torch.cat(x_list, dim=1)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
else:
|
| 329 |
+
T_ = T - reference_length
|
| 330 |
+
q = rearrange(q, "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 331 |
+
k = rearrange(k[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 332 |
+
v = rearrange(v[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
|
| 333 |
+
|
| 334 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 335 |
+
x = F.scaled_dot_product_attention(query=q, key=k, value=v)
|
| 336 |
+
x = rearrange(x, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
|
| 337 |
+
x = x.to(q.dtype)
|
| 338 |
+
|
| 339 |
+
# linear proj
|
| 340 |
+
x = self.to_out(x)
|
| 341 |
+
|
| 342 |
+
return x
|
algorithms/worldmem/models/diffusion.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Callable
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
from omegaconf import DictConfig
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from .utils import linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, extract
|
| 9 |
+
from .dit import DiT_models
|
| 10 |
+
|
| 11 |
+
ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start", "model_out"])
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Diffusion(nn.Module):
|
| 15 |
+
# Special thanks to lucidrains for the implementation of the base Diffusion model
|
| 16 |
+
# https://github.com/lucidrains/denoising-diffusion-pytorch
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
x_shape: torch.Size,
|
| 21 |
+
reference_length: int,
|
| 22 |
+
action_cond_dim: int,
|
| 23 |
+
pose_cond_dim,
|
| 24 |
+
is_causal: bool,
|
| 25 |
+
cfg: DictConfig,
|
| 26 |
+
is_dit: bool=False,
|
| 27 |
+
use_plucker=False,
|
| 28 |
+
relative_embedding=False,
|
| 29 |
+
state_embed_only_on_qk=False,
|
| 30 |
+
use_memory_attention=False,
|
| 31 |
+
add_timestamp_embedding=False,
|
| 32 |
+
memory_token_cross_attention=False,
|
| 33 |
+
memory_cross_attn_layers=None,
|
| 34 |
+
ref_mode='sequential'
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.cfg = cfg
|
| 38 |
+
|
| 39 |
+
self.x_shape = x_shape
|
| 40 |
+
self.action_cond_dim = action_cond_dim
|
| 41 |
+
self.timesteps = cfg.timesteps
|
| 42 |
+
self.sampling_timesteps = cfg.sampling_timesteps
|
| 43 |
+
self.beta_schedule = cfg.beta_schedule
|
| 44 |
+
self.schedule_fn_kwargs = cfg.schedule_fn_kwargs
|
| 45 |
+
self.objective = cfg.objective
|
| 46 |
+
self.use_fused_snr = cfg.use_fused_snr
|
| 47 |
+
self.snr_clip = cfg.snr_clip
|
| 48 |
+
self.cum_snr_decay = cfg.cum_snr_decay
|
| 49 |
+
self.ddim_sampling_eta = cfg.ddim_sampling_eta
|
| 50 |
+
self.clip_noise = cfg.clip_noise
|
| 51 |
+
self.arch = cfg.architecture
|
| 52 |
+
self.stabilization_level = cfg.stabilization_level
|
| 53 |
+
self.is_causal = is_causal
|
| 54 |
+
self.is_dit = is_dit
|
| 55 |
+
self.reference_length = reference_length
|
| 56 |
+
self.pose_cond_dim = pose_cond_dim
|
| 57 |
+
self.use_plucker = use_plucker
|
| 58 |
+
self.relative_embedding = relative_embedding
|
| 59 |
+
self.state_embed_only_on_qk = state_embed_only_on_qk
|
| 60 |
+
self.use_memory_attention = use_memory_attention
|
| 61 |
+
self.add_timestamp_embedding = add_timestamp_embedding
|
| 62 |
+
self.memory_token_cross_attention = memory_token_cross_attention
|
| 63 |
+
self.memory_cross_attn_layers = memory_cross_attn_layers
|
| 64 |
+
self.ref_mode = ref_mode
|
| 65 |
+
if self.use_memory_attention:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
"WorldMem reference-frame use_memory_attention has been removed from DiT. "
|
| 68 |
+
"Use memory_token_cross_attention=True for compact memory tokens."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self._build_model()
|
| 72 |
+
self._build_buffer()
|
| 73 |
+
|
| 74 |
+
def _build_model(self):
|
| 75 |
+
x_channel = self.x_shape[0]
|
| 76 |
+
if self.is_dit:
|
| 77 |
+
self.model = DiT_models["DiT-S/2"](action_cond_dim=self.action_cond_dim,
|
| 78 |
+
reference_length=self.reference_length,
|
| 79 |
+
memory_token_cross_attention=self.memory_token_cross_attention,
|
| 80 |
+
memory_cross_attn_layers=self.memory_cross_attn_layers,
|
| 81 |
+
ref_mode=self.ref_mode)
|
| 82 |
+
else:
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
def _build_buffer(self):
|
| 86 |
+
if self.beta_schedule == "linear":
|
| 87 |
+
beta_schedule_fn = linear_beta_schedule
|
| 88 |
+
elif self.beta_schedule == "cosine":
|
| 89 |
+
beta_schedule_fn = cosine_beta_schedule
|
| 90 |
+
elif self.beta_schedule == "sigmoid":
|
| 91 |
+
beta_schedule_fn = sigmoid_beta_schedule
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"unknown beta schedule {self.beta_schedule}")
|
| 94 |
+
|
| 95 |
+
betas = beta_schedule_fn(self.timesteps, **self.schedule_fn_kwargs)
|
| 96 |
+
|
| 97 |
+
alphas = 1.0 - betas
|
| 98 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 99 |
+
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
| 100 |
+
|
| 101 |
+
# sampling related parameters
|
| 102 |
+
assert self.sampling_timesteps <= self.timesteps
|
| 103 |
+
self.is_ddim_sampling = self.sampling_timesteps < self.timesteps
|
| 104 |
+
|
| 105 |
+
# helper function to register buffer from float64 to float32
|
| 106 |
+
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
|
| 107 |
+
|
| 108 |
+
register_buffer("betas", betas)
|
| 109 |
+
register_buffer("alphas_cumprod", alphas_cumprod)
|
| 110 |
+
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
|
| 111 |
+
|
| 112 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
| 113 |
+
|
| 114 |
+
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
| 115 |
+
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
|
| 116 |
+
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
|
| 117 |
+
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
|
| 118 |
+
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
|
| 119 |
+
|
| 120 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
| 121 |
+
|
| 122 |
+
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
| 123 |
+
|
| 124 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
| 125 |
+
|
| 126 |
+
register_buffer("posterior_variance", posterior_variance)
|
| 127 |
+
|
| 128 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
| 129 |
+
|
| 130 |
+
register_buffer(
|
| 131 |
+
"posterior_log_variance_clipped",
|
| 132 |
+
torch.log(posterior_variance.clamp(min=1e-20)),
|
| 133 |
+
)
|
| 134 |
+
register_buffer(
|
| 135 |
+
"posterior_mean_coef1",
|
| 136 |
+
betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
|
| 137 |
+
)
|
| 138 |
+
register_buffer(
|
| 139 |
+
"posterior_mean_coef2",
|
| 140 |
+
(1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# calculate p2 reweighting
|
| 144 |
+
|
| 145 |
+
# register_buffer(
|
| 146 |
+
# "p2_loss_weight",
|
| 147 |
+
# (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
|
| 148 |
+
# ** -self.p2_loss_weight_gamma,
|
| 149 |
+
# )
|
| 150 |
+
|
| 151 |
+
# derive loss weight
|
| 152 |
+
# https://arxiv.org/abs/2303.09556
|
| 153 |
+
# snr: signal noise ratio
|
| 154 |
+
snr = alphas_cumprod / (1 - alphas_cumprod)
|
| 155 |
+
clipped_snr = snr.clone()
|
| 156 |
+
clipped_snr.clamp_(max=self.snr_clip)
|
| 157 |
+
|
| 158 |
+
register_buffer("clipped_snr", clipped_snr)
|
| 159 |
+
register_buffer("snr", snr)
|
| 160 |
+
|
| 161 |
+
def add_shape_channels(self, x):
|
| 162 |
+
return rearrange(x, f"... -> ...{' 1' * len(self.x_shape)}")
|
| 163 |
+
|
| 164 |
+
def model_predictions(self, x, t, action_cond=None, current_frame=None,
|
| 165 |
+
pose_cond=None, mode="training", reference_length=None, frame_idx=None,
|
| 166 |
+
memory_tokens=None, memory_token_mask=None, memory_retrieval_tokens=None, memory_retrieval_mask=None,
|
| 167 |
+
**memory_kwargs):
|
| 168 |
+
x = x.permute(1,0,2,3,4)
|
| 169 |
+
action_cond = action_cond.permute(1,0,2)
|
| 170 |
+
if pose_cond is not None and pose_cond[0] is not None:
|
| 171 |
+
try:
|
| 172 |
+
pose_cond = pose_cond.permute(1,0,2)
|
| 173 |
+
except:
|
| 174 |
+
pass
|
| 175 |
+
t = t.permute(1,0)
|
| 176 |
+
model_output = self.model(x, t, action_cond, current_frame=current_frame, pose_cond=pose_cond,
|
| 177 |
+
mode=mode, reference_length=reference_length, frame_idx=frame_idx,
|
| 178 |
+
memory_tokens=memory_tokens, memory_token_mask=memory_token_mask,
|
| 179 |
+
memory_retrieval_tokens=memory_retrieval_tokens,
|
| 180 |
+
memory_retrieval_mask=memory_retrieval_mask, **memory_kwargs)
|
| 181 |
+
model_output = model_output.permute(1,0,2,3,4)
|
| 182 |
+
x = x.permute(1,0,2,3,4)
|
| 183 |
+
t = t.permute(1,0)
|
| 184 |
+
|
| 185 |
+
if self.objective == "pred_noise":
|
| 186 |
+
pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
|
| 187 |
+
x_start = self.predict_start_from_noise(x, t, pred_noise)
|
| 188 |
+
|
| 189 |
+
elif self.objective == "pred_x0":
|
| 190 |
+
x_start = model_output
|
| 191 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 192 |
+
|
| 193 |
+
elif self.objective == "pred_v":
|
| 194 |
+
v = model_output
|
| 195 |
+
x_start = self.predict_start_from_v(x, t, v)
|
| 196 |
+
pred_noise = self.predict_noise_from_start(x, t, x_start)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
return ModelPrediction(pred_noise, x_start, model_output)
|
| 200 |
+
|
| 201 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
| 202 |
+
return (
|
| 203 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
| 204 |
+
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
def predict_noise_from_start(self, x_t, t, x0):
|
| 208 |
+
return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract(
|
| 209 |
+
self.sqrt_recipm1_alphas_cumprod, t, x_t.shape
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def predict_v(self, x_start, t, noise):
|
| 213 |
+
return (
|
| 214 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
|
| 215 |
+
- extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
def predict_start_from_v(self, x_t, t, v):
|
| 219 |
+
return (
|
| 220 |
+
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
|
| 221 |
+
- extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
def q_mean_variance(self, x_start, t):
|
| 225 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 226 |
+
variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
|
| 227 |
+
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
| 228 |
+
return mean, variance, log_variance
|
| 229 |
+
|
| 230 |
+
def q_posterior(self, x_start, x_t, t):
|
| 231 |
+
posterior_mean = (
|
| 232 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
| 233 |
+
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
| 234 |
+
)
|
| 235 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
| 236 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
| 237 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
| 238 |
+
|
| 239 |
+
def q_sample(self, x_start, t, noise=None):
|
| 240 |
+
if noise is None:
|
| 241 |
+
noise = torch.randn_like(x_start)
|
| 242 |
+
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
|
| 243 |
+
return (
|
| 244 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
| 245 |
+
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def p_mean_variance(
|
| 249 |
+
self,
|
| 250 |
+
x,
|
| 251 |
+
t,
|
| 252 |
+
action_cond=None,
|
| 253 |
+
pose_cond=None,
|
| 254 |
+
reference_length=None,
|
| 255 |
+
frame_idx=None,
|
| 256 |
+
memory_tokens=None,
|
| 257 |
+
memory_token_mask=None,
|
| 258 |
+
memory_retrieval_tokens=None,
|
| 259 |
+
memory_retrieval_mask=None,
|
| 260 |
+
**memory_kwargs,
|
| 261 |
+
):
|
| 262 |
+
model_pred = self.model_predictions(x=x, t=t, action_cond=action_cond,
|
| 263 |
+
pose_cond=pose_cond, reference_length=reference_length,
|
| 264 |
+
frame_idx=frame_idx, memory_tokens=memory_tokens, memory_token_mask=memory_token_mask,
|
| 265 |
+
memory_retrieval_tokens=memory_retrieval_tokens,
|
| 266 |
+
memory_retrieval_mask=memory_retrieval_mask, **memory_kwargs)
|
| 267 |
+
x_start = model_pred.pred_x_start
|
| 268 |
+
return self.q_posterior(x_start=x_start, x_t=x, t=t)
|
| 269 |
+
|
| 270 |
+
def compute_loss_weights(self, noise_levels: torch.Tensor):
|
| 271 |
+
|
| 272 |
+
snr = self.snr[noise_levels]
|
| 273 |
+
clipped_snr = self.clipped_snr[noise_levels]
|
| 274 |
+
normalized_clipped_snr = clipped_snr / self.snr_clip
|
| 275 |
+
normalized_snr = snr / self.snr_clip
|
| 276 |
+
|
| 277 |
+
if not self.use_fused_snr:
|
| 278 |
+
# min SNR reweighting
|
| 279 |
+
match self.objective:
|
| 280 |
+
case "pred_noise":
|
| 281 |
+
return clipped_snr / snr
|
| 282 |
+
case "pred_x0":
|
| 283 |
+
return clipped_snr
|
| 284 |
+
case "pred_v":
|
| 285 |
+
return clipped_snr / (snr + 1)
|
| 286 |
+
|
| 287 |
+
cum_snr = torch.zeros_like(normalized_snr)
|
| 288 |
+
for t in range(0, noise_levels.shape[0]):
|
| 289 |
+
if t == 0:
|
| 290 |
+
cum_snr[t] = normalized_clipped_snr[t]
|
| 291 |
+
else:
|
| 292 |
+
cum_snr[t] = self.cum_snr_decay * cum_snr[t - 1] + (1 - self.cum_snr_decay) * normalized_clipped_snr[t]
|
| 293 |
+
|
| 294 |
+
cum_snr = F.pad(cum_snr[:-1], (0, 0, 1, 0), value=0.0)
|
| 295 |
+
clipped_fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_clipped_snr)
|
| 296 |
+
fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_snr)
|
| 297 |
+
|
| 298 |
+
match self.objective:
|
| 299 |
+
case "pred_noise":
|
| 300 |
+
return clipped_fused_snr / fused_snr
|
| 301 |
+
case "pred_x0":
|
| 302 |
+
return clipped_fused_snr * self.snr_clip
|
| 303 |
+
case "pred_v":
|
| 304 |
+
return clipped_fused_snr * self.snr_clip / (fused_snr * self.snr_clip + 1)
|
| 305 |
+
case _:
|
| 306 |
+
raise ValueError(f"unknown objective {self.objective}")
|
| 307 |
+
|
| 308 |
+
def forward(
|
| 309 |
+
self,
|
| 310 |
+
x: torch.Tensor,
|
| 311 |
+
action_cond: Optional[torch.Tensor],
|
| 312 |
+
pose_cond,
|
| 313 |
+
noise_levels: torch.Tensor,
|
| 314 |
+
reference_length,
|
| 315 |
+
frame_idx=None,
|
| 316 |
+
memory_tokens=None,
|
| 317 |
+
memory_token_mask=None,
|
| 318 |
+
memory_retrieval_tokens=None,
|
| 319 |
+
memory_retrieval_mask=None,
|
| 320 |
+
**memory_kwargs,
|
| 321 |
+
):
|
| 322 |
+
noise = torch.randn_like(x)
|
| 323 |
+
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
|
| 324 |
+
|
| 325 |
+
noised_x = self.q_sample(x_start=x, t=noise_levels, noise=noise)
|
| 326 |
+
|
| 327 |
+
model_pred = self.model_predictions(x=noised_x, t=noise_levels, action_cond=action_cond,
|
| 328 |
+
pose_cond=pose_cond,reference_length=reference_length, frame_idx=frame_idx,
|
| 329 |
+
memory_tokens=memory_tokens, memory_token_mask=memory_token_mask,
|
| 330 |
+
memory_retrieval_tokens=memory_retrieval_tokens,
|
| 331 |
+
memory_retrieval_mask=memory_retrieval_mask, **memory_kwargs)
|
| 332 |
+
|
| 333 |
+
pred = model_pred.model_out
|
| 334 |
+
x_pred = model_pred.pred_x_start
|
| 335 |
+
|
| 336 |
+
if self.objective == "pred_noise":
|
| 337 |
+
target = noise
|
| 338 |
+
elif self.objective == "pred_x0":
|
| 339 |
+
target = x
|
| 340 |
+
elif self.objective == "pred_v":
|
| 341 |
+
target = self.predict_v(x, noise_levels, noise)
|
| 342 |
+
else:
|
| 343 |
+
raise ValueError(f"unknown objective {self.objective}")
|
| 344 |
+
|
| 345 |
+
# 训练的时候每个frame随便给噪声
|
| 346 |
+
loss = F.mse_loss(pred, target.detach(), reduction="none")
|
| 347 |
+
loss_weight = self.compute_loss_weights(noise_levels)
|
| 348 |
+
|
| 349 |
+
loss_weight = loss_weight.view(*loss_weight.shape, *((1,) * (loss.ndim - 2)))
|
| 350 |
+
|
| 351 |
+
loss = loss * loss_weight
|
| 352 |
+
|
| 353 |
+
return x_pred, loss
|
| 354 |
+
|
| 355 |
+
def sample_step(
|
| 356 |
+
self,
|
| 357 |
+
x: torch.Tensor,
|
| 358 |
+
action_cond: Optional[torch.Tensor],
|
| 359 |
+
pose_cond,
|
| 360 |
+
curr_noise_level: torch.Tensor,
|
| 361 |
+
next_noise_level: torch.Tensor,
|
| 362 |
+
guidance_fn: Optional[Callable] = None,
|
| 363 |
+
current_frame=None,
|
| 364 |
+
mode="training",
|
| 365 |
+
reference_length=None,
|
| 366 |
+
frame_idx=None,
|
| 367 |
+
memory_tokens=None,
|
| 368 |
+
memory_token_mask=None,
|
| 369 |
+
memory_retrieval_tokens=None,
|
| 370 |
+
memory_retrieval_mask=None,
|
| 371 |
+
**memory_kwargs,
|
| 372 |
+
):
|
| 373 |
+
real_steps = torch.linspace(-1, self.timesteps - 1, steps=self.sampling_timesteps + 1, device=x.device).long()
|
| 374 |
+
|
| 375 |
+
# convert noise levels (0 ~ sampling_timesteps) to real noise levels (-1 ~ timesteps - 1)
|
| 376 |
+
curr_noise_level = real_steps[curr_noise_level]
|
| 377 |
+
next_noise_level = real_steps[next_noise_level]
|
| 378 |
+
|
| 379 |
+
if self.is_ddim_sampling:
|
| 380 |
+
return self.ddim_sample_step(
|
| 381 |
+
x=x,
|
| 382 |
+
action_cond=action_cond,
|
| 383 |
+
pose_cond=pose_cond,
|
| 384 |
+
curr_noise_level=curr_noise_level,
|
| 385 |
+
next_noise_level=next_noise_level,
|
| 386 |
+
guidance_fn=guidance_fn,
|
| 387 |
+
current_frame=current_frame,
|
| 388 |
+
mode=mode,
|
| 389 |
+
reference_length=reference_length,
|
| 390 |
+
frame_idx=frame_idx,
|
| 391 |
+
memory_tokens=memory_tokens,
|
| 392 |
+
memory_token_mask=memory_token_mask,
|
| 393 |
+
memory_retrieval_tokens=memory_retrieval_tokens,
|
| 394 |
+
memory_retrieval_mask=memory_retrieval_mask,
|
| 395 |
+
**memory_kwargs,
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
# FIXME: temporary code for checking ddpm sampling
|
| 399 |
+
assert torch.all(
|
| 400 |
+
(curr_noise_level - 1 == next_noise_level) | ((curr_noise_level == -1) & (next_noise_level == -1))
|
| 401 |
+
), "Wrong noise level given for ddpm sampling."
|
| 402 |
+
|
| 403 |
+
assert (
|
| 404 |
+
self.sampling_timesteps == self.timesteps
|
| 405 |
+
), "sampling_timesteps should be equal to timesteps for ddpm sampling."
|
| 406 |
+
|
| 407 |
+
return self.ddpm_sample_step(
|
| 408 |
+
x=x,
|
| 409 |
+
action_cond=action_cond,
|
| 410 |
+
pose_cond=pose_cond,
|
| 411 |
+
curr_noise_level=curr_noise_level,
|
| 412 |
+
guidance_fn=guidance_fn,
|
| 413 |
+
reference_length=reference_length,
|
| 414 |
+
frame_idx=frame_idx,
|
| 415 |
+
memory_tokens=memory_tokens,
|
| 416 |
+
memory_token_mask=memory_token_mask,
|
| 417 |
+
memory_retrieval_tokens=memory_retrieval_tokens,
|
| 418 |
+
memory_retrieval_mask=memory_retrieval_mask,
|
| 419 |
+
**memory_kwargs,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
def ddpm_sample_step(
|
| 423 |
+
self,
|
| 424 |
+
x: torch.Tensor,
|
| 425 |
+
action_cond: Optional[torch.Tensor],
|
| 426 |
+
pose_cond,
|
| 427 |
+
curr_noise_level: torch.Tensor,
|
| 428 |
+
guidance_fn: Optional[Callable] = None,
|
| 429 |
+
reference_length=None,
|
| 430 |
+
frame_idx=None,
|
| 431 |
+
memory_tokens=None,
|
| 432 |
+
memory_token_mask=None,
|
| 433 |
+
memory_retrieval_tokens=None,
|
| 434 |
+
memory_retrieval_mask=None,
|
| 435 |
+
**memory_kwargs,
|
| 436 |
+
):
|
| 437 |
+
clipped_curr_noise_level = torch.where(
|
| 438 |
+
curr_noise_level < 0,
|
| 439 |
+
torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
|
| 440 |
+
curr_noise_level,
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
# treating as stabilization would require us to scale with sqrt of alpha_cum
|
| 444 |
+
orig_x = x.clone().detach()
|
| 445 |
+
scaled_context = self.q_sample(
|
| 446 |
+
x,
|
| 447 |
+
clipped_curr_noise_level,
|
| 448 |
+
noise=torch.zeros_like(x),
|
| 449 |
+
)
|
| 450 |
+
x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
|
| 451 |
+
|
| 452 |
+
if guidance_fn is not None:
|
| 453 |
+
raise NotImplementedError("Guidance function is not implemented for ddpm sampling yet.")
|
| 454 |
+
|
| 455 |
+
else:
|
| 456 |
+
model_mean, _, model_log_variance = self.p_mean_variance(
|
| 457 |
+
x=x,
|
| 458 |
+
t=clipped_curr_noise_level,
|
| 459 |
+
action_cond=action_cond,
|
| 460 |
+
pose_cond=pose_cond,
|
| 461 |
+
reference_length=reference_length,
|
| 462 |
+
frame_idx=frame_idx,
|
| 463 |
+
memory_tokens=memory_tokens,
|
| 464 |
+
memory_token_mask=memory_token_mask,
|
| 465 |
+
memory_retrieval_tokens=memory_retrieval_tokens,
|
| 466 |
+
memory_retrieval_mask=memory_retrieval_mask,
|
| 467 |
+
**memory_kwargs,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
noise = torch.where(
|
| 471 |
+
self.add_shape_channels(clipped_curr_noise_level > 0),
|
| 472 |
+
torch.randn_like(x),
|
| 473 |
+
0,
|
| 474 |
+
)
|
| 475 |
+
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
|
| 476 |
+
x_pred = model_mean + torch.exp(0.5 * model_log_variance) * noise
|
| 477 |
+
|
| 478 |
+
# only update frames where the noise level decreases
|
| 479 |
+
return torch.where(self.add_shape_channels(curr_noise_level == -1), orig_x, x_pred)
|
| 480 |
+
|
| 481 |
+
def ddim_sample_step(
|
| 482 |
+
self,
|
| 483 |
+
x: torch.Tensor,
|
| 484 |
+
action_cond: Optional[torch.Tensor],
|
| 485 |
+
pose_cond,
|
| 486 |
+
curr_noise_level: torch.Tensor,
|
| 487 |
+
next_noise_level: torch.Tensor,
|
| 488 |
+
guidance_fn: Optional[Callable] = None,
|
| 489 |
+
current_frame=None,
|
| 490 |
+
mode="training",
|
| 491 |
+
reference_length=None,
|
| 492 |
+
frame_idx=None,
|
| 493 |
+
memory_tokens=None,
|
| 494 |
+
memory_token_mask=None,
|
| 495 |
+
memory_retrieval_tokens=None,
|
| 496 |
+
memory_retrieval_mask=None,
|
| 497 |
+
**memory_kwargs,
|
| 498 |
+
):
|
| 499 |
+
# convert noise level -1 to self.stabilization_level - 1
|
| 500 |
+
clipped_curr_noise_level = torch.where(
|
| 501 |
+
curr_noise_level < 0,
|
| 502 |
+
torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
|
| 503 |
+
curr_noise_level,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# treating as stabilization would require us to scale with sqrt of alpha_cum
|
| 507 |
+
orig_x = x.clone().detach()
|
| 508 |
+
scaled_context = self.q_sample(
|
| 509 |
+
x,
|
| 510 |
+
clipped_curr_noise_level,
|
| 511 |
+
noise=torch.zeros_like(x),
|
| 512 |
+
)
|
| 513 |
+
x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
|
| 514 |
+
|
| 515 |
+
alpha = self.alphas_cumprod[clipped_curr_noise_level]
|
| 516 |
+
alpha_next = torch.where(
|
| 517 |
+
next_noise_level < 0,
|
| 518 |
+
torch.ones_like(next_noise_level),
|
| 519 |
+
self.alphas_cumprod[next_noise_level],
|
| 520 |
+
)
|
| 521 |
+
sigma = torch.where(
|
| 522 |
+
next_noise_level < 0,
|
| 523 |
+
torch.zeros_like(next_noise_level),
|
| 524 |
+
self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt(),
|
| 525 |
+
)
|
| 526 |
+
c = (1 - alpha_next - sigma**2).sqrt()
|
| 527 |
+
|
| 528 |
+
alpha_next = self.add_shape_channels(alpha_next)
|
| 529 |
+
c = self.add_shape_channels(c)
|
| 530 |
+
sigma = self.add_shape_channels(sigma)
|
| 531 |
+
|
| 532 |
+
if guidance_fn is not None:
|
| 533 |
+
with torch.enable_grad():
|
| 534 |
+
x = x.detach().requires_grad_()
|
| 535 |
+
|
| 536 |
+
model_pred = self.model_predictions(
|
| 537 |
+
x=x,
|
| 538 |
+
t=clipped_curr_noise_level,
|
| 539 |
+
action_cond=action_cond,
|
| 540 |
+
pose_cond=pose_cond,
|
| 541 |
+
current_frame=current_frame,
|
| 542 |
+
mode=mode,
|
| 543 |
+
reference_length=reference_length,
|
| 544 |
+
frame_idx=frame_idx,
|
| 545 |
+
memory_tokens=memory_tokens,
|
| 546 |
+
memory_token_mask=memory_token_mask,
|
| 547 |
+
memory_retrieval_tokens=memory_retrieval_tokens,
|
| 548 |
+
memory_retrieval_mask=memory_retrieval_mask,
|
| 549 |
+
**memory_kwargs,
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
guidance_loss = guidance_fn(model_pred.pred_x_start)
|
| 553 |
+
grad = -torch.autograd.grad(
|
| 554 |
+
guidance_loss,
|
| 555 |
+
x,
|
| 556 |
+
)[0]
|
| 557 |
+
|
| 558 |
+
pred_noise = model_pred.pred_noise + (1 - alpha_next).sqrt() * grad
|
| 559 |
+
x_start = self.predict_start_from_noise(x, clipped_curr_noise_level, pred_noise)
|
| 560 |
+
|
| 561 |
+
else:
|
| 562 |
+
# print(clipped_curr_noise_level)
|
| 563 |
+
model_pred = self.model_predictions(
|
| 564 |
+
x=x,
|
| 565 |
+
t=clipped_curr_noise_level,
|
| 566 |
+
action_cond=action_cond,
|
| 567 |
+
pose_cond=pose_cond,
|
| 568 |
+
current_frame=current_frame,
|
| 569 |
+
mode=mode,
|
| 570 |
+
reference_length=reference_length,
|
| 571 |
+
frame_idx=frame_idx,
|
| 572 |
+
memory_tokens=memory_tokens,
|
| 573 |
+
memory_token_mask=memory_token_mask,
|
| 574 |
+
memory_retrieval_tokens=memory_retrieval_tokens,
|
| 575 |
+
memory_retrieval_mask=memory_retrieval_mask,
|
| 576 |
+
**memory_kwargs,
|
| 577 |
+
)
|
| 578 |
+
x_start = model_pred.pred_x_start
|
| 579 |
+
pred_noise = model_pred.pred_noise
|
| 580 |
+
|
| 581 |
+
noise = torch.randn_like(x)
|
| 582 |
+
noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
|
| 583 |
+
|
| 584 |
+
x_pred = x_start * alpha_next.sqrt() + pred_noise * c + sigma * noise
|
| 585 |
+
|
| 586 |
+
# only update frames where the noise level decreases
|
| 587 |
+
mask = curr_noise_level == next_noise_level
|
| 588 |
+
x_pred = torch.where(
|
| 589 |
+
self.add_shape_channels(mask),
|
| 590 |
+
orig_x,
|
| 591 |
+
x_pred,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
return x_pred
|
algorithms/worldmem/models/dit.py
ADDED
|
@@ -0,0 +1,899 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
References:
|
| 3 |
+
- DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
| 4 |
+
- Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py
|
| 5 |
+
- Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional, Literal
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
from .rotary_embedding_torch import RotaryEmbedding
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from .attention import SpatialAxialAttention, TemporalAxialAttention
|
| 14 |
+
from timm.models.vision_transformer import Mlp
|
| 15 |
+
from timm.layers.helpers import to_2tuple
|
| 16 |
+
import math
|
| 17 |
+
from collections import namedtuple
|
| 18 |
+
from typing import Optional, Callable
|
| 19 |
+
|
| 20 |
+
def modulate(x, shift, scale):
|
| 21 |
+
fixed_dims = [1] * len(shift.shape[1:])
|
| 22 |
+
shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims)
|
| 23 |
+
scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims)
|
| 24 |
+
while shift.dim() < x.dim():
|
| 25 |
+
shift = shift.unsqueeze(-2)
|
| 26 |
+
scale = scale.unsqueeze(-2)
|
| 27 |
+
return x * (1 + scale) + shift
|
| 28 |
+
|
| 29 |
+
def gate(x, g):
|
| 30 |
+
fixed_dims = [1] * len(g.shape[1:])
|
| 31 |
+
g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims)
|
| 32 |
+
while g.dim() < x.dim():
|
| 33 |
+
g = g.unsqueeze(-2)
|
| 34 |
+
return g * x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class PatchEmbed(nn.Module):
|
| 38 |
+
"""2D Image to Patch Embedding"""
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
img_height=256,
|
| 43 |
+
img_width=256,
|
| 44 |
+
patch_size=16,
|
| 45 |
+
in_chans=3,
|
| 46 |
+
embed_dim=768,
|
| 47 |
+
norm_layer=None,
|
| 48 |
+
flatten=True,
|
| 49 |
+
):
|
| 50 |
+
super().__init__()
|
| 51 |
+
img_size = (img_height, img_width)
|
| 52 |
+
patch_size = to_2tuple(patch_size)
|
| 53 |
+
self.img_size = img_size
|
| 54 |
+
self.patch_size = patch_size
|
| 55 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 56 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 57 |
+
self.flatten = flatten
|
| 58 |
+
|
| 59 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 60 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 61 |
+
|
| 62 |
+
def forward(self, x, random_sample=False):
|
| 63 |
+
B, C, H, W = x.shape
|
| 64 |
+
assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 65 |
+
|
| 66 |
+
x = self.proj(x)
|
| 67 |
+
if self.flatten:
|
| 68 |
+
x = rearrange(x, "B C H W -> B (H W) C")
|
| 69 |
+
else:
|
| 70 |
+
x = rearrange(x, "B C H W -> B H W C")
|
| 71 |
+
x = self.norm(x)
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TimestepEmbedder(nn.Module):
|
| 76 |
+
"""
|
| 77 |
+
Embeds scalar timesteps into vector representations.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, hidden_size, frequency_embedding_size=256, freq_type='time_step'):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.mlp = nn.Sequential(
|
| 83 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size
|
| 84 |
+
nn.SiLU(),
|
| 85 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 86 |
+
)
|
| 87 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 88 |
+
self.freq_type = freq_type
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def timestep_embedding(t, dim, max_period=10000, freq_type='time_step'):
|
| 92 |
+
"""
|
| 93 |
+
Create sinusoidal timestep embeddings.
|
| 94 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 95 |
+
These may be fractional.
|
| 96 |
+
:param dim: the dimension of the output.
|
| 97 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 98 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 99 |
+
"""
|
| 100 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 101 |
+
half = dim // 2
|
| 102 |
+
|
| 103 |
+
if freq_type == 'time_step':
|
| 104 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
| 105 |
+
elif freq_type == 'spatial': # ~(-5 5)
|
| 106 |
+
freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi
|
| 107 |
+
elif freq_type == 'angle': # 0-360
|
| 108 |
+
freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi / 180
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
args = t[:, None].float() * freqs[None]
|
| 112 |
+
|
| 113 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 114 |
+
if dim % 2:
|
| 115 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 116 |
+
return embedding
|
| 117 |
+
|
| 118 |
+
def forward(self, t):
|
| 119 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size, freq_type=self.freq_type)
|
| 120 |
+
t_emb = self.mlp(t_freq)
|
| 121 |
+
return t_emb
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class FinalLayer(nn.Module):
|
| 125 |
+
"""
|
| 126 |
+
The final layer of DiT.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 132 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 133 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 134 |
+
|
| 135 |
+
def forward(self, x, c):
|
| 136 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 137 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 138 |
+
x = self.linear(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
MEMORY_TYPE_NAMES = ("anchor", "dynamic", "revisit")
|
| 143 |
+
MEMORY_TYPE_ANCHOR = 0
|
| 144 |
+
MEMORY_TYPE_DYNAMIC = 1
|
| 145 |
+
MEMORY_TYPE_REVISIT = 2
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class MemoryTokenCrossAttention(nn.Module):
|
| 149 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, num_memory_types=3):
|
| 150 |
+
super().__init__()
|
| 151 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 152 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 153 |
+
self.num_heads = num_heads
|
| 154 |
+
self.num_memory_types = num_memory_types
|
| 155 |
+
self.norm_q = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 156 |
+
self.norm_mem = nn.LayerNorm(hidden_size, eps=1e-6)
|
| 157 |
+
self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
|
| 158 |
+
self.norm_mlp = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 159 |
+
self.mlp = Mlp(
|
| 160 |
+
in_features=hidden_size,
|
| 161 |
+
hidden_features=mlp_hidden_dim,
|
| 162 |
+
act_layer=approx_gelu,
|
| 163 |
+
drop=0,
|
| 164 |
+
)
|
| 165 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 166 |
+
self.memory_type_embed = nn.Embedding(num_memory_types, hidden_size)
|
| 167 |
+
self.memory_type_scale = nn.Parameter(torch.ones(num_memory_types, hidden_size))
|
| 168 |
+
self.memory_type_gate = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, num_memory_types, bias=True))
|
| 169 |
+
self.last_gate_mean = None
|
| 170 |
+
self.last_delta_ratio = None
|
| 171 |
+
self.last_valid_fraction = None
|
| 172 |
+
self.last_type_gate_mean = None
|
| 173 |
+
for type_name in MEMORY_TYPE_NAMES[:num_memory_types]:
|
| 174 |
+
setattr(self, f"last_type_gate_{type_name}_mean", None)
|
| 175 |
+
nn.init.normal_(self.memory_type_embed.weight, std=0.02)
|
| 176 |
+
self.reset_identity_init()
|
| 177 |
+
|
| 178 |
+
def reset_identity_init(self):
|
| 179 |
+
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
| 180 |
+
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
| 181 |
+
nn.init.constant_(self.memory_type_gate[-1].weight, 0)
|
| 182 |
+
nn.init.constant_(self.memory_type_gate[-1].bias, 0)
|
| 183 |
+
|
| 184 |
+
def _attend(self, query, memory_tokens, memory_token_mask=None, memory_token_gate=None):
|
| 185 |
+
if memory_token_mask is None and memory_token_gate is None:
|
| 186 |
+
out, _ = self.attn(query, memory_tokens, memory_tokens, need_weights=False)
|
| 187 |
+
return out, None
|
| 188 |
+
|
| 189 |
+
if memory_token_mask is None:
|
| 190 |
+
memory_token_mask = torch.ones(
|
| 191 |
+
memory_tokens.shape[:2],
|
| 192 |
+
device=memory_tokens.device,
|
| 193 |
+
dtype=torch.bool,
|
| 194 |
+
)
|
| 195 |
+
else:
|
| 196 |
+
memory_token_mask = memory_token_mask.bool()
|
| 197 |
+
gate_tensor = None
|
| 198 |
+
if memory_token_gate is not None:
|
| 199 |
+
if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:2]):
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"memory_token_gate must have shape {tuple(memory_tokens.shape[:2])}, "
|
| 202 |
+
f"got {tuple(memory_token_gate.shape)}"
|
| 203 |
+
)
|
| 204 |
+
gate_tensor = memory_token_gate.to(device=memory_tokens.device, dtype=query.dtype)
|
| 205 |
+
memory_token_mask = memory_token_mask & (gate_tensor > 0)
|
| 206 |
+
valid_rows = memory_token_mask.any(dim=1)
|
| 207 |
+
out = torch.zeros_like(query)
|
| 208 |
+
if valid_rows.any():
|
| 209 |
+
attn_mask = None
|
| 210 |
+
key_padding_mask = ~memory_token_mask[valid_rows]
|
| 211 |
+
if gate_tensor is not None:
|
| 212 |
+
gate_bias = torch.log(gate_tensor[valid_rows].clamp_min(1.0e-6))
|
| 213 |
+
gate_bias = gate_bias[:, None, :].expand(-1, query.shape[1], -1)
|
| 214 |
+
attn_mask = gate_bias.repeat_interleave(self.num_heads, dim=0)
|
| 215 |
+
float_padding_mask = torch.zeros_like(gate_tensor[valid_rows], dtype=query.dtype)
|
| 216 |
+
key_padding_mask = float_padding_mask.masked_fill(key_padding_mask, float("-inf"))
|
| 217 |
+
attended, _ = self.attn(
|
| 218 |
+
query[valid_rows],
|
| 219 |
+
memory_tokens[valid_rows],
|
| 220 |
+
memory_tokens[valid_rows],
|
| 221 |
+
key_padding_mask=key_padding_mask,
|
| 222 |
+
attn_mask=attn_mask,
|
| 223 |
+
need_weights=False,
|
| 224 |
+
)
|
| 225 |
+
out[valid_rows] = attended.to(out.dtype)
|
| 226 |
+
return out, valid_rows
|
| 227 |
+
|
| 228 |
+
def _apply_memory_type(self, memory_tokens, memory_type_ids):
|
| 229 |
+
if memory_type_ids is None:
|
| 230 |
+
return memory_tokens
|
| 231 |
+
memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
|
| 232 |
+
type_embed = self.memory_type_embed(memory_type_ids).to(memory_tokens.dtype)
|
| 233 |
+
type_scale = self.memory_type_scale[memory_type_ids].to(memory_tokens.dtype)
|
| 234 |
+
while type_embed.dim() < memory_tokens.dim():
|
| 235 |
+
type_embed = type_embed.unsqueeze(0)
|
| 236 |
+
type_scale = type_scale.unsqueeze(0)
|
| 237 |
+
return memory_tokens * type_scale + type_embed
|
| 238 |
+
|
| 239 |
+
def _store_type_gate_diagnostics(self, stage_gate):
|
| 240 |
+
with torch.no_grad():
|
| 241 |
+
detached = stage_gate.detach().float()
|
| 242 |
+
self.last_type_gate_mean = detached.mean()
|
| 243 |
+
for type_idx, type_name in enumerate(MEMORY_TYPE_NAMES[: self.num_memory_types]):
|
| 244 |
+
setattr(self, f"last_type_gate_{type_name}_mean", detached[..., type_idx].mean())
|
| 245 |
+
|
| 246 |
+
def _type_stage_gate(self, c, memory_tokens, memory_type_ids):
|
| 247 |
+
if memory_type_ids is None:
|
| 248 |
+
return None
|
| 249 |
+
memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
|
| 250 |
+
stage_gate = torch.sigmoid(self.memory_type_gate(c)).to(memory_tokens.dtype)
|
| 251 |
+
self._store_type_gate_diagnostics(stage_gate)
|
| 252 |
+
if memory_tokens.dim() == 4:
|
| 253 |
+
batch_size, num_frames, num_tokens = memory_tokens.shape[:3]
|
| 254 |
+
if memory_type_ids.dim() == 1:
|
| 255 |
+
gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, num_frames, num_tokens)
|
| 256 |
+
elif tuple(memory_type_ids.shape) == (batch_size, num_frames, num_tokens):
|
| 257 |
+
gather_ids = memory_type_ids
|
| 258 |
+
else:
|
| 259 |
+
raise ValueError(
|
| 260 |
+
"rank-4 memory_type_ids must have shape (M,) or (B,T,M), "
|
| 261 |
+
f"got {tuple(memory_type_ids.shape)}"
|
| 262 |
+
)
|
| 263 |
+
return torch.gather(stage_gate, dim=-1, index=gather_ids)
|
| 264 |
+
if memory_tokens.dim() == 3:
|
| 265 |
+
batch_size, num_tokens = memory_tokens.shape[:2]
|
| 266 |
+
if memory_type_ids.dim() != 1:
|
| 267 |
+
raise ValueError("rank-3 memory_type_ids must have shape (M,)")
|
| 268 |
+
gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, stage_gate.shape[1], num_tokens)
|
| 269 |
+
return torch.gather(stage_gate, dim=-1, index=gather_ids).mean(dim=1)
|
| 270 |
+
raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}")
|
| 271 |
+
|
| 272 |
+
def _combine_memory_gate(self, memory_tokens, memory_token_gate, type_stage_gate):
|
| 273 |
+
combined_gate = type_stage_gate
|
| 274 |
+
if memory_token_gate is not None:
|
| 275 |
+
if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:-1]):
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"memory_token_gate must have shape {tuple(memory_tokens.shape[:-1])}, "
|
| 278 |
+
f"got {tuple(memory_token_gate.shape)}"
|
| 279 |
+
)
|
| 280 |
+
stream_gate = memory_token_gate.to(device=memory_tokens.device, dtype=memory_tokens.dtype)
|
| 281 |
+
combined_gate = stream_gate if combined_gate is None else combined_gate * stream_gate
|
| 282 |
+
return combined_gate
|
| 283 |
+
|
| 284 |
+
def _valid_mask(self, valid_rows, batch_size, num_frames, dtype, device):
|
| 285 |
+
if valid_rows is None:
|
| 286 |
+
return None
|
| 287 |
+
valid_rows = valid_rows.to(device=device, dtype=dtype)
|
| 288 |
+
if valid_rows.numel() == batch_size:
|
| 289 |
+
return valid_rows.view(batch_size, 1, 1, 1, 1)
|
| 290 |
+
if valid_rows.numel() == batch_size * num_frames:
|
| 291 |
+
return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None]
|
| 292 |
+
raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}")
|
| 293 |
+
|
| 294 |
+
def _gate_valid_mask(self, valid_rows, batch_size, num_frames, dtype, device):
|
| 295 |
+
if valid_rows is None:
|
| 296 |
+
return None
|
| 297 |
+
valid_rows = valid_rows.to(device=device, dtype=dtype)
|
| 298 |
+
if valid_rows.numel() == batch_size:
|
| 299 |
+
return valid_rows.view(batch_size, 1, 1)
|
| 300 |
+
if valid_rows.numel() == batch_size * num_frames:
|
| 301 |
+
return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None]
|
| 302 |
+
raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}")
|
| 303 |
+
|
| 304 |
+
def _residual_gate(self, residual_gate, batch_size, num_frames, dtype, device):
|
| 305 |
+
if residual_gate is None:
|
| 306 |
+
return None
|
| 307 |
+
if not torch.is_tensor(residual_gate):
|
| 308 |
+
return torch.tensor(float(residual_gate), dtype=dtype, device=device).view(1, 1, 1, 1, 1)
|
| 309 |
+
gate_tensor = residual_gate.to(device=device, dtype=dtype)
|
| 310 |
+
if gate_tensor.dim() == 0:
|
| 311 |
+
gate_tensor = gate_tensor.view(1, 1, 1, 1, 1)
|
| 312 |
+
elif gate_tensor.dim() == 1:
|
| 313 |
+
if gate_tensor.numel() == batch_size:
|
| 314 |
+
gate_tensor = gate_tensor.view(batch_size, 1, 1, 1, 1)
|
| 315 |
+
elif gate_tensor.numel() == batch_size * num_frames:
|
| 316 |
+
gate_tensor = rearrange(gate_tensor, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None]
|
| 317 |
+
else:
|
| 318 |
+
raise ValueError(f"residual_gate has incompatible shape: {tuple(gate_tensor.shape)}")
|
| 319 |
+
elif gate_tensor.dim() == 2:
|
| 320 |
+
if tuple(gate_tensor.shape) != (batch_size, num_frames):
|
| 321 |
+
raise ValueError(f"residual_gate must have shape (B,T), got {tuple(gate_tensor.shape)}")
|
| 322 |
+
gate_tensor = gate_tensor[:, :, None, None, None]
|
| 323 |
+
elif gate_tensor.dim() == 3:
|
| 324 |
+
if tuple(gate_tensor.shape[:2]) != (batch_size, num_frames):
|
| 325 |
+
raise ValueError(f"residual_gate must start with (B,T), got {tuple(gate_tensor.shape)}")
|
| 326 |
+
gate_tensor = gate_tensor[:, :, :, None, None]
|
| 327 |
+
else:
|
| 328 |
+
while gate_tensor.dim() < 5:
|
| 329 |
+
gate_tensor = gate_tensor.unsqueeze(-1)
|
| 330 |
+
return gate_tensor
|
| 331 |
+
|
| 332 |
+
def _store_diagnostics(self, output, base, gate_msa, gate_mlp, valid_rows):
|
| 333 |
+
with torch.no_grad():
|
| 334 |
+
batch_size, num_frames = base.shape[:2]
|
| 335 |
+
gate_values = torch.cat(
|
| 336 |
+
[gate_msa.detach().float().abs(), gate_mlp.detach().float().abs()],
|
| 337 |
+
dim=-1,
|
| 338 |
+
)
|
| 339 |
+
gate_mask = self._gate_valid_mask(
|
| 340 |
+
valid_rows,
|
| 341 |
+
batch_size,
|
| 342 |
+
num_frames,
|
| 343 |
+
dtype=gate_values.dtype,
|
| 344 |
+
device=gate_values.device,
|
| 345 |
+
)
|
| 346 |
+
if gate_mask is not None:
|
| 347 |
+
gate_values = gate_values * gate_mask
|
| 348 |
+
self.last_valid_fraction = valid_rows.detach().float().mean()
|
| 349 |
+
valid_count = (gate_mask.sum() * gate_values.shape[-1]).clamp_min(1.0)
|
| 350 |
+
self.last_gate_mean = gate_values.sum() / valid_count
|
| 351 |
+
else:
|
| 352 |
+
self.last_valid_fraction = base.detach().new_tensor(1.0, dtype=torch.float32)
|
| 353 |
+
self.last_gate_mean = gate_values.mean()
|
| 354 |
+
|
| 355 |
+
delta_norm = (output.detach().float() - base.detach().float()).norm()
|
| 356 |
+
base_norm = base.detach().float().norm()
|
| 357 |
+
self.last_delta_ratio = delta_norm / (base_norm + 1e-6)
|
| 358 |
+
|
| 359 |
+
def forward(
|
| 360 |
+
self,
|
| 361 |
+
x,
|
| 362 |
+
c,
|
| 363 |
+
memory_tokens,
|
| 364 |
+
memory_token_mask=None,
|
| 365 |
+
residual_base=None,
|
| 366 |
+
return_delta=False,
|
| 367 |
+
residual_gate=None,
|
| 368 |
+
memory_type_ids=None,
|
| 369 |
+
memory_token_gate=None,
|
| 370 |
+
):
|
| 371 |
+
B, T, H, W, D = x.shape
|
| 372 |
+
if residual_base is None:
|
| 373 |
+
residual_base = x
|
| 374 |
+
m_shift_msa, m_scale_msa, m_gate_msa, m_shift_mlp, m_scale_mlp, m_gate_mlp = (
|
| 375 |
+
self.adaLN_modulation(c).chunk(6, dim=-1)
|
| 376 |
+
)
|
| 377 |
+
query_source = modulate(self.norm_q(x), m_shift_msa, m_scale_msa)
|
| 378 |
+
type_stage_gate = self._type_stage_gate(c, memory_tokens, memory_type_ids)
|
| 379 |
+
effective_token_gate = self._combine_memory_gate(memory_tokens, memory_token_gate, type_stage_gate)
|
| 380 |
+
if memory_tokens.dim() == 3:
|
| 381 |
+
query = rearrange(query_source, "b t h w d -> b (t h w) d")
|
| 382 |
+
memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids)
|
| 383 |
+
valid_rows = None
|
| 384 |
+
if memory_token_mask is not None:
|
| 385 |
+
if tuple(memory_token_mask.shape) != tuple(memory_tokens.shape[:2]):
|
| 386 |
+
raise ValueError(
|
| 387 |
+
f"legacy memory mask must have shape {tuple(memory_tokens.shape[:2])}, "
|
| 388 |
+
f"got {tuple(memory_token_mask.shape)}"
|
| 389 |
+
)
|
| 390 |
+
out, valid_rows = self._attend(
|
| 391 |
+
query,
|
| 392 |
+
memory_tokens,
|
| 393 |
+
memory_token_mask=memory_token_mask,
|
| 394 |
+
memory_token_gate=effective_token_gate,
|
| 395 |
+
)
|
| 396 |
+
out = rearrange(out, "b (t h w) d -> b t h w d", t=T, h=H, w=W)
|
| 397 |
+
elif memory_tokens.dim() == 4:
|
| 398 |
+
assert memory_tokens.shape[:2] == (B, T), (
|
| 399 |
+
f"per-frame memory tokens must have shape (B, T, M, D), got {tuple(memory_tokens.shape)}"
|
| 400 |
+
)
|
| 401 |
+
query = rearrange(query_source, "b t h w d -> (b t) (h w) d")
|
| 402 |
+
memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids)
|
| 403 |
+
memory_tokens = rearrange(memory_tokens, "b t m d -> (b t) m d")
|
| 404 |
+
if effective_token_gate is not None:
|
| 405 |
+
effective_token_gate = rearrange(effective_token_gate, "b t m -> (b t) m")
|
| 406 |
+
valid_rows = None
|
| 407 |
+
if memory_token_mask is not None:
|
| 408 |
+
expected_mask_shape = (B, T, memory_tokens.shape[1])
|
| 409 |
+
if tuple(memory_token_mask.shape) != expected_mask_shape:
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"per-frame memory mask must have shape {expected_mask_shape}, "
|
| 412 |
+
f"got {tuple(memory_token_mask.shape)}"
|
| 413 |
+
)
|
| 414 |
+
memory_token_mask = rearrange(memory_token_mask.bool(), "b t m -> (b t) m")
|
| 415 |
+
out, valid_rows = self._attend(
|
| 416 |
+
query,
|
| 417 |
+
memory_tokens,
|
| 418 |
+
memory_token_mask=memory_token_mask,
|
| 419 |
+
memory_token_gate=effective_token_gate,
|
| 420 |
+
)
|
| 421 |
+
out = rearrange(out, "(b t) (h w) d -> b t h w d", b=B, t=T, h=H, w=W)
|
| 422 |
+
else:
|
| 423 |
+
raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}")
|
| 424 |
+
|
| 425 |
+
valid_mask = self._valid_mask(valid_rows, B, T, dtype=out.dtype, device=out.device)
|
| 426 |
+
residual_gate_tensor = self._residual_gate(residual_gate, B, T, dtype=out.dtype, device=out.device)
|
| 427 |
+
attn_delta = gate(out, m_gate_msa)
|
| 428 |
+
if valid_mask is not None:
|
| 429 |
+
attn_delta = attn_delta * valid_mask
|
| 430 |
+
if residual_gate_tensor is not None:
|
| 431 |
+
attn_delta = attn_delta * residual_gate_tensor
|
| 432 |
+
output = residual_base + attn_delta
|
| 433 |
+
|
| 434 |
+
mlp_delta = gate(self.mlp(modulate(self.norm_mlp(output), m_shift_mlp, m_scale_mlp)), m_gate_mlp)
|
| 435 |
+
if valid_mask is not None:
|
| 436 |
+
mlp_delta = mlp_delta * valid_mask
|
| 437 |
+
if residual_gate_tensor is not None:
|
| 438 |
+
mlp_delta = mlp_delta * residual_gate_tensor
|
| 439 |
+
output = output + mlp_delta
|
| 440 |
+
self._store_diagnostics(output, residual_base, m_gate_msa, m_gate_mlp, valid_rows)
|
| 441 |
+
if return_delta:
|
| 442 |
+
return attn_delta + mlp_delta
|
| 443 |
+
return output
|
| 444 |
+
|
| 445 |
+
class SpatioTemporalDiTBlock(nn.Module):
|
| 446 |
+
def __init__(
|
| 447 |
+
self,
|
| 448 |
+
hidden_size,
|
| 449 |
+
num_heads,
|
| 450 |
+
reference_length,
|
| 451 |
+
mlp_ratio=4.0,
|
| 452 |
+
is_causal=True,
|
| 453 |
+
spatial_rotary_emb: Optional[RotaryEmbedding] = None,
|
| 454 |
+
temporal_rotary_emb: Optional[RotaryEmbedding] = None,
|
| 455 |
+
use_memory_token_cross_attention=False,
|
| 456 |
+
ref_mode='sequential'
|
| 457 |
+
):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.is_causal = is_causal
|
| 460 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 461 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 462 |
+
|
| 463 |
+
self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 464 |
+
self.s_attn = SpatialAxialAttention(
|
| 465 |
+
hidden_size,
|
| 466 |
+
heads=num_heads,
|
| 467 |
+
dim_head=hidden_size // num_heads,
|
| 468 |
+
rotary_emb=spatial_rotary_emb
|
| 469 |
+
)
|
| 470 |
+
self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 471 |
+
self.s_mlp = Mlp(
|
| 472 |
+
in_features=hidden_size,
|
| 473 |
+
hidden_features=mlp_hidden_dim,
|
| 474 |
+
act_layer=approx_gelu,
|
| 475 |
+
drop=0,
|
| 476 |
+
)
|
| 477 |
+
self.s_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 478 |
+
|
| 479 |
+
self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 480 |
+
self.t_attn = TemporalAxialAttention(
|
| 481 |
+
hidden_size,
|
| 482 |
+
heads=num_heads,
|
| 483 |
+
dim_head=hidden_size // num_heads,
|
| 484 |
+
is_causal=is_causal,
|
| 485 |
+
rotary_emb=temporal_rotary_emb,
|
| 486 |
+
reference_length=reference_length
|
| 487 |
+
)
|
| 488 |
+
self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 489 |
+
self.t_mlp = Mlp(
|
| 490 |
+
in_features=hidden_size,
|
| 491 |
+
hidden_features=mlp_hidden_dim,
|
| 492 |
+
act_layer=approx_gelu,
|
| 493 |
+
drop=0,
|
| 494 |
+
)
|
| 495 |
+
self.t_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
|
| 496 |
+
|
| 497 |
+
self.reference_length = reference_length
|
| 498 |
+
self.use_memory_token_cross_attention = use_memory_token_cross_attention
|
| 499 |
+
if self.use_memory_token_cross_attention:
|
| 500 |
+
self.memory_token_cross_attn = MemoryTokenCrossAttention(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
| 501 |
+
|
| 502 |
+
self.ref_mode = ref_mode
|
| 503 |
+
|
| 504 |
+
if self.ref_mode == 'parallel':
|
| 505 |
+
self.parallel_map = nn.Linear(hidden_size, hidden_size)
|
| 506 |
+
|
| 507 |
+
def _expand_memory_stream(self, tokens, mask, stream_gate, type_idx, batch_size, num_frames):
|
| 508 |
+
if tokens is None or tokens.shape[-2] == 0:
|
| 509 |
+
return None
|
| 510 |
+
if tokens.dim() == 3:
|
| 511 |
+
if tokens.shape[0] != batch_size:
|
| 512 |
+
raise ValueError(f"rank-3 memory tokens must start with B={batch_size}, got {tuple(tokens.shape)}")
|
| 513 |
+
tokens = tokens[:, None].expand(-1, num_frames, -1, -1)
|
| 514 |
+
if mask is None:
|
| 515 |
+
mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool)
|
| 516 |
+
elif mask.dim() == 2:
|
| 517 |
+
mask = mask[:, None].expand(-1, num_frames, -1)
|
| 518 |
+
elif mask.dim() != 3:
|
| 519 |
+
raise ValueError(f"rank-3 stream mask must have rank 2 or 3, got {tuple(mask.shape)}")
|
| 520 |
+
elif tokens.dim() == 4:
|
| 521 |
+
if tuple(tokens.shape[:2]) != (batch_size, num_frames):
|
| 522 |
+
raise ValueError(
|
| 523 |
+
f"rank-4 memory tokens must start with (B,T)={(batch_size, num_frames)}, "
|
| 524 |
+
f"got {tuple(tokens.shape)}"
|
| 525 |
+
)
|
| 526 |
+
if mask is None:
|
| 527 |
+
mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool)
|
| 528 |
+
elif mask.dim() != 3:
|
| 529 |
+
raise ValueError(f"rank-4 stream mask must have rank 3, got {tuple(mask.shape)}")
|
| 530 |
+
else:
|
| 531 |
+
raise ValueError(f"memory stream tokens must be rank 3 or 4, got rank {tokens.dim()}")
|
| 532 |
+
if tuple(mask.shape) != tuple(tokens.shape[:3]):
|
| 533 |
+
raise ValueError(f"memory stream mask must have shape {tuple(tokens.shape[:3])}, got {tuple(mask.shape)}")
|
| 534 |
+
gate_tensor = self._expand_memory_stream_gate(stream_gate, tokens)
|
| 535 |
+
type_ids = torch.full((tokens.shape[2],), int(type_idx), device=tokens.device, dtype=torch.long)
|
| 536 |
+
return tokens, mask.to(device=tokens.device, dtype=torch.bool), gate_tensor, type_ids
|
| 537 |
+
|
| 538 |
+
def _expand_memory_stream_gate(self, stream_gate, tokens):
|
| 539 |
+
batch_size, num_frames, num_tokens = tokens.shape[:3]
|
| 540 |
+
if stream_gate is None:
|
| 541 |
+
return torch.ones(tokens.shape[:3], device=tokens.device, dtype=tokens.dtype)
|
| 542 |
+
if not torch.is_tensor(stream_gate):
|
| 543 |
+
return torch.full(tokens.shape[:3], float(stream_gate), device=tokens.device, dtype=tokens.dtype)
|
| 544 |
+
gate_tensor = stream_gate.to(device=tokens.device, dtype=tokens.dtype)
|
| 545 |
+
if gate_tensor.dim() == 0:
|
| 546 |
+
return gate_tensor.view(1, 1, 1).expand(batch_size, num_frames, num_tokens)
|
| 547 |
+
if gate_tensor.dim() == 1:
|
| 548 |
+
if gate_tensor.numel() != batch_size:
|
| 549 |
+
raise ValueError(f"rank-1 memory gate must have B={batch_size} values, got {tuple(gate_tensor.shape)}")
|
| 550 |
+
return gate_tensor.view(batch_size, 1, 1).expand(batch_size, num_frames, num_tokens)
|
| 551 |
+
if gate_tensor.dim() == 2:
|
| 552 |
+
if tuple(gate_tensor.shape) == (batch_size, num_frames):
|
| 553 |
+
return gate_tensor[:, :, None].expand(batch_size, num_frames, num_tokens)
|
| 554 |
+
if tuple(gate_tensor.shape) == (batch_size, num_tokens):
|
| 555 |
+
return gate_tensor[:, None, :].expand(batch_size, num_frames, num_tokens)
|
| 556 |
+
raise ValueError(
|
| 557 |
+
f"rank-2 memory gate must have shape (B,T) or (B,M), got {tuple(gate_tensor.shape)}"
|
| 558 |
+
)
|
| 559 |
+
if gate_tensor.dim() == 3:
|
| 560 |
+
if tuple(gate_tensor.shape) == (batch_size, num_frames, 1):
|
| 561 |
+
return gate_tensor.expand(batch_size, num_frames, num_tokens)
|
| 562 |
+
if tuple(gate_tensor.shape) == (batch_size, num_frames, num_tokens):
|
| 563 |
+
return gate_tensor
|
| 564 |
+
raise ValueError(
|
| 565 |
+
f"rank-3 memory gate must have shape (B,T,1) or (B,T,M), got {tuple(gate_tensor.shape)}"
|
| 566 |
+
)
|
| 567 |
+
raise ValueError(f"memory gate rank must be <=3, got rank {gate_tensor.dim()}")
|
| 568 |
+
|
| 569 |
+
def _pack_typed_memory_streams(
|
| 570 |
+
self,
|
| 571 |
+
batch_size,
|
| 572 |
+
num_frames,
|
| 573 |
+
memory_tokens=None,
|
| 574 |
+
memory_token_mask=None,
|
| 575 |
+
memory_dynamic_tokens=None,
|
| 576 |
+
memory_dynamic_mask=None,
|
| 577 |
+
memory_retrieval_tokens=None,
|
| 578 |
+
memory_retrieval_mask=None,
|
| 579 |
+
memory_anchor_gate=None,
|
| 580 |
+
memory_dynamic_gate=None,
|
| 581 |
+
memory_retrieval_gate=None,
|
| 582 |
+
):
|
| 583 |
+
streams = []
|
| 584 |
+
for tokens, mask, stream_gate, type_idx in (
|
| 585 |
+
(memory_tokens, memory_token_mask, memory_anchor_gate, MEMORY_TYPE_ANCHOR),
|
| 586 |
+
(memory_dynamic_tokens, memory_dynamic_mask, memory_dynamic_gate, MEMORY_TYPE_DYNAMIC),
|
| 587 |
+
(memory_retrieval_tokens, memory_retrieval_mask, memory_retrieval_gate, MEMORY_TYPE_REVISIT),
|
| 588 |
+
):
|
| 589 |
+
expanded = self._expand_memory_stream(tokens, mask, stream_gate, type_idx, batch_size, num_frames)
|
| 590 |
+
if expanded is not None:
|
| 591 |
+
streams.append(expanded)
|
| 592 |
+
if not streams:
|
| 593 |
+
return None
|
| 594 |
+
packed_tokens = torch.cat([item[0] for item in streams], dim=2)
|
| 595 |
+
packed_mask = torch.cat([item[1] for item in streams], dim=2)
|
| 596 |
+
packed_gate = torch.cat([item[2] for item in streams], dim=2)
|
| 597 |
+
packed_type_ids = torch.cat([item[3] for item in streams], dim=0)
|
| 598 |
+
valid_gate = packed_gate.masked_fill(~packed_mask, 0)
|
| 599 |
+
residual_gate = valid_gate.max(dim=2).values
|
| 600 |
+
return packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate
|
| 601 |
+
|
| 602 |
+
def forward(self, x, c, current_frame=None, timestep=None, is_last_block=False,
|
| 603 |
+
pose_cond=None, mode="training", c_action_cond=None, reference_length=None,
|
| 604 |
+
memory_tokens=None, memory_token_mask=None, memory_dynamic_tokens=None, memory_dynamic_mask=None,
|
| 605 |
+
memory_retrieval_tokens=None, memory_retrieval_mask=None, memory_anchor_gate=None,
|
| 606 |
+
memory_dynamic_gate=None, memory_retrieval_gate=None):
|
| 607 |
+
B, T, H, W, D = x.shape
|
| 608 |
+
|
| 609 |
+
# spatial block
|
| 610 |
+
|
| 611 |
+
s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1)
|
| 612 |
+
x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa)
|
| 613 |
+
x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp)
|
| 614 |
+
|
| 615 |
+
# temporal block
|
| 616 |
+
if c_action_cond is not None:
|
| 617 |
+
t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c_action_cond).chunk(6, dim=-1)
|
| 618 |
+
else:
|
| 619 |
+
t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1)
|
| 620 |
+
|
| 621 |
+
x_t = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa)
|
| 622 |
+
x_t = x_t + gate(self.t_mlp(modulate(self.t_norm2(x_t), t_shift_mlp, t_scale_mlp)), t_gate_mlp)
|
| 623 |
+
|
| 624 |
+
if self.ref_mode == 'sequential':
|
| 625 |
+
x = x_t
|
| 626 |
+
|
| 627 |
+
if self.use_memory_token_cross_attention:
|
| 628 |
+
memory_base = x
|
| 629 |
+
packed_memory = self._pack_typed_memory_streams(
|
| 630 |
+
B,
|
| 631 |
+
T,
|
| 632 |
+
memory_tokens=memory_tokens,
|
| 633 |
+
memory_token_mask=memory_token_mask,
|
| 634 |
+
memory_dynamic_tokens=memory_dynamic_tokens,
|
| 635 |
+
memory_dynamic_mask=memory_dynamic_mask,
|
| 636 |
+
memory_retrieval_tokens=memory_retrieval_tokens,
|
| 637 |
+
memory_retrieval_mask=memory_retrieval_mask,
|
| 638 |
+
memory_anchor_gate=memory_anchor_gate,
|
| 639 |
+
memory_dynamic_gate=memory_dynamic_gate,
|
| 640 |
+
memory_retrieval_gate=memory_retrieval_gate,
|
| 641 |
+
)
|
| 642 |
+
if packed_memory is not None:
|
| 643 |
+
packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate = packed_memory
|
| 644 |
+
x = self.memory_token_cross_attn(
|
| 645 |
+
memory_base,
|
| 646 |
+
c,
|
| 647 |
+
packed_tokens,
|
| 648 |
+
packed_mask,
|
| 649 |
+
residual_gate=residual_gate,
|
| 650 |
+
memory_type_ids=packed_type_ids,
|
| 651 |
+
memory_token_gate=packed_gate,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
if self.ref_mode == 'parallel':
|
| 655 |
+
x = x_t + self.parallel_map(x)
|
| 656 |
+
|
| 657 |
+
return x
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
class DiT(nn.Module):
|
| 661 |
+
"""
|
| 662 |
+
Diffusion model with a Transformer backbone.
|
| 663 |
+
"""
|
| 664 |
+
|
| 665 |
+
def __init__(
|
| 666 |
+
self,
|
| 667 |
+
input_h=18,
|
| 668 |
+
input_w=32,
|
| 669 |
+
patch_size=2,
|
| 670 |
+
in_channels=16,
|
| 671 |
+
hidden_size=1024,
|
| 672 |
+
depth=12,
|
| 673 |
+
num_heads=16,
|
| 674 |
+
mlp_ratio=4.0,
|
| 675 |
+
action_cond_dim=25,
|
| 676 |
+
max_frames=32,
|
| 677 |
+
reference_length=8,
|
| 678 |
+
memory_token_cross_attention=False,
|
| 679 |
+
memory_cross_attn_layers=None,
|
| 680 |
+
ref_mode='sequential'
|
| 681 |
+
):
|
| 682 |
+
super().__init__()
|
| 683 |
+
self.in_channels = in_channels
|
| 684 |
+
self.out_channels = in_channels
|
| 685 |
+
self.patch_size = patch_size
|
| 686 |
+
self.num_heads = num_heads
|
| 687 |
+
self.max_frames = max_frames
|
| 688 |
+
|
| 689 |
+
self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False)
|
| 690 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 691 |
+
|
| 692 |
+
self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
|
| 693 |
+
self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads)
|
| 694 |
+
|
| 695 |
+
self.external_cond = nn.Linear(action_cond_dim, hidden_size) if action_cond_dim > 0 else nn.Identity()
|
| 696 |
+
if memory_cross_attn_layers is None:
|
| 697 |
+
memory_cross_attn_layer_set = None
|
| 698 |
+
else:
|
| 699 |
+
memory_cross_attn_layer_set = {int(layer_idx) for layer_idx in memory_cross_attn_layers}
|
| 700 |
+
invalid_layers = sorted(
|
| 701 |
+
layer_idx for layer_idx in memory_cross_attn_layer_set if layer_idx < 0 or layer_idx >= depth
|
| 702 |
+
)
|
| 703 |
+
if invalid_layers:
|
| 704 |
+
raise ValueError(
|
| 705 |
+
f"memory_cross_attn_layers contains invalid indices {invalid_layers} for depth={depth}"
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
self.blocks = nn.ModuleList(
|
| 709 |
+
[
|
| 710 |
+
SpatioTemporalDiTBlock(
|
| 711 |
+
hidden_size,
|
| 712 |
+
num_heads,
|
| 713 |
+
mlp_ratio=mlp_ratio,
|
| 714 |
+
is_causal=True,
|
| 715 |
+
reference_length=reference_length,
|
| 716 |
+
spatial_rotary_emb=self.spatial_rotary_emb,
|
| 717 |
+
temporal_rotary_emb=self.temporal_rotary_emb,
|
| 718 |
+
use_memory_token_cross_attention=memory_token_cross_attention
|
| 719 |
+
and (memory_cross_attn_layer_set is None or block_idx in memory_cross_attn_layer_set),
|
| 720 |
+
ref_mode=ref_mode
|
| 721 |
+
)
|
| 722 |
+
for block_idx in range(depth)
|
| 723 |
+
]
|
| 724 |
+
)
|
| 725 |
+
self.memory_token_cross_attention = memory_token_cross_attention
|
| 726 |
+
self.memory_cross_attn_layers = (
|
| 727 |
+
None if memory_cross_attn_layer_set is None else tuple(sorted(memory_cross_attn_layer_set))
|
| 728 |
+
)
|
| 729 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 730 |
+
self.initialize_weights()
|
| 731 |
+
|
| 732 |
+
def initialize_weights(self):
|
| 733 |
+
# Initialize transformer layers:
|
| 734 |
+
def _basic_init(module):
|
| 735 |
+
if isinstance(module, nn.Linear):
|
| 736 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 737 |
+
if module.bias is not None:
|
| 738 |
+
nn.init.constant_(module.bias, 0)
|
| 739 |
+
|
| 740 |
+
self.apply(_basic_init)
|
| 741 |
+
|
| 742 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 743 |
+
w = self.x_embedder.proj.weight.data
|
| 744 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 745 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 746 |
+
|
| 747 |
+
# Initialize timestep embedding MLP:
|
| 748 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 749 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 750 |
+
|
| 751 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 752 |
+
for block in self.blocks:
|
| 753 |
+
nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0)
|
| 754 |
+
nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0)
|
| 755 |
+
nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0)
|
| 756 |
+
nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0)
|
| 757 |
+
|
| 758 |
+
# Zero-out output layers:
|
| 759 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 760 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 761 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 762 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 763 |
+
|
| 764 |
+
if self.memory_token_cross_attention:
|
| 765 |
+
for block in self.blocks:
|
| 766 |
+
memory_adapter = getattr(block, "memory_token_cross_attn", None)
|
| 767 |
+
if memory_adapter is not None:
|
| 768 |
+
memory_adapter.reset_identity_init()
|
| 769 |
+
|
| 770 |
+
def memory_adapter_delta_diagnostics(self):
|
| 771 |
+
diagnostics = {}
|
| 772 |
+
ratios = []
|
| 773 |
+
type_gate_values = {type_name: [] for type_name in MEMORY_TYPE_NAMES}
|
| 774 |
+
shared_type_gate_values = []
|
| 775 |
+
for block in self.blocks:
|
| 776 |
+
adapter = getattr(block, "memory_token_cross_attn", None)
|
| 777 |
+
if adapter is None:
|
| 778 |
+
continue
|
| 779 |
+
ratio = getattr(adapter, "last_delta_ratio", None)
|
| 780 |
+
if ratio is not None:
|
| 781 |
+
ratios.append(torch.as_tensor(ratio).detach().float())
|
| 782 |
+
type_gate = getattr(adapter, "last_type_gate_mean", None)
|
| 783 |
+
if type_gate is not None:
|
| 784 |
+
shared_type_gate_values.append(torch.as_tensor(type_gate).detach().float())
|
| 785 |
+
for type_name in MEMORY_TYPE_NAMES:
|
| 786 |
+
value = getattr(adapter, f"last_type_gate_{type_name}_mean", None)
|
| 787 |
+
if value is not None:
|
| 788 |
+
type_gate_values[type_name].append(torch.as_tensor(value).detach().float())
|
| 789 |
+
if ratios:
|
| 790 |
+
values = torch.stack(ratios)
|
| 791 |
+
diagnostics["memory_adapter_delta_ratio_max"] = float(values.max().item())
|
| 792 |
+
diagnostics["memory_adapter_delta_ratio_mean"] = float(values.mean().item())
|
| 793 |
+
if shared_type_gate_values:
|
| 794 |
+
values = torch.stack(shared_type_gate_values)
|
| 795 |
+
diagnostics["memory_adapter_type_gate_mean"] = float(values.mean().item())
|
| 796 |
+
for type_name, values_list in type_gate_values.items():
|
| 797 |
+
if values_list:
|
| 798 |
+
values = torch.stack(values_list)
|
| 799 |
+
diagnostics[f"memory_adapter_type_gate_{type_name}_mean"] = float(values.mean().item())
|
| 800 |
+
return diagnostics
|
| 801 |
+
|
| 802 |
+
def unpatchify(self, x):
|
| 803 |
+
"""
|
| 804 |
+
x: (N, H, W, patch_size**2 * C)
|
| 805 |
+
imgs: (N, H, W, C)
|
| 806 |
+
"""
|
| 807 |
+
c = self.out_channels
|
| 808 |
+
p = self.x_embedder.patch_size[0]
|
| 809 |
+
h = x.shape[1]
|
| 810 |
+
w = x.shape[2]
|
| 811 |
+
|
| 812 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 813 |
+
x = torch.einsum("nhwpqc->nchpwq", x)
|
| 814 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 815 |
+
return imgs
|
| 816 |
+
|
| 817 |
+
def forward(
|
| 818 |
+
self,
|
| 819 |
+
x,
|
| 820 |
+
t,
|
| 821 |
+
action_cond=None,
|
| 822 |
+
pose_cond=None,
|
| 823 |
+
current_frame=None,
|
| 824 |
+
mode=None,
|
| 825 |
+
reference_length=None,
|
| 826 |
+
frame_idx=None,
|
| 827 |
+
memory_tokens=None,
|
| 828 |
+
memory_token_mask=None,
|
| 829 |
+
memory_dynamic_tokens=None,
|
| 830 |
+
memory_dynamic_mask=None,
|
| 831 |
+
memory_retrieval_tokens=None,
|
| 832 |
+
memory_retrieval_mask=None,
|
| 833 |
+
memory_anchor_gate=None,
|
| 834 |
+
memory_dynamic_gate=None,
|
| 835 |
+
memory_retrieval_gate=None,
|
| 836 |
+
):
|
| 837 |
+
"""
|
| 838 |
+
Forward pass of DiT.
|
| 839 |
+
x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 840 |
+
t: (B, T,) tensor of diffusion timesteps
|
| 841 |
+
"""
|
| 842 |
+
|
| 843 |
+
B, T, C, H, W = x.shape
|
| 844 |
+
|
| 845 |
+
# add spatial embeddings
|
| 846 |
+
x = rearrange(x, "b t c h w -> (b t) c h w")
|
| 847 |
+
|
| 848 |
+
x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
|
| 849 |
+
# restore shape
|
| 850 |
+
x = rearrange(x, "(b t) h w d -> b t h w d", t=T)
|
| 851 |
+
# embed noise steps
|
| 852 |
+
t = rearrange(t, "b t -> (b t)")
|
| 853 |
+
|
| 854 |
+
c_t = self.t_embedder(t) # (N, D)
|
| 855 |
+
c = c_t.clone()
|
| 856 |
+
c = rearrange(c, "(b t) d -> b t d", t=T)
|
| 857 |
+
|
| 858 |
+
if torch.is_tensor(action_cond):
|
| 859 |
+
c_action_cond = c + self.external_cond(action_cond)
|
| 860 |
+
else:
|
| 861 |
+
c_action_cond = None
|
| 862 |
+
|
| 863 |
+
for i, block in enumerate(self.blocks):
|
| 864 |
+
x = block(x, c, current_frame=current_frame, timestep=t, is_last_block= (i+1 == len(self.blocks)),
|
| 865 |
+
mode=mode, c_action_cond=c_action_cond, reference_length=reference_length,
|
| 866 |
+
memory_tokens=memory_tokens, memory_token_mask=memory_token_mask,
|
| 867 |
+
memory_dynamic_tokens=memory_dynamic_tokens, memory_dynamic_mask=memory_dynamic_mask,
|
| 868 |
+
memory_retrieval_tokens=memory_retrieval_tokens, memory_retrieval_mask=memory_retrieval_mask,
|
| 869 |
+
memory_anchor_gate=memory_anchor_gate, memory_dynamic_gate=memory_dynamic_gate,
|
| 870 |
+
memory_retrieval_gate=memory_retrieval_gate) # (N, T, H, W, D)
|
| 871 |
+
x = self.final_layer(x, c) # (N, T, H, W, patch_size ** 2 * out_channels)
|
| 872 |
+
# unpatchify
|
| 873 |
+
x = rearrange(x, "b t h w d -> (b t) h w d")
|
| 874 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 875 |
+
x = rearrange(x, "(b t) c h w -> b t c h w", t=T)
|
| 876 |
+
return x
|
| 877 |
+
|
| 878 |
+
|
| 879 |
+
def DiT_S_2(
|
| 880 |
+
action_cond_dim,
|
| 881 |
+
reference_length,
|
| 882 |
+
ref_mode,
|
| 883 |
+
memory_token_cross_attention=False,
|
| 884 |
+
memory_cross_attn_layers=None,
|
| 885 |
+
):
|
| 886 |
+
return DiT(
|
| 887 |
+
patch_size=2,
|
| 888 |
+
hidden_size=1024,
|
| 889 |
+
depth=16,
|
| 890 |
+
num_heads=16,
|
| 891 |
+
action_cond_dim=action_cond_dim,
|
| 892 |
+
reference_length=reference_length,
|
| 893 |
+
memory_token_cross_attention=memory_token_cross_attention,
|
| 894 |
+
memory_cross_attn_layers=memory_cross_attn_layers,
|
| 895 |
+
ref_mode=ref_mode
|
| 896 |
+
)
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
DiT_models = {"DiT-S/2": DiT_S_2}
|
algorithms/worldmem/models/pose_prediction.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class PosePredictionNet(nn.Module):
|
| 6 |
+
def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128):
|
| 7 |
+
super(PosePredictionNet, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.cnn = nn.Sequential(
|
| 10 |
+
nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1),
|
| 11 |
+
nn.ReLU(),
|
| 12 |
+
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
|
| 13 |
+
nn.ReLU(),
|
| 14 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
| 15 |
+
nn.ReLU(),
|
| 16 |
+
nn.AdaptiveAvgPool2d((1, 1))
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
self.fc_img = nn.Linear(128, img_feat_dim)
|
| 20 |
+
|
| 21 |
+
self.mlp_motion = nn.Sequential(
|
| 22 |
+
nn.Linear(pose_dim + action_dim, hidden_dim),
|
| 23 |
+
nn.ReLU(),
|
| 24 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 25 |
+
nn.ReLU()
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
self.fc_out = nn.Sequential(
|
| 29 |
+
nn.Linear(img_feat_dim + hidden_dim, hidden_dim),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
nn.Linear(hidden_dim, pose_dim)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, img, action, pose):
|
| 35 |
+
img_feat = self.cnn(img).view(img.size(0), -1)
|
| 36 |
+
img_feat = self.fc_img(img_feat)
|
| 37 |
+
|
| 38 |
+
motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1))
|
| 39 |
+
fused_feat = torch.cat([img_feat, motion_feat], dim=1)
|
| 40 |
+
pose_next_pred = self.fc_out(fused_feat)
|
| 41 |
+
|
| 42 |
+
return pose_next_pred
|
algorithms/worldmem/models/rotary_embedding_torch.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
from math import pi, log
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn import Module, ModuleList
|
| 10 |
+
from torch.amp import autocast
|
| 11 |
+
from torch import nn, einsum, broadcast_tensors, Tensor
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
|
| 15 |
+
from typing import Literal
|
| 16 |
+
|
| 17 |
+
# helper functions
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def exists(val):
|
| 21 |
+
return val is not None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def default(val, d):
|
| 25 |
+
return val if exists(val) else d
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# broadcat, as tortoise-tts was using it
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def broadcat(tensors, dim=-1):
|
| 32 |
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
| 33 |
+
return torch.cat(broadcasted_tensors, dim=dim)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# rotary embedding helper functions
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def rotate_half(x):
|
| 40 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 41 |
+
x1, x2 = x.unbind(dim=-1)
|
| 42 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 43 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@autocast("cuda", enabled=False)
|
| 47 |
+
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
|
| 48 |
+
dtype = t.dtype
|
| 49 |
+
|
| 50 |
+
if t.ndim == 3:
|
| 51 |
+
seq_len = t.shape[seq_dim]
|
| 52 |
+
freqs = freqs[-seq_len:]
|
| 53 |
+
|
| 54 |
+
rot_dim = freqs.shape[-1]
|
| 55 |
+
end_index = start_index + rot_dim
|
| 56 |
+
|
| 57 |
+
assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
| 58 |
+
|
| 59 |
+
# Split t into three parts: left, middle (to be transformed), and right
|
| 60 |
+
t_left = t[..., :start_index]
|
| 61 |
+
t_middle = t[..., start_index:end_index]
|
| 62 |
+
t_right = t[..., end_index:]
|
| 63 |
+
|
| 64 |
+
# Apply rotary embeddings without modifying t in place
|
| 65 |
+
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
| 66 |
+
|
| 67 |
+
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
| 68 |
+
|
| 69 |
+
return out.type(dtype)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# learned rotation helpers
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
|
| 76 |
+
if exists(freq_ranges):
|
| 77 |
+
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
|
| 78 |
+
rotations = rearrange(rotations, "... r f -> ... (r f)")
|
| 79 |
+
|
| 80 |
+
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
|
| 81 |
+
return apply_rotary_emb(rotations, t, start_index=start_index)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# classes
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class RotaryEmbedding(Module):
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
dim,
|
| 91 |
+
custom_freqs: Tensor | None = None,
|
| 92 |
+
freqs_for: Literal["lang", "pixel", "constant"] = "lang",
|
| 93 |
+
theta=10000,
|
| 94 |
+
max_freq=10,
|
| 95 |
+
num_freqs=1,
|
| 96 |
+
learned_freq=False,
|
| 97 |
+
use_xpos=False,
|
| 98 |
+
xpos_scale_base=512,
|
| 99 |
+
interpolate_factor=1.0,
|
| 100 |
+
theta_rescale_factor=1.0,
|
| 101 |
+
seq_before_head_dim=False,
|
| 102 |
+
cache_if_possible=True,
|
| 103 |
+
cache_max_seq_len=8192,
|
| 104 |
+
):
|
| 105 |
+
super().__init__()
|
| 106 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 107 |
+
# has some connection to NTK literature
|
| 108 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 109 |
+
|
| 110 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 111 |
+
|
| 112 |
+
self.freqs_for = freqs_for
|
| 113 |
+
|
| 114 |
+
if exists(custom_freqs):
|
| 115 |
+
freqs = custom_freqs
|
| 116 |
+
elif freqs_for == "lang":
|
| 117 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 118 |
+
elif freqs_for == "pixel":
|
| 119 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 120 |
+
elif freqs_for == "spacetime":
|
| 121 |
+
time_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 122 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 123 |
+
elif freqs_for == "constant":
|
| 124 |
+
freqs = torch.ones(num_freqs).float()
|
| 125 |
+
|
| 126 |
+
if freqs_for == "spacetime":
|
| 127 |
+
self.time_freqs = nn.Parameter(time_freqs, requires_grad=learned_freq)
|
| 128 |
+
self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
|
| 129 |
+
|
| 130 |
+
self.cache_if_possible = cache_if_possible
|
| 131 |
+
self.cache_max_seq_len = cache_max_seq_len
|
| 132 |
+
|
| 133 |
+
self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
| 134 |
+
self.register_buffer("cached_freqs_seq_len", torch.tensor(0), persistent=False)
|
| 135 |
+
|
| 136 |
+
self.learned_freq = learned_freq
|
| 137 |
+
|
| 138 |
+
# dummy for device
|
| 139 |
+
|
| 140 |
+
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
| 141 |
+
|
| 142 |
+
# default sequence dimension
|
| 143 |
+
|
| 144 |
+
self.seq_before_head_dim = seq_before_head_dim
|
| 145 |
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
| 146 |
+
|
| 147 |
+
# interpolation factors
|
| 148 |
+
|
| 149 |
+
assert interpolate_factor >= 1.0
|
| 150 |
+
self.interpolate_factor = interpolate_factor
|
| 151 |
+
|
| 152 |
+
# xpos
|
| 153 |
+
|
| 154 |
+
self.use_xpos = use_xpos
|
| 155 |
+
|
| 156 |
+
if not use_xpos:
|
| 157 |
+
return
|
| 158 |
+
|
| 159 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 160 |
+
self.scale_base = xpos_scale_base
|
| 161 |
+
|
| 162 |
+
self.register_buffer("scale", scale, persistent=False)
|
| 163 |
+
self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
|
| 164 |
+
self.register_buffer("cached_scales_seq_len", torch.tensor(0), persistent=False)
|
| 165 |
+
|
| 166 |
+
# add apply_rotary_emb as static method
|
| 167 |
+
|
| 168 |
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def device(self):
|
| 172 |
+
return self.dummy.device
|
| 173 |
+
|
| 174 |
+
def get_seq_pos(self, seq_len, device, dtype, offset=0):
|
| 175 |
+
return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
|
| 176 |
+
|
| 177 |
+
def rotate_queries_or_keys(self, t, freqs, seq_dim=None, offset=0, scale=None):
|
| 178 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 179 |
+
|
| 180 |
+
assert not self.use_xpos or exists(scale), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
|
| 181 |
+
|
| 182 |
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
| 183 |
+
|
| 184 |
+
seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
|
| 185 |
+
|
| 186 |
+
seq_freqs = self.forward(seq, freqs, seq_len=seq_len, offset=offset)
|
| 187 |
+
|
| 188 |
+
if seq_dim == -3:
|
| 189 |
+
seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
|
| 190 |
+
|
| 191 |
+
return apply_rotary_emb(seq_freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
|
| 192 |
+
|
| 193 |
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
|
| 194 |
+
dtype, device, seq_dim = (
|
| 195 |
+
q.dtype,
|
| 196 |
+
q.device,
|
| 197 |
+
default(seq_dim, self.default_seq_dim),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
| 201 |
+
assert q_len <= k_len
|
| 202 |
+
|
| 203 |
+
q_scale = k_scale = 1.0
|
| 204 |
+
|
| 205 |
+
if self.use_xpos:
|
| 206 |
+
seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
|
| 207 |
+
|
| 208 |
+
q_scale = self.get_scale(seq[-q_len:]).type(dtype)
|
| 209 |
+
k_scale = self.get_scale(seq).type(dtype)
|
| 210 |
+
|
| 211 |
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
|
| 212 |
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
|
| 213 |
+
|
| 214 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 215 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 216 |
+
|
| 217 |
+
return rotated_q, rotated_k
|
| 218 |
+
|
| 219 |
+
def rotate_queries_and_keys(self, q, k, freqs, seq_dim=None):
|
| 220 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 221 |
+
|
| 222 |
+
assert self.use_xpos
|
| 223 |
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
| 224 |
+
|
| 225 |
+
seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
|
| 226 |
+
|
| 227 |
+
seq_freqs = self.forward(seq, freqs, seq_len=seq_len)
|
| 228 |
+
scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
|
| 229 |
+
|
| 230 |
+
if seq_dim == -3:
|
| 231 |
+
seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
|
| 232 |
+
scale = rearrange(scale, "n d -> n 1 d")
|
| 233 |
+
|
| 234 |
+
rotated_q = apply_rotary_emb(seq_freqs, q, scale=scale, seq_dim=seq_dim)
|
| 235 |
+
rotated_k = apply_rotary_emb(seq_freqs, k, scale=scale**-1, seq_dim=seq_dim)
|
| 236 |
+
|
| 237 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 238 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 239 |
+
|
| 240 |
+
return rotated_q, rotated_k
|
| 241 |
+
|
| 242 |
+
def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
|
| 243 |
+
assert self.use_xpos
|
| 244 |
+
|
| 245 |
+
should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
|
| 246 |
+
|
| 247 |
+
if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len.item():
|
| 248 |
+
return self.cached_scales[offset : (offset + seq_len)]
|
| 249 |
+
|
| 250 |
+
scale = 1.0
|
| 251 |
+
if self.use_xpos:
|
| 252 |
+
power = (t - len(t) // 2) / self.scale_base
|
| 253 |
+
scale = self.scale ** rearrange(power, "n -> n 1")
|
| 254 |
+
scale = repeat(scale, "n d -> n (d r)", r=2)
|
| 255 |
+
|
| 256 |
+
if should_cache and offset == 0:
|
| 257 |
+
self.cached_scales[:seq_len] = scale.detach()
|
| 258 |
+
self.cached_scales_seq_len.copy_(seq_len)
|
| 259 |
+
|
| 260 |
+
return scale
|
| 261 |
+
|
| 262 |
+
def get_axial_freqs(self, *dims):
|
| 263 |
+
Colon = slice(None)
|
| 264 |
+
all_freqs = []
|
| 265 |
+
|
| 266 |
+
for ind, dim in enumerate(dims):
|
| 267 |
+
# only allow pixel freqs for last two dimensions
|
| 268 |
+
use_pixel = (self.freqs_for == "pixel" or self.freqs_for == "spacetime") and ind >= len(dims) - 2
|
| 269 |
+
if use_pixel:
|
| 270 |
+
pos = torch.linspace(-1, 1, steps=dim, device=self.device)
|
| 271 |
+
else:
|
| 272 |
+
pos = torch.arange(dim, device=self.device)
|
| 273 |
+
|
| 274 |
+
if self.freqs_for == "spacetime" and not use_pixel:
|
| 275 |
+
seq_freqs = self.forward(pos, self.time_freqs, seq_len=dim)
|
| 276 |
+
else:
|
| 277 |
+
seq_freqs = self.forward(pos, self.freqs, seq_len=dim)
|
| 278 |
+
|
| 279 |
+
all_axis = [None] * len(dims)
|
| 280 |
+
all_axis[ind] = Colon
|
| 281 |
+
|
| 282 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
| 283 |
+
all_freqs.append(seq_freqs[new_axis_slice])
|
| 284 |
+
|
| 285 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
| 286 |
+
return torch.cat(all_freqs, dim=-1)
|
| 287 |
+
|
| 288 |
+
@autocast("cuda", enabled=False)
|
| 289 |
+
def forward(self, t: Tensor, freqs: Tensor, seq_len=None, offset=0):
|
| 290 |
+
should_cache = self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" and (offset + seq_len) <= self.cache_max_seq_len
|
| 291 |
+
|
| 292 |
+
if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len.item():
|
| 293 |
+
return self.cached_freqs[offset : (offset + seq_len)].detach()
|
| 294 |
+
|
| 295 |
+
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
| 296 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
| 297 |
+
|
| 298 |
+
if should_cache and offset == 0:
|
| 299 |
+
self.cached_freqs[:seq_len] = freqs.detach()
|
| 300 |
+
self.cached_freqs_seq_len.copy_(seq_len)
|
| 301 |
+
|
| 302 |
+
return freqs
|
algorithms/worldmem/models/utils.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
|
| 3 |
+
Action format derived from VPT https://github.com/openai/Video-Pre-Training
|
| 4 |
+
Adapted from https://github.com/etched-ai/open-oasis/blob/master/utils.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import math
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torchvision.io import read_image, read_video
|
| 11 |
+
from torchvision.transforms.functional import resize
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from typing import Mapping, Sequence
|
| 14 |
+
from einops import rearrange, parse_shape
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def exists(val):
|
| 18 |
+
return val is not None
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def default(val, d):
|
| 22 |
+
if exists(val):
|
| 23 |
+
return val
|
| 24 |
+
return d() if callable(d) else d
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def extract(a, t, x_shape):
|
| 28 |
+
f, b = t.shape
|
| 29 |
+
out = a[t]
|
| 30 |
+
return out.reshape(f, b, *((1,) * (len(x_shape) - 2)))
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def linear_beta_schedule(timesteps):
|
| 34 |
+
"""
|
| 35 |
+
linear schedule, proposed in original ddpm paper
|
| 36 |
+
"""
|
| 37 |
+
scale = 1000 / timesteps
|
| 38 |
+
beta_start = scale * 0.0001
|
| 39 |
+
beta_end = scale * 0.02
|
| 40 |
+
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
| 44 |
+
"""
|
| 45 |
+
cosine schedule
|
| 46 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
| 47 |
+
"""
|
| 48 |
+
steps = timesteps + 1
|
| 49 |
+
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
|
| 50 |
+
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
|
| 51 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 52 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 53 |
+
return torch.clip(betas, 0, 0.999)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
|
| 58 |
+
"""
|
| 59 |
+
sigmoid schedule
|
| 60 |
+
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
|
| 61 |
+
better for images > 64x64, when used during training
|
| 62 |
+
"""
|
| 63 |
+
steps = timesteps + 1
|
| 64 |
+
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
|
| 65 |
+
v_start = torch.tensor(start / tau).sigmoid()
|
| 66 |
+
v_end = torch.tensor(end / tau).sigmoid()
|
| 67 |
+
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
|
| 68 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 69 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 70 |
+
return torch.clip(betas, 0, 0.999)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
ACTION_KEYS = [
|
| 74 |
+
"inventory",
|
| 75 |
+
"ESC",
|
| 76 |
+
"hotbar.1",
|
| 77 |
+
"hotbar.2",
|
| 78 |
+
"hotbar.3",
|
| 79 |
+
"hotbar.4",
|
| 80 |
+
"hotbar.5",
|
| 81 |
+
"hotbar.6",
|
| 82 |
+
"hotbar.7",
|
| 83 |
+
"hotbar.8",
|
| 84 |
+
"hotbar.9",
|
| 85 |
+
"forward",
|
| 86 |
+
"back",
|
| 87 |
+
"left",
|
| 88 |
+
"right",
|
| 89 |
+
"cameraX",
|
| 90 |
+
"cameraY",
|
| 91 |
+
"jump",
|
| 92 |
+
"sneak",
|
| 93 |
+
"sprint",
|
| 94 |
+
"swapHands",
|
| 95 |
+
"attack",
|
| 96 |
+
"use",
|
| 97 |
+
"pickItem",
|
| 98 |
+
"drop",
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
|
| 103 |
+
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
|
| 104 |
+
for i, current_actions in enumerate(actions):
|
| 105 |
+
for j, action_key in enumerate(ACTION_KEYS):
|
| 106 |
+
if action_key.startswith("camera"):
|
| 107 |
+
if action_key == "cameraX":
|
| 108 |
+
value = current_actions["camera"][0]
|
| 109 |
+
elif action_key == "cameraY":
|
| 110 |
+
value = current_actions["camera"][1]
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Unknown camera action key: {action_key}")
|
| 113 |
+
max_val = 20
|
| 114 |
+
bin_size = 0.5
|
| 115 |
+
num_buckets = int(max_val / bin_size)
|
| 116 |
+
value = (value - num_buckets) / num_buckets
|
| 117 |
+
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
|
| 118 |
+
else:
|
| 119 |
+
value = current_actions[action_key]
|
| 120 |
+
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
|
| 121 |
+
actions_one_hot[i, j] = value
|
| 122 |
+
|
| 123 |
+
return actions_one_hot
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
IMAGE_EXTENSIONS = {"png", "jpg", "jpeg"}
|
| 127 |
+
VIDEO_EXTENSIONS = {"mp4"}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load_prompt(path, video_offset=None, n_prompt_frames=1):
|
| 131 |
+
if path.lower().split(".")[-1] in IMAGE_EXTENSIONS:
|
| 132 |
+
print("prompt is image; ignoring video_offset and n_prompt_frames")
|
| 133 |
+
prompt = read_image(path)
|
| 134 |
+
# add frame dimension
|
| 135 |
+
prompt = rearrange(prompt, "c h w -> 1 c h w")
|
| 136 |
+
elif path.lower().split(".")[-1] in VIDEO_EXTENSIONS:
|
| 137 |
+
prompt = read_video(path, pts_unit="sec")[0]
|
| 138 |
+
if video_offset is not None:
|
| 139 |
+
prompt = prompt[video_offset:]
|
| 140 |
+
prompt = prompt[:n_prompt_frames]
|
| 141 |
+
else:
|
| 142 |
+
raise ValueError(f"unrecognized prompt file extension; expected one in {IMAGE_EXTENSIONS} or {VIDEO_EXTENSIONS}")
|
| 143 |
+
assert prompt.shape[0] == n_prompt_frames, f"input prompt {path} had less than n_prompt_frames={n_prompt_frames} frames"
|
| 144 |
+
prompt = resize(prompt, (360, 640))
|
| 145 |
+
# add batch dimension
|
| 146 |
+
prompt = rearrange(prompt, "t c h w -> 1 t c h w")
|
| 147 |
+
prompt = prompt.float() / 255.0
|
| 148 |
+
return prompt
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def load_actions(path, action_offset=None):
|
| 152 |
+
if path.endswith(".actions.pt"):
|
| 153 |
+
actions = one_hot_actions(torch.load(path))
|
| 154 |
+
elif path.endswith(".one_hot_actions.pt"):
|
| 155 |
+
actions = torch.load(path, weights_only=True)
|
| 156 |
+
else:
|
| 157 |
+
raise ValueError("unrecognized action file extension; expected '*.actions.pt' or '*.one_hot_actions.pt'")
|
| 158 |
+
if action_offset is not None:
|
| 159 |
+
actions = actions[action_offset:]
|
| 160 |
+
actions = torch.cat([torch.zeros_like(actions[:1]), actions], dim=0)
|
| 161 |
+
# add batch dimension
|
| 162 |
+
actions = rearrange(actions, "t d -> 1 t d")
|
| 163 |
+
return actions
|
algorithms/worldmem/models/vae.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
References:
|
| 3 |
+
- VQGAN: https://github.com/CompVis/taming-transformers
|
| 4 |
+
- MAE: https://github.com/facebookresearch/mae
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import math
|
| 9 |
+
import functools
|
| 10 |
+
from collections import namedtuple
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from timm.models.vision_transformer import Mlp
|
| 16 |
+
from timm.layers.helpers import to_2tuple
|
| 17 |
+
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
|
| 18 |
+
from .dit import PatchEmbed
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DiagonalGaussianDistribution(object):
|
| 22 |
+
def __init__(self, parameters, deterministic=False, dim=1):
|
| 23 |
+
self.parameters = parameters
|
| 24 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
| 25 |
+
if dim == 1:
|
| 26 |
+
self.dims = [1, 2, 3]
|
| 27 |
+
elif dim == 2:
|
| 28 |
+
self.dims = [1, 2]
|
| 29 |
+
else:
|
| 30 |
+
raise NotImplementedError
|
| 31 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 32 |
+
self.deterministic = deterministic
|
| 33 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 34 |
+
self.var = torch.exp(self.logvar)
|
| 35 |
+
if self.deterministic:
|
| 36 |
+
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
| 37 |
+
|
| 38 |
+
def sample(self):
|
| 39 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
def mode(self):
|
| 43 |
+
return self.mean
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Attention(nn.Module):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
dim,
|
| 50 |
+
num_heads,
|
| 51 |
+
frame_height,
|
| 52 |
+
frame_width,
|
| 53 |
+
qkv_bias=False,
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.num_heads = num_heads
|
| 57 |
+
head_dim = dim // num_heads
|
| 58 |
+
self.frame_height = frame_height
|
| 59 |
+
self.frame_width = frame_width
|
| 60 |
+
|
| 61 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 62 |
+
self.proj = nn.Linear(dim, dim)
|
| 63 |
+
|
| 64 |
+
rotary_freqs = RotaryEmbedding(
|
| 65 |
+
dim=head_dim // 4,
|
| 66 |
+
freqs_for="pixel",
|
| 67 |
+
max_freq=frame_height * frame_width,
|
| 68 |
+
).get_axial_freqs(frame_height, frame_width)
|
| 69 |
+
self.register_buffer("rotary_freqs", rotary_freqs, persistent=False)
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
B, N, C = x.shape
|
| 73 |
+
assert N == self.frame_height * self.frame_width
|
| 74 |
+
|
| 75 |
+
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
| 76 |
+
|
| 77 |
+
q = rearrange(
|
| 78 |
+
q,
|
| 79 |
+
"b (H W) (h d) -> b h H W d",
|
| 80 |
+
H=self.frame_height,
|
| 81 |
+
W=self.frame_width,
|
| 82 |
+
h=self.num_heads,
|
| 83 |
+
)
|
| 84 |
+
k = rearrange(
|
| 85 |
+
k,
|
| 86 |
+
"b (H W) (h d) -> b h H W d",
|
| 87 |
+
H=self.frame_height,
|
| 88 |
+
W=self.frame_width,
|
| 89 |
+
h=self.num_heads,
|
| 90 |
+
)
|
| 91 |
+
v = rearrange(
|
| 92 |
+
v,
|
| 93 |
+
"b (H W) (h d) -> b h H W d",
|
| 94 |
+
H=self.frame_height,
|
| 95 |
+
W=self.frame_width,
|
| 96 |
+
h=self.num_heads,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
q = apply_rotary_emb(self.rotary_freqs, q)
|
| 100 |
+
k = apply_rotary_emb(self.rotary_freqs, k)
|
| 101 |
+
|
| 102 |
+
q = rearrange(q, "b h H W d -> b h (H W) d")
|
| 103 |
+
k = rearrange(k, "b h H W d -> b h (H W) d")
|
| 104 |
+
v = rearrange(v, "b h H W d -> b h (H W) d")
|
| 105 |
+
|
| 106 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 107 |
+
x = rearrange(x, "b h N d -> b N (h d)")
|
| 108 |
+
|
| 109 |
+
x = self.proj(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class AttentionBlock(nn.Module):
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
dim,
|
| 117 |
+
num_heads,
|
| 118 |
+
frame_height,
|
| 119 |
+
frame_width,
|
| 120 |
+
mlp_ratio=4.0,
|
| 121 |
+
qkv_bias=False,
|
| 122 |
+
attn_causal=False,
|
| 123 |
+
act_layer=nn.GELU,
|
| 124 |
+
norm_layer=nn.LayerNorm,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.norm1 = norm_layer(dim)
|
| 128 |
+
self.attn = Attention(
|
| 129 |
+
dim,
|
| 130 |
+
num_heads,
|
| 131 |
+
frame_height,
|
| 132 |
+
frame_width,
|
| 133 |
+
qkv_bias=qkv_bias,
|
| 134 |
+
)
|
| 135 |
+
self.norm2 = norm_layer(dim)
|
| 136 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 137 |
+
self.mlp = Mlp(
|
| 138 |
+
in_features=dim,
|
| 139 |
+
hidden_features=mlp_hidden_dim,
|
| 140 |
+
act_layer=act_layer,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
x = x + self.attn(self.norm1(x))
|
| 145 |
+
x = x + self.mlp(self.norm2(x))
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class AutoencoderKL(nn.Module):
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
latent_dim,
|
| 153 |
+
input_height=256,
|
| 154 |
+
input_width=256,
|
| 155 |
+
patch_size=16,
|
| 156 |
+
enc_dim=768,
|
| 157 |
+
enc_depth=6,
|
| 158 |
+
enc_heads=12,
|
| 159 |
+
dec_dim=768,
|
| 160 |
+
dec_depth=6,
|
| 161 |
+
dec_heads=12,
|
| 162 |
+
mlp_ratio=4.0,
|
| 163 |
+
norm_layer=functools.partial(nn.LayerNorm, eps=1e-6),
|
| 164 |
+
use_variational=True,
|
| 165 |
+
**kwargs,
|
| 166 |
+
):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.input_height = input_height
|
| 169 |
+
self.input_width = input_width
|
| 170 |
+
self.patch_size = patch_size
|
| 171 |
+
self.seq_h = input_height // patch_size
|
| 172 |
+
self.seq_w = input_width // patch_size
|
| 173 |
+
self.seq_len = self.seq_h * self.seq_w
|
| 174 |
+
self.patch_dim = 3 * patch_size**2
|
| 175 |
+
|
| 176 |
+
self.latent_dim = latent_dim
|
| 177 |
+
self.enc_dim = enc_dim
|
| 178 |
+
self.dec_dim = dec_dim
|
| 179 |
+
|
| 180 |
+
# patch
|
| 181 |
+
self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim)
|
| 182 |
+
|
| 183 |
+
# encoder
|
| 184 |
+
self.encoder = nn.ModuleList(
|
| 185 |
+
[
|
| 186 |
+
AttentionBlock(
|
| 187 |
+
enc_dim,
|
| 188 |
+
enc_heads,
|
| 189 |
+
self.seq_h,
|
| 190 |
+
self.seq_w,
|
| 191 |
+
mlp_ratio,
|
| 192 |
+
qkv_bias=True,
|
| 193 |
+
norm_layer=norm_layer,
|
| 194 |
+
)
|
| 195 |
+
for i in range(enc_depth)
|
| 196 |
+
]
|
| 197 |
+
)
|
| 198 |
+
self.enc_norm = norm_layer(enc_dim)
|
| 199 |
+
|
| 200 |
+
# bottleneck
|
| 201 |
+
self.use_variational = use_variational
|
| 202 |
+
mult = 2 if self.use_variational else 1
|
| 203 |
+
self.quant_conv = nn.Linear(enc_dim, mult * latent_dim)
|
| 204 |
+
self.post_quant_conv = nn.Linear(latent_dim, dec_dim)
|
| 205 |
+
|
| 206 |
+
# decoder
|
| 207 |
+
self.decoder = nn.ModuleList(
|
| 208 |
+
[
|
| 209 |
+
AttentionBlock(
|
| 210 |
+
dec_dim,
|
| 211 |
+
dec_heads,
|
| 212 |
+
self.seq_h,
|
| 213 |
+
self.seq_w,
|
| 214 |
+
mlp_ratio,
|
| 215 |
+
qkv_bias=True,
|
| 216 |
+
norm_layer=norm_layer,
|
| 217 |
+
)
|
| 218 |
+
for i in range(dec_depth)
|
| 219 |
+
]
|
| 220 |
+
)
|
| 221 |
+
self.dec_norm = norm_layer(dec_dim)
|
| 222 |
+
self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch
|
| 223 |
+
|
| 224 |
+
# initialize this weight first
|
| 225 |
+
self.initialize_weights()
|
| 226 |
+
|
| 227 |
+
def initialize_weights(self):
|
| 228 |
+
# initialization
|
| 229 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 230 |
+
self.apply(self._init_weights)
|
| 231 |
+
|
| 232 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 233 |
+
w = self.patch_embed.proj.weight.data
|
| 234 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 235 |
+
|
| 236 |
+
def _init_weights(self, m):
|
| 237 |
+
if isinstance(m, nn.Linear):
|
| 238 |
+
# we use xavier_uniform following official JAX ViT:
|
| 239 |
+
nn.init.xavier_uniform_(m.weight)
|
| 240 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 241 |
+
nn.init.constant_(m.bias, 0.0)
|
| 242 |
+
elif isinstance(m, nn.LayerNorm):
|
| 243 |
+
nn.init.constant_(m.bias, 0.0)
|
| 244 |
+
nn.init.constant_(m.weight, 1.0)
|
| 245 |
+
|
| 246 |
+
def patchify(self, x):
|
| 247 |
+
# patchify
|
| 248 |
+
bsz, _, h, w = x.shape
|
| 249 |
+
x = x.reshape(
|
| 250 |
+
bsz,
|
| 251 |
+
3,
|
| 252 |
+
self.seq_h,
|
| 253 |
+
self.patch_size,
|
| 254 |
+
self.seq_w,
|
| 255 |
+
self.patch_size,
|
| 256 |
+
).permute([0, 1, 3, 5, 2, 4]) # [b, c, h, p, w, p] --> [b, c, p, p, h, w]
|
| 257 |
+
x = x.reshape(bsz, self.patch_dim, self.seq_h, self.seq_w) # --> [b, cxpxp, h, w]
|
| 258 |
+
x = x.permute([0, 2, 3, 1]).reshape(bsz, self.seq_len, self.patch_dim) # --> [b, hxw, cxpxp]
|
| 259 |
+
return x
|
| 260 |
+
|
| 261 |
+
def unpatchify(self, x):
|
| 262 |
+
bsz = x.shape[0]
|
| 263 |
+
# unpatchify
|
| 264 |
+
x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute([0, 3, 1, 2]) # [b, h, w, cxpxp] --> [b, cxpxp, h, w]
|
| 265 |
+
x = x.reshape(
|
| 266 |
+
bsz,
|
| 267 |
+
3,
|
| 268 |
+
self.patch_size,
|
| 269 |
+
self.patch_size,
|
| 270 |
+
self.seq_h,
|
| 271 |
+
self.seq_w,
|
| 272 |
+
).permute([0, 1, 4, 2, 5, 3]) # [b, c, p, p, h, w] --> [b, c, h, p, w, p]
|
| 273 |
+
x = x.reshape(
|
| 274 |
+
bsz,
|
| 275 |
+
3,
|
| 276 |
+
self.input_height,
|
| 277 |
+
self.input_width,
|
| 278 |
+
) # [b, c, hxp, wxp]
|
| 279 |
+
return x
|
| 280 |
+
|
| 281 |
+
def encode(self, x):
|
| 282 |
+
# patchify
|
| 283 |
+
x = self.patch_embed(x)
|
| 284 |
+
|
| 285 |
+
# encoder
|
| 286 |
+
for blk in self.encoder:
|
| 287 |
+
x = blk(x)
|
| 288 |
+
x = self.enc_norm(x)
|
| 289 |
+
|
| 290 |
+
# bottleneck
|
| 291 |
+
moments = self.quant_conv(x)
|
| 292 |
+
if not self.use_variational:
|
| 293 |
+
moments = torch.cat((moments, torch.zeros_like(moments)), 2)
|
| 294 |
+
posterior = DiagonalGaussianDistribution(moments, deterministic=(not self.use_variational), dim=2)
|
| 295 |
+
return posterior
|
| 296 |
+
|
| 297 |
+
def decode(self, z):
|
| 298 |
+
# bottleneck
|
| 299 |
+
z = self.post_quant_conv(z)
|
| 300 |
+
|
| 301 |
+
# decoder
|
| 302 |
+
for blk in self.decoder:
|
| 303 |
+
z = blk(z)
|
| 304 |
+
z = self.dec_norm(z)
|
| 305 |
+
|
| 306 |
+
# predictor
|
| 307 |
+
z = self.predictor(z)
|
| 308 |
+
|
| 309 |
+
# unpatchify
|
| 310 |
+
dec = self.unpatchify(z)
|
| 311 |
+
return dec
|
| 312 |
+
|
| 313 |
+
def autoencode(self, input, sample_posterior=True):
|
| 314 |
+
posterior = self.encode(input)
|
| 315 |
+
if self.use_variational and sample_posterior:
|
| 316 |
+
z = posterior.sample()
|
| 317 |
+
else:
|
| 318 |
+
z = posterior.mode()
|
| 319 |
+
dec = self.decode(z)
|
| 320 |
+
return dec, posterior, z
|
| 321 |
+
|
| 322 |
+
def get_input(self, batch, k):
|
| 323 |
+
x = batch[k]
|
| 324 |
+
if len(x.shape) == 3:
|
| 325 |
+
x = x[..., None]
|
| 326 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
def forward(self, inputs, labels, split="train"):
|
| 330 |
+
rec, post, latent = self.autoencode(inputs)
|
| 331 |
+
return rec, post, latent
|
| 332 |
+
|
| 333 |
+
def get_last_layer(self):
|
| 334 |
+
return self.predictor.weight
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def ViT_L_20_Shallow_Encoder(**kwargs):
|
| 338 |
+
if "latent_dim" in kwargs:
|
| 339 |
+
latent_dim = kwargs.pop("latent_dim")
|
| 340 |
+
else:
|
| 341 |
+
latent_dim = 16
|
| 342 |
+
return AutoencoderKL(
|
| 343 |
+
latent_dim=latent_dim,
|
| 344 |
+
patch_size=20,
|
| 345 |
+
enc_dim=1024,
|
| 346 |
+
enc_depth=6,
|
| 347 |
+
enc_heads=16,
|
| 348 |
+
dec_dim=1024,
|
| 349 |
+
dec_depth=12,
|
| 350 |
+
dec_heads=16,
|
| 351 |
+
input_height=360,
|
| 352 |
+
input_width=640,
|
| 353 |
+
**kwargs,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
VAE_models = {
|
| 358 |
+
"vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder,
|
| 359 |
+
}
|
configurations/algorithm/base_algo.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class
|
| 2 |
+
|
| 3 |
+
debug: ${debug} # inherited from configurations/config.yaml
|
configurations/algorithm/base_pytorch_algo.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_algo # inherits from configurations/algorithm/base_algo.yaml
|
| 3 |
+
|
| 4 |
+
lr: ${experiment.training.lr}
|
configurations/algorithm/base_video_dit.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- df_base
|
| 3 |
+
|
| 4 |
+
n_frames: ${dataset.n_frames}
|
| 5 |
+
frame_skip: ${dataset.frame_skip}
|
| 6 |
+
metadata: ${dataset.metadata}
|
| 7 |
+
|
| 8 |
+
# training hyperparameters
|
| 9 |
+
weight_decay: 2e-3
|
| 10 |
+
warmup_steps: 1000
|
| 11 |
+
optimizer_beta: [0.9, 0.99]
|
| 12 |
+
action_cond_dim: 25
|
| 13 |
+
use_plucker: true
|
| 14 |
+
|
| 15 |
+
diffusion:
|
| 16 |
+
# training
|
| 17 |
+
beta_schedule: sigmoid
|
| 18 |
+
objective: pred_v
|
| 19 |
+
use_fused_snr: True
|
| 20 |
+
cum_snr_decay: 0.96
|
| 21 |
+
clip_noise: 20.
|
| 22 |
+
# sampling
|
| 23 |
+
sampling_timesteps: 20
|
| 24 |
+
ddim_sampling_eta: 0.0
|
| 25 |
+
stabilization_level: 15
|
| 26 |
+
# architecture
|
| 27 |
+
architecture:
|
| 28 |
+
network_size: 64
|
| 29 |
+
attn_heads: 4
|
| 30 |
+
attn_dim_head: 64
|
| 31 |
+
dim_mults: [1, 2, 4, 8]
|
| 32 |
+
resolution: ${dataset.resolution}
|
| 33 |
+
attn_resolutions: [16, 32, 64, 128]
|
| 34 |
+
use_init_temporal_attn: True
|
| 35 |
+
use_linear_attn: True
|
| 36 |
+
time_emb_type: rotary
|
configurations/algorithm/dememwm_memory_dit.yaml
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
defaults:
|
| 3 |
+
- base_video_dit
|
| 4 |
+
- _self_
|
| 5 |
+
|
| 6 |
+
_name: dememwm_memory_dit
|
| 7 |
+
|
| 8 |
+
# Standalone Memory-DiT path. Do not route through old SSM-memory config.
|
| 9 |
+
memory_token_cross_attention: true
|
| 10 |
+
memory_cross_attn_layers: null
|
| 11 |
+
memory_condition_length: 0
|
| 12 |
+
pose_cond_dim: 5
|
| 13 |
+
log_video: false
|
| 14 |
+
|
| 15 |
+
dememwm:
|
| 16 |
+
enabled: true
|
| 17 |
+
training_stage: stage_1 # fallback only when curriculum.enabled=false
|
| 18 |
+
debug_force_all_streams: false
|
| 19 |
+
curriculum:
|
| 20 |
+
enabled: true
|
| 21 |
+
full_stage_start_step: 60000
|
| 22 |
+
freeze_vae: true
|
| 23 |
+
dit_freeze:
|
| 24 |
+
enabled: true
|
| 25 |
+
lr:
|
| 26 |
+
dememwm_modules: 1.0e-4
|
| 27 |
+
memory_adapters: 1.0e-4
|
| 28 |
+
full_dit: 1.0e-5
|
| 29 |
+
# Current Conv2D memory projectors preserve latent H,W=(18,32).
|
| 30 |
+
# Pool sizes are resolved from projected spatial grid size and downsample ratios.
|
| 31 |
+
token_patch_size: 2
|
| 32 |
+
anchor:
|
| 33 |
+
enabled: true
|
| 34 |
+
anchor_indices: [0, 1, 2, 3]
|
| 35 |
+
allow_generated_as_anchor: false
|
| 36 |
+
diverse_selection: true
|
| 37 |
+
compress:
|
| 38 |
+
downsample_ratio: 4
|
| 39 |
+
dynamic:
|
| 40 |
+
enabled: true
|
| 41 |
+
exclude_latest_local_frames: 4
|
| 42 |
+
recent_frames: 8
|
| 43 |
+
conv_kernel_t: 3
|
| 44 |
+
conv_stride_t: 2
|
| 45 |
+
revisit:
|
| 46 |
+
enabled: true
|
| 47 |
+
deterministic_pose_retrieval: true
|
| 48 |
+
fov_overlap_threshold: 0.30
|
| 49 |
+
high_quality_fov_threshold: 0.70
|
| 50 |
+
plucker_weight: 0.10
|
| 51 |
+
max_frames: 2
|
| 52 |
+
# FoV geometry for coverage-based retrieval scoring.
|
| 53 |
+
# fov_half_h/v: half-angles (degrees) of the horizontal/vertical field of view.
|
| 54 |
+
# fov_radius: world-space radius of the sample sphere.
|
| 55 |
+
# fov_{yaw,pitch,depth}_samples: grid resolution for FoV point sampling.
|
| 56 |
+
fov_half_h: 52.5 # 105 deg total horizontal FoV
|
| 57 |
+
fov_half_v: 37.5 # 75 deg total vertical FoV
|
| 58 |
+
fov_radius: 30.0
|
| 59 |
+
fov_yaw_samples: 25
|
| 60 |
+
fov_pitch_samples: 20
|
| 61 |
+
fov_depth_samples: 20
|
| 62 |
+
pose_preselect_topk: 64
|
| 63 |
+
# Plucker descriptor grid for secondary pose-similarity scoring.
|
| 64 |
+
plucker_grid_h: 4
|
| 65 |
+
plucker_grid_w: 4
|
| 66 |
+
plucker_focal_length: 0.35
|
| 67 |
+
compress:
|
| 68 |
+
downsample_ratio: 4
|
| 69 |
+
stage_policy:
|
| 70 |
+
noise_bucket_logging: true
|
| 71 |
+
eval_ablation:
|
| 72 |
+
enabled: false
|
| 73 |
+
branch: A_plus_D_plus_R_normal
|
| 74 |
+
generated_history_proxy:
|
| 75 |
+
enabled: false
|
| 76 |
+
start_step: 0
|
| 77 |
+
ramp_steps: 1
|
| 78 |
+
max_prob: 0.0
|
| 79 |
+
noise_std: 0.25
|
| 80 |
+
dropout_prob: 0.0
|
| 81 |
+
injection:
|
| 82 |
+
dit_hidden_size: 1024
|
| 83 |
+
anchor_gate: 1.0
|
| 84 |
+
dynamic_gate: 1.0
|
| 85 |
+
revisit_gate: 1.0
|
| 86 |
+
cache:
|
| 87 |
+
enabled: true
|
| 88 |
+
device: cpu
|
| 89 |
+
keep_raw_latents: all
|
| 90 |
+
keep_compressed_records: true
|
| 91 |
+
keep_prefix_anchors: true
|
| 92 |
+
eviction_policy: none
|
| 93 |
+
no_evict: true
|
| 94 |
+
clear_between_videos: true
|
| 95 |
+
max_records: null
|
| 96 |
+
max_slots: null
|
| 97 |
+
on_capacity_exceeded: warn
|
| 98 |
+
checkpoint:
|
| 99 |
+
strict_dememwm_eval_load: true
|
| 100 |
+
|
| 101 |
+
diffusion:
|
| 102 |
+
architecture:
|
| 103 |
+
network_size: 64
|
configurations/algorithm/df_base.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_pytorch_algo
|
| 3 |
+
|
| 4 |
+
# dataset-dependent configurations
|
| 5 |
+
x_shape: ${dataset.observation_shape}
|
| 6 |
+
frame_stack: 1
|
| 7 |
+
frame_skip: 1
|
| 8 |
+
data_mean: ${dataset.data_mean}
|
| 9 |
+
data_std: ${dataset.data_std}
|
| 10 |
+
external_cond_dim: 0 #${dataset.action_dim}
|
| 11 |
+
context_frames: ${dataset.context_length}
|
| 12 |
+
# training hyperparameters
|
| 13 |
+
weight_decay: 1e-4
|
| 14 |
+
warmup_steps: 10000
|
| 15 |
+
optimizer_beta: [0.9, 0.999]
|
| 16 |
+
# diffusion-related
|
| 17 |
+
uncertainty_scale: 1
|
| 18 |
+
guidance_scale: 0.0
|
| 19 |
+
chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
|
| 20 |
+
scheduling_matrix: autoregressive
|
| 21 |
+
noise_level: random_all
|
| 22 |
+
causal: True
|
| 23 |
+
|
| 24 |
+
diffusion:
|
| 25 |
+
# training
|
| 26 |
+
objective: pred_x0
|
| 27 |
+
beta_schedule: cosine
|
| 28 |
+
schedule_fn_kwargs: {}
|
| 29 |
+
clip_noise: 20.0
|
| 30 |
+
use_snr: False
|
| 31 |
+
use_cum_snr: False
|
| 32 |
+
use_fused_snr: False
|
| 33 |
+
snr_clip: 5.0
|
| 34 |
+
cum_snr_decay: 0.98
|
| 35 |
+
timesteps: 1000
|
| 36 |
+
# sampling
|
| 37 |
+
sampling_timesteps: 50 # fixme, numer of diffusion steps, should be increased
|
| 38 |
+
ddim_sampling_eta: 1.0
|
| 39 |
+
stabilization_level: 10
|
| 40 |
+
# architecture
|
| 41 |
+
architecture:
|
| 42 |
+
network_size: 64
|
configurations/dataset/base_dataset.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This will be passed as the cfg to Dataset.__init__(cfg) of your dataset class
|
| 2 |
+
|
| 3 |
+
debug: ${debug} # inherited from configurations/config.yaml
|
configurations/dataset/base_video.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_dataset
|
| 3 |
+
|
| 4 |
+
metadata: "data/${dataset.name}/metadata.json"
|
| 5 |
+
data_mean: "data/${dataset.name}/data_mean.npy"
|
| 6 |
+
data_std: "data/${dataset.name}/data_std.npy"
|
| 7 |
+
save_dir: ???
|
| 8 |
+
n_frames: 32
|
| 9 |
+
context_length: 4
|
| 10 |
+
resolution: 128
|
| 11 |
+
observation_shape: [3, "${dataset.resolution}", "${dataset.resolution}"]
|
| 12 |
+
external_cond_dim: 0
|
| 13 |
+
validation_multiplier: 1
|
| 14 |
+
frame_skip: 1
|
configurations/dataset/video_minecraft.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_video
|
| 3 |
+
|
| 4 |
+
save_dir: data/minecraft_simple_backforward
|
| 5 |
+
n_frames: 16 # TODO: increase later
|
| 6 |
+
resolution: 128
|
| 7 |
+
data_mean: 0.5
|
| 8 |
+
data_std: 0.5
|
| 9 |
+
action_cond_dim: 25
|
| 10 |
+
context_length: 1
|
| 11 |
+
frame_skip: 1
|
| 12 |
+
validation_multiplier: 1
|
| 13 |
+
|
| 14 |
+
_name: video_minecraft_oasis
|
configurations/dataset/video_minecraft_latent.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- video_minecraft
|
| 3 |
+
|
| 4 |
+
precomputed_feature_dir: /share_1/users/bonan_ding/worldmem_data/minecraft/vae_features
|
| 5 |
+
|
| 6 |
+
_name: video_minecraft_latent
|
configurations/experiment/base_experiment.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
debug: ${debug} # inherited from configurations/config.yaml
|
| 2 |
+
tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a subset of them.
|
configurations/experiment/base_pytorch.yaml
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inherites from base_experiment.yaml
|
| 2 |
+
# most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html
|
| 3 |
+
|
| 4 |
+
defaults:
|
| 5 |
+
- base_experiment
|
| 6 |
+
|
| 7 |
+
tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a subset of them.
|
| 8 |
+
num_nodes: 1 # number of gpu servers used in large scale distributed training
|
| 9 |
+
|
| 10 |
+
training:
|
| 11 |
+
precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable
|
| 12 |
+
compile: False # whether to compile the model with torch.compile
|
| 13 |
+
lr: 0.001 # learning rate
|
| 14 |
+
batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training
|
| 15 |
+
max_epochs: 1000 # set to -1 to train forever
|
| 16 |
+
max_steps: -1 # set to -1 to train forever, will override max_epochs
|
| 17 |
+
max_time: null # set to something like "00:12:00:00" to enable
|
| 18 |
+
data:
|
| 19 |
+
num_workers: 4 # number of CPU threads for data preprocessing.
|
| 20 |
+
shuffle: True # whether training data will be shuffled
|
| 21 |
+
optim:
|
| 22 |
+
accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop
|
| 23 |
+
gradient_clip_val: 1.0 # clip gradients with norm above this value, set to 0 to disable
|
| 24 |
+
checkpointing:
|
| 25 |
+
# these are arguments to pytorch lightning's callback, `ModelCheckpoint` class
|
| 26 |
+
every_n_train_steps: 5000 # save a checkpoint every n train steps
|
| 27 |
+
every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``
|
| 28 |
+
train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
|
| 29 |
+
save_last: True # keep last.ckpt for automatic interrupted-run resume
|
| 30 |
+
enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones.
|
| 31 |
+
|
| 32 |
+
validation:
|
| 33 |
+
precision: 16-mixed
|
| 34 |
+
compile: False # whether to compile the model with torch.compile
|
| 35 |
+
batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
| 36 |
+
val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
|
| 37 |
+
val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
|
| 38 |
+
limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
|
| 39 |
+
inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
|
| 40 |
+
data:
|
| 41 |
+
num_workers: 4 # number of CPU threads for data preprocessing, for validation.
|
| 42 |
+
shuffle: False # whether validation data will be shuffled
|
| 43 |
+
|
| 44 |
+
test:
|
| 45 |
+
precision: 16-mixed
|
| 46 |
+
compile: False # whether to compile the model with torch.compile
|
| 47 |
+
batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
|
| 48 |
+
limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
|
| 49 |
+
data:
|
| 50 |
+
num_workers: 4 # number of CPU threads for data preprocessing, for test.
|
| 51 |
+
shuffle: False # whether test data will be shuffled
|
configurations/experiment/exp_video.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- base_pytorch
|
| 3 |
+
|
| 4 |
+
tasks: [training]
|
| 5 |
+
|
| 6 |
+
training:
|
| 7 |
+
lr: 2e-5
|
| 8 |
+
precision: 16-mixed
|
| 9 |
+
batch_size: 4
|
| 10 |
+
max_epochs: -1
|
| 11 |
+
max_steps: 2000005
|
| 12 |
+
checkpointing:
|
| 13 |
+
every_n_train_steps: 2500
|
| 14 |
+
optim:
|
| 15 |
+
gradient_clip_val: 1.0
|
| 16 |
+
|
| 17 |
+
validation:
|
| 18 |
+
val_every_n_step: 2500
|
| 19 |
+
val_every_n_epoch: null
|
| 20 |
+
batch_size: 4
|
| 21 |
+
limit_batch: 1
|
| 22 |
+
|
| 23 |
+
test:
|
| 24 |
+
limit_batch: 1
|
| 25 |
+
batch_size: 1
|
| 26 |
+
|
| 27 |
+
logging:
|
| 28 |
+
metrics:
|
| 29 |
+
# - fvd
|
| 30 |
+
# - fid
|
| 31 |
+
# - lpips
|
configurations/training.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# configuration parsing starts here
|
| 2 |
+
defaults:
|
| 3 |
+
- experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme]
|
| 4 |
+
- dataset: video_minecraft # dataset yaml file name in configurations/dataset folder [fixme]
|
| 5 |
+
- algorithm: dememwm_memory_dit # algorithm yaml file name in configurations/algorithm folder [fixme]
|
| 6 |
+
- cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute
|
| 7 |
+
|
| 8 |
+
debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
|
| 9 |
+
|
| 10 |
+
wandb:
|
| 11 |
+
entity: turlin # wandb account name / organization name [fixme]
|
| 12 |
+
project: DeMemWM # wandb project name; if not provided, defaults to root folder name [fixme]
|
| 13 |
+
mode: offline # set wandb logging to online, offline or dryrun
|
| 14 |
+
|
| 15 |
+
resume: null # wandb run id to resume logging and loading checkpoint from
|
| 16 |
+
load: null # wandb run id containing checkpoint or a path to a checkpoint file
|
| 17 |
+
auto_resume: true # automatically resume training from output_dir/checkpoints when available
|
| 18 |
+
resume_ckpt_path: null # explicit full Lightning checkpoint path for deterministic training resume
|
datasets/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .video import MinecraftVideoDataset, MinecraftVideoLatentDataset
|
datasets/video/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .minecraft_video_dataset import MinecraftVideoDataset
|
| 2 |
+
from .minecraft_video_latent_dataset import MinecraftVideoLatentDataset
|