BonanDing commited on
Commit
b47a1ce
·
0 Parent(s):

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +27 -0
  2. LICENSE.md +14 -0
  3. README.md +37 -0
  4. algorithms/__init__.py +0 -0
  5. algorithms/common/__init__.py +0 -0
  6. algorithms/common/base_algo.py +21 -0
  7. algorithms/common/base_pytorch_algo.py +277 -0
  8. algorithms/common/metrics/__init__.py +3 -0
  9. algorithms/common/metrics/fid.py +1 -0
  10. algorithms/common/metrics/fvd.py +158 -0
  11. algorithms/common/metrics/lpips.py +1 -0
  12. algorithms/worldmem/__init__.py +1 -0
  13. algorithms/worldmem/dememwm/__init__.py +18 -0
  14. algorithms/worldmem/dememwm/algorithm.py +0 -0
  15. algorithms/worldmem/dememwm/cache.py +513 -0
  16. algorithms/worldmem/dememwm/compression.py +260 -0
  17. algorithms/worldmem/dememwm/diagnostics.py +174 -0
  18. algorithms/worldmem/dememwm/gates.py +46 -0
  19. algorithms/worldmem/dememwm/injection.py +83 -0
  20. algorithms/worldmem/dememwm/labels.py +479 -0
  21. algorithms/worldmem/dememwm/memory.py +208 -0
  22. algorithms/worldmem/dememwm/negatives.py +41 -0
  23. algorithms/worldmem/dememwm/retrieval.py +476 -0
  24. algorithms/worldmem/dememwm/schedules.py +223 -0
  25. algorithms/worldmem/dememwm/types.py +98 -0
  26. algorithms/worldmem/dememwm_memory_dit.py +18 -0
  27. algorithms/worldmem/df_base.py +307 -0
  28. algorithms/worldmem/df_video.py +926 -0
  29. algorithms/worldmem/models/attention.py +342 -0
  30. algorithms/worldmem/models/diffusion.py +594 -0
  31. algorithms/worldmem/models/dit.py +899 -0
  32. algorithms/worldmem/models/pose_prediction.py +42 -0
  33. algorithms/worldmem/models/rotary_embedding_torch.py +302 -0
  34. algorithms/worldmem/models/utils.py +163 -0
  35. algorithms/worldmem/models/vae.py +359 -0
  36. configurations/algorithm/base_algo.yaml +3 -0
  37. configurations/algorithm/base_pytorch_algo.yaml +4 -0
  38. configurations/algorithm/base_video_dit.yaml +36 -0
  39. configurations/algorithm/dememwm_memory_dit.yaml +103 -0
  40. configurations/algorithm/df_base.yaml +42 -0
  41. configurations/dataset/base_dataset.yaml +3 -0
  42. configurations/dataset/base_video.yaml +14 -0
  43. configurations/dataset/video_minecraft.yaml +14 -0
  44. configurations/dataset/video_minecraft_latent.yaml +6 -0
  45. configurations/experiment/base_experiment.yaml +2 -0
  46. configurations/experiment/base_pytorch.yaml +51 -0
  47. configurations/experiment/exp_video.yaml +31 -0
  48. configurations/training.yaml +18 -0
  49. datasets/__init__.py +1 -0
  50. datasets/video/__init__.py +2 -0
.gitignore ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+
5
+ .pytest_cache/
6
+ .mypy_cache/
7
+ .ruff_cache/
8
+ .ipynb_checkpoints/
9
+
10
+ .env
11
+ .venv/
12
+ venv/
13
+
14
+ outputs/
15
+ slurm_logs/
16
+ latest-run
17
+ .wandb_run_id
18
+ .wandb_osh_command_dir/
19
+ wandb/
20
+
21
+ *.log
22
+ *.ckpt
23
+ *.pt
24
+ *.pth
25
+ *.safetensors
26
+
27
+ data/
LICENSE.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # S-Lab License 1.0
2
+
3
+ Copyright 2025 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
6
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\
9
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
10
+ 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
11
+
12
+
13
+ ---
14
+ For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg)
README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DeMemWM
2
+
3
+ This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep this sentence in `README.md` and the `LICENSE` file to credit the author.
4
+
5
+ DeMemWM is a Memory-DiT video prediction project built on the local research template. The primary algorithm entry point is `DeMemWMMinecraft`, registered through the Hydra algorithm config `dememwm_memory_dit`.
6
+
7
+ ## Quick Start
8
+
9
+ ```bash
10
+ python -m venv .venv
11
+ source .venv/bin/activate
12
+ pip install -r requirements.txt
13
+ python -m pytest tests
14
+ ```
15
+
16
+ Run a local offline experiment after setting the dataset path in `configurations/dataset/video_minecraft.yaml`:
17
+
18
+ ```bash
19
+ python main.py +name=dememwm_debug algorithm=dememwm_memory_dit wandb.mode=offline
20
+ ```
21
+
22
+ Use `resume_ckpt_path=/path/to/checkpoint.ckpt` for deterministic checkpoint resume, or keep `auto_resume=true` to resume from `output_dir/checkpoints` when available.
23
+
24
+ ## Layout
25
+
26
+ - `algorithms/worldmem/dememwm/`: DeMemWM memory construction, retrieval, scheduling, diagnostics, and injection code.
27
+ - `algorithms/worldmem/dememwm_memory_dit.py`: primary DeMemWM algorithm class.
28
+ - `configurations/algorithm/dememwm_memory_dit.yaml`: consumed DeMemWM training and evaluation contract.
29
+ - `scripts/`: Slurm and inspection scripts using the DeMemWM naming.
30
+ - `tests/`: static and unit coverage for DeMemWM config, retrieval, compression, schedules, and training behavior.
31
+
32
+ ## Reproducibility Notes
33
+
34
+ - Keep `wandb.mode=offline` for local reproducible runs that do not depend on network access.
35
+ - Set `seed=<int>` on the command line to seed Lightning and dataloader workers.
36
+ - Runtime artifacts such as `outputs/`, `slurm_logs/`, Python caches, checkpoints, and local datasets are ignored by git.
37
+ - The default Hydra training config selects `algorithm: dememwm_memory_dit`.
algorithms/__init__.py ADDED
File without changes
algorithms/common/__init__.py ADDED
File without changes
algorithms/common/base_algo.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ from omegaconf import DictConfig
5
+
6
+
7
+ class BaseAlgo(ABC):
8
+ """
9
+ A base class for generic algorithms.
10
+ """
11
+
12
+ def __init__(self, cfg: DictConfig):
13
+ super().__init__()
14
+ self.cfg = cfg
15
+
16
+ @abstractmethod
17
+ def run(*args: Any, **kwargs: Any) -> Any:
18
+ """
19
+ Run the algorithm.
20
+ """
21
+ raise NotImplementedError
algorithms/common/base_pytorch_algo.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import warnings
3
+ import random
4
+ from typing import Any, Union, Sequence, Optional
5
+
6
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
7
+ from omegaconf import DictConfig
8
+ import lightning.pytorch as pl
9
+ import torch
10
+ import numpy as np
11
+ from PIL import Image
12
+ import wandb
13
+ import einops
14
+
15
+
16
+ class BasePytorchAlgo(pl.LightningModule, ABC):
17
+ """
18
+ A base class for Pytorch algorithms using Pytorch Lightning.
19
+ See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details.
20
+ """
21
+
22
+ def __init__(self, cfg: DictConfig):
23
+ super().__init__()
24
+ self.cfg = cfg
25
+ self._build_model()
26
+
27
+ @abstractmethod
28
+ def _build_model(self):
29
+ """
30
+ Create all pytorch nn.Modules here.
31
+ """
32
+ raise NotImplementedError
33
+
34
+ @abstractmethod
35
+ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
36
+ r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or
37
+ logger.
38
+
39
+ Args:
40
+ batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`.
41
+ batch_idx: The index of this batch.
42
+ dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch.
43
+
44
+ Return:
45
+ Any of these options:
46
+ - :class:`~torch.Tensor` - The loss tensor
47
+ - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``.
48
+ - ``None`` - Skip to the next batch. This is only supported for automatic optimization.
49
+ This is not supported for multi-GPU, TPU, IPU, or DeepSpeed.
50
+
51
+ In this step you'd normally do the forward pass and calculate the loss for a batch.
52
+ You can also do fancier things like multiple forward passes or something model specific.
53
+
54
+ Example::
55
+
56
+ def training_step(self, batch, batch_idx):
57
+ x, y, z = batch
58
+ out = self.encoder(x)
59
+ loss = self.loss(out, x)
60
+ return loss
61
+
62
+ To use multiple optimizers, you can switch to 'manual optimization' and control their stepping:
63
+
64
+ .. code-block:: python
65
+
66
+ def __init__(self):
67
+ super().__init__()
68
+ self.automatic_optimization = False
69
+
70
+
71
+ # Multiple optimizers (e.g.: GANs)
72
+ def training_step(self, batch, batch_idx):
73
+ opt1, opt2 = self.optimizers()
74
+
75
+ # do training_step with encoder
76
+ ...
77
+ opt1.step()
78
+ # do training_step with decoder
79
+ ...
80
+ opt2.step()
81
+
82
+ Note:
83
+ When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically
84
+ normalized by ``accumulate_grad_batches`` internally.
85
+
86
+ """
87
+ return super().training_step(*args, **kwargs)
88
+
89
+ def configure_optimizers(self):
90
+ """
91
+ Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation:
92
+ https://lightning.ai/docs/pytorch/stable/common/optimization.html
93
+ """
94
+ parameters = self.parameters()
95
+ return torch.optim.Adam(parameters, lr=self.cfg.lr)
96
+
97
+ def on_save_checkpoint(self, checkpoint):
98
+ checkpoint["rng_states"] = {
99
+ "python": random.getstate(),
100
+ "numpy": np.random.get_state(),
101
+ "torch": torch.get_rng_state(),
102
+ "cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
103
+ }
104
+
105
+ def on_load_checkpoint(self, checkpoint):
106
+ rng_states = checkpoint.get("rng_states")
107
+ if rng_states is None:
108
+ if getattr(self, "_strict_resume_state", False):
109
+ raise RuntimeError(
110
+ "Cannot deterministically resume because this checkpoint has no rng_states entry. "
111
+ "Use a checkpoint created after automatic resume support was added, or start a fresh run."
112
+ )
113
+ return
114
+
115
+ random.setstate(rng_states["python"])
116
+ np.random.set_state(rng_states["numpy"])
117
+ torch.set_rng_state(rng_states["torch"])
118
+ if torch.cuda.is_available() and rng_states["cuda"] is not None:
119
+ torch.cuda.set_rng_state_all(rng_states["cuda"])
120
+
121
+ def log_video(
122
+ self,
123
+ key: str,
124
+ video: Union[np.ndarray, torch.Tensor],
125
+ mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
126
+ std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
127
+ fps: int = 5,
128
+ format: str = "mp4",
129
+ ):
130
+ """
131
+ Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly.
132
+
133
+ Args:
134
+ video: a numpy array or tensor, either in form (time, channel, height, width) or in the form
135
+ (batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8
136
+ or [0, 1] otherwise.
137
+ mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1].
138
+ std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1].
139
+ key: the name of the video.
140
+ fps: the frame rate of the video.
141
+ format: the format of the video. Can be either "mp4" or "gif".
142
+ """
143
+
144
+ if isinstance(video, torch.Tensor):
145
+ video = video.detach().cpu().numpy()
146
+
147
+ expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1]
148
+ if std is not None:
149
+ if isinstance(std, (float, int)):
150
+ std = [std] * 3
151
+ if isinstance(std, torch.Tensor):
152
+ std = std.detach().cpu().numpy()
153
+ std = np.array(std).reshape(*expand_shape)
154
+ video = video * std
155
+ if mean is not None:
156
+ if isinstance(mean, (float, int)):
157
+ mean = [mean] * 3
158
+ if isinstance(mean, torch.Tensor):
159
+ mean = mean.detach().cpu().numpy()
160
+ mean = np.array(mean).reshape(*expand_shape)
161
+ video = video + mean
162
+
163
+ if video.dtype != np.uint8:
164
+ video = np.clip(video, a_min=0, a_max=1) * 255
165
+ video = video.astype(np.uint8)
166
+
167
+ self.logger.experiment.log(
168
+ {
169
+ key: wandb.Video(video, fps=fps, format=format),
170
+ },
171
+ step=self.global_step,
172
+ )
173
+
174
+ def log_image(
175
+ self,
176
+ key: str,
177
+ image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]],
178
+ mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
179
+ std: Union[np.ndarray, torch.Tensor, Sequence, float] = None,
180
+ **kwargs: Any,
181
+ ):
182
+ """
183
+ Log image(s) using WandbLogger.
184
+ Args:
185
+ key: the name of the video.
186
+ image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width).
187
+ mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1].
188
+ std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1].
189
+ kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx.
190
+ """
191
+ if isinstance(image, Image.Image):
192
+ image = [image]
193
+ elif len(image) and not isinstance(image[0], Image.Image):
194
+ if isinstance(image, torch.Tensor):
195
+ image = image.detach().cpu().numpy()
196
+
197
+ if len(image.shape) == 3:
198
+ image = image[None]
199
+
200
+ if image.shape[1] == 3:
201
+ if image.shape[-1] == 3:
202
+ warnings.warn(f"Two channels in shape {image.shape} have size 3, assuming channel first.")
203
+ image = einops.rearrange(image, "b c h w -> b h w c")
204
+
205
+ if std is not None:
206
+ if isinstance(std, (float, int)):
207
+ std = [std] * 3
208
+ if isinstance(std, torch.Tensor):
209
+ std = std.detach().cpu().numpy()
210
+ std = np.array(std)[None, None, None]
211
+ image = image * std
212
+ if mean is not None:
213
+ if isinstance(mean, (float, int)):
214
+ mean = [mean] * 3
215
+ if isinstance(mean, torch.Tensor):
216
+ mean = mean.detach().cpu().numpy()
217
+ mean = np.array(mean)[None, None, None]
218
+ image = image + mean
219
+
220
+ if image.dtype != np.uint8:
221
+ image = np.clip(image, a_min=0.0, a_max=1.0) * 255
222
+ image = image.astype(np.uint8)
223
+ image = [img for img in image]
224
+
225
+ self.logger.log_image(key=key, images=image, **kwargs)
226
+
227
+ def log_gradient_stats(self):
228
+ """Log gradient statistics such as the mean or std of norm."""
229
+
230
+ with torch.no_grad():
231
+ grad_norms = []
232
+ gpr = [] # gradient-to-parameter ratio
233
+ for param in self.parameters():
234
+ if param.grad is not None:
235
+ grad_norms.append(torch.norm(param.grad).item())
236
+ gpr.append(torch.norm(param.grad) / torch.norm(param))
237
+ if len(grad_norms) == 0:
238
+ return
239
+ grad_norms = torch.tensor(grad_norms)
240
+ gpr = torch.tensor(gpr)
241
+ self.log_dict(
242
+ {
243
+ "train/grad_norm/min": grad_norms.min(),
244
+ "train/grad_norm/max": grad_norms.max(),
245
+ "train/grad_norm/std": grad_norms.std(),
246
+ "train/grad_norm/mean": grad_norms.mean(),
247
+ "train/grad_norm/median": torch.median(grad_norms),
248
+ "train/gpr/min": gpr.min(),
249
+ "train/gpr/max": gpr.max(),
250
+ "train/gpr/std": gpr.std(),
251
+ "train/gpr/mean": gpr.mean(),
252
+ "train/gpr/median": torch.median(gpr),
253
+ }
254
+ )
255
+
256
+ def register_data_mean_std(
257
+ self, mean: Union[str, float, Sequence], std: Union[str, float, Sequence], namespace: str = "data"
258
+ ):
259
+ """
260
+ Register mean and std of data as tensor buffer.
261
+
262
+ Args:
263
+ mean: the mean of data.
264
+ std: the std of data.
265
+ namespace: the namespace of the registered buffer.
266
+ """
267
+ for k, v in [("mean", mean), ("std", std)]:
268
+ if isinstance(v, str):
269
+ if v.endswith(".npy"):
270
+ v = torch.from_numpy(np.load(v))
271
+ elif v.endswith(".pt"):
272
+ v = torch.load(v)
273
+ else:
274
+ raise ValueError(f"Unsupported file type {v.split('.')[-1]}.")
275
+ else:
276
+ v = torch.tensor(v)
277
+ self.register_buffer(f"{namespace}_{k}", v.float().to(self.device))
algorithms/common/metrics/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .fid import FrechetInceptionDistance
2
+ from .lpips import LearnedPerceptualImagePatchSimilarity
3
+ from .fvd import FrechetVideoDistance
algorithms/common/metrics/fid.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torchmetrics.image.fid import FrechetInceptionDistance
algorithms/common/metrics/fvd.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adopted from https://github.com/cvpr2022-stylegan-v/stylegan-v
3
+ Verified to be the same as tf version by https://github.com/universome/fvd-comparison
4
+ """
5
+
6
+ import io
7
+ import re
8
+ import requests
9
+ import html
10
+ import hashlib
11
+ import urllib
12
+ import urllib.request
13
+ from typing import Any, List, Tuple, Union, Dict
14
+ import scipy
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import numpy as np
19
+
20
+
21
+ def open_url(
22
+ url: str,
23
+ num_attempts: int = 10,
24
+ verbose: bool = True,
25
+ return_filename: bool = False,
26
+ ) -> Any:
27
+ """Download the given URL and return a binary-mode file object to access the data."""
28
+ assert num_attempts >= 1
29
+
30
+ # Doesn't look like an URL scheme so interpret it as a local filename.
31
+ if not re.match("^[a-z]+://", url):
32
+ return url if return_filename else open(url, "rb")
33
+
34
+ # Handle file URLs. This code handles unusual file:// patterns that
35
+ # arise on Windows:
36
+ #
37
+ # file:///c:/foo.txt
38
+ #
39
+ # which would translate to a local '/c:/foo.txt' filename that's
40
+ # invalid. Drop the forward slash for such pathnames.
41
+ #
42
+ # If you touch this code path, you should test it on both Linux and
43
+ # Windows.
44
+ #
45
+ # Some internet resources suggest using urllib.request.url2pathname() but
46
+ # but that converts forward slashes to backslashes and this causes
47
+ # its own set of problems.
48
+ if url.startswith("file://"):
49
+ filename = urllib.parse.urlparse(url).path
50
+ if re.match(r"^/[a-zA-Z]:", filename):
51
+ filename = filename[1:]
52
+ return filename if return_filename else open(filename, "rb")
53
+
54
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
55
+
56
+ # Download.
57
+ url_name = None
58
+ url_data = None
59
+ with requests.Session() as session:
60
+ if verbose:
61
+ print("Downloading %s ..." % url, end="", flush=True)
62
+ for attempts_left in reversed(range(num_attempts)):
63
+ try:
64
+ with session.get(url) as res:
65
+ res.raise_for_status()
66
+ if len(res.content) == 0:
67
+ raise IOError("No data received")
68
+
69
+ if len(res.content) < 8192:
70
+ content_str = res.content.decode("utf-8")
71
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
72
+ links = [
73
+ html.unescape(link)
74
+ for link in content_str.split('"')
75
+ if "export=download" in link
76
+ ]
77
+ if len(links) == 1:
78
+ url = requests.compat.urljoin(url, links[0])
79
+ raise IOError("Google Drive virus checker nag")
80
+ if "Google Drive - Quota exceeded" in content_str:
81
+ raise IOError(
82
+ "Google Drive download quota exceeded -- please try again later"
83
+ )
84
+
85
+ match = re.search(
86
+ r'filename="([^"]*)"',
87
+ res.headers.get("Content-Disposition", ""),
88
+ )
89
+ url_name = match[1] if match else url
90
+ url_data = res.content
91
+ if verbose:
92
+ print(" done")
93
+ break
94
+ except KeyboardInterrupt:
95
+ raise
96
+ except:
97
+ if not attempts_left:
98
+ if verbose:
99
+ print(" failed")
100
+ raise
101
+ if verbose:
102
+ print(".", end="", flush=True)
103
+
104
+ # Return data as file object.
105
+ assert not return_filename
106
+ return io.BytesIO(url_data)
107
+
108
+
109
+ def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
110
+ mu_gen, sigma_gen = compute_stats(feats_fake)
111
+ mu_real, sigma_real = compute_stats(feats_real)
112
+
113
+ m = np.square(mu_gen - mu_real).sum()
114
+ s, _ = scipy.linalg.sqrtm(
115
+ np.dot(sigma_gen, sigma_real), disp=False
116
+ ) # pylint: disable=no-member
117
+ fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
118
+
119
+ return float(fid)
120
+
121
+
122
+ def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
123
+ mu = feats.mean(axis=0) # [d]
124
+ sigma = np.cov(feats, rowvar=False) # [d, d]
125
+
126
+ return mu, sigma
127
+
128
+
129
+ class FrechetVideoDistance(nn.Module):
130
+ def __init__(self):
131
+ super().__init__()
132
+ detector_url = (
133
+ "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1"
134
+ )
135
+ # Return raw features before the softmax layer.
136
+ self.detector_kwargs = dict(rescale=False, resize=True, return_features=True)
137
+ with open_url(detector_url, verbose=False) as f:
138
+ self.detector = torch.jit.load(f).eval()
139
+
140
+ @torch.no_grad()
141
+ def compute(self, videos_fake: torch.Tensor, videos_real: torch.Tensor):
142
+ """
143
+ :param videos_fake: predicted video tensor of shape (frame, batch, channel, height, width)
144
+ :param videos_real: ground-truth observation tensor of shape (frame, batch, channel, height, width)
145
+ :return:
146
+ """
147
+ n_frames, batch_size, c, h, w = videos_fake.shape
148
+ if n_frames < 2:
149
+ raise ValueError("Video must have more than 1 frame for FVD")
150
+
151
+ videos_fake = videos_fake.permute(1, 2, 0, 3, 4).contiguous()
152
+ videos_real = videos_real.permute(1, 2, 0, 3, 4).contiguous()
153
+
154
+ # detector takes in tensors of shape [batch_size, c, video_len, h, w] with range -1 to 1
155
+ feats_fake = self.detector(videos_fake, **self.detector_kwargs).cpu().numpy()
156
+ feats_real = self.detector(videos_real, **self.detector_kwargs).cpu().numpy()
157
+
158
+ return compute_fvd(feats_fake, feats_real)
algorithms/common/metrics/lpips.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
algorithms/worldmem/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .dememwm_memory_dit import DeMemWMMinecraft, DeMemWMMemoryDiTMinecraft
algorithms/worldmem/dememwm/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors, RevisitRetrievalResult, StreamGateState
3
+ from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens
4
+ from .compression import CausalConv3DDynamicCompressor, latent_patch_tokens, spatial_pool_tokens
5
+ from .retrieval import deterministic_revisit_retrieval
6
+ from .schedules import compute_stream_gates, CurriculumState, resolve_curriculum, DeMemWMCurriculumState, resolve_dememwm_curriculum
7
+ from .gates import RevisitRawGate
8
+ from .cache import StreamingCache, DeMemWMStreamingCache
9
+ from .injection import InjectionAdapter, DeMemWMInjectionAdapter
10
+
11
+ __all__ = [
12
+ "MemoryRecord", "MemorySourceType", "MemoryStreamTensors", "RevisitRetrievalResult", "StreamGateState",
13
+ "CausalMemoryBank", "MemoryBankQuery", "stack_record_tokens",
14
+ "CausalConv3DDynamicCompressor", "latent_patch_tokens", "spatial_pool_tokens",
15
+ "deterministic_revisit_retrieval", "compute_stream_gates", "CurriculumState", "resolve_curriculum",
16
+ "DeMemWMCurriculumState", "resolve_dememwm_curriculum", "RevisitRawGate",
17
+ "StreamingCache", "DeMemWMStreamingCache", "InjectionAdapter", "DeMemWMInjectionAdapter",
18
+ ]
algorithms/worldmem/dememwm/algorithm.py ADDED
The diff for this file is too large to render. See raw diff
 
algorithms/worldmem/dememwm/cache.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Any, Iterable, Optional
6
+ import warnings
7
+
8
+ import torch
9
+
10
+ from .memory import CausalMemoryBank
11
+ from .types import MemoryRecord
12
+
13
+
14
+ @dataclass
15
+ class _RawLatentSegment:
16
+ latents: torch.Tensor
17
+ frame_indices: torch.Tensor
18
+ source_is_generated: torch.Tensor
19
+ pose: Optional[torch.Tensor]
20
+
21
+
22
+ class StreamingCache:
23
+ """Per-video DeMemWM streaming cache with strict no-eviction semantics.
24
+
25
+ The cache is intentionally allowed to grow for the current video. It stores
26
+ detached CPU (or pinned CPU) raw latents plus compressed MemoryRecord objects,
27
+ while DiT readout tensors remain bounded by the caller's manual budgets.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ *,
33
+ enabled: bool = True,
34
+ device: str = "cpu",
35
+ keep_raw_latents: str = "all",
36
+ keep_compressed_records: bool = True,
37
+ keep_prefix_anchors: bool = True,
38
+ eviction_policy: str = "none",
39
+ no_evict: bool = True,
40
+ clear_between_videos: bool = True,
41
+ max_records: Optional[int] = None,
42
+ max_slots: Optional[int] = None,
43
+ on_capacity_exceeded: str = "warn",
44
+ ) -> None:
45
+ self.enabled = bool(enabled)
46
+ self.device = str(device or "cpu")
47
+ self.keep_raw_latents = keep_raw_latents
48
+ self.keep_compressed_records = bool(keep_compressed_records)
49
+ self.keep_prefix_anchors = bool(keep_prefix_anchors)
50
+ self.eviction_policy = str(eviction_policy or "none")
51
+ self.no_evict = bool(no_evict)
52
+ self.clear_between_videos = bool(clear_between_videos)
53
+ self.max_records = max_records
54
+ self.max_slots = max_slots
55
+ self.on_capacity_exceeded = str(on_capacity_exceeded or "warn")
56
+ if self.eviction_policy != "none" or not self.no_evict:
57
+ raise ValueError("DeMemWMStreamingCache only supports eviction_policy='none' with no_evict=true")
58
+ if self.device not in {"cpu", "pinned_cpu", "cuda"}:
59
+ raise ValueError("cache.device must be one of: cpu, pinned_cpu, cuda")
60
+ self.reset_count = 0
61
+ self.evictions = 0
62
+ self.capacity_exceeded_count = 0
63
+ self.current_video_id: Any = None
64
+ self._raw_segments: list[_RawLatentSegment] = []
65
+ self._records: dict[str, dict[int, list[MemoryRecord]]] = {"anchor": {}, "revisit": {}}
66
+ self._raw_keys: set[tuple[int, int]] = set()
67
+ self._raw_index: dict[tuple[int, int], tuple[int, int]] = {}
68
+ self._record_keys: set[tuple[str, int, str, int, int, bool]] = set()
69
+ self._batch_size: Optional[int] = None
70
+ # Concat cache: avoids repeated torch.cat across DDIM steps within one chunk.
71
+ # Invalidated whenever new raw segments are added.
72
+ self._raw_concat_version: int = 0
73
+ self._raw_concat_built: int = -1
74
+ self._raw_concat_cache: Optional[tuple] = None # (latents, frame_indices, generated, pose)
75
+ # GPU memory-bank cache: avoids repeated CPU→GPU record transfers across DDIM steps.
76
+ # Invalidated whenever new records are added.
77
+ self._banks_version: int = 0
78
+ self._banks_built_cache: dict[tuple, tuple[int, list[CausalMemoryBank]]] = {}
79
+
80
+ @classmethod
81
+ def from_config(cls, cfg: Any, *, enabled_default: bool = True) -> "StreamingCache":
82
+ def get(name: str, default: Any) -> Any:
83
+ return getattr(cfg, name, default) if cfg is not None else default
84
+
85
+ return cls(
86
+ enabled=bool(get("enabled", enabled_default)),
87
+ device=str(get("device", "cpu")),
88
+ keep_raw_latents=str(get("keep_raw_latents", "all")),
89
+ keep_compressed_records=bool(get("keep_compressed_records", True)),
90
+ keep_prefix_anchors=bool(get("keep_prefix_anchors", True)),
91
+ eviction_policy=str(get("eviction_policy", "none")),
92
+ no_evict=bool(get("no_evict", True)),
93
+ clear_between_videos=bool(get("clear_between_videos", True)),
94
+ max_records=get("max_records", None),
95
+ max_slots=get("max_slots", None),
96
+ on_capacity_exceeded=str(get("on_capacity_exceeded", "warn")),
97
+ )
98
+
99
+ @property
100
+ def batch_size(self) -> int:
101
+ return int(self._batch_size or 0)
102
+
103
+ @property
104
+ def raw_segment_count(self) -> int:
105
+ return len(self._raw_segments)
106
+
107
+ @property
108
+ def raw_frame_slots(self) -> int:
109
+ return sum(int(seg.latents.shape[0] * seg.latents.shape[1]) for seg in self._raw_segments)
110
+
111
+ @property
112
+ def record_count(self) -> int:
113
+ return sum(len(records) for by_batch in self._records.values() for records in by_batch.values())
114
+
115
+ @property
116
+ def slot_count(self) -> int:
117
+ return sum(record.valid_slots for by_batch in self._records.values() for records in by_batch.values() for record in records)
118
+
119
+ def records_count(self, kind: str | None = None) -> int:
120
+ if kind is None:
121
+ return self.record_count
122
+ return sum(len(records) for records in self._records.get(kind, {}).values())
123
+
124
+ def reset(self, video_id: Any = None) -> None:
125
+ self.current_video_id = video_id
126
+ self._raw_segments.clear()
127
+ self._records = {"anchor": {}, "revisit": {}}
128
+ self._raw_keys.clear()
129
+ self._raw_index.clear()
130
+ self._record_keys.clear()
131
+ self._batch_size = None
132
+ self.evictions = 0
133
+ self.capacity_exceeded_count = 0
134
+ self.reset_count += 1
135
+ self._raw_concat_version += 1
136
+ self._raw_concat_built = -1
137
+ self._raw_concat_cache = None
138
+ self._banks_version += 1
139
+ self._banks_built_cache.clear()
140
+
141
+ def _store_tensor(self, tensor: Optional[torch.Tensor], *, dtype: torch.dtype | None = None) -> Optional[torch.Tensor]:
142
+ if tensor is None:
143
+ return None
144
+ out = tensor.detach()
145
+ if dtype is not None and out.is_floating_point():
146
+ out = out.to(dtype=dtype)
147
+ if self.device in {"cpu", "pinned_cpu"}:
148
+ out = out.to(device="cpu", copy=True)
149
+ if self.device == "pinned_cpu":
150
+ try:
151
+ out = out.pin_memory()
152
+ except RuntimeError:
153
+ # Keep stable CPU behavior if pinning is unavailable in a worker/process.
154
+ pass
155
+ elif self.device == "cuda":
156
+ out = out.clone()
157
+ return out
158
+
159
+ def _metadata_to_storage(self, metadata: dict) -> dict:
160
+ out = {}
161
+ for key, value in dict(metadata or {}).items():
162
+ if torch.is_tensor(value):
163
+ out[key] = self._store_tensor(value)
164
+ elif isinstance(value, dict):
165
+ out[key] = self._metadata_to_storage(value)
166
+ else:
167
+ out[key] = value
168
+ return out
169
+
170
+ def _metadata_to_device(self, metadata: dict, *, device: torch.device, dtype: torch.dtype) -> dict:
171
+ out = {}
172
+ for key, value in dict(metadata or {}).items():
173
+ if torch.is_tensor(value):
174
+ tensor = value.to(device=device)
175
+ out[key] = tensor.to(dtype=dtype) if tensor.is_floating_point() else tensor
176
+ elif isinstance(value, dict):
177
+ out[key] = self._metadata_to_device(value, device=device, dtype=dtype)
178
+ else:
179
+ out[key] = value
180
+ return out
181
+
182
+ def _record_to_storage(self, record: MemoryRecord) -> MemoryRecord:
183
+ return MemoryRecord(
184
+ tokens=self._store_tensor(record.tokens),
185
+ mask=self._store_tensor(record.mask),
186
+ source_start=int(record.source_start),
187
+ source_end=int(record.source_end),
188
+ frame_indices=self._store_tensor(record.frame_indices),
189
+ pose=self._store_tensor(record.pose),
190
+ source_type=record.source_type,
191
+ is_generated=bool(record.is_generated),
192
+ score=None if record.score is None or not torch.is_tensor(record.score) else self._store_tensor(record.score),
193
+ chunk_id=record.chunk_id,
194
+ metadata=self._metadata_to_storage(record.metadata),
195
+ )
196
+
197
+ def _record_to_device(self, record: MemoryRecord, *, device: torch.device, dtype: torch.dtype) -> MemoryRecord:
198
+ return MemoryRecord(
199
+ tokens=record.tokens.to(device=device, dtype=dtype),
200
+ mask=record.mask.to(device=device, dtype=torch.bool),
201
+ source_start=int(record.source_start),
202
+ source_end=int(record.source_end),
203
+ frame_indices=record.frame_indices.to(device=device),
204
+ pose=None if record.pose is None else record.pose.to(device=device),
205
+ source_type=record.source_type,
206
+ is_generated=bool(record.is_generated),
207
+ score=record.score,
208
+ chunk_id=record.chunk_id,
209
+ metadata=self._metadata_to_device(record.metadata, device=device, dtype=dtype),
210
+ )
211
+
212
+ def _check_capacity(self) -> None:
213
+ exceeded = False
214
+ if self.max_records is not None and self.record_count > int(self.max_records):
215
+ exceeded = True
216
+ if self.max_slots is not None and self.slot_count > int(self.max_slots):
217
+ exceeded = True
218
+ if not exceeded:
219
+ return
220
+ self.capacity_exceeded_count += 1
221
+ msg = (
222
+ "DeMemWMStreamingCache capacity exceeded "
223
+ f"records={self.record_count}/{self.max_records}, slots={self.slot_count}/{self.max_slots}; "
224
+ "no eviction performed because no_evict=true"
225
+ )
226
+ if self.on_capacity_exceeded == "error":
227
+ raise RuntimeError(msg)
228
+ if self.on_capacity_exceeded == "warn":
229
+ warnings.warn(msg, RuntimeWarning, stacklevel=2)
230
+
231
+ def add_raw_latents(
232
+ self,
233
+ latents: torch.Tensor,
234
+ frame_indices: torch.Tensor,
235
+ source_is_generated: Optional[torch.Tensor] = None,
236
+ pose: Optional[torch.Tensor] = None,
237
+ ) -> None:
238
+ if not self.enabled or self.keep_raw_latents != "all":
239
+ return
240
+ if latents.ndim != 5:
241
+ raise ValueError("cached raw latents must have shape (T,B,C,H,W)")
242
+ T, B = int(latents.shape[0]), int(latents.shape[1])
243
+ if frame_indices.shape != (T, B):
244
+ raise ValueError("cached frame_indices must have shape (T,B)")
245
+ if self._batch_size is None:
246
+ self._batch_size = B
247
+ elif self._batch_size != B:
248
+ raise ValueError("streaming cache batch size changed within a video")
249
+ keep_positions: list[int] = []
250
+ frame_cpu = frame_indices.detach().cpu()
251
+ for t in range(T):
252
+ keys = [(b, int(frame_cpu[t, b].item())) for b in range(B)]
253
+ if any(key not in self._raw_keys for key in keys):
254
+ keep_positions.append(t)
255
+ self._raw_keys.update(keys)
256
+ if not keep_positions:
257
+ return
258
+ pos = torch.as_tensor(keep_positions, dtype=torch.long)
259
+ seg_latents = latents.index_select(0, pos.to(device=latents.device))
260
+ seg_frames = frame_indices.index_select(0, pos.to(device=frame_indices.device))
261
+ if source_is_generated is None:
262
+ seg_generated = torch.zeros(seg_frames.shape, device=seg_frames.device, dtype=torch.bool)
263
+ else:
264
+ seg_generated = source_is_generated.index_select(0, pos.to(device=source_is_generated.device)).bool()
265
+ seg_pose = None if pose is None else pose.index_select(0, pos.to(device=pose.device))
266
+ segment_idx = len(self._raw_segments)
267
+ self._raw_segments.append(
268
+ _RawLatentSegment(
269
+ latents=self._store_tensor(seg_latents),
270
+ frame_indices=self._store_tensor(seg_frames),
271
+ source_is_generated=self._store_tensor(seg_generated),
272
+ pose=self._store_tensor(seg_pose),
273
+ )
274
+ )
275
+ for local_pos, source_pos in enumerate(keep_positions):
276
+ for b in range(B):
277
+ key = (b, int(frame_cpu[source_pos, b].item()))
278
+ self._raw_index.setdefault(key, (segment_idx, local_pos))
279
+ # Invalidate the concat cache — new segment was added.
280
+ self._raw_concat_version += 1
281
+ self._raw_concat_cache = None
282
+
283
+ def add_records(self, kind: str, batch_idx: int, records: Iterable[MemoryRecord]) -> None:
284
+ if not self.enabled or not self.keep_compressed_records:
285
+ return
286
+ if kind not in self._records:
287
+ raise ValueError(f"unsupported cache record kind: {kind}")
288
+ batch_idx = int(batch_idx)
289
+ bucket = self._records[kind].setdefault(batch_idx, [])
290
+ added_any = False
291
+ for record in records:
292
+ if kind == "anchor" and not self.keep_prefix_anchors:
293
+ continue
294
+ key = (
295
+ kind,
296
+ batch_idx,
297
+ str(record.chunk_id or ""),
298
+ int(record.source_start),
299
+ int(record.source_end),
300
+ bool(record.is_generated),
301
+ )
302
+ if key in self._record_keys:
303
+ continue
304
+ self._record_keys.add(key)
305
+ bucket.append(self._record_to_storage(record))
306
+ added_any = True
307
+ if added_any:
308
+ # Invalidate the GPU banks cache — new records were added.
309
+ self._banks_version += 1
310
+ self._banks_built_cache.clear()
311
+ self._check_capacity()
312
+
313
+ def add_memory_banks(self, anchor_banks: list[CausalMemoryBank], revisit_banks: list[CausalMemoryBank]) -> None:
314
+ for batch_idx, bank in enumerate(anchor_banks):
315
+ self.add_records("anchor", batch_idx, bank.records)
316
+ for batch_idx, bank in enumerate(revisit_banks):
317
+ self.add_records("revisit", batch_idx, bank.records)
318
+
319
+ def memory_banks(self, kind: str, *, device: torch.device, dtype: torch.dtype, batch_size: int | None = None) -> list[CausalMemoryBank]:
320
+ if kind not in self._records:
321
+ raise ValueError(f"unsupported cache record kind: {kind}")
322
+ B = int(batch_size or self.batch_size or (max(self._records[kind].keys()) + 1 if self._records[kind] else 0))
323
+ cache_key = (kind, device, dtype, B)
324
+ cached = self._banks_built_cache.get(cache_key)
325
+ if cached is not None and cached[0] == self._banks_version:
326
+ return cached[1]
327
+ banks: list[CausalMemoryBank] = []
328
+ for batch_idx in range(B):
329
+ bank = CausalMemoryBank()
330
+ for record in self._records[kind].get(batch_idx, []):
331
+ bank.add_record(self._record_to_device(record, device=device, dtype=dtype))
332
+ banks.append(bank)
333
+ self._banks_built_cache[cache_key] = (self._banks_version, banks)
334
+ return banks
335
+
336
+ def records_for_batch(self, kind: str, batch_idx: int) -> tuple[MemoryRecord, ...]:
337
+ if kind not in self._records:
338
+ raise ValueError(f"unsupported cache record kind: {kind}")
339
+ return tuple(self._records[kind].get(int(batch_idx), ()))
340
+
341
+ def raw_latents_for_frames(
342
+ self,
343
+ *,
344
+ batch_idx: int,
345
+ frame_indices: torch.Tensor,
346
+ device: torch.device,
347
+ dtype: torch.dtype,
348
+ ) -> torch.Tensor:
349
+ frames = frame_indices.detach().cpu().reshape(-1)
350
+ rows = []
351
+ batch_idx = int(batch_idx)
352
+ for frame in frames.tolist():
353
+ key = (batch_idx, int(frame))
354
+ location = self._raw_index.get(key)
355
+ if location is None:
356
+ raise KeyError(f"raw latent for batch={batch_idx}, frame={int(frame)} is not cached")
357
+ segment_idx, local_pos = location
358
+ rows.append(self._raw_segments[segment_idx].latents[local_pos, batch_idx])
359
+ if not rows:
360
+ template = self._raw_segments[0].latents
361
+ return template[:0, batch_idx:batch_idx + 1].to(device=device, dtype=dtype)
362
+ return torch.stack(rows, dim=0).unsqueeze(1).to(device=device, dtype=dtype)
363
+
364
+ def _select_time_positions(
365
+ self,
366
+ frame_indices: torch.Tensor,
367
+ target_frame_indices: Optional[torch.Tensor],
368
+ max_recent_frames: Optional[int],
369
+ exclude_latest_local_frames: int = 0,
370
+ ) -> torch.Tensor:
371
+ T, B = frame_indices.shape
372
+ if target_frame_indices is None or max_recent_frames is None or int(max_recent_frames) <= 0:
373
+ return torch.arange(T, dtype=torch.long)
374
+ targets = target_frame_indices.detach().cpu()
375
+ if targets.ndim == 1:
376
+ targets = targets[:, None].expand(-1, B)
377
+ frames = frame_indices.detach().cpu() # (T, B)
378
+ recent = int(max_recent_frames)
379
+ exclude = max(0, int(exclude_latest_local_frames))
380
+ # Vectorized: valid[t_tgt, t_src, b] = True if source position t_src is
381
+ # causally valid for target t_tgt in batch b.
382
+ # frames (T, B) → (1, T, B); targets (T_tgt, B) → (T_tgt, 1, B)
383
+ valid = frames.unsqueeze(0) < (targets.unsqueeze(1) - exclude) # (T_tgt, T, B)
384
+ # For each (t_tgt, b), retain only the last `recent` valid positions.
385
+ # Flip T, cumsum along T (counting from the end), keep where ≤ recent.
386
+ valid_f = valid.flip(1)
387
+ keep_f = (valid_f.long().cumsum(1) <= recent) & valid_f
388
+ # Any position needed by any (t_tgt, b) pair.
389
+ keep_any = keep_f.flip(1).any(dim=0).any(dim=1) # (T,)
390
+ return keep_any.nonzero(as_tuple=False).flatten()
391
+
392
+ def materialize_raw_latents(
393
+ self,
394
+ *,
395
+ device: torch.device,
396
+ dtype: torch.dtype,
397
+ max_recent_frames: Optional[int] = None,
398
+ target_frame_indices: Optional[torch.Tensor] = None,
399
+ exclude_latest_local_frames: int = 0,
400
+ ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
401
+ if not self._raw_segments:
402
+ return None, None, None, None
403
+ if target_frame_indices is not None and max_recent_frames is not None and int(max_recent_frames) > 0:
404
+ return self._materialize_recent_raw_latents(
405
+ device=device,
406
+ dtype=dtype,
407
+ max_recent_frames=int(max_recent_frames),
408
+ target_frame_indices=target_frame_indices,
409
+ exclude_latest_local_frames=exclude_latest_local_frames,
410
+ )
411
+ # Rebuild the concatenated CPU tensors only when new segments were added.
412
+ if self._raw_concat_cache is None or self._raw_concat_built != self._raw_concat_version:
413
+ latents = torch.cat([seg.latents for seg in self._raw_segments], dim=0)
414
+ frame_indices = torch.cat([seg.frame_indices for seg in self._raw_segments], dim=0)
415
+ generated = torch.cat([seg.source_is_generated for seg in self._raw_segments], dim=0)
416
+ pose: Optional[torch.Tensor] = None
417
+ if all(seg.pose is not None for seg in self._raw_segments):
418
+ pose = torch.cat([seg.pose for seg in self._raw_segments if seg.pose is not None], dim=0)
419
+ self._raw_concat_cache = (latents, frame_indices, generated, pose)
420
+ self._raw_concat_built = self._raw_concat_version
421
+ else:
422
+ latents, frame_indices, generated, pose = self._raw_concat_cache
423
+ pos = self._select_time_positions(frame_indices, target_frame_indices, max_recent_frames, exclude_latest_local_frames)
424
+ if pos.numel() == 0:
425
+ empty_latents = latents[:0].to(device=device, dtype=dtype)
426
+ empty_frames = frame_indices[:0].to(device=device)
427
+ empty_generated = generated[:0].to(device=device, dtype=torch.bool)
428
+ empty_pose = None if pose is None else pose[:0].to(device=device)
429
+ return empty_latents, empty_frames, empty_generated, empty_pose
430
+ latents = latents.index_select(0, pos.to(device=latents.device)).to(device=device, dtype=dtype)
431
+ frame_indices = frame_indices.index_select(0, pos.to(device=frame_indices.device)).to(device=device)
432
+ generated = generated.index_select(0, pos.to(device=generated.device)).to(device=device, dtype=torch.bool)
433
+ if pose is not None:
434
+ pose = pose.index_select(0, pos.to(device=pose.device)).to(device=device)
435
+ return latents, frame_indices, generated, pose
436
+
437
+ def _materialize_recent_raw_latents(
438
+ self,
439
+ *,
440
+ device: torch.device,
441
+ dtype: torch.dtype,
442
+ max_recent_frames: int,
443
+ target_frame_indices: torch.Tensor,
444
+ exclude_latest_local_frames: int = 0,
445
+ ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
446
+ B = self.batch_size
447
+ targets = target_frame_indices.detach().cpu()
448
+ if targets.ndim == 1:
449
+ targets = targets[:, None].expand(-1, B)
450
+ elif targets.shape[1] == 1 and B > 1:
451
+ targets = targets.expand(-1, B)
452
+ if targets.shape[1] != B:
453
+ raise ValueError("target_frame_indices batch dimension does not match streaming cache")
454
+
455
+ recent = max(0, int(max_recent_frames))
456
+ exclude = max(0, int(exclude_latest_local_frames))
457
+ counts = torch.zeros(targets.shape, dtype=torch.long)
458
+ selected: list[tuple[_RawLatentSegment, int]] = []
459
+
460
+ for segment in reversed(self._raw_segments):
461
+ frames = segment.frame_indices.detach().cpu()
462
+ for local_pos in range(frames.shape[0] - 1, -1, -1):
463
+ valid = frames[local_pos].unsqueeze(0) < (targets - exclude)
464
+ needed = valid & (counts < recent)
465
+ if not needed.any():
466
+ continue
467
+ selected.append((segment, local_pos))
468
+ counts += needed.long()
469
+ if bool((counts >= recent).all().item()):
470
+ break
471
+ if bool((counts >= recent).all().item()):
472
+ break
473
+
474
+ if not selected:
475
+ template = self._raw_segments[0]
476
+ empty_latents = template.latents[:0].to(device=device, dtype=dtype)
477
+ empty_frames = template.frame_indices[:0].to(device=device)
478
+ empty_generated = template.source_is_generated[:0].to(device=device, dtype=torch.bool)
479
+ empty_pose = None if template.pose is None else template.pose[:0].to(device=device)
480
+ return empty_latents, empty_frames, empty_generated, empty_pose
481
+
482
+ selected.reverse()
483
+ latents = torch.stack([segment.latents[local_pos] for segment, local_pos in selected], dim=0).to(device=device, dtype=dtype)
484
+ frame_indices = torch.stack([segment.frame_indices[local_pos] for segment, local_pos in selected], dim=0).to(device=device)
485
+ generated = torch.stack([segment.source_is_generated[local_pos] for segment, local_pos in selected], dim=0).to(device=device, dtype=torch.bool)
486
+ pose = None
487
+ if all(segment.pose is not None for segment, _ in selected):
488
+ pose = torch.stack(
489
+ [segment.pose[local_pos] for segment, local_pos in selected if segment.pose is not None],
490
+ dim=0,
491
+ ).to(device=device)
492
+ return latents, frame_indices, generated, pose
493
+
494
+ def diagnostics(self, prefix: str = "cache") -> dict[str, Any]:
495
+ return {
496
+ f"{prefix}_enabled": bool(self.enabled),
497
+ f"{prefix}_records": int(self.record_count),
498
+ f"{prefix}_anchor_records": int(self.records_count("anchor")),
499
+ f"{prefix}_revisit_records": int(self.records_count("revisit")),
500
+ f"{prefix}_slots": int(self.slot_count),
501
+ f"{prefix}_raw_frame_slots": int(self.raw_frame_slots),
502
+ f"{prefix}_raw_segments": int(self.raw_segment_count),
503
+ f"{prefix}_evictions": int(self.evictions),
504
+ f"{prefix}_resets": int(self.reset_count),
505
+ f"{prefix}_capacity_exceeded": int(self.capacity_exceeded_count),
506
+ f"{prefix}_device": self.device,
507
+ f"{prefix}_current_video_id": self.current_video_id,
508
+ f"{prefix}_clear_between_videos": bool(self.clear_between_videos),
509
+ f"{prefix}_no_evict": bool(self.no_evict),
510
+ }
511
+
512
+
513
+ DeMemWMStreamingCache = StreamingCache
algorithms/worldmem/dememwm/compression.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+
11
+ def latent_patch_tokens(latents: torch.Tensor, patch_size: int) -> torch.Tensor:
12
+ if latents.ndim != 5:
13
+ raise ValueError("latents must have shape (T,B,C,H,W)")
14
+ if patch_size <= 0:
15
+ raise ValueError("patch_size must be positive")
16
+ T, B, C, H, W = latents.shape
17
+ if H % patch_size != 0 or W % patch_size != 0:
18
+ raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={patch_size}")
19
+ flat = latents.reshape(T * B, C, H, W)
20
+ patches = F.unfold(flat, kernel_size=patch_size, stride=patch_size).transpose(1, 2).contiguous()
21
+ return patches.reshape(T, B, patches.shape[1], C * patch_size * patch_size)
22
+
23
+
24
+ def spatial_pool_tokens(
25
+ tokens: torch.Tensor,
26
+ pool_h: int,
27
+ pool_w: int,
28
+ src_h: int,
29
+ src_w: int,
30
+ ) -> torch.Tensor:
31
+ """2D adaptive average pool on a flattened (src_h*src_w, D) token grid.
32
+ Preserves 2D spatial layout. Returns (pool_h*pool_w, D)."""
33
+ if tokens.ndim != 2:
34
+ raise ValueError("tokens must have shape (N, D)")
35
+ D = tokens.shape[-1]
36
+ spatial = tokens.reshape(src_h, src_w, D).permute(2, 0, 1).unsqueeze(0)
37
+ pooled = F.adaptive_avg_pool2d(spatial, (pool_h, pool_w))
38
+ return pooled.squeeze(0).permute(1, 2, 0).reshape(-1, D)
39
+
40
+
41
+ class SpatialConv2DMemoryProjector(nn.Module):
42
+ """Project latent maps to DiT hidden tokens while preserving the HxW grid."""
43
+
44
+ projects_spatial_latents = True
45
+
46
+ def __init__(
47
+ self,
48
+ latent_channels: int,
49
+ dit_hidden_size: int,
50
+ mid_channels: int,
51
+ kernel_size: int = 3,
52
+ ):
53
+ super().__init__()
54
+ kernel_size = int(kernel_size)
55
+ if kernel_size <= 0 or kernel_size % 2 == 0:
56
+ raise ValueError("kernel_size must be a positive odd integer")
57
+ self.latent_channels = int(latent_channels)
58
+ self.dit_hidden_size = int(dit_hidden_size)
59
+ self.mid_channels = int(mid_channels)
60
+ self.kernel_size = kernel_size
61
+ self.out_features = self.dit_hidden_size
62
+ self.proj_in = nn.Conv2d(self.latent_channels, self.mid_channels, kernel_size=1)
63
+ self.proj_spatial = nn.Conv2d(
64
+ self.mid_channels,
65
+ self.dit_hidden_size,
66
+ kernel_size=kernel_size,
67
+ padding=kernel_size // 2,
68
+ )
69
+
70
+ def forward(self, latents: torch.Tensor) -> torch.Tensor:
71
+ if latents.ndim != 5:
72
+ raise ValueError("latents must have shape (T,B,C,H,W)")
73
+ T, B, C, H, W = latents.shape
74
+ if C != self.latent_channels:
75
+ raise ValueError(f"expected {self.latent_channels} latent channels, got {C}")
76
+ x = latents.reshape(T * B, C, H, W)
77
+ x = self.proj_spatial(self.proj_in(x))
78
+ x = x.reshape(T, B, self.dit_hidden_size, H, W)
79
+ return x.permute(1, 0, 3, 4, 2).reshape(B, T, H * W, self.dit_hidden_size).contiguous()
80
+
81
+
82
+ class CausalConv3DDynamicCompressor(nn.Module):
83
+ """Dynamic memory compressor: delta preprocessing + causal Conv3D on raw latents.
84
+
85
+ Replaces ShortTermLatentCompressor (slot cross-attention).
86
+ - Operates directly on (T, C, H, W) raw latents
87
+ - Delta: inp[0]=latent[0], inp[t]=latent[t]-latent[t-1]
88
+ - Causal padding prepends temporal zeros and right-aligns fixed outputs
89
+ - Zero-padded to max_source_frames for fixed output shape
90
+ - No slot cross-attention, no chunking
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ latent_channels: int,
96
+ dit_hidden_size: int,
97
+ patch_size: int = 2,
98
+ conv_kernel_t: int = 3,
99
+ conv_stride_t: int = 2,
100
+ max_source_frames: int = 8,
101
+ exclude_latest_local_frames: int = 4,
102
+ ):
103
+ super().__init__()
104
+ self.latent_channels = latent_channels
105
+ self.dit_hidden_size = dit_hidden_size
106
+ self.patch_size = patch_size
107
+ self.conv_kernel_t = conv_kernel_t
108
+ self.conv_stride_t = conv_stride_t
109
+ self.max_source_frames = max_source_frames
110
+ self.exclude_latest_local_frames = int(exclude_latest_local_frames)
111
+ self.causal_pad = self._temporal_left_pad()
112
+ self.conv3d = nn.Conv3d(
113
+ latent_channels, dit_hidden_size,
114
+ kernel_size=(conv_kernel_t, patch_size, patch_size),
115
+ stride=(conv_stride_t, patch_size, patch_size),
116
+ padding=0,
117
+ )
118
+ self.out_norm = nn.LayerNorm(dit_hidden_size)
119
+ self._init_temporal_as_delta()
120
+
121
+ def _init_temporal_as_delta(self) -> None:
122
+ with torch.no_grad():
123
+ self.conv3d.weight.zero_()
124
+ k_t, p = self.conv_kernel_t, self.patch_size
125
+ D_out, D_in = self.conv3d.weight.shape[:2]
126
+ scale = 1.0 / (p * p)
127
+ # Delta preprocessing happens in forward. Initialize every output
128
+ # channel to read a patch-averaged current delta, repeating latent
129
+ # channels across the wider DiT hidden dimension.
130
+ for d in range(D_out):
131
+ self.conv3d.weight[d, d % D_in, k_t - 1, :, :] = scale
132
+ if self.conv3d.bias is not None:
133
+ nn.init.zeros_(self.conv3d.bias)
134
+
135
+ def _temporal_output_count(self) -> int:
136
+ return math.ceil(self.max_source_frames / self.conv_stride_t)
137
+
138
+ def _temporal_left_pad(self) -> int:
139
+ t_out = self._temporal_output_count()
140
+ latest_output_end = (t_out - 1) * self.conv_stride_t + self.conv_kernel_t - 1
141
+ latest_source = self.max_source_frames - 1
142
+ return max(0, latest_output_end - latest_source)
143
+
144
+ def _output_time_indices(self, device: torch.device) -> torch.Tensor:
145
+ t_out = self._temporal_output_count()
146
+ return (
147
+ torch.arange(t_out, device=device, dtype=torch.long) * self.conv_stride_t
148
+ + self.conv_kernel_t
149
+ - 1
150
+ - self.causal_pad
151
+ )
152
+
153
+ def tokens_per_target(self, H: int, W: int) -> int:
154
+ p = self.patch_size
155
+ T_out = self._temporal_output_count()
156
+ return T_out * (H // p) * (W // p)
157
+
158
+ def forward(
159
+ self,
160
+ latents: torch.Tensor,
161
+ frame_indices: torch.Tensor,
162
+ pose: Optional[torch.Tensor],
163
+ target_frame_indices: torch.Tensor,
164
+ source_is_generated: Optional[torch.Tensor] = None,
165
+ exclude_latest_local_frames: Optional[int] = None,
166
+ ) -> tuple[torch.Tensor, torch.Tensor, dict]:
167
+ if latents.ndim != 5:
168
+ raise ValueError("latents must have shape (T_src,B,C,H,W)")
169
+ exclude_latest_local_frames = (
170
+ self.exclude_latest_local_frames
171
+ if exclude_latest_local_frames is None
172
+ else int(exclude_latest_local_frames)
173
+ )
174
+ T_src, B, C, H, W = latents.shape
175
+ p = self.patch_size
176
+ if H % p != 0 or W % p != 0:
177
+ raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}")
178
+ if target_frame_indices.ndim == 1:
179
+ target_frame_indices = target_frame_indices[:, None].expand(-1, B)
180
+ T_tgt = target_frame_indices.shape[0]
181
+ device = latents.device
182
+ generated_flags = None if source_is_generated is None else source_is_generated.to(device=device, dtype=torch.bool)
183
+ n_spatial = (H // p) * (W // p)
184
+ T_out = self._temporal_output_count()
185
+ num_slots = T_out * n_spatial
186
+ output_time_idx = self._output_time_indices(device)
187
+ selected_source_count = torch.zeros((B, T_tgt), dtype=torch.long, device=device)
188
+ max_source_frame = torch.full((B, T_tgt), -1, dtype=torch.long, device=device)
189
+ generated_source_fraction = torch.zeros((B, T_tgt), dtype=torch.float32, device=device)
190
+ min_gap = torch.full((B, T_tgt), -1, dtype=torch.long, device=device)
191
+ max_gap = torch.full((B, T_tgt), -1, dtype=torch.long, device=device)
192
+ output_rows, mask_rows = [], []
193
+ for b in range(B):
194
+ src_frames_b = frame_indices[:, b]
195
+ tgt_outputs, tgt_masks = [], []
196
+ for j in range(T_tgt):
197
+ target = int(target_frame_indices[j, b].item())
198
+ valid_idx = (
199
+ src_frames_b < target - exclude_latest_local_frames
200
+ ).nonzero(as_tuple=False).flatten()
201
+ if valid_idx.numel() == 0:
202
+ tgt_outputs.append(latents.new_zeros(num_slots, self.dit_hidden_size))
203
+ tgt_masks.append(torch.zeros(num_slots, device=device, dtype=torch.bool))
204
+ continue
205
+ selected_frames = src_frames_b.index_select(0, valid_idx)
206
+ order = torch.argsort(selected_frames)
207
+ valid_idx = valid_idx.index_select(0, order)[-self.max_source_frames:]
208
+ selected_frames = src_frames_b.index_select(0, valid_idx)
209
+ selected_source_count[b, j] = int(selected_frames.numel())
210
+ max_source_frame[b, j] = selected_frames.max()
211
+ gaps = target - selected_frames
212
+ min_gap[b, j] = gaps.min()
213
+ max_gap[b, j] = gaps.max()
214
+ if generated_flags is not None:
215
+ generated = generated_flags.index_select(0, valid_idx)[:, b]
216
+ generated_source_fraction[b, j] = generated.float().mean()
217
+ chunk = latents[valid_idx, b]
218
+ real_mask = torch.ones((chunk.shape[0],), device=device, dtype=torch.bool)
219
+ if chunk.shape[0] < self.max_source_frames:
220
+ pad = chunk.new_zeros(self.max_source_frames - chunk.shape[0], C, H, W)
221
+ chunk = torch.cat([pad, chunk], dim=0)
222
+ real_mask = torch.cat([
223
+ torch.zeros((pad.shape[0],), device=device, dtype=torch.bool),
224
+ real_mask,
225
+ ])
226
+ inp = chunk.clone()
227
+ inp[1:] = chunk[1:] - chunk[:-1]
228
+ x = inp.permute(1, 0, 2, 3).unsqueeze(0) # (1,C,T,H,W)
229
+ x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0)) # left-pad time
230
+ x = self.conv3d(x) # (1,D,T_out,H//p,W//p)
231
+ x = x.squeeze(0).permute(1, 2, 3, 0) # (T_out,H//p,W//p,D)
232
+ x = self.out_norm(x)
233
+ tokens = x.reshape(num_slots, self.dit_hidden_size)
234
+ clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1)
235
+ temporal_mask = (
236
+ (output_time_idx >= 0)
237
+ & (output_time_idx < self.max_source_frames)
238
+ & real_mask.index_select(0, clamped_time_idx)
239
+ )
240
+ mask = temporal_mask[:, None].expand(T_out, n_spatial).reshape(num_slots)
241
+ tgt_outputs.append(tokens)
242
+ tgt_masks.append(mask)
243
+ output_rows.append(torch.stack(tgt_outputs))
244
+ mask_rows.append(torch.stack(tgt_masks))
245
+ out_tokens = torch.stack(output_rows)
246
+ out_mask = torch.stack(mask_rows)
247
+ diagnostics = {
248
+ "num_dynamic_slots": num_slots,
249
+ "dynamic_T_out": T_out,
250
+ "dynamic_n_spatial": n_spatial,
251
+ "dynamic_temporal_left_pad": self.causal_pad,
252
+ "dynamic_output_time_indices": output_time_idx,
253
+ "selected_source_count": selected_source_count,
254
+ "max_source_frame": max_source_frame,
255
+ "generated_source_fraction": generated_source_fraction,
256
+ "dynamic_min_gap_to_target_per_target": min_gap,
257
+ "dynamic_max_gap_to_target_per_target": max_gap,
258
+ "dynamic_exclude_latest_local_frames": exclude_latest_local_frames,
259
+ }
260
+ return out_tokens, out_mask, diagnostics
algorithms/worldmem/dememwm/diagnostics.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from .schedules import EVAL_ABLATION_BRANCH_TO_ID, NOISE_BUCKETS, NOISE_BUCKET_TO_ID, normalize_eval_ablation_branch, normalize_noise_bucket
8
+
9
+
10
+ _REVISIT_LABEL_SOURCE = "deterministic_fov_coverage_plucker"
11
+
12
+
13
+ def tensor_valid_fraction(mask: torch.Tensor | None) -> float:
14
+ if mask is None or mask.numel() == 0:
15
+ return 0.0
16
+ return float(mask.detach().bool().float().mean().item())
17
+
18
+
19
+ def gate_stats(gate: torch.Tensor | float | int | None) -> dict[str, float]:
20
+ if gate is None:
21
+ return {"mean": 0.0, "min": 0.0, "max": 0.0}
22
+ if not torch.is_tensor(gate):
23
+ value = float(gate)
24
+ return {"mean": value, "min": value, "max": value}
25
+ g = gate.detach().float()
26
+ return {"mean": float(g.mean().item()), "min": float(g.min().item()), "max": float(g.max().item())}
27
+
28
+
29
+ def summarize_stream(name: str, tokens: torch.Tensor | None, mask: torch.Tensor | None, gate: torch.Tensor | float | None) -> dict[str, Any]:
30
+ return {f"{name}_tokens_shape": None if tokens is None else tuple(tokens.shape), f"{name}_valid_fraction": tensor_valid_fraction(mask), f"{name}_valid_tokens": 0 if mask is None else int(mask.detach().bool().sum().item()), f"{name}_gate": gate_stats(gate)}
31
+
32
+
33
+ def assert_no_future_sources(target_frame: int, max_source_frame: int | torch.Tensor) -> None:
34
+ max_src = int(max_source_frame.detach().max().item()) if torch.is_tensor(max_source_frame) else int(max_source_frame)
35
+ if max_src >= int(target_frame):
36
+ raise AssertionError(f"DeMemWM memory source {max_src} is not causal for target {target_frame}")
37
+
38
+
39
+ def _collect_values(result_diagnostics: list[dict[str, Any]], key: str) -> list[float]:
40
+ values: list[float] = []
41
+ for diag in result_diagnostics:
42
+ for value in diag.get(key, []) or []:
43
+ values.append(float(value))
44
+ return values
45
+
46
+
47
+ def _value_stats(values: list[float], prefix: str) -> dict[str, float]:
48
+ if not values:
49
+ return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0}
50
+ return {
51
+ f"{prefix}_mean": float(sum(values) / len(values)),
52
+ f"{prefix}_min": float(min(values)),
53
+ f"{prefix}_max": float(max(values)),
54
+ }
55
+
56
+
57
+ def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], valid_revisit_mask: torch.Tensor | None) -> dict[str, Any]:
58
+ target_count = len(result_diagnostics)
59
+ candidate_count = sum(int(diag.get("revisit_candidate_frame_count", diag.get("revisit_candidate_count", diag.get("candidate_count", 0)))) for diag in result_diagnostics)
60
+ candidate_count_mean = float(candidate_count / target_count) if target_count else 0.0
61
+ valid_candidate_label_count = sum(int(diag.get("valid_candidate_label_count", diag.get("valid_candidate_count", 0))) for diag in result_diagnostics)
62
+ pose_preselect_input_count = sum(int(diag.get("revisit_pose_preselect_input_count", 0)) for diag in result_diagnostics)
63
+ pose_preselect_selected_count = sum(int(diag.get("revisit_pose_preselect_selected_count", 0)) for diag in result_diagnostics)
64
+ exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics)
65
+ valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics)
66
+ valid_count_mean = float(valid_count / target_count) if target_count else 0.0
67
+ valid_target_count = sum(int(diag.get("valid_revisit_target_count", diag.get("high_quality_selected_revisit", 0))) for diag in result_diagnostics)
68
+ selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics)
69
+ no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics)
70
+ abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics)
71
+ selected_gaps = [int(diag["revisit_min_gap_to_target"]) for diag in result_diagnostics if int(diag.get("revisit_min_gap_to_target", -1)) >= 0]
72
+ diagnostics: dict[str, Any] = {
73
+ "revisit_candidate_frame_count": candidate_count_mean,
74
+ "revisit_candidate_count": candidate_count_mean,
75
+ "valid_candidate_label_count": int(valid_candidate_label_count),
76
+ "revisit_pose_preselect_input_count": float(pose_preselect_input_count / target_count) if target_count else 0.0,
77
+ "revisit_pose_preselect_selected_count": float(pose_preselect_selected_count / target_count) if target_count else 0.0,
78
+ "revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0,
79
+ "valid_revisit_frame_count": valid_count_mean,
80
+ "valid_revisit_count": valid_count_mean,
81
+ "valid_revisit_target_count": int(valid_target_count),
82
+ "no_valid_revisit_count": int(no_valid_count),
83
+ "valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask),
84
+ "revisit_selected_frame_count": int(selected_count),
85
+ "revisit_selected_count": int(selected_count),
86
+ "revisit_abstained_count": int(abstained_count),
87
+ "revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1,
88
+ "revisit_label_source": _REVISIT_LABEL_SOURCE,
89
+ }
90
+ frame_fov_values = _collect_values(result_diagnostics, "frame_fov_overlap_values")
91
+ if not frame_fov_values:
92
+ frame_fov_values = _collect_values(result_diagnostics, "fov_overlap_values")
93
+ diagnostics.update(_value_stats(frame_fov_values, "revisit_frame_fov_overlap"))
94
+ diagnostics.update(_value_stats(frame_fov_values, "revisit_fov_overlap"))
95
+ diagnostics.update(_value_stats(_collect_values(result_diagnostics, "plucker_overlap_values"), "revisit_plucker_overlap"))
96
+ diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_fov_overlap_values"), "revisit_best_selected_fov_overlap"))
97
+ diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_plucker_overlap_values"), "revisit_best_selected_plucker_overlap"))
98
+ diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_gap_frame_values"), "revisit_best_selected_gap_frames"))
99
+ diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_frame_fov_overlap_values"), "revisit_best_selected_frame_fov_overlap"))
100
+ diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_frame_fov_overlap_values"), "revisit_selected_frame_fov_overlap"))
101
+ diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_incremental_fov_overlap_values"), "revisit_incremental_fov_overlap"))
102
+ return diagnostics
103
+
104
+
105
+ def summarize_noise_bucket_diagnostics(
106
+ *,
107
+ noise_bucket: str | None,
108
+ valid_revisit_mask: torch.Tensor | None,
109
+ no_valid_revisit_mask: torch.Tensor | None,
110
+ noise_bucket_ids: torch.Tensor | None = None,
111
+ ) -> dict[str, Any]:
112
+ bucket = normalize_noise_bucket(noise_bucket)
113
+ diagnostics: dict[str, Any] = {
114
+ "noise_bucket": bucket,
115
+ "noise_bucket_id": int(NOISE_BUCKET_TO_ID[bucket]),
116
+ }
117
+ for candidate in NOISE_BUCKETS:
118
+ diagnostics[f"noise_bucket_is_{candidate}"] = int(bucket == candidate)
119
+
120
+ valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
121
+ no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
122
+ target_count = int(valid.numel())
123
+ diagnostics["noise_bucket_target_count"] = target_count
124
+ if noise_bucket_ids is None:
125
+ target_bucket_ids = torch.full((target_count,), int(NOISE_BUCKET_TO_ID[bucket]), dtype=torch.long)
126
+ else:
127
+ target_bucket_ids = noise_bucket_ids.detach().long().reshape(-1).cpu()
128
+ if int(target_bucket_ids.numel()) != target_count:
129
+ raise ValueError(
130
+ f"noise_bucket_ids has {int(target_bucket_ids.numel())} targets, expected {target_count}"
131
+ )
132
+
133
+ for bucket_name in NOISE_BUCKETS:
134
+ bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
135
+ diagnostics[f"noise_bucket_{bucket_name}_target_count"] = int(bucket_mask.long().sum().item())
136
+
137
+ mask_specs = (
138
+ ("valid_revisit", valid),
139
+ ("no_valid_revisit", no_valid),
140
+ )
141
+ for mask_name, mask in mask_specs:
142
+ for bucket_name in NOISE_BUCKETS:
143
+ bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name])
144
+ count = int((mask & bucket_mask).long().sum().item()) if mask.numel() else 0
145
+ diagnostics[f"{mask_name}_noise_bucket_{bucket_name}_count"] = count
146
+ return diagnostics
147
+
148
+
149
+ def summarize_eval_ablation_diagnostics(
150
+ *,
151
+ enabled: bool,
152
+ branch: str | None,
153
+ valid_revisit_mask: torch.Tensor | None,
154
+ no_valid_revisit_mask: torch.Tensor | None,
155
+ eval_corrupted_revisit_mask: torch.Tensor | None,
156
+ ) -> dict[str, Any]:
157
+ branch = normalize_eval_ablation_branch(branch)
158
+ valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu()
159
+ no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu()
160
+ corrupted = torch.zeros_like(valid) if eval_corrupted_revisit_mask is None else eval_corrupted_revisit_mask.detach().bool().reshape(-1).cpu()
161
+ true_revisit = valid & (~corrupted)
162
+ diagnostics: dict[str, Any] = {
163
+ "eval_ablation_enabled": bool(enabled),
164
+ "eval_ablation_branch": branch,
165
+ "eval_ablation_branch_id": int(EVAL_ABLATION_BRANCH_TO_ID[branch]),
166
+ "eval_bucket_true_revisit_count": int(true_revisit.long().sum().item()),
167
+ "eval_bucket_no_valid_revisit_count": int(no_valid.long().sum().item()),
168
+ "eval_bucket_corrupted_memory_count": int(corrupted.long().sum().item()),
169
+ }
170
+ total = max(int(valid.numel()), 1)
171
+ diagnostics["eval_bucket_true_revisit_fraction"] = float(diagnostics["eval_bucket_true_revisit_count"] / total)
172
+ diagnostics["eval_bucket_no_valid_revisit_fraction"] = float(diagnostics["eval_bucket_no_valid_revisit_count"] / total)
173
+ diagnostics["eval_bucket_corrupted_memory_fraction"] = float(diagnostics["eval_bucket_corrupted_memory_count"] / total)
174
+ return diagnostics
algorithms/worldmem/dememwm/gates.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class RevisitRawGate(nn.Module):
8
+ """Learned per-target revisit gate from selected-revisit quality features.
9
+
10
+ The caller applies validity masking and stage/denoise scaling after this
11
+ module. This module never turns selected revisit validity into a target.
12
+ """
13
+
14
+ def __init__(self, init_logit: float = 1.0):
15
+ super().__init__()
16
+ self.net = nn.Linear(3, 1)
17
+ nn.init.zeros_(self.net.weight)
18
+ nn.init.constant_(self.net.bias, float(init_logit))
19
+
20
+ def forward(
21
+ self,
22
+ *,
23
+ valid_revisit_mask: torch.Tensor,
24
+ best_selected_fov_overlap: torch.Tensor | None = None,
25
+ best_selected_plucker_overlap: torch.Tensor | None = None,
26
+ selected_gap_frames: torch.Tensor | None = None,
27
+ ) -> torch.Tensor:
28
+ if valid_revisit_mask.ndim != 2:
29
+ raise ValueError("valid_revisit_mask must have shape (B,T)")
30
+ device = valid_revisit_mask.device
31
+ dtype = torch.float32
32
+ shape = valid_revisit_mask.shape
33
+
34
+ def _feature(value: torch.Tensor | None) -> torch.Tensor:
35
+ if value is None:
36
+ return torch.zeros(shape, device=device, dtype=dtype)
37
+ tensor = value.to(device=device, dtype=dtype)
38
+ if tensor.ndim == 0:
39
+ return torch.full(shape, float(tensor.item()), device=device, dtype=dtype)
40
+ return tensor.expand(shape)
41
+
42
+ fov = _feature(best_selected_fov_overlap).clamp(min=0.0, max=1.0)
43
+ plucker = _feature(best_selected_plucker_overlap).clamp(min=0.0, max=1.0)
44
+ log_age = torch.log1p(_feature(selected_gap_frames).clamp_min(0.0)).clamp(max=8.0) / 8.0
45
+ features = torch.stack([fov, plucker, log_age], dim=-1)
46
+ return torch.sigmoid(self.net(features).squeeze(-1))
algorithms/worldmem/dememwm/injection.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from .diagnostics import summarize_stream
10
+ from .types import MemoryStreamTensors
11
+
12
+
13
+ @dataclass
14
+ class InjectionAdapter:
15
+ """Convert DeMemWM stream tensors to Diffusion/DiT Option-C kwargs."""
16
+
17
+ omit_disabled: bool = False
18
+
19
+ def _tokens(self, name: str, tokens: torch.Tensor, device, dtype) -> torch.Tensor:
20
+ if tokens.ndim != 4:
21
+ raise ValueError(f"{name} tokens must have shape (B,T,M,D), got {tuple(tokens.shape)}")
22
+ return tokens.to(device=device, dtype=dtype)
23
+
24
+ def _mask(self, name: str, mask: torch.Tensor, tokens: torch.Tensor, device) -> torch.Tensor:
25
+ if mask.shape != tokens.shape[:3]:
26
+ raise ValueError(f"{name} mask must have shape {tuple(tokens.shape[:3])}, got {tuple(mask.shape)}")
27
+ return mask.to(device=device, dtype=torch.bool)
28
+
29
+ def _gate(self, gate: torch.Tensor | float | int, tokens: torch.Tensor, device, dtype):
30
+ if torch.is_tensor(gate):
31
+ return gate.to(device=device, dtype=dtype)
32
+ return torch.tensor(float(gate), device=device, dtype=dtype)
33
+
34
+ def __call__(self, streams: MemoryStreamTensors, device=None, dtype=None) -> tuple[dict[str, Any], dict[str, Any]]:
35
+ ref = streams.anchor_tokens
36
+ device = device or ref.device
37
+ dtype = dtype or ref.dtype
38
+ anchor_tokens = self._tokens("anchor", streams.anchor_tokens, device, dtype)
39
+ dynamic_tokens = self._tokens("dynamic", streams.dynamic_tokens, device, dtype)
40
+ revisit_tokens = self._tokens("revisit", streams.revisit_tokens, device, dtype)
41
+ anchor_mask = self._mask("anchor", streams.anchor_mask, anchor_tokens, device)
42
+ dynamic_mask = self._mask("dynamic", streams.dynamic_mask, dynamic_tokens, device)
43
+ revisit_mask = self._mask("revisit", streams.revisit_mask, revisit_tokens, device)
44
+ kwargs = {
45
+ "memory_tokens": anchor_tokens,
46
+ "memory_token_mask": anchor_mask,
47
+ "memory_dynamic_tokens": dynamic_tokens,
48
+ "memory_dynamic_mask": dynamic_mask,
49
+ "memory_retrieval_tokens": revisit_tokens,
50
+ "memory_retrieval_mask": revisit_mask,
51
+ "memory_anchor_gate": self._gate(streams.anchor_gate, anchor_tokens, device, dtype),
52
+ "memory_dynamic_gate": self._gate(streams.dynamic_gate, dynamic_tokens, device, dtype),
53
+ "memory_retrieval_gate": self._gate(streams.revisit_gate, revisit_tokens, device, dtype),
54
+ }
55
+ if self.omit_disabled:
56
+ if not anchor_mask.any():
57
+ kwargs["memory_tokens"] = None
58
+ kwargs["memory_token_mask"] = None
59
+ if not dynamic_mask.any():
60
+ kwargs["memory_dynamic_tokens"] = None
61
+ kwargs["memory_dynamic_mask"] = None
62
+ if not revisit_mask.any():
63
+ kwargs["memory_retrieval_tokens"] = None
64
+ kwargs["memory_retrieval_mask"] = None
65
+ diagnostics = dict(streams.diagnostics)
66
+ diagnostics.update(summarize_stream("anchor", anchor_tokens, anchor_mask, kwargs["memory_anchor_gate"]))
67
+ diagnostics.update(summarize_stream("dynamic", dynamic_tokens, dynamic_mask, kwargs["memory_dynamic_gate"]))
68
+ diagnostics.update(summarize_stream("revisit", revisit_tokens, revisit_mask, kwargs["memory_retrieval_gate"]))
69
+ if streams.revisit_gate_raw is not None:
70
+ raw_gate = streams.revisit_gate_raw.to(device=device, dtype=dtype)
71
+ diagnostics["revisit_gate_raw"] = raw_gate
72
+ diagnostics["revisit_gate_raw_mean"] = float(raw_gate.detach().float().mean().item()) if raw_gate.numel() else 0.0
73
+ diagnostics["revisit_gate_raw_min"] = float(raw_gate.detach().float().min().item()) if raw_gate.numel() else 0.0
74
+ diagnostics["revisit_gate_raw_max"] = float(raw_gate.detach().float().max().item()) if raw_gate.numel() else 0.0
75
+ if streams.no_valid_revisit_mask is not None:
76
+ diagnostics["no_valid_revisit_mask"] = streams.no_valid_revisit_mask.to(device=device, dtype=torch.bool)
77
+ max_sources = [v for k, v in streams.diagnostics.items() if k.endswith("max_source_frame")]
78
+ if max_sources:
79
+ diagnostics["max_source_frame"] = max(int(torch.as_tensor(v).max().item()) for v in max_sources)
80
+ return kwargs, diagnostics
81
+
82
+
83
+ DeMemWMInjectionAdapter = InjectionAdapter
algorithms/worldmem/dememwm/labels.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Any, Optional
6
+
7
+ import torch
8
+
9
+ from .types import MemoryRecord
10
+
11
+
12
+ LABEL_SOURCE = "deterministic_fov_coverage_plucker"
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class RevisitCandidateLabel:
17
+ record: MemoryRecord
18
+ valid: bool
19
+ gap_valid: bool
20
+ gap_to_target: int
21
+ fov_overlap: Optional[float]
22
+ plucker_overlap: Optional[float]
23
+ primary_overlap: float
24
+ coverage_mask: Optional[torch.Tensor]
25
+ reject_reasons: tuple[str, ...]
26
+ best_frame_index: Optional[int] = None
27
+ best_frame_fov_overlap: Optional[float] = None
28
+
29
+ @property
30
+ def sort_key(self) -> tuple[float, float, float, int, int, str]:
31
+ fov = 0.0 if self.fov_overlap is None else float(self.fov_overlap)
32
+ plucker = 0.0 if self.plucker_overlap is None else float(self.plucker_overlap)
33
+ return (
34
+ -self.primary_overlap,
35
+ -fov,
36
+ -plucker,
37
+ self.gap_to_target,
38
+ int(self.record.source_start),
39
+ str(self.record.chunk_id or ""),
40
+ )
41
+
42
+
43
+ def _as_float_tensor(value) -> Optional[torch.Tensor]:
44
+ if value is None:
45
+ return None
46
+ if torch.is_tensor(value):
47
+ if value.numel() == 0:
48
+ return None
49
+ return value.detach().float()
50
+ tensor = torch.as_tensor(value, dtype=torch.float32)
51
+ if tensor.numel() == 0:
52
+ return None
53
+ return tensor
54
+
55
+
56
+ def _pose_frames(value) -> Optional[torch.Tensor]:
57
+ tensor = _as_float_tensor(value)
58
+ if tensor is None or tensor.shape[-1] < 5:
59
+ return None
60
+ return tensor.reshape(-1, tensor.shape[-1])[:, :5]
61
+
62
+
63
+ def _angle_diff_degrees(value: torch.Tensor) -> torch.Tensor:
64
+ diff = value.abs() % 360.0
65
+ return torch.where(diff > 180.0, 360.0 - diff, diff)
66
+
67
+
68
+ def _target_fov_points(
69
+ target_pose: torch.Tensor,
70
+ *,
71
+ fov_half_h: float,
72
+ fov_half_v: float,
73
+ yaw_samples: int,
74
+ pitch_samples: int,
75
+ depth_samples: int,
76
+ radius: float,
77
+ ) -> torch.Tensor:
78
+ yaw_samples = max(1, int(yaw_samples))
79
+ pitch_samples = max(1, int(pitch_samples))
80
+ depth_samples = max(1, int(depth_samples))
81
+ device = target_pose.device
82
+ dtype = target_pose.dtype
83
+ if yaw_samples == 1:
84
+ yaw_offsets = torch.zeros((1,), device=device, dtype=dtype)
85
+ else:
86
+ yaw_offsets = torch.linspace(-float(fov_half_h), float(fov_half_h), yaw_samples + 2, device=device, dtype=dtype)[1:-1]
87
+ if pitch_samples == 1:
88
+ pitch_offsets = torch.zeros((1,), device=device, dtype=dtype)
89
+ else:
90
+ pitch_offsets = torch.linspace(-float(fov_half_v), float(fov_half_v), pitch_samples + 2, device=device, dtype=dtype)[1:-1]
91
+ if depth_samples == 1:
92
+ depths = torch.full((1,), float(radius), device=device, dtype=dtype)
93
+ else:
94
+ depths = torch.linspace(float(radius) / float(depth_samples), float(radius), depth_samples, device=device, dtype=dtype)
95
+ depth_grid, pitch_grid, yaw_grid = torch.meshgrid(depths, pitch_offsets, yaw_offsets, indexing="ij")
96
+ pitch = torch.deg2rad(target_pose[3] + pitch_grid.reshape(-1))
97
+ yaw = torch.deg2rad(target_pose[4] + yaw_grid.reshape(-1))
98
+ depth = depth_grid.reshape(-1)
99
+ cos_pitch = torch.cos(pitch)
100
+ vectors = torch.stack(
101
+ [
102
+ depth * cos_pitch * torch.sin(yaw),
103
+ depth * torch.sin(pitch),
104
+ depth * cos_pitch * torch.cos(yaw),
105
+ ],
106
+ dim=-1,
107
+ )
108
+ return target_pose[:3].reshape(1, 3) + vectors
109
+
110
+
111
+ def _inside_fov_3d_hv(
112
+ points: torch.Tensor,
113
+ poses: torch.Tensor,
114
+ *,
115
+ fov_half_h: float,
116
+ fov_half_v: float,
117
+ ) -> torch.Tensor:
118
+ vectors = points.unsqueeze(0) - poses[:, None, :3]
119
+ x = vectors[..., 0]
120
+ y = vectors[..., 1]
121
+ z = vectors[..., 2]
122
+ azimuth = torch.atan2(x, z) * (180.0 / math.pi)
123
+ elevation = torch.atan2(y, torch.sqrt(x.square() + z.square()).clamp_min(1e-8)) * (180.0 / math.pi)
124
+ diff_azimuth = _angle_diff_degrees(azimuth - poses[:, None, 4])
125
+ diff_elevation = _angle_diff_degrees(elevation - poses[:, None, 3])
126
+ return (diff_azimuth < float(fov_half_h)) & (diff_elevation < float(fov_half_v))
127
+
128
+
129
+ def fov_coverage_overlap(
130
+ source_pose,
131
+ target_pose,
132
+ *,
133
+ fov_half_h: float = 105.0 / 2.0,
134
+ fov_half_v: float = 75.0 / 2.0,
135
+ yaw_samples: int = 25,
136
+ pitch_samples: int = 20,
137
+ depth_samples: int = 20,
138
+ radius: float = 30.0,
139
+ ) -> tuple[Optional[float], Optional[torch.Tensor]]:
140
+ source_poses = _pose_frames(source_pose)
141
+ target_poses = _pose_frames(target_pose)
142
+ if source_poses is None or target_poses is None:
143
+ return None, None
144
+ target = target_poses[-1].to(device=source_poses.device, dtype=source_poses.dtype)
145
+ points = _target_fov_points(
146
+ target,
147
+ fov_half_h=fov_half_h,
148
+ fov_half_v=fov_half_v,
149
+ yaw_samples=yaw_samples,
150
+ pitch_samples=pitch_samples,
151
+ depth_samples=depth_samples,
152
+ radius=radius,
153
+ )
154
+ coverage_mask = _inside_fov_3d_hv(points, source_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v).any(dim=0)
155
+ return float(coverage_mask.float().mean().item()), coverage_mask.detach()
156
+
157
+
158
+ def _rotation_from_pose(poses: torch.Tensor) -> torch.Tensor:
159
+ pitch = torch.deg2rad(poses[:, 3])
160
+ yaw = torch.deg2rad(poses[:, 4])
161
+ cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch)
162
+ cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw)
163
+ zeros = torch.zeros_like(pitch)
164
+ ones = torch.ones_like(pitch)
165
+ r_pitch = torch.stack(
166
+ [
167
+ ones, zeros, zeros,
168
+ zeros, cos_pitch, -sin_pitch,
169
+ zeros, sin_pitch, cos_pitch,
170
+ ],
171
+ dim=-1,
172
+ ).reshape(-1, 3, 3)
173
+ r_yaw = torch.stack(
174
+ [
175
+ cos_yaw, zeros, sin_yaw,
176
+ zeros, ones, zeros,
177
+ -sin_yaw, zeros, cos_yaw,
178
+ ],
179
+ dim=-1,
180
+ ).reshape(-1, 3, 3)
181
+ return torch.matmul(r_yaw, r_pitch)
182
+
183
+
184
+ def _plucker_descriptor(
185
+ poses: torch.Tensor,
186
+ *,
187
+ grid_h: int,
188
+ grid_w: int,
189
+ focal_length: float,
190
+ ) -> torch.Tensor:
191
+ grid_h = max(1, int(grid_h))
192
+ grid_w = max(1, int(grid_w))
193
+ poses = poses.float()
194
+ device = poses.device
195
+ dtype = torch.float32
196
+ ys, xs = torch.meshgrid(
197
+ torch.linspace(0, grid_h - 1, grid_h, device=device, dtype=dtype),
198
+ torch.linspace(0, grid_w - 1, grid_w, device=device, dtype=dtype),
199
+ indexing="ij",
200
+ )
201
+ fx = float(focal_length) * float(grid_w)
202
+ fy = float(focal_length) * float(grid_h)
203
+ cx = 0.5 * float(grid_w)
204
+ cy = 0.5 * float(grid_h)
205
+ zs = torch.ones_like(xs)
206
+ dirs = torch.stack([-(xs + 0.5 - cx) / fx, -(ys + 0.5 - cy) / fy, zs], dim=-1)
207
+ dirs = dirs.reshape(-1, 3)
208
+ dirs = dirs / dirs.norm(dim=-1, keepdim=True).clamp_min(1e-8)
209
+ rotation = _rotation_from_pose(poses)
210
+ rays_d = torch.matmul(dirs.unsqueeze(0), rotation.transpose(1, 2)).float()
211
+ rays_o = poses[:, None, :3].expand_as(rays_d).float()
212
+ moments = torch.linalg.cross(rays_o, rays_d, dim=-1)
213
+ return torch.cat([moments, rays_d], dim=-1).reshape(poses.shape[0], -1)
214
+
215
+
216
+ def plucker_overlap(
217
+ source_pose,
218
+ target_pose,
219
+ *,
220
+ grid_h: int = 4,
221
+ grid_w: int = 4,
222
+ focal_length: float = 0.35,
223
+ ) -> Optional[float]:
224
+ source_poses = _pose_frames(source_pose)
225
+ target_poses = _pose_frames(target_pose)
226
+ if source_poses is None or target_poses is None:
227
+ return None
228
+ target = target_poses[-1:].to(device=source_poses.device, dtype=source_poses.dtype)
229
+ source_desc = _plucker_descriptor(source_poses, grid_h=grid_h, grid_w=grid_w, focal_length=focal_length)
230
+ target_desc = _plucker_descriptor(target, grid_h=grid_h, grid_w=grid_w, focal_length=focal_length)
231
+ diff = source_desc - target_desc
232
+ distance = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1]))
233
+ best_distance = float(distance.min().item())
234
+ return float(1.0 / (1.0 + max(best_distance, 0.0)))
235
+
236
+
237
+ def _passes_threshold(overlap: Optional[float], threshold: Optional[float]) -> bool:
238
+ if threshold is None or overlap is None:
239
+ return True
240
+ return float(overlap) >= float(threshold)
241
+
242
+
243
+ def _batched_pose_overlaps(
244
+ records: list[MemoryRecord],
245
+ *,
246
+ target_pose=None,
247
+ fov_half_h: float = 105.0 / 2.0,
248
+ fov_half_v: float = 75.0 / 2.0,
249
+ fov_yaw_samples: int = 25,
250
+ fov_pitch_samples: int = 20,
251
+ fov_depth_samples: int = 20,
252
+ fov_radius: float = 30.0,
253
+ plucker_grid_h: int = 4,
254
+ plucker_grid_w: int = 4,
255
+ plucker_focal_length: float = 0.35,
256
+ ) -> tuple[
257
+ list[Optional[float]],
258
+ list[Optional[float]],
259
+ list[Optional[torch.Tensor]],
260
+ list[Optional[int]],
261
+ list[Optional[float]],
262
+ ]:
263
+ fov_overlaps: list[Optional[float]] = [None] * len(records)
264
+ plucker_overlaps: list[Optional[float]] = [None] * len(records)
265
+ coverage_masks: list[Optional[torch.Tensor]] = [None] * len(records)
266
+ best_frame_indices: list[Optional[int]] = [None] * len(records)
267
+ best_frame_fov_overlaps: list[Optional[float]] = [None] * len(records)
268
+ target_poses = _pose_frames(target_pose)
269
+ if not records or target_poses is None:
270
+ return fov_overlaps, plucker_overlaps, coverage_masks, best_frame_indices, best_frame_fov_overlaps
271
+
272
+ pose_records: list[tuple[int, torch.Tensor, torch.Tensor | None]] = []
273
+ device = None
274
+ for record_idx, record in enumerate(records):
275
+ source_poses = _pose_frames(record.pose)
276
+ if source_poses is None:
277
+ continue
278
+ if device is None:
279
+ device = source_poses.device
280
+ frame_values = record.frame_indices.detach().reshape(-1)
281
+ frame_values = frame_values if int(frame_values.numel()) == int(source_poses.shape[0]) else None
282
+ pose_records.append((record_idx, source_poses, frame_values))
283
+ if not pose_records or device is None:
284
+ return fov_overlaps, plucker_overlaps, coverage_masks, best_frame_indices, best_frame_fov_overlaps
285
+
286
+ source_pose_blocks = [poses.to(device=device, dtype=torch.float32) for _, poses, _ in pose_records]
287
+ source_poses = torch.cat(source_pose_blocks, dim=0)
288
+ record_ids = torch.cat(
289
+ [
290
+ torch.full((poses.shape[0],), int(record_idx), device=device, dtype=torch.long)
291
+ for (record_idx, _, _), poses in zip(pose_records, source_pose_blocks)
292
+ ],
293
+ dim=0,
294
+ )
295
+ source_frame_blocks = [
296
+ torch.full((poses.shape[0],), -1, device=device, dtype=torch.long)
297
+ if frame_values is None
298
+ else frame_values.to(device=device, dtype=torch.long)
299
+ for (_, _, frame_values), poses in zip(pose_records, source_pose_blocks)
300
+ ]
301
+ source_frame_values = torch.cat(source_frame_blocks, dim=0)
302
+ pose_record_indices = [record_idx for record_idx, _, _ in pose_records]
303
+ target = target_poses[-1].to(device=device, dtype=source_poses.dtype)
304
+
305
+ points = _target_fov_points(
306
+ target,
307
+ fov_half_h=fov_half_h,
308
+ fov_half_v=fov_half_v,
309
+ yaw_samples=fov_yaw_samples,
310
+ pitch_samples=fov_pitch_samples,
311
+ depth_samples=fov_depth_samples,
312
+ radius=fov_radius,
313
+ )
314
+ inside = _inside_fov_3d_hv(points, source_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v)
315
+ per_frame_fov = inside.float().mean(dim=1)
316
+
317
+ source_desc = _plucker_descriptor(source_poses, grid_h=plucker_grid_h, grid_w=plucker_grid_w, focal_length=plucker_focal_length)
318
+ target_desc = _plucker_descriptor(
319
+ target.reshape(1, -1),
320
+ grid_h=plucker_grid_h,
321
+ grid_w=plucker_grid_w,
322
+ focal_length=plucker_focal_length,
323
+ )
324
+ diff = source_desc - target_desc
325
+ distance = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1]))
326
+ best_distance = torch.full((len(records),), float("inf"), device=device, dtype=distance.dtype)
327
+ best_distance.scatter_reduce_(0, record_ids, distance, reduce="amin", include_self=True)
328
+ plucker_values = (1.0 / (1.0 + best_distance.clamp_min(0.0))).detach().cpu().tolist()
329
+
330
+ for record_idx in pose_record_indices:
331
+ rows = record_ids == int(record_idx)
332
+ if not rows.any():
333
+ continue
334
+ record_fov = per_frame_fov[rows]
335
+ record_inside = inside[rows]
336
+ best_pose_row = int(torch.argmax(record_fov).item())
337
+ best_score_value = float(record_fov[best_pose_row].item())
338
+ fov_overlaps[record_idx] = best_score_value
339
+ plucker_overlaps[record_idx] = float(plucker_values[record_idx])
340
+ coverage_masks[record_idx] = record_inside[best_pose_row].detach()
341
+ valid_rows = rows & (source_frame_values >= 0)
342
+ if valid_rows.any():
343
+ frame_scores = per_frame_fov[valid_rows].detach().cpu().tolist()
344
+ frame_values = source_frame_values[valid_rows].detach().cpu().tolist()
345
+ best_score, best_frame = max(
346
+ ((float(score), int(frame)) for score, frame in zip(frame_scores, frame_values)),
347
+ key=lambda item: (item[0], item[1]),
348
+ )
349
+ best_frame_indices[record_idx] = int(best_frame)
350
+ best_frame_fov_overlaps[record_idx] = float(best_score)
351
+ return fov_overlaps, plucker_overlaps, coverage_masks, best_frame_indices, best_frame_fov_overlaps
352
+
353
+
354
+ def make_revisit_candidate_labels(
355
+ records: list[MemoryRecord],
356
+ *,
357
+ target_frame: int,
358
+ exclude_local_context_frames: int,
359
+ target_pose=None,
360
+ fov_overlap_threshold: Optional[float] = 0.0,
361
+ plucker_weight: float = 0.1,
362
+ target_video_id: Any = None,
363
+ fov_half_h: float = 105.0 / 2.0,
364
+ fov_half_v: float = 75.0 / 2.0,
365
+ fov_yaw_samples: int = 25,
366
+ fov_pitch_samples: int = 20,
367
+ fov_depth_samples: int = 20,
368
+ fov_radius: float = 30.0,
369
+ plucker_grid_h: int = 4,
370
+ plucker_grid_w: int = 4,
371
+ plucker_focal_length: float = 0.35,
372
+ ) -> list[RevisitCandidateLabel]:
373
+ del plucker_weight, target_video_id
374
+ target_frame = int(target_frame)
375
+ exclude_local_context_frames = int(exclude_local_context_frames)
376
+
377
+ score_indices: list[int] = []
378
+ max_source_frames: list[int] = []
379
+ gap_flags: list[bool] = []
380
+ for record_idx, record in enumerate(records):
381
+ max_source_frame = int(record.source_end) - 1
382
+ gap_ok = max_source_frame < target_frame - exclude_local_context_frames
383
+ max_source_frames.append(max_source_frame)
384
+ gap_flags.append(gap_ok)
385
+ if gap_ok:
386
+ score_indices.append(record_idx)
387
+
388
+ fov_overlaps: list[Optional[float]] = [None] * len(records)
389
+ plucker_overlaps: list[Optional[float]] = [None] * len(records)
390
+ coverage_masks: list[Optional[torch.Tensor]] = [None] * len(records)
391
+ best_frame_indices: list[Optional[int]] = [None] * len(records)
392
+ best_frame_fov_overlaps: list[Optional[float]] = [None] * len(records)
393
+ scored_records = [records[record_idx] for record_idx in score_indices]
394
+ scored_fov, scored_plucker, scored_masks, scored_best_frames, scored_best_frame_fov = _batched_pose_overlaps(
395
+ scored_records,
396
+ target_pose=target_pose,
397
+ fov_half_h=fov_half_h,
398
+ fov_half_v=fov_half_v,
399
+ fov_yaw_samples=fov_yaw_samples,
400
+ fov_pitch_samples=fov_pitch_samples,
401
+ fov_depth_samples=fov_depth_samples,
402
+ fov_radius=fov_radius,
403
+ plucker_grid_h=plucker_grid_h,
404
+ plucker_grid_w=plucker_grid_w,
405
+ plucker_focal_length=plucker_focal_length,
406
+ )
407
+ for scored_idx, record_idx in enumerate(score_indices):
408
+ fov_overlaps[record_idx] = scored_fov[scored_idx]
409
+ plucker_overlaps[record_idx] = scored_plucker[scored_idx]
410
+ coverage_masks[record_idx] = scored_masks[scored_idx]
411
+ best_frame_indices[record_idx] = scored_best_frames[scored_idx]
412
+ best_frame_fov_overlaps[record_idx] = scored_best_frame_fov[scored_idx]
413
+
414
+ labels: list[RevisitCandidateLabel] = []
415
+ for record_idx, record in enumerate(records):
416
+ gap = target_frame - max_source_frames[record_idx]
417
+ fov_overlap = fov_overlaps[record_idx]
418
+ reasons: list[str] = []
419
+ if not gap_flags[record_idx]:
420
+ reasons.append("inside_c_short")
421
+ if not _passes_threshold(fov_overlap, fov_overlap_threshold):
422
+ reasons.append("fov_overlap_below_threshold")
423
+
424
+ fov_score = 0.0 if fov_overlap is None else float(fov_overlap)
425
+ labels.append(
426
+ RevisitCandidateLabel(
427
+ record=record,
428
+ valid=not reasons,
429
+ gap_valid=gap_flags[record_idx],
430
+ gap_to_target=gap,
431
+ fov_overlap=fov_overlap,
432
+ plucker_overlap=plucker_overlaps[record_idx],
433
+ primary_overlap=fov_score,
434
+ coverage_mask=coverage_masks[record_idx],
435
+ reject_reasons=tuple(reasons),
436
+ best_frame_index=best_frame_indices[record_idx],
437
+ best_frame_fov_overlap=best_frame_fov_overlaps[record_idx],
438
+ )
439
+ )
440
+ return labels
441
+
442
+
443
+ def make_revisit_candidate_label(
444
+ record: MemoryRecord,
445
+ *,
446
+ target_frame: int,
447
+ exclude_local_context_frames: int,
448
+ target_pose=None,
449
+ fov_overlap_threshold: Optional[float] = 0.0,
450
+ plucker_weight: float = 0.1,
451
+ target_video_id: Any = None,
452
+ fov_half_h: float = 105.0 / 2.0,
453
+ fov_half_v: float = 75.0 / 2.0,
454
+ fov_yaw_samples: int = 25,
455
+ fov_pitch_samples: int = 20,
456
+ fov_depth_samples: int = 20,
457
+ fov_radius: float = 30.0,
458
+ plucker_grid_h: int = 4,
459
+ plucker_grid_w: int = 4,
460
+ plucker_focal_length: float = 0.35,
461
+ ) -> RevisitCandidateLabel:
462
+ return make_revisit_candidate_labels(
463
+ [record],
464
+ target_frame=target_frame,
465
+ exclude_local_context_frames=exclude_local_context_frames,
466
+ target_pose=target_pose,
467
+ fov_overlap_threshold=fov_overlap_threshold,
468
+ plucker_weight=plucker_weight,
469
+ target_video_id=target_video_id,
470
+ fov_half_h=fov_half_h,
471
+ fov_half_v=fov_half_v,
472
+ fov_yaw_samples=fov_yaw_samples,
473
+ fov_pitch_samples=fov_pitch_samples,
474
+ fov_depth_samples=fov_depth_samples,
475
+ fov_radius=fov_radius,
476
+ plucker_grid_h=plucker_grid_h,
477
+ plucker_grid_w=plucker_grid_w,
478
+ plucker_focal_length=plucker_focal_length,
479
+ )[0]
algorithms/worldmem/dememwm/memory.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Iterable, Optional
6
+
7
+ import torch
8
+
9
+ from .types import MemoryRecord, MemorySourceType
10
+
11
+
12
+ @dataclass
13
+ class MemoryBankQuery:
14
+ target_frame: int
15
+ source_type: Optional[MemorySourceType] = None
16
+ include_generated: bool = True
17
+ max_records: Optional[int] = None
18
+ max_slots: Optional[int] = None
19
+
20
+
21
+ class CausalMemoryBank:
22
+ """Small causal memory bank for DeMemWM records."""
23
+
24
+ def __init__(self, max_records: Optional[int] = None, max_slots: Optional[int] = None):
25
+ self.max_records = max_records
26
+ self.max_slots = max_slots
27
+ self._records: list[MemoryRecord] = []
28
+
29
+ def __len__(self) -> int:
30
+ return len(self._records)
31
+
32
+ @property
33
+ def records(self) -> tuple[MemoryRecord, ...]:
34
+ return tuple(self._records)
35
+
36
+ def add_record(self, record: MemoryRecord) -> None:
37
+ if record.source_type == MemorySourceType.PREFIX_GT and record.is_generated:
38
+ raise ValueError("generated records cannot be high-trust prefix anchors")
39
+ self._records.append(record)
40
+ if self.max_records is not None and len(self._records) > self.max_records:
41
+ self._records = self._records[-self.max_records:]
42
+
43
+ def add_prefix_anchors(
44
+ self,
45
+ tokens: torch.Tensor,
46
+ mask: torch.Tensor,
47
+ frame_indices: torch.Tensor,
48
+ pose: Optional[torch.Tensor] = None,
49
+ slots_per_anchor: Optional[int] = None,
50
+ ) -> None:
51
+ if tokens.ndim == 2:
52
+ tokens = tokens.unsqueeze(0)
53
+ if mask.ndim == 1:
54
+ mask = mask.unsqueeze(0)
55
+ flat_frames = frame_indices.detach().reshape(-1)
56
+ if tokens.shape[0] != flat_frames.numel():
57
+ raise ValueError("tokens first dimension must match number of frame indices")
58
+ for i, frame in enumerate(flat_frames.tolist()):
59
+ rec_tokens = tokens[i]
60
+ rec_mask = mask[i].bool()
61
+ if slots_per_anchor is not None:
62
+ rec_tokens = rec_tokens[:slots_per_anchor]
63
+ rec_mask = rec_mask[:slots_per_anchor]
64
+ self.add_record(
65
+ MemoryRecord(
66
+ tokens=rec_tokens,
67
+ mask=rec_mask,
68
+ source_start=int(frame),
69
+ source_end=int(frame) + 1,
70
+ frame_indices=torch.as_tensor([frame], device=rec_tokens.device),
71
+ pose=None if pose is None else pose[i],
72
+ source_type=MemorySourceType.PREFIX_GT,
73
+ is_generated=False,
74
+ chunk_id=f"prefix_{int(frame)}",
75
+ )
76
+ )
77
+
78
+ def add_chunk_record(
79
+ self,
80
+ tokens: torch.Tensor,
81
+ mask: torch.Tensor,
82
+ frame_indices: torch.Tensor,
83
+ pose: Optional[torch.Tensor] = None,
84
+ source_type: MemorySourceType = MemorySourceType.PREFIX_GT,
85
+ is_generated: bool = False,
86
+ chunk_id: Optional[str] = None,
87
+ metadata: Optional[dict] = None,
88
+ ) -> None:
89
+ flat_frames = frame_indices.detach().reshape(-1)
90
+ if flat_frames.numel() == 0:
91
+ raise ValueError("chunk frame_indices must be non-empty")
92
+ if tokens.ndim != 2:
93
+ raise ValueError("chunk tokens must have shape (M,D)")
94
+ if mask.ndim != 1 or mask.shape[0] != tokens.shape[0]:
95
+ raise ValueError("chunk mask must have shape (M,)")
96
+ start = int(flat_frames.min().item())
97
+ end = int(flat_frames.max().item()) + 1
98
+ self.add_record(
99
+ MemoryRecord(
100
+ tokens=tokens,
101
+ mask=mask.bool(),
102
+ source_start=start,
103
+ source_end=end,
104
+ frame_indices=flat_frames.to(device=tokens.device),
105
+ pose=pose,
106
+ source_type=source_type,
107
+ is_generated=bool(is_generated),
108
+ chunk_id=chunk_id or f"{source_type.value}_chunk_{start}_{end}",
109
+ metadata=dict(metadata or {}),
110
+ )
111
+ )
112
+
113
+ def add_frame_record(
114
+ self,
115
+ tokens: torch.Tensor,
116
+ mask: torch.Tensor,
117
+ frame_index: torch.Tensor | int,
118
+ pose: Optional[torch.Tensor] = None,
119
+ source_type: MemorySourceType = MemorySourceType.REVISIT,
120
+ is_generated: bool = False,
121
+ record_id: Optional[str] = None,
122
+ metadata: Optional[dict] = None,
123
+ ) -> None:
124
+ frame_tensor = torch.as_tensor([int(torch.as_tensor(frame_index).reshape(-1)[0].item())], device=tokens.device)
125
+ frame = int(frame_tensor.item())
126
+ self.add_record(
127
+ MemoryRecord(
128
+ tokens=tokens,
129
+ mask=mask.bool(),
130
+ source_start=frame,
131
+ source_end=frame + 1,
132
+ frame_indices=frame_tensor,
133
+ pose=pose,
134
+ source_type=source_type,
135
+ is_generated=bool(is_generated),
136
+ chunk_id=record_id or f"{source_type.value}_frame_{frame}",
137
+ metadata=dict(metadata or {}),
138
+ )
139
+ )
140
+
141
+ def add_generated_records(
142
+ self,
143
+ tokens: torch.Tensor,
144
+ mask: torch.Tensor,
145
+ frame_indices: torch.Tensor,
146
+ pose: Optional[torch.Tensor] = None,
147
+ source_type: MemorySourceType = MemorySourceType.GENERATED,
148
+ ) -> None:
149
+ if source_type == MemorySourceType.PREFIX_GT:
150
+ raise ValueError("generated frames cannot be added as PREFIX_GT anchors by default")
151
+ if tokens.ndim == 2:
152
+ tokens = tokens.unsqueeze(0)
153
+ if mask.ndim == 1:
154
+ mask = mask.unsqueeze(0)
155
+ flat_frames = frame_indices.detach().reshape(-1)
156
+ for i, frame in enumerate(flat_frames.tolist()):
157
+ self.add_record(
158
+ MemoryRecord(
159
+ tokens=tokens[i],
160
+ mask=mask[i].bool(),
161
+ source_start=int(frame),
162
+ source_end=int(frame) + 1,
163
+ frame_indices=torch.as_tensor([frame], device=tokens.device),
164
+ pose=None if pose is None else pose[i],
165
+ source_type=source_type,
166
+ is_generated=True,
167
+ chunk_id=f"generated_{int(frame)}",
168
+ )
169
+ )
170
+
171
+ def query(self, query: MemoryBankQuery | int, **kwargs) -> list[MemoryRecord]:
172
+ if isinstance(query, int):
173
+ query = MemoryBankQuery(target_frame=query, **kwargs)
174
+ out: list[MemoryRecord] = []
175
+ used_slots = 0
176
+ for record in self._records:
177
+ if int(record.source_end) > int(query.target_frame):
178
+ continue
179
+ if query.source_type is not None and record.source_type != query.source_type:
180
+ continue
181
+ if not query.include_generated and record.is_generated:
182
+ continue
183
+ if query.max_slots is not None and used_slots >= query.max_slots:
184
+ break
185
+ out.append(record)
186
+ if query.max_slots is not None:
187
+ used_slots += record.valid_slots
188
+ if query.max_records is not None and len(out) >= query.max_records:
189
+ break
190
+ if query.max_slots is not None and used_slots >= query.max_slots:
191
+ break
192
+ return out
193
+
194
+ def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None:
195
+ offenders = [r.chunk_id or f"[{r.source_start},{r.source_end})" for r in records if int(r.source_end) > int(target_frame)]
196
+ if offenders:
197
+ raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}")
198
+
199
+
200
+ def stack_record_tokens(records: list[MemoryRecord], max_slots: int | None = None):
201
+ if not records:
202
+ return None, None
203
+ tokens = torch.cat([r.tokens for r in records], dim=0)
204
+ mask = torch.cat([r.mask.bool() for r in records], dim=0)
205
+ if max_slots is not None:
206
+ tokens = tokens[:max_slots]
207
+ mask = mask[:max_slots]
208
+ return tokens, mask
algorithms/worldmem/dememwm/negatives.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+
5
+ from .schedules import EVAL_CORRUPTION_BRANCHES
6
+
7
+
8
+ def _deterministic_noise_like(tokens: torch.Tensor, seed: int) -> torch.Tensor:
9
+ flat = torch.arange(tokens.numel(), device=tokens.device, dtype=torch.float32).reshape(tokens.shape)
10
+ noise = torch.sin(flat + float(seed) * 0.137).to(dtype=tokens.dtype)
11
+ scale = tokens.detach().float().std().to(device=tokens.device).clamp_min(0.05).to(dtype=tokens.dtype)
12
+ return noise * scale
13
+
14
+
15
+ def apply_revisit_eval_corruption(
16
+ *,
17
+ tokens: torch.Tensor,
18
+ mask: torch.Tensor,
19
+ branch: str,
20
+ target_frame: int,
21
+ ) -> tuple[torch.Tensor, bool]:
22
+ if branch not in EVAL_CORRUPTION_BRANCHES or not mask.any():
23
+ return tokens, False
24
+
25
+ corrupted = tokens.clone()
26
+ if branch == "wrong_pose":
27
+ corrupted = -corrupted
28
+ elif branch == "time_shuffle":
29
+ corrupted = torch.flip(corrupted, dims=(0,))
30
+ elif branch == "source_matched_random":
31
+ corrupted = _deterministic_noise_like(corrupted, seed=int(target_frame))
32
+ elif branch == "local_context_overlap_fake_revisit":
33
+ corrupted = torch.roll(corrupted, shifts=1, dims=0)
34
+ elif branch == "pose_shuffle":
35
+ corrupted = torch.roll(corrupted, shifts=1, dims=-1)
36
+ elif branch == "wrong_video":
37
+ corrupted = corrupted.detach().mean().to(dtype=corrupted.dtype).expand_as(corrupted).clone()
38
+ else:
39
+ return tokens, False
40
+
41
+ return corrupted, True
algorithms/worldmem/dememwm/retrieval.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import replace
5
+ from typing import Any, Optional
6
+
7
+ import torch
8
+
9
+ from .labels import (
10
+ LABEL_SOURCE,
11
+ RevisitCandidateLabel,
12
+ _inside_fov_3d_hv,
13
+ _plucker_descriptor,
14
+ _target_fov_points,
15
+ )
16
+ from .types import MemoryRecord, RevisitRetrievalResult
17
+
18
+
19
+ def _overlap_values(labels, name: str) -> list[float]:
20
+ values: list[float] = []
21
+ for label in labels:
22
+ value = getattr(label, name)
23
+ if value is not None:
24
+ values.append(float(value))
25
+ return values
26
+
27
+
28
+ def _overlap_stats(values: list[float], prefix: str) -> dict[str, float]:
29
+ if not values:
30
+ return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0}
31
+ return {
32
+ f"{prefix}_mean": float(sum(values) / len(values)),
33
+ f"{prefix}_min": float(min(values)),
34
+ f"{prefix}_max": float(max(values)),
35
+ }
36
+
37
+
38
+ def _pose_rows(pose) -> torch.Tensor | None:
39
+ if pose is None:
40
+ return None
41
+ pose_tensor = pose if torch.is_tensor(pose) else torch.as_tensor(pose, dtype=torch.float32)
42
+ if pose_tensor.ndim == 0 or pose_tensor.numel() == 0 or pose_tensor.shape[-1] < 5:
43
+ return None
44
+ return pose_tensor.detach().reshape(-1, pose_tensor.shape[-1])[:, :5].to(dtype=torch.float32)
45
+
46
+
47
+ def _pose_forward(poses: torch.Tensor) -> torch.Tensor:
48
+ pitch = torch.deg2rad(poses[:, 3])
49
+ yaw = torch.deg2rad(poses[:, 4])
50
+ cos_pitch = torch.cos(pitch)
51
+ return torch.stack(
52
+ [
53
+ cos_pitch * torch.sin(yaw),
54
+ torch.sin(pitch),
55
+ cos_pitch * torch.cos(yaw),
56
+ ],
57
+ dim=-1,
58
+ )
59
+
60
+
61
+ def _single_frame_pose(record: MemoryRecord) -> torch.Tensor | None:
62
+ if int(record.frame_indices.numel()) != 1:
63
+ return None
64
+ pose_rows = _pose_rows(record.pose)
65
+ if pose_rows is None or int(pose_rows.shape[0]) != 1:
66
+ return None
67
+ return pose_rows[0]
68
+
69
+
70
+ def _vectorized_frame_candidate_labels(
71
+ records: list[MemoryRecord],
72
+ *,
73
+ target_frame: int,
74
+ target_pose,
75
+ fov_overlap_threshold: Optional[float],
76
+ fov_half_h: float,
77
+ fov_half_v: float,
78
+ fov_yaw_samples: int,
79
+ fov_pitch_samples: int,
80
+ fov_depth_samples: int,
81
+ fov_radius: float,
82
+ plucker_grid_h: int,
83
+ plucker_grid_w: int,
84
+ plucker_focal_length: float,
85
+ pose_preselect_topk: Optional[int],
86
+ ) -> tuple[list[RevisitCandidateLabel], dict[str, float | int]]:
87
+ diagnostics: dict[str, float | int] = {
88
+ "revisit_pose_preselect_input_count": len(records),
89
+ "revisit_pose_preselect_scored_count": len(records),
90
+ "revisit_pose_preselect_unscored_count": 0,
91
+ "revisit_pose_preselect_selected_count": len(records),
92
+ "revisit_pose_preselect_min_distance": 0.0,
93
+ "revisit_pose_preselect_max_distance": 0.0,
94
+ "revisit_exact_fov_candidate_count": len(records),
95
+ "revisit_vectorized_frame_scorer_used": 1,
96
+ }
97
+ if not records:
98
+ return [], diagnostics
99
+
100
+ target_poses = _pose_rows(target_pose)
101
+ if target_poses is None:
102
+ raise ValueError("DeMemWM revisit retrieval requires target_pose for frame-level FoV scoring")
103
+
104
+ pose_rows: list[torch.Tensor] = []
105
+ for record in records:
106
+ pose_row = _single_frame_pose(record)
107
+ if pose_row is None:
108
+ raise ValueError(
109
+ "DeMemWM revisit retrieval requires frame-level records with exactly one frame index and one pose row"
110
+ )
111
+ pose_rows.append(pose_row)
112
+
113
+ device = pose_rows[0].device
114
+ if target_poses.is_cuda:
115
+ device = target_poses.device
116
+ source_poses = torch.stack([row.to(device=device, dtype=torch.float32) for row in pose_rows], dim=0)
117
+ target = target_poses[-1].to(device=device, dtype=torch.float32)
118
+
119
+ selected_indices = list(range(len(records)))
120
+ topk = None if pose_preselect_topk is None else int(pose_preselect_topk)
121
+ if topk is not None and topk > 0 and len(records) > topk:
122
+ translation_norm = torch.linalg.vector_norm(source_poses[:, :3] - target[:3], dim=-1) / max(float(fov_radius), 1e-6)
123
+ source_forward = _pose_forward(source_poses)
124
+ target_forward = _pose_forward(target.reshape(1, -1)).squeeze(0)
125
+ dot = (source_forward * target_forward.reshape(1, 3)).sum(dim=-1).clamp(-1.0, 1.0)
126
+ distances = translation_norm + (torch.acos(dot) / math.pi)
127
+ distance_values = [float(value) for value in distances.detach().cpu().tolist()]
128
+ ranked = [
129
+ (
130
+ distance_values[idx],
131
+ -int(record.max_source_frame),
132
+ int(record.source_start),
133
+ str(record.chunk_id or ""),
134
+ idx,
135
+ )
136
+ for idx, record in enumerate(records)
137
+ ]
138
+ ranked.sort()
139
+ selected_indices = [idx for *_, idx in ranked[:topk]]
140
+ diagnostics["revisit_pose_preselect_selected_count"] = len(selected_indices)
141
+ diagnostics["revisit_pose_preselect_min_distance"] = float(min(distance_values))
142
+ diagnostics["revisit_pose_preselect_max_distance"] = float(max(distance_values))
143
+
144
+ selected_tensor = torch.tensor(selected_indices, device=device, dtype=torch.long)
145
+ selected_records = [records[idx] for idx in selected_indices]
146
+ selected_poses = source_poses.index_select(0, selected_tensor)
147
+ points = _target_fov_points(
148
+ target,
149
+ fov_half_h=fov_half_h,
150
+ fov_half_v=fov_half_v,
151
+ yaw_samples=fov_yaw_samples,
152
+ pitch_samples=fov_pitch_samples,
153
+ depth_samples=fov_depth_samples,
154
+ radius=fov_radius,
155
+ )
156
+ inside = _inside_fov_3d_hv(points, selected_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v)
157
+ fov_values = inside.float().mean(dim=1)
158
+
159
+ source_desc = _plucker_descriptor(
160
+ selected_poses,
161
+ grid_h=plucker_grid_h,
162
+ grid_w=plucker_grid_w,
163
+ focal_length=plucker_focal_length,
164
+ )
165
+ target_desc = _plucker_descriptor(
166
+ target.reshape(1, -1),
167
+ grid_h=plucker_grid_h,
168
+ grid_w=plucker_grid_w,
169
+ focal_length=plucker_focal_length,
170
+ )
171
+ diff = source_desc - target_desc
172
+ distances = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1]))
173
+ plucker_values = 1.0 / (1.0 + distances.clamp_min(0.0))
174
+ valid_mask = torch.ones_like(fov_values, dtype=torch.bool)
175
+ if fov_overlap_threshold is not None:
176
+ valid_mask = fov_values >= float(fov_overlap_threshold)
177
+
178
+ diagnostics["revisit_exact_fov_candidate_count"] = len(selected_records)
179
+ fov_list = [float(value) for value in fov_values.detach().cpu().tolist()]
180
+ plucker_list = [float(value) for value in plucker_values.detach().cpu().tolist()]
181
+ valid_list = [bool(value) for value in valid_mask.detach().cpu().tolist()]
182
+
183
+ labels: list[RevisitCandidateLabel] = []
184
+ for row_idx, record in enumerate(selected_records):
185
+ fov_overlap = fov_list[row_idx]
186
+ reasons = () if valid_list[row_idx] else ("fov_overlap_below_threshold",)
187
+ gap_to_target = int(target_frame) - (int(record.source_end) - 1)
188
+ labels.append(
189
+ RevisitCandidateLabel(
190
+ record=record,
191
+ valid=valid_list[row_idx],
192
+ gap_valid=True,
193
+ gap_to_target=gap_to_target,
194
+ fov_overlap=fov_overlap,
195
+ plucker_overlap=plucker_list[row_idx],
196
+ primary_overlap=fov_overlap,
197
+ coverage_mask=inside[row_idx].detach(),
198
+ reject_reasons=reasons,
199
+ best_frame_index=int(record.max_source_frame),
200
+ best_frame_fov_overlap=fov_overlap,
201
+ )
202
+ )
203
+ return labels, diagnostics
204
+
205
+
206
+ def _coverage_gain(label: RevisitCandidateLabel, covered_mask: torch.Tensor | None) -> float:
207
+ mask = label.coverage_mask
208
+ if mask is None or mask.numel() == 0:
209
+ return 0.0 if label.fov_overlap is None else float(label.fov_overlap)
210
+ mask = mask.detach().bool()
211
+ if covered_mask is None or covered_mask.shape != mask.shape:
212
+ return float(mask.float().mean().item())
213
+ return float((mask & ~covered_mask.to(device=mask.device, dtype=torch.bool)).float().mean().item())
214
+
215
+
216
+ def _coverage_gains(labels: list[RevisitCandidateLabel], covered_mask: torch.Tensor | None) -> list[float]:
217
+ masks = [label.coverage_mask for label in labels]
218
+ valid_masks = [mask for mask in masks if mask is not None and mask.numel() > 0]
219
+ if not valid_masks:
220
+ return [0.0 if label.fov_overlap is None else float(label.fov_overlap) for label in labels]
221
+
222
+ shape = valid_masks[0].shape
223
+ device = valid_masks[0].device
224
+ if any(mask.shape != shape for mask in valid_masks):
225
+ return [_coverage_gain(label, covered_mask) for label in labels]
226
+
227
+ stacked = torch.stack([
228
+ torch.zeros(shape, device=device, dtype=torch.bool)
229
+ if mask is None or mask.numel() == 0
230
+ else mask.detach().to(device=device, dtype=torch.bool)
231
+ for mask in masks
232
+ ])
233
+ if covered_mask is None or covered_mask.shape != shape:
234
+ gains = stacked.float().mean(dim=1)
235
+ else:
236
+ covered = covered_mask.to(device=device, dtype=torch.bool)
237
+ gains = (stacked & ~covered).float().mean(dim=1)
238
+ return [float(value) for value in gains.detach().cpu().tolist()]
239
+
240
+
241
+ def _select_greedy_coverage(
242
+ labels: list[RevisitCandidateLabel],
243
+ *,
244
+ topk: int,
245
+ plucker_weight: float,
246
+ ) -> tuple[list[RevisitCandidateLabel], list[float], list[float]]:
247
+ remaining = list(labels)
248
+ selected: list[RevisitCandidateLabel] = []
249
+ selected_scores: list[float] = []
250
+ selected_gains: list[float] = []
251
+ covered_mask: torch.Tensor | None = None
252
+ for _ in range(max(0, int(topk))):
253
+ if not remaining:
254
+ break
255
+ gains = _coverage_gains(remaining, covered_mask)
256
+ ranked = []
257
+ for idx, (label, gain) in enumerate(zip(remaining, gains)):
258
+ plucker = 0.0 if label.plucker_overlap is None else float(label.plucker_overlap)
259
+ fov = 0.0 if label.fov_overlap is None else float(label.fov_overlap)
260
+ plucker_secondary = float(plucker_weight) * plucker
261
+ ranked.append((
262
+ -gain,
263
+ -fov,
264
+ -plucker_secondary,
265
+ label.gap_to_target,
266
+ int(label.record.source_start),
267
+ str(label.record.chunk_id or ""),
268
+ idx,
269
+ gain,
270
+ ))
271
+ ranked.sort()
272
+ _, _, _, _, _, _, best_idx, best_gain = ranked[0]
273
+ label = remaining.pop(best_idx)
274
+ selected.append(label)
275
+ selected_scores.append(float(best_gain))
276
+ selected_gains.append(float(best_gain))
277
+ if label.coverage_mask is not None and label.coverage_mask.numel() > 0:
278
+ mask = label.coverage_mask.detach().bool()
279
+ if covered_mask is None or covered_mask.shape != mask.shape:
280
+ covered_mask = torch.zeros_like(mask, dtype=torch.bool)
281
+ covered_mask = covered_mask.to(device=mask.device, dtype=torch.bool) | mask
282
+ return selected, selected_scores, selected_gains
283
+
284
+
285
+ def _best_selected_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidateLabel | None:
286
+ if not labels:
287
+ return None
288
+ return max(
289
+ labels,
290
+ key=lambda label: (
291
+ 0.0 if label.fov_overlap is None else float(label.fov_overlap),
292
+ 0.0 if label.plucker_overlap is None else float(label.plucker_overlap),
293
+ -int(label.gap_to_target),
294
+ -int(label.record.source_start),
295
+ str(label.record.chunk_id or ""),
296
+ ),
297
+ )
298
+
299
+
300
+ def _best_selected_frame_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidateLabel | None:
301
+ frame_labels = [label for label in labels if label.best_frame_fov_overlap is not None]
302
+ if not frame_labels:
303
+ return None
304
+ return max(
305
+ frame_labels,
306
+ key=lambda label: (
307
+ float(label.best_frame_fov_overlap),
308
+ 0.0 if label.fov_overlap is None else float(label.fov_overlap),
309
+ 0.0 if label.plucker_overlap is None else float(label.plucker_overlap),
310
+ -int(label.gap_to_target),
311
+ -int(label.record.source_start),
312
+ str(label.record.chunk_id or ""),
313
+ ),
314
+ )
315
+
316
+
317
+ def _record_with_selected_frame_metadata(
318
+ label: RevisitCandidateLabel,
319
+ *,
320
+ high_quality_fov_threshold: float,
321
+ ) -> MemoryRecord:
322
+ metadata = dict(label.record.metadata or {})
323
+ if label.fov_overlap is not None:
324
+ metadata["dememwm_selected_revisit_fov_overlap"] = float(label.fov_overlap)
325
+ if label.plucker_overlap is not None:
326
+ metadata["dememwm_selected_revisit_plucker_overlap"] = float(label.plucker_overlap)
327
+ if label.best_frame_index is not None:
328
+ metadata["dememwm_selected_frame_index"] = int(label.best_frame_index)
329
+ if label.best_frame_fov_overlap is not None:
330
+ frame_fov = float(label.best_frame_fov_overlap)
331
+ metadata["dememwm_selected_frame_fov_overlap"] = frame_fov
332
+ metadata["dememwm_selected_frame_fov_threshold"] = float(high_quality_fov_threshold)
333
+ metadata["dememwm_selected_frame_passes_high_quality"] = bool(frame_fov >= float(high_quality_fov_threshold))
334
+ return replace(label.record, metadata=metadata)
335
+
336
+
337
+ def deterministic_revisit_retrieval(
338
+ records: list[MemoryRecord],
339
+ target_frame: int,
340
+ target_pose: Optional[torch.Tensor] = None,
341
+ target_summary: Optional[torch.Tensor] = None,
342
+ topk: int = 2,
343
+ exclude_local_context_frames: int = 0,
344
+ fov_overlap_threshold: Optional[float] = 0.30,
345
+ high_quality_fov_threshold: float = 0.70,
346
+ plucker_weight: float = 0.1,
347
+ target_video_id: Any = None,
348
+ fov_half_h: float = 105.0 / 2.0,
349
+ fov_half_v: float = 75.0 / 2.0,
350
+ fov_yaw_samples: int = 25,
351
+ fov_pitch_samples: int = 20,
352
+ fov_depth_samples: int = 20,
353
+ fov_radius: float = 30.0,
354
+ plucker_grid_h: int = 4,
355
+ plucker_grid_w: int = 4,
356
+ plucker_focal_length: float = 0.35,
357
+ pose_preselect_topk: Optional[int] = 64,
358
+ **_legacy_scoring_kwargs,
359
+ ) -> RevisitRetrievalResult:
360
+ del target_summary, target_video_id
361
+ topk = max(0, int(topk))
362
+ target_frame = int(target_frame)
363
+ exclude_local_context_frames = int(exclude_local_context_frames)
364
+ causal_records = [record for record in records if int(record.source_end) <= target_frame]
365
+ score_records = [
366
+ record
367
+ for record in causal_records
368
+ if int(record.source_end) <= target_frame - exclude_local_context_frames
369
+ ]
370
+ labels, pose_preselect_diagnostics = _vectorized_frame_candidate_labels(
371
+ score_records,
372
+ target_frame=target_frame,
373
+ target_pose=target_pose,
374
+ fov_overlap_threshold=fov_overlap_threshold,
375
+ fov_half_h=fov_half_h,
376
+ fov_half_v=fov_half_v,
377
+ fov_yaw_samples=fov_yaw_samples,
378
+ fov_pitch_samples=fov_pitch_samples,
379
+ fov_depth_samples=fov_depth_samples,
380
+ fov_radius=fov_radius,
381
+ plucker_grid_h=plucker_grid_h,
382
+ plucker_grid_w=plucker_grid_w,
383
+ plucker_focal_length=plucker_focal_length,
384
+ pose_preselect_topk=pose_preselect_topk,
385
+ )
386
+ exact_fov_candidate_count = int(pose_preselect_diagnostics["revisit_exact_fov_candidate_count"])
387
+ valid_labels = [label for label in labels if label.valid]
388
+ selected_labels, selected_scores, selected_gains = _select_greedy_coverage(
389
+ valid_labels,
390
+ topk=topk,
391
+ plucker_weight=float(plucker_weight),
392
+ )
393
+ best_selected = _best_selected_label(selected_labels)
394
+ best_selected_frame = _best_selected_frame_label(selected_labels)
395
+ best_selected_fov = 0.0 if best_selected is None or best_selected.fov_overlap is None else float(best_selected.fov_overlap)
396
+ best_selected_plucker = 0.0 if best_selected is None or best_selected.plucker_overlap is None else float(best_selected.plucker_overlap)
397
+ best_selected_gap = -1 if best_selected is None else int(best_selected.gap_to_target)
398
+ best_selected_frame_fov = 0.0 if best_selected_frame is None else float(best_selected_frame.best_frame_fov_overlap)
399
+ best_selected_frame_index = -1 if best_selected_frame is None or best_selected_frame.best_frame_index is None else int(best_selected_frame.best_frame_index)
400
+ high_quality_selected = int(best_selected_frame is not None and best_selected_frame_fov >= float(high_quality_fov_threshold))
401
+ selected_records = [
402
+ _record_with_selected_frame_metadata(label, high_quality_fov_threshold=float(high_quality_fov_threshold))
403
+ for label in selected_labels
404
+ ]
405
+ score_device = selected_records[0].tokens.device if selected_records else torch.device("cpu")
406
+ scores = torch.tensor(selected_scores, dtype=torch.float32, device=score_device)
407
+
408
+ fov_values = _overlap_values(valid_labels, "fov_overlap")
409
+ plucker_values = _overlap_values(valid_labels, "plucker_overlap")
410
+ selected_gaps = [label.gap_to_target for label in selected_labels]
411
+ selected_frame_fov_values = [
412
+ float(label.best_frame_fov_overlap)
413
+ for label in selected_labels
414
+ if label.best_frame_fov_overlap is not None
415
+ ]
416
+ diagnostics = {
417
+ "target_frame": int(target_frame),
418
+ "candidate_count": len(causal_records),
419
+ "candidate_frame_count": len(causal_records),
420
+ "valid_candidate_count": len(valid_labels),
421
+ "revisit_exact_fov_candidate_count": exact_fov_candidate_count,
422
+ "valid_candidate_frame_count": len(valid_labels),
423
+ "valid_candidate_label_count": len(valid_labels),
424
+ "selected_count": len(selected_records),
425
+ "selected_frame_count": len(selected_records),
426
+ "revisit_candidate_frame_count": len(causal_records),
427
+ "revisit_candidate_count": len(causal_records),
428
+ "valid_revisit_frame_count": len(valid_labels),
429
+ "valid_revisit_count": len(valid_labels),
430
+ "valid_revisit_target_count": high_quality_selected,
431
+ "no_valid_revisit_count": int(len(valid_labels) == 0),
432
+ "valid_revisit_mask": int(len(valid_labels) > 0),
433
+ "revisit_abstained_count": int(len(selected_records) == 0),
434
+ "abstained": bool(len(selected_records) == 0),
435
+ "revisit_selected_frame_count": len(selected_records),
436
+ "revisit_selected_count": len(selected_records),
437
+ "revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1,
438
+ "best_selected_fov_overlap": best_selected_fov,
439
+ "best_selected_plucker_overlap": best_selected_plucker,
440
+ "best_selected_gap_frames": best_selected_gap,
441
+ "best_selected_frame_index": best_selected_frame_index,
442
+ "best_selected_frame_fov_overlap": best_selected_frame_fov,
443
+ "best_selected_frame_passes_high_quality": high_quality_selected,
444
+ "high_quality_selected_revisit": high_quality_selected,
445
+ "high_quality_fov_threshold": float(high_quality_fov_threshold),
446
+ "revisit_label_source": LABEL_SOURCE,
447
+ "selected_frame_ids": [int(record.max_source_frame) for record in selected_records],
448
+ "selected_frame_record_ids": [record.chunk_id for record in selected_records],
449
+ "selected_ranges": [(record.source_start, record.source_end) for record in selected_records],
450
+ "frame_fov_overlap_values": fov_values,
451
+ "fov_overlap_values": fov_values,
452
+ "plucker_overlap_values": plucker_values,
453
+ "best_selected_fov_overlap_values": [] if best_selected is None else [best_selected_fov],
454
+ "best_selected_plucker_overlap_values": [] if best_selected is None else [best_selected_plucker],
455
+ "best_selected_gap_frame_values": [] if best_selected is None else [best_selected_gap],
456
+ "best_selected_frame_fov_overlap_values": [] if best_selected_frame is None else [best_selected_frame_fov],
457
+ "selected_frame_fov_overlap_values": selected_frame_fov_values,
458
+ "selected_incremental_fov_overlap_values": selected_gains,
459
+ "selected_revisit_scores": selected_scores,
460
+ **pose_preselect_diagnostics,
461
+ }
462
+ diagnostics.update(_overlap_stats(fov_values, "revisit_frame_fov_overlap"))
463
+ diagnostics.update(_overlap_stats(fov_values, "revisit_fov_overlap"))
464
+ diagnostics.update(_overlap_stats(plucker_values, "revisit_plucker_overlap"))
465
+ diagnostics.update(_overlap_stats(diagnostics["best_selected_fov_overlap_values"], "revisit_best_selected_fov_overlap"))
466
+ diagnostics.update(_overlap_stats(diagnostics["best_selected_plucker_overlap_values"], "revisit_best_selected_plucker_overlap"))
467
+ diagnostics.update(_overlap_stats(diagnostics["best_selected_gap_frame_values"], "revisit_best_selected_gap_frames"))
468
+ diagnostics.update(_overlap_stats(diagnostics["best_selected_frame_fov_overlap_values"], "revisit_best_selected_frame_fov_overlap"))
469
+ diagnostics.update(_overlap_stats(selected_frame_fov_values, "revisit_selected_frame_fov_overlap"))
470
+ diagnostics.update(_overlap_stats(selected_gains, "revisit_incremental_fov_overlap"))
471
+ return RevisitRetrievalResult(
472
+ records=selected_records,
473
+ scores=scores,
474
+ selected_frame_ids=[int(record.max_source_frame) for record in selected_records],
475
+ diagnostics=diagnostics,
476
+ )
algorithms/worldmem/dememwm/schedules.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from .types import StreamGateState
10
+
11
+ NOISE_BUCKETS = ("high", "mid", "low")
12
+ NOISE_BUCKET_TO_ID = {name: idx for idx, name in enumerate(NOISE_BUCKETS)}
13
+ EVAL_ABLATION_BRANCHES = (
14
+ "memory_off",
15
+ "A_only",
16
+ "D_only",
17
+ "A_plus_D",
18
+ "A_plus_D_plus_R_normal",
19
+ "R_forced_off",
20
+ "R_forced_on",
21
+ "wrong_pose",
22
+ "time_shuffle",
23
+ "source_matched_random",
24
+ "pose_shuffle",
25
+ "wrong_video",
26
+ "local_context_overlap_fake_revisit",
27
+ )
28
+ EVAL_ABLATION_BRANCH_TO_ID = {name: idx for idx, name in enumerate(EVAL_ABLATION_BRANCHES)}
29
+ EVAL_CORRUPTION_BRANCHES = (
30
+ "wrong_pose",
31
+ "time_shuffle",
32
+ "source_matched_random",
33
+ "pose_shuffle",
34
+ "wrong_video",
35
+ "local_context_overlap_fake_revisit",
36
+ )
37
+
38
+
39
+
40
+ def _clamp01(value: float) -> float:
41
+ return max(0.0, min(1.0, float(value)))
42
+
43
+
44
+ def noise_bucket_from_denoising_fraction(denoising_fraction: float | None) -> str:
45
+ if denoising_fraction is None:
46
+ return "mid"
47
+ frac = _clamp01(float(denoising_fraction))
48
+ if frac < (1.0 / 3.0):
49
+ return "high"
50
+ if frac < (2.0 / 3.0):
51
+ return "mid"
52
+ return "low"
53
+
54
+
55
+ def noise_bucket_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> str:
56
+ if noise_levels is None or timesteps is None or int(timesteps) <= 1:
57
+ return "mid"
58
+ noise_fraction = _clamp01(float(noise_levels.detach().float().mean().item()) / float(int(timesteps) - 1))
59
+ if noise_fraction >= (2.0 / 3.0):
60
+ return "high"
61
+ if noise_fraction >= (1.0 / 3.0):
62
+ return "mid"
63
+ return "low"
64
+
65
+
66
+ def noise_bucket_ids_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> torch.Tensor | None:
67
+ if noise_levels is None or timesteps is None or int(timesteps) <= 1:
68
+ return None
69
+ noise_fraction = noise_levels.detach().float() / float(int(timesteps) - 1)
70
+ bucket_ids = torch.full_like(noise_levels, NOISE_BUCKET_TO_ID["mid"], dtype=torch.long)
71
+ bucket_ids = torch.where(
72
+ noise_fraction >= (2.0 / 3.0),
73
+ torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["high"]),
74
+ bucket_ids,
75
+ )
76
+ bucket_ids = torch.where(
77
+ noise_fraction < (1.0 / 3.0),
78
+ torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["low"]),
79
+ bucket_ids,
80
+ )
81
+ return bucket_ids
82
+
83
+
84
+ def denoising_fraction_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> float | None:
85
+ if noise_levels is None or timesteps is None or int(timesteps) <= 1:
86
+ return None
87
+ noise_fraction = _clamp01(float(noise_levels.detach().float().mean().item()) / float(int(timesteps) - 1))
88
+ return _clamp01(1.0 - noise_fraction)
89
+
90
+
91
+ def normalize_eval_ablation_branch(branch: str | None) -> str:
92
+ if branch is None:
93
+ return "A_plus_D_plus_R_normal"
94
+ branch = str(branch)
95
+ if branch not in EVAL_ABLATION_BRANCH_TO_ID:
96
+ raise ValueError(f"unknown DeMemWM eval ablation branch: {branch}")
97
+ return branch
98
+
99
+
100
+ def normalize_noise_bucket(noise_bucket: str | None) -> str:
101
+ if noise_bucket in NOISE_BUCKET_TO_ID:
102
+ return str(noise_bucket)
103
+ return "mid"
104
+
105
+
106
+ _STAGE_ENABLES = {
107
+ 'stage_1': (True, True, True),
108
+ 'stage_2': (True, True, True),
109
+ }
110
+
111
+
112
+
113
+ @dataclass(frozen=True)
114
+ class CurriculumState:
115
+ """Step-resolved DeMemWM curriculum/freezing state for one continuous run."""
116
+
117
+ global_step: int
118
+ enabled: bool
119
+ stage: str
120
+ anchor_enabled: bool
121
+ dynamic_enabled: bool
122
+ revisit_enabled: bool
123
+ dit_train_state: str
124
+ freeze_vae: bool
125
+ dememwm_lr: float
126
+ memory_adapter_lr: float
127
+ full_dit_lr: float
128
+
129
+
130
+ @property
131
+ def dit_full_trainable(self) -> bool:
132
+ return self.dit_train_state == "full"
133
+
134
+ def diagnostics(self) -> dict[str, Any]:
135
+ return {
136
+ "dememwm_global_step": self.global_step,
137
+ "dememwm_curriculum_enabled": self.enabled,
138
+ "dememwm_stage": self.stage,
139
+ "curriculum_anchor_enabled": self.anchor_enabled,
140
+ "curriculum_dynamic_enabled": self.dynamic_enabled,
141
+ "curriculum_revisit_enabled": self.revisit_enabled,
142
+ "dit_train_state": self.dit_train_state,
143
+ "dit_full_trainable": self.dit_full_trainable,
144
+ "freeze_vae": self.freeze_vae,
145
+ "lr_dememwm_modules": self.dememwm_lr,
146
+ "lr_memory_adapters": self.memory_adapter_lr,
147
+ "lr_full_dit": self.full_dit_lr,
148
+ }
149
+
150
+
151
+ def _cfg_get(obj: Any, name: str, default: Any) -> Any:
152
+ return getattr(obj, name, default) if obj is not None else default
153
+
154
+
155
+ def _stage_for_step(curriculum_cfg: Any, step: int) -> str:
156
+ full_start = int(_cfg_get(curriculum_cfg, 'full_stage_start_step', 60000))
157
+ return 'stage_2' if step >= full_start else 'stage_1'
158
+
159
+
160
+ def _dit_train_state(curriculum_cfg: Any, step: int) -> str:
161
+ freeze_cfg = _cfg_get(curriculum_cfg, 'dit_freeze', None)
162
+ freeze_enabled = bool(_cfg_get(freeze_cfg, 'enabled', True))
163
+ full_step = int(_cfg_get(curriculum_cfg, 'full_stage_start_step', 60000))
164
+ if freeze_enabled and step < full_step:
165
+ return 'frozen'
166
+ return 'full'
167
+
168
+
169
+ def resolve_curriculum(memory_cfg: Any, global_step: int | None = None) -> CurriculumState:
170
+ """Resolve internal DeMemWM curriculum phase from Lightning global_step.
171
+
172
+ This intentionally supports one continuous training run; stage names are internal
173
+ gates only and do not imply separate jobs/checkpoints.
174
+ """
175
+
176
+ step = max(0, int(global_step or 0))
177
+ curriculum_cfg = _cfg_get(memory_cfg, "curriculum", None)
178
+ lr_cfg = _cfg_get(curriculum_cfg, "lr", None)
179
+ enabled = bool(_cfg_get(curriculum_cfg, "enabled", False))
180
+
181
+ if enabled:
182
+ stage = _stage_for_step(curriculum_cfg, step)
183
+ dit_state = _dit_train_state(curriculum_cfg, step)
184
+ freeze_vae = bool(_cfg_get(curriculum_cfg, "freeze_vae", True))
185
+ else:
186
+ stage = str(_cfg_get(memory_cfg, "training_stage", "stage_1"))
187
+ dit_state = "full"
188
+ freeze_vae = True
189
+
190
+ if stage not in _STAGE_ENABLES:
191
+ raise ValueError(f"unknown DeMemWM stage: {stage}")
192
+ anchor_on, dynamic_on, revisit_on = _STAGE_ENABLES[stage]
193
+ return CurriculumState(
194
+ global_step=step,
195
+ enabled=enabled,
196
+ stage=stage,
197
+ anchor_enabled=anchor_on,
198
+ dynamic_enabled=dynamic_on,
199
+ revisit_enabled=revisit_on,
200
+ dit_train_state=dit_state,
201
+ freeze_vae=freeze_vae,
202
+ dememwm_lr=float(_cfg_get(lr_cfg, "dememwm_modules", 1.0e-4)),
203
+ memory_adapter_lr=float(_cfg_get(lr_cfg, "memory_adapters", 1.0e-4)),
204
+ full_dit_lr=float(_cfg_get(lr_cfg, "full_dit", 1.0e-5)),
205
+ )
206
+
207
+
208
+ DeMemWMCurriculumState = CurriculumState
209
+ resolve_dememwm_curriculum = resolve_curriculum
210
+
211
+
212
+ def compute_stream_gates(stage: str, denoising_fraction: float | None = None, debug_force_all_streams: bool = False, anchor_gate: float = 1.0, dynamic_gate: float = 1.0, revisit_gate: float = 1.0) -> StreamGateState:
213
+ if debug_force_all_streams:
214
+ return StreamGateState(True, True, True, float(anchor_gate), float(dynamic_gate), float(revisit_gate), stage, "debug_force_all_streams")
215
+ if stage not in _STAGE_ENABLES:
216
+ raise ValueError(f"unknown DeMemWM stage: {stage}")
217
+ a_on, d_on, r_on = _STAGE_ENABLES[stage]
218
+ if denoising_fraction is not None:
219
+ denoising_fraction = max(0.0, min(1.0, float(denoising_fraction)))
220
+ stage_scale = 0.25 + 0.75 * denoising_fraction
221
+ else:
222
+ stage_scale = 1.0
223
+ return StreamGateState(a_on, d_on, r_on, float(anchor_gate) if a_on else 0.0, float(dynamic_gate) * stage_scale if d_on else 0.0, float(revisit_gate) * stage_scale if r_on else 0.0, stage, "stage_schedule")
algorithms/worldmem/dememwm/types.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, field
5
+ from enum import Enum
6
+ from typing import Any, Optional
7
+
8
+ import torch
9
+
10
+
11
+ class MemorySourceType(str, Enum):
12
+ PREFIX_GT = "prefix_gt"
13
+ GENERATED = "generated"
14
+ DYNAMIC = "dynamic"
15
+ REVISIT = "revisit"
16
+
17
+
18
+ @dataclass
19
+ class MemoryRecord:
20
+ """One causal DeMemWM memory item.
21
+
22
+ Frame ranges use an inclusive start and exclusive end: [source_start, source_end).
23
+ All source frame indices represented by this record must be strictly smaller
24
+ than a queried target frame unless the caller is explicitly querying an
25
+ already-committed prefix frame.
26
+ """
27
+
28
+ tokens: torch.Tensor
29
+ mask: torch.Tensor
30
+ source_start: int
31
+ source_end: int
32
+ frame_indices: torch.Tensor
33
+ pose: Optional[torch.Tensor]
34
+ source_type: MemorySourceType
35
+ is_generated: bool
36
+ score: Optional[float | torch.Tensor] = None
37
+ chunk_id: Optional[str] = None
38
+ metadata: dict[str, Any] = field(default_factory=dict)
39
+
40
+ def __post_init__(self) -> None:
41
+ if self.source_end <= self.source_start:
42
+ raise ValueError("source_end must be greater than source_start")
43
+ if self.tokens.ndim < 2:
44
+ raise ValueError("tokens must include slot and channel dimensions")
45
+ self.mask = self.mask.bool()
46
+ if self.mask.ndim != 1:
47
+ raise ValueError("mask must have shape (M,)")
48
+ if self.mask.shape[0] != self.tokens.shape[0]:
49
+ raise ValueError("mask length must match token slots")
50
+ if self.source_type == MemorySourceType.PREFIX_GT and self.is_generated:
51
+ raise ValueError("generated records cannot be PREFIX_GT anchors")
52
+ if self.frame_indices.numel() == 0:
53
+ raise ValueError("frame_indices cannot be empty")
54
+
55
+ @property
56
+ def max_source_frame(self) -> int:
57
+ return int(self.frame_indices.detach().max().item())
58
+
59
+ @property
60
+ def valid_slots(self) -> int:
61
+ return int(self.mask.detach().sum().item())
62
+
63
+
64
+ @dataclass
65
+ class MemoryStreamTensors:
66
+ anchor_tokens: torch.Tensor
67
+ anchor_mask: torch.Tensor
68
+ dynamic_tokens: torch.Tensor
69
+ dynamic_mask: torch.Tensor
70
+ revisit_tokens: torch.Tensor
71
+ revisit_mask: torch.Tensor
72
+ anchor_gate: torch.Tensor | float
73
+ dynamic_gate: torch.Tensor | float
74
+ revisit_gate: torch.Tensor | float
75
+ revisit_gate_raw: torch.Tensor | None = None
76
+ valid_revisit_mask: torch.Tensor | None = None
77
+ no_valid_revisit_mask: torch.Tensor | None = None
78
+ diagnostics: dict[str, Any] = field(default_factory=dict)
79
+
80
+
81
+ @dataclass(frozen=True)
82
+ class StreamGateState:
83
+ anchor_enabled: bool
84
+ dynamic_enabled: bool
85
+ revisit_enabled: bool
86
+ anchor_gate: float
87
+ dynamic_gate: float
88
+ revisit_gate: float
89
+ stage: str
90
+ reason: str = ""
91
+
92
+
93
+ @dataclass
94
+ class RevisitRetrievalResult:
95
+ records: list[MemoryRecord]
96
+ scores: torch.Tensor
97
+ selected_frame_ids: list[int]
98
+ diagnostics: dict[str, Any]
algorithms/worldmem/dememwm_memory_dit.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ from .dememwm.algorithm import MemoryDiTMixin
5
+ from .df_video import BaseVideoDiTMinecraft
6
+
7
+
8
+ class DeMemWMMinecraft(MemoryDiTMixin, BaseVideoDiTMinecraft):
9
+ """Standalone DeMemWM / Memory-DiT algorithm.
10
+
11
+ Reuses the base video-DiT VAE/diffusion/training infrastructure,
12
+ but owns memory construction/injection. Does not route through the legacy memory method.
13
+ """
14
+
15
+ pass
16
+
17
+
18
+ DeMemWMMemoryDiTMinecraft = DeMemWMMinecraft
algorithms/worldmem/df_base.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research
3
+ template [repo](https://github.com/buoyancy99/research-template).
4
+ By its MIT license, you must keep the above sentence in `README.md`
5
+ and the `LICENSE` file to credit the author.
6
+ """
7
+
8
+ from typing import Optional
9
+ from tqdm import tqdm
10
+ from omegaconf import DictConfig
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from typing import Any
15
+ from einops import rearrange
16
+
17
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
18
+
19
+ from algorithms.common.base_pytorch_algo import BasePytorchAlgo
20
+ from .models.diffusion import Diffusion
21
+
22
+
23
+ class DiffusionForcingBase(BasePytorchAlgo):
24
+ def __init__(self, cfg: DictConfig):
25
+ self.cfg = cfg
26
+ self.x_shape = cfg.x_shape
27
+ self.frame_stack = cfg.frame_stack
28
+ self.x_stacked_shape = list(self.x_shape)
29
+ self.x_stacked_shape[0] *= cfg.frame_stack
30
+ self.guidance_scale = cfg.guidance_scale
31
+ self.context_frames = cfg.context_frames
32
+ self.chunk_size = cfg.chunk_size
33
+ self.action_cond_dim = cfg.action_cond_dim
34
+ self.causal = cfg.causal
35
+
36
+ self.uncertainty_scale = cfg.uncertainty_scale
37
+ self.timesteps = cfg.diffusion.timesteps
38
+ self.sampling_timesteps = cfg.diffusion.sampling_timesteps
39
+ self.clip_noise = cfg.diffusion.clip_noise
40
+
41
+ self.cfg.diffusion.cum_snr_decay = self.cfg.diffusion.cum_snr_decay ** (self.frame_stack * cfg.frame_skip)
42
+
43
+ self.validation_step_outputs = []
44
+ super().__init__(cfg)
45
+
46
+ def _build_model(self):
47
+ self.diffusion_model = Diffusion(
48
+ x_shape=self.x_stacked_shape,
49
+ action_cond_dim=self.action_cond_dim,
50
+ is_causal=self.causal,
51
+ cfg=self.cfg.diffusion,
52
+ )
53
+ self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std)
54
+
55
+ def configure_optimizers(self):
56
+ params = tuple(self.diffusion_model.parameters())
57
+ optimizer_dynamics = torch.optim.AdamW(
58
+ params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay, betas=self.cfg.optimizer_beta
59
+ )
60
+ return optimizer_dynamics
61
+
62
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
63
+ # update params
64
+ optimizer.step(closure=optimizer_closure)
65
+
66
+ # manually warm up lr without a scheduler
67
+ if self.trainer.global_step < self.cfg.warmup_steps:
68
+ lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.cfg.warmup_steps)
69
+ for pg in optimizer.param_groups:
70
+ pg["lr"] = lr_scale * self.cfg.lr
71
+
72
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
73
+ xs, conditions, masks = self._preprocess_batch(batch)
74
+
75
+ rand_length = torch.randint(3,xs.shape[0]-2, (1,))[0].item()
76
+ xs = torch.cat([xs[:rand_length], xs[rand_length-3:rand_length-1]])
77
+ conditions = torch.cat([conditions[:rand_length], conditions[rand_length-3:rand_length-1]])
78
+ masks = torch.cat([masks[:rand_length], masks[rand_length-3:rand_length-1]])
79
+ noise_levels=self._generate_noise_levels(xs)
80
+ noise_levels[:rand_length] = 15 # stable_noise_levels
81
+ noise_levels[rand_length+1:] = 15 # stable_noise_levels
82
+
83
+ xs_pred, loss = self.diffusion_model(xs, conditions, noise_levels=noise_levels)
84
+ loss = self.reweight_loss(loss, masks)
85
+
86
+ # log the loss
87
+ if batch_idx % 20 == 0:
88
+ self.log("training/loss", loss)
89
+
90
+ xs = self._unstack_and_unnormalize(xs)
91
+ xs_pred = self._unstack_and_unnormalize(xs_pred)
92
+
93
+ output_dict = {
94
+ "loss": loss,
95
+ "xs_pred": xs_pred,
96
+ "xs": xs,
97
+ }
98
+
99
+ return output_dict
100
+
101
+ @torch.no_grad()
102
+ def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
103
+ xs, conditions, masks = self._preprocess_batch(batch)
104
+ n_frames, batch_size, *_ = xs.shape
105
+ xs_pred = []
106
+ curr_frame = 0
107
+
108
+ # context
109
+ n_context_frames = self.context_frames // self.frame_stack
110
+ xs_pred = xs[:n_context_frames].clone()
111
+ curr_frame += n_context_frames
112
+
113
+ if self.condtion_similar_length:
114
+ n_frames -= self.condtion_similar_length
115
+
116
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
117
+ while curr_frame < n_frames:
118
+ if self.chunk_size > 0:
119
+ horizon = min(n_frames - curr_frame, self.chunk_size)
120
+ else:
121
+ horizon = n_frames - curr_frame
122
+ assert horizon <= self.n_tokens, "horizon exceeds the number of tokens."
123
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
124
+
125
+ chunk = torch.randn((horizon, batch_size, *self.x_stacked_shape), device=self.device)
126
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
127
+ xs_pred = torch.cat([xs_pred, chunk], 0)
128
+
129
+ # sliding window: only input the last n_tokens frames
130
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
131
+
132
+ pbar.set_postfix(
133
+ {
134
+ "start": start_frame,
135
+ "end": curr_frame + horizon,
136
+ }
137
+ )
138
+
139
+ if self.condtion_similar_length:
140
+ xs_pred = torch.cat([xs_pred, xs[curr_frame-self.condtion_similar_length:curr_frame].clone()], 0)
141
+
142
+ for m in range(scheduling_matrix.shape[0] - 1):
143
+
144
+ from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[
145
+ :, None
146
+ ].repeat(batch_size, axis=1)
147
+ to_noise_levels = np.concatenate(
148
+ (
149
+ np.zeros((curr_frame,), dtype=np.int64),
150
+ scheduling_matrix[m + 1],
151
+ )
152
+ )[
153
+ :, None
154
+ ].repeat(batch_size, axis=1)
155
+
156
+ if self.condtion_similar_length:
157
+ from_noise_levels = np.concatenate([from_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
158
+ to_noise_levels = np.concatenate([to_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0)
159
+
160
+ from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
161
+ to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
162
+
163
+ # update xs_pred by DDIM or DDPM sampling
164
+ # input frames within the sliding window
165
+
166
+ try:
167
+ input_condition = conditions[start_frame : curr_frame + horizon].clone()
168
+ except:
169
+ import pdb;pdb.set_trace()
170
+ if self.condtion_similar_length:
171
+ input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], conditions[-self.condtion_similar_length:]], dim=0)
172
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
173
+ xs_pred[start_frame:],
174
+ input_condition,
175
+ from_noise_levels[start_frame:],
176
+ to_noise_levels[start_frame:],
177
+ )
178
+
179
+ if self.condtion_similar_length:
180
+ xs_pred = xs_pred[:-self.condtion_similar_length]
181
+
182
+ curr_frame += horizon
183
+ pbar.update(horizon)
184
+
185
+ if self.condtion_similar_length:
186
+ xs = xs[:-self.condtion_similar_length]
187
+ # FIXME: loss
188
+ loss = F.mse_loss(xs_pred, xs, reduction="none")
189
+ loss = self.reweight_loss(loss, masks)
190
+ self.validation_step_outputs.append((xs_pred.detach().cpu(), xs.detach().cpu()))
191
+
192
+ return loss
193
+
194
+ def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
195
+ return self.validation_step(*args, **kwargs, namespace="test")
196
+
197
+ def on_test_epoch_end(self) -> None:
198
+ self.on_validation_epoch_end(namespace="test")
199
+
200
+ def _generate_noise_levels(self, xs: torch.Tensor, masks: Optional[torch.Tensor] = None) -> torch.Tensor:
201
+ """
202
+ Generate noise levels for training.
203
+ """
204
+ num_frames, batch_size, *_ = xs.shape
205
+ match self.cfg.noise_level:
206
+ case "random_all": # entirely random noise levels
207
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
208
+ case "same":
209
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
210
+ noise_levels[1:] = noise_levels[0]
211
+
212
+ if masks is not None:
213
+ # for frames that are not available, treat as full noise
214
+ discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
215
+ noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
216
+
217
+ return noise_levels
218
+
219
+ def _generate_scheduling_matrix(self, horizon: int):
220
+ match self.cfg.scheduling_matrix:
221
+ case "pyramid":
222
+ return self._generate_pyramid_scheduling_matrix(horizon, self.uncertainty_scale)
223
+ case "full_sequence":
224
+ return np.arange(self.sampling_timesteps, -1, -1)[:, None].repeat(horizon, axis=1)
225
+ case "autoregressive":
226
+ return self._generate_pyramid_scheduling_matrix(horizon, self.sampling_timesteps)
227
+ case "trapezoid":
228
+ return self._generate_trapezoid_scheduling_matrix(horizon, self.uncertainty_scale)
229
+
230
+ def _generate_pyramid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
231
+ height = self.sampling_timesteps + int((horizon - 1) * uncertainty_scale) + 1
232
+ scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
233
+ for m in range(height):
234
+ for t in range(horizon):
235
+ scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
236
+
237
+ return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
238
+
239
+ def _generate_trapezoid_scheduling_matrix(self, horizon: int, uncertainty_scale: float):
240
+ height = self.sampling_timesteps + int((horizon + 1) // 2 * uncertainty_scale)
241
+ scheduling_matrix = np.zeros((height, horizon), dtype=np.int64)
242
+ for m in range(height):
243
+ for t in range((horizon + 1) // 2):
244
+ scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
245
+ scheduling_matrix[m, -t] = self.sampling_timesteps + int(t * uncertainty_scale) - m
246
+
247
+ return np.clip(scheduling_matrix, 0, self.sampling_timesteps)
248
+
249
+ def reweight_loss(self, loss, weight=None):
250
+ # Note there is another part of loss reweighting (fused_snr) inside the Diffusion class!
251
+ loss = rearrange(loss, "t b (fs c) ... -> t b fs c ...", fs=self.frame_stack)
252
+ if weight is not None:
253
+ expand_dim = len(loss.shape) - len(weight.shape) - 1
254
+ weight = rearrange(
255
+ weight,
256
+ "(t fs) b ... -> t b fs ..." + " 1" * expand_dim,
257
+ fs=self.frame_stack,
258
+ )
259
+ loss = loss * weight
260
+
261
+ return loss.mean()
262
+
263
+ def _preprocess_batch(self, batch):
264
+ xs = batch[0]
265
+ batch_size, n_frames = xs.shape[:2]
266
+
267
+ if n_frames % self.frame_stack != 0:
268
+ raise ValueError("Number of frames must be divisible by frame stack size")
269
+ if self.context_frames % self.frame_stack != 0:
270
+ raise ValueError("Number of context frames must be divisible by frame stack size")
271
+
272
+ masks = torch.ones(n_frames, batch_size).to(xs.device)
273
+ n_frames = n_frames // self.frame_stack
274
+
275
+ if self.action_cond_dim:
276
+ conditions = batch[1]
277
+ conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
278
+ conditions = rearrange(conditions, "b (t fs) d -> t b (fs d)", fs=self.frame_stack).contiguous()
279
+
280
+ # f, _, _ = conditions.shape
281
+ # predefined_1 = torch.tensor([0,0,0,1]).to(conditions.device)
282
+ # predefined_2 = torch.tensor([0,0,1,0]).to(conditions.device)
283
+ # conditions[:f//2] = predefined_1
284
+ # conditions[f//2:] = predefined_2
285
+ else:
286
+ conditions = [None for _ in range(n_frames)]
287
+
288
+ xs = self._normalize_x(xs)
289
+ xs = rearrange(xs, "b (t fs) c ... -> t b (fs c) ...", fs=self.frame_stack).contiguous()
290
+
291
+ return xs, conditions, masks
292
+
293
+ def _normalize_x(self, xs):
294
+ shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
295
+ mean = self.data_mean.reshape(shape)
296
+ std = self.data_std.reshape(shape)
297
+ return (xs - mean) / std
298
+
299
+ def _unnormalize_x(self, xs):
300
+ shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape)
301
+ mean = self.data_mean.reshape(shape)
302
+ std = self.data_std.reshape(shape)
303
+ return xs * std + mean
304
+
305
+ def _unstack_and_unnormalize(self, xs):
306
+ xs = rearrange(xs, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack)
307
+ return self._unnormalize_x(xs)
algorithms/worldmem/df_video.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms.functional as TF
8
+ from torchvision.transforms import InterpolationMode
9
+ from PIL import Image
10
+ from packaging import version as pver
11
+ from einops import rearrange
12
+ from tqdm import tqdm
13
+ from omegaconf import DictConfig
14
+ from lightning.pytorch.utilities.types import STEP_OUTPUT
15
+ from algorithms.common.metrics import (
16
+ LearnedPerceptualImagePatchSimilarity,
17
+ )
18
+ from utils.logging_utils import log_video, get_validation_metrics_for_videos
19
+ from .df_base import DiffusionForcingBase
20
+ from .models.vae import VAE_models
21
+ from .models.diffusion import Diffusion
22
+ from .models.pose_prediction import PosePredictionNet
23
+ import glob
24
+
25
+ # Utility Functions
26
+ def euler_to_rotation_matrix(pitch, yaw):
27
+ """
28
+ Convert pitch and yaw angles (in radians) to a 3x3 rotation matrix.
29
+ Supports batch input.
30
+
31
+ Args:
32
+ pitch (torch.Tensor): Pitch angles in radians.
33
+ yaw (torch.Tensor): Yaw angles in radians.
34
+
35
+ Returns:
36
+ torch.Tensor: Rotation matrix of shape (batch_size, 3, 3).
37
+ """
38
+ cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch)
39
+ cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw)
40
+
41
+ R_pitch = torch.stack([
42
+ torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch),
43
+ torch.zeros_like(pitch), cos_pitch, -sin_pitch,
44
+ torch.zeros_like(pitch), sin_pitch, cos_pitch
45
+ ], dim=-1).reshape(-1, 3, 3)
46
+
47
+ R_yaw = torch.stack([
48
+ cos_yaw, torch.zeros_like(yaw), sin_yaw,
49
+ torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
50
+ -sin_yaw, torch.zeros_like(yaw), cos_yaw
51
+ ], dim=-1).reshape(-1, 3, 3)
52
+
53
+ return torch.matmul(R_yaw, R_pitch)
54
+
55
+
56
+ def euler_to_camera_to_world_matrix(pose):
57
+ """
58
+ Convert (x, y, z, pitch, yaw) to a 4x4 camera-to-world transformation matrix using torch.
59
+ Supports both (5,) and (f, b, 5) shaped inputs.
60
+
61
+ Args:
62
+ pose (torch.Tensor): Pose tensor of shape (5,) or (f, b, 5).
63
+
64
+ Returns:
65
+ torch.Tensor: Camera-to-world transformation matrix of shape (4, 4).
66
+ """
67
+
68
+ origin_dim = pose.ndim
69
+ if origin_dim == 1:
70
+ pose = pose.unsqueeze(0).unsqueeze(0) # Convert (5,) -> (1, 1, 5)
71
+ elif origin_dim == 2:
72
+ pose = pose.unsqueeze(0)
73
+
74
+ x, y, z, pitch, yaw = pose[..., 0], pose[..., 1], pose[..., 2], pose[..., 3], pose[..., 4]
75
+ pitch, yaw = torch.deg2rad(pitch), torch.deg2rad(yaw)
76
+
77
+ # Compute rotation matrix (batch mode)
78
+ R = euler_to_rotation_matrix(pitch, yaw) # Shape (f*b, 3, 3)
79
+
80
+ # Create the 4x4 transformation matrix
81
+ eye = torch.eye(4, dtype=torch.float32, device=pose.device)
82
+ camera_to_world = eye.repeat(R.shape[0], 1, 1) # Shape (f*b, 4, 4)
83
+
84
+ # Assign rotation
85
+ camera_to_world[:, :3, :3] = R
86
+
87
+ # Assign translation
88
+ camera_to_world[:, :3, 3] = torch.stack([x.reshape(-1), y.reshape(-1), z.reshape(-1)], dim=-1)
89
+
90
+ # Reshape back to (f, b, 4, 4) if needed
91
+ if origin_dim == 3:
92
+ return camera_to_world.view(pose.shape[0], pose.shape[1], 4, 4)
93
+ elif origin_dim == 2:
94
+ return camera_to_world.view(pose.shape[0], 4, 4)
95
+ else:
96
+ return camera_to_world.squeeze(0).squeeze(0) # Convert (1,1,4,4) -> (4,4)
97
+
98
+ def is_inside_fov_3d_hv(points, center, center_pitch, center_yaw, fov_half_h, fov_half_v):
99
+ """
100
+ Check whether points are within a given 3D field of view (FOV)
101
+ with separately defined horizontal and vertical ranges.
102
+
103
+ The center view direction is specified by pitch and yaw (in degrees).
104
+
105
+ :param points: (N, B, 3) Sample point coordinates
106
+ :param center: (3,) Center coordinates of the FOV
107
+ :param center_pitch: Pitch angle of the center view (in degrees)
108
+ :param center_yaw: Yaw angle of the center view (in degrees)
109
+ :param fov_half_h: Horizontal half-FOV angle (in degrees)
110
+ :param fov_half_v: Vertical half-FOV angle (in degrees)
111
+ :return: Boolean tensor (N, B), indicating whether each point is inside the FOV
112
+ """
113
+ # Compute vectors relative to the center
114
+ vectors = points - center # shape (N, B, 3)
115
+ x = vectors[..., 0]
116
+ y = vectors[..., 1]
117
+ z = vectors[..., 2]
118
+
119
+ # Compute horizontal angle (yaw): measured with respect to the z-axis as the forward direction,
120
+ # and the x-axis as left-right, resulting in a range of -180 to 180 degrees.
121
+ azimuth = torch.atan2(x, z) * (180 / math.pi)
122
+
123
+ # Compute vertical angle (pitch): measured with respect to the horizontal plane,
124
+ # resulting in a range of -90 to 90 degrees.
125
+ elevation = torch.atan2(y, torch.sqrt(x**2 + z**2)) * (180 / math.pi)
126
+
127
+ # Compute the angular difference from the center view (handling circular angle wrap-around)
128
+ diff_azimuth = (azimuth - center_yaw).abs() % 360
129
+ diff_elevation = (elevation - center_pitch).abs() % 360
130
+
131
+ # Adjust values greater than 180 degrees to the shorter angular difference
132
+ diff_azimuth = torch.where(diff_azimuth > 180, 360 - diff_azimuth, diff_azimuth)
133
+ diff_elevation = torch.where(diff_elevation > 180, 360 - diff_elevation, diff_elevation)
134
+
135
+ # Check if both horizontal and vertical angles are within their respective FOV limits
136
+ return (diff_azimuth < fov_half_h) & (diff_elevation < fov_half_v)
137
+
138
+ def generate_points_in_sphere(n_points, radius):
139
+ # Sample three independent uniform distributions
140
+ samples_r = torch.rand(n_points) # For radius distribution
141
+ samples_phi = torch.rand(n_points) # For azimuthal angle phi
142
+ samples_u = torch.rand(n_points) # For polar angle theta
143
+
144
+ # Apply cube root to ensure uniform volumetric distribution
145
+ r = radius * torch.pow(samples_r, 1/3)
146
+ # Azimuthal angle phi uniformly distributed in [0, 2π]
147
+ phi = 2 * math.pi * samples_phi
148
+ # Convert u to theta to ensure cos(theta) is uniformly distributed
149
+ theta = torch.acos(1 - 2 * samples_u)
150
+
151
+ # Convert spherical coordinates to Cartesian coordinates
152
+ x = r * torch.sin(theta) * torch.cos(phi)
153
+ y = r * torch.sin(theta) * torch.sin(phi)
154
+ z = r * torch.cos(theta)
155
+
156
+ points = torch.stack((x, y, z), dim=1)
157
+ return points
158
+
159
+ def tensor_max_with_number(tensor, number):
160
+ number_tensor = torch.tensor(number, dtype=tensor.dtype, device=tensor.device)
161
+ result = torch.max(tensor, number_tensor)
162
+ return result
163
+
164
+ def custom_meshgrid(*args):
165
+ # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
166
+ if pver.parse(torch.__version__) < pver.parse('1.10'):
167
+ return torch.meshgrid(*args)
168
+ else:
169
+ return torch.meshgrid(*args, indexing='ij')
170
+
171
+ def camera_to_world_to_world_to_camera(camera_to_world: torch.Tensor) -> torch.Tensor:
172
+ """
173
+ Convert Camera-to-World matrices to World-to-Camera matrices for a tensor with shape (f, b, 4, 4).
174
+
175
+ Args:
176
+ camera_to_world (torch.Tensor): A tensor of shape (f, b, 4, 4), where:
177
+ f = number of frames,
178
+ b = batch size.
179
+
180
+ Returns:
181
+ torch.Tensor: A tensor of shape (f, b, 4, 4) representing the World-to-Camera matrices.
182
+ """
183
+ # Ensure input is a 4D tensor
184
+ assert camera_to_world.ndim == 4 and camera_to_world.shape[2:] == (4, 4), \
185
+ "Input must be of shape (f, b, 4, 4)"
186
+
187
+ # Extract the rotation (R) and translation (T) parts
188
+ R = camera_to_world[:, :, :3, :3] # Shape: (f, b, 3, 3)
189
+ T = camera_to_world[:, :, :3, 3] # Shape: (f, b, 3)
190
+
191
+ # Initialize an identity matrix for the output
192
+ world_to_camera = torch.eye(4, device=camera_to_world.device).unsqueeze(0).unsqueeze(0)
193
+ world_to_camera = world_to_camera.repeat(camera_to_world.size(0), camera_to_world.size(1), 1, 1) # Shape: (f, b, 4, 4)
194
+
195
+ # Compute the rotation (transpose of R)
196
+ world_to_camera[:, :, :3, :3] = R.transpose(2, 3)
197
+
198
+ # Compute the translation (-R^T * T)
199
+ world_to_camera[:, :, :3, 3] = -torch.matmul(R.transpose(2, 3), T.unsqueeze(-1)).squeeze(-1)
200
+
201
+ return world_to_camera.to(camera_to_world.dtype)
202
+
203
+ def convert_to_plucker(poses, curr_frame, focal_length, image_width, image_height):
204
+
205
+ intrinsic = np.asarray([focal_length * image_width,
206
+ focal_length * image_height,
207
+ 0.5 * image_width,
208
+ 0.5 * image_height], dtype=np.float32)
209
+
210
+ c2ws = get_relative_pose(poses, zero_first_frame_scale=curr_frame)
211
+ c2ws = rearrange(c2ws, "t b m n -> b t m n")
212
+
213
+ K = torch.as_tensor(intrinsic, device=poses.device, dtype=poses.dtype).repeat(c2ws.shape[0],c2ws.shape[1],1) # [B, F, 4]
214
+ plucker_embedding = ray_condition(K, c2ws, image_height, image_width, device=c2ws.device)
215
+ plucker_embedding = rearrange(plucker_embedding, "b t h w d -> t b h w d").contiguous()
216
+
217
+ return plucker_embedding
218
+
219
+
220
+ def get_relative_pose(abs_c2ws, zero_first_frame_scale):
221
+ abs_w2cs = camera_to_world_to_world_to_camera(abs_c2ws)
222
+ target_cam_c2w = torch.tensor([
223
+ [1, 0, 0, 0],
224
+ [0, 1, 0, 0],
225
+ [0, 0, 1, 0],
226
+ [0, 0, 0, 1]
227
+ ]).to(abs_c2ws.device).to(abs_c2ws.dtype)
228
+ abs2rel = target_cam_c2w @ abs_w2cs[zero_first_frame_scale]
229
+ ret_poses = [abs2rel @ abs_c2w for abs_c2w in abs_c2ws]
230
+ ret_poses = torch.stack(ret_poses)
231
+ return ret_poses
232
+
233
+ def ray_condition(K, c2w, H, W, device):
234
+ # c2w: B, V, 4, 4
235
+ # K: B, V, 4
236
+
237
+ B = K.shape[0]
238
+
239
+ j, i = custom_meshgrid(
240
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
241
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
242
+ )
243
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
244
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
245
+
246
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
247
+
248
+ zs = torch.ones_like(i, device=device, dtype=c2w.dtype) # [B, HxW]
249
+ xs = -(i - cx) / fx * zs
250
+ ys = -(j - cy) / fy * zs
251
+
252
+ zs = zs.expand_as(ys)
253
+
254
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
255
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
256
+
257
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
258
+ rays_o = c2w[..., :3, 3] # B, V, 3
259
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
260
+ # c2w @ dirctions
261
+ rays_dxo = torch.linalg.cross(rays_o, rays_d)
262
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
263
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
264
+
265
+ return plucker
266
+
267
+ def random_transform(tensor):
268
+ """
269
+ Apply the same random translation, rotation, and scaling to all frames in the batch.
270
+
271
+ Args:
272
+ tensor (torch.Tensor): Input tensor of shape (F, B, 3, H, W).
273
+
274
+ Returns:
275
+ torch.Tensor: Transformed tensor of shape (F, B, 3, H, W).
276
+ """
277
+ if tensor.ndim != 5:
278
+ raise ValueError("Input tensor must have shape (F, B, 3, H, W)")
279
+
280
+ F, B, C, H, W = tensor.shape
281
+
282
+ # Generate random transformation parameters
283
+ max_translate = 0.2 # Translate up to 20% of width/height
284
+ max_rotate = 30 # Rotate up to 30 degrees
285
+ max_scale = 0.2 # Scale change by up to +/- 20%
286
+
287
+ translate_x = random.uniform(-max_translate, max_translate) * W
288
+ translate_y = random.uniform(-max_translate, max_translate) * H
289
+ rotate_angle = random.uniform(-max_rotate, max_rotate)
290
+ scale_factor = 1 + random.uniform(-max_scale, max_scale)
291
+
292
+ # Apply the same transformation to all frames and batches
293
+
294
+ tensor = tensor.reshape(F*B, C, H, W)
295
+ transformed_tensor = TF.affine(
296
+ tensor,
297
+ angle=rotate_angle,
298
+ translate=(translate_x, translate_y),
299
+ scale=scale_factor,
300
+ shear=(0, 0),
301
+ interpolation=InterpolationMode.BILINEAR,
302
+ fill=0
303
+ )
304
+
305
+ transformed_tensor = transformed_tensor.reshape(F, B, C, H, W)
306
+ return transformed_tensor
307
+
308
+ def save_tensor_as_png(tensor, file_path):
309
+ """
310
+ Save a 3*H*W tensor as a PNG image.
311
+
312
+ Args:
313
+ tensor (torch.Tensor): Input tensor of shape (3, H, W).
314
+ file_path (str): Path to save the PNG file.
315
+ """
316
+ if tensor.ndim != 3 or tensor.shape[0] != 3:
317
+ raise ValueError("Input tensor must have shape (3, H, W)")
318
+
319
+ # Convert tensor to PIL Image
320
+ image = TF.to_pil_image(tensor)
321
+
322
+ # Save image
323
+ image.save(file_path)
324
+
325
+ class BaseVideoDiTMinecraft(DiffusionForcingBase):
326
+ """
327
+ Video generation for MineCraft with memory.
328
+ """
329
+
330
+ def __init__(self, cfg: DictConfig):
331
+ """
332
+ Initialize the base video-DiT Minecraft class with the given configuration.
333
+
334
+ Args:
335
+ cfg (DictConfig): Configuration object.
336
+ """
337
+ self.n_tokens = cfg.n_frames // cfg.frame_stack # number of max tokens for the model
338
+ self.n_frames = cfg.n_frames
339
+ if hasattr(cfg, "n_tokens"):
340
+ self.n_tokens = cfg.n_tokens // cfg.frame_stack
341
+ self.memory_condition_length = cfg.memory_condition_length
342
+ self.pose_cond_dim = getattr(cfg, "pose_cond_dim", 5)
343
+
344
+ self.use_plucker = getattr(cfg, "use_plucker", True)
345
+ self.relative_embedding = getattr(cfg, "relative_embedding", True)
346
+ self.state_embed_only_on_qk = getattr(cfg, "state_embed_only_on_qk", True)
347
+ self.use_memory_attention = getattr(cfg, "use_memory_attention", True)
348
+ self.add_timestamp_embedding = getattr(cfg, "add_timestamp_embedding", True)
349
+ self.ref_mode = getattr(cfg, "ref_mode", 'sequential')
350
+ self.log_curve = getattr(cfg, "log_curve", False)
351
+ self.focal_length = getattr(cfg, "focal_length", 0.35)
352
+ self.log_video = cfg.log_video
353
+ self.save_local = getattr(cfg, "save_local", True)
354
+ self.local_save_dir = getattr(cfg, "local_save_dir", None)
355
+ self.lpips_batch_size = getattr(cfg, "lpips_batch_size", 16)
356
+ self.next_frame_length = getattr(cfg, "next_frame_length", 1)
357
+ self.require_pose_prediction = getattr(cfg, "require_pose_prediction", False)
358
+
359
+ super().__init__(cfg)
360
+
361
+ def _build_model(self):
362
+
363
+ self.diffusion_model = Diffusion(
364
+ reference_length=self.memory_condition_length,
365
+ x_shape=self.x_stacked_shape,
366
+ action_cond_dim=self.action_cond_dim,
367
+ pose_cond_dim=self.pose_cond_dim,
368
+ is_causal=self.causal,
369
+ cfg=self.cfg.diffusion,
370
+ is_dit=True,
371
+ use_plucker=self.use_plucker,
372
+ relative_embedding=self.relative_embedding,
373
+ state_embed_only_on_qk=self.state_embed_only_on_qk,
374
+ use_memory_attention=self.use_memory_attention,
375
+ add_timestamp_embedding=self.add_timestamp_embedding,
376
+ ref_mode=self.ref_mode
377
+ )
378
+
379
+ self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity()
380
+ vae = VAE_models["vit-l-20-shallow-encoder"]()
381
+ self.vae = vae.eval()
382
+
383
+ if self.require_pose_prediction:
384
+ self.pose_prediction_model = PosePredictionNet()
385
+
386
+ def _generate_noise_levels(self, xs: torch.Tensor, masks = None) -> torch.Tensor:
387
+ """
388
+ Generate noise levels for training.
389
+ """
390
+ num_frames, batch_size, *_ = xs.shape
391
+ match self.cfg.noise_level:
392
+ case "random_all": # entirely random noise levels
393
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
394
+ case "same":
395
+ noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device)
396
+ noise_levels[1:] = noise_levels[0]
397
+
398
+ if masks is not None:
399
+ # for frames that are not available, treat as full noise
400
+ discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1)
401
+ noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels)
402
+
403
+ return noise_levels
404
+
405
+ def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
406
+ """
407
+ Perform a single training step.
408
+
409
+ This function processes the input batch,
410
+ encodes the input frames, generates noise levels, and computes the loss using the diffusion model.
411
+
412
+ Args:
413
+ batch: Input batch of data containing frames, conditions, poses, etc.
414
+ batch_idx: Index of the current batch.
415
+
416
+ Returns:
417
+ dict: A dictionary containing the training loss.
418
+ """
419
+ xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
420
+
421
+ if self.use_plucker:
422
+ if self.relative_embedding:
423
+ input_pose_condition = []
424
+ frame_idx_list = []
425
+ for i in range(self.n_frames):
426
+ input_pose_condition.append(
427
+ convert_to_plucker(
428
+ torch.cat([c2w_mat[i:i + 1], c2w_mat[-self.memory_condition_length:]]).clone(),
429
+ 0,
430
+ focal_length=self.focal_length,
431
+ image_height=xs.shape[-2],image_width=xs.shape[-1]
432
+ ).to(xs.dtype)
433
+ )
434
+ frame_idx_list.append(
435
+ torch.cat([
436
+ frame_idx[i:i + 1] - frame_idx[i:i + 1],
437
+ frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1]
438
+ ]).clone()
439
+ )
440
+ input_pose_condition = torch.cat(input_pose_condition)
441
+ frame_idx_list = torch.cat(frame_idx_list)
442
+ else:
443
+ input_pose_condition = convert_to_plucker(
444
+ c2w_mat, 0, focal_length=self.focal_length
445
+ ).to(xs.dtype)
446
+ frame_idx_list = frame_idx
447
+ else:
448
+ input_pose_condition = pose_conditions.to(xs.dtype)
449
+ frame_idx_list = None
450
+
451
+ xs = self.encode(xs)
452
+
453
+ noise_levels = self._generate_noise_levels(xs)
454
+
455
+ if self.memory_condition_length:
456
+ noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level
457
+ conditions[-self.memory_condition_length:] *= 0
458
+
459
+ _, loss = self.diffusion_model(
460
+ xs,
461
+ conditions,
462
+ input_pose_condition,
463
+ noise_levels=noise_levels,
464
+ reference_length=self.memory_condition_length,
465
+ frame_idx=frame_idx_list
466
+ )
467
+
468
+ if self.memory_condition_length:
469
+ loss = loss[:-self.memory_condition_length]
470
+
471
+ loss = self.reweight_loss(loss, None)
472
+
473
+ if batch_idx % 20 == 0:
474
+ self.log("training/loss", loss.cpu())
475
+
476
+ return {"loss": loss}
477
+
478
+ def on_validation_epoch_end(self, namespace="validation") -> None:
479
+ if not self.validation_step_outputs:
480
+ return
481
+
482
+ xs_pred = []
483
+ xs = []
484
+ for pred, gt in self.validation_step_outputs:
485
+ xs_pred.append(pred)
486
+ xs.append(gt)
487
+
488
+ xs_pred = torch.cat(xs_pred, 1)
489
+ if gt is not None:
490
+ xs = torch.cat(xs, 1)
491
+ else:
492
+ xs = None
493
+
494
+ if self.logger and self.log_video:
495
+ log_video(
496
+ xs_pred,
497
+ xs,
498
+ step=None if namespace == "test" else self.global_step,
499
+ namespace=namespace + "_vis",
500
+ context_frames=self.context_frames,
501
+ logger=self.logger.experiment,
502
+ save_local=self.save_local,
503
+ local_save_dir=self.local_save_dir,
504
+ )
505
+
506
+ if xs is not None:
507
+ # Move data to the same device as LPIPS model for metric calculation
508
+ device = next(self.validation_lpips_model.parameters()).device
509
+ xs_pred_device = xs_pred.to(device)
510
+ xs_device = xs.to(device)
511
+
512
+ metric_dict = get_validation_metrics_for_videos(
513
+ xs_pred_device, xs_device,
514
+ lpips_model=self.validation_lpips_model,
515
+ lpips_batch_size=self.lpips_batch_size)
516
+
517
+ self.log_dict(
518
+ {"mse": metric_dict['mse'],
519
+ "psnr": metric_dict['psnr'],
520
+ "lpips": metric_dict['lpips']},
521
+ sync_dist=True
522
+ )
523
+
524
+ if self.log_curve:
525
+ psnr_values = metric_dict['frame_wise_psnr'].cpu().tolist()
526
+ frames = list(range(len(psnr_values)))
527
+ line_plot = wandb.plot.line_series(
528
+ xs = frames,
529
+ ys = [psnr_values],
530
+ keys = ["PSNR"],
531
+ title = "Frame-wise PSNR",
532
+ xname = "Frame index"
533
+ )
534
+
535
+ self.logger.experiment.log({"frame_wise_psnr_plot": line_plot})
536
+
537
+ self.validation_step_outputs.clear()
538
+
539
+ def _preprocess_batch(self, batch):
540
+
541
+ xs, conditions, pose_conditions, frame_index = batch
542
+
543
+ if self.action_cond_dim:
544
+ conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1)
545
+ conditions = rearrange(conditions, "b t d -> t b d").contiguous()
546
+ else:
547
+ raise NotImplementedError("Only support external cond.")
548
+
549
+ pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous()
550
+ c2w_mat = euler_to_camera_to_world_matrix(pose_conditions)
551
+ xs = rearrange(xs, "b t c ... -> t b c ...").contiguous()
552
+ frame_index = rearrange(frame_index, "b t -> t b").contiguous()
553
+
554
+ return xs, conditions, pose_conditions, c2w_mat, frame_index
555
+
556
+ def encode(self, x):
557
+ # vae encoding
558
+ T = x.shape[0]
559
+ H, W = x.shape[-2:]
560
+ scaling_factor = 0.07843137255
561
+
562
+ x = rearrange(x, "t b c h w -> (t b) c h w")
563
+ with torch.no_grad():
564
+ x = self.vae.encode(x * 2 - 1).mean * scaling_factor
565
+ x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size)
566
+ return x
567
+
568
+ def decode(self, x):
569
+ total_frames = x.shape[0]
570
+ scaling_factor = 0.07843137255
571
+ x = rearrange(x, "t b c h w -> (t b) (h w) c")
572
+ with torch.no_grad():
573
+ x = (self.vae.decode(x / scaling_factor) + 1) / 2
574
+ x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames)
575
+ return x
576
+
577
+ def _generate_condition_indices(self, curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon):
578
+ """
579
+ Generate indices for condition similarity based on the current frame and pose conditions.
580
+ """
581
+ if curr_frame < memory_condition_length:
582
+ random_idx = [i for i in range(curr_frame)] + [0] * (memory_condition_length - curr_frame)
583
+ random_idx = np.repeat(np.array(random_idx)[:, None], xs_pred.shape[1], -1)
584
+ else:
585
+ # Generate points in a sphere and filter based on field of view
586
+ num_samples = 10000
587
+ radius = 30
588
+ points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device)
589
+ points = points[:, None].repeat(1, pose_conditions.shape[1], 1)
590
+ points += pose_conditions[curr_frame, :, :3][None]
591
+ fov_half_h = torch.tensor(105 / 2, device=pose_conditions.device)
592
+ fov_half_v = torch.tensor(75 / 2, device=pose_conditions.device)
593
+
594
+ # in_fov1 = is_inside_fov_3d_hv(
595
+ # points, pose_conditions[curr_frame, :, :3],
596
+ # pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1],
597
+ # fov_half_h, fov_half_v
598
+ # )
599
+
600
+ in_fov1 = torch.stack([
601
+ is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
602
+ for pc in pose_conditions[curr_frame:curr_frame+horizon]
603
+ ])
604
+
605
+ in_fov1 = torch.sum(in_fov1, 0) > 0
606
+
607
+ # Compute overlap ratios and select indices
608
+ in_fov_list = torch.stack([
609
+ is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v)
610
+ for pc in pose_conditions[:curr_frame]
611
+ ])
612
+
613
+ random_idx = []
614
+ for _ in range(memory_condition_length):
615
+ overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum()
616
+
617
+ confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2)
618
+
619
+ if len(random_idx) > 0:
620
+ confidence[torch.cat(random_idx)] = -1e10
621
+ _, r_idx = torch.topk(confidence, k=1, dim=0)
622
+ random_idx.append(r_idx[0])
623
+
624
+ # choice 1: directly remove overlapping region
625
+ occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0)
626
+ in_fov1 = in_fov1 & ~occupied_mask
627
+
628
+ # choice 2: apply similarity filter
629
+ # cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])],
630
+ # range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2)
631
+ # cos_sim = cos_sim.mean((-2,-1))
632
+
633
+ # mask_sim = cos_sim>0.9
634
+ # in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device)
635
+
636
+ random_idx = torch.stack(random_idx).cpu()
637
+
638
+ return random_idx
639
+
640
+ def _prepare_conditions(self,
641
+ start_frame, curr_frame, horizon, conditions,
642
+ pose_conditions, c2w_mat, frame_idx, random_idx,
643
+ image_width, image_height):
644
+ """
645
+ Prepare input conditions and pose conditions for sampling.
646
+ """
647
+
648
+ padding = torch.zeros((len(random_idx),) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype)
649
+ input_condition = torch.cat([conditions[start_frame:curr_frame + horizon], padding], dim=0)
650
+
651
+ batch_size = conditions.shape[1]
652
+
653
+ if self.use_plucker:
654
+ if self.relative_embedding:
655
+ frame_idx_list = []
656
+ input_pose_condition = []
657
+ for i in range(start_frame, curr_frame + horizon):
658
+ input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]]).clone(), 0, focal_length=self.focal_length,
659
+ image_width=image_width, image_height=image_height).to(conditions.dtype))
660
+ frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(batch_size)], range(batch_size)]-frame_idx[i:i+1]]))
661
+ input_pose_condition = torch.cat(input_pose_condition)
662
+ frame_idx_list = torch.cat(frame_idx_list)
663
+
664
+ else:
665
+ input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
666
+ input_pose_condition = convert_to_plucker(input_pose_condition, 0, focal_length=self.focal_length)
667
+ frame_idx_list = None
668
+ else:
669
+ input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone()
670
+ frame_idx_list = None
671
+
672
+ return input_condition, input_pose_condition, frame_idx_list
673
+
674
+ def _prepare_noise_levels(self, scheduling_matrix, m, curr_frame, batch_size, memory_condition_length):
675
+ """
676
+ Prepare noise levels for the current sampling step.
677
+ """
678
+ from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[:, None].repeat(batch_size, axis=1)
679
+ to_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m + 1]))[:, None].repeat(batch_size, axis=1)
680
+ if memory_condition_length:
681
+ from_noise_levels = np.concatenate([from_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
682
+ to_noise_levels = np.concatenate([to_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0)
683
+ from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device)
684
+ to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device)
685
+ return from_noise_levels, to_noise_levels
686
+
687
+ def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT:
688
+ """
689
+ Perform a single validation step.
690
+
691
+ This function processes the input batch, encodes frames, generates predictions using a sliding window approach,
692
+ and handles condition similarity logic for sampling. The results are decoded and stored for evaluation.
693
+
694
+ Args:
695
+ batch: Input batch of data containing frames, conditions, poses, etc.
696
+ batch_idx: Index of the current batch.
697
+ namespace: Namespace for logging (default: "validation").
698
+
699
+ Returns:
700
+ None: Appends the predicted and ground truth frames to `self.validation_step_outputs`.
701
+ """
702
+ # Preprocess the input batch
703
+ memory_condition_length = self.memory_condition_length
704
+ xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch)
705
+
706
+
707
+ # Encode frames in chunks if necessary
708
+ total_frame = xs_raw.shape[0]
709
+ if total_frame > 10:
710
+ xs = torch.cat([
711
+ self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu()
712
+ for i in range(10)
713
+ ])
714
+ else:
715
+ xs = self.encode(xs_raw).cpu()
716
+
717
+ n_frames, batch_size, *_ = xs.shape
718
+ curr_frame = 0
719
+
720
+ # Initialize context frames
721
+ n_context_frames = self.context_frames // self.frame_stack
722
+ xs_pred = xs[:n_context_frames].clone()
723
+ curr_frame += n_context_frames
724
+
725
+ # Progress bar for sampling
726
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
727
+
728
+ while curr_frame < n_frames:
729
+ # Determine the horizon for the current chunk
730
+ horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame
731
+ assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens."
732
+
733
+ # Generate scheduling matrix and initialize noise
734
+ scheduling_matrix = self._generate_scheduling_matrix(horizon)
735
+ chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:]))
736
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device)
737
+ xs_pred = torch.cat([xs_pred, chunk], 0)
738
+
739
+ # Sliding window: only input the last `n_tokens` frames
740
+ start_frame = max(0, curr_frame + horizon - self.n_tokens)
741
+ pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon})
742
+
743
+ # Handle condition similarity logic
744
+ if memory_condition_length:
745
+ random_idx = self._generate_condition_indices(
746
+ curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon
747
+ )
748
+
749
+ xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
750
+
751
+ # Prepare input conditions and pose conditions
752
+ input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
753
+ start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
754
+ image_width=xs_raw.shape[-1], image_height=xs_raw.shape[-2]
755
+ )
756
+
757
+ # Perform sampling for each step in the scheduling matrix
758
+ for m in range(scheduling_matrix.shape[0] - 1):
759
+ from_noise_levels, to_noise_levels = self._prepare_noise_levels(
760
+ scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
761
+ )
762
+
763
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
764
+ xs_pred[start_frame:].to(input_condition.device),
765
+ input_condition,
766
+ input_pose_condition,
767
+ from_noise_levels[start_frame:],
768
+ to_noise_levels[start_frame:],
769
+ current_frame=curr_frame,
770
+ mode="validation",
771
+ reference_length=memory_condition_length,
772
+ frame_idx=frame_idx_list
773
+ ).cpu()
774
+
775
+ # Remove condition similarity frames if applicable
776
+ if memory_condition_length:
777
+ xs_pred = xs_pred[:-memory_condition_length]
778
+
779
+ curr_frame += horizon
780
+ pbar.update(horizon)
781
+
782
+ # Decode predictions and ground truth
783
+ xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device))
784
+ xs_decode = self.decode(xs[n_context_frames:].to(conditions.device))
785
+
786
+ # Store results for evaluation (move to CPU to save GPU memory)
787
+ self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu()))
788
+ return
789
+
790
+ @torch.no_grad()
791
+ def interactive(self, first_frame, new_actions, first_pose, device,
792
+ memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx):
793
+
794
+ memory_condition_length = self.memory_condition_length
795
+
796
+ if memory_latent_frames is None:
797
+ first_frame = torch.from_numpy(first_frame)
798
+ new_actions = torch.from_numpy(new_actions)
799
+ first_pose = torch.from_numpy(first_pose)
800
+ first_frame_encode = self.encode(first_frame[None, None].to(device))
801
+ memory_latent_frames = first_frame_encode.cpu()
802
+ memory_actions = new_actions[None, None].to(device)
803
+ memory_poses = first_pose[None, None].to(device)
804
+ new_c2w_mat = euler_to_camera_to_world_matrix(first_pose)
805
+ memory_c2w = new_c2w_mat[None, None].to(device)
806
+ memory_frame_idx = torch.tensor([[0]]).to(device)
807
+ return first_frame.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()
808
+ else:
809
+ memory_latent_frames = torch.from_numpy(memory_latent_frames)
810
+ memory_actions = torch.from_numpy(memory_actions).to(device)
811
+ memory_poses = torch.from_numpy(memory_poses).to(device)
812
+ memory_c2w = torch.from_numpy(memory_c2w).to(device)
813
+ memory_frame_idx = torch.from_numpy(memory_frame_idx).to(device)
814
+ new_actions = new_actions.to(device)
815
+
816
+ curr_frame = 0
817
+ batch_size = 1
818
+ horizon = self.next_frame_length
819
+ n_frames = curr_frame + horizon
820
+ # context
821
+ n_context_frames = len(memory_latent_frames)
822
+ xs_pred = memory_latent_frames[:n_context_frames].clone()
823
+ curr_frame += n_context_frames
824
+
825
+ pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling")
826
+
827
+ new_pose_condition_list = []
828
+ last_frame = xs_pred[-1].clone()
829
+ last_pose_condition = memory_poses[-1].clone()
830
+ curr_actions = new_actions.clone()
831
+ for hi in range(len(new_actions)):
832
+ last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15
833
+ new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None, hi], last_pose_condition)
834
+ new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:])
835
+ new_pose_condition = last_pose_condition + new_pose_condition_offset
836
+ new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15
837
+ new_pose_condition[:,3:] %= 360
838
+ last_pose_condition = new_pose_condition.clone()
839
+ new_pose_condition_list.append(new_pose_condition[None])
840
+ new_pose_condition_list = torch.cat(new_pose_condition_list, 0)
841
+
842
+ ai = 0
843
+ while ai < len(new_actions):
844
+ next_horizon = min(horizon, len(new_actions) - ai)
845
+ last_frame = xs_pred[-1].clone()
846
+ curr_actions = new_actions[ai:ai+next_horizon].clone()
847
+
848
+ new_pose_condition = new_pose_condition_list[ai:ai+next_horizon].clone()
849
+
850
+ new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition)
851
+ memory_poses = torch.cat([memory_poses, new_pose_condition])
852
+ memory_actions = torch.cat([memory_actions, curr_actions[:, None]])
853
+ memory_c2w = torch.cat([memory_c2w, new_c2w_mat])
854
+ new_indices = memory_frame_idx[-1,0] + torch.arange(next_horizon, device=memory_frame_idx.device) + 1
855
+
856
+ memory_frame_idx = torch.cat([memory_frame_idx, new_indices[:, None]])
857
+
858
+ conditions = memory_actions.clone()
859
+ pose_conditions = memory_poses.clone()
860
+ c2w_mat = memory_c2w .clone()
861
+ frame_idx = memory_frame_idx.clone()
862
+
863
+ # generation on frame
864
+ scheduling_matrix = self._generate_scheduling_matrix(next_horizon)
865
+ chunk = torch.randn((next_horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device)
866
+ chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise)
867
+
868
+ xs_pred = torch.cat([xs_pred, chunk], 0)
869
+
870
+ # sliding window: only input the last n_tokens frames
871
+ start_frame = max(0, curr_frame - self.n_tokens)
872
+
873
+ pbar.set_postfix(
874
+ {
875
+ "start": start_frame,
876
+ "end": curr_frame + next_horizon,
877
+ }
878
+ )
879
+
880
+ # Handle condition similarity logic
881
+ if memory_condition_length:
882
+ random_idx = self._generate_condition_indices(
883
+ curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, next_horizon
884
+ )
885
+
886
+ # random_idx = np.unique(random_idx)[:, None]
887
+ # memory_condition_length = len(random_idx)
888
+ xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0)
889
+
890
+ # Prepare input conditions and pose conditions
891
+ input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions(
892
+ start_frame, curr_frame, next_horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx,
893
+ image_width=first_frame.shape[-1], image_height=first_frame.shape[-2]
894
+ )
895
+
896
+ # Perform sampling for each step in the scheduling matrix
897
+ for m in range(scheduling_matrix.shape[0] - 1):
898
+ from_noise_levels, to_noise_levels = self._prepare_noise_levels(
899
+ scheduling_matrix, m, curr_frame, batch_size, memory_condition_length
900
+ )
901
+
902
+ xs_pred[start_frame:] = self.diffusion_model.sample_step(
903
+ xs_pred[start_frame:].to(input_condition.device),
904
+ input_condition,
905
+ input_pose_condition,
906
+ from_noise_levels[start_frame:],
907
+ to_noise_levels[start_frame:],
908
+ current_frame=curr_frame,
909
+ mode="validation",
910
+ reference_length=memory_condition_length,
911
+ frame_idx=frame_idx_list
912
+ ).cpu()
913
+
914
+
915
+ if memory_condition_length:
916
+ xs_pred = xs_pred[:-memory_condition_length]
917
+
918
+ curr_frame += next_horizon
919
+ pbar.update(next_horizon)
920
+ ai += next_horizon
921
+
922
+ memory_latent_frames = torch.cat([memory_latent_frames, xs_pred[n_context_frames:]])
923
+ xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu()
924
+
925
+ return xs_pred.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), \
926
+ memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy()
algorithms/worldmem/models/attention.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
3
+ """
4
+
5
+ from typing import Optional
6
+ from collections import namedtuple
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from einops import rearrange
11
+ from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
12
+ import numpy as np
13
+
14
+ class TemporalAxialAttention(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ heads: int,
19
+ dim_head: int,
20
+ reference_length: int,
21
+ rotary_emb: RotaryEmbedding,
22
+ is_causal: bool = True,
23
+ is_temporal_independent: bool = False,
24
+ use_domain_adapter = False
25
+ ):
26
+ super().__init__()
27
+ self.inner_dim = dim_head * heads
28
+ self.heads = heads
29
+ self.head_dim = dim_head
30
+ self.inner_dim = dim_head * heads
31
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
32
+
33
+ self.use_domain_adapter = use_domain_adapter
34
+ if self.use_domain_adapter:
35
+ lora_rank = 8
36
+ self.lora_A = nn.Linear(dim, lora_rank, bias=False)
37
+ self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
38
+
39
+ self.to_out = nn.Linear(self.inner_dim, dim)
40
+
41
+ self.rotary_emb = rotary_emb
42
+ self.is_causal = is_causal
43
+ self.is_temporal_independent = is_temporal_independent
44
+
45
+ self.reference_length = reference_length
46
+
47
+ def forward(self, x: torch.Tensor):
48
+ B, T, H, W, D = x.shape
49
+
50
+ # if T>=9:
51
+ # try:
52
+ # # x = torch.cat([x[:,:-1],x[:,16-T:17-T],x[:,-1:]], dim=1)
53
+ # x = torch.cat([x[:,16-T:17-T],x], dim=1)
54
+ # except:
55
+ # import pdb;pdb.set_trace()
56
+ # print("="*50)
57
+ # print(x.shape)
58
+
59
+ B, T, H, W, D = x.shape
60
+
61
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
62
+
63
+ if self.use_domain_adapter:
64
+ q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
65
+ q = q+q_lora
66
+ k = k+k_lora
67
+ v = v+v_lora
68
+
69
+ q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
70
+ k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
71
+ v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
72
+
73
+ q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
74
+ k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
75
+
76
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
77
+
78
+ if self.is_temporal_independent:
79
+ attn_bias = torch.ones((T, T), dtype=q.dtype, device=q.device)
80
+ attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
81
+ attn_bias[range(T), range(T)] = 0
82
+ elif self.is_causal:
83
+ attn_bias = torch.triu(torch.ones((T, T), dtype=q.dtype, device=q.device), diagonal=1)
84
+ attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf'))
85
+ attn_bias[(T-self.reference_length):] = float('-inf')
86
+ attn_bias[range(T), range(T)] = 0
87
+ else:
88
+ attn_bias = None
89
+
90
+ try:
91
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
92
+ except:
93
+ import pdb;pdb.set_trace()
94
+
95
+ x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
96
+ x = x.to(q.dtype)
97
+
98
+ # linear proj
99
+ x = self.to_out(x)
100
+
101
+ # if T>=10:
102
+ # try:
103
+ # # x = torch.cat([x[:,:-2],x[:,-1:]], dim=1)
104
+ # x = x[:,1:]
105
+ # except:
106
+ # import pdb;pdb.set_trace()
107
+ # print(x.shape)
108
+ return x
109
+
110
+ class SpatialAxialAttention(nn.Module):
111
+ def __init__(
112
+ self,
113
+ dim: int,
114
+ heads: int,
115
+ dim_head: int,
116
+ rotary_emb: RotaryEmbedding,
117
+ use_domain_adapter = False
118
+ ):
119
+ super().__init__()
120
+ self.inner_dim = dim_head * heads
121
+ self.heads = heads
122
+ self.head_dim = dim_head
123
+ self.inner_dim = dim_head * heads
124
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
125
+ self.use_domain_adapter = use_domain_adapter
126
+ if self.use_domain_adapter:
127
+ lora_rank = 8
128
+ self.lora_A = nn.Linear(dim, lora_rank, bias=False)
129
+ self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False)
130
+
131
+ self.to_out = nn.Linear(self.inner_dim, dim)
132
+
133
+ self.rotary_emb = rotary_emb
134
+
135
+ def forward(self, x: torch.Tensor):
136
+ B, T, H, W, D = x.shape
137
+
138
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
139
+
140
+ if self.use_domain_adapter:
141
+ q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1)
142
+ q = q+q_lora
143
+ k = k+k_lora
144
+ v = v+v_lora
145
+
146
+ q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
147
+ k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
148
+ v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
149
+
150
+ freqs = self.rotary_emb.get_axial_freqs(H, W)
151
+ q = apply_rotary_emb(freqs, q)
152
+ k = apply_rotary_emb(freqs, k)
153
+
154
+ # prepare for attn
155
+ q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
156
+ k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
157
+ v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
158
+
159
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False)
160
+
161
+ x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
162
+ x = x.to(q.dtype)
163
+
164
+ # linear proj
165
+ x = self.to_out(x)
166
+ return x
167
+
168
+ class MemTemporalAxialAttention(nn.Module):
169
+ def __init__(
170
+ self,
171
+ dim: int,
172
+ heads: int,
173
+ dim_head: int,
174
+ rotary_emb: RotaryEmbedding,
175
+ is_causal: bool = True,
176
+ ):
177
+ super().__init__()
178
+ self.inner_dim = dim_head * heads
179
+ self.heads = heads
180
+ self.head_dim = dim_head
181
+ self.inner_dim = dim_head * heads
182
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
183
+ self.to_out = nn.Linear(self.inner_dim, dim)
184
+
185
+ self.rotary_emb = rotary_emb
186
+ self.is_causal = is_causal
187
+
188
+ self.reference_length = 3
189
+
190
+ def forward(self, x: torch.Tensor):
191
+ B, T, H, W, D = x.shape
192
+
193
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
194
+
195
+
196
+ q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
197
+ k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
198
+ v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
199
+
200
+
201
+
202
+ # q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
203
+ # k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
204
+
205
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
206
+
207
+ # if T == 21000:
208
+ # # 手动计算缩放点积分数
209
+ # _, _, _, d_k = q.shape
210
+ # scores = torch.einsum("b h n d, b h m d -> b h n m", q, k) / (d_k ** 0.5) # Shape: (B, T_q, T_k)
211
+
212
+ # # 计算注意力图 (Attention Map)
213
+ # attention_map = F.softmax(scores, dim=-1) # Shape: (B, T_q, T_k)
214
+ # b_, h_, n_, m_ = attention_map.shape
215
+ # attention_map = attention_map.reshape(1, int(np.sqrt(b_/1)), int(np.sqrt(b_/1)), h_, n_, m_)
216
+ # attention_map = attention_map.mean(3)
217
+
218
+ # attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
219
+ # T_origin = T - self.reference_length
220
+ # attn_bias[:T_origin, T_origin:] = 1
221
+ # attn_bias[range(T), range(T)] = 1
222
+
223
+ # attention_map = attention_map * attn_bias
224
+
225
+ # # print 注意力图
226
+ # import matplotlib.pyplot as plt
227
+ # fig, axes = plt.subplots(21000, 21000, figsize=(9, 9)) # 调整figsize以适配图像大小
228
+
229
+ # # 遍历3*3维度
230
+ # for i in range(21000):
231
+ # for j in range(21000):
232
+ # # 取出第(i, j)个子图像
233
+ # img = attention_map[0, :, :, i, j].cpu().numpy()
234
+ # axes[i, j].imshow(img, cmap='viridis') # 可以自定义cmap
235
+ # axes[i, j].axis('off') # 隐藏坐标轴
236
+
237
+ # # 调整子图间距
238
+ # plt.tight_layout()
239
+ # plt.savefig('attention_map.png')
240
+ # import pdb; pdb.set_trace()
241
+ # plt.close()
242
+
243
+ attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device)
244
+ attn_bias = attn_bias.masked_fill(attn_bias == 0, float('-inf'))
245
+ T_origin = T - self.reference_length
246
+ attn_bias[:T_origin, T_origin:] = 0
247
+ attn_bias[range(T), range(T)] = 0
248
+
249
+ # if T==121000:
250
+ # import pdb;pdb.set_trace()
251
+
252
+ try:
253
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias)
254
+ except:
255
+ import pdb;pdb.set_trace()
256
+
257
+ x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
258
+ x = x.to(q.dtype)
259
+
260
+ # linear proj
261
+ x = self.to_out(x)
262
+ return x
263
+
264
+ class MemFullAttention(nn.Module):
265
+ def __init__(
266
+ self,
267
+ dim: int,
268
+ heads: int,
269
+ dim_head: int,
270
+ reference_length: int,
271
+ rotary_emb: RotaryEmbedding,
272
+ is_causal: bool = True
273
+ ):
274
+ super().__init__()
275
+ self.inner_dim = dim_head * heads
276
+ self.heads = heads
277
+ self.head_dim = dim_head
278
+ self.inner_dim = dim_head * heads
279
+ self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
280
+ self.to_out = nn.Linear(self.inner_dim, dim)
281
+
282
+ self.rotary_emb = rotary_emb
283
+ self.is_causal = is_causal
284
+
285
+ self.reference_length = reference_length
286
+
287
+ self.store = None
288
+
289
+ def forward(self, x: torch.Tensor, relative_embedding=False,
290
+ extra_condition=None,
291
+ state_embed_only_on_qk=False,
292
+ reference_length=None):
293
+
294
+ B, T, H, W, D = x.shape
295
+
296
+ if state_embed_only_on_qk:
297
+ q, k, _ = self.to_qkv(x+extra_condition).chunk(3, dim=-1)
298
+ _, _, v = self.to_qkv(x).chunk(3, dim=-1)
299
+ else:
300
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
301
+
302
+ if relative_embedding:
303
+ length = reference_length+1
304
+ n_frames = T // length
305
+ x = x.reshape(B, n_frames, length, H, W, D)
306
+
307
+ x_list = []
308
+
309
+ for i in range(n_frames):
310
+ if i == n_frames-1:
311
+ q_i = rearrange(q[:, i*length:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
312
+ k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
313
+ v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
314
+ else:
315
+ q_i = rearrange(q[:, i*length:i*length+1], "B T H W (h d) -> B h (T H W) d", h=self.heads)
316
+ k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
317
+ v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads)
318
+
319
+ q_i, k_i, v_i = map(lambda t: t.contiguous(), (q_i, k_i, v_i))
320
+ x_i = F.scaled_dot_product_attention(query=q_i, key=k_i, value=v_i)
321
+ x_i = rearrange(x_i, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
322
+ x_i = x_i.to(q.dtype)
323
+ x_list.append(x_i)
324
+
325
+ x = torch.cat(x_list, dim=1)
326
+
327
+
328
+ else:
329
+ T_ = T - reference_length
330
+ q = rearrange(q, "B T H W (h d) -> B h (T H W) d", h=self.heads)
331
+ k = rearrange(k[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
332
+ v = rearrange(v[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads)
333
+
334
+ q, k, v = map(lambda t: t.contiguous(), (q, k, v))
335
+ x = F.scaled_dot_product_attention(query=q, key=k, value=v)
336
+ x = rearrange(x, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W)
337
+ x = x.to(q.dtype)
338
+
339
+ # linear proj
340
+ x = self.to_out(x)
341
+
342
+ return x
algorithms/worldmem/models/diffusion.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Callable
2
+ from collections import namedtuple
3
+ from omegaconf import DictConfig
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from einops import rearrange
8
+ from .utils import linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, extract
9
+ from .dit import DiT_models
10
+
11
+ ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start", "model_out"])
12
+
13
+
14
+ class Diffusion(nn.Module):
15
+ # Special thanks to lucidrains for the implementation of the base Diffusion model
16
+ # https://github.com/lucidrains/denoising-diffusion-pytorch
17
+
18
+ def __init__(
19
+ self,
20
+ x_shape: torch.Size,
21
+ reference_length: int,
22
+ action_cond_dim: int,
23
+ pose_cond_dim,
24
+ is_causal: bool,
25
+ cfg: DictConfig,
26
+ is_dit: bool=False,
27
+ use_plucker=False,
28
+ relative_embedding=False,
29
+ state_embed_only_on_qk=False,
30
+ use_memory_attention=False,
31
+ add_timestamp_embedding=False,
32
+ memory_token_cross_attention=False,
33
+ memory_cross_attn_layers=None,
34
+ ref_mode='sequential'
35
+ ):
36
+ super().__init__()
37
+ self.cfg = cfg
38
+
39
+ self.x_shape = x_shape
40
+ self.action_cond_dim = action_cond_dim
41
+ self.timesteps = cfg.timesteps
42
+ self.sampling_timesteps = cfg.sampling_timesteps
43
+ self.beta_schedule = cfg.beta_schedule
44
+ self.schedule_fn_kwargs = cfg.schedule_fn_kwargs
45
+ self.objective = cfg.objective
46
+ self.use_fused_snr = cfg.use_fused_snr
47
+ self.snr_clip = cfg.snr_clip
48
+ self.cum_snr_decay = cfg.cum_snr_decay
49
+ self.ddim_sampling_eta = cfg.ddim_sampling_eta
50
+ self.clip_noise = cfg.clip_noise
51
+ self.arch = cfg.architecture
52
+ self.stabilization_level = cfg.stabilization_level
53
+ self.is_causal = is_causal
54
+ self.is_dit = is_dit
55
+ self.reference_length = reference_length
56
+ self.pose_cond_dim = pose_cond_dim
57
+ self.use_plucker = use_plucker
58
+ self.relative_embedding = relative_embedding
59
+ self.state_embed_only_on_qk = state_embed_only_on_qk
60
+ self.use_memory_attention = use_memory_attention
61
+ self.add_timestamp_embedding = add_timestamp_embedding
62
+ self.memory_token_cross_attention = memory_token_cross_attention
63
+ self.memory_cross_attn_layers = memory_cross_attn_layers
64
+ self.ref_mode = ref_mode
65
+ if self.use_memory_attention:
66
+ raise ValueError(
67
+ "WorldMem reference-frame use_memory_attention has been removed from DiT. "
68
+ "Use memory_token_cross_attention=True for compact memory tokens."
69
+ )
70
+
71
+ self._build_model()
72
+ self._build_buffer()
73
+
74
+ def _build_model(self):
75
+ x_channel = self.x_shape[0]
76
+ if self.is_dit:
77
+ self.model = DiT_models["DiT-S/2"](action_cond_dim=self.action_cond_dim,
78
+ reference_length=self.reference_length,
79
+ memory_token_cross_attention=self.memory_token_cross_attention,
80
+ memory_cross_attn_layers=self.memory_cross_attn_layers,
81
+ ref_mode=self.ref_mode)
82
+ else:
83
+ raise NotImplementedError
84
+
85
+ def _build_buffer(self):
86
+ if self.beta_schedule == "linear":
87
+ beta_schedule_fn = linear_beta_schedule
88
+ elif self.beta_schedule == "cosine":
89
+ beta_schedule_fn = cosine_beta_schedule
90
+ elif self.beta_schedule == "sigmoid":
91
+ beta_schedule_fn = sigmoid_beta_schedule
92
+ else:
93
+ raise ValueError(f"unknown beta schedule {self.beta_schedule}")
94
+
95
+ betas = beta_schedule_fn(self.timesteps, **self.schedule_fn_kwargs)
96
+
97
+ alphas = 1.0 - betas
98
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
99
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
100
+
101
+ # sampling related parameters
102
+ assert self.sampling_timesteps <= self.timesteps
103
+ self.is_ddim_sampling = self.sampling_timesteps < self.timesteps
104
+
105
+ # helper function to register buffer from float64 to float32
106
+ register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
107
+
108
+ register_buffer("betas", betas)
109
+ register_buffer("alphas_cumprod", alphas_cumprod)
110
+ register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
111
+
112
+ # calculations for diffusion q(x_t | x_{t-1}) and others
113
+
114
+ register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
115
+ register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
116
+ register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
117
+ register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
118
+ register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
119
+
120
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
121
+
122
+ posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
123
+
124
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
125
+
126
+ register_buffer("posterior_variance", posterior_variance)
127
+
128
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
129
+
130
+ register_buffer(
131
+ "posterior_log_variance_clipped",
132
+ torch.log(posterior_variance.clamp(min=1e-20)),
133
+ )
134
+ register_buffer(
135
+ "posterior_mean_coef1",
136
+ betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
137
+ )
138
+ register_buffer(
139
+ "posterior_mean_coef2",
140
+ (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod),
141
+ )
142
+
143
+ # calculate p2 reweighting
144
+
145
+ # register_buffer(
146
+ # "p2_loss_weight",
147
+ # (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
148
+ # ** -self.p2_loss_weight_gamma,
149
+ # )
150
+
151
+ # derive loss weight
152
+ # https://arxiv.org/abs/2303.09556
153
+ # snr: signal noise ratio
154
+ snr = alphas_cumprod / (1 - alphas_cumprod)
155
+ clipped_snr = snr.clone()
156
+ clipped_snr.clamp_(max=self.snr_clip)
157
+
158
+ register_buffer("clipped_snr", clipped_snr)
159
+ register_buffer("snr", snr)
160
+
161
+ def add_shape_channels(self, x):
162
+ return rearrange(x, f"... -> ...{' 1' * len(self.x_shape)}")
163
+
164
+ def model_predictions(self, x, t, action_cond=None, current_frame=None,
165
+ pose_cond=None, mode="training", reference_length=None, frame_idx=None,
166
+ memory_tokens=None, memory_token_mask=None, memory_retrieval_tokens=None, memory_retrieval_mask=None,
167
+ **memory_kwargs):
168
+ x = x.permute(1,0,2,3,4)
169
+ action_cond = action_cond.permute(1,0,2)
170
+ if pose_cond is not None and pose_cond[0] is not None:
171
+ try:
172
+ pose_cond = pose_cond.permute(1,0,2)
173
+ except:
174
+ pass
175
+ t = t.permute(1,0)
176
+ model_output = self.model(x, t, action_cond, current_frame=current_frame, pose_cond=pose_cond,
177
+ mode=mode, reference_length=reference_length, frame_idx=frame_idx,
178
+ memory_tokens=memory_tokens, memory_token_mask=memory_token_mask,
179
+ memory_retrieval_tokens=memory_retrieval_tokens,
180
+ memory_retrieval_mask=memory_retrieval_mask, **memory_kwargs)
181
+ model_output = model_output.permute(1,0,2,3,4)
182
+ x = x.permute(1,0,2,3,4)
183
+ t = t.permute(1,0)
184
+
185
+ if self.objective == "pred_noise":
186
+ pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise)
187
+ x_start = self.predict_start_from_noise(x, t, pred_noise)
188
+
189
+ elif self.objective == "pred_x0":
190
+ x_start = model_output
191
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
192
+
193
+ elif self.objective == "pred_v":
194
+ v = model_output
195
+ x_start = self.predict_start_from_v(x, t, v)
196
+ pred_noise = self.predict_noise_from_start(x, t, x_start)
197
+
198
+
199
+ return ModelPrediction(pred_noise, x_start, model_output)
200
+
201
+ def predict_start_from_noise(self, x_t, t, noise):
202
+ return (
203
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
204
+ - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
205
+ )
206
+
207
+ def predict_noise_from_start(self, x_t, t, x0):
208
+ return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract(
209
+ self.sqrt_recipm1_alphas_cumprod, t, x_t.shape
210
+ )
211
+
212
+ def predict_v(self, x_start, t, noise):
213
+ return (
214
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise
215
+ - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
216
+ )
217
+
218
+ def predict_start_from_v(self, x_t, t, v):
219
+ return (
220
+ extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
221
+ - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
222
+ )
223
+
224
+ def q_mean_variance(self, x_start, t):
225
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
226
+ variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape)
227
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
228
+ return mean, variance, log_variance
229
+
230
+ def q_posterior(self, x_start, x_t, t):
231
+ posterior_mean = (
232
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
233
+ + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
234
+ )
235
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
236
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
237
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
238
+
239
+ def q_sample(self, x_start, t, noise=None):
240
+ if noise is None:
241
+ noise = torch.randn_like(x_start)
242
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
243
+ return (
244
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
245
+ + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
246
+ )
247
+
248
+ def p_mean_variance(
249
+ self,
250
+ x,
251
+ t,
252
+ action_cond=None,
253
+ pose_cond=None,
254
+ reference_length=None,
255
+ frame_idx=None,
256
+ memory_tokens=None,
257
+ memory_token_mask=None,
258
+ memory_retrieval_tokens=None,
259
+ memory_retrieval_mask=None,
260
+ **memory_kwargs,
261
+ ):
262
+ model_pred = self.model_predictions(x=x, t=t, action_cond=action_cond,
263
+ pose_cond=pose_cond, reference_length=reference_length,
264
+ frame_idx=frame_idx, memory_tokens=memory_tokens, memory_token_mask=memory_token_mask,
265
+ memory_retrieval_tokens=memory_retrieval_tokens,
266
+ memory_retrieval_mask=memory_retrieval_mask, **memory_kwargs)
267
+ x_start = model_pred.pred_x_start
268
+ return self.q_posterior(x_start=x_start, x_t=x, t=t)
269
+
270
+ def compute_loss_weights(self, noise_levels: torch.Tensor):
271
+
272
+ snr = self.snr[noise_levels]
273
+ clipped_snr = self.clipped_snr[noise_levels]
274
+ normalized_clipped_snr = clipped_snr / self.snr_clip
275
+ normalized_snr = snr / self.snr_clip
276
+
277
+ if not self.use_fused_snr:
278
+ # min SNR reweighting
279
+ match self.objective:
280
+ case "pred_noise":
281
+ return clipped_snr / snr
282
+ case "pred_x0":
283
+ return clipped_snr
284
+ case "pred_v":
285
+ return clipped_snr / (snr + 1)
286
+
287
+ cum_snr = torch.zeros_like(normalized_snr)
288
+ for t in range(0, noise_levels.shape[0]):
289
+ if t == 0:
290
+ cum_snr[t] = normalized_clipped_snr[t]
291
+ else:
292
+ cum_snr[t] = self.cum_snr_decay * cum_snr[t - 1] + (1 - self.cum_snr_decay) * normalized_clipped_snr[t]
293
+
294
+ cum_snr = F.pad(cum_snr[:-1], (0, 0, 1, 0), value=0.0)
295
+ clipped_fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_clipped_snr)
296
+ fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_snr)
297
+
298
+ match self.objective:
299
+ case "pred_noise":
300
+ return clipped_fused_snr / fused_snr
301
+ case "pred_x0":
302
+ return clipped_fused_snr * self.snr_clip
303
+ case "pred_v":
304
+ return clipped_fused_snr * self.snr_clip / (fused_snr * self.snr_clip + 1)
305
+ case _:
306
+ raise ValueError(f"unknown objective {self.objective}")
307
+
308
+ def forward(
309
+ self,
310
+ x: torch.Tensor,
311
+ action_cond: Optional[torch.Tensor],
312
+ pose_cond,
313
+ noise_levels: torch.Tensor,
314
+ reference_length,
315
+ frame_idx=None,
316
+ memory_tokens=None,
317
+ memory_token_mask=None,
318
+ memory_retrieval_tokens=None,
319
+ memory_retrieval_mask=None,
320
+ **memory_kwargs,
321
+ ):
322
+ noise = torch.randn_like(x)
323
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
324
+
325
+ noised_x = self.q_sample(x_start=x, t=noise_levels, noise=noise)
326
+
327
+ model_pred = self.model_predictions(x=noised_x, t=noise_levels, action_cond=action_cond,
328
+ pose_cond=pose_cond,reference_length=reference_length, frame_idx=frame_idx,
329
+ memory_tokens=memory_tokens, memory_token_mask=memory_token_mask,
330
+ memory_retrieval_tokens=memory_retrieval_tokens,
331
+ memory_retrieval_mask=memory_retrieval_mask, **memory_kwargs)
332
+
333
+ pred = model_pred.model_out
334
+ x_pred = model_pred.pred_x_start
335
+
336
+ if self.objective == "pred_noise":
337
+ target = noise
338
+ elif self.objective == "pred_x0":
339
+ target = x
340
+ elif self.objective == "pred_v":
341
+ target = self.predict_v(x, noise_levels, noise)
342
+ else:
343
+ raise ValueError(f"unknown objective {self.objective}")
344
+
345
+ # 训练的时候每个frame随便给噪声
346
+ loss = F.mse_loss(pred, target.detach(), reduction="none")
347
+ loss_weight = self.compute_loss_weights(noise_levels)
348
+
349
+ loss_weight = loss_weight.view(*loss_weight.shape, *((1,) * (loss.ndim - 2)))
350
+
351
+ loss = loss * loss_weight
352
+
353
+ return x_pred, loss
354
+
355
+ def sample_step(
356
+ self,
357
+ x: torch.Tensor,
358
+ action_cond: Optional[torch.Tensor],
359
+ pose_cond,
360
+ curr_noise_level: torch.Tensor,
361
+ next_noise_level: torch.Tensor,
362
+ guidance_fn: Optional[Callable] = None,
363
+ current_frame=None,
364
+ mode="training",
365
+ reference_length=None,
366
+ frame_idx=None,
367
+ memory_tokens=None,
368
+ memory_token_mask=None,
369
+ memory_retrieval_tokens=None,
370
+ memory_retrieval_mask=None,
371
+ **memory_kwargs,
372
+ ):
373
+ real_steps = torch.linspace(-1, self.timesteps - 1, steps=self.sampling_timesteps + 1, device=x.device).long()
374
+
375
+ # convert noise levels (0 ~ sampling_timesteps) to real noise levels (-1 ~ timesteps - 1)
376
+ curr_noise_level = real_steps[curr_noise_level]
377
+ next_noise_level = real_steps[next_noise_level]
378
+
379
+ if self.is_ddim_sampling:
380
+ return self.ddim_sample_step(
381
+ x=x,
382
+ action_cond=action_cond,
383
+ pose_cond=pose_cond,
384
+ curr_noise_level=curr_noise_level,
385
+ next_noise_level=next_noise_level,
386
+ guidance_fn=guidance_fn,
387
+ current_frame=current_frame,
388
+ mode=mode,
389
+ reference_length=reference_length,
390
+ frame_idx=frame_idx,
391
+ memory_tokens=memory_tokens,
392
+ memory_token_mask=memory_token_mask,
393
+ memory_retrieval_tokens=memory_retrieval_tokens,
394
+ memory_retrieval_mask=memory_retrieval_mask,
395
+ **memory_kwargs,
396
+ )
397
+
398
+ # FIXME: temporary code for checking ddpm sampling
399
+ assert torch.all(
400
+ (curr_noise_level - 1 == next_noise_level) | ((curr_noise_level == -1) & (next_noise_level == -1))
401
+ ), "Wrong noise level given for ddpm sampling."
402
+
403
+ assert (
404
+ self.sampling_timesteps == self.timesteps
405
+ ), "sampling_timesteps should be equal to timesteps for ddpm sampling."
406
+
407
+ return self.ddpm_sample_step(
408
+ x=x,
409
+ action_cond=action_cond,
410
+ pose_cond=pose_cond,
411
+ curr_noise_level=curr_noise_level,
412
+ guidance_fn=guidance_fn,
413
+ reference_length=reference_length,
414
+ frame_idx=frame_idx,
415
+ memory_tokens=memory_tokens,
416
+ memory_token_mask=memory_token_mask,
417
+ memory_retrieval_tokens=memory_retrieval_tokens,
418
+ memory_retrieval_mask=memory_retrieval_mask,
419
+ **memory_kwargs,
420
+ )
421
+
422
+ def ddpm_sample_step(
423
+ self,
424
+ x: torch.Tensor,
425
+ action_cond: Optional[torch.Tensor],
426
+ pose_cond,
427
+ curr_noise_level: torch.Tensor,
428
+ guidance_fn: Optional[Callable] = None,
429
+ reference_length=None,
430
+ frame_idx=None,
431
+ memory_tokens=None,
432
+ memory_token_mask=None,
433
+ memory_retrieval_tokens=None,
434
+ memory_retrieval_mask=None,
435
+ **memory_kwargs,
436
+ ):
437
+ clipped_curr_noise_level = torch.where(
438
+ curr_noise_level < 0,
439
+ torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
440
+ curr_noise_level,
441
+ )
442
+
443
+ # treating as stabilization would require us to scale with sqrt of alpha_cum
444
+ orig_x = x.clone().detach()
445
+ scaled_context = self.q_sample(
446
+ x,
447
+ clipped_curr_noise_level,
448
+ noise=torch.zeros_like(x),
449
+ )
450
+ x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
451
+
452
+ if guidance_fn is not None:
453
+ raise NotImplementedError("Guidance function is not implemented for ddpm sampling yet.")
454
+
455
+ else:
456
+ model_mean, _, model_log_variance = self.p_mean_variance(
457
+ x=x,
458
+ t=clipped_curr_noise_level,
459
+ action_cond=action_cond,
460
+ pose_cond=pose_cond,
461
+ reference_length=reference_length,
462
+ frame_idx=frame_idx,
463
+ memory_tokens=memory_tokens,
464
+ memory_token_mask=memory_token_mask,
465
+ memory_retrieval_tokens=memory_retrieval_tokens,
466
+ memory_retrieval_mask=memory_retrieval_mask,
467
+ **memory_kwargs,
468
+ )
469
+
470
+ noise = torch.where(
471
+ self.add_shape_channels(clipped_curr_noise_level > 0),
472
+ torch.randn_like(x),
473
+ 0,
474
+ )
475
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
476
+ x_pred = model_mean + torch.exp(0.5 * model_log_variance) * noise
477
+
478
+ # only update frames where the noise level decreases
479
+ return torch.where(self.add_shape_channels(curr_noise_level == -1), orig_x, x_pred)
480
+
481
+ def ddim_sample_step(
482
+ self,
483
+ x: torch.Tensor,
484
+ action_cond: Optional[torch.Tensor],
485
+ pose_cond,
486
+ curr_noise_level: torch.Tensor,
487
+ next_noise_level: torch.Tensor,
488
+ guidance_fn: Optional[Callable] = None,
489
+ current_frame=None,
490
+ mode="training",
491
+ reference_length=None,
492
+ frame_idx=None,
493
+ memory_tokens=None,
494
+ memory_token_mask=None,
495
+ memory_retrieval_tokens=None,
496
+ memory_retrieval_mask=None,
497
+ **memory_kwargs,
498
+ ):
499
+ # convert noise level -1 to self.stabilization_level - 1
500
+ clipped_curr_noise_level = torch.where(
501
+ curr_noise_level < 0,
502
+ torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long),
503
+ curr_noise_level,
504
+ )
505
+
506
+ # treating as stabilization would require us to scale with sqrt of alpha_cum
507
+ orig_x = x.clone().detach()
508
+ scaled_context = self.q_sample(
509
+ x,
510
+ clipped_curr_noise_level,
511
+ noise=torch.zeros_like(x),
512
+ )
513
+ x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x)
514
+
515
+ alpha = self.alphas_cumprod[clipped_curr_noise_level]
516
+ alpha_next = torch.where(
517
+ next_noise_level < 0,
518
+ torch.ones_like(next_noise_level),
519
+ self.alphas_cumprod[next_noise_level],
520
+ )
521
+ sigma = torch.where(
522
+ next_noise_level < 0,
523
+ torch.zeros_like(next_noise_level),
524
+ self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt(),
525
+ )
526
+ c = (1 - alpha_next - sigma**2).sqrt()
527
+
528
+ alpha_next = self.add_shape_channels(alpha_next)
529
+ c = self.add_shape_channels(c)
530
+ sigma = self.add_shape_channels(sigma)
531
+
532
+ if guidance_fn is not None:
533
+ with torch.enable_grad():
534
+ x = x.detach().requires_grad_()
535
+
536
+ model_pred = self.model_predictions(
537
+ x=x,
538
+ t=clipped_curr_noise_level,
539
+ action_cond=action_cond,
540
+ pose_cond=pose_cond,
541
+ current_frame=current_frame,
542
+ mode=mode,
543
+ reference_length=reference_length,
544
+ frame_idx=frame_idx,
545
+ memory_tokens=memory_tokens,
546
+ memory_token_mask=memory_token_mask,
547
+ memory_retrieval_tokens=memory_retrieval_tokens,
548
+ memory_retrieval_mask=memory_retrieval_mask,
549
+ **memory_kwargs,
550
+ )
551
+
552
+ guidance_loss = guidance_fn(model_pred.pred_x_start)
553
+ grad = -torch.autograd.grad(
554
+ guidance_loss,
555
+ x,
556
+ )[0]
557
+
558
+ pred_noise = model_pred.pred_noise + (1 - alpha_next).sqrt() * grad
559
+ x_start = self.predict_start_from_noise(x, clipped_curr_noise_level, pred_noise)
560
+
561
+ else:
562
+ # print(clipped_curr_noise_level)
563
+ model_pred = self.model_predictions(
564
+ x=x,
565
+ t=clipped_curr_noise_level,
566
+ action_cond=action_cond,
567
+ pose_cond=pose_cond,
568
+ current_frame=current_frame,
569
+ mode=mode,
570
+ reference_length=reference_length,
571
+ frame_idx=frame_idx,
572
+ memory_tokens=memory_tokens,
573
+ memory_token_mask=memory_token_mask,
574
+ memory_retrieval_tokens=memory_retrieval_tokens,
575
+ memory_retrieval_mask=memory_retrieval_mask,
576
+ **memory_kwargs,
577
+ )
578
+ x_start = model_pred.pred_x_start
579
+ pred_noise = model_pred.pred_noise
580
+
581
+ noise = torch.randn_like(x)
582
+ noise = torch.clamp(noise, -self.clip_noise, self.clip_noise)
583
+
584
+ x_pred = x_start * alpha_next.sqrt() + pred_noise * c + sigma * noise
585
+
586
+ # only update frames where the noise level decreases
587
+ mask = curr_noise_level == next_noise_level
588
+ x_pred = torch.where(
589
+ self.add_shape_channels(mask),
590
+ orig_x,
591
+ x_pred,
592
+ )
593
+
594
+ return x_pred
algorithms/worldmem/models/dit.py ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ - DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ - Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py
5
+ - Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py
6
+ """
7
+
8
+ from typing import Optional, Literal
9
+ import torch
10
+ from torch import nn
11
+ from .rotary_embedding_torch import RotaryEmbedding
12
+ from einops import rearrange
13
+ from .attention import SpatialAxialAttention, TemporalAxialAttention
14
+ from timm.models.vision_transformer import Mlp
15
+ from timm.layers.helpers import to_2tuple
16
+ import math
17
+ from collections import namedtuple
18
+ from typing import Optional, Callable
19
+
20
+ def modulate(x, shift, scale):
21
+ fixed_dims = [1] * len(shift.shape[1:])
22
+ shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims)
23
+ scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims)
24
+ while shift.dim() < x.dim():
25
+ shift = shift.unsqueeze(-2)
26
+ scale = scale.unsqueeze(-2)
27
+ return x * (1 + scale) + shift
28
+
29
+ def gate(x, g):
30
+ fixed_dims = [1] * len(g.shape[1:])
31
+ g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims)
32
+ while g.dim() < x.dim():
33
+ g = g.unsqueeze(-2)
34
+ return g * x
35
+
36
+
37
+ class PatchEmbed(nn.Module):
38
+ """2D Image to Patch Embedding"""
39
+
40
+ def __init__(
41
+ self,
42
+ img_height=256,
43
+ img_width=256,
44
+ patch_size=16,
45
+ in_chans=3,
46
+ embed_dim=768,
47
+ norm_layer=None,
48
+ flatten=True,
49
+ ):
50
+ super().__init__()
51
+ img_size = (img_height, img_width)
52
+ patch_size = to_2tuple(patch_size)
53
+ self.img_size = img_size
54
+ self.patch_size = patch_size
55
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
56
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
57
+ self.flatten = flatten
58
+
59
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
60
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
61
+
62
+ def forward(self, x, random_sample=False):
63
+ B, C, H, W = x.shape
64
+ assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
65
+
66
+ x = self.proj(x)
67
+ if self.flatten:
68
+ x = rearrange(x, "B C H W -> B (H W) C")
69
+ else:
70
+ x = rearrange(x, "B C H W -> B H W C")
71
+ x = self.norm(x)
72
+ return x
73
+
74
+
75
+ class TimestepEmbedder(nn.Module):
76
+ """
77
+ Embeds scalar timesteps into vector representations.
78
+ """
79
+
80
+ def __init__(self, hidden_size, frequency_embedding_size=256, freq_type='time_step'):
81
+ super().__init__()
82
+ self.mlp = nn.Sequential(
83
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size
84
+ nn.SiLU(),
85
+ nn.Linear(hidden_size, hidden_size, bias=True),
86
+ )
87
+ self.frequency_embedding_size = frequency_embedding_size
88
+ self.freq_type = freq_type
89
+
90
+ @staticmethod
91
+ def timestep_embedding(t, dim, max_period=10000, freq_type='time_step'):
92
+ """
93
+ Create sinusoidal timestep embeddings.
94
+ :param t: a 1-D Tensor of N indices, one per batch element.
95
+ These may be fractional.
96
+ :param dim: the dimension of the output.
97
+ :param max_period: controls the minimum frequency of the embeddings.
98
+ :return: an (N, D) Tensor of positional embeddings.
99
+ """
100
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
101
+ half = dim // 2
102
+
103
+ if freq_type == 'time_step':
104
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
105
+ elif freq_type == 'spatial': # ~(-5 5)
106
+ freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi
107
+ elif freq_type == 'angle': # 0-360
108
+ freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi / 180
109
+
110
+
111
+ args = t[:, None].float() * freqs[None]
112
+
113
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
114
+ if dim % 2:
115
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
116
+ return embedding
117
+
118
+ def forward(self, t):
119
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size, freq_type=self.freq_type)
120
+ t_emb = self.mlp(t_freq)
121
+ return t_emb
122
+
123
+
124
+ class FinalLayer(nn.Module):
125
+ """
126
+ The final layer of DiT.
127
+ """
128
+
129
+ def __init__(self, hidden_size, patch_size, out_channels):
130
+ super().__init__()
131
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
132
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
133
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
134
+
135
+ def forward(self, x, c):
136
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
137
+ x = modulate(self.norm_final(x), shift, scale)
138
+ x = self.linear(x)
139
+ return x
140
+
141
+
142
+ MEMORY_TYPE_NAMES = ("anchor", "dynamic", "revisit")
143
+ MEMORY_TYPE_ANCHOR = 0
144
+ MEMORY_TYPE_DYNAMIC = 1
145
+ MEMORY_TYPE_REVISIT = 2
146
+
147
+
148
+ class MemoryTokenCrossAttention(nn.Module):
149
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, num_memory_types=3):
150
+ super().__init__()
151
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
152
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
153
+ self.num_heads = num_heads
154
+ self.num_memory_types = num_memory_types
155
+ self.norm_q = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
156
+ self.norm_mem = nn.LayerNorm(hidden_size, eps=1e-6)
157
+ self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
158
+ self.norm_mlp = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
159
+ self.mlp = Mlp(
160
+ in_features=hidden_size,
161
+ hidden_features=mlp_hidden_dim,
162
+ act_layer=approx_gelu,
163
+ drop=0,
164
+ )
165
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
166
+ self.memory_type_embed = nn.Embedding(num_memory_types, hidden_size)
167
+ self.memory_type_scale = nn.Parameter(torch.ones(num_memory_types, hidden_size))
168
+ self.memory_type_gate = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, num_memory_types, bias=True))
169
+ self.last_gate_mean = None
170
+ self.last_delta_ratio = None
171
+ self.last_valid_fraction = None
172
+ self.last_type_gate_mean = None
173
+ for type_name in MEMORY_TYPE_NAMES[:num_memory_types]:
174
+ setattr(self, f"last_type_gate_{type_name}_mean", None)
175
+ nn.init.normal_(self.memory_type_embed.weight, std=0.02)
176
+ self.reset_identity_init()
177
+
178
+ def reset_identity_init(self):
179
+ nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
180
+ nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
181
+ nn.init.constant_(self.memory_type_gate[-1].weight, 0)
182
+ nn.init.constant_(self.memory_type_gate[-1].bias, 0)
183
+
184
+ def _attend(self, query, memory_tokens, memory_token_mask=None, memory_token_gate=None):
185
+ if memory_token_mask is None and memory_token_gate is None:
186
+ out, _ = self.attn(query, memory_tokens, memory_tokens, need_weights=False)
187
+ return out, None
188
+
189
+ if memory_token_mask is None:
190
+ memory_token_mask = torch.ones(
191
+ memory_tokens.shape[:2],
192
+ device=memory_tokens.device,
193
+ dtype=torch.bool,
194
+ )
195
+ else:
196
+ memory_token_mask = memory_token_mask.bool()
197
+ gate_tensor = None
198
+ if memory_token_gate is not None:
199
+ if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:2]):
200
+ raise ValueError(
201
+ f"memory_token_gate must have shape {tuple(memory_tokens.shape[:2])}, "
202
+ f"got {tuple(memory_token_gate.shape)}"
203
+ )
204
+ gate_tensor = memory_token_gate.to(device=memory_tokens.device, dtype=query.dtype)
205
+ memory_token_mask = memory_token_mask & (gate_tensor > 0)
206
+ valid_rows = memory_token_mask.any(dim=1)
207
+ out = torch.zeros_like(query)
208
+ if valid_rows.any():
209
+ attn_mask = None
210
+ key_padding_mask = ~memory_token_mask[valid_rows]
211
+ if gate_tensor is not None:
212
+ gate_bias = torch.log(gate_tensor[valid_rows].clamp_min(1.0e-6))
213
+ gate_bias = gate_bias[:, None, :].expand(-1, query.shape[1], -1)
214
+ attn_mask = gate_bias.repeat_interleave(self.num_heads, dim=0)
215
+ float_padding_mask = torch.zeros_like(gate_tensor[valid_rows], dtype=query.dtype)
216
+ key_padding_mask = float_padding_mask.masked_fill(key_padding_mask, float("-inf"))
217
+ attended, _ = self.attn(
218
+ query[valid_rows],
219
+ memory_tokens[valid_rows],
220
+ memory_tokens[valid_rows],
221
+ key_padding_mask=key_padding_mask,
222
+ attn_mask=attn_mask,
223
+ need_weights=False,
224
+ )
225
+ out[valid_rows] = attended.to(out.dtype)
226
+ return out, valid_rows
227
+
228
+ def _apply_memory_type(self, memory_tokens, memory_type_ids):
229
+ if memory_type_ids is None:
230
+ return memory_tokens
231
+ memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
232
+ type_embed = self.memory_type_embed(memory_type_ids).to(memory_tokens.dtype)
233
+ type_scale = self.memory_type_scale[memory_type_ids].to(memory_tokens.dtype)
234
+ while type_embed.dim() < memory_tokens.dim():
235
+ type_embed = type_embed.unsqueeze(0)
236
+ type_scale = type_scale.unsqueeze(0)
237
+ return memory_tokens * type_scale + type_embed
238
+
239
+ def _store_type_gate_diagnostics(self, stage_gate):
240
+ with torch.no_grad():
241
+ detached = stage_gate.detach().float()
242
+ self.last_type_gate_mean = detached.mean()
243
+ for type_idx, type_name in enumerate(MEMORY_TYPE_NAMES[: self.num_memory_types]):
244
+ setattr(self, f"last_type_gate_{type_name}_mean", detached[..., type_idx].mean())
245
+
246
+ def _type_stage_gate(self, c, memory_tokens, memory_type_ids):
247
+ if memory_type_ids is None:
248
+ return None
249
+ memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long)
250
+ stage_gate = torch.sigmoid(self.memory_type_gate(c)).to(memory_tokens.dtype)
251
+ self._store_type_gate_diagnostics(stage_gate)
252
+ if memory_tokens.dim() == 4:
253
+ batch_size, num_frames, num_tokens = memory_tokens.shape[:3]
254
+ if memory_type_ids.dim() == 1:
255
+ gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, num_frames, num_tokens)
256
+ elif tuple(memory_type_ids.shape) == (batch_size, num_frames, num_tokens):
257
+ gather_ids = memory_type_ids
258
+ else:
259
+ raise ValueError(
260
+ "rank-4 memory_type_ids must have shape (M,) or (B,T,M), "
261
+ f"got {tuple(memory_type_ids.shape)}"
262
+ )
263
+ return torch.gather(stage_gate, dim=-1, index=gather_ids)
264
+ if memory_tokens.dim() == 3:
265
+ batch_size, num_tokens = memory_tokens.shape[:2]
266
+ if memory_type_ids.dim() != 1:
267
+ raise ValueError("rank-3 memory_type_ids must have shape (M,)")
268
+ gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, stage_gate.shape[1], num_tokens)
269
+ return torch.gather(stage_gate, dim=-1, index=gather_ids).mean(dim=1)
270
+ raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}")
271
+
272
+ def _combine_memory_gate(self, memory_tokens, memory_token_gate, type_stage_gate):
273
+ combined_gate = type_stage_gate
274
+ if memory_token_gate is not None:
275
+ if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:-1]):
276
+ raise ValueError(
277
+ f"memory_token_gate must have shape {tuple(memory_tokens.shape[:-1])}, "
278
+ f"got {tuple(memory_token_gate.shape)}"
279
+ )
280
+ stream_gate = memory_token_gate.to(device=memory_tokens.device, dtype=memory_tokens.dtype)
281
+ combined_gate = stream_gate if combined_gate is None else combined_gate * stream_gate
282
+ return combined_gate
283
+
284
+ def _valid_mask(self, valid_rows, batch_size, num_frames, dtype, device):
285
+ if valid_rows is None:
286
+ return None
287
+ valid_rows = valid_rows.to(device=device, dtype=dtype)
288
+ if valid_rows.numel() == batch_size:
289
+ return valid_rows.view(batch_size, 1, 1, 1, 1)
290
+ if valid_rows.numel() == batch_size * num_frames:
291
+ return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None]
292
+ raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}")
293
+
294
+ def _gate_valid_mask(self, valid_rows, batch_size, num_frames, dtype, device):
295
+ if valid_rows is None:
296
+ return None
297
+ valid_rows = valid_rows.to(device=device, dtype=dtype)
298
+ if valid_rows.numel() == batch_size:
299
+ return valid_rows.view(batch_size, 1, 1)
300
+ if valid_rows.numel() == batch_size * num_frames:
301
+ return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None]
302
+ raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}")
303
+
304
+ def _residual_gate(self, residual_gate, batch_size, num_frames, dtype, device):
305
+ if residual_gate is None:
306
+ return None
307
+ if not torch.is_tensor(residual_gate):
308
+ return torch.tensor(float(residual_gate), dtype=dtype, device=device).view(1, 1, 1, 1, 1)
309
+ gate_tensor = residual_gate.to(device=device, dtype=dtype)
310
+ if gate_tensor.dim() == 0:
311
+ gate_tensor = gate_tensor.view(1, 1, 1, 1, 1)
312
+ elif gate_tensor.dim() == 1:
313
+ if gate_tensor.numel() == batch_size:
314
+ gate_tensor = gate_tensor.view(batch_size, 1, 1, 1, 1)
315
+ elif gate_tensor.numel() == batch_size * num_frames:
316
+ gate_tensor = rearrange(gate_tensor, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None]
317
+ else:
318
+ raise ValueError(f"residual_gate has incompatible shape: {tuple(gate_tensor.shape)}")
319
+ elif gate_tensor.dim() == 2:
320
+ if tuple(gate_tensor.shape) != (batch_size, num_frames):
321
+ raise ValueError(f"residual_gate must have shape (B,T), got {tuple(gate_tensor.shape)}")
322
+ gate_tensor = gate_tensor[:, :, None, None, None]
323
+ elif gate_tensor.dim() == 3:
324
+ if tuple(gate_tensor.shape[:2]) != (batch_size, num_frames):
325
+ raise ValueError(f"residual_gate must start with (B,T), got {tuple(gate_tensor.shape)}")
326
+ gate_tensor = gate_tensor[:, :, :, None, None]
327
+ else:
328
+ while gate_tensor.dim() < 5:
329
+ gate_tensor = gate_tensor.unsqueeze(-1)
330
+ return gate_tensor
331
+
332
+ def _store_diagnostics(self, output, base, gate_msa, gate_mlp, valid_rows):
333
+ with torch.no_grad():
334
+ batch_size, num_frames = base.shape[:2]
335
+ gate_values = torch.cat(
336
+ [gate_msa.detach().float().abs(), gate_mlp.detach().float().abs()],
337
+ dim=-1,
338
+ )
339
+ gate_mask = self._gate_valid_mask(
340
+ valid_rows,
341
+ batch_size,
342
+ num_frames,
343
+ dtype=gate_values.dtype,
344
+ device=gate_values.device,
345
+ )
346
+ if gate_mask is not None:
347
+ gate_values = gate_values * gate_mask
348
+ self.last_valid_fraction = valid_rows.detach().float().mean()
349
+ valid_count = (gate_mask.sum() * gate_values.shape[-1]).clamp_min(1.0)
350
+ self.last_gate_mean = gate_values.sum() / valid_count
351
+ else:
352
+ self.last_valid_fraction = base.detach().new_tensor(1.0, dtype=torch.float32)
353
+ self.last_gate_mean = gate_values.mean()
354
+
355
+ delta_norm = (output.detach().float() - base.detach().float()).norm()
356
+ base_norm = base.detach().float().norm()
357
+ self.last_delta_ratio = delta_norm / (base_norm + 1e-6)
358
+
359
+ def forward(
360
+ self,
361
+ x,
362
+ c,
363
+ memory_tokens,
364
+ memory_token_mask=None,
365
+ residual_base=None,
366
+ return_delta=False,
367
+ residual_gate=None,
368
+ memory_type_ids=None,
369
+ memory_token_gate=None,
370
+ ):
371
+ B, T, H, W, D = x.shape
372
+ if residual_base is None:
373
+ residual_base = x
374
+ m_shift_msa, m_scale_msa, m_gate_msa, m_shift_mlp, m_scale_mlp, m_gate_mlp = (
375
+ self.adaLN_modulation(c).chunk(6, dim=-1)
376
+ )
377
+ query_source = modulate(self.norm_q(x), m_shift_msa, m_scale_msa)
378
+ type_stage_gate = self._type_stage_gate(c, memory_tokens, memory_type_ids)
379
+ effective_token_gate = self._combine_memory_gate(memory_tokens, memory_token_gate, type_stage_gate)
380
+ if memory_tokens.dim() == 3:
381
+ query = rearrange(query_source, "b t h w d -> b (t h w) d")
382
+ memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids)
383
+ valid_rows = None
384
+ if memory_token_mask is not None:
385
+ if tuple(memory_token_mask.shape) != tuple(memory_tokens.shape[:2]):
386
+ raise ValueError(
387
+ f"legacy memory mask must have shape {tuple(memory_tokens.shape[:2])}, "
388
+ f"got {tuple(memory_token_mask.shape)}"
389
+ )
390
+ out, valid_rows = self._attend(
391
+ query,
392
+ memory_tokens,
393
+ memory_token_mask=memory_token_mask,
394
+ memory_token_gate=effective_token_gate,
395
+ )
396
+ out = rearrange(out, "b (t h w) d -> b t h w d", t=T, h=H, w=W)
397
+ elif memory_tokens.dim() == 4:
398
+ assert memory_tokens.shape[:2] == (B, T), (
399
+ f"per-frame memory tokens must have shape (B, T, M, D), got {tuple(memory_tokens.shape)}"
400
+ )
401
+ query = rearrange(query_source, "b t h w d -> (b t) (h w) d")
402
+ memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids)
403
+ memory_tokens = rearrange(memory_tokens, "b t m d -> (b t) m d")
404
+ if effective_token_gate is not None:
405
+ effective_token_gate = rearrange(effective_token_gate, "b t m -> (b t) m")
406
+ valid_rows = None
407
+ if memory_token_mask is not None:
408
+ expected_mask_shape = (B, T, memory_tokens.shape[1])
409
+ if tuple(memory_token_mask.shape) != expected_mask_shape:
410
+ raise ValueError(
411
+ f"per-frame memory mask must have shape {expected_mask_shape}, "
412
+ f"got {tuple(memory_token_mask.shape)}"
413
+ )
414
+ memory_token_mask = rearrange(memory_token_mask.bool(), "b t m -> (b t) m")
415
+ out, valid_rows = self._attend(
416
+ query,
417
+ memory_tokens,
418
+ memory_token_mask=memory_token_mask,
419
+ memory_token_gate=effective_token_gate,
420
+ )
421
+ out = rearrange(out, "(b t) (h w) d -> b t h w d", b=B, t=T, h=H, w=W)
422
+ else:
423
+ raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}")
424
+
425
+ valid_mask = self._valid_mask(valid_rows, B, T, dtype=out.dtype, device=out.device)
426
+ residual_gate_tensor = self._residual_gate(residual_gate, B, T, dtype=out.dtype, device=out.device)
427
+ attn_delta = gate(out, m_gate_msa)
428
+ if valid_mask is not None:
429
+ attn_delta = attn_delta * valid_mask
430
+ if residual_gate_tensor is not None:
431
+ attn_delta = attn_delta * residual_gate_tensor
432
+ output = residual_base + attn_delta
433
+
434
+ mlp_delta = gate(self.mlp(modulate(self.norm_mlp(output), m_shift_mlp, m_scale_mlp)), m_gate_mlp)
435
+ if valid_mask is not None:
436
+ mlp_delta = mlp_delta * valid_mask
437
+ if residual_gate_tensor is not None:
438
+ mlp_delta = mlp_delta * residual_gate_tensor
439
+ output = output + mlp_delta
440
+ self._store_diagnostics(output, residual_base, m_gate_msa, m_gate_mlp, valid_rows)
441
+ if return_delta:
442
+ return attn_delta + mlp_delta
443
+ return output
444
+
445
+ class SpatioTemporalDiTBlock(nn.Module):
446
+ def __init__(
447
+ self,
448
+ hidden_size,
449
+ num_heads,
450
+ reference_length,
451
+ mlp_ratio=4.0,
452
+ is_causal=True,
453
+ spatial_rotary_emb: Optional[RotaryEmbedding] = None,
454
+ temporal_rotary_emb: Optional[RotaryEmbedding] = None,
455
+ use_memory_token_cross_attention=False,
456
+ ref_mode='sequential'
457
+ ):
458
+ super().__init__()
459
+ self.is_causal = is_causal
460
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
461
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
462
+
463
+ self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
464
+ self.s_attn = SpatialAxialAttention(
465
+ hidden_size,
466
+ heads=num_heads,
467
+ dim_head=hidden_size // num_heads,
468
+ rotary_emb=spatial_rotary_emb
469
+ )
470
+ self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
471
+ self.s_mlp = Mlp(
472
+ in_features=hidden_size,
473
+ hidden_features=mlp_hidden_dim,
474
+ act_layer=approx_gelu,
475
+ drop=0,
476
+ )
477
+ self.s_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
478
+
479
+ self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
480
+ self.t_attn = TemporalAxialAttention(
481
+ hidden_size,
482
+ heads=num_heads,
483
+ dim_head=hidden_size // num_heads,
484
+ is_causal=is_causal,
485
+ rotary_emb=temporal_rotary_emb,
486
+ reference_length=reference_length
487
+ )
488
+ self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
489
+ self.t_mlp = Mlp(
490
+ in_features=hidden_size,
491
+ hidden_features=mlp_hidden_dim,
492
+ act_layer=approx_gelu,
493
+ drop=0,
494
+ )
495
+ self.t_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True))
496
+
497
+ self.reference_length = reference_length
498
+ self.use_memory_token_cross_attention = use_memory_token_cross_attention
499
+ if self.use_memory_token_cross_attention:
500
+ self.memory_token_cross_attn = MemoryTokenCrossAttention(hidden_size, num_heads, mlp_ratio=mlp_ratio)
501
+
502
+ self.ref_mode = ref_mode
503
+
504
+ if self.ref_mode == 'parallel':
505
+ self.parallel_map = nn.Linear(hidden_size, hidden_size)
506
+
507
+ def _expand_memory_stream(self, tokens, mask, stream_gate, type_idx, batch_size, num_frames):
508
+ if tokens is None or tokens.shape[-2] == 0:
509
+ return None
510
+ if tokens.dim() == 3:
511
+ if tokens.shape[0] != batch_size:
512
+ raise ValueError(f"rank-3 memory tokens must start with B={batch_size}, got {tuple(tokens.shape)}")
513
+ tokens = tokens[:, None].expand(-1, num_frames, -1, -1)
514
+ if mask is None:
515
+ mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool)
516
+ elif mask.dim() == 2:
517
+ mask = mask[:, None].expand(-1, num_frames, -1)
518
+ elif mask.dim() != 3:
519
+ raise ValueError(f"rank-3 stream mask must have rank 2 or 3, got {tuple(mask.shape)}")
520
+ elif tokens.dim() == 4:
521
+ if tuple(tokens.shape[:2]) != (batch_size, num_frames):
522
+ raise ValueError(
523
+ f"rank-4 memory tokens must start with (B,T)={(batch_size, num_frames)}, "
524
+ f"got {tuple(tokens.shape)}"
525
+ )
526
+ if mask is None:
527
+ mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool)
528
+ elif mask.dim() != 3:
529
+ raise ValueError(f"rank-4 stream mask must have rank 3, got {tuple(mask.shape)}")
530
+ else:
531
+ raise ValueError(f"memory stream tokens must be rank 3 or 4, got rank {tokens.dim()}")
532
+ if tuple(mask.shape) != tuple(tokens.shape[:3]):
533
+ raise ValueError(f"memory stream mask must have shape {tuple(tokens.shape[:3])}, got {tuple(mask.shape)}")
534
+ gate_tensor = self._expand_memory_stream_gate(stream_gate, tokens)
535
+ type_ids = torch.full((tokens.shape[2],), int(type_idx), device=tokens.device, dtype=torch.long)
536
+ return tokens, mask.to(device=tokens.device, dtype=torch.bool), gate_tensor, type_ids
537
+
538
+ def _expand_memory_stream_gate(self, stream_gate, tokens):
539
+ batch_size, num_frames, num_tokens = tokens.shape[:3]
540
+ if stream_gate is None:
541
+ return torch.ones(tokens.shape[:3], device=tokens.device, dtype=tokens.dtype)
542
+ if not torch.is_tensor(stream_gate):
543
+ return torch.full(tokens.shape[:3], float(stream_gate), device=tokens.device, dtype=tokens.dtype)
544
+ gate_tensor = stream_gate.to(device=tokens.device, dtype=tokens.dtype)
545
+ if gate_tensor.dim() == 0:
546
+ return gate_tensor.view(1, 1, 1).expand(batch_size, num_frames, num_tokens)
547
+ if gate_tensor.dim() == 1:
548
+ if gate_tensor.numel() != batch_size:
549
+ raise ValueError(f"rank-1 memory gate must have B={batch_size} values, got {tuple(gate_tensor.shape)}")
550
+ return gate_tensor.view(batch_size, 1, 1).expand(batch_size, num_frames, num_tokens)
551
+ if gate_tensor.dim() == 2:
552
+ if tuple(gate_tensor.shape) == (batch_size, num_frames):
553
+ return gate_tensor[:, :, None].expand(batch_size, num_frames, num_tokens)
554
+ if tuple(gate_tensor.shape) == (batch_size, num_tokens):
555
+ return gate_tensor[:, None, :].expand(batch_size, num_frames, num_tokens)
556
+ raise ValueError(
557
+ f"rank-2 memory gate must have shape (B,T) or (B,M), got {tuple(gate_tensor.shape)}"
558
+ )
559
+ if gate_tensor.dim() == 3:
560
+ if tuple(gate_tensor.shape) == (batch_size, num_frames, 1):
561
+ return gate_tensor.expand(batch_size, num_frames, num_tokens)
562
+ if tuple(gate_tensor.shape) == (batch_size, num_frames, num_tokens):
563
+ return gate_tensor
564
+ raise ValueError(
565
+ f"rank-3 memory gate must have shape (B,T,1) or (B,T,M), got {tuple(gate_tensor.shape)}"
566
+ )
567
+ raise ValueError(f"memory gate rank must be <=3, got rank {gate_tensor.dim()}")
568
+
569
+ def _pack_typed_memory_streams(
570
+ self,
571
+ batch_size,
572
+ num_frames,
573
+ memory_tokens=None,
574
+ memory_token_mask=None,
575
+ memory_dynamic_tokens=None,
576
+ memory_dynamic_mask=None,
577
+ memory_retrieval_tokens=None,
578
+ memory_retrieval_mask=None,
579
+ memory_anchor_gate=None,
580
+ memory_dynamic_gate=None,
581
+ memory_retrieval_gate=None,
582
+ ):
583
+ streams = []
584
+ for tokens, mask, stream_gate, type_idx in (
585
+ (memory_tokens, memory_token_mask, memory_anchor_gate, MEMORY_TYPE_ANCHOR),
586
+ (memory_dynamic_tokens, memory_dynamic_mask, memory_dynamic_gate, MEMORY_TYPE_DYNAMIC),
587
+ (memory_retrieval_tokens, memory_retrieval_mask, memory_retrieval_gate, MEMORY_TYPE_REVISIT),
588
+ ):
589
+ expanded = self._expand_memory_stream(tokens, mask, stream_gate, type_idx, batch_size, num_frames)
590
+ if expanded is not None:
591
+ streams.append(expanded)
592
+ if not streams:
593
+ return None
594
+ packed_tokens = torch.cat([item[0] for item in streams], dim=2)
595
+ packed_mask = torch.cat([item[1] for item in streams], dim=2)
596
+ packed_gate = torch.cat([item[2] for item in streams], dim=2)
597
+ packed_type_ids = torch.cat([item[3] for item in streams], dim=0)
598
+ valid_gate = packed_gate.masked_fill(~packed_mask, 0)
599
+ residual_gate = valid_gate.max(dim=2).values
600
+ return packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate
601
+
602
+ def forward(self, x, c, current_frame=None, timestep=None, is_last_block=False,
603
+ pose_cond=None, mode="training", c_action_cond=None, reference_length=None,
604
+ memory_tokens=None, memory_token_mask=None, memory_dynamic_tokens=None, memory_dynamic_mask=None,
605
+ memory_retrieval_tokens=None, memory_retrieval_mask=None, memory_anchor_gate=None,
606
+ memory_dynamic_gate=None, memory_retrieval_gate=None):
607
+ B, T, H, W, D = x.shape
608
+
609
+ # spatial block
610
+
611
+ s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1)
612
+ x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa)
613
+ x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp)
614
+
615
+ # temporal block
616
+ if c_action_cond is not None:
617
+ t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c_action_cond).chunk(6, dim=-1)
618
+ else:
619
+ t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1)
620
+
621
+ x_t = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa)
622
+ x_t = x_t + gate(self.t_mlp(modulate(self.t_norm2(x_t), t_shift_mlp, t_scale_mlp)), t_gate_mlp)
623
+
624
+ if self.ref_mode == 'sequential':
625
+ x = x_t
626
+
627
+ if self.use_memory_token_cross_attention:
628
+ memory_base = x
629
+ packed_memory = self._pack_typed_memory_streams(
630
+ B,
631
+ T,
632
+ memory_tokens=memory_tokens,
633
+ memory_token_mask=memory_token_mask,
634
+ memory_dynamic_tokens=memory_dynamic_tokens,
635
+ memory_dynamic_mask=memory_dynamic_mask,
636
+ memory_retrieval_tokens=memory_retrieval_tokens,
637
+ memory_retrieval_mask=memory_retrieval_mask,
638
+ memory_anchor_gate=memory_anchor_gate,
639
+ memory_dynamic_gate=memory_dynamic_gate,
640
+ memory_retrieval_gate=memory_retrieval_gate,
641
+ )
642
+ if packed_memory is not None:
643
+ packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate = packed_memory
644
+ x = self.memory_token_cross_attn(
645
+ memory_base,
646
+ c,
647
+ packed_tokens,
648
+ packed_mask,
649
+ residual_gate=residual_gate,
650
+ memory_type_ids=packed_type_ids,
651
+ memory_token_gate=packed_gate,
652
+ )
653
+
654
+ if self.ref_mode == 'parallel':
655
+ x = x_t + self.parallel_map(x)
656
+
657
+ return x
658
+
659
+
660
+ class DiT(nn.Module):
661
+ """
662
+ Diffusion model with a Transformer backbone.
663
+ """
664
+
665
+ def __init__(
666
+ self,
667
+ input_h=18,
668
+ input_w=32,
669
+ patch_size=2,
670
+ in_channels=16,
671
+ hidden_size=1024,
672
+ depth=12,
673
+ num_heads=16,
674
+ mlp_ratio=4.0,
675
+ action_cond_dim=25,
676
+ max_frames=32,
677
+ reference_length=8,
678
+ memory_token_cross_attention=False,
679
+ memory_cross_attn_layers=None,
680
+ ref_mode='sequential'
681
+ ):
682
+ super().__init__()
683
+ self.in_channels = in_channels
684
+ self.out_channels = in_channels
685
+ self.patch_size = patch_size
686
+ self.num_heads = num_heads
687
+ self.max_frames = max_frames
688
+
689
+ self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False)
690
+ self.t_embedder = TimestepEmbedder(hidden_size)
691
+
692
+ self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
693
+ self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads)
694
+
695
+ self.external_cond = nn.Linear(action_cond_dim, hidden_size) if action_cond_dim > 0 else nn.Identity()
696
+ if memory_cross_attn_layers is None:
697
+ memory_cross_attn_layer_set = None
698
+ else:
699
+ memory_cross_attn_layer_set = {int(layer_idx) for layer_idx in memory_cross_attn_layers}
700
+ invalid_layers = sorted(
701
+ layer_idx for layer_idx in memory_cross_attn_layer_set if layer_idx < 0 or layer_idx >= depth
702
+ )
703
+ if invalid_layers:
704
+ raise ValueError(
705
+ f"memory_cross_attn_layers contains invalid indices {invalid_layers} for depth={depth}"
706
+ )
707
+
708
+ self.blocks = nn.ModuleList(
709
+ [
710
+ SpatioTemporalDiTBlock(
711
+ hidden_size,
712
+ num_heads,
713
+ mlp_ratio=mlp_ratio,
714
+ is_causal=True,
715
+ reference_length=reference_length,
716
+ spatial_rotary_emb=self.spatial_rotary_emb,
717
+ temporal_rotary_emb=self.temporal_rotary_emb,
718
+ use_memory_token_cross_attention=memory_token_cross_attention
719
+ and (memory_cross_attn_layer_set is None or block_idx in memory_cross_attn_layer_set),
720
+ ref_mode=ref_mode
721
+ )
722
+ for block_idx in range(depth)
723
+ ]
724
+ )
725
+ self.memory_token_cross_attention = memory_token_cross_attention
726
+ self.memory_cross_attn_layers = (
727
+ None if memory_cross_attn_layer_set is None else tuple(sorted(memory_cross_attn_layer_set))
728
+ )
729
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
730
+ self.initialize_weights()
731
+
732
+ def initialize_weights(self):
733
+ # Initialize transformer layers:
734
+ def _basic_init(module):
735
+ if isinstance(module, nn.Linear):
736
+ torch.nn.init.xavier_uniform_(module.weight)
737
+ if module.bias is not None:
738
+ nn.init.constant_(module.bias, 0)
739
+
740
+ self.apply(_basic_init)
741
+
742
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
743
+ w = self.x_embedder.proj.weight.data
744
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
745
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
746
+
747
+ # Initialize timestep embedding MLP:
748
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
749
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
750
+
751
+ # Zero-out adaLN modulation layers in DiT blocks:
752
+ for block in self.blocks:
753
+ nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0)
754
+ nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0)
755
+ nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0)
756
+ nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0)
757
+
758
+ # Zero-out output layers:
759
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
760
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
761
+ nn.init.constant_(self.final_layer.linear.weight, 0)
762
+ nn.init.constant_(self.final_layer.linear.bias, 0)
763
+
764
+ if self.memory_token_cross_attention:
765
+ for block in self.blocks:
766
+ memory_adapter = getattr(block, "memory_token_cross_attn", None)
767
+ if memory_adapter is not None:
768
+ memory_adapter.reset_identity_init()
769
+
770
+ def memory_adapter_delta_diagnostics(self):
771
+ diagnostics = {}
772
+ ratios = []
773
+ type_gate_values = {type_name: [] for type_name in MEMORY_TYPE_NAMES}
774
+ shared_type_gate_values = []
775
+ for block in self.blocks:
776
+ adapter = getattr(block, "memory_token_cross_attn", None)
777
+ if adapter is None:
778
+ continue
779
+ ratio = getattr(adapter, "last_delta_ratio", None)
780
+ if ratio is not None:
781
+ ratios.append(torch.as_tensor(ratio).detach().float())
782
+ type_gate = getattr(adapter, "last_type_gate_mean", None)
783
+ if type_gate is not None:
784
+ shared_type_gate_values.append(torch.as_tensor(type_gate).detach().float())
785
+ for type_name in MEMORY_TYPE_NAMES:
786
+ value = getattr(adapter, f"last_type_gate_{type_name}_mean", None)
787
+ if value is not None:
788
+ type_gate_values[type_name].append(torch.as_tensor(value).detach().float())
789
+ if ratios:
790
+ values = torch.stack(ratios)
791
+ diagnostics["memory_adapter_delta_ratio_max"] = float(values.max().item())
792
+ diagnostics["memory_adapter_delta_ratio_mean"] = float(values.mean().item())
793
+ if shared_type_gate_values:
794
+ values = torch.stack(shared_type_gate_values)
795
+ diagnostics["memory_adapter_type_gate_mean"] = float(values.mean().item())
796
+ for type_name, values_list in type_gate_values.items():
797
+ if values_list:
798
+ values = torch.stack(values_list)
799
+ diagnostics[f"memory_adapter_type_gate_{type_name}_mean"] = float(values.mean().item())
800
+ return diagnostics
801
+
802
+ def unpatchify(self, x):
803
+ """
804
+ x: (N, H, W, patch_size**2 * C)
805
+ imgs: (N, H, W, C)
806
+ """
807
+ c = self.out_channels
808
+ p = self.x_embedder.patch_size[0]
809
+ h = x.shape[1]
810
+ w = x.shape[2]
811
+
812
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
813
+ x = torch.einsum("nhwpqc->nchpwq", x)
814
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
815
+ return imgs
816
+
817
+ def forward(
818
+ self,
819
+ x,
820
+ t,
821
+ action_cond=None,
822
+ pose_cond=None,
823
+ current_frame=None,
824
+ mode=None,
825
+ reference_length=None,
826
+ frame_idx=None,
827
+ memory_tokens=None,
828
+ memory_token_mask=None,
829
+ memory_dynamic_tokens=None,
830
+ memory_dynamic_mask=None,
831
+ memory_retrieval_tokens=None,
832
+ memory_retrieval_mask=None,
833
+ memory_anchor_gate=None,
834
+ memory_dynamic_gate=None,
835
+ memory_retrieval_gate=None,
836
+ ):
837
+ """
838
+ Forward pass of DiT.
839
+ x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
840
+ t: (B, T,) tensor of diffusion timesteps
841
+ """
842
+
843
+ B, T, C, H, W = x.shape
844
+
845
+ # add spatial embeddings
846
+ x = rearrange(x, "b t c h w -> (b t) c h w")
847
+
848
+ x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
849
+ # restore shape
850
+ x = rearrange(x, "(b t) h w d -> b t h w d", t=T)
851
+ # embed noise steps
852
+ t = rearrange(t, "b t -> (b t)")
853
+
854
+ c_t = self.t_embedder(t) # (N, D)
855
+ c = c_t.clone()
856
+ c = rearrange(c, "(b t) d -> b t d", t=T)
857
+
858
+ if torch.is_tensor(action_cond):
859
+ c_action_cond = c + self.external_cond(action_cond)
860
+ else:
861
+ c_action_cond = None
862
+
863
+ for i, block in enumerate(self.blocks):
864
+ x = block(x, c, current_frame=current_frame, timestep=t, is_last_block= (i+1 == len(self.blocks)),
865
+ mode=mode, c_action_cond=c_action_cond, reference_length=reference_length,
866
+ memory_tokens=memory_tokens, memory_token_mask=memory_token_mask,
867
+ memory_dynamic_tokens=memory_dynamic_tokens, memory_dynamic_mask=memory_dynamic_mask,
868
+ memory_retrieval_tokens=memory_retrieval_tokens, memory_retrieval_mask=memory_retrieval_mask,
869
+ memory_anchor_gate=memory_anchor_gate, memory_dynamic_gate=memory_dynamic_gate,
870
+ memory_retrieval_gate=memory_retrieval_gate) # (N, T, H, W, D)
871
+ x = self.final_layer(x, c) # (N, T, H, W, patch_size ** 2 * out_channels)
872
+ # unpatchify
873
+ x = rearrange(x, "b t h w d -> (b t) h w d")
874
+ x = self.unpatchify(x) # (N, out_channels, H, W)
875
+ x = rearrange(x, "(b t) c h w -> b t c h w", t=T)
876
+ return x
877
+
878
+
879
+ def DiT_S_2(
880
+ action_cond_dim,
881
+ reference_length,
882
+ ref_mode,
883
+ memory_token_cross_attention=False,
884
+ memory_cross_attn_layers=None,
885
+ ):
886
+ return DiT(
887
+ patch_size=2,
888
+ hidden_size=1024,
889
+ depth=16,
890
+ num_heads=16,
891
+ action_cond_dim=action_cond_dim,
892
+ reference_length=reference_length,
893
+ memory_token_cross_attention=memory_token_cross_attention,
894
+ memory_cross_attn_layers=memory_cross_attn_layers,
895
+ ref_mode=ref_mode
896
+ )
897
+
898
+
899
+ DiT_models = {"DiT-S/2": DiT_S_2}
algorithms/worldmem/models/pose_prediction.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class PosePredictionNet(nn.Module):
6
+ def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128):
7
+ super(PosePredictionNet, self).__init__()
8
+
9
+ self.cnn = nn.Sequential(
10
+ nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1),
11
+ nn.ReLU(),
12
+ nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
13
+ nn.ReLU(),
14
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
15
+ nn.ReLU(),
16
+ nn.AdaptiveAvgPool2d((1, 1))
17
+ )
18
+
19
+ self.fc_img = nn.Linear(128, img_feat_dim)
20
+
21
+ self.mlp_motion = nn.Sequential(
22
+ nn.Linear(pose_dim + action_dim, hidden_dim),
23
+ nn.ReLU(),
24
+ nn.Linear(hidden_dim, hidden_dim),
25
+ nn.ReLU()
26
+ )
27
+
28
+ self.fc_out = nn.Sequential(
29
+ nn.Linear(img_feat_dim + hidden_dim, hidden_dim),
30
+ nn.ReLU(),
31
+ nn.Linear(hidden_dim, pose_dim)
32
+ )
33
+
34
+ def forward(self, img, action, pose):
35
+ img_feat = self.cnn(img).view(img.size(0), -1)
36
+ img_feat = self.fc_img(img_feat)
37
+
38
+ motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1))
39
+ fused_feat = torch.cat([img_feat, motion_feat], dim=1)
40
+ pose_next_pred = self.fc_out(fused_feat)
41
+
42
+ return pose_next_pred
algorithms/worldmem/models/rotary_embedding_torch.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from math import pi, log
7
+
8
+ import torch
9
+ from torch.nn import Module, ModuleList
10
+ from torch.amp import autocast
11
+ from torch import nn, einsum, broadcast_tensors, Tensor
12
+
13
+ from einops import rearrange, repeat
14
+
15
+ from typing import Literal
16
+
17
+ # helper functions
18
+
19
+
20
+ def exists(val):
21
+ return val is not None
22
+
23
+
24
+ def default(val, d):
25
+ return val if exists(val) else d
26
+
27
+
28
+ # broadcat, as tortoise-tts was using it
29
+
30
+
31
+ def broadcat(tensors, dim=-1):
32
+ broadcasted_tensors = broadcast_tensors(*tensors)
33
+ return torch.cat(broadcasted_tensors, dim=dim)
34
+
35
+
36
+ # rotary embedding helper functions
37
+
38
+
39
+ def rotate_half(x):
40
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
41
+ x1, x2 = x.unbind(dim=-1)
42
+ x = torch.stack((-x2, x1), dim=-1)
43
+ return rearrange(x, "... d r -> ... (d r)")
44
+
45
+
46
+ @autocast("cuda", enabled=False)
47
+ def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2):
48
+ dtype = t.dtype
49
+
50
+ if t.ndim == 3:
51
+ seq_len = t.shape[seq_dim]
52
+ freqs = freqs[-seq_len:]
53
+
54
+ rot_dim = freqs.shape[-1]
55
+ end_index = start_index + rot_dim
56
+
57
+ assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
58
+
59
+ # Split t into three parts: left, middle (to be transformed), and right
60
+ t_left = t[..., :start_index]
61
+ t_middle = t[..., start_index:end_index]
62
+ t_right = t[..., end_index:]
63
+
64
+ # Apply rotary embeddings without modifying t in place
65
+ t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
66
+
67
+ out = torch.cat((t_left, t_transformed, t_right), dim=-1)
68
+
69
+ return out.type(dtype)
70
+
71
+
72
+ # learned rotation helpers
73
+
74
+
75
+ def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
76
+ if exists(freq_ranges):
77
+ rotations = einsum("..., f -> ... f", rotations, freq_ranges)
78
+ rotations = rearrange(rotations, "... r f -> ... (r f)")
79
+
80
+ rotations = repeat(rotations, "... n -> ... (n r)", r=2)
81
+ return apply_rotary_emb(rotations, t, start_index=start_index)
82
+
83
+
84
+ # classes
85
+
86
+
87
+ class RotaryEmbedding(Module):
88
+ def __init__(
89
+ self,
90
+ dim,
91
+ custom_freqs: Tensor | None = None,
92
+ freqs_for: Literal["lang", "pixel", "constant"] = "lang",
93
+ theta=10000,
94
+ max_freq=10,
95
+ num_freqs=1,
96
+ learned_freq=False,
97
+ use_xpos=False,
98
+ xpos_scale_base=512,
99
+ interpolate_factor=1.0,
100
+ theta_rescale_factor=1.0,
101
+ seq_before_head_dim=False,
102
+ cache_if_possible=True,
103
+ cache_max_seq_len=8192,
104
+ ):
105
+ super().__init__()
106
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
107
+ # has some connection to NTK literature
108
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
109
+
110
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
111
+
112
+ self.freqs_for = freqs_for
113
+
114
+ if exists(custom_freqs):
115
+ freqs = custom_freqs
116
+ elif freqs_for == "lang":
117
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
118
+ elif freqs_for == "pixel":
119
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
120
+ elif freqs_for == "spacetime":
121
+ time_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
122
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
123
+ elif freqs_for == "constant":
124
+ freqs = torch.ones(num_freqs).float()
125
+
126
+ if freqs_for == "spacetime":
127
+ self.time_freqs = nn.Parameter(time_freqs, requires_grad=learned_freq)
128
+ self.freqs = nn.Parameter(freqs, requires_grad=learned_freq)
129
+
130
+ self.cache_if_possible = cache_if_possible
131
+ self.cache_max_seq_len = cache_max_seq_len
132
+
133
+ self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False)
134
+ self.register_buffer("cached_freqs_seq_len", torch.tensor(0), persistent=False)
135
+
136
+ self.learned_freq = learned_freq
137
+
138
+ # dummy for device
139
+
140
+ self.register_buffer("dummy", torch.tensor(0), persistent=False)
141
+
142
+ # default sequence dimension
143
+
144
+ self.seq_before_head_dim = seq_before_head_dim
145
+ self.default_seq_dim = -3 if seq_before_head_dim else -2
146
+
147
+ # interpolation factors
148
+
149
+ assert interpolate_factor >= 1.0
150
+ self.interpolate_factor = interpolate_factor
151
+
152
+ # xpos
153
+
154
+ self.use_xpos = use_xpos
155
+
156
+ if not use_xpos:
157
+ return
158
+
159
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
160
+ self.scale_base = xpos_scale_base
161
+
162
+ self.register_buffer("scale", scale, persistent=False)
163
+ self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False)
164
+ self.register_buffer("cached_scales_seq_len", torch.tensor(0), persistent=False)
165
+
166
+ # add apply_rotary_emb as static method
167
+
168
+ self.apply_rotary_emb = staticmethod(apply_rotary_emb)
169
+
170
+ @property
171
+ def device(self):
172
+ return self.dummy.device
173
+
174
+ def get_seq_pos(self, seq_len, device, dtype, offset=0):
175
+ return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor
176
+
177
+ def rotate_queries_or_keys(self, t, freqs, seq_dim=None, offset=0, scale=None):
178
+ seq_dim = default(seq_dim, self.default_seq_dim)
179
+
180
+ assert not self.use_xpos or exists(scale), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings"
181
+
182
+ device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
183
+
184
+ seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset)
185
+
186
+ seq_freqs = self.forward(seq, freqs, seq_len=seq_len, offset=offset)
187
+
188
+ if seq_dim == -3:
189
+ seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
190
+
191
+ return apply_rotary_emb(seq_freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim)
192
+
193
+ def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0):
194
+ dtype, device, seq_dim = (
195
+ q.dtype,
196
+ q.device,
197
+ default(seq_dim, self.default_seq_dim),
198
+ )
199
+
200
+ q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
201
+ assert q_len <= k_len
202
+
203
+ q_scale = k_scale = 1.0
204
+
205
+ if self.use_xpos:
206
+ seq = self.get_seq_pos(k_len, dtype=dtype, device=device)
207
+
208
+ q_scale = self.get_scale(seq[-q_len:]).type(dtype)
209
+ k_scale = self.get_scale(seq).type(dtype)
210
+
211
+ rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset)
212
+ rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1)
213
+
214
+ rotated_q = rotated_q.type(q.dtype)
215
+ rotated_k = rotated_k.type(k.dtype)
216
+
217
+ return rotated_q, rotated_k
218
+
219
+ def rotate_queries_and_keys(self, q, k, freqs, seq_dim=None):
220
+ seq_dim = default(seq_dim, self.default_seq_dim)
221
+
222
+ assert self.use_xpos
223
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
224
+
225
+ seq = self.get_seq_pos(seq_len, dtype=dtype, device=device)
226
+
227
+ seq_freqs = self.forward(seq, freqs, seq_len=seq_len)
228
+ scale = self.get_scale(seq, seq_len=seq_len).to(dtype)
229
+
230
+ if seq_dim == -3:
231
+ seq_freqs = rearrange(seq_freqs, "n d -> n 1 d")
232
+ scale = rearrange(scale, "n d -> n 1 d")
233
+
234
+ rotated_q = apply_rotary_emb(seq_freqs, q, scale=scale, seq_dim=seq_dim)
235
+ rotated_k = apply_rotary_emb(seq_freqs, k, scale=scale**-1, seq_dim=seq_dim)
236
+
237
+ rotated_q = rotated_q.type(q.dtype)
238
+ rotated_k = rotated_k.type(k.dtype)
239
+
240
+ return rotated_q, rotated_k
241
+
242
+ def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0):
243
+ assert self.use_xpos
244
+
245
+ should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len
246
+
247
+ if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len.item():
248
+ return self.cached_scales[offset : (offset + seq_len)]
249
+
250
+ scale = 1.0
251
+ if self.use_xpos:
252
+ power = (t - len(t) // 2) / self.scale_base
253
+ scale = self.scale ** rearrange(power, "n -> n 1")
254
+ scale = repeat(scale, "n d -> n (d r)", r=2)
255
+
256
+ if should_cache and offset == 0:
257
+ self.cached_scales[:seq_len] = scale.detach()
258
+ self.cached_scales_seq_len.copy_(seq_len)
259
+
260
+ return scale
261
+
262
+ def get_axial_freqs(self, *dims):
263
+ Colon = slice(None)
264
+ all_freqs = []
265
+
266
+ for ind, dim in enumerate(dims):
267
+ # only allow pixel freqs for last two dimensions
268
+ use_pixel = (self.freqs_for == "pixel" or self.freqs_for == "spacetime") and ind >= len(dims) - 2
269
+ if use_pixel:
270
+ pos = torch.linspace(-1, 1, steps=dim, device=self.device)
271
+ else:
272
+ pos = torch.arange(dim, device=self.device)
273
+
274
+ if self.freqs_for == "spacetime" and not use_pixel:
275
+ seq_freqs = self.forward(pos, self.time_freqs, seq_len=dim)
276
+ else:
277
+ seq_freqs = self.forward(pos, self.freqs, seq_len=dim)
278
+
279
+ all_axis = [None] * len(dims)
280
+ all_axis[ind] = Colon
281
+
282
+ new_axis_slice = (Ellipsis, *all_axis, Colon)
283
+ all_freqs.append(seq_freqs[new_axis_slice])
284
+
285
+ all_freqs = broadcast_tensors(*all_freqs)
286
+ return torch.cat(all_freqs, dim=-1)
287
+
288
+ @autocast("cuda", enabled=False)
289
+ def forward(self, t: Tensor, freqs: Tensor, seq_len=None, offset=0):
290
+ should_cache = self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" and (offset + seq_len) <= self.cache_max_seq_len
291
+
292
+ if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len.item():
293
+ return self.cached_freqs[offset : (offset + seq_len)].detach()
294
+
295
+ freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
296
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
297
+
298
+ if should_cache and offset == 0:
299
+ self.cached_freqs[:seq_len] = freqs.detach()
300
+ self.cached_freqs_seq_len.copy_(seq_len)
301
+
302
+ return freqs
algorithms/worldmem/models/utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
3
+ Action format derived from VPT https://github.com/openai/Video-Pre-Training
4
+ Adapted from https://github.com/etched-ai/open-oasis/blob/master/utils.py
5
+ """
6
+
7
+ import math
8
+ import torch
9
+ from torch import nn
10
+ from torchvision.io import read_image, read_video
11
+ from torchvision.transforms.functional import resize
12
+ from einops import rearrange
13
+ from typing import Mapping, Sequence
14
+ from einops import rearrange, parse_shape
15
+
16
+
17
+ def exists(val):
18
+ return val is not None
19
+
20
+
21
+ def default(val, d):
22
+ if exists(val):
23
+ return val
24
+ return d() if callable(d) else d
25
+
26
+
27
+ def extract(a, t, x_shape):
28
+ f, b = t.shape
29
+ out = a[t]
30
+ return out.reshape(f, b, *((1,) * (len(x_shape) - 2)))
31
+
32
+
33
+ def linear_beta_schedule(timesteps):
34
+ """
35
+ linear schedule, proposed in original ddpm paper
36
+ """
37
+ scale = 1000 / timesteps
38
+ beta_start = scale * 0.0001
39
+ beta_end = scale * 0.02
40
+ return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
41
+
42
+
43
+ def cosine_beta_schedule(timesteps, s=0.008):
44
+ """
45
+ cosine schedule
46
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
47
+ """
48
+ steps = timesteps + 1
49
+ t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
50
+ alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
51
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
52
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
53
+ return torch.clip(betas, 0, 0.999)
54
+
55
+
56
+
57
+ def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
58
+ """
59
+ sigmoid schedule
60
+ proposed in https://arxiv.org/abs/2212.11972 - Figure 8
61
+ better for images > 64x64, when used during training
62
+ """
63
+ steps = timesteps + 1
64
+ t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
65
+ v_start = torch.tensor(start / tau).sigmoid()
66
+ v_end = torch.tensor(end / tau).sigmoid()
67
+ alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
68
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
69
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
70
+ return torch.clip(betas, 0, 0.999)
71
+
72
+
73
+ ACTION_KEYS = [
74
+ "inventory",
75
+ "ESC",
76
+ "hotbar.1",
77
+ "hotbar.2",
78
+ "hotbar.3",
79
+ "hotbar.4",
80
+ "hotbar.5",
81
+ "hotbar.6",
82
+ "hotbar.7",
83
+ "hotbar.8",
84
+ "hotbar.9",
85
+ "forward",
86
+ "back",
87
+ "left",
88
+ "right",
89
+ "cameraX",
90
+ "cameraY",
91
+ "jump",
92
+ "sneak",
93
+ "sprint",
94
+ "swapHands",
95
+ "attack",
96
+ "use",
97
+ "pickItem",
98
+ "drop",
99
+ ]
100
+
101
+
102
+ def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
103
+ actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
104
+ for i, current_actions in enumerate(actions):
105
+ for j, action_key in enumerate(ACTION_KEYS):
106
+ if action_key.startswith("camera"):
107
+ if action_key == "cameraX":
108
+ value = current_actions["camera"][0]
109
+ elif action_key == "cameraY":
110
+ value = current_actions["camera"][1]
111
+ else:
112
+ raise ValueError(f"Unknown camera action key: {action_key}")
113
+ max_val = 20
114
+ bin_size = 0.5
115
+ num_buckets = int(max_val / bin_size)
116
+ value = (value - num_buckets) / num_buckets
117
+ assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
118
+ else:
119
+ value = current_actions[action_key]
120
+ assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
121
+ actions_one_hot[i, j] = value
122
+
123
+ return actions_one_hot
124
+
125
+
126
+ IMAGE_EXTENSIONS = {"png", "jpg", "jpeg"}
127
+ VIDEO_EXTENSIONS = {"mp4"}
128
+
129
+
130
+ def load_prompt(path, video_offset=None, n_prompt_frames=1):
131
+ if path.lower().split(".")[-1] in IMAGE_EXTENSIONS:
132
+ print("prompt is image; ignoring video_offset and n_prompt_frames")
133
+ prompt = read_image(path)
134
+ # add frame dimension
135
+ prompt = rearrange(prompt, "c h w -> 1 c h w")
136
+ elif path.lower().split(".")[-1] in VIDEO_EXTENSIONS:
137
+ prompt = read_video(path, pts_unit="sec")[0]
138
+ if video_offset is not None:
139
+ prompt = prompt[video_offset:]
140
+ prompt = prompt[:n_prompt_frames]
141
+ else:
142
+ raise ValueError(f"unrecognized prompt file extension; expected one in {IMAGE_EXTENSIONS} or {VIDEO_EXTENSIONS}")
143
+ assert prompt.shape[0] == n_prompt_frames, f"input prompt {path} had less than n_prompt_frames={n_prompt_frames} frames"
144
+ prompt = resize(prompt, (360, 640))
145
+ # add batch dimension
146
+ prompt = rearrange(prompt, "t c h w -> 1 t c h w")
147
+ prompt = prompt.float() / 255.0
148
+ return prompt
149
+
150
+
151
+ def load_actions(path, action_offset=None):
152
+ if path.endswith(".actions.pt"):
153
+ actions = one_hot_actions(torch.load(path))
154
+ elif path.endswith(".one_hot_actions.pt"):
155
+ actions = torch.load(path, weights_only=True)
156
+ else:
157
+ raise ValueError("unrecognized action file extension; expected '*.actions.pt' or '*.one_hot_actions.pt'")
158
+ if action_offset is not None:
159
+ actions = actions[action_offset:]
160
+ actions = torch.cat([torch.zeros_like(actions[:1]), actions], dim=0)
161
+ # add batch dimension
162
+ actions = rearrange(actions, "t d -> 1 t d")
163
+ return actions
algorithms/worldmem/models/vae.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ - VQGAN: https://github.com/CompVis/taming-transformers
4
+ - MAE: https://github.com/facebookresearch/mae
5
+ """
6
+
7
+ import numpy as np
8
+ import math
9
+ import functools
10
+ from collections import namedtuple
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange
15
+ from timm.models.vision_transformer import Mlp
16
+ from timm.layers.helpers import to_2tuple
17
+ from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
18
+ from .dit import PatchEmbed
19
+
20
+
21
+ class DiagonalGaussianDistribution(object):
22
+ def __init__(self, parameters, deterministic=False, dim=1):
23
+ self.parameters = parameters
24
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
25
+ if dim == 1:
26
+ self.dims = [1, 2, 3]
27
+ elif dim == 2:
28
+ self.dims = [1, 2]
29
+ else:
30
+ raise NotImplementedError
31
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
32
+ self.deterministic = deterministic
33
+ self.std = torch.exp(0.5 * self.logvar)
34
+ self.var = torch.exp(self.logvar)
35
+ if self.deterministic:
36
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
37
+
38
+ def sample(self):
39
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
40
+ return x
41
+
42
+ def mode(self):
43
+ return self.mean
44
+
45
+
46
+ class Attention(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim,
50
+ num_heads,
51
+ frame_height,
52
+ frame_width,
53
+ qkv_bias=False,
54
+ ):
55
+ super().__init__()
56
+ self.num_heads = num_heads
57
+ head_dim = dim // num_heads
58
+ self.frame_height = frame_height
59
+ self.frame_width = frame_width
60
+
61
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
62
+ self.proj = nn.Linear(dim, dim)
63
+
64
+ rotary_freqs = RotaryEmbedding(
65
+ dim=head_dim // 4,
66
+ freqs_for="pixel",
67
+ max_freq=frame_height * frame_width,
68
+ ).get_axial_freqs(frame_height, frame_width)
69
+ self.register_buffer("rotary_freqs", rotary_freqs, persistent=False)
70
+
71
+ def forward(self, x):
72
+ B, N, C = x.shape
73
+ assert N == self.frame_height * self.frame_width
74
+
75
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
76
+
77
+ q = rearrange(
78
+ q,
79
+ "b (H W) (h d) -> b h H W d",
80
+ H=self.frame_height,
81
+ W=self.frame_width,
82
+ h=self.num_heads,
83
+ )
84
+ k = rearrange(
85
+ k,
86
+ "b (H W) (h d) -> b h H W d",
87
+ H=self.frame_height,
88
+ W=self.frame_width,
89
+ h=self.num_heads,
90
+ )
91
+ v = rearrange(
92
+ v,
93
+ "b (H W) (h d) -> b h H W d",
94
+ H=self.frame_height,
95
+ W=self.frame_width,
96
+ h=self.num_heads,
97
+ )
98
+
99
+ q = apply_rotary_emb(self.rotary_freqs, q)
100
+ k = apply_rotary_emb(self.rotary_freqs, k)
101
+
102
+ q = rearrange(q, "b h H W d -> b h (H W) d")
103
+ k = rearrange(k, "b h H W d -> b h (H W) d")
104
+ v = rearrange(v, "b h H W d -> b h (H W) d")
105
+
106
+ x = F.scaled_dot_product_attention(q, k, v)
107
+ x = rearrange(x, "b h N d -> b N (h d)")
108
+
109
+ x = self.proj(x)
110
+ return x
111
+
112
+
113
+ class AttentionBlock(nn.Module):
114
+ def __init__(
115
+ self,
116
+ dim,
117
+ num_heads,
118
+ frame_height,
119
+ frame_width,
120
+ mlp_ratio=4.0,
121
+ qkv_bias=False,
122
+ attn_causal=False,
123
+ act_layer=nn.GELU,
124
+ norm_layer=nn.LayerNorm,
125
+ ):
126
+ super().__init__()
127
+ self.norm1 = norm_layer(dim)
128
+ self.attn = Attention(
129
+ dim,
130
+ num_heads,
131
+ frame_height,
132
+ frame_width,
133
+ qkv_bias=qkv_bias,
134
+ )
135
+ self.norm2 = norm_layer(dim)
136
+ mlp_hidden_dim = int(dim * mlp_ratio)
137
+ self.mlp = Mlp(
138
+ in_features=dim,
139
+ hidden_features=mlp_hidden_dim,
140
+ act_layer=act_layer,
141
+ )
142
+
143
+ def forward(self, x):
144
+ x = x + self.attn(self.norm1(x))
145
+ x = x + self.mlp(self.norm2(x))
146
+ return x
147
+
148
+
149
+ class AutoencoderKL(nn.Module):
150
+ def __init__(
151
+ self,
152
+ latent_dim,
153
+ input_height=256,
154
+ input_width=256,
155
+ patch_size=16,
156
+ enc_dim=768,
157
+ enc_depth=6,
158
+ enc_heads=12,
159
+ dec_dim=768,
160
+ dec_depth=6,
161
+ dec_heads=12,
162
+ mlp_ratio=4.0,
163
+ norm_layer=functools.partial(nn.LayerNorm, eps=1e-6),
164
+ use_variational=True,
165
+ **kwargs,
166
+ ):
167
+ super().__init__()
168
+ self.input_height = input_height
169
+ self.input_width = input_width
170
+ self.patch_size = patch_size
171
+ self.seq_h = input_height // patch_size
172
+ self.seq_w = input_width // patch_size
173
+ self.seq_len = self.seq_h * self.seq_w
174
+ self.patch_dim = 3 * patch_size**2
175
+
176
+ self.latent_dim = latent_dim
177
+ self.enc_dim = enc_dim
178
+ self.dec_dim = dec_dim
179
+
180
+ # patch
181
+ self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim)
182
+
183
+ # encoder
184
+ self.encoder = nn.ModuleList(
185
+ [
186
+ AttentionBlock(
187
+ enc_dim,
188
+ enc_heads,
189
+ self.seq_h,
190
+ self.seq_w,
191
+ mlp_ratio,
192
+ qkv_bias=True,
193
+ norm_layer=norm_layer,
194
+ )
195
+ for i in range(enc_depth)
196
+ ]
197
+ )
198
+ self.enc_norm = norm_layer(enc_dim)
199
+
200
+ # bottleneck
201
+ self.use_variational = use_variational
202
+ mult = 2 if self.use_variational else 1
203
+ self.quant_conv = nn.Linear(enc_dim, mult * latent_dim)
204
+ self.post_quant_conv = nn.Linear(latent_dim, dec_dim)
205
+
206
+ # decoder
207
+ self.decoder = nn.ModuleList(
208
+ [
209
+ AttentionBlock(
210
+ dec_dim,
211
+ dec_heads,
212
+ self.seq_h,
213
+ self.seq_w,
214
+ mlp_ratio,
215
+ qkv_bias=True,
216
+ norm_layer=norm_layer,
217
+ )
218
+ for i in range(dec_depth)
219
+ ]
220
+ )
221
+ self.dec_norm = norm_layer(dec_dim)
222
+ self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch
223
+
224
+ # initialize this weight first
225
+ self.initialize_weights()
226
+
227
+ def initialize_weights(self):
228
+ # initialization
229
+ # initialize nn.Linear and nn.LayerNorm
230
+ self.apply(self._init_weights)
231
+
232
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
233
+ w = self.patch_embed.proj.weight.data
234
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
235
+
236
+ def _init_weights(self, m):
237
+ if isinstance(m, nn.Linear):
238
+ # we use xavier_uniform following official JAX ViT:
239
+ nn.init.xavier_uniform_(m.weight)
240
+ if isinstance(m, nn.Linear) and m.bias is not None:
241
+ nn.init.constant_(m.bias, 0.0)
242
+ elif isinstance(m, nn.LayerNorm):
243
+ nn.init.constant_(m.bias, 0.0)
244
+ nn.init.constant_(m.weight, 1.0)
245
+
246
+ def patchify(self, x):
247
+ # patchify
248
+ bsz, _, h, w = x.shape
249
+ x = x.reshape(
250
+ bsz,
251
+ 3,
252
+ self.seq_h,
253
+ self.patch_size,
254
+ self.seq_w,
255
+ self.patch_size,
256
+ ).permute([0, 1, 3, 5, 2, 4]) # [b, c, h, p, w, p] --> [b, c, p, p, h, w]
257
+ x = x.reshape(bsz, self.patch_dim, self.seq_h, self.seq_w) # --> [b, cxpxp, h, w]
258
+ x = x.permute([0, 2, 3, 1]).reshape(bsz, self.seq_len, self.patch_dim) # --> [b, hxw, cxpxp]
259
+ return x
260
+
261
+ def unpatchify(self, x):
262
+ bsz = x.shape[0]
263
+ # unpatchify
264
+ x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute([0, 3, 1, 2]) # [b, h, w, cxpxp] --> [b, cxpxp, h, w]
265
+ x = x.reshape(
266
+ bsz,
267
+ 3,
268
+ self.patch_size,
269
+ self.patch_size,
270
+ self.seq_h,
271
+ self.seq_w,
272
+ ).permute([0, 1, 4, 2, 5, 3]) # [b, c, p, p, h, w] --> [b, c, h, p, w, p]
273
+ x = x.reshape(
274
+ bsz,
275
+ 3,
276
+ self.input_height,
277
+ self.input_width,
278
+ ) # [b, c, hxp, wxp]
279
+ return x
280
+
281
+ def encode(self, x):
282
+ # patchify
283
+ x = self.patch_embed(x)
284
+
285
+ # encoder
286
+ for blk in self.encoder:
287
+ x = blk(x)
288
+ x = self.enc_norm(x)
289
+
290
+ # bottleneck
291
+ moments = self.quant_conv(x)
292
+ if not self.use_variational:
293
+ moments = torch.cat((moments, torch.zeros_like(moments)), 2)
294
+ posterior = DiagonalGaussianDistribution(moments, deterministic=(not self.use_variational), dim=2)
295
+ return posterior
296
+
297
+ def decode(self, z):
298
+ # bottleneck
299
+ z = self.post_quant_conv(z)
300
+
301
+ # decoder
302
+ for blk in self.decoder:
303
+ z = blk(z)
304
+ z = self.dec_norm(z)
305
+
306
+ # predictor
307
+ z = self.predictor(z)
308
+
309
+ # unpatchify
310
+ dec = self.unpatchify(z)
311
+ return dec
312
+
313
+ def autoencode(self, input, sample_posterior=True):
314
+ posterior = self.encode(input)
315
+ if self.use_variational and sample_posterior:
316
+ z = posterior.sample()
317
+ else:
318
+ z = posterior.mode()
319
+ dec = self.decode(z)
320
+ return dec, posterior, z
321
+
322
+ def get_input(self, batch, k):
323
+ x = batch[k]
324
+ if len(x.shape) == 3:
325
+ x = x[..., None]
326
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
327
+ return x
328
+
329
+ def forward(self, inputs, labels, split="train"):
330
+ rec, post, latent = self.autoencode(inputs)
331
+ return rec, post, latent
332
+
333
+ def get_last_layer(self):
334
+ return self.predictor.weight
335
+
336
+
337
+ def ViT_L_20_Shallow_Encoder(**kwargs):
338
+ if "latent_dim" in kwargs:
339
+ latent_dim = kwargs.pop("latent_dim")
340
+ else:
341
+ latent_dim = 16
342
+ return AutoencoderKL(
343
+ latent_dim=latent_dim,
344
+ patch_size=20,
345
+ enc_dim=1024,
346
+ enc_depth=6,
347
+ enc_heads=16,
348
+ dec_dim=1024,
349
+ dec_depth=12,
350
+ dec_heads=16,
351
+ input_height=360,
352
+ input_width=640,
353
+ **kwargs,
354
+ )
355
+
356
+
357
+ VAE_models = {
358
+ "vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder,
359
+ }
configurations/algorithm/base_algo.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class
2
+
3
+ debug: ${debug} # inherited from configurations/config.yaml
configurations/algorithm/base_pytorch_algo.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ defaults:
2
+ - base_algo # inherits from configurations/algorithm/base_algo.yaml
3
+
4
+ lr: ${experiment.training.lr}
configurations/algorithm/base_video_dit.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - df_base
3
+
4
+ n_frames: ${dataset.n_frames}
5
+ frame_skip: ${dataset.frame_skip}
6
+ metadata: ${dataset.metadata}
7
+
8
+ # training hyperparameters
9
+ weight_decay: 2e-3
10
+ warmup_steps: 1000
11
+ optimizer_beta: [0.9, 0.99]
12
+ action_cond_dim: 25
13
+ use_plucker: true
14
+
15
+ diffusion:
16
+ # training
17
+ beta_schedule: sigmoid
18
+ objective: pred_v
19
+ use_fused_snr: True
20
+ cum_snr_decay: 0.96
21
+ clip_noise: 20.
22
+ # sampling
23
+ sampling_timesteps: 20
24
+ ddim_sampling_eta: 0.0
25
+ stabilization_level: 15
26
+ # architecture
27
+ architecture:
28
+ network_size: 64
29
+ attn_heads: 4
30
+ attn_dim_head: 64
31
+ dim_mults: [1, 2, 4, 8]
32
+ resolution: ${dataset.resolution}
33
+ attn_resolutions: [16, 32, 64, 128]
34
+ use_init_temporal_attn: True
35
+ use_linear_attn: True
36
+ time_emb_type: rotary
configurations/algorithm/dememwm_memory_dit.yaml ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ defaults:
3
+ - base_video_dit
4
+ - _self_
5
+
6
+ _name: dememwm_memory_dit
7
+
8
+ # Standalone Memory-DiT path. Do not route through old SSM-memory config.
9
+ memory_token_cross_attention: true
10
+ memory_cross_attn_layers: null
11
+ memory_condition_length: 0
12
+ pose_cond_dim: 5
13
+ log_video: false
14
+
15
+ dememwm:
16
+ enabled: true
17
+ training_stage: stage_1 # fallback only when curriculum.enabled=false
18
+ debug_force_all_streams: false
19
+ curriculum:
20
+ enabled: true
21
+ full_stage_start_step: 60000
22
+ freeze_vae: true
23
+ dit_freeze:
24
+ enabled: true
25
+ lr:
26
+ dememwm_modules: 1.0e-4
27
+ memory_adapters: 1.0e-4
28
+ full_dit: 1.0e-5
29
+ # Current Conv2D memory projectors preserve latent H,W=(18,32).
30
+ # Pool sizes are resolved from projected spatial grid size and downsample ratios.
31
+ token_patch_size: 2
32
+ anchor:
33
+ enabled: true
34
+ anchor_indices: [0, 1, 2, 3]
35
+ allow_generated_as_anchor: false
36
+ diverse_selection: true
37
+ compress:
38
+ downsample_ratio: 4
39
+ dynamic:
40
+ enabled: true
41
+ exclude_latest_local_frames: 4
42
+ recent_frames: 8
43
+ conv_kernel_t: 3
44
+ conv_stride_t: 2
45
+ revisit:
46
+ enabled: true
47
+ deterministic_pose_retrieval: true
48
+ fov_overlap_threshold: 0.30
49
+ high_quality_fov_threshold: 0.70
50
+ plucker_weight: 0.10
51
+ max_frames: 2
52
+ # FoV geometry for coverage-based retrieval scoring.
53
+ # fov_half_h/v: half-angles (degrees) of the horizontal/vertical field of view.
54
+ # fov_radius: world-space radius of the sample sphere.
55
+ # fov_{yaw,pitch,depth}_samples: grid resolution for FoV point sampling.
56
+ fov_half_h: 52.5 # 105 deg total horizontal FoV
57
+ fov_half_v: 37.5 # 75 deg total vertical FoV
58
+ fov_radius: 30.0
59
+ fov_yaw_samples: 25
60
+ fov_pitch_samples: 20
61
+ fov_depth_samples: 20
62
+ pose_preselect_topk: 64
63
+ # Plucker descriptor grid for secondary pose-similarity scoring.
64
+ plucker_grid_h: 4
65
+ plucker_grid_w: 4
66
+ plucker_focal_length: 0.35
67
+ compress:
68
+ downsample_ratio: 4
69
+ stage_policy:
70
+ noise_bucket_logging: true
71
+ eval_ablation:
72
+ enabled: false
73
+ branch: A_plus_D_plus_R_normal
74
+ generated_history_proxy:
75
+ enabled: false
76
+ start_step: 0
77
+ ramp_steps: 1
78
+ max_prob: 0.0
79
+ noise_std: 0.25
80
+ dropout_prob: 0.0
81
+ injection:
82
+ dit_hidden_size: 1024
83
+ anchor_gate: 1.0
84
+ dynamic_gate: 1.0
85
+ revisit_gate: 1.0
86
+ cache:
87
+ enabled: true
88
+ device: cpu
89
+ keep_raw_latents: all
90
+ keep_compressed_records: true
91
+ keep_prefix_anchors: true
92
+ eviction_policy: none
93
+ no_evict: true
94
+ clear_between_videos: true
95
+ max_records: null
96
+ max_slots: null
97
+ on_capacity_exceeded: warn
98
+ checkpoint:
99
+ strict_dememwm_eval_load: true
100
+
101
+ diffusion:
102
+ architecture:
103
+ network_size: 64
configurations/algorithm/df_base.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_pytorch_algo
3
+
4
+ # dataset-dependent configurations
5
+ x_shape: ${dataset.observation_shape}
6
+ frame_stack: 1
7
+ frame_skip: 1
8
+ data_mean: ${dataset.data_mean}
9
+ data_std: ${dataset.data_std}
10
+ external_cond_dim: 0 #${dataset.action_dim}
11
+ context_frames: ${dataset.context_length}
12
+ # training hyperparameters
13
+ weight_decay: 1e-4
14
+ warmup_steps: 10000
15
+ optimizer_beta: [0.9, 0.999]
16
+ # diffusion-related
17
+ uncertainty_scale: 1
18
+ guidance_scale: 0.0
19
+ chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size
20
+ scheduling_matrix: autoregressive
21
+ noise_level: random_all
22
+ causal: True
23
+
24
+ diffusion:
25
+ # training
26
+ objective: pred_x0
27
+ beta_schedule: cosine
28
+ schedule_fn_kwargs: {}
29
+ clip_noise: 20.0
30
+ use_snr: False
31
+ use_cum_snr: False
32
+ use_fused_snr: False
33
+ snr_clip: 5.0
34
+ cum_snr_decay: 0.98
35
+ timesteps: 1000
36
+ # sampling
37
+ sampling_timesteps: 50 # fixme, numer of diffusion steps, should be increased
38
+ ddim_sampling_eta: 1.0
39
+ stabilization_level: 10
40
+ # architecture
41
+ architecture:
42
+ network_size: 64
configurations/dataset/base_dataset.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # This will be passed as the cfg to Dataset.__init__(cfg) of your dataset class
2
+
3
+ debug: ${debug} # inherited from configurations/config.yaml
configurations/dataset/base_video.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_dataset
3
+
4
+ metadata: "data/${dataset.name}/metadata.json"
5
+ data_mean: "data/${dataset.name}/data_mean.npy"
6
+ data_std: "data/${dataset.name}/data_std.npy"
7
+ save_dir: ???
8
+ n_frames: 32
9
+ context_length: 4
10
+ resolution: 128
11
+ observation_shape: [3, "${dataset.resolution}", "${dataset.resolution}"]
12
+ external_cond_dim: 0
13
+ validation_multiplier: 1
14
+ frame_skip: 1
configurations/dataset/video_minecraft.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_video
3
+
4
+ save_dir: data/minecraft_simple_backforward
5
+ n_frames: 16 # TODO: increase later
6
+ resolution: 128
7
+ data_mean: 0.5
8
+ data_std: 0.5
9
+ action_cond_dim: 25
10
+ context_length: 1
11
+ frame_skip: 1
12
+ validation_multiplier: 1
13
+
14
+ _name: video_minecraft_oasis
configurations/dataset/video_minecraft_latent.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - video_minecraft
3
+
4
+ precomputed_feature_dir: /share_1/users/bonan_ding/worldmem_data/minecraft/vae_features
5
+
6
+ _name: video_minecraft_latent
configurations/experiment/base_experiment.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ debug: ${debug} # inherited from configurations/config.yaml
2
+ tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a subset of them.
configurations/experiment/base_pytorch.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inherites from base_experiment.yaml
2
+ # most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html
3
+
4
+ defaults:
5
+ - base_experiment
6
+
7
+ tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a subset of them.
8
+ num_nodes: 1 # number of gpu servers used in large scale distributed training
9
+
10
+ training:
11
+ precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable
12
+ compile: False # whether to compile the model with torch.compile
13
+ lr: 0.001 # learning rate
14
+ batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training
15
+ max_epochs: 1000 # set to -1 to train forever
16
+ max_steps: -1 # set to -1 to train forever, will override max_epochs
17
+ max_time: null # set to something like "00:12:00:00" to enable
18
+ data:
19
+ num_workers: 4 # number of CPU threads for data preprocessing.
20
+ shuffle: True # whether training data will be shuffled
21
+ optim:
22
+ accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop
23
+ gradient_clip_val: 1.0 # clip gradients with norm above this value, set to 0 to disable
24
+ checkpointing:
25
+ # these are arguments to pytorch lightning's callback, `ModelCheckpoint` class
26
+ every_n_train_steps: 5000 # save a checkpoint every n train steps
27
+ every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``
28
+ train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``.
29
+ save_last: True # keep last.ckpt for automatic interrupted-run resume
30
+ enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones.
31
+
32
+ validation:
33
+ precision: 16-mixed
34
+ compile: False # whether to compile the model with torch.compile
35
+ batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
36
+ val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set)
37
+ val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null.
38
+ limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation.
39
+ inference_mode: True # whether to run validation in inference mode (enable_grad won't work!)
40
+ data:
41
+ num_workers: 4 # number of CPU threads for data preprocessing, for validation.
42
+ shuffle: False # whether validation data will be shuffled
43
+
44
+ test:
45
+ precision: 16-mixed
46
+ compile: False # whether to compile the model with torch.compile
47
+ batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training
48
+ limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test.
49
+ data:
50
+ num_workers: 4 # number of CPU threads for data preprocessing, for test.
51
+ shuffle: False # whether test data will be shuffled
configurations/experiment/exp_video.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_pytorch
3
+
4
+ tasks: [training]
5
+
6
+ training:
7
+ lr: 2e-5
8
+ precision: 16-mixed
9
+ batch_size: 4
10
+ max_epochs: -1
11
+ max_steps: 2000005
12
+ checkpointing:
13
+ every_n_train_steps: 2500
14
+ optim:
15
+ gradient_clip_val: 1.0
16
+
17
+ validation:
18
+ val_every_n_step: 2500
19
+ val_every_n_epoch: null
20
+ batch_size: 4
21
+ limit_batch: 1
22
+
23
+ test:
24
+ limit_batch: 1
25
+ batch_size: 1
26
+
27
+ logging:
28
+ metrics:
29
+ # - fvd
30
+ # - fid
31
+ # - lpips
configurations/training.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration parsing starts here
2
+ defaults:
3
+ - experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme]
4
+ - dataset: video_minecraft # dataset yaml file name in configurations/dataset folder [fixme]
5
+ - algorithm: dememwm_memory_dit # algorithm yaml file name in configurations/algorithm folder [fixme]
6
+ - cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute
7
+
8
+ debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm
9
+
10
+ wandb:
11
+ entity: turlin # wandb account name / organization name [fixme]
12
+ project: DeMemWM # wandb project name; if not provided, defaults to root folder name [fixme]
13
+ mode: offline # set wandb logging to online, offline or dryrun
14
+
15
+ resume: null # wandb run id to resume logging and loading checkpoint from
16
+ load: null # wandb run id containing checkpoint or a path to a checkpoint file
17
+ auto_resume: true # automatically resume training from output_dir/checkpoints when available
18
+ resume_ckpt_path: null # explicit full Lightning checkpoint path for deterministic training resume
datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import MinecraftVideoDataset, MinecraftVideoLatentDataset
datasets/video/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .minecraft_video_dataset import MinecraftVideoDataset
2
+ from .minecraft_video_latent_dataset import MinecraftVideoLatentDataset