diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..aad6d9b6871b343258781b60f30eedbbc13186ca --- /dev/null +++ b/.gitignore @@ -0,0 +1,27 @@ +__pycache__/ +*.py[cod] +*$py.class + +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +.ipynb_checkpoints/ + +.env +.venv/ +venv/ + +outputs/ +slurm_logs/ +latest-run +.wandb_run_id +.wandb_osh_command_dir/ +wandb/ + +*.log +*.ckpt +*.pt +*.pth +*.safetensors + +data/ diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..f6534a33dca6cdcdc3892347e0959217530359aa --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,14 @@ +# S-Lab License 1.0 + +Copyright 2025 S-Lab + +Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\ +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. + + +--- +For the commercial use of the code, please consult Prof. Chen Change Loy (ccloy@ntu.edu.sg) diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2f48b5699725629ccb946b9822effc9fca627ddd --- /dev/null +++ b/README.md @@ -0,0 +1,37 @@ +# DeMemWM + +This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research template [repo](https://github.com/buoyancy99/research-template). By its MIT license, you must keep this sentence in `README.md` and the `LICENSE` file to credit the author. + +DeMemWM is a Memory-DiT video prediction project built on the local research template. The primary algorithm entry point is `DeMemWMMinecraft`, registered through the Hydra algorithm config `dememwm_memory_dit`. + +## Quick Start + +```bash +python -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt +python -m pytest tests +``` + +Run a local offline experiment after setting the dataset path in `configurations/dataset/video_minecraft.yaml`: + +```bash +python main.py +name=dememwm_debug algorithm=dememwm_memory_dit wandb.mode=offline +``` + +Use `resume_ckpt_path=/path/to/checkpoint.ckpt` for deterministic checkpoint resume, or keep `auto_resume=true` to resume from `output_dir/checkpoints` when available. + +## Layout + +- `algorithms/worldmem/dememwm/`: DeMemWM memory construction, retrieval, scheduling, diagnostics, and injection code. +- `algorithms/worldmem/dememwm_memory_dit.py`: primary DeMemWM algorithm class. +- `configurations/algorithm/dememwm_memory_dit.yaml`: consumed DeMemWM training and evaluation contract. +- `scripts/`: Slurm and inspection scripts using the DeMemWM naming. +- `tests/`: static and unit coverage for DeMemWM config, retrieval, compression, schedules, and training behavior. + +## Reproducibility Notes + +- Keep `wandb.mode=offline` for local reproducible runs that do not depend on network access. +- Set `seed=` on the command line to seed Lightning and dataloader workers. +- Runtime artifacts such as `outputs/`, `slurm_logs/`, Python caches, checkpoints, and local datasets are ignored by git. +- The default Hydra training config selects `algorithm: dememwm_memory_dit`. diff --git a/algorithms/__init__.py b/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/algorithms/common/__init__.py b/algorithms/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/algorithms/common/base_algo.py b/algorithms/common/base_algo.py new file mode 100644 index 0000000000000000000000000000000000000000..cd46eb877ff794907061a34b934df2f0f4e526be --- /dev/null +++ b/algorithms/common/base_algo.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Tuple, Union + +from omegaconf import DictConfig + + +class BaseAlgo(ABC): + """ + A base class for generic algorithms. + """ + + def __init__(self, cfg: DictConfig): + super().__init__() + self.cfg = cfg + + @abstractmethod + def run(*args: Any, **kwargs: Any) -> Any: + """ + Run the algorithm. + """ + raise NotImplementedError diff --git a/algorithms/common/base_pytorch_algo.py b/algorithms/common/base_pytorch_algo.py new file mode 100644 index 0000000000000000000000000000000000000000..706a339fda681dd8735525292706f2e929a5a9a8 --- /dev/null +++ b/algorithms/common/base_pytorch_algo.py @@ -0,0 +1,277 @@ +from abc import ABC, abstractmethod +import warnings +import random +from typing import Any, Union, Sequence, Optional + +from lightning.pytorch.utilities.types import STEP_OUTPUT +from omegaconf import DictConfig +import lightning.pytorch as pl +import torch +import numpy as np +from PIL import Image +import wandb +import einops + + +class BasePytorchAlgo(pl.LightningModule, ABC): + """ + A base class for Pytorch algorithms using Pytorch Lightning. + See https://lightning.ai/docs/pytorch/stable/starter/introduction.html for more details. + """ + + def __init__(self, cfg: DictConfig): + super().__init__() + self.cfg = cfg + self._build_model() + + @abstractmethod + def _build_model(self): + """ + Create all pytorch nn.Modules here. + """ + raise NotImplementedError + + @abstractmethod + def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or + logger. + + Args: + batch: The output of your data iterable, normally a :class:`~torch.utils.data.DataLoader`. + batch_idx: The index of this batch. + dataloader_idx: (only if multiple dataloaders used) The index of the dataloader that produced this batch. + + Return: + Any of these options: + - :class:`~torch.Tensor` - The loss tensor + - ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``. + - ``None`` - Skip to the next batch. This is only supported for automatic optimization. + This is not supported for multi-GPU, TPU, IPU, or DeepSpeed. + + In this step you'd normally do the forward pass and calculate the loss for a batch. + You can also do fancier things like multiple forward passes or something model specific. + + Example:: + + def training_step(self, batch, batch_idx): + x, y, z = batch + out = self.encoder(x) + loss = self.loss(out, x) + return loss + + To use multiple optimizers, you can switch to 'manual optimization' and control their stepping: + + .. code-block:: python + + def __init__(self): + super().__init__() + self.automatic_optimization = False + + + # Multiple optimizers (e.g.: GANs) + def training_step(self, batch, batch_idx): + opt1, opt2 = self.optimizers() + + # do training_step with encoder + ... + opt1.step() + # do training_step with decoder + ... + opt2.step() + + Note: + When ``accumulate_grad_batches`` > 1, the loss returned here will be automatically + normalized by ``accumulate_grad_batches`` internally. + + """ + return super().training_step(*args, **kwargs) + + def configure_optimizers(self): + """ + Return an optimizer. If you need to use more than one optimizer, refer to pytorch lightning documentation: + https://lightning.ai/docs/pytorch/stable/common/optimization.html + """ + parameters = self.parameters() + return torch.optim.Adam(parameters, lr=self.cfg.lr) + + def on_save_checkpoint(self, checkpoint): + checkpoint["rng_states"] = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "torch": torch.get_rng_state(), + "cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None, + } + + def on_load_checkpoint(self, checkpoint): + rng_states = checkpoint.get("rng_states") + if rng_states is None: + if getattr(self, "_strict_resume_state", False): + raise RuntimeError( + "Cannot deterministically resume because this checkpoint has no rng_states entry. " + "Use a checkpoint created after automatic resume support was added, or start a fresh run." + ) + return + + random.setstate(rng_states["python"]) + np.random.set_state(rng_states["numpy"]) + torch.set_rng_state(rng_states["torch"]) + if torch.cuda.is_available() and rng_states["cuda"] is not None: + torch.cuda.set_rng_state_all(rng_states["cuda"]) + + def log_video( + self, + key: str, + video: Union[np.ndarray, torch.Tensor], + mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None, + std: Union[np.ndarray, torch.Tensor, Sequence, float] = None, + fps: int = 5, + format: str = "mp4", + ): + """ + Log video to wandb. WandbLogger in pytorch lightning does not support video logging yet, so we call wandb directly. + + Args: + video: a numpy array or tensor, either in form (time, channel, height, width) or in the form + (batch, time, channel, height, width). The content must be be in 0-255 if under dtype uint8 + or [0, 1] otherwise. + mean: optional, the mean to unnormalize video tensor, assuming unnormalized data is in [0, 1]. + std: optional, the std to unnormalize video tensor, assuming unnormalized data is in [0, 1]. + key: the name of the video. + fps: the frame rate of the video. + format: the format of the video. Can be either "mp4" or "gif". + """ + + if isinstance(video, torch.Tensor): + video = video.detach().cpu().numpy() + + expand_shape = [1] * (len(video.shape) - 2) + [3, 1, 1] + if std is not None: + if isinstance(std, (float, int)): + std = [std] * 3 + if isinstance(std, torch.Tensor): + std = std.detach().cpu().numpy() + std = np.array(std).reshape(*expand_shape) + video = video * std + if mean is not None: + if isinstance(mean, (float, int)): + mean = [mean] * 3 + if isinstance(mean, torch.Tensor): + mean = mean.detach().cpu().numpy() + mean = np.array(mean).reshape(*expand_shape) + video = video + mean + + if video.dtype != np.uint8: + video = np.clip(video, a_min=0, a_max=1) * 255 + video = video.astype(np.uint8) + + self.logger.experiment.log( + { + key: wandb.Video(video, fps=fps, format=format), + }, + step=self.global_step, + ) + + def log_image( + self, + key: str, + image: Union[np.ndarray, torch.Tensor, Image.Image, Sequence[Image.Image]], + mean: Union[np.ndarray, torch.Tensor, Sequence, float] = None, + std: Union[np.ndarray, torch.Tensor, Sequence, float] = None, + **kwargs: Any, + ): + """ + Log image(s) using WandbLogger. + Args: + key: the name of the video. + image: a single image or a batch of images. If a batch of images, the shape should be (batch, channel, height, width). + mean: optional, the mean to unnormalize image tensor, assuming unnormalized data is in [0, 1]. + std: optional, the std to unnormalize tensor, assuming unnormalized data is in [0, 1]. + kwargs: optional, WandbLogger log_image kwargs, such as captions=xxx. + """ + if isinstance(image, Image.Image): + image = [image] + elif len(image) and not isinstance(image[0], Image.Image): + if isinstance(image, torch.Tensor): + image = image.detach().cpu().numpy() + + if len(image.shape) == 3: + image = image[None] + + if image.shape[1] == 3: + if image.shape[-1] == 3: + warnings.warn(f"Two channels in shape {image.shape} have size 3, assuming channel first.") + image = einops.rearrange(image, "b c h w -> b h w c") + + if std is not None: + if isinstance(std, (float, int)): + std = [std] * 3 + if isinstance(std, torch.Tensor): + std = std.detach().cpu().numpy() + std = np.array(std)[None, None, None] + image = image * std + if mean is not None: + if isinstance(mean, (float, int)): + mean = [mean] * 3 + if isinstance(mean, torch.Tensor): + mean = mean.detach().cpu().numpy() + mean = np.array(mean)[None, None, None] + image = image + mean + + if image.dtype != np.uint8: + image = np.clip(image, a_min=0.0, a_max=1.0) * 255 + image = image.astype(np.uint8) + image = [img for img in image] + + self.logger.log_image(key=key, images=image, **kwargs) + + def log_gradient_stats(self): + """Log gradient statistics such as the mean or std of norm.""" + + with torch.no_grad(): + grad_norms = [] + gpr = [] # gradient-to-parameter ratio + for param in self.parameters(): + if param.grad is not None: + grad_norms.append(torch.norm(param.grad).item()) + gpr.append(torch.norm(param.grad) / torch.norm(param)) + if len(grad_norms) == 0: + return + grad_norms = torch.tensor(grad_norms) + gpr = torch.tensor(gpr) + self.log_dict( + { + "train/grad_norm/min": grad_norms.min(), + "train/grad_norm/max": grad_norms.max(), + "train/grad_norm/std": grad_norms.std(), + "train/grad_norm/mean": grad_norms.mean(), + "train/grad_norm/median": torch.median(grad_norms), + "train/gpr/min": gpr.min(), + "train/gpr/max": gpr.max(), + "train/gpr/std": gpr.std(), + "train/gpr/mean": gpr.mean(), + "train/gpr/median": torch.median(gpr), + } + ) + + def register_data_mean_std( + self, mean: Union[str, float, Sequence], std: Union[str, float, Sequence], namespace: str = "data" + ): + """ + Register mean and std of data as tensor buffer. + + Args: + mean: the mean of data. + std: the std of data. + namespace: the namespace of the registered buffer. + """ + for k, v in [("mean", mean), ("std", std)]: + if isinstance(v, str): + if v.endswith(".npy"): + v = torch.from_numpy(np.load(v)) + elif v.endswith(".pt"): + v = torch.load(v) + else: + raise ValueError(f"Unsupported file type {v.split('.')[-1]}.") + else: + v = torch.tensor(v) + self.register_buffer(f"{namespace}_{k}", v.float().to(self.device)) diff --git a/algorithms/common/metrics/__init__.py b/algorithms/common/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..61c0d28943dd29d7aeb4b1121939b573f7989e9b --- /dev/null +++ b/algorithms/common/metrics/__init__.py @@ -0,0 +1,3 @@ +from .fid import FrechetInceptionDistance +from .lpips import LearnedPerceptualImagePatchSimilarity +from .fvd import FrechetVideoDistance diff --git a/algorithms/common/metrics/fid.py b/algorithms/common/metrics/fid.py new file mode 100644 index 0000000000000000000000000000000000000000..428a621a58807767650101026576335090d10fc0 --- /dev/null +++ b/algorithms/common/metrics/fid.py @@ -0,0 +1 @@ +from torchmetrics.image.fid import FrechetInceptionDistance diff --git a/algorithms/common/metrics/fvd.py b/algorithms/common/metrics/fvd.py new file mode 100644 index 0000000000000000000000000000000000000000..a502055eff0b19ab8724d1d5cbee38ab85a8ee7c --- /dev/null +++ b/algorithms/common/metrics/fvd.py @@ -0,0 +1,158 @@ +""" +Adopted from https://github.com/cvpr2022-stylegan-v/stylegan-v +Verified to be the same as tf version by https://github.com/universome/fvd-comparison +""" + +import io +import re +import requests +import html +import hashlib +import urllib +import urllib.request +from typing import Any, List, Tuple, Union, Dict +import scipy + +import torch +import torch.nn as nn +import numpy as np + + +def open_url( + url: str, + num_attempts: int = 10, + verbose: bool = True, + return_filename: bool = False, +) -> Any: + """Download the given URL and return a binary-mode file object to access the data.""" + assert num_attempts >= 1 + + # Doesn't look like an URL scheme so interpret it as a local filename. + if not re.match("^[a-z]+://", url): + return url if return_filename else open(url, "rb") + + # Handle file URLs. This code handles unusual file:// patterns that + # arise on Windows: + # + # file:///c:/foo.txt + # + # which would translate to a local '/c:/foo.txt' filename that's + # invalid. Drop the forward slash for such pathnames. + # + # If you touch this code path, you should test it on both Linux and + # Windows. + # + # Some internet resources suggest using urllib.request.url2pathname() but + # but that converts forward slashes to backslashes and this causes + # its own set of problems. + if url.startswith("file://"): + filename = urllib.parse.urlparse(url).path + if re.match(r"^/[a-zA-Z]:", filename): + filename = filename[1:] + return filename if return_filename else open(filename, "rb") + + url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() + + # Download. + url_name = None + url_data = None + with requests.Session() as session: + if verbose: + print("Downloading %s ..." % url, end="", flush=True) + for attempts_left in reversed(range(num_attempts)): + try: + with session.get(url) as res: + res.raise_for_status() + if len(res.content) == 0: + raise IOError("No data received") + + if len(res.content) < 8192: + content_str = res.content.decode("utf-8") + if "download_warning" in res.headers.get("Set-Cookie", ""): + links = [ + html.unescape(link) + for link in content_str.split('"') + if "export=download" in link + ] + if len(links) == 1: + url = requests.compat.urljoin(url, links[0]) + raise IOError("Google Drive virus checker nag") + if "Google Drive - Quota exceeded" in content_str: + raise IOError( + "Google Drive download quota exceeded -- please try again later" + ) + + match = re.search( + r'filename="([^"]*)"', + res.headers.get("Content-Disposition", ""), + ) + url_name = match[1] if match else url + url_data = res.content + if verbose: + print(" done") + break + except KeyboardInterrupt: + raise + except: + if not attempts_left: + if verbose: + print(" failed") + raise + if verbose: + print(".", end="", flush=True) + + # Return data as file object. + assert not return_filename + return io.BytesIO(url_data) + + +def compute_fvd(feats_fake: np.ndarray, feats_real: np.ndarray) -> float: + mu_gen, sigma_gen = compute_stats(feats_fake) + mu_real, sigma_real = compute_stats(feats_real) + + m = np.square(mu_gen - mu_real).sum() + s, _ = scipy.linalg.sqrtm( + np.dot(sigma_gen, sigma_real), disp=False + ) # pylint: disable=no-member + fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) + + return float(fid) + + +def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + mu = feats.mean(axis=0) # [d] + sigma = np.cov(feats, rowvar=False) # [d, d] + + return mu, sigma + + +class FrechetVideoDistance(nn.Module): + def __init__(self): + super().__init__() + detector_url = ( + "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1" + ) + # Return raw features before the softmax layer. + self.detector_kwargs = dict(rescale=False, resize=True, return_features=True) + with open_url(detector_url, verbose=False) as f: + self.detector = torch.jit.load(f).eval() + + @torch.no_grad() + def compute(self, videos_fake: torch.Tensor, videos_real: torch.Tensor): + """ + :param videos_fake: predicted video tensor of shape (frame, batch, channel, height, width) + :param videos_real: ground-truth observation tensor of shape (frame, batch, channel, height, width) + :return: + """ + n_frames, batch_size, c, h, w = videos_fake.shape + if n_frames < 2: + raise ValueError("Video must have more than 1 frame for FVD") + + videos_fake = videos_fake.permute(1, 2, 0, 3, 4).contiguous() + videos_real = videos_real.permute(1, 2, 0, 3, 4).contiguous() + + # detector takes in tensors of shape [batch_size, c, video_len, h, w] with range -1 to 1 + feats_fake = self.detector(videos_fake, **self.detector_kwargs).cpu().numpy() + feats_real = self.detector(videos_real, **self.detector_kwargs).cpu().numpy() + + return compute_fvd(feats_fake, feats_real) diff --git a/algorithms/common/metrics/lpips.py b/algorithms/common/metrics/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..34fc01b7c3375c3efd2b4b3929866104471022eb --- /dev/null +++ b/algorithms/common/metrics/lpips.py @@ -0,0 +1 @@ +from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity diff --git a/algorithms/worldmem/__init__.py b/algorithms/worldmem/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0305ca7933b09d4fa30463d00817d0df3f602269 --- /dev/null +++ b/algorithms/worldmem/__init__.py @@ -0,0 +1 @@ +from .dememwm_memory_dit import DeMemWMMinecraft, DeMemWMMemoryDiTMinecraft diff --git a/algorithms/worldmem/dememwm/__init__.py b/algorithms/worldmem/dememwm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0b63030303e5da0ff3564429be0b6b64eba8004 --- /dev/null +++ b/algorithms/worldmem/dememwm/__init__.py @@ -0,0 +1,18 @@ + +from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors, RevisitRetrievalResult, StreamGateState +from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens +from .compression import CausalConv3DDynamicCompressor, latent_patch_tokens, spatial_pool_tokens +from .retrieval import deterministic_revisit_retrieval +from .schedules import compute_stream_gates, CurriculumState, resolve_curriculum, DeMemWMCurriculumState, resolve_dememwm_curriculum +from .gates import RevisitRawGate +from .cache import StreamingCache, DeMemWMStreamingCache +from .injection import InjectionAdapter, DeMemWMInjectionAdapter + +__all__ = [ + "MemoryRecord", "MemorySourceType", "MemoryStreamTensors", "RevisitRetrievalResult", "StreamGateState", + "CausalMemoryBank", "MemoryBankQuery", "stack_record_tokens", + "CausalConv3DDynamicCompressor", "latent_patch_tokens", "spatial_pool_tokens", + "deterministic_revisit_retrieval", "compute_stream_gates", "CurriculumState", "resolve_curriculum", + "DeMemWMCurriculumState", "resolve_dememwm_curriculum", "RevisitRawGate", + "StreamingCache", "DeMemWMStreamingCache", "InjectionAdapter", "DeMemWMInjectionAdapter", +] diff --git a/algorithms/worldmem/dememwm/algorithm.py b/algorithms/worldmem/dememwm/algorithm.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb8fd6b45dbc0e00c8ae5d7568fe83c9442d1b7 --- /dev/null +++ b/algorithms/worldmem/dememwm/algorithm.py @@ -0,0 +1,2464 @@ + +from __future__ import annotations + +import math +from dataclasses import replace +from typing import Iterable + +import torch +from einops import rearrange + +from .cache import StreamingCache +from .compression import CausalConv3DDynamicCompressor, SpatialConv2DMemoryProjector, latent_patch_tokens, spatial_pool_tokens +from .diagnostics import summarize_eval_ablation_diagnostics, summarize_noise_bucket_diagnostics, summarize_revisit_diagnostics +from .injection import InjectionAdapter +from .memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens +from .negatives import apply_revisit_eval_corruption +from .retrieval import deterministic_revisit_retrieval +from .schedules import EVAL_CORRUPTION_BRANCHES, compute_stream_gates, denoising_fraction_from_noise_levels, noise_bucket_from_denoising_fraction, noise_bucket_from_noise_levels, noise_bucket_ids_from_noise_levels, normalize_eval_ablation_branch, resolve_curriculum +from .types import MemoryRecord, MemorySourceType, MemoryStreamTensors + + +class MemoryDiTMixin: + """Standalone DeMemWM / Memory-DiT mixin. + + Reuses the base video-DiT infrastructure while keeping memory construction and + injection under the standalone `dememwm` package. Legacy memory-method files + are not part of this path. + """ + + strict_key_prefixes = ( + "dememwm_dynamic_compressor.", + "dememwm_anchor_proj.", + "dememwm_revisit_proj.", + "dememwm_revisit_gate.", + ) + strict_key_substrings = ( + ".memory_token_cross_attn.", + ) + _TRAIN_DIAGNOSTIC_LOG_KEYS = frozenset({ + "revisit_candidate_frame_count", + "revisit_pose_preselect_input_count", + "revisit_pose_preselect_selected_count", + "revisit_exact_fov_candidate_count", + "valid_revisit_frame_count", + "valid_revisit_target_count", + "no_valid_revisit_count", + "revisit_selected_frame_count", + "revisit_frame_fov_overlap_mean", + "revisit_best_selected_frame_fov_overlap_mean", + "revisit_best_selected_plucker_overlap_mean", + "revisit_best_selected_gap_frames_mean", + "revisit_gate_raw", + "revisit_gate_eff", + "revisit_learned_gate_mean", + "revisit_effective_gate_mean", + "generated_history_proxy_prob", + "noise_bucket_target_count", + "noise_bucket_high_target_count", + "noise_bucket_mid_target_count", + "noise_bucket_low_target_count", + }) + _VALIDATION_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS | frozenset({ + "cache_records", + "cache_slots", + }) + + def _memory_cfg(self): + return getattr(self.cfg, "dememwm", None) + + def _cfg_get(self, obj, name, default): + if obj is None: + return default + if isinstance(obj, dict): + return obj.get(name, default) + return getattr(obj, name, default) + + def _cfg_has(self, obj, name: str) -> bool: + if obj is None: + return False + if isinstance(obj, dict): + return name in obj + try: + getattr(obj, name) + return True + except Exception: + return False + + def _stage_policy_cfg(self): + return self._cfg_get(self._memory_cfg(), "stage_policy", None) + + def _eval_ablation_cfg(self): + return self._cfg_get(self._memory_cfg(), "eval_ablation", None) + + def _generated_history_proxy_cfg(self): + return self._cfg_get(self._memory_cfg(), "generated_history_proxy", None) + + def _eval_ablation_state(self) -> tuple[bool, str]: + cfg = self._eval_ablation_cfg() + enabled = bool(self._cfg_get(cfg, "enabled", False)) + branch = normalize_eval_ablation_branch(self._cfg_get(cfg, "branch", "A_plus_D_plus_R_normal")) + return enabled, branch + + def _effective_gate_state(self, denoising_fraction: float | None = None, noise_bucket: str | None = None) -> dict: + memory_cfg = self._memory_cfg() + anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) + injection_cfg = self._cfg_get(memory_cfg, "injection", None) + anchor_config_enabled = self._stream_enabled(anchor_cfg) + dynamic_config_enabled = self._stream_enabled(dynamic_cfg) + revisit_config_enabled = self._stream_enabled(revisit_cfg) + curriculum_state = self._curriculum_state() + eval_ablation_enabled, eval_ablation_branch = self._eval_ablation_state() + debug_force = bool(self._cfg_get(memory_cfg, "debug_force_all_streams", False)) + resolved_noise_bucket = noise_bucket or noise_bucket_from_denoising_fraction(denoising_fraction) + gates = compute_stream_gates( + curriculum_state.stage, + denoising_fraction=denoising_fraction, + debug_force_all_streams=debug_force, + anchor_gate=float(self._cfg_get(injection_cfg, "anchor_gate", 1.0)), + dynamic_gate=float(self._cfg_get(injection_cfg, "dynamic_gate", 1.0)), + revisit_gate=float(self._cfg_get(injection_cfg, "revisit_gate", 1.0)), + ) + anchor_effective_enabled = bool(gates.anchor_enabled and anchor_config_enabled) + dynamic_effective_enabled = bool(gates.dynamic_enabled and dynamic_config_enabled) + revisit_stage_config_enabled = bool(gates.revisit_enabled and revisit_config_enabled) + if eval_ablation_enabled: + if eval_ablation_branch == "memory_off": + anchor_effective_enabled = False + dynamic_effective_enabled = False + revisit_stage_config_enabled = False + elif eval_ablation_branch == "A_only": + dynamic_effective_enabled = False + revisit_stage_config_enabled = False + elif eval_ablation_branch == "D_only": + anchor_effective_enabled = False + revisit_stage_config_enabled = False + elif eval_ablation_branch == "A_plus_D": + revisit_stage_config_enabled = False + return { + "curriculum_state": curriculum_state, + "gates": gates, + "resolved_noise_bucket": resolved_noise_bucket, + "anchor_config_enabled": anchor_config_enabled, + "dynamic_config_enabled": dynamic_config_enabled, + "revisit_config_enabled": revisit_config_enabled, + "anchor_effective_enabled": anchor_effective_enabled, + "dynamic_effective_enabled": dynamic_effective_enabled, + "revisit_stage_config_enabled": revisit_stage_config_enabled, + "eval_ablation_enabled": eval_ablation_enabled, + "eval_ablation_branch": eval_ablation_branch, + "force_revisit_off": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_off"), + "force_revisit_on": bool(eval_ablation_enabled and eval_ablation_branch == "R_forced_on"), + } + + def _validate_config_contract(self) -> dict: + if bool(getattr(self, "_dememwm_contract_validated", False)): + return getattr(self, "_last_dememwm_config_diagnostics", {}) + memory_cfg = self._memory_cfg() + if memory_cfg is None: + self._dememwm_contract_validated = True + self._last_dememwm_config_diagnostics = {} + return {} + + stale_sections = [name for name in ("ablation", "memory", "loss", "abstention") if self._cfg_has(memory_cfg, name)] + if stale_sections: + raise ValueError(f"stale DeMemWM config sections are not part of the final contract: {stale_sections}") + ratio_fields = [ + name + for name in ("anchor_ratio", "dynamic_ratio", "revisit_ratio", "revisit_max_ratio") + if self._cfg_has(memory_cfg, name) + ] + if ratio_fields: + raise ValueError(f"standalone DeMemWM uses fixed manual token budgets, not ratio fields: {ratio_fields}") + + anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) + stale_nested = [] + for section_name, section_cfg, field_names in ( + ("anchor", anchor_cfg, ("policy", "topk", "pin_prefix")), + ("dynamic", dynamic_cfg, ("include_generated_recent",)), + ("revisit", revisit_cfg, ("deterministic_only", "min_age_frames", "min_gap_frames", "topk", "max_chunks", "chunk_frames", "min_score", "time_weight", "pose_weight", "latent_weight", "pose_overlap_threshold", "action_overlap_threshold", "generated_penalty", "force_gate_zero_when_invalid")), + ): + stale_nested.extend( + f"{section_name}.{field_name}" for field_name in field_names if self._cfg_has(section_cfg, field_name) + ) + if stale_nested: + raise ValueError(f"stale DeMemWM config fields are not part of the final contract: {stale_nested}") + + exclude_latest_local_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)) + if exclude_latest_local_frames < 0: + raise ValueError("dememwm.dynamic.exclude_latest_local_frames must be non-negative") + if not bool(self._cfg_get(revisit_cfg, "deterministic_pose_retrieval", True)): + raise ValueError("final DeMemWM requires deterministic FOV/Plucker revisit retrieval") + fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30) + if fov_overlap_threshold is not None: + fov_overlap_threshold = float(fov_overlap_threshold) + if fov_overlap_threshold < 0.0: + raise ValueError("dememwm.revisit.fov_overlap_threshold must be non-negative") + high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70)) + if high_quality_fov_threshold < 0.0: + raise ValueError("dememwm.revisit.high_quality_fov_threshold must be non-negative") + plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10)) + if plucker_weight < 0.0: + raise ValueError("dememwm.revisit.plucker_weight must be non-negative") + for field_name, default in ( + ("fov_half_h", 52.5), + ("fov_half_v", 37.5), + ("fov_radius", 30.0), + ("plucker_focal_length", 0.35), + ): + value = float(self._cfg_get(revisit_cfg, field_name, default)) + if value <= 0.0: + raise ValueError(f"dememwm.revisit.{field_name} must be positive") + for field_name, default in ( + ("fov_yaw_samples", 25), + ("fov_pitch_samples", 20), + ("fov_depth_samples", 20), + ("plucker_grid_h", 4), + ("plucker_grid_w", 4), + ): + value = int(self._cfg_get(revisit_cfg, field_name, default)) + if value <= 0: + raise ValueError(f"dememwm.revisit.{field_name} must be positive") + stage_policy_cfg = self._stage_policy_cfg() + if not bool(self._cfg_get(stage_policy_cfg, "noise_bucket_logging", True)): + raise ValueError("final DeMemWM keeps noise_bucket logging enabled") + proxy_cfg = self._generated_history_proxy_cfg() + proxy_max_prob = float(self._cfg_get(proxy_cfg, "max_prob", 0.0)) + proxy_dropout_prob = float(self._cfg_get(proxy_cfg, "dropout_prob", 0.0)) + proxy_noise_std = float(self._cfg_get(proxy_cfg, "noise_std", 0.0)) + proxy_ramp_steps = int(self._cfg_get(proxy_cfg, "ramp_steps", 0)) + if proxy_max_prob < 0.0 or proxy_max_prob > 1.0: + raise ValueError("dememwm.generated_history_proxy.max_prob must be in [0, 1]") + if proxy_dropout_prob < 0.0 or proxy_dropout_prob > 1.0: + raise ValueError("dememwm.generated_history_proxy.dropout_prob must be in [0, 1]") + if proxy_noise_std < 0.0: + raise ValueError("dememwm.generated_history_proxy.noise_std must be non-negative") + if proxy_ramp_steps < 0: + raise ValueError("dememwm.generated_history_proxy.ramp_steps must be non-negative") + eval_ablation_cfg = self._eval_ablation_cfg() + normalize_eval_ablation_branch(self._cfg_get(eval_ablation_cfg, "branch", "A_plus_D_plus_R_normal")) + + diagnostics = { + "dynamic_exclude_latest_local_frames": exclude_latest_local_frames, + "revisit_deterministic_fov_plucker_retrieval": True, + "revisit_local_context_exclusion_frames": self._local_context_exclusion_frames(), + "revisit_fov_overlap_threshold": -1.0 if fov_overlap_threshold is None else fov_overlap_threshold, + "revisit_high_quality_fov_threshold": high_quality_fov_threshold, + "revisit_plucker_weight": plucker_weight, + "stage_policy_noise_bucket_logging": True, + } + self._dememwm_contract_validated = True + self._last_dememwm_config_diagnostics = diagnostics + return diagnostics + + def _stream_enabled(self, stream_cfg) -> bool: + return bool(self._cfg_get(stream_cfg, "enabled", True)) + + def _context_frame_count(self) -> int: + frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1)) + return max(0, int(getattr(self, "context_frames", 0) or 0) // frame_stack) + + def _local_context_exclusion_frames(self) -> int: + n_tokens = max(0, int(getattr(self, "n_tokens", 0) or 0)) + frame_stack = max(1, int(getattr(self, "frame_stack", 1) or 1)) + return n_tokens * frame_stack + + def _curriculum_state(self, step: int | None = None): + if step is None: + step = int(getattr(self, "global_step", 0) or 0) + return resolve_curriculum(self._memory_cfg(), step) + + def _generated_history_proxy_prob(self, step: int | None = None) -> float: + cfg = self._generated_history_proxy_cfg() + if not bool(self._cfg_get(cfg, "enabled", False)): + return 0.0 + max_prob = min(max(float(self._cfg_get(cfg, "max_prob", 0.0)), 0.0), 1.0) + if max_prob <= 0.0: + return 0.0 + if step is None: + step = int(getattr(self, "global_step", 0) or 0) + start_step = int(self._cfg_get(cfg, "start_step", 0)) + if step < start_step: + return 0.0 + ramp_steps = int(self._cfg_get(cfg, "ramp_steps", 0)) + if ramp_steps <= 0: + return max_prob + ramp_fraction = min(max(float(step - start_step) / float(ramp_steps), 0.0), 1.0) + return max_prob * ramp_fraction + + def _apply_generated_history_proxy( + self, + source_latents: torch.Tensor, + source_is_generated: torch.Tensor | None, + context_frame_count: int | None = None, + target_start_frame: int | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, dict]: + cfg = self._generated_history_proxy_cfg() + prob = self._generated_history_proxy_prob() + noise_std = float(self._cfg_get(cfg, "noise_std", 0.0)) + dropout_prob = float(self._cfg_get(cfg, "dropout_prob", 0.0)) + diagnostics = { + "generated_history_proxy_enabled": bool(self._cfg_get(cfg, "enabled", False)), + "generated_history_proxy_prob": float(prob), + "generated_history_proxy_noise_std": float(noise_std), + "generated_history_proxy_dropout_prob": float(dropout_prob), + "generated_history_proxy_frame_count": 0, + "generated_history_proxy_frame_fraction": 0.0, + } + if source_is_generated is None: + source_is_generated = torch.zeros(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool) + else: + source_is_generated = source_is_generated.to(device=source_latents.device, dtype=torch.bool) + if prob <= 0.0 or source_latents.numel() == 0: + return source_latents, source_is_generated, diagnostics + + eligible_mask = torch.ones(source_latents.shape[:2], device=source_latents.device, dtype=torch.bool) + if context_frame_count is not None or target_start_frame is not None: + frame_positions = torch.arange(source_latents.shape[0], device=source_latents.device)[:, None] + if context_frame_count is not None: + eligible_mask &= frame_positions >= max(0, int(context_frame_count)) + if target_start_frame is not None: + eligible_mask &= frame_positions < max(0, int(target_start_frame)) + proxy_mask = (torch.rand(source_latents.shape[:2], device=source_latents.device) < prob) & eligible_mask + proxy_count = int(proxy_mask.detach().long().sum().item()) + total_count = max(1, int(proxy_mask.numel())) + diagnostics["generated_history_proxy_frame_count"] = proxy_count + diagnostics["generated_history_proxy_frame_fraction"] = float(proxy_count / total_count) + if proxy_count == 0: + return source_latents, source_is_generated, diagnostics + + corrupt_latents = source_latents.clone() + frame_mask = proxy_mask[:, :, None, None, None].to(dtype=corrupt_latents.dtype) + if noise_std > 0.0: + corrupt_latents = corrupt_latents + torch.randn_like(corrupt_latents) * float(noise_std) * frame_mask + if dropout_prob > 0.0: + dropout_mask = torch.rand( + (*source_latents.shape[:2], 1, source_latents.shape[-2], source_latents.shape[-1]), + device=source_latents.device, + ) < dropout_prob + dropout_mask = dropout_mask & proxy_mask[:, :, None, None, None] + corrupt_latents = torch.where(dropout_mask, corrupt_latents.new_zeros(()), corrupt_latents) + source_is_generated = source_is_generated.clone() + source_is_generated |= proxy_mask + return corrupt_latents, source_is_generated, diagnostics + + def _checkpoint_cfg(self): + return self._cfg_get(self._memory_cfg(), "checkpoint", None) + + def _strict_eval_load_enabled(self) -> bool: + return bool(self._cfg_get(self._checkpoint_cfg(), "strict_dememwm_eval_load", True)) + + def _cache_cfg(self): + return self._cfg_get(self._memory_cfg(), "cache", None) + + def _cache_enabled(self) -> bool: + return bool(self._cfg_get(self._cache_cfg(), "enabled", False)) + + def _new_streaming_cache(self, video_id=None) -> StreamingCache | None: + if not self._cache_enabled(): + return None + cache = StreamingCache.from_config(self._cache_cfg(), enabled_default=True) + if cache.clear_between_videos: + cache.reset(video_id=video_id) + return cache + + def _is_memory_adapter_param(self, name: str) -> bool: + return ".memory_token_cross_attn." in name + + def _param_group_name(self, name: str, state=None) -> str: + state = state or self._curriculum_state() + if name.startswith("vae.") or name.startswith("validation_lpips_model."): + return "excluded_frozen" + if name.startswith(("dememwm_dynamic_compressor.", "dememwm_anchor_proj.", "dememwm_revisit_proj.")): + return "dememwm_modules" + if self._is_memory_adapter_param(name): + return "memory_adapters" + if name.startswith("diffusion_model."): + return "full_dit" + return "dememwm_modules" + + def _group_trainable(self, group_name: str, state) -> bool: + if group_name in {"dememwm_modules", "memory_adapters"}: + return True + if group_name == "full_dit": + return state.dit_full_trainable + return False + + def _group_lr(self, group_name: str, state) -> float: + if group_name == "dememwm_modules": + return state.dememwm_lr + if group_name == "memory_adapters": + return state.memory_adapter_lr + if group_name == "full_dit": + return state.full_dit_lr + return 0.0 + + def _apply_freeze_policy(self, optimizer=None, step: int | None = None): + state = self._curriculum_state(step) + + # Keep DDP's trainable graph stable: DiT params stay requires_grad=True + # from step 0 and are frozen by optimizer LR=0 until the full stage. + # Re-walk only when curriculum diagnostics can change. + freeze_key = (state.stage, state.dit_train_state, state.freeze_vae) + last_key = getattr(self, "_last_freeze_key", None) + if last_key != freeze_key: + trainable_tensors = { + "dememwm_modules": 0, + "memory_adapters": 0, + "full_dit": 0, + "excluded_frozen": 0, + } + trainable_scalars = {key: 0 for key in trainable_tensors} + requires_grad_tensors = {key: 0 for key in trainable_tensors} + requires_grad_scalars = {key: 0 for key in trainable_tensors} + for name, param in self.named_parameters(): + group_name = self._param_group_name(name, state) + should_train = self._group_trainable(group_name, state) + if group_name == "excluded_frozen" or (name.startswith("vae.") and state.freeze_vae): + should_train = False + should_require_grad = False + else: + should_require_grad = True + param.requires_grad_(should_require_grad) + if should_train: + trainable_tensors[group_name] = trainable_tensors.get(group_name, 0) + 1 + trainable_scalars[group_name] = trainable_scalars.get(group_name, 0) + int(param.numel()) + if should_require_grad: + requires_grad_tensors[group_name] = requires_grad_tensors.get(group_name, 0) + 1 + requires_grad_scalars[group_name] = requires_grad_scalars.get(group_name, 0) + int(param.numel()) + self._last_freeze_key = freeze_key + self._last_trainable_tensors = trainable_tensors + self._last_trainable_scalars = trainable_scalars + self._last_requires_grad_tensors = requires_grad_tensors + self._last_requires_grad_scalars = requires_grad_scalars + else: + trainable_tensors = getattr(self, "_last_trainable_tensors", {}) + trainable_scalars = getattr(self, "_last_trainable_scalars", {}) + requires_grad_tensors = getattr(self, "_last_requires_grad_tensors", {}) + requires_grad_scalars = getattr(self, "_last_requires_grad_scalars", {}) + + if optimizer is not None: + for param_group in optimizer.param_groups: + group_name = param_group.get("name", "") + trainable = self._group_trainable(group_name, state) + param_group["lr"] = self._group_lr(group_name, state) if trainable else 0.0 + + diagnostics = state.diagnostics() + for group_name in ("dememwm_modules", "memory_adapters", "full_dit"): + diagnostics[f"trainable_tensors_{group_name}"] = trainable_tensors.get(group_name, 0) + diagnostics[f"trainable_params_{group_name}"] = trainable_scalars.get(group_name, 0) + diagnostics[f"requires_grad_tensors_{group_name}"] = requires_grad_tensors.get(group_name, 0) + diagnostics[f"requires_grad_params_{group_name}"] = requires_grad_scalars.get(group_name, 0) + diagnostics[f"optimizer_lr_{group_name}"] = self._group_lr(group_name, state) if self._group_trainable(group_name, state) else 0.0 + self._last_dememwm_freeze_diagnostics = diagnostics + return state + + def configure_optimizers(self): + state = self._curriculum_state(0) + self._apply_freeze_policy(step=0) + grouped: dict[str, list[torch.nn.Parameter]] = { + "dememwm_modules": [], + "memory_adapters": [], + "full_dit": [], + } + for name, param in self.named_parameters(): + group_name = self._param_group_name(name, state) + if group_name in grouped: + grouped[group_name].append(param) + param_groups = [] + for group_name in ("dememwm_modules", "memory_adapters", "full_dit"): + params = grouped[group_name] + if params: + trainable = self._group_trainable(group_name, state) + param_groups.append({ + "params": params, + "lr": self._group_lr(group_name, state) if trainable else 0.0, + "name": group_name, + }) + if not param_groups: + raise RuntimeError("DeMemWM optimizer found no trainable parameter groups") + return torch.optim.AdamW( + param_groups, + weight_decay=self.cfg.weight_decay, + betas=self.cfg.optimizer_beta, + ) + + def on_train_start(self): + optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or [] + for optimizer in optimizers: + self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) + + def on_train_batch_start(self, batch, batch_idx): + optimizers = getattr(getattr(self, "trainer", None), "optimizers", []) or [] + for optimizer in optimizers: + self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) + + def on_after_backward(self): + step = int(getattr(self, "global_step", 0) or 0) + state = self._apply_freeze_policy(step=step) + for name, param in self.named_parameters(): + if param.grad is None: + continue + group_name = self._param_group_name(name, state) + if not self._group_trainable(group_name, state): + param.grad = None + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): + self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0)) + optimizer.step(closure=optimizer_closure) + self._apply_freeze_policy(optimizer, int(getattr(self, "global_step", 0) or 0) + 1) + + def on_load_checkpoint(self, checkpoint): + super().on_load_checkpoint(checkpoint) + if self._strict_eval_load_enabled(): + state_dict = checkpoint.get("state_dict", checkpoint) if isinstance(checkpoint, dict) else checkpoint + self.strict_checkpoint_key_check(state_dict) + + def _preprocess_batch(self, batch): + """Preprocess RGB or precomputed-latent Minecraft batches for DeMemWM. + + MinecraftVideoLatentDataset returns an extra image_hw tensor. Keep the + DeMemWM path on VAE latents while preserving RGB image size for Plucker + pose embeddings. This mirrors the existing latent-dataset contract + without routing through the legacy SSM memory implementation. + """ + from ..df_video import euler_to_camera_to_world_matrix + + if len(batch) == 5: + xs, conditions, pose_conditions, frame_index, image_hw = batch + self._last_dememwm_xs_are_latents = True + self._last_dememwm_image_hw = image_hw + else: + xs, conditions, pose_conditions, frame_index = batch + self._last_dememwm_xs_are_latents = False + self._last_dememwm_image_hw = None + + if self.action_cond_dim: + conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1) + conditions = rearrange(conditions, "b t d -> t b d").contiguous() + else: + raise NotImplementedError("Only support external cond.") + + pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous() + c2w_mat = euler_to_camera_to_world_matrix(pose_conditions) + xs = rearrange(xs, "b t c ... -> t b c ...").contiguous() + frame_index = rearrange(frame_index, "b t -> t b").contiguous() + return xs, conditions, pose_conditions, c2w_mat, frame_index + + def _as_latents(self, xs: torch.Tensor) -> torch.Tensor: + if bool(getattr(self, "_last_dememwm_xs_are_latents", False)): + return xs + return self.encode(xs) + + def _image_size(self, xs: torch.Tensor) -> tuple[int, int]: + image_hw = getattr(self, "_last_dememwm_image_hw", None) + if image_hw is not None: + if torch.is_tensor(image_hw): + values = image_hw.detach().cpu().reshape(-1).tolist() + else: + values = list(image_hw) + if len(values) >= 2: + return int(values[0]), int(values[1]) + return int(xs.shape[-2]), int(xs.shape[-1]) + + def _update_streaming_cache( + self, + cache: StreamingCache | None, + new_latents: torch.Tensor, + frame_indices: torch.Tensor, + pose: torch.Tensor | None = None, + source_is_generated: torch.Tensor | None = None, + action: torch.Tensor | None = None, + ) -> None: + if cache is None or not cache.enabled or new_latents is None or new_latents.shape[0] == 0: + return + cache.add_raw_latents(new_latents, frame_indices, source_is_generated, pose) + if not cache.keep_compressed_records: + return + memory_cfg = self._memory_cfg() + anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) + token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) + anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])] + anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None) + anchor_src_h, anchor_src_w = self._projected_spatial_grid_size( + int(new_latents.shape[-2]), + int(new_latents.shape[-1]), + self.dememwm_anchor_proj, + token_patch_size, + ) + anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size( + anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8 + ) + anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False)) + allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False)) + # Prefix anchors are a per-video prefix resource. Do not add new prefix + # anchors for later committed segments unless explicitly generated anchors are allowed. + if cache.records_count("anchor") > 0 and not allow_generated_anchor: + anchor_indices = [] + anchor_banks, revisit_banks = self._build_streaming_cache_records( + new_latents, + frame_indices, + source_is_generated, + pose, + action, + allow_generated_anchor, + anchor_indices, + anchor_pool_h, + anchor_pool_w, + anchor_diverse, + token_patch_size, + ) + cache.add_memory_banks(anchor_banks, revisit_banks) + + def _build_model(self): + from algorithms.common.metrics import LearnedPerceptualImagePatchSimilarity + from .gates import RevisitRawGate + from ..models.diffusion import Diffusion + from ..models.pose_prediction import PosePredictionNet + from ..models.vae import VAE_models + + self.diffusion_model = Diffusion( + reference_length=self.memory_condition_length, + x_shape=self.x_stacked_shape, + action_cond_dim=self.action_cond_dim, + pose_cond_dim=self.pose_cond_dim, + is_causal=self.causal, + cfg=self.cfg.diffusion, + is_dit=True, + use_plucker=self.use_plucker, + relative_embedding=self.relative_embedding, + state_embed_only_on_qk=self.state_embed_only_on_qk, + use_memory_attention=False, + add_timestamp_embedding=self.add_timestamp_embedding, + memory_token_cross_attention=getattr(self.cfg, "memory_token_cross_attention", True), + memory_cross_attn_layers=getattr(self.cfg, "memory_cross_attn_layers", None), + ref_mode=self.ref_mode, + ) + memory_cfg = self._memory_cfg() + self._validate_config_contract() + injection_cfg = self._cfg_get(memory_cfg, "injection", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) + token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) + max_source_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8)) + self.dememwm_dynamic_compressor = CausalConv3DDynamicCompressor( + latent_channels=self.x_stacked_shape[0], + dit_hidden_size=hidden_size, + patch_size=token_patch_size, + conv_kernel_t=int(self._cfg_get(dynamic_cfg, "conv_kernel_t", 3)), + conv_stride_t=int(self._cfg_get(dynamic_cfg, "conv_stride_t", 2)), + max_source_frames=max_source_frames, + exclude_latest_local_frames=int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)), + ) + spatial_mid_channels = self.x_stacked_shape[0] * token_patch_size * token_patch_size + self.dememwm_anchor_proj = SpatialConv2DMemoryProjector( + latent_channels=self.x_stacked_shape[0], + dit_hidden_size=hidden_size, + mid_channels=spatial_mid_channels, + kernel_size=3, + ) + self.dememwm_revisit_proj = SpatialConv2DMemoryProjector( + latent_channels=self.x_stacked_shape[0], + dit_hidden_size=hidden_size, + mid_channels=spatial_mid_channels, + kernel_size=3, + ) + self.dememwm_revisit_gate = RevisitRawGate() + self.dememwm_injection_adapter = InjectionAdapter() + self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity() + self.vae = VAE_models["vit-l-20-shallow-encoder"]().eval() + for param in self.vae.parameters(): + param.requires_grad_(False) + if self.require_pose_prediction: + self.pose_prediction_model = PosePredictionNet() + + def _project_latent_patch_tokens( + self, + latents: torch.Tensor, + projection: torch.nn.Module, + patch_size: int, + ) -> torch.Tensor: + # (T,B,C,H,W) -> (B,T,T_frame,D). Conv2D projectors keep T_frame=H*W. + if bool(getattr(projection, "projects_spatial_latents", False)): + return projection(latents) + patch_vectors = latent_patch_tokens(latents, patch_size) + return projection(patch_vectors).permute(1, 0, 2, 3).contiguous() + + def _projected_spatial_grid_size( + self, + latent_h: int, + latent_w: int, + projection: torch.nn.Module, + patch_size: int, + ) -> tuple[int, int]: + if bool(getattr(projection, "projects_spatial_latents", False)): + return int(latent_h), int(latent_w) + return int(latent_h) // int(patch_size), int(latent_w) // int(patch_size) + + def _take_uniform_slots(self, tokens: torch.Tensor, num_slots: int) -> torch.Tensor: + if tokens.ndim != 2: + raise ValueError("tokens must have shape (N,D)") + num_slots = max(0, int(num_slots)) + if num_slots == 0: + return tokens[:0] + if tokens.shape[0] <= num_slots: + return tokens + idx = torch.linspace(0, tokens.shape[0] - 1, num_slots, device=tokens.device).round().long() + return tokens.index_select(0, idx) + + def _spatial_pool_tokens( + self, + tokens: torch.Tensor, + pool_h: int, + pool_w: int, + src_h: int, + src_w: int, + ) -> torch.Tensor: + return spatial_pool_tokens(tokens, pool_h, pool_w, src_h, src_w) + + def _resolve_spatial_pool_size( + self, + compress_cfg, + src_h: int, + src_w: int, + default_pool_h: int, + default_pool_w: int, + ) -> tuple[int, int]: + ratio = self._cfg_get(compress_cfg, "downsample_ratio", None) + ratio_h = self._cfg_get(compress_cfg, "downsample_h", ratio) + ratio_w = self._cfg_get(compress_cfg, "downsample_w", ratio) + if ratio_h is not None or ratio_w is not None: + if ratio_h is None: + ratio_h = ratio_w + if ratio_w is None: + ratio_w = ratio_h + ratio_h = float(ratio_h) + ratio_w = float(ratio_w) + if ratio_h <= 0.0 or ratio_w <= 0.0: + raise ValueError("DeMemWM compress downsample ratios must be positive") + return ( + max(1, int(math.ceil(float(src_h) / ratio_h))), + max(1, int(math.ceil(float(src_w) / ratio_w))), + ) + pool_h = int(self._cfg_get(compress_cfg, "pool_h", default_pool_h)) + pool_w = int(self._cfg_get(compress_cfg, "pool_w", default_pool_w)) + if pool_h <= 0 or pool_w <= 0: + raise ValueError("DeMemWM compress pool_h/pool_w must be positive") + return pool_h, pool_w + + def _select_diverse_anchor_positions( + self, + source_positions: torch.Tensor, + pose: torch.Tensor | None, + num_anchors: int, + ) -> torch.Tensor: + num_anchors = max(0, int(num_anchors)) + if num_anchors == 0: + return source_positions[:0] + if source_positions.numel() <= num_anchors or pose is None: + return source_positions[:num_anchors] + poses = pose.float() + selected = [0] + dists = torch.cdist(poses[0:1], poses).squeeze(0) + for _ in range(num_anchors - 1): + farthest = int(dists.argmax().item()) + selected.append(farthest) + d_new = torch.cdist(poses[farthest:farthest + 1], poses).squeeze(0) + dists = torch.minimum(dists, d_new) + return source_positions[torch.tensor(sorted(selected), device=source_positions.device)] + + def _build_streaming_cache_records( + self, + source_latents: torch.Tensor, + source_frame_indices: torch.Tensor, + source_is_generated: torch.Tensor | None, + pose: torch.Tensor | None, + action: torch.Tensor | None, + allow_generated_anchor: bool, + anchor_indices: list[int], + anchor_pool_h: int, + anchor_pool_w: int, + anchor_diverse: bool, + token_patch_size: int, + ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]: + if source_latents.ndim != 5: + raise ValueError("source_latents must have shape (T,B,C,H,W)") + if source_frame_indices.ndim != 2: + raise ValueError("source_frame_indices must have shape (T,B)") + T_src, B = source_frame_indices.shape + if source_latents.shape[:2] != (T_src, B): + raise ValueError("source_latents and source_frame_indices must share T/B dimensions") + _, _, _, latent_H, latent_W = source_latents.shape + src_h, src_w = self._projected_spatial_grid_size( + latent_H, + latent_W, + self.dememwm_anchor_proj, + token_patch_size, + ) + + param = next(iter(self.dememwm_anchor_proj.parameters())) + project_device = param.device + project_dtype = param.dtype + hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0]) + generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device) + anchor_banks: list[CausalMemoryBank] = [] + revisit_banks: list[CausalMemoryBank] = [] + + def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): + if tensor is None or tensor.ndim < 2: + return None + pos = positions.to(device=tensor.device) + if tensor.shape[0] == T_src and tensor.shape[1] == B: + return tensor[pos, batch_idx] + if tensor.shape[0] == B and tensor.shape[1] == T_src: + return tensor[batch_idx, pos] + return None + + def _metadata_subset(positions: torch.Tensor, batch_idx: int): + return {"dememwm_revisit_metadata_only": True} + + def _pose_subset(positions: torch.Tensor, batch_idx: int): + return _tensor_subset(pose, positions, batch_idx) + + def _add_anchor_records(bank: CausalMemoryBank, batch_idx: int, positions: torch.Tensor, generated_anchor: bool) -> None: + if positions.numel() == 0: + return + projected = self._project_latent_patch_tokens( + source_latents.index_select(0, positions.to(device=source_latents.device))[:, batch_idx:batch_idx + 1].to(device=project_device, dtype=project_dtype), + self.dememwm_anchor_proj, + token_patch_size, + )[0] + src_frames = source_frame_indices[:, batch_idx] + for local_idx, source_pos in enumerate(positions): + source_pos_i = int(source_pos.item()) + anchor_tokens = self._spatial_pool_tokens(projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w) + n_slots = anchor_tokens.shape[0] + record_mask = torch.ones((n_slots,), device=anchor_tokens.device, dtype=torch.bool) + if generated_anchor: + bank.add_generated_records( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device), + source_type=MemorySourceType.GENERATED, + ) + else: + bank.add_prefix_anchors( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1].to(device=anchor_tokens.device), + slots_per_anchor=n_slots, + ) + + for batch_idx in range(B): + anchor_bank = CausalMemoryBank() + revisit_bank = CausalMemoryBank() + src_frames = source_frame_indices[:, batch_idx] + if generated is None: + non_generated = torch.ones_like(src_frames, dtype=torch.bool) + else: + non_generated = ~generated[:, batch_idx] + + source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() + if source_positions.numel() > 0: + if anchor_diverse: + anchor_pose = _pose_subset(source_positions, batch_idx) + selected_anchor_positions = self._select_diverse_anchor_positions( + source_positions, anchor_pose, len(anchor_indices) + ) + else: + selected_list = [] + for anchor_idx in anchor_indices: + if 0 <= int(anchor_idx) < source_positions.numel(): + selected_list.append(source_positions[int(anchor_idx)]) + selected_anchor_positions = torch.stack(selected_list).long() if selected_list else source_positions[:0] + if selected_anchor_positions.numel() > 0: + _add_anchor_records(anchor_bank, batch_idx, selected_anchor_positions.long(), False) + + dummy_tokens = torch.zeros((1, hidden_size), device=source_frame_indices.device, dtype=project_dtype) + dummy_mask = torch.ones((1,), device=source_frame_indices.device, dtype=torch.bool) + for prefix, positions, source_type, is_generated in ( + ("prefix", source_positions, MemorySourceType.PREFIX_GT, False), + ( + "generated", + torch.empty(0, device=source_frame_indices.device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(), + MemorySourceType.GENERATED, + True, + ), + ): + if positions.numel() == 0: + continue + for source_pos in positions.to(device=source_frame_indices.device, dtype=torch.long): + source_pos_i = int(source_pos.item()) + frame_index = src_frames[source_pos_i] + frame = int(frame_index.detach().item()) + frame_pos = source_pos.reshape(1) + revisit_bank.add_frame_record( + dummy_tokens, + dummy_mask, + frame_index, + pose=_pose_subset(frame_pos, batch_idx), + source_type=source_type, + metadata=_metadata_subset(frame_pos, batch_idx), + is_generated=is_generated, + record_id=f"{prefix}_revisit_b{batch_idx}_f{frame}", + ) + + if allow_generated_anchor and generated is not None and anchor_indices: + generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten() + _add_anchor_records(anchor_bank, batch_idx, generated_positions[:len(anchor_indices)].long(), True) + + anchor_banks.append(anchor_bank) + revisit_banks.append(revisit_bank) + return anchor_banks, revisit_banks + + + def _build_causal_memory_banks( + self, + anchor_projected: torch.Tensor, + revisit_projected: torch.Tensor, + source_frame_indices: torch.Tensor, + source_is_generated: torch.Tensor | None, + pose: torch.Tensor | None, + action: torch.Tensor | None, + allow_generated_anchor: bool, + anchor_indices: list[int], + anchor_pool_h: int, + anchor_pool_w: int, + revisit_pool_h: int, + revisit_pool_w: int, + src_h: int, + src_w: int, + ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank]]: + # projected tensors use the same batch/source convention as + # _project_latent_patch_tokens: (B, T_src, T_frame, D), while frame indices are + # (T_src, B). Build separate banks because anchor and revisit records + # come from different projections. + if anchor_projected.ndim != 4 or revisit_projected.ndim != 4: + raise ValueError("anchor/revisit projected tensors must have shape (B,T_src,T_frame,D)") + B, T_src, _, _ = anchor_projected.shape + if revisit_projected.shape[:3] != anchor_projected.shape[:3]: + raise ValueError("anchor/revisit projected tensors must share batch/source/token dimensions") + generated = None if source_is_generated is None else source_is_generated.bool().to(source_frame_indices.device) + anchor_banks: list[CausalMemoryBank] = [] + revisit_banks: list[CausalMemoryBank] = [] + + def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): + if tensor is None or tensor.ndim < 2: + return None + if tensor.shape[0] == T_src and tensor.shape[1] == B: + return tensor[positions, batch_idx] + if tensor.shape[0] == B and tensor.shape[1] == T_src: + return tensor[batch_idx, positions] + return None + + def _metadata_subset(positions: torch.Tensor, batch_idx: int): + return {} + + def _pose_subset(positions: torch.Tensor, batch_idx: int): + return _tensor_subset(pose, positions, batch_idx) + + for batch_idx in range(B): + anchor_bank = CausalMemoryBank() + revisit_bank = CausalMemoryBank() + src_frames = source_frame_indices[:, batch_idx] + if generated is None: + non_generated = torch.ones_like(src_frames, dtype=torch.bool) + else: + non_generated = ~generated[:, batch_idx] + + source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() + if source_positions.numel() > 0: + selected_anchor_positions = [] + for anchor_idx in anchor_indices: + if 0 <= int(anchor_idx) < source_positions.numel(): + selected_anchor_positions.append(source_positions[int(anchor_idx)]) + for source_pos in selected_anchor_positions: + source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos) + anchor_tokens = self._spatial_pool_tokens( + anchor_projected[batch_idx, source_pos_i], + anchor_pool_h, anchor_pool_w, src_h, src_w, + ) + n_slots = anchor_tokens.shape[0] + record_mask = torch.ones( + (n_slots,), + device=anchor_projected.device, + dtype=torch.bool, + ) + anchor_bank.add_prefix_anchors( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1], + slots_per_anchor=n_slots, + ) + + for source_pos in source_positions: + source_pos_i = int(source_pos.item()) + frame_index = src_frames[source_pos_i] + frame = int(frame_index.detach().item()) + frame_pos = source_pos.reshape(1) + frame_tokens = self._spatial_pool_tokens( + revisit_projected[batch_idx, source_pos_i], + revisit_pool_h, revisit_pool_w, src_h, src_w, + ) + frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool) + revisit_bank.add_frame_record( + frame_tokens, + frame_mask, + frame_index, + pose=_pose_subset(frame_pos, batch_idx), + source_type=MemorySourceType.PREFIX_GT, + metadata=_metadata_subset(frame_pos, batch_idx), + is_generated=False, + record_id=f"prefix_revisit_b{batch_idx}_f{frame}", + ) + + if generated is not None: + generated_positions = torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten() + if generated_positions.numel() > 0: + for source_pos in generated_positions: + source_pos_i = int(source_pos.item()) + frame_index = src_frames[source_pos_i] + frame = int(frame_index.detach().item()) + frame_pos = source_pos.reshape(1) + frame_tokens = self._spatial_pool_tokens( + revisit_projected[batch_idx, source_pos_i], + revisit_pool_h, revisit_pool_w, src_h, src_w, + ) + frame_mask = torch.ones((frame_tokens.shape[0],), device=revisit_projected.device, dtype=torch.bool) + revisit_bank.add_frame_record( + frame_tokens, + frame_mask, + frame_index, + pose=_pose_subset(frame_pos, batch_idx), + source_type=MemorySourceType.GENERATED, + metadata=_metadata_subset(frame_pos, batch_idx), + is_generated=True, + record_id=f"generated_revisit_b{batch_idx}_f{frame}", + ) + if allow_generated_anchor: + for source_pos in generated_positions[:len(anchor_indices)]: + source_pos_i = int(source_pos.item()) if torch.is_tensor(source_pos) else int(source_pos) + anchor_tokens = self._spatial_pool_tokens( + anchor_projected[batch_idx, source_pos_i], + anchor_pool_h, anchor_pool_w, src_h, src_w, + ) + record_mask = torch.ones((anchor_tokens.shape[0],), device=anchor_projected.device, dtype=torch.bool) + anchor_bank.add_generated_records( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1], + source_type=MemorySourceType.GENERATED, + ) + + anchor_banks.append(anchor_bank) + revisit_banks.append(revisit_bank) + return anchor_banks, revisit_banks + + def _build_preselected_causal_memory_banks( + self, + committed_latents: torch.Tensor, + source_frame_indices: torch.Tensor, + source_is_generated: torch.Tensor | None, + pose: torch.Tensor | None, + action: torch.Tensor | None, + target_frame_indices: torch.Tensor, + target_pose: torch.Tensor | None, + target_action: torch.Tensor | None, + target_video_ids, + allow_generated_anchor: bool, + anchor_indices: list[int], + anchor_pool_h: int, + anchor_pool_w: int, + anchor_diverse: bool, + revisit_pool_h: int, + revisit_pool_w: int, + revisit_max_frames: int, + exclude_local_context_frames: int, + fov_overlap_threshold, + plucker_weight: float, + revisit_retrieval_kwargs: dict | None, + token_patch_size: int, + ) -> tuple[list[CausalMemoryBank], list[CausalMemoryBank], int, dict]: + if committed_latents.ndim != 5: + raise ValueError("committed_latents must have shape (T_src,B,C,H,W)") + T_src, B, _, H, W = committed_latents.shape + if source_frame_indices.shape != (T_src, B): + raise ValueError("source_frame_indices must have shape (T_src,B)") + if target_frame_indices.ndim == 1: + target_frame_indices = target_frame_indices[:, None] + if target_frame_indices.shape[1] != B: + raise ValueError("target_frame_indices must have batch dimension B") + T_tgt = target_frame_indices.shape[0] + stream_device = committed_latents.device + hidden_size = int(getattr(self.dememwm_revisit_proj, "out_features", 0) or self.dememwm_revisit_proj.weight.shape[0]) + src_h, src_w = self._projected_spatial_grid_size( + H, + W, + self.dememwm_anchor_proj, + token_patch_size, + ) + tokens_per_frame = src_h * src_w + generated = None if source_is_generated is None else source_is_generated.bool().to(device=source_frame_indices.device) + anchor_banks: list[CausalMemoryBank] = [] + revisit_banks: list[CausalMemoryBank] = [] + dummy_tokens = committed_latents.new_zeros((1, hidden_size)) + dummy_mask = torch.ones((1,), device=stream_device, dtype=torch.bool) + preselection_candidate_count = 0 + preselection_valid_candidate_label_count = 0 + preselection_selected_count = 0 + projected_anchor_frames = 0 + projected_revisit_frames = 0 + projected_revisit_records = 0 + retrieval_kwargs = dict(revisit_retrieval_kwargs or {}) + + # Pre-convert pose tensors to stream_device once so that the + # _tensor_subset / _target_tensor closures below never trigger a + # device transfer on every call. + if pose is not None: + pose = pose.to(device=stream_device) + if target_pose is not None: + target_pose = target_pose.to(device=stream_device) + + def _tensor_subset(tensor: torch.Tensor | None, positions: torch.Tensor, batch_idx: int): + if tensor is None or tensor.ndim < 2: + return None + if tensor.shape[0] == T_src and tensor.shape[1] == B: + return tensor[positions, batch_idx] + if tensor.shape[0] == B and tensor.shape[1] == T_src: + return tensor[batch_idx, positions] + return None + + def _target_tensor(tensor: torch.Tensor | None, batch_idx: int, target_idx: int): + if tensor is None or tensor.ndim < 2: + return None + if tensor.shape[0] == T_tgt and tensor.shape[1] == B: + return tensor[target_idx, batch_idx] + if tensor.shape[0] == B and tensor.shape[1] == T_tgt: + return tensor[batch_idx, target_idx] + return None + + def _target_video_id(batch_idx: int, target_idx: int): + if target_video_ids is None: + return None + if torch.is_tensor(target_video_ids): + ids = target_video_ids.detach().cpu() + if ids.ndim == 0: + return ids.item() + if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B: + return ids[target_idx, batch_idx].item() + if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt: + return ids[batch_idx, target_idx].item() + return None + if isinstance(target_video_ids, (list, tuple)): + if len(target_video_ids) == B: + return target_video_ids[batch_idx] + if len(target_video_ids) == T_tgt: + row = target_video_ids[target_idx] + if isinstance(row, (list, tuple)) and len(row) == B: + return row[batch_idx] + return row + return target_video_ids + + def _metadata_subset(positions: torch.Tensor, batch_idx: int): + return {} + + def _pose_subset(positions: torch.Tensor, batch_idx: int): + return _tensor_subset(pose, positions, batch_idx) + + def _candidate_record( + *, + batch_idx: int, + frame_position: torch.Tensor, + source_type: MemorySourceType, + is_generated: bool, + record_id: str, + ) -> MemoryRecord: + frame_values = source_frame_indices[frame_position, batch_idx].to(device=stream_device) + frame = int(frame_values.reshape(-1)[0].item()) + return MemoryRecord( + tokens=dummy_tokens, + mask=dummy_mask, + source_start=frame, + source_end=frame + 1, + frame_indices=frame_values.reshape(1), + pose=_pose_subset(frame_position, batch_idx), + source_type=source_type, + is_generated=bool(is_generated), + chunk_id=record_id, + metadata=_metadata_subset(frame_position, batch_idx), + ) + + for batch_idx in range(B): + anchor_bank = CausalMemoryBank() + revisit_bank = CausalMemoryBank() + src_frames = source_frame_indices[:, batch_idx] + if generated is None: + non_generated = torch.ones_like(src_frames, dtype=torch.bool) + else: + non_generated = ~generated[:, batch_idx] + source_positions = torch.nonzero(non_generated, as_tuple=False).flatten() + + anchor_positions = source_positions[:0].to(device=stream_device, dtype=torch.long) + if anchor_indices and source_positions.numel() > 0: + if anchor_diverse: + anchor_source_positions = source_positions[source_positions < self._context_frame_count()] + if anchor_source_positions.numel() > 0: + anchor_pose = _pose_subset(anchor_source_positions, batch_idx) + anchor_positions = self._select_diverse_anchor_positions( + anchor_source_positions, anchor_pose, len(anchor_indices) + ).to(device=stream_device, dtype=torch.long) + else: + selected_anchor_positions = [] + for anchor_idx in anchor_indices: + if 0 <= int(anchor_idx) < source_positions.numel(): + selected_anchor_positions.append(source_positions[int(anchor_idx)]) + if selected_anchor_positions: + anchor_positions = torch.stack(selected_anchor_positions).to(device=stream_device, dtype=torch.long) + if anchor_positions.numel() > 0: + projected_anchor_frames += int(anchor_positions.numel()) + anchor_projected = self._project_latent_patch_tokens( + committed_latents.index_select(0, anchor_positions)[:, batch_idx:batch_idx + 1], + self.dememwm_anchor_proj, + token_patch_size, + )[0] + for local_idx, source_pos in enumerate(anchor_positions): + source_pos_i = int(source_pos.item()) + anchor_tokens = self._spatial_pool_tokens(anchor_projected[local_idx], anchor_pool_h, anchor_pool_w, src_h, src_w) + n_slots = anchor_tokens.shape[0] + record_mask = torch.ones((n_slots,), device=stream_device, dtype=torch.bool) + anchor_bank.add_prefix_anchors( + anchor_tokens.unsqueeze(0), + record_mask.unsqueeze(0), + src_frames[source_pos_i:source_pos_i + 1], + slots_per_anchor=n_slots, + ) + + candidate_records: list[MemoryRecord] = [] + candidate_positions: dict[str, torch.Tensor] = {} + src_frames_cpu = src_frames.detach().cpu() + target_frames_cpu = target_frame_indices[:, batch_idx].detach().cpu().to(dtype=torch.long) + latest_valid_source_frame_exclusive = int(target_frames_cpu.max().item()) - int(exclude_local_context_frames) + for prefix, positions, source_type, is_generated in ( + ("prefix", source_positions, MemorySourceType.PREFIX_GT, False), + ( + "generated", + torch.empty(0, device=stream_device, dtype=torch.long) if generated is None else torch.nonzero(generated[:, batch_idx], as_tuple=False).flatten(), + MemorySourceType.GENERATED, + True, + ), + ): + if positions.numel() == 0 or latest_valid_source_frame_exclusive <= 0: + continue + positions_cpu = positions.detach().cpu().to(dtype=torch.long) + for frame_position_cpu in positions_cpu: + frame = int(src_frames_cpu[int(frame_position_cpu.item())].item()) + if frame >= latest_valid_source_frame_exclusive: + continue + frame_position = frame_position_cpu.reshape(1).to(device=stream_device, dtype=torch.long) + record_id = f"{prefix}_revisit_b{batch_idx}_f{frame}" + candidate_positions[record_id] = frame_position + candidate_records.append(_candidate_record( + batch_idx=batch_idx, + frame_position=frame_position, + source_type=source_type, + is_generated=is_generated, + record_id=record_id, + )) + + selected_frame_record_ids: set[str] = set() + selected_frame_metadata: dict[str, dict] = {} + for target_idx in range(T_tgt): + target_frame = int(target_frame_indices[target_idx, batch_idx].item()) + result = deterministic_revisit_retrieval( + candidate_records, + target_frame=target_frame, + target_pose=_target_tensor(target_pose, batch_idx, target_idx), + target_summary=None, + topk=revisit_max_frames, + exclude_local_context_frames=exclude_local_context_frames, + fov_overlap_threshold=fov_overlap_threshold, + plucker_weight=plucker_weight, + target_video_id=_target_video_id(batch_idx, target_idx), + **retrieval_kwargs, + ) + preselection_candidate_count += int(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0))) + preselection_valid_candidate_label_count += int(result.diagnostics.get("valid_candidate_label_count", 0)) + preselection_selected_count += int(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0))) + for selected_record in result.records: + if selected_record.chunk_id is None: + continue + record_id = str(selected_record.chunk_id) + selected_frame_record_ids.add(record_id) + selected_frame_metadata[record_id] = dict(selected_record.metadata) + + for record in candidate_records: + if record.chunk_id not in selected_frame_record_ids: + continue + record_id = str(record.chunk_id) + frame_position = candidate_positions[record_id] + projected_revisit_records += 1 + projected_revisit_frames += int(frame_position.numel()) + revisit_projected = self._project_latent_patch_tokens( + committed_latents.index_select(0, frame_position)[:, batch_idx:batch_idx + 1], + self.dememwm_revisit_proj, + token_patch_size, + )[0] + frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, src_h, src_w) + frame_mask = torch.ones((frame_tokens.shape[0],), device=stream_device, dtype=torch.bool) + record_metadata = dict(record.metadata) + record_metadata.update(selected_frame_metadata.get(record_id, {})) + revisit_bank.add_frame_record( + frame_tokens, + frame_mask, + record.frame_indices.reshape(-1)[0], + pose=record.pose, + source_type=record.source_type, + metadata=record_metadata, + is_generated=record.is_generated, + record_id=record.chunk_id, + ) + + anchor_banks.append(anchor_bank) + revisit_banks.append(revisit_bank) + + diagnostics = { + "preselected_anchor_projected_frame_count": projected_anchor_frames, + "preselected_revisit_projected_frame_count": projected_revisit_frames, + "preselected_revisit_projected_frame_record_count": projected_revisit_records, + "preselected_revisit_candidate_frame_count": preselection_candidate_count, + "preselected_revisit_candidate_count": preselection_candidate_count, + "preselected_revisit_valid_candidate_label_count": preselection_valid_candidate_label_count, + "preselected_revisit_selected_frame_count": preselection_selected_count, + "preselected_revisit_selected_count": preselection_selected_count, + } + return anchor_banks, revisit_banks, tokens_per_frame, diagnostics + + def _records_to_stream( + self, + records, + max_tokens: int, + hidden_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor, int]: + max_tokens = max(0, int(max_tokens)) + record_list = list(records) + stacked_tokens, stacked_mask = stack_record_tokens(record_list, max_slots=max_tokens) + max_source_frame = max((int(record.max_source_frame) for record in record_list), default=-1) + if stacked_tokens is None or stacked_mask is None or max_tokens == 0: + tokens = torch.zeros((max_tokens, hidden_size), device=device, dtype=dtype) + mask = torch.zeros((max_tokens,), device=device, dtype=torch.bool) + return tokens, mask, max_source_frame + n = min(max_tokens, stacked_tokens.shape[0]) + filled = stacked_tokens[:n].to(device=device, dtype=dtype) + filled_mask = stacked_mask[:n].to(device=device, dtype=torch.bool) + if n < max_tokens: + pad = filled.new_zeros(max_tokens - n, hidden_size) + pad_mask = torch.zeros(max_tokens - n, device=device, dtype=torch.bool) + tokens = torch.cat([filled, pad], dim=0) + mask = torch.cat([filled_mask, pad_mask], dim=0) + else: + tokens = filled + mask = filled_mask + return tokens, mask, max_source_frame + + def _project_streaming_revisit_records( + self, + *, + cache: StreamingCache, + batch_idx: int, + records: list[MemoryRecord], + device: torch.device, + dtype: torch.dtype, + token_patch_size: int, + revisit_pool_h: int, + revisit_pool_w: int, + projection_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord], + ) -> list[MemoryRecord]: + projected_records: list[MemoryRecord] = [] + for record in records: + if not bool(record.metadata.get("dememwm_revisit_metadata_only", False)): + projected_records.append(record) + continue + selected_frame_index = record.metadata.get("dememwm_selected_frame_index") + if selected_frame_index is None: + best_frame_idx = record.frame_indices[torch.argmax(record.frame_indices)].reshape(1) + else: + best_frame_idx = torch.as_tensor( + [int(selected_frame_index)], + device=record.frame_indices.device, + dtype=record.frame_indices.dtype, + ) + key = ( + int(batch_idx), + str(record.chunk_id or ""), + int(record.source_start), + int(record.source_end), + int(best_frame_idx.detach().cpu().reshape(-1)[0].item()), + bool(record.is_generated), + ) + cached = projection_cache.get(key) + if cached is not None: + projected_records.append(cached) + continue + + raw_latents = cache.raw_latents_for_frames( + batch_idx=batch_idx, + frame_indices=best_frame_idx, + device=device, + dtype=dtype, + ) + revisit_projected = self._project_latent_patch_tokens( + raw_latents, + self.dememwm_revisit_proj, + token_patch_size, + )[0] + _proj_src_h, _proj_src_w = self._projected_spatial_grid_size( + raw_latents.shape[3], + raw_latents.shape[4], + self.dememwm_revisit_proj, + token_patch_size, + ) + frame_tokens = self._spatial_pool_tokens(revisit_projected[0], revisit_pool_h, revisit_pool_w, _proj_src_h, _proj_src_w) + frame_mask = torch.ones((frame_tokens.shape[0],), device=device, dtype=torch.bool) + metadata = { + key: (value.to(device=device) if torch.is_tensor(value) else value) + for key, value in record.metadata.items() + } + metadata["dememwm_revisit_metadata_only"] = False + projected = MemoryRecord( + tokens=frame_tokens, + mask=frame_mask, + source_start=int(record.source_start), + source_end=int(record.source_end), + frame_indices=record.frame_indices.to(device=device), + pose=None if record.pose is None else record.pose.to(device=device), + source_type=record.source_type, + is_generated=bool(record.is_generated), + score=record.score, + chunk_id=record.chunk_id, + metadata=metadata, + ) + projection_cache[key] = projected + projected_records.append(projected) + return projected_records + + def build_memory_streams( + self, + committed_latents: torch.Tensor | None, + source_frame_indices: torch.Tensor | None, + target_frame_indices: torch.Tensor | None = None, + pose: torch.Tensor | None = None, + target_pose: torch.Tensor | None = None, + action: torch.Tensor | None = None, + target_action: torch.Tensor | None = None, + target_video_ids=None, + source_is_generated: torch.Tensor | None = None, + denoising_fraction: float | None = None, + noise_bucket: str | None = None, + noise_bucket_ids: torch.Tensor | None = None, + streaming_cache: StreamingCache | None = None, + ) -> MemoryStreamTensors: + if target_frame_indices is None: + if source_frame_indices is None: + raise ValueError("target_frame_indices or source_frame_indices is required") + target_frame_indices = source_frame_indices + memory_cfg = self._memory_cfg() + anchor_cfg = self._cfg_get(memory_cfg, "anchor", None) + dynamic_cfg = self._cfg_get(memory_cfg, "dynamic", None) + revisit_cfg = self._cfg_get(memory_cfg, "revisit", None) + injection_cfg = self._cfg_get(memory_cfg, "injection", None) + contract_diag = self._validate_config_contract() + gate_state = self._effective_gate_state( + denoising_fraction=denoising_fraction, + noise_bucket=noise_bucket, + ) + anchor_config_enabled = gate_state["anchor_config_enabled"] + dynamic_config_enabled = gate_state["dynamic_config_enabled"] + revisit_config_enabled = gate_state["revisit_config_enabled"] + curriculum_state = gate_state["curriculum_state"] + eval_ablation_enabled = gate_state["eval_ablation_enabled"] + eval_ablation_branch = gate_state["eval_ablation_branch"] + resolved_noise_bucket = gate_state["resolved_noise_bucket"] + gates = gate_state["gates"] + anchor_effective_enabled = gate_state["anchor_effective_enabled"] + dynamic_effective_enabled = gate_state["dynamic_effective_enabled"] + revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"] + force_revisit_off = gate_state["force_revisit_off"] + force_revisit_on = gate_state["force_revisit_on"] + token_patch_size = int(self._cfg_get(memory_cfg, "token_patch_size", 2)) + anchor_indices = [int(x) for x in self._cfg_get(anchor_cfg, "anchor_indices", [0, 1, 2, 3])] + anchor_compress_cfg = self._cfg_get(anchor_cfg, "compress", None) + pool_latent_h = int(committed_latents.shape[-2]) if committed_latents is not None else int(self.x_stacked_shape[-2]) + pool_latent_w = int(committed_latents.shape[-1]) if committed_latents is not None else int(self.x_stacked_shape[-1]) + anchor_src_h, anchor_src_w = self._projected_spatial_grid_size( + pool_latent_h, + pool_latent_w, + self.dememwm_anchor_proj, + token_patch_size, + ) + anchor_pool_h, anchor_pool_w = self._resolve_spatial_pool_size( + anchor_compress_cfg, anchor_src_h, anchor_src_w, 5, 8 + ) + anchor_num_tokens = len(anchor_indices) * anchor_pool_h * anchor_pool_w + anchor_diverse = bool(self._cfg_get(anchor_cfg, "diverse_selection", False)) + allow_generated_anchor = bool(self._cfg_get(anchor_cfg, "allow_generated_as_anchor", False)) + revisit_max_frames = int(self._cfg_get(revisit_cfg, "max_frames", 2)) + revisit_compress_cfg = self._cfg_get(revisit_cfg, "compress", None) + revisit_src_h, revisit_src_w = self._projected_spatial_grid_size( + pool_latent_h, + pool_latent_w, + self.dememwm_revisit_proj, + token_patch_size, + ) + revisit_pool_h, revisit_pool_w = self._resolve_spatial_pool_size( + revisit_compress_cfg, revisit_src_h, revisit_src_w, 5, 8 + ) + revisit_max_tokens = revisit_max_frames * revisit_pool_h * revisit_pool_w + recent_frames = int(self._cfg_get(dynamic_cfg, "recent_frames", 8)) + exclude_latest_local_frames = int(self._cfg_get(dynamic_cfg, "exclude_latest_local_frames", 4)) + local_context_exclusion_frames = self._local_context_exclusion_frames() + fov_overlap_threshold = self._cfg_get(revisit_cfg, "fov_overlap_threshold", 0.30) + high_quality_fov_threshold = float(self._cfg_get(revisit_cfg, "high_quality_fov_threshold", 0.70)) + plucker_weight = float(self._cfg_get(revisit_cfg, "plucker_weight", 0.10)) + revisit_retrieval_kwargs = { + "high_quality_fov_threshold": high_quality_fov_threshold, + "fov_half_h": float(self._cfg_get(revisit_cfg, "fov_half_h", 52.5)), + "fov_half_v": float(self._cfg_get(revisit_cfg, "fov_half_v", 37.5)), + "fov_yaw_samples": int(self._cfg_get(revisit_cfg, "fov_yaw_samples", 25)), + "fov_pitch_samples": int(self._cfg_get(revisit_cfg, "fov_pitch_samples", 20)), + "fov_depth_samples": int(self._cfg_get(revisit_cfg, "fov_depth_samples", 20)), + "fov_radius": float(self._cfg_get(revisit_cfg, "fov_radius", 30.0)), + "pose_preselect_topk": self._cfg_get(revisit_cfg, "pose_preselect_topk", 64), + "plucker_grid_h": int(self._cfg_get(revisit_cfg, "plucker_grid_h", 4)), + "plucker_grid_w": int(self._cfg_get(revisit_cfg, "plucker_grid_w", 4)), + "plucker_focal_length": float(self._cfg_get(revisit_cfg, "plucker_focal_length", 0.35)), + } + preselection_diag = {} + use_cache_revisit_records = False + revisit_record_batches: list[tuple[MemoryRecord, ...]] | None = None + + cache = streaming_cache if streaming_cache is not None and getattr(streaming_cache, "enabled", False) else None + cache_diag = cache.diagnostics("cache") if cache is not None else {"cache_enabled": False, "cache_records": 0, "cache_slots": 0, "cache_evictions": 0, "cache_resets": 0} + if committed_latents is not None: + stream_device = committed_latents.device + stream_dtype = committed_latents.dtype + else: + param = next(iter(self.dememwm_anchor_proj.parameters())) + stream_device = param.device + stream_dtype = param.dtype + target_frame_indices = target_frame_indices.to(device=stream_device) + if target_frame_indices.ndim == 1: + target_frame_indices = target_frame_indices[:, None] + + use_cache_records = cache is not None and cache.keep_compressed_records and cache.record_count > 0 + dynamic_latents = committed_latents if dynamic_effective_enabled else None + dynamic_frame_indices = source_frame_indices if dynamic_effective_enabled else None + dynamic_generated = source_is_generated if dynamic_effective_enabled else None + dynamic_pose = pose if dynamic_effective_enabled else None + if dynamic_effective_enabled and cache is not None and cache.raw_frame_slots > 0: + raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents( + device=stream_device, + dtype=stream_dtype, + max_recent_frames=recent_frames, + target_frame_indices=target_frame_indices, + exclude_latest_local_frames=exclude_latest_local_frames, + ) + if raw_latents is not None: + dynamic_latents = raw_latents + dynamic_frame_indices = raw_frames + dynamic_generated = raw_generated + dynamic_pose = raw_pose + + if use_cache_records: + B = target_frame_indices.shape[1] + hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) + anchor_banks = ( + cache.memory_banks("anchor", device=stream_device, dtype=stream_dtype, batch_size=B) + if anchor_effective_enabled else [CausalMemoryBank() for _ in range(B)] + ) + revisit_banks = [CausalMemoryBank() for _ in range(B)] + revisit_record_batches = ( + [cache.records_for_batch("revisit", batch_idx) for batch_idx in range(B)] + if revisit_stage_config_enabled else [tuple() for _ in range(B)] + ) + use_cache_revisit_records = bool(revisit_stage_config_enabled) + if dynamic_latents is not None and dynamic_latents.ndim == 5 and dynamic_latents.shape[0] > 0: + tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size( + dynamic_latents.shape[-2], + dynamic_latents.shape[-1], + self.dememwm_anchor_proj, + token_patch_size, + ) + tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w + else: + latent_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 0 + latent_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 0 + tokens_per_frame_h, tokens_per_frame_w = self._projected_spatial_grid_size( + latent_h, + latent_w, + self.dememwm_anchor_proj, + token_patch_size, + ) + tokens_per_frame = tokens_per_frame_h * tokens_per_frame_w + else: + if committed_latents is None or source_frame_indices is None: + raise ValueError("committed_latents/source_frame_indices are required when no streaming cache records are available") + B = committed_latents.shape[1] + hidden_size = int(self._cfg_get(injection_cfg, "dit_hidden_size", 1024)) + target_pose_source = target_pose if target_pose is not None else pose + anchor_banks, revisit_banks, tokens_per_frame, preselection_diag = self._build_preselected_causal_memory_banks( + committed_latents, + source_frame_indices.to(device=stream_device), + None if source_is_generated is None else source_is_generated.to(device=stream_device, dtype=torch.bool), + None if pose is None else pose.to(device=stream_device), + None, + target_frame_indices, + None if target_pose_source is None else target_pose_source.to(device=stream_device), + None, + target_video_ids, + allow_generated_anchor, + anchor_indices, + anchor_pool_h, + anchor_pool_w, + anchor_diverse, + revisit_pool_h, + revisit_pool_w, + revisit_max_frames, + local_context_exclusion_frames, + fov_overlap_threshold, + plucker_weight, + revisit_retrieval_kwargs, + token_patch_size, + ) + revisit_record_batches = [tuple(bank.records) for bank in revisit_banks] + + T_tgt = target_frame_indices.shape[0] + anchor_slots = max(0, anchor_num_tokens) + revisit_slots = max(0, revisit_max_tokens) + anchor_source_type = None if allow_generated_anchor else MemorySourceType.PREFIX_GT + anchor_include_generated = allow_generated_anchor + anchor_token_rows = [] + anchor_mask_rows = [] + anchor_max_rows = [] + for batch_idx, anchor_bank in enumerate(anchor_banks): + batch_token_rows = [] + batch_mask_rows = [] + batch_max_rows = [] + for target_idx in range(T_tgt): + target_frame = int(target_frame_indices[target_idx, batch_idx].item()) + records = anchor_bank.query( + MemoryBankQuery( + target_frame=target_frame, + source_type=anchor_source_type, + include_generated=anchor_include_generated, + max_records=len(anchor_indices), + max_slots=anchor_slots, + ) + ) + anchor_bank.assert_causal(target_frame, records) + stream_tokens, stream_mask, max_source_frame = self._records_to_stream( + records, + anchor_slots, + hidden_size, + stream_device, + stream_dtype, + ) + batch_token_rows.append(stream_tokens) + batch_mask_rows.append(stream_mask) + batch_max_rows.append(torch.as_tensor(max_source_frame, device=stream_device, dtype=torch.long)) + anchor_token_rows.append(torch.stack(batch_token_rows, dim=0)) + anchor_mask_rows.append(torch.stack(batch_mask_rows, dim=0)) + anchor_max_rows.append(torch.stack(batch_max_rows, dim=0)) + anchor_tokens = torch.stack(anchor_token_rows, dim=0) + anchor_mask = torch.stack(anchor_mask_rows, dim=0) + anchor_max = torch.stack(anchor_max_rows, dim=0) + + if dynamic_latents is None or dynamic_frame_indices is None or dynamic_latents.shape[0] == 0: + _fallback_h = int(self.x_stacked_shape[-2]) if len(self.x_stacked_shape) >= 2 else 18 + _fallback_w = int(self.x_stacked_shape[-1]) if len(self.x_stacked_shape) >= 1 else 32 + dynamic_num_slots = self.dememwm_dynamic_compressor.tokens_per_target(_fallback_h, _fallback_w) + dynamic_tokens = torch.zeros((B, T_tgt, dynamic_num_slots, hidden_size), device=stream_device, dtype=stream_dtype) + dynamic_mask = torch.zeros((B, T_tgt, dynamic_num_slots), device=stream_device, dtype=torch.bool) + dynamic_diag = { + "selected_source_count": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device), + "max_source_frame": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), + "generated_source_fraction": torch.zeros((B, T_tgt), dtype=torch.float32, device=stream_device), + "dynamic_min_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), + "dynamic_max_gap_to_target_per_target": torch.full((B, T_tgt), -1, dtype=torch.long, device=stream_device), + "dynamic_overlap_with_c_short_count_per_target": torch.zeros((B, T_tgt), dtype=torch.long, device=stream_device), + "dynamic_exclude_latest_local_frames": exclude_latest_local_frames, + } + else: + # Pre-select dynamic source frame positions using only frame index metadata + # before touching latents, so we pass a small slice instead of the full + # 1000-frame tensor to the compressor. + _dfi = dynamic_frame_indices.to(device=stream_device) + _max_src = self.dememwm_dynamic_compressor.max_source_frames + _needed: list[int] = [] + for _b in range(B): + for _j in range(T_tgt): + _target = int(target_frame_indices[_j, _b].item()) + _valid = (_dfi[:, _b] < _target - exclude_latest_local_frames).nonzero(as_tuple=False).flatten() + _needed.extend(_valid[-_max_src:].tolist()) + if _needed: + _needed_idx = torch.tensor(sorted(set(_needed)), device=stream_device, dtype=torch.long) + _dynamic_latents_small = dynamic_latents.index_select(0, _needed_idx) + _dynamic_fi_small = _dfi.index_select(0, _needed_idx) + _dynamic_pose_small = dynamic_pose.index_select(0, _needed_idx) if dynamic_pose is not None else None + _dynamic_gen_small = ( + dynamic_generated.to(device=stream_device, dtype=torch.bool).index_select(0, _needed_idx) + if dynamic_generated is not None else None + ) + else: + _dynamic_latents_small = dynamic_latents[:0] + _dynamic_fi_small = _dfi[:0] + _dynamic_pose_small = dynamic_pose[:0] if dynamic_pose is not None else None + _dynamic_gen_small = None + dynamic_tokens, dynamic_mask, dynamic_diag = self.dememwm_dynamic_compressor( + _dynamic_latents_small, + _dynamic_fi_small, + _dynamic_pose_small, + target_frame_indices, + _dynamic_gen_small, + exclude_latest_local_frames=exclude_latest_local_frames, + ) + + dynamic_min_gap_tensor = torch.as_tensor( + dynamic_diag.get("dynamic_min_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)), + device=stream_device, + ) + dynamic_max_gap_tensor = torch.as_tensor( + dynamic_diag.get("dynamic_max_gap_to_target_per_target", torch.full((B, T_tgt), -1, device=stream_device)), + device=stream_device, + ) + dynamic_gap_valid = dynamic_min_gap_tensor >= 0 + dynamic_min_gap_to_target = int(dynamic_min_gap_tensor[dynamic_gap_valid].min().item()) if dynamic_gap_valid.any() else -1 + dynamic_max_gap_valid = dynamic_max_gap_tensor >= 0 + dynamic_max_gap_to_target = int(dynamic_max_gap_tensor[dynamic_max_gap_valid].max().item()) if dynamic_max_gap_valid.any() else -1 + def _target_tensor_or_none(tensor: torch.Tensor | None, batch_idx: int, target_idx: int): + if tensor is None or tensor.ndim < 2: + return None + tensor_dev = tensor.to(device=stream_device) + if tensor_dev.shape[0] == T_tgt and tensor_dev.shape[1] == B: + return tensor_dev[target_idx, batch_idx] + if tensor_dev.shape[0] == B and tensor_dev.shape[1] == T_tgt: + return tensor_dev[batch_idx, target_idx] + return None + + def _target_video_id_or_none(batch_idx: int, target_idx: int): + if target_video_ids is None: + return None + if torch.is_tensor(target_video_ids): + ids = target_video_ids.detach().cpu() + if ids.ndim == 0: + return ids.item() + if ids.ndim >= 2 and ids.shape[0] == T_tgt and ids.shape[1] == B: + return ids[target_idx, batch_idx].item() + if ids.ndim >= 2 and ids.shape[0] == B and ids.shape[1] == T_tgt: + return ids[batch_idx, target_idx].item() + return None + if isinstance(target_video_ids, (list, tuple)): + if len(target_video_ids) == B: + return target_video_ids[batch_idx] + if len(target_video_ids) == T_tgt: + row = target_video_ids[target_idx] + if isinstance(row, (list, tuple)) and len(row) == B: + return row[batch_idx] + return row + return target_video_ids + + target_pose_source = target_pose if target_pose is not None else pose + + revisit_token_rows = [] + revisit_mask_rows = [] + revisit_max_rows = [] + valid_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) + revisit_candidate_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) + revisit_selected_count = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) + revisit_best_selected_fov_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) + revisit_best_selected_plucker_overlap = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.float32) + revisit_selected_gap_frames = torch.full((B, T_tgt), -1.0, device=stream_device, dtype=torch.float32) + valid_revisit_target_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) + eval_corrupted_revisit_mask = torch.zeros((B, T_tgt), device=stream_device, dtype=torch.bool) + revisit_causal_max = torch.full((B, T_tgt), -1, device=stream_device, dtype=torch.long) + eval_corruption_enabled = bool(eval_ablation_enabled and eval_ablation_branch in EVAL_CORRUPTION_BRANCHES) + revisit_result_diagnostics = [] + projected_revisit_record_cache: dict[tuple[int, str, int, int, int, bool], MemoryRecord] = {} + if revisit_record_batches is None: + revisit_record_batches = [tuple(bank.records) for bank in revisit_banks] + for batch_idx in range(B): + revisit_bank = revisit_banks[batch_idx] + batch_token_rows = [] + batch_mask_rows = [] + batch_max_rows = [] + for target_idx in range(T_tgt): + target_frame = int(target_frame_indices[target_idx, batch_idx].item()) + if use_cache_revisit_records: + candidate_records = list(revisit_record_batches[batch_idx]) + else: + candidate_records = revisit_bank.query( + MemoryBankQuery( + target_frame=target_frame, + include_generated=True, + ) + ) + result = deterministic_revisit_retrieval( + candidate_records, + target_frame=target_frame, + target_pose=_target_tensor_or_none(target_pose_source, batch_idx, target_idx), + target_summary=None, + topk=revisit_max_frames, + exclude_local_context_frames=local_context_exclusion_frames, + fov_overlap_threshold=fov_overlap_threshold, + plucker_weight=plucker_weight, + target_video_id=_target_video_id_or_none(batch_idx, target_idx), + **revisit_retrieval_kwargs, + ) + selected_records = result.records + if use_cache_revisit_records and selected_records: + selected_records = self._project_streaming_revisit_records( + cache=cache, + batch_idx=batch_idx, + records=selected_records, + device=stream_device, + dtype=stream_dtype, + token_patch_size=token_patch_size, + revisit_pool_h=revisit_pool_h, + revisit_pool_w=revisit_pool_w, + projection_cache=projected_revisit_record_cache, + ) + revisit_result_diagnostics.append(result.diagnostics) + revisit_candidate_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_candidate_frame_count", result.diagnostics.get("revisit_candidate_count", 0))) + revisit_selected_count[batch_idx, target_idx] = float(result.diagnostics.get("revisit_selected_frame_count", result.diagnostics.get("revisit_selected_count", 0))) + revisit_best_selected_fov_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_fov_overlap", 0.0)) + revisit_best_selected_plucker_overlap[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_plucker_overlap", 0.0)) + revisit_selected_gap_frames[batch_idx, target_idx] = float(result.diagnostics.get("best_selected_gap_frames", -1)) + valid_revisit_target_mask[batch_idx, target_idx] = bool(result.diagnostics.get("valid_revisit_target_count", 0)) + revisit_bank.assert_causal(target_frame, selected_records) + if selected_records: + valid_revisit_mask[batch_idx, target_idx] = True + stream_tokens, stream_mask, max_source_frame = self._records_to_stream( + selected_records, + revisit_slots, + hidden_size, + stream_device, + stream_dtype, + ) + revisit_causal_max[batch_idx, target_idx] = max_source_frame + if eval_corruption_enabled: + stream_tokens, was_corrupted = apply_revisit_eval_corruption( + tokens=stream_tokens, + mask=stream_mask, + branch=eval_ablation_branch, + target_frame=target_frame, + ) + eval_corrupted_revisit_mask[batch_idx, target_idx] = bool(was_corrupted) + actual_max_source_frame = max((int(record.max_source_frame) for record in selected_records), default=max_source_frame) + batch_token_rows.append(stream_tokens) + batch_mask_rows.append(stream_mask) + batch_max_rows.append(torch.as_tensor(actual_max_source_frame, device=stream_device, dtype=torch.long)) + revisit_token_rows.append(torch.stack(batch_token_rows, dim=0)) + revisit_mask_rows.append(torch.stack(batch_mask_rows, dim=0)) + revisit_max_rows.append(torch.stack(batch_max_rows, dim=0)) + revisit_tokens = torch.stack(revisit_token_rows, dim=0) + revisit_mask = torch.stack(revisit_mask_rows, dim=0) + revisit_max = torch.stack(revisit_max_rows, dim=0) + + if anchor_tokens.shape[-2] != anchor_num_tokens: + raise AssertionError(f"anchor token budget mismatch: got {anchor_tokens.shape[-2]}, expected {anchor_num_tokens}") + if dynamic_latents is not None and dynamic_latents.shape[0] > 0: + _expected_dyn = self.dememwm_dynamic_compressor.tokens_per_target( + int(dynamic_latents.shape[-2]), int(dynamic_latents.shape[-1]) + ) + if dynamic_tokens.shape[-2] != _expected_dyn: + raise AssertionError(f"dynamic token budget mismatch: got {dynamic_tokens.shape[-2]}, expected {_expected_dyn}") + if revisit_tokens.shape[-2] > revisit_max_tokens: + raise AssertionError(f"revisit token cap exceeded: got {revisit_tokens.shape[-2]}, cap {revisit_max_tokens}") + anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0 + dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0 + gate_module = getattr(self, "dememwm_revisit_gate", None) + if gate_module is None: + revisit_gate_raw = torch.ones((B, T_tgt), device=stream_device, dtype=stream_dtype) + else: + revisit_gate_raw = gate_module( + valid_revisit_mask=valid_revisit_mask, + best_selected_fov_overlap=revisit_best_selected_fov_overlap, + best_selected_plucker_overlap=revisit_best_selected_plucker_overlap, + selected_gap_frames=revisit_selected_gap_frames, + ).to(device=stream_device, dtype=stream_dtype) + valid_revisit_eff_mask = valid_revisit_mask + if not revisit_stage_config_enabled or force_revisit_off: + revisit_gate = torch.zeros_like(revisit_gate_raw) + elif force_revisit_on: + revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * torch.ones_like(revisit_gate_raw) + else: + revisit_gate = valid_revisit_eff_mask.to(device=stream_device, dtype=stream_dtype) * revisit_gate_raw * float(gates.revisit_gate) + revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item()) + if not anchor_effective_enabled: + anchor_mask = torch.zeros_like(anchor_mask) + if not dynamic_effective_enabled: + dynamic_mask = torch.zeros_like(dynamic_mask) + if not revisit_stage_config_enabled: + revisit_mask = torch.zeros_like(revisit_mask) + valid_revisit_mask = torch.zeros_like(valid_revisit_mask) + revisit_candidate_count = torch.zeros_like(revisit_candidate_count) + revisit_selected_count = torch.zeros_like(revisit_selected_count) + revisit_best_selected_fov_overlap = torch.zeros_like(revisit_best_selected_fov_overlap) + revisit_best_selected_plucker_overlap = torch.zeros_like(revisit_best_selected_plucker_overlap) + revisit_selected_gap_frames = torch.full_like(revisit_selected_gap_frames, -1.0) + valid_revisit_target_mask = torch.zeros_like(valid_revisit_target_mask) + eval_corrupted_revisit_mask = torch.zeros_like(eval_corrupted_revisit_mask) + valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask) + revisit_gate_raw = torch.zeros_like(revisit_gate_raw) + revisit_gate = torch.zeros_like(revisit_gate) + no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask) + revisit_diag = summarize_revisit_diagnostics(revisit_result_diagnostics, valid_revisit_mask) + causal_violation_count = 0 + for source_max in (anchor_max, dynamic_diag.get("max_source_frame"), revisit_causal_max): + if source_max is None: + continue + source_max_t = torch.as_tensor(source_max, device=target_frame_indices.device) + valid = source_max_t >= 0 + if valid.any(): + causal_violation_count += int((source_max_t[valid] >= target_frame_indices.transpose(0, 1)[valid]).sum().item()) + diagnostics = { + **curriculum_state.diagnostics(), + **getattr(self, "_last_dememwm_freeze_diagnostics", {}), + **contract_diag, + **cache_diag, + **preselection_diag, + **revisit_diag, + "dememwm_stage": gates.stage, + "dememwm_gate_reason": gates.reason, + "anchor_config_enabled": anchor_config_enabled, + "dynamic_config_enabled": dynamic_config_enabled, + "revisit_config_enabled": revisit_config_enabled, + "anchor_effective_enabled": anchor_effective_enabled, + "dynamic_effective_enabled": dynamic_effective_enabled, + "revisit_effective_enabled": revisit_effective_enabled, + "revisit_stage_config_enabled": revisit_stage_config_enabled, + "revisit_gate_raw": revisit_gate_raw.detach(), + "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)), + "no_valid_revisit_mask": no_valid_revisit_mask, + "valid_revisit_eff_mask": valid_revisit_eff_mask, + "valid_revisit_target_mask": valid_revisit_target_mask, + "revisit_candidate_frame_count_per_target": revisit_candidate_count, + "revisit_selected_frame_count_per_target": revisit_selected_count, + "revisit_best_selected_fov_overlap_per_target": revisit_best_selected_fov_overlap, + "revisit_best_selected_plucker_overlap_per_target": revisit_best_selected_plucker_overlap, + "revisit_selected_gap_frames_per_target": revisit_selected_gap_frames, + "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0, + "revisit_effective_gate_mean": float(torch.as_tensor(revisit_gate, device=stream_device).float().mean().item()), + **summarize_noise_bucket_diagnostics( + noise_bucket=resolved_noise_bucket, + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + noise_bucket_ids=noise_bucket_ids, + ), + **summarize_eval_ablation_diagnostics( + enabled=eval_ablation_enabled, + branch=eval_ablation_branch, + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + eval_corrupted_revisit_mask=eval_corrupted_revisit_mask if eval_corruption_enabled else None, + ), + "token_patch_size": token_patch_size, + "tokens_per_frame": tokens_per_frame, + "anchor_token_slots": int(anchor_tokens.shape[-2]), + "anchor_budget_tokens": anchor_num_tokens, + "anchor_pool_h": anchor_pool_h, + "anchor_pool_w": anchor_pool_w, + "dynamic_token_slots": int(dynamic_tokens.shape[-2]), + "dynamic_budget_tokens": int(dynamic_tokens.shape[-2]), + "dynamic_min_gap_to_target": dynamic_min_gap_to_target, + "dynamic_max_gap_to_target": dynamic_max_gap_to_target, + "dynamic_exclude_latest_local_frames": exclude_latest_local_frames, + "revisit_token_slots": int(revisit_tokens.shape[-2]), + "revisit_max_tokens": revisit_max_tokens, + "revisit_local_context_exclusion_frames": local_context_exclusion_frames, + "revisit_high_quality_fov_threshold": high_quality_fov_threshold, + "revisit_pool_h": revisit_pool_h, + "revisit_pool_w": revisit_pool_w, + "revisit_max_frames": revisit_max_frames, + "anchor_valid_tokens_per_target_max": int(anchor_mask.sum(dim=-1).max().item()) if anchor_mask.numel() else 0, + "dynamic_valid_tokens_per_target_max": int(dynamic_mask.sum(dim=-1).max().item()) if dynamic_mask.numel() else 0, + "revisit_valid_tokens_per_target_max": int(revisit_mask.sum(dim=-1).max().item()) if revisit_mask.numel() else 0, + "causal_violation_count": causal_violation_count, + "anchor_max_source_frame": anchor_max, + "dynamic_max_source_frame": dynamic_diag.get("max_source_frame"), + "revisit_max_source_frame": revisit_max, + "dynamic_generated_source_fraction": dynamic_diag.get("generated_source_fraction"), + } + if eval_corruption_enabled: + diagnostics["eval_corrupted_revisit_mask"] = eval_corrupted_revisit_mask + + return MemoryStreamTensors( + anchor_tokens=anchor_tokens, + anchor_mask=anchor_mask, + dynamic_tokens=dynamic_tokens, + dynamic_mask=dynamic_mask, + revisit_tokens=revisit_tokens, + revisit_mask=revisit_mask, + anchor_gate=anchor_gate, + dynamic_gate=dynamic_gate, + revisit_gate=revisit_gate, + revisit_gate_raw=revisit_gate_raw, + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + diagnostics=diagnostics, + ) + + def _refresh_stream_gates( + self, + streams: MemoryStreamTensors, + denoising_fraction: float | None = None, + noise_bucket: str | None = None, + ) -> MemoryStreamTensors: + gate_state = self._effective_gate_state( + denoising_fraction=denoising_fraction, + noise_bucket=noise_bucket, + ) + gates = gate_state["gates"] + device = streams.anchor_tokens.device + dtype = streams.anchor_tokens.dtype + B, T_tgt = streams.anchor_tokens.shape[:2] + valid_revisit_mask = streams.valid_revisit_mask + if valid_revisit_mask is None: + valid_revisit_mask = torch.zeros((B, T_tgt), device=device, dtype=torch.bool) + else: + valid_revisit_mask = valid_revisit_mask.to(device=device, dtype=torch.bool) + + diagnostics = dict(streams.diagnostics) + + def _diagnostic_tensor(name: str, fill_value: float = 0.0) -> torch.Tensor: + value = diagnostics.get(name) + if value is None: + return torch.full((B, T_tgt), float(fill_value), device=device, dtype=torch.float32) + tensor = torch.as_tensor(value, device=device, dtype=torch.float32) + if tensor.ndim == 0: + return torch.full((B, T_tgt), float(tensor.item()), device=device, dtype=torch.float32) + return tensor.expand((B, T_tgt)) + + revisit_best_selected_fov_overlap = _diagnostic_tensor("revisit_best_selected_fov_overlap_per_target") + revisit_best_selected_plucker_overlap = _diagnostic_tensor("revisit_best_selected_plucker_overlap_per_target") + revisit_selected_gap_frames = _diagnostic_tensor("revisit_selected_gap_frames_per_target", -1.0) + + anchor_effective_enabled = gate_state["anchor_effective_enabled"] + dynamic_effective_enabled = gate_state["dynamic_effective_enabled"] + revisit_stage_config_enabled = gate_state["revisit_stage_config_enabled"] + anchor_gate = gates.anchor_gate if anchor_effective_enabled else 0.0 + dynamic_gate = gates.dynamic_gate if dynamic_effective_enabled else 0.0 + gate_module = getattr(self, "dememwm_revisit_gate", None) + if gate_module is None: + revisit_gate_raw = torch.ones((B, T_tgt), device=device, dtype=dtype) + else: + revisit_gate_raw = gate_module( + valid_revisit_mask=valid_revisit_mask, + best_selected_fov_overlap=revisit_best_selected_fov_overlap, + best_selected_plucker_overlap=revisit_best_selected_plucker_overlap, + selected_gap_frames=revisit_selected_gap_frames, + ).to(device=device, dtype=dtype) + valid_revisit_eff_mask = valid_revisit_mask + if not revisit_stage_config_enabled or gate_state["force_revisit_off"]: + revisit_gate = torch.zeros_like(revisit_gate_raw) + elif gate_state["force_revisit_on"]: + revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * torch.ones_like(revisit_gate_raw) + else: + revisit_gate = valid_revisit_eff_mask.to(device=device, dtype=dtype) * revisit_gate_raw * float(gates.revisit_gate) + if not revisit_stage_config_enabled: + valid_revisit_mask = torch.zeros_like(valid_revisit_mask) + valid_revisit_eff_mask = torch.zeros_like(valid_revisit_eff_mask) + revisit_gate_raw = torch.zeros_like(revisit_gate_raw) + revisit_gate = torch.zeros_like(revisit_gate) + no_valid_revisit_mask = (~valid_revisit_mask) if revisit_stage_config_enabled else torch.zeros_like(valid_revisit_mask) + eval_corrupted_revisit_mask = diagnostics.get("eval_corrupted_revisit_mask") + if eval_corrupted_revisit_mask is not None: + eval_corrupted_revisit_mask = torch.as_tensor(eval_corrupted_revisit_mask, device=device, dtype=torch.bool) + revisit_effective_enabled = bool(revisit_stage_config_enabled and (revisit_gate > 0).any().item()) + diagnostics.update(gate_state["curriculum_state"].diagnostics()) + diagnostics.update({ + "dememwm_stage": gates.stage, + "dememwm_gate_reason": gates.reason, + "anchor_config_enabled": gate_state["anchor_config_enabled"], + "dynamic_config_enabled": gate_state["dynamic_config_enabled"], + "revisit_config_enabled": gate_state["revisit_config_enabled"], + "anchor_effective_enabled": anchor_effective_enabled, + "dynamic_effective_enabled": dynamic_effective_enabled, + "revisit_effective_enabled": revisit_effective_enabled, + "revisit_stage_config_enabled": revisit_stage_config_enabled, + "revisit_gate_raw": revisit_gate_raw.detach(), + "revisit_gate_eff": revisit_gate.detach() if torch.is_tensor(revisit_gate) else torch.tensor(float(revisit_gate)), + "no_valid_revisit_mask": no_valid_revisit_mask, + "valid_revisit_eff_mask": valid_revisit_eff_mask, + "revisit_learned_gate_mean": float(revisit_gate_raw.detach().float().mean().item()) if revisit_gate_raw.numel() else 0.0, + "revisit_effective_gate_mean": float(revisit_gate.detach().float().mean().item()) if revisit_gate.numel() else 0.0, + }) + diagnostics.update(summarize_noise_bucket_diagnostics( + noise_bucket=gate_state["resolved_noise_bucket"], + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + )) + diagnostics.update(summarize_eval_ablation_diagnostics( + enabled=gate_state["eval_ablation_enabled"], + branch=gate_state["eval_ablation_branch"], + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + eval_corrupted_revisit_mask=eval_corrupted_revisit_mask, + )) + return replace( + streams, + anchor_gate=anchor_gate, + dynamic_gate=dynamic_gate, + revisit_gate=revisit_gate, + revisit_gate_raw=revisit_gate_raw, + valid_revisit_mask=valid_revisit_mask, + no_valid_revisit_mask=no_valid_revisit_mask, + diagnostics=diagnostics, + ) + + def _streams_to_kwargs(self, streams: MemoryStreamTensors) -> tuple[dict, dict]: + memory_kwargs, diagnostics = self.dememwm_injection_adapter(streams, device=streams.anchor_tokens.device, dtype=streams.anchor_tokens.dtype) + return memory_kwargs, diagnostics + + def build_memory_kwargs(self, *args, **kwargs) -> tuple[dict, dict]: + streams = self.build_memory_streams(*args, **kwargs) + return self._streams_to_kwargs(streams) + + def _memory_adapter_delta_diagnostics(self) -> dict: + dit_model = getattr(getattr(self, "diffusion_model", None), "model", None) + diagnostics_fn = getattr(dit_model, "memory_adapter_delta_diagnostics", None) + if diagnostics_fn is None: + return {} + return diagnostics_fn() + + def _log_memory_diagnostics(self, namespace: str, diagnostics: dict) -> None: + if namespace == "training/dememwm": + allowed_keys = self._TRAIN_DIAGNOSTIC_LOG_KEYS + elif namespace.endswith("/dememwm"): + allowed_keys = self._VALIDATION_DIAGNOSTIC_LOG_KEYS + else: + allowed_keys = None + for key, value in diagnostics.items(): + if allowed_keys is not None and key not in allowed_keys: + continue + if isinstance(value, str) or value is None: + continue + if torch.is_tensor(value): + if value.numel() > 0: + self.log(f"{namespace}/{key}", value.float().mean().item(), prog_bar=False, sync_dist=True) + elif isinstance(value, (bool, int, float)): + self.log(f"{namespace}/{key}", float(value), prog_bar=False, sync_dist=True) + + def _training_pose_condition(self, xs, pose_conditions, c2w_mat, frame_idx): + from ..df_video import convert_to_plucker + image_height, image_width = self._image_size(xs) + if self.use_plucker: + if self.relative_embedding: + input_pose_condition = [] + frame_idx_list = [] + ref_c2w = c2w_mat[-self.memory_condition_length:] if self.memory_condition_length else c2w_mat[:0] + ref_idx = frame_idx[-self.memory_condition_length:] if self.memory_condition_length else frame_idx[:0] + for i in range(c2w_mat.shape[0]): + input_pose_condition.append( + convert_to_plucker( + torch.cat([c2w_mat[i:i + 1], ref_c2w]).clone(), + 0, + focal_length=self.focal_length, + image_height=image_height, image_width=image_width + ).to(xs.dtype) + ) + frame_idx_list.append(torch.cat([frame_idx[i:i + 1] - frame_idx[i:i + 1], ref_idx - frame_idx[i:i + 1]]).clone()) + return torch.cat(input_pose_condition), torch.cat(frame_idx_list) + return convert_to_plucker( + c2w_mat, 0, focal_length=self.focal_length, + image_height=image_height, image_width=image_width + ).to(xs.dtype), frame_idx + return pose_conditions.to(xs.dtype), None + + def _training_window_bounds(self, total_frames: int, device: torch.device) -> tuple[int, int]: + total_frames = max(0, int(total_frames)) + n_tokens = max(1, min(int(self.n_tokens), total_frames)) + max_start = max(0, total_frames - n_tokens) + if max_start == 0: + return 0, n_tokens + context_start = self._context_frame_count() + min_start = min(context_start, max_start) + if min_start == max_start: + return min_start, min_start + n_tokens + start = int(torch.randint(min_start, max_start + 1, (1,), device=device).item()) + return start, start + n_tokens + + def training_step(self, batch, batch_idx): + xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) + xs = self._as_latents(xs) + + # Randomly select a contiguous n_tokens denoising window inside the long + # clip. DeMemWM memory streams are selected causally from frames before + # each target, then only those selected frames are projected. + total_frames = xs.shape[0] + start, end = self._training_window_bounds(total_frames, xs.device) + + xs_window = xs[start:end] + conditions_window = conditions[start:end].clone() + frame_idx_window = frame_idx[start:end] + + input_pose_condition, frame_idx_list = self._training_pose_condition( + xs_window, pose_conditions[start:end], c2w_mat[start:end], frame_idx_window + ) + + noise_levels = self._generate_noise_levels(xs_window) + if self.memory_condition_length: + noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level + conditions_window[-self.memory_condition_length:] *= 0 + source_is_generated = torch.zeros(frame_idx.shape, device=frame_idx.device, dtype=torch.bool) + memory_source_latents, source_is_generated, proxy_diagnostics = self._apply_generated_history_proxy( + xs, + source_is_generated, + context_frame_count=self._context_frame_count(), + target_start_frame=start, + ) + timesteps = int(getattr(self, "timesteps", 0) or 0) + training_noise_bucket = noise_bucket_from_noise_levels(noise_levels, timesteps) + training_noise_bucket_ids = noise_bucket_ids_from_noise_levels(noise_levels, timesteps) + training_denoising_fraction = denoising_fraction_from_noise_levels(noise_levels, timesteps) + memory_kwargs, diagnostics = self.build_memory_kwargs( + memory_source_latents, + frame_idx, + target_frame_indices=frame_idx_window, + pose=pose_conditions, + target_pose=pose_conditions[start:end], + action=conditions, + target_action=conditions_window, + source_is_generated=source_is_generated, + denoising_fraction=training_denoising_fraction, + noise_bucket=training_noise_bucket, + noise_bucket_ids=None if training_noise_bucket_ids is None else training_noise_bucket_ids.transpose(0, 1), + ) + diagnostics.update(proxy_diagnostics) + _, loss = self.diffusion_model( + xs_window, + conditions_window, + input_pose_condition, + noise_levels=noise_levels, + reference_length=self.memory_condition_length, + frame_idx=frame_idx_list, + **memory_kwargs, + ) + diagnostics.update(self._memory_adapter_delta_diagnostics()) + if self.memory_condition_length: + loss = loss[:-self.memory_condition_length] + loss_denoise = self.reweight_loss(loss, None) + loss_total = loss_denoise + diagnostics["training_window_start"] = int(start) + diagnostics["training_window_end"] = int(end) + diagnostics["training_window_size"] = int(end - start) + diagnostics["loss_denoise"] = float(loss_denoise.detach().item()) + diagnostics["loss_total"] = float(loss_total.detach().item()) + if batch_idx % 20 == 0: + self.log("training/loss", loss_total.detach().cpu()) + self._log_memory_diagnostics("training/dememwm", diagnostics) + return {"loss": loss_total} + + def validation_step(self, batch, batch_idx, namespace="validation"): + import numpy as np + from tqdm import tqdm + + memory_condition_length = self.memory_condition_length + xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) + total_frame = xs_raw.shape[0] + if bool(getattr(self, "_last_dememwm_xs_are_latents", False)): + xs = xs_raw.cpu() + elif total_frame > 10: + xs = torch.cat([self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu() for i in range(10)]) + else: + xs = self.encode(xs_raw).cpu() + n_frames, batch_size, *_ = xs.shape + curr_frame = 0 + n_context_frames = self.context_frames // self.frame_stack + xs_pred = xs[:n_context_frames].clone() + curr_frame += n_context_frames + streaming_cache = self._new_streaming_cache(video_id=f"{namespace}:{batch_idx}") + cached_until = 0 + pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") + last_diagnostics = None + while curr_frame < n_frames: + if streaming_cache is not None and curr_frame > cached_until: + new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device) + if curr_frame > n_context_frames: + rel_start = max(0, n_context_frames - cached_until) + new_generated[rel_start:] = True + self._update_streaming_cache( + streaming_cache, + xs_pred[cached_until:curr_frame], + frame_idx[cached_until:curr_frame], + pose=pose_conditions[cached_until:curr_frame], + source_is_generated=new_generated, + action=conditions[cached_until:curr_frame], + ) + cached_until = curr_frame + horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame + assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens." + scheduling_matrix = self._generate_scheduling_matrix(horizon) + chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])) + chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device) + xs_pred = torch.cat([xs_pred, chunk], 0) + start_frame = max(0, curr_frame + horizon - self.n_tokens) + pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon}) + if memory_condition_length: + random_idx = self._generate_condition_indices(curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon) + xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0) + else: + random_idx = torch.empty((0, batch_size), dtype=torch.long, device=frame_idx.device) + input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions( + start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx, + image_width=self._image_size(xs_raw)[1], image_height=self._image_size(xs_raw)[0] + ) + target_idx = frame_idx[start_frame:curr_frame + horizon].to(input_condition.device) + use_streaming_cache = streaming_cache is not None and streaming_cache.record_count > 0 + target_pose = pose_conditions[start_frame:curr_frame + horizon].to(input_condition.device) + target_action = conditions[start_frame:curr_frame + horizon].to(input_condition.device) + if use_streaming_cache: + committed_latents = None + committed_idx = None + generated_flags = None + source_pose = None + source_action = None + else: + committed_latents = xs_pred[:curr_frame].to(input_condition.device) + committed_idx = frame_idx[:curr_frame].to(input_condition.device) + generated_flags = torch.zeros(committed_idx.shape, device=input_condition.device, dtype=torch.bool) + if curr_frame > n_context_frames: + generated_flags[n_context_frames:] = True + source_pose = pose_conditions[:curr_frame].to(input_condition.device) + source_action = conditions[:curr_frame].to(input_condition.device) + memory_streams = self.build_memory_streams( + committed_latents, + committed_idx, + target_frame_indices=target_idx, + pose=source_pose, + target_pose=target_pose, + action=source_action, + target_action=target_action, + source_is_generated=generated_flags, + denoising_fraction=None, + streaming_cache=streaming_cache, + ) + for m in range(scheduling_matrix.shape[0] - 1): + from_noise_levels, to_noise_levels = self._prepare_noise_levels(scheduling_matrix, m, curr_frame, batch_size, memory_condition_length) + denoise_frac = float(m + 1) / max(float(scheduling_matrix.shape[0] - 1), 1.0) + step_streams = self._refresh_stream_gates(memory_streams, denoising_fraction=denoise_frac) + memory_kwargs, last_diagnostics = self._streams_to_kwargs(step_streams) + xs_pred[start_frame:] = self.diffusion_model.sample_step( + xs_pred[start_frame:].to(input_condition.device), + input_condition, + input_pose_condition, + from_noise_levels[start_frame:], + to_noise_levels[start_frame:], + current_frame=curr_frame, + mode="validation", + reference_length=memory_condition_length, + frame_idx=frame_idx_list, + **memory_kwargs, + ).cpu() + if memory_condition_length: + xs_pred = xs_pred[:-memory_condition_length] + curr_frame += horizon + if streaming_cache is not None and curr_frame > cached_until: + new_generated = torch.zeros(frame_idx[cached_until:curr_frame].shape, dtype=torch.bool, device=frame_idx.device) + if curr_frame > n_context_frames: + rel_start = max(0, n_context_frames - cached_until) + new_generated[rel_start:] = True + self._update_streaming_cache( + streaming_cache, + xs_pred[cached_until:curr_frame], + frame_idx[cached_until:curr_frame], + pose=pose_conditions[cached_until:curr_frame], + source_is_generated=new_generated, + action=conditions[cached_until:curr_frame], + ) + cached_until = curr_frame + if last_diagnostics is not None: + last_diagnostics.update(streaming_cache.diagnostics("cache")) + pbar.update(horizon) + pbar.close() + if last_diagnostics is not None: + self._log_memory_diagnostics(f"{namespace}/dememwm", last_diagnostics) + xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device)) + xs_decode = self.decode(xs[n_context_frames:].to(conditions.device)) + self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu())) + return + + def strict_checkpoint_key_check(self, state_dict: dict, required_prefixes: Iterable[str] | None = None) -> None: + prefixes = tuple(required_prefixes or self.strict_key_prefixes) + strip_prefixes = ("", "model.", "module.", "algo.") + normalized_keys = [] + for key in state_dict.keys(): + key = str(key) + for strip_prefix in strip_prefixes: + if not strip_prefix or key.startswith(strip_prefix): + normalized_keys.append(key.removeprefix(strip_prefix)) + missing_prefixes = [prefix for prefix in prefixes if not any(key.startswith(prefix) for key in normalized_keys)] + missing_substrings = [ + marker + for marker in self.strict_key_substrings + if not any(marker in key for key in normalized_keys) + ] + if missing_prefixes or missing_substrings: + raise RuntimeError( + "DeMemWM checkpoint is missing required DeMemWM key coverage: " + f"prefixes={missing_prefixes}, memory_adapter_markers={missing_substrings}" + ) + + # Compatibility aliases for old DeMemWM test and experiment call sites. + dememwm_strict_key_prefixes = strict_key_prefixes + dememwm_strict_key_substrings = strict_key_substrings + _DEMEMWM_TRAIN_DIAGNOSTIC_LOG_KEYS = _TRAIN_DIAGNOSTIC_LOG_KEYS + _DEMEMWM_VALIDATION_DIAGNOSTIC_LOG_KEYS = _VALIDATION_DIAGNOSTIC_LOG_KEYS + _dememwm_cfg = _memory_cfg + _dememwm_stage_policy_cfg = _stage_policy_cfg + _dememwm_eval_ablation_cfg = _eval_ablation_cfg + _dememwm_generated_history_proxy_cfg = _generated_history_proxy_cfg + _dememwm_eval_ablation_state = _eval_ablation_state + _dememwm_effective_gate_state = _effective_gate_state + _dememwm_validate_config_contract = _validate_config_contract + _dememwm_stream_enabled = _stream_enabled + _dememwm_context_frame_count = _context_frame_count + _dememwm_local_context_exclusion_frames = _local_context_exclusion_frames + _dememwm_curriculum_state = _curriculum_state + _dememwm_generated_history_proxy_prob = _generated_history_proxy_prob + _dememwm_apply_generated_history_proxy = _apply_generated_history_proxy + _dememwm_checkpoint_cfg = _checkpoint_cfg + _dememwm_strict_eval_load_enabled = _strict_eval_load_enabled + _dememwm_cache_cfg = _cache_cfg + _dememwm_cache_enabled = _cache_enabled + _dememwm_new_streaming_cache = _new_streaming_cache + _dememwm_is_memory_adapter_param = _is_memory_adapter_param + _dememwm_param_group_name = _param_group_name + _dememwm_group_trainable = _group_trainable + _dememwm_group_lr = _group_lr + _dememwm_apply_freeze_policy = _apply_freeze_policy + _dememwm_as_latents = _as_latents + _dememwm_image_size = _image_size + _dememwm_update_streaming_cache = _update_streaming_cache + _build_dememwm_streaming_cache_records = _build_streaming_cache_records + _build_dememwm_causal_memory_banks = _build_causal_memory_banks + _build_dememwm_preselected_causal_memory_banks = _build_preselected_causal_memory_banks + _dememwm_records_to_stream = _records_to_stream + build_dememwm_memory_streams = build_memory_streams + _dememwm_refresh_stream_gates = _refresh_stream_gates + _dememwm_streams_to_kwargs = _streams_to_kwargs + build_dememwm_memory_kwargs = build_memory_kwargs + _dememwm_memory_adapter_delta_diagnostics = _memory_adapter_delta_diagnostics + _log_dememwm_diagnostics = _log_memory_diagnostics + _dememwm_training_window_bounds = _training_window_bounds + strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check + + +DeMemWMMemoryDiTMixin = MemoryDiTMixin diff --git a/algorithms/worldmem/dememwm/cache.py b/algorithms/worldmem/dememwm/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..51a69cfb2498dfbe52a541456aa5b180cb3565d5 --- /dev/null +++ b/algorithms/worldmem/dememwm/cache.py @@ -0,0 +1,513 @@ + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Iterable, Optional +import warnings + +import torch + +from .memory import CausalMemoryBank +from .types import MemoryRecord + + +@dataclass +class _RawLatentSegment: + latents: torch.Tensor + frame_indices: torch.Tensor + source_is_generated: torch.Tensor + pose: Optional[torch.Tensor] + + +class StreamingCache: + """Per-video DeMemWM streaming cache with strict no-eviction semantics. + + The cache is intentionally allowed to grow for the current video. It stores + detached CPU (or pinned CPU) raw latents plus compressed MemoryRecord objects, + while DiT readout tensors remain bounded by the caller's manual budgets. + """ + + def __init__( + self, + *, + enabled: bool = True, + device: str = "cpu", + keep_raw_latents: str = "all", + keep_compressed_records: bool = True, + keep_prefix_anchors: bool = True, + eviction_policy: str = "none", + no_evict: bool = True, + clear_between_videos: bool = True, + max_records: Optional[int] = None, + max_slots: Optional[int] = None, + on_capacity_exceeded: str = "warn", + ) -> None: + self.enabled = bool(enabled) + self.device = str(device or "cpu") + self.keep_raw_latents = keep_raw_latents + self.keep_compressed_records = bool(keep_compressed_records) + self.keep_prefix_anchors = bool(keep_prefix_anchors) + self.eviction_policy = str(eviction_policy or "none") + self.no_evict = bool(no_evict) + self.clear_between_videos = bool(clear_between_videos) + self.max_records = max_records + self.max_slots = max_slots + self.on_capacity_exceeded = str(on_capacity_exceeded or "warn") + if self.eviction_policy != "none" or not self.no_evict: + raise ValueError("DeMemWMStreamingCache only supports eviction_policy='none' with no_evict=true") + if self.device not in {"cpu", "pinned_cpu", "cuda"}: + raise ValueError("cache.device must be one of: cpu, pinned_cpu, cuda") + self.reset_count = 0 + self.evictions = 0 + self.capacity_exceeded_count = 0 + self.current_video_id: Any = None + self._raw_segments: list[_RawLatentSegment] = [] + self._records: dict[str, dict[int, list[MemoryRecord]]] = {"anchor": {}, "revisit": {}} + self._raw_keys: set[tuple[int, int]] = set() + self._raw_index: dict[tuple[int, int], tuple[int, int]] = {} + self._record_keys: set[tuple[str, int, str, int, int, bool]] = set() + self._batch_size: Optional[int] = None + # Concat cache: avoids repeated torch.cat across DDIM steps within one chunk. + # Invalidated whenever new raw segments are added. + self._raw_concat_version: int = 0 + self._raw_concat_built: int = -1 + self._raw_concat_cache: Optional[tuple] = None # (latents, frame_indices, generated, pose) + # GPU memory-bank cache: avoids repeated CPU→GPU record transfers across DDIM steps. + # Invalidated whenever new records are added. + self._banks_version: int = 0 + self._banks_built_cache: dict[tuple, tuple[int, list[CausalMemoryBank]]] = {} + + @classmethod + def from_config(cls, cfg: Any, *, enabled_default: bool = True) -> "StreamingCache": + def get(name: str, default: Any) -> Any: + return getattr(cfg, name, default) if cfg is not None else default + + return cls( + enabled=bool(get("enabled", enabled_default)), + device=str(get("device", "cpu")), + keep_raw_latents=str(get("keep_raw_latents", "all")), + keep_compressed_records=bool(get("keep_compressed_records", True)), + keep_prefix_anchors=bool(get("keep_prefix_anchors", True)), + eviction_policy=str(get("eviction_policy", "none")), + no_evict=bool(get("no_evict", True)), + clear_between_videos=bool(get("clear_between_videos", True)), + max_records=get("max_records", None), + max_slots=get("max_slots", None), + on_capacity_exceeded=str(get("on_capacity_exceeded", "warn")), + ) + + @property + def batch_size(self) -> int: + return int(self._batch_size or 0) + + @property + def raw_segment_count(self) -> int: + return len(self._raw_segments) + + @property + def raw_frame_slots(self) -> int: + return sum(int(seg.latents.shape[0] * seg.latents.shape[1]) for seg in self._raw_segments) + + @property + def record_count(self) -> int: + return sum(len(records) for by_batch in self._records.values() for records in by_batch.values()) + + @property + def slot_count(self) -> int: + return sum(record.valid_slots for by_batch in self._records.values() for records in by_batch.values() for record in records) + + def records_count(self, kind: str | None = None) -> int: + if kind is None: + return self.record_count + return sum(len(records) for records in self._records.get(kind, {}).values()) + + def reset(self, video_id: Any = None) -> None: + self.current_video_id = video_id + self._raw_segments.clear() + self._records = {"anchor": {}, "revisit": {}} + self._raw_keys.clear() + self._raw_index.clear() + self._record_keys.clear() + self._batch_size = None + self.evictions = 0 + self.capacity_exceeded_count = 0 + self.reset_count += 1 + self._raw_concat_version += 1 + self._raw_concat_built = -1 + self._raw_concat_cache = None + self._banks_version += 1 + self._banks_built_cache.clear() + + def _store_tensor(self, tensor: Optional[torch.Tensor], *, dtype: torch.dtype | None = None) -> Optional[torch.Tensor]: + if tensor is None: + return None + out = tensor.detach() + if dtype is not None and out.is_floating_point(): + out = out.to(dtype=dtype) + if self.device in {"cpu", "pinned_cpu"}: + out = out.to(device="cpu", copy=True) + if self.device == "pinned_cpu": + try: + out = out.pin_memory() + except RuntimeError: + # Keep stable CPU behavior if pinning is unavailable in a worker/process. + pass + elif self.device == "cuda": + out = out.clone() + return out + + def _metadata_to_storage(self, metadata: dict) -> dict: + out = {} + for key, value in dict(metadata or {}).items(): + if torch.is_tensor(value): + out[key] = self._store_tensor(value) + elif isinstance(value, dict): + out[key] = self._metadata_to_storage(value) + else: + out[key] = value + return out + + def _metadata_to_device(self, metadata: dict, *, device: torch.device, dtype: torch.dtype) -> dict: + out = {} + for key, value in dict(metadata or {}).items(): + if torch.is_tensor(value): + tensor = value.to(device=device) + out[key] = tensor.to(dtype=dtype) if tensor.is_floating_point() else tensor + elif isinstance(value, dict): + out[key] = self._metadata_to_device(value, device=device, dtype=dtype) + else: + out[key] = value + return out + + def _record_to_storage(self, record: MemoryRecord) -> MemoryRecord: + return MemoryRecord( + tokens=self._store_tensor(record.tokens), + mask=self._store_tensor(record.mask), + source_start=int(record.source_start), + source_end=int(record.source_end), + frame_indices=self._store_tensor(record.frame_indices), + pose=self._store_tensor(record.pose), + source_type=record.source_type, + is_generated=bool(record.is_generated), + score=None if record.score is None or not torch.is_tensor(record.score) else self._store_tensor(record.score), + chunk_id=record.chunk_id, + metadata=self._metadata_to_storage(record.metadata), + ) + + def _record_to_device(self, record: MemoryRecord, *, device: torch.device, dtype: torch.dtype) -> MemoryRecord: + return MemoryRecord( + tokens=record.tokens.to(device=device, dtype=dtype), + mask=record.mask.to(device=device, dtype=torch.bool), + source_start=int(record.source_start), + source_end=int(record.source_end), + frame_indices=record.frame_indices.to(device=device), + pose=None if record.pose is None else record.pose.to(device=device), + source_type=record.source_type, + is_generated=bool(record.is_generated), + score=record.score, + chunk_id=record.chunk_id, + metadata=self._metadata_to_device(record.metadata, device=device, dtype=dtype), + ) + + def _check_capacity(self) -> None: + exceeded = False + if self.max_records is not None and self.record_count > int(self.max_records): + exceeded = True + if self.max_slots is not None and self.slot_count > int(self.max_slots): + exceeded = True + if not exceeded: + return + self.capacity_exceeded_count += 1 + msg = ( + "DeMemWMStreamingCache capacity exceeded " + f"records={self.record_count}/{self.max_records}, slots={self.slot_count}/{self.max_slots}; " + "no eviction performed because no_evict=true" + ) + if self.on_capacity_exceeded == "error": + raise RuntimeError(msg) + if self.on_capacity_exceeded == "warn": + warnings.warn(msg, RuntimeWarning, stacklevel=2) + + def add_raw_latents( + self, + latents: torch.Tensor, + frame_indices: torch.Tensor, + source_is_generated: Optional[torch.Tensor] = None, + pose: Optional[torch.Tensor] = None, + ) -> None: + if not self.enabled or self.keep_raw_latents != "all": + return + if latents.ndim != 5: + raise ValueError("cached raw latents must have shape (T,B,C,H,W)") + T, B = int(latents.shape[0]), int(latents.shape[1]) + if frame_indices.shape != (T, B): + raise ValueError("cached frame_indices must have shape (T,B)") + if self._batch_size is None: + self._batch_size = B + elif self._batch_size != B: + raise ValueError("streaming cache batch size changed within a video") + keep_positions: list[int] = [] + frame_cpu = frame_indices.detach().cpu() + for t in range(T): + keys = [(b, int(frame_cpu[t, b].item())) for b in range(B)] + if any(key not in self._raw_keys for key in keys): + keep_positions.append(t) + self._raw_keys.update(keys) + if not keep_positions: + return + pos = torch.as_tensor(keep_positions, dtype=torch.long) + seg_latents = latents.index_select(0, pos.to(device=latents.device)) + seg_frames = frame_indices.index_select(0, pos.to(device=frame_indices.device)) + if source_is_generated is None: + seg_generated = torch.zeros(seg_frames.shape, device=seg_frames.device, dtype=torch.bool) + else: + seg_generated = source_is_generated.index_select(0, pos.to(device=source_is_generated.device)).bool() + seg_pose = None if pose is None else pose.index_select(0, pos.to(device=pose.device)) + segment_idx = len(self._raw_segments) + self._raw_segments.append( + _RawLatentSegment( + latents=self._store_tensor(seg_latents), + frame_indices=self._store_tensor(seg_frames), + source_is_generated=self._store_tensor(seg_generated), + pose=self._store_tensor(seg_pose), + ) + ) + for local_pos, source_pos in enumerate(keep_positions): + for b in range(B): + key = (b, int(frame_cpu[source_pos, b].item())) + self._raw_index.setdefault(key, (segment_idx, local_pos)) + # Invalidate the concat cache — new segment was added. + self._raw_concat_version += 1 + self._raw_concat_cache = None + + def add_records(self, kind: str, batch_idx: int, records: Iterable[MemoryRecord]) -> None: + if not self.enabled or not self.keep_compressed_records: + return + if kind not in self._records: + raise ValueError(f"unsupported cache record kind: {kind}") + batch_idx = int(batch_idx) + bucket = self._records[kind].setdefault(batch_idx, []) + added_any = False + for record in records: + if kind == "anchor" and not self.keep_prefix_anchors: + continue + key = ( + kind, + batch_idx, + str(record.chunk_id or ""), + int(record.source_start), + int(record.source_end), + bool(record.is_generated), + ) + if key in self._record_keys: + continue + self._record_keys.add(key) + bucket.append(self._record_to_storage(record)) + added_any = True + if added_any: + # Invalidate the GPU banks cache — new records were added. + self._banks_version += 1 + self._banks_built_cache.clear() + self._check_capacity() + + def add_memory_banks(self, anchor_banks: list[CausalMemoryBank], revisit_banks: list[CausalMemoryBank]) -> None: + for batch_idx, bank in enumerate(anchor_banks): + self.add_records("anchor", batch_idx, bank.records) + for batch_idx, bank in enumerate(revisit_banks): + self.add_records("revisit", batch_idx, bank.records) + + def memory_banks(self, kind: str, *, device: torch.device, dtype: torch.dtype, batch_size: int | None = None) -> list[CausalMemoryBank]: + if kind not in self._records: + raise ValueError(f"unsupported cache record kind: {kind}") + B = int(batch_size or self.batch_size or (max(self._records[kind].keys()) + 1 if self._records[kind] else 0)) + cache_key = (kind, device, dtype, B) + cached = self._banks_built_cache.get(cache_key) + if cached is not None and cached[0] == self._banks_version: + return cached[1] + banks: list[CausalMemoryBank] = [] + for batch_idx in range(B): + bank = CausalMemoryBank() + for record in self._records[kind].get(batch_idx, []): + bank.add_record(self._record_to_device(record, device=device, dtype=dtype)) + banks.append(bank) + self._banks_built_cache[cache_key] = (self._banks_version, banks) + return banks + + def records_for_batch(self, kind: str, batch_idx: int) -> tuple[MemoryRecord, ...]: + if kind not in self._records: + raise ValueError(f"unsupported cache record kind: {kind}") + return tuple(self._records[kind].get(int(batch_idx), ())) + + def raw_latents_for_frames( + self, + *, + batch_idx: int, + frame_indices: torch.Tensor, + device: torch.device, + dtype: torch.dtype, + ) -> torch.Tensor: + frames = frame_indices.detach().cpu().reshape(-1) + rows = [] + batch_idx = int(batch_idx) + for frame in frames.tolist(): + key = (batch_idx, int(frame)) + location = self._raw_index.get(key) + if location is None: + raise KeyError(f"raw latent for batch={batch_idx}, frame={int(frame)} is not cached") + segment_idx, local_pos = location + rows.append(self._raw_segments[segment_idx].latents[local_pos, batch_idx]) + if not rows: + template = self._raw_segments[0].latents + return template[:0, batch_idx:batch_idx + 1].to(device=device, dtype=dtype) + return torch.stack(rows, dim=0).unsqueeze(1).to(device=device, dtype=dtype) + + def _select_time_positions( + self, + frame_indices: torch.Tensor, + target_frame_indices: Optional[torch.Tensor], + max_recent_frames: Optional[int], + exclude_latest_local_frames: int = 0, + ) -> torch.Tensor: + T, B = frame_indices.shape + if target_frame_indices is None or max_recent_frames is None or int(max_recent_frames) <= 0: + return torch.arange(T, dtype=torch.long) + targets = target_frame_indices.detach().cpu() + if targets.ndim == 1: + targets = targets[:, None].expand(-1, B) + frames = frame_indices.detach().cpu() # (T, B) + recent = int(max_recent_frames) + exclude = max(0, int(exclude_latest_local_frames)) + # Vectorized: valid[t_tgt, t_src, b] = True if source position t_src is + # causally valid for target t_tgt in batch b. + # frames (T, B) → (1, T, B); targets (T_tgt, B) → (T_tgt, 1, B) + valid = frames.unsqueeze(0) < (targets.unsqueeze(1) - exclude) # (T_tgt, T, B) + # For each (t_tgt, b), retain only the last `recent` valid positions. + # Flip T, cumsum along T (counting from the end), keep where ≤ recent. + valid_f = valid.flip(1) + keep_f = (valid_f.long().cumsum(1) <= recent) & valid_f + # Any position needed by any (t_tgt, b) pair. + keep_any = keep_f.flip(1).any(dim=0).any(dim=1) # (T,) + return keep_any.nonzero(as_tuple=False).flatten() + + def materialize_raw_latents( + self, + *, + device: torch.device, + dtype: torch.dtype, + max_recent_frames: Optional[int] = None, + target_frame_indices: Optional[torch.Tensor] = None, + exclude_latest_local_frames: int = 0, + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + if not self._raw_segments: + return None, None, None, None + if target_frame_indices is not None and max_recent_frames is not None and int(max_recent_frames) > 0: + return self._materialize_recent_raw_latents( + device=device, + dtype=dtype, + max_recent_frames=int(max_recent_frames), + target_frame_indices=target_frame_indices, + exclude_latest_local_frames=exclude_latest_local_frames, + ) + # Rebuild the concatenated CPU tensors only when new segments were added. + if self._raw_concat_cache is None or self._raw_concat_built != self._raw_concat_version: + latents = torch.cat([seg.latents for seg in self._raw_segments], dim=0) + frame_indices = torch.cat([seg.frame_indices for seg in self._raw_segments], dim=0) + generated = torch.cat([seg.source_is_generated for seg in self._raw_segments], dim=0) + pose: Optional[torch.Tensor] = None + if all(seg.pose is not None for seg in self._raw_segments): + pose = torch.cat([seg.pose for seg in self._raw_segments if seg.pose is not None], dim=0) + self._raw_concat_cache = (latents, frame_indices, generated, pose) + self._raw_concat_built = self._raw_concat_version + else: + latents, frame_indices, generated, pose = self._raw_concat_cache + pos = self._select_time_positions(frame_indices, target_frame_indices, max_recent_frames, exclude_latest_local_frames) + if pos.numel() == 0: + empty_latents = latents[:0].to(device=device, dtype=dtype) + empty_frames = frame_indices[:0].to(device=device) + empty_generated = generated[:0].to(device=device, dtype=torch.bool) + empty_pose = None if pose is None else pose[:0].to(device=device) + return empty_latents, empty_frames, empty_generated, empty_pose + latents = latents.index_select(0, pos.to(device=latents.device)).to(device=device, dtype=dtype) + frame_indices = frame_indices.index_select(0, pos.to(device=frame_indices.device)).to(device=device) + generated = generated.index_select(0, pos.to(device=generated.device)).to(device=device, dtype=torch.bool) + if pose is not None: + pose = pose.index_select(0, pos.to(device=pose.device)).to(device=device) + return latents, frame_indices, generated, pose + + def _materialize_recent_raw_latents( + self, + *, + device: torch.device, + dtype: torch.dtype, + max_recent_frames: int, + target_frame_indices: torch.Tensor, + exclude_latest_local_frames: int = 0, + ) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + B = self.batch_size + targets = target_frame_indices.detach().cpu() + if targets.ndim == 1: + targets = targets[:, None].expand(-1, B) + elif targets.shape[1] == 1 and B > 1: + targets = targets.expand(-1, B) + if targets.shape[1] != B: + raise ValueError("target_frame_indices batch dimension does not match streaming cache") + + recent = max(0, int(max_recent_frames)) + exclude = max(0, int(exclude_latest_local_frames)) + counts = torch.zeros(targets.shape, dtype=torch.long) + selected: list[tuple[_RawLatentSegment, int]] = [] + + for segment in reversed(self._raw_segments): + frames = segment.frame_indices.detach().cpu() + for local_pos in range(frames.shape[0] - 1, -1, -1): + valid = frames[local_pos].unsqueeze(0) < (targets - exclude) + needed = valid & (counts < recent) + if not needed.any(): + continue + selected.append((segment, local_pos)) + counts += needed.long() + if bool((counts >= recent).all().item()): + break + if bool((counts >= recent).all().item()): + break + + if not selected: + template = self._raw_segments[0] + empty_latents = template.latents[:0].to(device=device, dtype=dtype) + empty_frames = template.frame_indices[:0].to(device=device) + empty_generated = template.source_is_generated[:0].to(device=device, dtype=torch.bool) + empty_pose = None if template.pose is None else template.pose[:0].to(device=device) + return empty_latents, empty_frames, empty_generated, empty_pose + + selected.reverse() + latents = torch.stack([segment.latents[local_pos] for segment, local_pos in selected], dim=0).to(device=device, dtype=dtype) + frame_indices = torch.stack([segment.frame_indices[local_pos] for segment, local_pos in selected], dim=0).to(device=device) + generated = torch.stack([segment.source_is_generated[local_pos] for segment, local_pos in selected], dim=0).to(device=device, dtype=torch.bool) + pose = None + if all(segment.pose is not None for segment, _ in selected): + pose = torch.stack( + [segment.pose[local_pos] for segment, local_pos in selected if segment.pose is not None], + dim=0, + ).to(device=device) + return latents, frame_indices, generated, pose + + def diagnostics(self, prefix: str = "cache") -> dict[str, Any]: + return { + f"{prefix}_enabled": bool(self.enabled), + f"{prefix}_records": int(self.record_count), + f"{prefix}_anchor_records": int(self.records_count("anchor")), + f"{prefix}_revisit_records": int(self.records_count("revisit")), + f"{prefix}_slots": int(self.slot_count), + f"{prefix}_raw_frame_slots": int(self.raw_frame_slots), + f"{prefix}_raw_segments": int(self.raw_segment_count), + f"{prefix}_evictions": int(self.evictions), + f"{prefix}_resets": int(self.reset_count), + f"{prefix}_capacity_exceeded": int(self.capacity_exceeded_count), + f"{prefix}_device": self.device, + f"{prefix}_current_video_id": self.current_video_id, + f"{prefix}_clear_between_videos": bool(self.clear_between_videos), + f"{prefix}_no_evict": bool(self.no_evict), + } + + +DeMemWMStreamingCache = StreamingCache diff --git a/algorithms/worldmem/dememwm/compression.py b/algorithms/worldmem/dememwm/compression.py new file mode 100644 index 0000000000000000000000000000000000000000..9cbe43db219820f6070a72c63525e37406b2a7f5 --- /dev/null +++ b/algorithms/worldmem/dememwm/compression.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + + +def latent_patch_tokens(latents: torch.Tensor, patch_size: int) -> torch.Tensor: + if latents.ndim != 5: + raise ValueError("latents must have shape (T,B,C,H,W)") + if patch_size <= 0: + raise ValueError("patch_size must be positive") + T, B, C, H, W = latents.shape + if H % patch_size != 0 or W % patch_size != 0: + raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={patch_size}") + flat = latents.reshape(T * B, C, H, W) + patches = F.unfold(flat, kernel_size=patch_size, stride=patch_size).transpose(1, 2).contiguous() + return patches.reshape(T, B, patches.shape[1], C * patch_size * patch_size) + + +def spatial_pool_tokens( + tokens: torch.Tensor, + pool_h: int, + pool_w: int, + src_h: int, + src_w: int, +) -> torch.Tensor: + """2D adaptive average pool on a flattened (src_h*src_w, D) token grid. + Preserves 2D spatial layout. Returns (pool_h*pool_w, D).""" + if tokens.ndim != 2: + raise ValueError("tokens must have shape (N, D)") + D = tokens.shape[-1] + spatial = tokens.reshape(src_h, src_w, D).permute(2, 0, 1).unsqueeze(0) + pooled = F.adaptive_avg_pool2d(spatial, (pool_h, pool_w)) + return pooled.squeeze(0).permute(1, 2, 0).reshape(-1, D) + + +class SpatialConv2DMemoryProjector(nn.Module): + """Project latent maps to DiT hidden tokens while preserving the HxW grid.""" + + projects_spatial_latents = True + + def __init__( + self, + latent_channels: int, + dit_hidden_size: int, + mid_channels: int, + kernel_size: int = 3, + ): + super().__init__() + kernel_size = int(kernel_size) + if kernel_size <= 0 or kernel_size % 2 == 0: + raise ValueError("kernel_size must be a positive odd integer") + self.latent_channels = int(latent_channels) + self.dit_hidden_size = int(dit_hidden_size) + self.mid_channels = int(mid_channels) + self.kernel_size = kernel_size + self.out_features = self.dit_hidden_size + self.proj_in = nn.Conv2d(self.latent_channels, self.mid_channels, kernel_size=1) + self.proj_spatial = nn.Conv2d( + self.mid_channels, + self.dit_hidden_size, + kernel_size=kernel_size, + padding=kernel_size // 2, + ) + + def forward(self, latents: torch.Tensor) -> torch.Tensor: + if latents.ndim != 5: + raise ValueError("latents must have shape (T,B,C,H,W)") + T, B, C, H, W = latents.shape + if C != self.latent_channels: + raise ValueError(f"expected {self.latent_channels} latent channels, got {C}") + x = latents.reshape(T * B, C, H, W) + x = self.proj_spatial(self.proj_in(x)) + x = x.reshape(T, B, self.dit_hidden_size, H, W) + return x.permute(1, 0, 3, 4, 2).reshape(B, T, H * W, self.dit_hidden_size).contiguous() + + +class CausalConv3DDynamicCompressor(nn.Module): + """Dynamic memory compressor: delta preprocessing + causal Conv3D on raw latents. + + Replaces ShortTermLatentCompressor (slot cross-attention). + - Operates directly on (T, C, H, W) raw latents + - Delta: inp[0]=latent[0], inp[t]=latent[t]-latent[t-1] + - Causal padding prepends temporal zeros and right-aligns fixed outputs + - Zero-padded to max_source_frames for fixed output shape + - No slot cross-attention, no chunking + """ + + def __init__( + self, + latent_channels: int, + dit_hidden_size: int, + patch_size: int = 2, + conv_kernel_t: int = 3, + conv_stride_t: int = 2, + max_source_frames: int = 8, + exclude_latest_local_frames: int = 4, + ): + super().__init__() + self.latent_channels = latent_channels + self.dit_hidden_size = dit_hidden_size + self.patch_size = patch_size + self.conv_kernel_t = conv_kernel_t + self.conv_stride_t = conv_stride_t + self.max_source_frames = max_source_frames + self.exclude_latest_local_frames = int(exclude_latest_local_frames) + self.causal_pad = self._temporal_left_pad() + self.conv3d = nn.Conv3d( + latent_channels, dit_hidden_size, + kernel_size=(conv_kernel_t, patch_size, patch_size), + stride=(conv_stride_t, patch_size, patch_size), + padding=0, + ) + self.out_norm = nn.LayerNorm(dit_hidden_size) + self._init_temporal_as_delta() + + def _init_temporal_as_delta(self) -> None: + with torch.no_grad(): + self.conv3d.weight.zero_() + k_t, p = self.conv_kernel_t, self.patch_size + D_out, D_in = self.conv3d.weight.shape[:2] + scale = 1.0 / (p * p) + # Delta preprocessing happens in forward. Initialize every output + # channel to read a patch-averaged current delta, repeating latent + # channels across the wider DiT hidden dimension. + for d in range(D_out): + self.conv3d.weight[d, d % D_in, k_t - 1, :, :] = scale + if self.conv3d.bias is not None: + nn.init.zeros_(self.conv3d.bias) + + def _temporal_output_count(self) -> int: + return math.ceil(self.max_source_frames / self.conv_stride_t) + + def _temporal_left_pad(self) -> int: + t_out = self._temporal_output_count() + latest_output_end = (t_out - 1) * self.conv_stride_t + self.conv_kernel_t - 1 + latest_source = self.max_source_frames - 1 + return max(0, latest_output_end - latest_source) + + def _output_time_indices(self, device: torch.device) -> torch.Tensor: + t_out = self._temporal_output_count() + return ( + torch.arange(t_out, device=device, dtype=torch.long) * self.conv_stride_t + + self.conv_kernel_t + - 1 + - self.causal_pad + ) + + def tokens_per_target(self, H: int, W: int) -> int: + p = self.patch_size + T_out = self._temporal_output_count() + return T_out * (H // p) * (W // p) + + def forward( + self, + latents: torch.Tensor, + frame_indices: torch.Tensor, + pose: Optional[torch.Tensor], + target_frame_indices: torch.Tensor, + source_is_generated: Optional[torch.Tensor] = None, + exclude_latest_local_frames: Optional[int] = None, + ) -> tuple[torch.Tensor, torch.Tensor, dict]: + if latents.ndim != 5: + raise ValueError("latents must have shape (T_src,B,C,H,W)") + exclude_latest_local_frames = ( + self.exclude_latest_local_frames + if exclude_latest_local_frames is None + else int(exclude_latest_local_frames) + ) + T_src, B, C, H, W = latents.shape + p = self.patch_size + if H % p != 0 or W % p != 0: + raise ValueError(f"latent H,W=({H},{W}) must be divisible by patch_size={p}") + if target_frame_indices.ndim == 1: + target_frame_indices = target_frame_indices[:, None].expand(-1, B) + T_tgt = target_frame_indices.shape[0] + device = latents.device + generated_flags = None if source_is_generated is None else source_is_generated.to(device=device, dtype=torch.bool) + n_spatial = (H // p) * (W // p) + T_out = self._temporal_output_count() + num_slots = T_out * n_spatial + output_time_idx = self._output_time_indices(device) + selected_source_count = torch.zeros((B, T_tgt), dtype=torch.long, device=device) + max_source_frame = torch.full((B, T_tgt), -1, dtype=torch.long, device=device) + generated_source_fraction = torch.zeros((B, T_tgt), dtype=torch.float32, device=device) + min_gap = torch.full((B, T_tgt), -1, dtype=torch.long, device=device) + max_gap = torch.full((B, T_tgt), -1, dtype=torch.long, device=device) + output_rows, mask_rows = [], [] + for b in range(B): + src_frames_b = frame_indices[:, b] + tgt_outputs, tgt_masks = [], [] + for j in range(T_tgt): + target = int(target_frame_indices[j, b].item()) + valid_idx = ( + src_frames_b < target - exclude_latest_local_frames + ).nonzero(as_tuple=False).flatten() + if valid_idx.numel() == 0: + tgt_outputs.append(latents.new_zeros(num_slots, self.dit_hidden_size)) + tgt_masks.append(torch.zeros(num_slots, device=device, dtype=torch.bool)) + continue + selected_frames = src_frames_b.index_select(0, valid_idx) + order = torch.argsort(selected_frames) + valid_idx = valid_idx.index_select(0, order)[-self.max_source_frames:] + selected_frames = src_frames_b.index_select(0, valid_idx) + selected_source_count[b, j] = int(selected_frames.numel()) + max_source_frame[b, j] = selected_frames.max() + gaps = target - selected_frames + min_gap[b, j] = gaps.min() + max_gap[b, j] = gaps.max() + if generated_flags is not None: + generated = generated_flags.index_select(0, valid_idx)[:, b] + generated_source_fraction[b, j] = generated.float().mean() + chunk = latents[valid_idx, b] + real_mask = torch.ones((chunk.shape[0],), device=device, dtype=torch.bool) + if chunk.shape[0] < self.max_source_frames: + pad = chunk.new_zeros(self.max_source_frames - chunk.shape[0], C, H, W) + chunk = torch.cat([pad, chunk], dim=0) + real_mask = torch.cat([ + torch.zeros((pad.shape[0],), device=device, dtype=torch.bool), + real_mask, + ]) + inp = chunk.clone() + inp[1:] = chunk[1:] - chunk[:-1] + x = inp.permute(1, 0, 2, 3).unsqueeze(0) # (1,C,T,H,W) + x = F.pad(x, (0, 0, 0, 0, self.causal_pad, 0)) # left-pad time + x = self.conv3d(x) # (1,D,T_out,H//p,W//p) + x = x.squeeze(0).permute(1, 2, 3, 0) # (T_out,H//p,W//p,D) + x = self.out_norm(x) + tokens = x.reshape(num_slots, self.dit_hidden_size) + clamped_time_idx = output_time_idx.clamp(min=0, max=self.max_source_frames - 1) + temporal_mask = ( + (output_time_idx >= 0) + & (output_time_idx < self.max_source_frames) + & real_mask.index_select(0, clamped_time_idx) + ) + mask = temporal_mask[:, None].expand(T_out, n_spatial).reshape(num_slots) + tgt_outputs.append(tokens) + tgt_masks.append(mask) + output_rows.append(torch.stack(tgt_outputs)) + mask_rows.append(torch.stack(tgt_masks)) + out_tokens = torch.stack(output_rows) + out_mask = torch.stack(mask_rows) + diagnostics = { + "num_dynamic_slots": num_slots, + "dynamic_T_out": T_out, + "dynamic_n_spatial": n_spatial, + "dynamic_temporal_left_pad": self.causal_pad, + "dynamic_output_time_indices": output_time_idx, + "selected_source_count": selected_source_count, + "max_source_frame": max_source_frame, + "generated_source_fraction": generated_source_fraction, + "dynamic_min_gap_to_target_per_target": min_gap, + "dynamic_max_gap_to_target_per_target": max_gap, + "dynamic_exclude_latest_local_frames": exclude_latest_local_frames, + } + return out_tokens, out_mask, diagnostics diff --git a/algorithms/worldmem/dememwm/diagnostics.py b/algorithms/worldmem/dememwm/diagnostics.py new file mode 100644 index 0000000000000000000000000000000000000000..8134a25e0379c028e733cd56aad8e090e90e390d --- /dev/null +++ b/algorithms/worldmem/dememwm/diagnostics.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from .schedules import EVAL_ABLATION_BRANCH_TO_ID, NOISE_BUCKETS, NOISE_BUCKET_TO_ID, normalize_eval_ablation_branch, normalize_noise_bucket + + +_REVISIT_LABEL_SOURCE = "deterministic_fov_coverage_plucker" + + +def tensor_valid_fraction(mask: torch.Tensor | None) -> float: + if mask is None or mask.numel() == 0: + return 0.0 + return float(mask.detach().bool().float().mean().item()) + + +def gate_stats(gate: torch.Tensor | float | int | None) -> dict[str, float]: + if gate is None: + return {"mean": 0.0, "min": 0.0, "max": 0.0} + if not torch.is_tensor(gate): + value = float(gate) + return {"mean": value, "min": value, "max": value} + g = gate.detach().float() + return {"mean": float(g.mean().item()), "min": float(g.min().item()), "max": float(g.max().item())} + + +def summarize_stream(name: str, tokens: torch.Tensor | None, mask: torch.Tensor | None, gate: torch.Tensor | float | None) -> dict[str, Any]: + return {f"{name}_tokens_shape": None if tokens is None else tuple(tokens.shape), f"{name}_valid_fraction": tensor_valid_fraction(mask), f"{name}_valid_tokens": 0 if mask is None else int(mask.detach().bool().sum().item()), f"{name}_gate": gate_stats(gate)} + + +def assert_no_future_sources(target_frame: int, max_source_frame: int | torch.Tensor) -> None: + max_src = int(max_source_frame.detach().max().item()) if torch.is_tensor(max_source_frame) else int(max_source_frame) + if max_src >= int(target_frame): + raise AssertionError(f"DeMemWM memory source {max_src} is not causal for target {target_frame}") + + +def _collect_values(result_diagnostics: list[dict[str, Any]], key: str) -> list[float]: + values: list[float] = [] + for diag in result_diagnostics: + for value in diag.get(key, []) or []: + values.append(float(value)) + return values + + +def _value_stats(values: list[float], prefix: str) -> dict[str, float]: + if not values: + return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0} + return { + f"{prefix}_mean": float(sum(values) / len(values)), + f"{prefix}_min": float(min(values)), + f"{prefix}_max": float(max(values)), + } + + +def summarize_revisit_diagnostics(result_diagnostics: list[dict[str, Any]], valid_revisit_mask: torch.Tensor | None) -> dict[str, Any]: + target_count = len(result_diagnostics) + candidate_count = sum(int(diag.get("revisit_candidate_frame_count", diag.get("revisit_candidate_count", diag.get("candidate_count", 0)))) for diag in result_diagnostics) + candidate_count_mean = float(candidate_count / target_count) if target_count else 0.0 + valid_candidate_label_count = sum(int(diag.get("valid_candidate_label_count", diag.get("valid_candidate_count", 0))) for diag in result_diagnostics) + pose_preselect_input_count = sum(int(diag.get("revisit_pose_preselect_input_count", 0)) for diag in result_diagnostics) + pose_preselect_selected_count = sum(int(diag.get("revisit_pose_preselect_selected_count", 0)) for diag in result_diagnostics) + exact_fov_candidate_count = sum(int(diag.get("revisit_exact_fov_candidate_count", 0)) for diag in result_diagnostics) + valid_count = sum(int(diag.get("valid_revisit_frame_count", diag.get("valid_revisit_count", diag.get("valid_candidate_count", 0)))) for diag in result_diagnostics) + valid_count_mean = float(valid_count / target_count) if target_count else 0.0 + valid_target_count = sum(int(diag.get("valid_revisit_target_count", diag.get("high_quality_selected_revisit", 0))) for diag in result_diagnostics) + selected_count = sum(int(diag.get("revisit_selected_frame_count", diag.get("revisit_selected_count", diag.get("selected_count", 0)))) for diag in result_diagnostics) + no_valid_count = sum(int(diag.get("no_valid_revisit_count", 0)) for diag in result_diagnostics) + abstained_count = sum(int(diag.get("revisit_abstained_count", int(bool(diag.get("abstained", False))))) for diag in result_diagnostics) + selected_gaps = [int(diag["revisit_min_gap_to_target"]) for diag in result_diagnostics if int(diag.get("revisit_min_gap_to_target", -1)) >= 0] + diagnostics: dict[str, Any] = { + "revisit_candidate_frame_count": candidate_count_mean, + "revisit_candidate_count": candidate_count_mean, + "valid_candidate_label_count": int(valid_candidate_label_count), + "revisit_pose_preselect_input_count": float(pose_preselect_input_count / target_count) if target_count else 0.0, + "revisit_pose_preselect_selected_count": float(pose_preselect_selected_count / target_count) if target_count else 0.0, + "revisit_exact_fov_candidate_count": float(exact_fov_candidate_count / target_count) if target_count else 0.0, + "valid_revisit_frame_count": valid_count_mean, + "valid_revisit_count": valid_count_mean, + "valid_revisit_target_count": int(valid_target_count), + "no_valid_revisit_count": int(no_valid_count), + "valid_revisit_mask_fraction": tensor_valid_fraction(valid_revisit_mask), + "revisit_selected_frame_count": int(selected_count), + "revisit_selected_count": int(selected_count), + "revisit_abstained_count": int(abstained_count), + "revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1, + "revisit_label_source": _REVISIT_LABEL_SOURCE, + } + frame_fov_values = _collect_values(result_diagnostics, "frame_fov_overlap_values") + if not frame_fov_values: + frame_fov_values = _collect_values(result_diagnostics, "fov_overlap_values") + diagnostics.update(_value_stats(frame_fov_values, "revisit_frame_fov_overlap")) + diagnostics.update(_value_stats(frame_fov_values, "revisit_fov_overlap")) + diagnostics.update(_value_stats(_collect_values(result_diagnostics, "plucker_overlap_values"), "revisit_plucker_overlap")) + diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_fov_overlap_values"), "revisit_best_selected_fov_overlap")) + diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_plucker_overlap_values"), "revisit_best_selected_plucker_overlap")) + diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_gap_frame_values"), "revisit_best_selected_gap_frames")) + diagnostics.update(_value_stats(_collect_values(result_diagnostics, "best_selected_frame_fov_overlap_values"), "revisit_best_selected_frame_fov_overlap")) + diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_frame_fov_overlap_values"), "revisit_selected_frame_fov_overlap")) + diagnostics.update(_value_stats(_collect_values(result_diagnostics, "selected_incremental_fov_overlap_values"), "revisit_incremental_fov_overlap")) + return diagnostics + + +def summarize_noise_bucket_diagnostics( + *, + noise_bucket: str | None, + valid_revisit_mask: torch.Tensor | None, + no_valid_revisit_mask: torch.Tensor | None, + noise_bucket_ids: torch.Tensor | None = None, +) -> dict[str, Any]: + bucket = normalize_noise_bucket(noise_bucket) + diagnostics: dict[str, Any] = { + "noise_bucket": bucket, + "noise_bucket_id": int(NOISE_BUCKET_TO_ID[bucket]), + } + for candidate in NOISE_BUCKETS: + diagnostics[f"noise_bucket_is_{candidate}"] = int(bucket == candidate) + + valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu() + no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu() + target_count = int(valid.numel()) + diagnostics["noise_bucket_target_count"] = target_count + if noise_bucket_ids is None: + target_bucket_ids = torch.full((target_count,), int(NOISE_BUCKET_TO_ID[bucket]), dtype=torch.long) + else: + target_bucket_ids = noise_bucket_ids.detach().long().reshape(-1).cpu() + if int(target_bucket_ids.numel()) != target_count: + raise ValueError( + f"noise_bucket_ids has {int(target_bucket_ids.numel())} targets, expected {target_count}" + ) + + for bucket_name in NOISE_BUCKETS: + bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name]) + diagnostics[f"noise_bucket_{bucket_name}_target_count"] = int(bucket_mask.long().sum().item()) + + mask_specs = ( + ("valid_revisit", valid), + ("no_valid_revisit", no_valid), + ) + for mask_name, mask in mask_specs: + for bucket_name in NOISE_BUCKETS: + bucket_mask = target_bucket_ids == int(NOISE_BUCKET_TO_ID[bucket_name]) + count = int((mask & bucket_mask).long().sum().item()) if mask.numel() else 0 + diagnostics[f"{mask_name}_noise_bucket_{bucket_name}_count"] = count + return diagnostics + + +def summarize_eval_ablation_diagnostics( + *, + enabled: bool, + branch: str | None, + valid_revisit_mask: torch.Tensor | None, + no_valid_revisit_mask: torch.Tensor | None, + eval_corrupted_revisit_mask: torch.Tensor | None, +) -> dict[str, Any]: + branch = normalize_eval_ablation_branch(branch) + valid = torch.zeros(0, dtype=torch.bool) if valid_revisit_mask is None else valid_revisit_mask.detach().bool().reshape(-1).cpu() + no_valid = torch.zeros_like(valid) if no_valid_revisit_mask is None else no_valid_revisit_mask.detach().bool().reshape(-1).cpu() + corrupted = torch.zeros_like(valid) if eval_corrupted_revisit_mask is None else eval_corrupted_revisit_mask.detach().bool().reshape(-1).cpu() + true_revisit = valid & (~corrupted) + diagnostics: dict[str, Any] = { + "eval_ablation_enabled": bool(enabled), + "eval_ablation_branch": branch, + "eval_ablation_branch_id": int(EVAL_ABLATION_BRANCH_TO_ID[branch]), + "eval_bucket_true_revisit_count": int(true_revisit.long().sum().item()), + "eval_bucket_no_valid_revisit_count": int(no_valid.long().sum().item()), + "eval_bucket_corrupted_memory_count": int(corrupted.long().sum().item()), + } + total = max(int(valid.numel()), 1) + diagnostics["eval_bucket_true_revisit_fraction"] = float(diagnostics["eval_bucket_true_revisit_count"] / total) + diagnostics["eval_bucket_no_valid_revisit_fraction"] = float(diagnostics["eval_bucket_no_valid_revisit_count"] / total) + diagnostics["eval_bucket_corrupted_memory_fraction"] = float(diagnostics["eval_bucket_corrupted_memory_count"] / total) + return diagnostics diff --git a/algorithms/worldmem/dememwm/gates.py b/algorithms/worldmem/dememwm/gates.py new file mode 100644 index 0000000000000000000000000000000000000000..9cebdfdd509a4d40f0d9a1f6bcecfb76bfd834dc --- /dev/null +++ b/algorithms/worldmem/dememwm/gates.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import torch +from torch import nn + + +class RevisitRawGate(nn.Module): + """Learned per-target revisit gate from selected-revisit quality features. + + The caller applies validity masking and stage/denoise scaling after this + module. This module never turns selected revisit validity into a target. + """ + + def __init__(self, init_logit: float = 1.0): + super().__init__() + self.net = nn.Linear(3, 1) + nn.init.zeros_(self.net.weight) + nn.init.constant_(self.net.bias, float(init_logit)) + + def forward( + self, + *, + valid_revisit_mask: torch.Tensor, + best_selected_fov_overlap: torch.Tensor | None = None, + best_selected_plucker_overlap: torch.Tensor | None = None, + selected_gap_frames: torch.Tensor | None = None, + ) -> torch.Tensor: + if valid_revisit_mask.ndim != 2: + raise ValueError("valid_revisit_mask must have shape (B,T)") + device = valid_revisit_mask.device + dtype = torch.float32 + shape = valid_revisit_mask.shape + + def _feature(value: torch.Tensor | None) -> torch.Tensor: + if value is None: + return torch.zeros(shape, device=device, dtype=dtype) + tensor = value.to(device=device, dtype=dtype) + if tensor.ndim == 0: + return torch.full(shape, float(tensor.item()), device=device, dtype=dtype) + return tensor.expand(shape) + + fov = _feature(best_selected_fov_overlap).clamp(min=0.0, max=1.0) + plucker = _feature(best_selected_plucker_overlap).clamp(min=0.0, max=1.0) + log_age = torch.log1p(_feature(selected_gap_frames).clamp_min(0.0)).clamp(max=8.0) / 8.0 + features = torch.stack([fov, plucker, log_age], dim=-1) + return torch.sigmoid(self.net(features).squeeze(-1)) diff --git a/algorithms/worldmem/dememwm/injection.py b/algorithms/worldmem/dememwm/injection.py new file mode 100644 index 0000000000000000000000000000000000000000..91524095036149ccbd48ed175af32de467c5439b --- /dev/null +++ b/algorithms/worldmem/dememwm/injection.py @@ -0,0 +1,83 @@ + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + +from .diagnostics import summarize_stream +from .types import MemoryStreamTensors + + +@dataclass +class InjectionAdapter: + """Convert DeMemWM stream tensors to Diffusion/DiT Option-C kwargs.""" + + omit_disabled: bool = False + + def _tokens(self, name: str, tokens: torch.Tensor, device, dtype) -> torch.Tensor: + if tokens.ndim != 4: + raise ValueError(f"{name} tokens must have shape (B,T,M,D), got {tuple(tokens.shape)}") + return tokens.to(device=device, dtype=dtype) + + def _mask(self, name: str, mask: torch.Tensor, tokens: torch.Tensor, device) -> torch.Tensor: + if mask.shape != tokens.shape[:3]: + raise ValueError(f"{name} mask must have shape {tuple(tokens.shape[:3])}, got {tuple(mask.shape)}") + return mask.to(device=device, dtype=torch.bool) + + def _gate(self, gate: torch.Tensor | float | int, tokens: torch.Tensor, device, dtype): + if torch.is_tensor(gate): + return gate.to(device=device, dtype=dtype) + return torch.tensor(float(gate), device=device, dtype=dtype) + + def __call__(self, streams: MemoryStreamTensors, device=None, dtype=None) -> tuple[dict[str, Any], dict[str, Any]]: + ref = streams.anchor_tokens + device = device or ref.device + dtype = dtype or ref.dtype + anchor_tokens = self._tokens("anchor", streams.anchor_tokens, device, dtype) + dynamic_tokens = self._tokens("dynamic", streams.dynamic_tokens, device, dtype) + revisit_tokens = self._tokens("revisit", streams.revisit_tokens, device, dtype) + anchor_mask = self._mask("anchor", streams.anchor_mask, anchor_tokens, device) + dynamic_mask = self._mask("dynamic", streams.dynamic_mask, dynamic_tokens, device) + revisit_mask = self._mask("revisit", streams.revisit_mask, revisit_tokens, device) + kwargs = { + "memory_tokens": anchor_tokens, + "memory_token_mask": anchor_mask, + "memory_dynamic_tokens": dynamic_tokens, + "memory_dynamic_mask": dynamic_mask, + "memory_retrieval_tokens": revisit_tokens, + "memory_retrieval_mask": revisit_mask, + "memory_anchor_gate": self._gate(streams.anchor_gate, anchor_tokens, device, dtype), + "memory_dynamic_gate": self._gate(streams.dynamic_gate, dynamic_tokens, device, dtype), + "memory_retrieval_gate": self._gate(streams.revisit_gate, revisit_tokens, device, dtype), + } + if self.omit_disabled: + if not anchor_mask.any(): + kwargs["memory_tokens"] = None + kwargs["memory_token_mask"] = None + if not dynamic_mask.any(): + kwargs["memory_dynamic_tokens"] = None + kwargs["memory_dynamic_mask"] = None + if not revisit_mask.any(): + kwargs["memory_retrieval_tokens"] = None + kwargs["memory_retrieval_mask"] = None + diagnostics = dict(streams.diagnostics) + diagnostics.update(summarize_stream("anchor", anchor_tokens, anchor_mask, kwargs["memory_anchor_gate"])) + diagnostics.update(summarize_stream("dynamic", dynamic_tokens, dynamic_mask, kwargs["memory_dynamic_gate"])) + diagnostics.update(summarize_stream("revisit", revisit_tokens, revisit_mask, kwargs["memory_retrieval_gate"])) + if streams.revisit_gate_raw is not None: + raw_gate = streams.revisit_gate_raw.to(device=device, dtype=dtype) + diagnostics["revisit_gate_raw"] = raw_gate + diagnostics["revisit_gate_raw_mean"] = float(raw_gate.detach().float().mean().item()) if raw_gate.numel() else 0.0 + diagnostics["revisit_gate_raw_min"] = float(raw_gate.detach().float().min().item()) if raw_gate.numel() else 0.0 + diagnostics["revisit_gate_raw_max"] = float(raw_gate.detach().float().max().item()) if raw_gate.numel() else 0.0 + if streams.no_valid_revisit_mask is not None: + diagnostics["no_valid_revisit_mask"] = streams.no_valid_revisit_mask.to(device=device, dtype=torch.bool) + max_sources = [v for k, v in streams.diagnostics.items() if k.endswith("max_source_frame")] + if max_sources: + diagnostics["max_source_frame"] = max(int(torch.as_tensor(v).max().item()) for v in max_sources) + return kwargs, diagnostics + + +DeMemWMInjectionAdapter = InjectionAdapter diff --git a/algorithms/worldmem/dememwm/labels.py b/algorithms/worldmem/dememwm/labels.py new file mode 100644 index 0000000000000000000000000000000000000000..31ca536b92d0122c85f5d43044c34b52a97b8ab4 --- /dev/null +++ b/algorithms/worldmem/dememwm/labels.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from .types import MemoryRecord + + +LABEL_SOURCE = "deterministic_fov_coverage_plucker" + + +@dataclass(frozen=True) +class RevisitCandidateLabel: + record: MemoryRecord + valid: bool + gap_valid: bool + gap_to_target: int + fov_overlap: Optional[float] + plucker_overlap: Optional[float] + primary_overlap: float + coverage_mask: Optional[torch.Tensor] + reject_reasons: tuple[str, ...] + best_frame_index: Optional[int] = None + best_frame_fov_overlap: Optional[float] = None + + @property + def sort_key(self) -> tuple[float, float, float, int, int, str]: + fov = 0.0 if self.fov_overlap is None else float(self.fov_overlap) + plucker = 0.0 if self.plucker_overlap is None else float(self.plucker_overlap) + return ( + -self.primary_overlap, + -fov, + -plucker, + self.gap_to_target, + int(self.record.source_start), + str(self.record.chunk_id or ""), + ) + + +def _as_float_tensor(value) -> Optional[torch.Tensor]: + if value is None: + return None + if torch.is_tensor(value): + if value.numel() == 0: + return None + return value.detach().float() + tensor = torch.as_tensor(value, dtype=torch.float32) + if tensor.numel() == 0: + return None + return tensor + + +def _pose_frames(value) -> Optional[torch.Tensor]: + tensor = _as_float_tensor(value) + if tensor is None or tensor.shape[-1] < 5: + return None + return tensor.reshape(-1, tensor.shape[-1])[:, :5] + + +def _angle_diff_degrees(value: torch.Tensor) -> torch.Tensor: + diff = value.abs() % 360.0 + return torch.where(diff > 180.0, 360.0 - diff, diff) + + +def _target_fov_points( + target_pose: torch.Tensor, + *, + fov_half_h: float, + fov_half_v: float, + yaw_samples: int, + pitch_samples: int, + depth_samples: int, + radius: float, +) -> torch.Tensor: + yaw_samples = max(1, int(yaw_samples)) + pitch_samples = max(1, int(pitch_samples)) + depth_samples = max(1, int(depth_samples)) + device = target_pose.device + dtype = target_pose.dtype + if yaw_samples == 1: + yaw_offsets = torch.zeros((1,), device=device, dtype=dtype) + else: + yaw_offsets = torch.linspace(-float(fov_half_h), float(fov_half_h), yaw_samples + 2, device=device, dtype=dtype)[1:-1] + if pitch_samples == 1: + pitch_offsets = torch.zeros((1,), device=device, dtype=dtype) + else: + pitch_offsets = torch.linspace(-float(fov_half_v), float(fov_half_v), pitch_samples + 2, device=device, dtype=dtype)[1:-1] + if depth_samples == 1: + depths = torch.full((1,), float(radius), device=device, dtype=dtype) + else: + depths = torch.linspace(float(radius) / float(depth_samples), float(radius), depth_samples, device=device, dtype=dtype) + depth_grid, pitch_grid, yaw_grid = torch.meshgrid(depths, pitch_offsets, yaw_offsets, indexing="ij") + pitch = torch.deg2rad(target_pose[3] + pitch_grid.reshape(-1)) + yaw = torch.deg2rad(target_pose[4] + yaw_grid.reshape(-1)) + depth = depth_grid.reshape(-1) + cos_pitch = torch.cos(pitch) + vectors = torch.stack( + [ + depth * cos_pitch * torch.sin(yaw), + depth * torch.sin(pitch), + depth * cos_pitch * torch.cos(yaw), + ], + dim=-1, + ) + return target_pose[:3].reshape(1, 3) + vectors + + +def _inside_fov_3d_hv( + points: torch.Tensor, + poses: torch.Tensor, + *, + fov_half_h: float, + fov_half_v: float, +) -> torch.Tensor: + vectors = points.unsqueeze(0) - poses[:, None, :3] + x = vectors[..., 0] + y = vectors[..., 1] + z = vectors[..., 2] + azimuth = torch.atan2(x, z) * (180.0 / math.pi) + elevation = torch.atan2(y, torch.sqrt(x.square() + z.square()).clamp_min(1e-8)) * (180.0 / math.pi) + diff_azimuth = _angle_diff_degrees(azimuth - poses[:, None, 4]) + diff_elevation = _angle_diff_degrees(elevation - poses[:, None, 3]) + return (diff_azimuth < float(fov_half_h)) & (diff_elevation < float(fov_half_v)) + + +def fov_coverage_overlap( + source_pose, + target_pose, + *, + fov_half_h: float = 105.0 / 2.0, + fov_half_v: float = 75.0 / 2.0, + yaw_samples: int = 25, + pitch_samples: int = 20, + depth_samples: int = 20, + radius: float = 30.0, +) -> tuple[Optional[float], Optional[torch.Tensor]]: + source_poses = _pose_frames(source_pose) + target_poses = _pose_frames(target_pose) + if source_poses is None or target_poses is None: + return None, None + target = target_poses[-1].to(device=source_poses.device, dtype=source_poses.dtype) + points = _target_fov_points( + target, + fov_half_h=fov_half_h, + fov_half_v=fov_half_v, + yaw_samples=yaw_samples, + pitch_samples=pitch_samples, + depth_samples=depth_samples, + radius=radius, + ) + coverage_mask = _inside_fov_3d_hv(points, source_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v).any(dim=0) + return float(coverage_mask.float().mean().item()), coverage_mask.detach() + + +def _rotation_from_pose(poses: torch.Tensor) -> torch.Tensor: + pitch = torch.deg2rad(poses[:, 3]) + yaw = torch.deg2rad(poses[:, 4]) + cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch) + cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw) + zeros = torch.zeros_like(pitch) + ones = torch.ones_like(pitch) + r_pitch = torch.stack( + [ + ones, zeros, zeros, + zeros, cos_pitch, -sin_pitch, + zeros, sin_pitch, cos_pitch, + ], + dim=-1, + ).reshape(-1, 3, 3) + r_yaw = torch.stack( + [ + cos_yaw, zeros, sin_yaw, + zeros, ones, zeros, + -sin_yaw, zeros, cos_yaw, + ], + dim=-1, + ).reshape(-1, 3, 3) + return torch.matmul(r_yaw, r_pitch) + + +def _plucker_descriptor( + poses: torch.Tensor, + *, + grid_h: int, + grid_w: int, + focal_length: float, +) -> torch.Tensor: + grid_h = max(1, int(grid_h)) + grid_w = max(1, int(grid_w)) + poses = poses.float() + device = poses.device + dtype = torch.float32 + ys, xs = torch.meshgrid( + torch.linspace(0, grid_h - 1, grid_h, device=device, dtype=dtype), + torch.linspace(0, grid_w - 1, grid_w, device=device, dtype=dtype), + indexing="ij", + ) + fx = float(focal_length) * float(grid_w) + fy = float(focal_length) * float(grid_h) + cx = 0.5 * float(grid_w) + cy = 0.5 * float(grid_h) + zs = torch.ones_like(xs) + dirs = torch.stack([-(xs + 0.5 - cx) / fx, -(ys + 0.5 - cy) / fy, zs], dim=-1) + dirs = dirs.reshape(-1, 3) + dirs = dirs / dirs.norm(dim=-1, keepdim=True).clamp_min(1e-8) + rotation = _rotation_from_pose(poses) + rays_d = torch.matmul(dirs.unsqueeze(0), rotation.transpose(1, 2)).float() + rays_o = poses[:, None, :3].expand_as(rays_d).float() + moments = torch.linalg.cross(rays_o, rays_d, dim=-1) + return torch.cat([moments, rays_d], dim=-1).reshape(poses.shape[0], -1) + + +def plucker_overlap( + source_pose, + target_pose, + *, + grid_h: int = 4, + grid_w: int = 4, + focal_length: float = 0.35, +) -> Optional[float]: + source_poses = _pose_frames(source_pose) + target_poses = _pose_frames(target_pose) + if source_poses is None or target_poses is None: + return None + target = target_poses[-1:].to(device=source_poses.device, dtype=source_poses.dtype) + source_desc = _plucker_descriptor(source_poses, grid_h=grid_h, grid_w=grid_w, focal_length=focal_length) + target_desc = _plucker_descriptor(target, grid_h=grid_h, grid_w=grid_w, focal_length=focal_length) + diff = source_desc - target_desc + distance = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1])) + best_distance = float(distance.min().item()) + return float(1.0 / (1.0 + max(best_distance, 0.0))) + + +def _passes_threshold(overlap: Optional[float], threshold: Optional[float]) -> bool: + if threshold is None or overlap is None: + return True + return float(overlap) >= float(threshold) + + +def _batched_pose_overlaps( + records: list[MemoryRecord], + *, + target_pose=None, + fov_half_h: float = 105.0 / 2.0, + fov_half_v: float = 75.0 / 2.0, + fov_yaw_samples: int = 25, + fov_pitch_samples: int = 20, + fov_depth_samples: int = 20, + fov_radius: float = 30.0, + plucker_grid_h: int = 4, + plucker_grid_w: int = 4, + plucker_focal_length: float = 0.35, +) -> tuple[ + list[Optional[float]], + list[Optional[float]], + list[Optional[torch.Tensor]], + list[Optional[int]], + list[Optional[float]], +]: + fov_overlaps: list[Optional[float]] = [None] * len(records) + plucker_overlaps: list[Optional[float]] = [None] * len(records) + coverage_masks: list[Optional[torch.Tensor]] = [None] * len(records) + best_frame_indices: list[Optional[int]] = [None] * len(records) + best_frame_fov_overlaps: list[Optional[float]] = [None] * len(records) + target_poses = _pose_frames(target_pose) + if not records or target_poses is None: + return fov_overlaps, plucker_overlaps, coverage_masks, best_frame_indices, best_frame_fov_overlaps + + pose_records: list[tuple[int, torch.Tensor, torch.Tensor | None]] = [] + device = None + for record_idx, record in enumerate(records): + source_poses = _pose_frames(record.pose) + if source_poses is None: + continue + if device is None: + device = source_poses.device + frame_values = record.frame_indices.detach().reshape(-1) + frame_values = frame_values if int(frame_values.numel()) == int(source_poses.shape[0]) else None + pose_records.append((record_idx, source_poses, frame_values)) + if not pose_records or device is None: + return fov_overlaps, plucker_overlaps, coverage_masks, best_frame_indices, best_frame_fov_overlaps + + source_pose_blocks = [poses.to(device=device, dtype=torch.float32) for _, poses, _ in pose_records] + source_poses = torch.cat(source_pose_blocks, dim=0) + record_ids = torch.cat( + [ + torch.full((poses.shape[0],), int(record_idx), device=device, dtype=torch.long) + for (record_idx, _, _), poses in zip(pose_records, source_pose_blocks) + ], + dim=0, + ) + source_frame_blocks = [ + torch.full((poses.shape[0],), -1, device=device, dtype=torch.long) + if frame_values is None + else frame_values.to(device=device, dtype=torch.long) + for (_, _, frame_values), poses in zip(pose_records, source_pose_blocks) + ] + source_frame_values = torch.cat(source_frame_blocks, dim=0) + pose_record_indices = [record_idx for record_idx, _, _ in pose_records] + target = target_poses[-1].to(device=device, dtype=source_poses.dtype) + + points = _target_fov_points( + target, + fov_half_h=fov_half_h, + fov_half_v=fov_half_v, + yaw_samples=fov_yaw_samples, + pitch_samples=fov_pitch_samples, + depth_samples=fov_depth_samples, + radius=fov_radius, + ) + inside = _inside_fov_3d_hv(points, source_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v) + per_frame_fov = inside.float().mean(dim=1) + + source_desc = _plucker_descriptor(source_poses, grid_h=plucker_grid_h, grid_w=plucker_grid_w, focal_length=plucker_focal_length) + target_desc = _plucker_descriptor( + target.reshape(1, -1), + grid_h=plucker_grid_h, + grid_w=plucker_grid_w, + focal_length=plucker_focal_length, + ) + diff = source_desc - target_desc + distance = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1])) + best_distance = torch.full((len(records),), float("inf"), device=device, dtype=distance.dtype) + best_distance.scatter_reduce_(0, record_ids, distance, reduce="amin", include_self=True) + plucker_values = (1.0 / (1.0 + best_distance.clamp_min(0.0))).detach().cpu().tolist() + + for record_idx in pose_record_indices: + rows = record_ids == int(record_idx) + if not rows.any(): + continue + record_fov = per_frame_fov[rows] + record_inside = inside[rows] + best_pose_row = int(torch.argmax(record_fov).item()) + best_score_value = float(record_fov[best_pose_row].item()) + fov_overlaps[record_idx] = best_score_value + plucker_overlaps[record_idx] = float(plucker_values[record_idx]) + coverage_masks[record_idx] = record_inside[best_pose_row].detach() + valid_rows = rows & (source_frame_values >= 0) + if valid_rows.any(): + frame_scores = per_frame_fov[valid_rows].detach().cpu().tolist() + frame_values = source_frame_values[valid_rows].detach().cpu().tolist() + best_score, best_frame = max( + ((float(score), int(frame)) for score, frame in zip(frame_scores, frame_values)), + key=lambda item: (item[0], item[1]), + ) + best_frame_indices[record_idx] = int(best_frame) + best_frame_fov_overlaps[record_idx] = float(best_score) + return fov_overlaps, plucker_overlaps, coverage_masks, best_frame_indices, best_frame_fov_overlaps + + +def make_revisit_candidate_labels( + records: list[MemoryRecord], + *, + target_frame: int, + exclude_local_context_frames: int, + target_pose=None, + fov_overlap_threshold: Optional[float] = 0.0, + plucker_weight: float = 0.1, + target_video_id: Any = None, + fov_half_h: float = 105.0 / 2.0, + fov_half_v: float = 75.0 / 2.0, + fov_yaw_samples: int = 25, + fov_pitch_samples: int = 20, + fov_depth_samples: int = 20, + fov_radius: float = 30.0, + plucker_grid_h: int = 4, + plucker_grid_w: int = 4, + plucker_focal_length: float = 0.35, +) -> list[RevisitCandidateLabel]: + del plucker_weight, target_video_id + target_frame = int(target_frame) + exclude_local_context_frames = int(exclude_local_context_frames) + + score_indices: list[int] = [] + max_source_frames: list[int] = [] + gap_flags: list[bool] = [] + for record_idx, record in enumerate(records): + max_source_frame = int(record.source_end) - 1 + gap_ok = max_source_frame < target_frame - exclude_local_context_frames + max_source_frames.append(max_source_frame) + gap_flags.append(gap_ok) + if gap_ok: + score_indices.append(record_idx) + + fov_overlaps: list[Optional[float]] = [None] * len(records) + plucker_overlaps: list[Optional[float]] = [None] * len(records) + coverage_masks: list[Optional[torch.Tensor]] = [None] * len(records) + best_frame_indices: list[Optional[int]] = [None] * len(records) + best_frame_fov_overlaps: list[Optional[float]] = [None] * len(records) + scored_records = [records[record_idx] for record_idx in score_indices] + scored_fov, scored_plucker, scored_masks, scored_best_frames, scored_best_frame_fov = _batched_pose_overlaps( + scored_records, + target_pose=target_pose, + fov_half_h=fov_half_h, + fov_half_v=fov_half_v, + fov_yaw_samples=fov_yaw_samples, + fov_pitch_samples=fov_pitch_samples, + fov_depth_samples=fov_depth_samples, + fov_radius=fov_radius, + plucker_grid_h=plucker_grid_h, + plucker_grid_w=plucker_grid_w, + plucker_focal_length=plucker_focal_length, + ) + for scored_idx, record_idx in enumerate(score_indices): + fov_overlaps[record_idx] = scored_fov[scored_idx] + plucker_overlaps[record_idx] = scored_plucker[scored_idx] + coverage_masks[record_idx] = scored_masks[scored_idx] + best_frame_indices[record_idx] = scored_best_frames[scored_idx] + best_frame_fov_overlaps[record_idx] = scored_best_frame_fov[scored_idx] + + labels: list[RevisitCandidateLabel] = [] + for record_idx, record in enumerate(records): + gap = target_frame - max_source_frames[record_idx] + fov_overlap = fov_overlaps[record_idx] + reasons: list[str] = [] + if not gap_flags[record_idx]: + reasons.append("inside_c_short") + if not _passes_threshold(fov_overlap, fov_overlap_threshold): + reasons.append("fov_overlap_below_threshold") + + fov_score = 0.0 if fov_overlap is None else float(fov_overlap) + labels.append( + RevisitCandidateLabel( + record=record, + valid=not reasons, + gap_valid=gap_flags[record_idx], + gap_to_target=gap, + fov_overlap=fov_overlap, + plucker_overlap=plucker_overlaps[record_idx], + primary_overlap=fov_score, + coverage_mask=coverage_masks[record_idx], + reject_reasons=tuple(reasons), + best_frame_index=best_frame_indices[record_idx], + best_frame_fov_overlap=best_frame_fov_overlaps[record_idx], + ) + ) + return labels + + +def make_revisit_candidate_label( + record: MemoryRecord, + *, + target_frame: int, + exclude_local_context_frames: int, + target_pose=None, + fov_overlap_threshold: Optional[float] = 0.0, + plucker_weight: float = 0.1, + target_video_id: Any = None, + fov_half_h: float = 105.0 / 2.0, + fov_half_v: float = 75.0 / 2.0, + fov_yaw_samples: int = 25, + fov_pitch_samples: int = 20, + fov_depth_samples: int = 20, + fov_radius: float = 30.0, + plucker_grid_h: int = 4, + plucker_grid_w: int = 4, + plucker_focal_length: float = 0.35, +) -> RevisitCandidateLabel: + return make_revisit_candidate_labels( + [record], + target_frame=target_frame, + exclude_local_context_frames=exclude_local_context_frames, + target_pose=target_pose, + fov_overlap_threshold=fov_overlap_threshold, + plucker_weight=plucker_weight, + target_video_id=target_video_id, + fov_half_h=fov_half_h, + fov_half_v=fov_half_v, + fov_yaw_samples=fov_yaw_samples, + fov_pitch_samples=fov_pitch_samples, + fov_depth_samples=fov_depth_samples, + fov_radius=fov_radius, + plucker_grid_h=plucker_grid_h, + plucker_grid_w=plucker_grid_w, + plucker_focal_length=plucker_focal_length, + )[0] diff --git a/algorithms/worldmem/dememwm/memory.py b/algorithms/worldmem/dememwm/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..d54aeb68f35d6ee16f998c767082d879135d37d9 --- /dev/null +++ b/algorithms/worldmem/dememwm/memory.py @@ -0,0 +1,208 @@ + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable, Optional + +import torch + +from .types import MemoryRecord, MemorySourceType + + +@dataclass +class MemoryBankQuery: + target_frame: int + source_type: Optional[MemorySourceType] = None + include_generated: bool = True + max_records: Optional[int] = None + max_slots: Optional[int] = None + + +class CausalMemoryBank: + """Small causal memory bank for DeMemWM records.""" + + def __init__(self, max_records: Optional[int] = None, max_slots: Optional[int] = None): + self.max_records = max_records + self.max_slots = max_slots + self._records: list[MemoryRecord] = [] + + def __len__(self) -> int: + return len(self._records) + + @property + def records(self) -> tuple[MemoryRecord, ...]: + return tuple(self._records) + + def add_record(self, record: MemoryRecord) -> None: + if record.source_type == MemorySourceType.PREFIX_GT and record.is_generated: + raise ValueError("generated records cannot be high-trust prefix anchors") + self._records.append(record) + if self.max_records is not None and len(self._records) > self.max_records: + self._records = self._records[-self.max_records:] + + def add_prefix_anchors( + self, + tokens: torch.Tensor, + mask: torch.Tensor, + frame_indices: torch.Tensor, + pose: Optional[torch.Tensor] = None, + slots_per_anchor: Optional[int] = None, + ) -> None: + if tokens.ndim == 2: + tokens = tokens.unsqueeze(0) + if mask.ndim == 1: + mask = mask.unsqueeze(0) + flat_frames = frame_indices.detach().reshape(-1) + if tokens.shape[0] != flat_frames.numel(): + raise ValueError("tokens first dimension must match number of frame indices") + for i, frame in enumerate(flat_frames.tolist()): + rec_tokens = tokens[i] + rec_mask = mask[i].bool() + if slots_per_anchor is not None: + rec_tokens = rec_tokens[:slots_per_anchor] + rec_mask = rec_mask[:slots_per_anchor] + self.add_record( + MemoryRecord( + tokens=rec_tokens, + mask=rec_mask, + source_start=int(frame), + source_end=int(frame) + 1, + frame_indices=torch.as_tensor([frame], device=rec_tokens.device), + pose=None if pose is None else pose[i], + source_type=MemorySourceType.PREFIX_GT, + is_generated=False, + chunk_id=f"prefix_{int(frame)}", + ) + ) + + def add_chunk_record( + self, + tokens: torch.Tensor, + mask: torch.Tensor, + frame_indices: torch.Tensor, + pose: Optional[torch.Tensor] = None, + source_type: MemorySourceType = MemorySourceType.PREFIX_GT, + is_generated: bool = False, + chunk_id: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> None: + flat_frames = frame_indices.detach().reshape(-1) + if flat_frames.numel() == 0: + raise ValueError("chunk frame_indices must be non-empty") + if tokens.ndim != 2: + raise ValueError("chunk tokens must have shape (M,D)") + if mask.ndim != 1 or mask.shape[0] != tokens.shape[0]: + raise ValueError("chunk mask must have shape (M,)") + start = int(flat_frames.min().item()) + end = int(flat_frames.max().item()) + 1 + self.add_record( + MemoryRecord( + tokens=tokens, + mask=mask.bool(), + source_start=start, + source_end=end, + frame_indices=flat_frames.to(device=tokens.device), + pose=pose, + source_type=source_type, + is_generated=bool(is_generated), + chunk_id=chunk_id or f"{source_type.value}_chunk_{start}_{end}", + metadata=dict(metadata or {}), + ) + ) + + def add_frame_record( + self, + tokens: torch.Tensor, + mask: torch.Tensor, + frame_index: torch.Tensor | int, + pose: Optional[torch.Tensor] = None, + source_type: MemorySourceType = MemorySourceType.REVISIT, + is_generated: bool = False, + record_id: Optional[str] = None, + metadata: Optional[dict] = None, + ) -> None: + frame_tensor = torch.as_tensor([int(torch.as_tensor(frame_index).reshape(-1)[0].item())], device=tokens.device) + frame = int(frame_tensor.item()) + self.add_record( + MemoryRecord( + tokens=tokens, + mask=mask.bool(), + source_start=frame, + source_end=frame + 1, + frame_indices=frame_tensor, + pose=pose, + source_type=source_type, + is_generated=bool(is_generated), + chunk_id=record_id or f"{source_type.value}_frame_{frame}", + metadata=dict(metadata or {}), + ) + ) + + def add_generated_records( + self, + tokens: torch.Tensor, + mask: torch.Tensor, + frame_indices: torch.Tensor, + pose: Optional[torch.Tensor] = None, + source_type: MemorySourceType = MemorySourceType.GENERATED, + ) -> None: + if source_type == MemorySourceType.PREFIX_GT: + raise ValueError("generated frames cannot be added as PREFIX_GT anchors by default") + if tokens.ndim == 2: + tokens = tokens.unsqueeze(0) + if mask.ndim == 1: + mask = mask.unsqueeze(0) + flat_frames = frame_indices.detach().reshape(-1) + for i, frame in enumerate(flat_frames.tolist()): + self.add_record( + MemoryRecord( + tokens=tokens[i], + mask=mask[i].bool(), + source_start=int(frame), + source_end=int(frame) + 1, + frame_indices=torch.as_tensor([frame], device=tokens.device), + pose=None if pose is None else pose[i], + source_type=source_type, + is_generated=True, + chunk_id=f"generated_{int(frame)}", + ) + ) + + def query(self, query: MemoryBankQuery | int, **kwargs) -> list[MemoryRecord]: + if isinstance(query, int): + query = MemoryBankQuery(target_frame=query, **kwargs) + out: list[MemoryRecord] = [] + used_slots = 0 + for record in self._records: + if int(record.source_end) > int(query.target_frame): + continue + if query.source_type is not None and record.source_type != query.source_type: + continue + if not query.include_generated and record.is_generated: + continue + if query.max_slots is not None and used_slots >= query.max_slots: + break + out.append(record) + if query.max_slots is not None: + used_slots += record.valid_slots + if query.max_records is not None and len(out) >= query.max_records: + break + if query.max_slots is not None and used_slots >= query.max_slots: + break + return out + + def assert_causal(self, target_frame: int, records: Iterable[MemoryRecord]) -> None: + offenders = [r.chunk_id or f"[{r.source_start},{r.source_end})" for r in records if int(r.source_end) > int(target_frame)] + if offenders: + raise AssertionError(f"future/non-causal memory selected for target {target_frame}: {offenders}") + + +def stack_record_tokens(records: list[MemoryRecord], max_slots: int | None = None): + if not records: + return None, None + tokens = torch.cat([r.tokens for r in records], dim=0) + mask = torch.cat([r.mask.bool() for r in records], dim=0) + if max_slots is not None: + tokens = tokens[:max_slots] + mask = mask[:max_slots] + return tokens, mask diff --git a/algorithms/worldmem/dememwm/negatives.py b/algorithms/worldmem/dememwm/negatives.py new file mode 100644 index 0000000000000000000000000000000000000000..704641ddf14632b302275eeaed907cc9ff32f92b --- /dev/null +++ b/algorithms/worldmem/dememwm/negatives.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import torch + +from .schedules import EVAL_CORRUPTION_BRANCHES + + +def _deterministic_noise_like(tokens: torch.Tensor, seed: int) -> torch.Tensor: + flat = torch.arange(tokens.numel(), device=tokens.device, dtype=torch.float32).reshape(tokens.shape) + noise = torch.sin(flat + float(seed) * 0.137).to(dtype=tokens.dtype) + scale = tokens.detach().float().std().to(device=tokens.device).clamp_min(0.05).to(dtype=tokens.dtype) + return noise * scale + + +def apply_revisit_eval_corruption( + *, + tokens: torch.Tensor, + mask: torch.Tensor, + branch: str, + target_frame: int, +) -> tuple[torch.Tensor, bool]: + if branch not in EVAL_CORRUPTION_BRANCHES or not mask.any(): + return tokens, False + + corrupted = tokens.clone() + if branch == "wrong_pose": + corrupted = -corrupted + elif branch == "time_shuffle": + corrupted = torch.flip(corrupted, dims=(0,)) + elif branch == "source_matched_random": + corrupted = _deterministic_noise_like(corrupted, seed=int(target_frame)) + elif branch == "local_context_overlap_fake_revisit": + corrupted = torch.roll(corrupted, shifts=1, dims=0) + elif branch == "pose_shuffle": + corrupted = torch.roll(corrupted, shifts=1, dims=-1) + elif branch == "wrong_video": + corrupted = corrupted.detach().mean().to(dtype=corrupted.dtype).expand_as(corrupted).clone() + else: + return tokens, False + + return corrupted, True diff --git a/algorithms/worldmem/dememwm/retrieval.py b/algorithms/worldmem/dememwm/retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..8c545ff7f6088c41137e07eba46078f0fe1993a3 --- /dev/null +++ b/algorithms/worldmem/dememwm/retrieval.py @@ -0,0 +1,476 @@ +from __future__ import annotations + +import math +from dataclasses import replace +from typing import Any, Optional + +import torch + +from .labels import ( + LABEL_SOURCE, + RevisitCandidateLabel, + _inside_fov_3d_hv, + _plucker_descriptor, + _target_fov_points, +) +from .types import MemoryRecord, RevisitRetrievalResult + + +def _overlap_values(labels, name: str) -> list[float]: + values: list[float] = [] + for label in labels: + value = getattr(label, name) + if value is not None: + values.append(float(value)) + return values + + +def _overlap_stats(values: list[float], prefix: str) -> dict[str, float]: + if not values: + return {f"{prefix}_mean": 0.0, f"{prefix}_min": 0.0, f"{prefix}_max": 0.0} + return { + f"{prefix}_mean": float(sum(values) / len(values)), + f"{prefix}_min": float(min(values)), + f"{prefix}_max": float(max(values)), + } + + +def _pose_rows(pose) -> torch.Tensor | None: + if pose is None: + return None + pose_tensor = pose if torch.is_tensor(pose) else torch.as_tensor(pose, dtype=torch.float32) + if pose_tensor.ndim == 0 or pose_tensor.numel() == 0 or pose_tensor.shape[-1] < 5: + return None + return pose_tensor.detach().reshape(-1, pose_tensor.shape[-1])[:, :5].to(dtype=torch.float32) + + +def _pose_forward(poses: torch.Tensor) -> torch.Tensor: + pitch = torch.deg2rad(poses[:, 3]) + yaw = torch.deg2rad(poses[:, 4]) + cos_pitch = torch.cos(pitch) + return torch.stack( + [ + cos_pitch * torch.sin(yaw), + torch.sin(pitch), + cos_pitch * torch.cos(yaw), + ], + dim=-1, + ) + + +def _single_frame_pose(record: MemoryRecord) -> torch.Tensor | None: + if int(record.frame_indices.numel()) != 1: + return None + pose_rows = _pose_rows(record.pose) + if pose_rows is None or int(pose_rows.shape[0]) != 1: + return None + return pose_rows[0] + + +def _vectorized_frame_candidate_labels( + records: list[MemoryRecord], + *, + target_frame: int, + target_pose, + fov_overlap_threshold: Optional[float], + fov_half_h: float, + fov_half_v: float, + fov_yaw_samples: int, + fov_pitch_samples: int, + fov_depth_samples: int, + fov_radius: float, + plucker_grid_h: int, + plucker_grid_w: int, + plucker_focal_length: float, + pose_preselect_topk: Optional[int], +) -> tuple[list[RevisitCandidateLabel], dict[str, float | int]]: + diagnostics: dict[str, float | int] = { + "revisit_pose_preselect_input_count": len(records), + "revisit_pose_preselect_scored_count": len(records), + "revisit_pose_preselect_unscored_count": 0, + "revisit_pose_preselect_selected_count": len(records), + "revisit_pose_preselect_min_distance": 0.0, + "revisit_pose_preselect_max_distance": 0.0, + "revisit_exact_fov_candidate_count": len(records), + "revisit_vectorized_frame_scorer_used": 1, + } + if not records: + return [], diagnostics + + target_poses = _pose_rows(target_pose) + if target_poses is None: + raise ValueError("DeMemWM revisit retrieval requires target_pose for frame-level FoV scoring") + + pose_rows: list[torch.Tensor] = [] + for record in records: + pose_row = _single_frame_pose(record) + if pose_row is None: + raise ValueError( + "DeMemWM revisit retrieval requires frame-level records with exactly one frame index and one pose row" + ) + pose_rows.append(pose_row) + + device = pose_rows[0].device + if target_poses.is_cuda: + device = target_poses.device + source_poses = torch.stack([row.to(device=device, dtype=torch.float32) for row in pose_rows], dim=0) + target = target_poses[-1].to(device=device, dtype=torch.float32) + + selected_indices = list(range(len(records))) + topk = None if pose_preselect_topk is None else int(pose_preselect_topk) + if topk is not None and topk > 0 and len(records) > topk: + translation_norm = torch.linalg.vector_norm(source_poses[:, :3] - target[:3], dim=-1) / max(float(fov_radius), 1e-6) + source_forward = _pose_forward(source_poses) + target_forward = _pose_forward(target.reshape(1, -1)).squeeze(0) + dot = (source_forward * target_forward.reshape(1, 3)).sum(dim=-1).clamp(-1.0, 1.0) + distances = translation_norm + (torch.acos(dot) / math.pi) + distance_values = [float(value) for value in distances.detach().cpu().tolist()] + ranked = [ + ( + distance_values[idx], + -int(record.max_source_frame), + int(record.source_start), + str(record.chunk_id or ""), + idx, + ) + for idx, record in enumerate(records) + ] + ranked.sort() + selected_indices = [idx for *_, idx in ranked[:topk]] + diagnostics["revisit_pose_preselect_selected_count"] = len(selected_indices) + diagnostics["revisit_pose_preselect_min_distance"] = float(min(distance_values)) + diagnostics["revisit_pose_preselect_max_distance"] = float(max(distance_values)) + + selected_tensor = torch.tensor(selected_indices, device=device, dtype=torch.long) + selected_records = [records[idx] for idx in selected_indices] + selected_poses = source_poses.index_select(0, selected_tensor) + points = _target_fov_points( + target, + fov_half_h=fov_half_h, + fov_half_v=fov_half_v, + yaw_samples=fov_yaw_samples, + pitch_samples=fov_pitch_samples, + depth_samples=fov_depth_samples, + radius=fov_radius, + ) + inside = _inside_fov_3d_hv(points, selected_poses, fov_half_h=fov_half_h, fov_half_v=fov_half_v) + fov_values = inside.float().mean(dim=1) + + source_desc = _plucker_descriptor( + selected_poses, + grid_h=plucker_grid_h, + grid_w=plucker_grid_w, + focal_length=plucker_focal_length, + ) + target_desc = _plucker_descriptor( + target.reshape(1, -1), + grid_h=plucker_grid_h, + grid_w=plucker_grid_w, + focal_length=plucker_focal_length, + ) + diff = source_desc - target_desc + distances = torch.linalg.vector_norm(diff, dim=-1) / math.sqrt(float(diff.shape[-1])) + plucker_values = 1.0 / (1.0 + distances.clamp_min(0.0)) + valid_mask = torch.ones_like(fov_values, dtype=torch.bool) + if fov_overlap_threshold is not None: + valid_mask = fov_values >= float(fov_overlap_threshold) + + diagnostics["revisit_exact_fov_candidate_count"] = len(selected_records) + fov_list = [float(value) for value in fov_values.detach().cpu().tolist()] + plucker_list = [float(value) for value in plucker_values.detach().cpu().tolist()] + valid_list = [bool(value) for value in valid_mask.detach().cpu().tolist()] + + labels: list[RevisitCandidateLabel] = [] + for row_idx, record in enumerate(selected_records): + fov_overlap = fov_list[row_idx] + reasons = () if valid_list[row_idx] else ("fov_overlap_below_threshold",) + gap_to_target = int(target_frame) - (int(record.source_end) - 1) + labels.append( + RevisitCandidateLabel( + record=record, + valid=valid_list[row_idx], + gap_valid=True, + gap_to_target=gap_to_target, + fov_overlap=fov_overlap, + plucker_overlap=plucker_list[row_idx], + primary_overlap=fov_overlap, + coverage_mask=inside[row_idx].detach(), + reject_reasons=reasons, + best_frame_index=int(record.max_source_frame), + best_frame_fov_overlap=fov_overlap, + ) + ) + return labels, diagnostics + + +def _coverage_gain(label: RevisitCandidateLabel, covered_mask: torch.Tensor | None) -> float: + mask = label.coverage_mask + if mask is None or mask.numel() == 0: + return 0.0 if label.fov_overlap is None else float(label.fov_overlap) + mask = mask.detach().bool() + if covered_mask is None or covered_mask.shape != mask.shape: + return float(mask.float().mean().item()) + return float((mask & ~covered_mask.to(device=mask.device, dtype=torch.bool)).float().mean().item()) + + +def _coverage_gains(labels: list[RevisitCandidateLabel], covered_mask: torch.Tensor | None) -> list[float]: + masks = [label.coverage_mask for label in labels] + valid_masks = [mask for mask in masks if mask is not None and mask.numel() > 0] + if not valid_masks: + return [0.0 if label.fov_overlap is None else float(label.fov_overlap) for label in labels] + + shape = valid_masks[0].shape + device = valid_masks[0].device + if any(mask.shape != shape for mask in valid_masks): + return [_coverage_gain(label, covered_mask) for label in labels] + + stacked = torch.stack([ + torch.zeros(shape, device=device, dtype=torch.bool) + if mask is None or mask.numel() == 0 + else mask.detach().to(device=device, dtype=torch.bool) + for mask in masks + ]) + if covered_mask is None or covered_mask.shape != shape: + gains = stacked.float().mean(dim=1) + else: + covered = covered_mask.to(device=device, dtype=torch.bool) + gains = (stacked & ~covered).float().mean(dim=1) + return [float(value) for value in gains.detach().cpu().tolist()] + + +def _select_greedy_coverage( + labels: list[RevisitCandidateLabel], + *, + topk: int, + plucker_weight: float, +) -> tuple[list[RevisitCandidateLabel], list[float], list[float]]: + remaining = list(labels) + selected: list[RevisitCandidateLabel] = [] + selected_scores: list[float] = [] + selected_gains: list[float] = [] + covered_mask: torch.Tensor | None = None + for _ in range(max(0, int(topk))): + if not remaining: + break + gains = _coverage_gains(remaining, covered_mask) + ranked = [] + for idx, (label, gain) in enumerate(zip(remaining, gains)): + plucker = 0.0 if label.plucker_overlap is None else float(label.plucker_overlap) + fov = 0.0 if label.fov_overlap is None else float(label.fov_overlap) + plucker_secondary = float(plucker_weight) * plucker + ranked.append(( + -gain, + -fov, + -plucker_secondary, + label.gap_to_target, + int(label.record.source_start), + str(label.record.chunk_id or ""), + idx, + gain, + )) + ranked.sort() + _, _, _, _, _, _, best_idx, best_gain = ranked[0] + label = remaining.pop(best_idx) + selected.append(label) + selected_scores.append(float(best_gain)) + selected_gains.append(float(best_gain)) + if label.coverage_mask is not None and label.coverage_mask.numel() > 0: + mask = label.coverage_mask.detach().bool() + if covered_mask is None or covered_mask.shape != mask.shape: + covered_mask = torch.zeros_like(mask, dtype=torch.bool) + covered_mask = covered_mask.to(device=mask.device, dtype=torch.bool) | mask + return selected, selected_scores, selected_gains + + +def _best_selected_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidateLabel | None: + if not labels: + return None + return max( + labels, + key=lambda label: ( + 0.0 if label.fov_overlap is None else float(label.fov_overlap), + 0.0 if label.plucker_overlap is None else float(label.plucker_overlap), + -int(label.gap_to_target), + -int(label.record.source_start), + str(label.record.chunk_id or ""), + ), + ) + + +def _best_selected_frame_label(labels: list[RevisitCandidateLabel]) -> RevisitCandidateLabel | None: + frame_labels = [label for label in labels if label.best_frame_fov_overlap is not None] + if not frame_labels: + return None + return max( + frame_labels, + key=lambda label: ( + float(label.best_frame_fov_overlap), + 0.0 if label.fov_overlap is None else float(label.fov_overlap), + 0.0 if label.plucker_overlap is None else float(label.plucker_overlap), + -int(label.gap_to_target), + -int(label.record.source_start), + str(label.record.chunk_id or ""), + ), + ) + + +def _record_with_selected_frame_metadata( + label: RevisitCandidateLabel, + *, + high_quality_fov_threshold: float, +) -> MemoryRecord: + metadata = dict(label.record.metadata or {}) + if label.fov_overlap is not None: + metadata["dememwm_selected_revisit_fov_overlap"] = float(label.fov_overlap) + if label.plucker_overlap is not None: + metadata["dememwm_selected_revisit_plucker_overlap"] = float(label.plucker_overlap) + if label.best_frame_index is not None: + metadata["dememwm_selected_frame_index"] = int(label.best_frame_index) + if label.best_frame_fov_overlap is not None: + frame_fov = float(label.best_frame_fov_overlap) + metadata["dememwm_selected_frame_fov_overlap"] = frame_fov + metadata["dememwm_selected_frame_fov_threshold"] = float(high_quality_fov_threshold) + metadata["dememwm_selected_frame_passes_high_quality"] = bool(frame_fov >= float(high_quality_fov_threshold)) + return replace(label.record, metadata=metadata) + + +def deterministic_revisit_retrieval( + records: list[MemoryRecord], + target_frame: int, + target_pose: Optional[torch.Tensor] = None, + target_summary: Optional[torch.Tensor] = None, + topk: int = 2, + exclude_local_context_frames: int = 0, + fov_overlap_threshold: Optional[float] = 0.30, + high_quality_fov_threshold: float = 0.70, + plucker_weight: float = 0.1, + target_video_id: Any = None, + fov_half_h: float = 105.0 / 2.0, + fov_half_v: float = 75.0 / 2.0, + fov_yaw_samples: int = 25, + fov_pitch_samples: int = 20, + fov_depth_samples: int = 20, + fov_radius: float = 30.0, + plucker_grid_h: int = 4, + plucker_grid_w: int = 4, + plucker_focal_length: float = 0.35, + pose_preselect_topk: Optional[int] = 64, + **_legacy_scoring_kwargs, +) -> RevisitRetrievalResult: + del target_summary, target_video_id + topk = max(0, int(topk)) + target_frame = int(target_frame) + exclude_local_context_frames = int(exclude_local_context_frames) + causal_records = [record for record in records if int(record.source_end) <= target_frame] + score_records = [ + record + for record in causal_records + if int(record.source_end) <= target_frame - exclude_local_context_frames + ] + labels, pose_preselect_diagnostics = _vectorized_frame_candidate_labels( + score_records, + target_frame=target_frame, + target_pose=target_pose, + fov_overlap_threshold=fov_overlap_threshold, + fov_half_h=fov_half_h, + fov_half_v=fov_half_v, + fov_yaw_samples=fov_yaw_samples, + fov_pitch_samples=fov_pitch_samples, + fov_depth_samples=fov_depth_samples, + fov_radius=fov_radius, + plucker_grid_h=plucker_grid_h, + plucker_grid_w=plucker_grid_w, + plucker_focal_length=plucker_focal_length, + pose_preselect_topk=pose_preselect_topk, + ) + exact_fov_candidate_count = int(pose_preselect_diagnostics["revisit_exact_fov_candidate_count"]) + valid_labels = [label for label in labels if label.valid] + selected_labels, selected_scores, selected_gains = _select_greedy_coverage( + valid_labels, + topk=topk, + plucker_weight=float(plucker_weight), + ) + best_selected = _best_selected_label(selected_labels) + best_selected_frame = _best_selected_frame_label(selected_labels) + best_selected_fov = 0.0 if best_selected is None or best_selected.fov_overlap is None else float(best_selected.fov_overlap) + best_selected_plucker = 0.0 if best_selected is None or best_selected.plucker_overlap is None else float(best_selected.plucker_overlap) + best_selected_gap = -1 if best_selected is None else int(best_selected.gap_to_target) + best_selected_frame_fov = 0.0 if best_selected_frame is None else float(best_selected_frame.best_frame_fov_overlap) + best_selected_frame_index = -1 if best_selected_frame is None or best_selected_frame.best_frame_index is None else int(best_selected_frame.best_frame_index) + high_quality_selected = int(best_selected_frame is not None and best_selected_frame_fov >= float(high_quality_fov_threshold)) + selected_records = [ + _record_with_selected_frame_metadata(label, high_quality_fov_threshold=float(high_quality_fov_threshold)) + for label in selected_labels + ] + score_device = selected_records[0].tokens.device if selected_records else torch.device("cpu") + scores = torch.tensor(selected_scores, dtype=torch.float32, device=score_device) + + fov_values = _overlap_values(valid_labels, "fov_overlap") + plucker_values = _overlap_values(valid_labels, "plucker_overlap") + selected_gaps = [label.gap_to_target for label in selected_labels] + selected_frame_fov_values = [ + float(label.best_frame_fov_overlap) + for label in selected_labels + if label.best_frame_fov_overlap is not None + ] + diagnostics = { + "target_frame": int(target_frame), + "candidate_count": len(causal_records), + "candidate_frame_count": len(causal_records), + "valid_candidate_count": len(valid_labels), + "revisit_exact_fov_candidate_count": exact_fov_candidate_count, + "valid_candidate_frame_count": len(valid_labels), + "valid_candidate_label_count": len(valid_labels), + "selected_count": len(selected_records), + "selected_frame_count": len(selected_records), + "revisit_candidate_frame_count": len(causal_records), + "revisit_candidate_count": len(causal_records), + "valid_revisit_frame_count": len(valid_labels), + "valid_revisit_count": len(valid_labels), + "valid_revisit_target_count": high_quality_selected, + "no_valid_revisit_count": int(len(valid_labels) == 0), + "valid_revisit_mask": int(len(valid_labels) > 0), + "revisit_abstained_count": int(len(selected_records) == 0), + "abstained": bool(len(selected_records) == 0), + "revisit_selected_frame_count": len(selected_records), + "revisit_selected_count": len(selected_records), + "revisit_min_gap_to_target": int(min(selected_gaps)) if selected_gaps else -1, + "best_selected_fov_overlap": best_selected_fov, + "best_selected_plucker_overlap": best_selected_plucker, + "best_selected_gap_frames": best_selected_gap, + "best_selected_frame_index": best_selected_frame_index, + "best_selected_frame_fov_overlap": best_selected_frame_fov, + "best_selected_frame_passes_high_quality": high_quality_selected, + "high_quality_selected_revisit": high_quality_selected, + "high_quality_fov_threshold": float(high_quality_fov_threshold), + "revisit_label_source": LABEL_SOURCE, + "selected_frame_ids": [int(record.max_source_frame) for record in selected_records], + "selected_frame_record_ids": [record.chunk_id for record in selected_records], + "selected_ranges": [(record.source_start, record.source_end) for record in selected_records], + "frame_fov_overlap_values": fov_values, + "fov_overlap_values": fov_values, + "plucker_overlap_values": plucker_values, + "best_selected_fov_overlap_values": [] if best_selected is None else [best_selected_fov], + "best_selected_plucker_overlap_values": [] if best_selected is None else [best_selected_plucker], + "best_selected_gap_frame_values": [] if best_selected is None else [best_selected_gap], + "best_selected_frame_fov_overlap_values": [] if best_selected_frame is None else [best_selected_frame_fov], + "selected_frame_fov_overlap_values": selected_frame_fov_values, + "selected_incremental_fov_overlap_values": selected_gains, + "selected_revisit_scores": selected_scores, + **pose_preselect_diagnostics, + } + diagnostics.update(_overlap_stats(fov_values, "revisit_frame_fov_overlap")) + diagnostics.update(_overlap_stats(fov_values, "revisit_fov_overlap")) + diagnostics.update(_overlap_stats(plucker_values, "revisit_plucker_overlap")) + diagnostics.update(_overlap_stats(diagnostics["best_selected_fov_overlap_values"], "revisit_best_selected_fov_overlap")) + diagnostics.update(_overlap_stats(diagnostics["best_selected_plucker_overlap_values"], "revisit_best_selected_plucker_overlap")) + diagnostics.update(_overlap_stats(diagnostics["best_selected_gap_frame_values"], "revisit_best_selected_gap_frames")) + diagnostics.update(_overlap_stats(diagnostics["best_selected_frame_fov_overlap_values"], "revisit_best_selected_frame_fov_overlap")) + diagnostics.update(_overlap_stats(selected_frame_fov_values, "revisit_selected_frame_fov_overlap")) + diagnostics.update(_overlap_stats(selected_gains, "revisit_incremental_fov_overlap")) + return RevisitRetrievalResult( + records=selected_records, + scores=scores, + selected_frame_ids=[int(record.max_source_frame) for record in selected_records], + diagnostics=diagnostics, + ) diff --git a/algorithms/worldmem/dememwm/schedules.py b/algorithms/worldmem/dememwm/schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..3ac007904468e5a3a855d2da96e1441ed9b26847 --- /dev/null +++ b/algorithms/worldmem/dememwm/schedules.py @@ -0,0 +1,223 @@ + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + +from .types import StreamGateState + +NOISE_BUCKETS = ("high", "mid", "low") +NOISE_BUCKET_TO_ID = {name: idx for idx, name in enumerate(NOISE_BUCKETS)} +EVAL_ABLATION_BRANCHES = ( + "memory_off", + "A_only", + "D_only", + "A_plus_D", + "A_plus_D_plus_R_normal", + "R_forced_off", + "R_forced_on", + "wrong_pose", + "time_shuffle", + "source_matched_random", + "pose_shuffle", + "wrong_video", + "local_context_overlap_fake_revisit", +) +EVAL_ABLATION_BRANCH_TO_ID = {name: idx for idx, name in enumerate(EVAL_ABLATION_BRANCHES)} +EVAL_CORRUPTION_BRANCHES = ( + "wrong_pose", + "time_shuffle", + "source_matched_random", + "pose_shuffle", + "wrong_video", + "local_context_overlap_fake_revisit", +) + + + +def _clamp01(value: float) -> float: + return max(0.0, min(1.0, float(value))) + + +def noise_bucket_from_denoising_fraction(denoising_fraction: float | None) -> str: + if denoising_fraction is None: + return "mid" + frac = _clamp01(float(denoising_fraction)) + if frac < (1.0 / 3.0): + return "high" + if frac < (2.0 / 3.0): + return "mid" + return "low" + + +def noise_bucket_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> str: + if noise_levels is None or timesteps is None or int(timesteps) <= 1: + return "mid" + noise_fraction = _clamp01(float(noise_levels.detach().float().mean().item()) / float(int(timesteps) - 1)) + if noise_fraction >= (2.0 / 3.0): + return "high" + if noise_fraction >= (1.0 / 3.0): + return "mid" + return "low" + + +def noise_bucket_ids_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> torch.Tensor | None: + if noise_levels is None or timesteps is None or int(timesteps) <= 1: + return None + noise_fraction = noise_levels.detach().float() / float(int(timesteps) - 1) + bucket_ids = torch.full_like(noise_levels, NOISE_BUCKET_TO_ID["mid"], dtype=torch.long) + bucket_ids = torch.where( + noise_fraction >= (2.0 / 3.0), + torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["high"]), + bucket_ids, + ) + bucket_ids = torch.where( + noise_fraction < (1.0 / 3.0), + torch.full_like(bucket_ids, NOISE_BUCKET_TO_ID["low"]), + bucket_ids, + ) + return bucket_ids + + +def denoising_fraction_from_noise_levels(noise_levels: torch.Tensor | None, timesteps: int | None) -> float | None: + if noise_levels is None or timesteps is None or int(timesteps) <= 1: + return None + noise_fraction = _clamp01(float(noise_levels.detach().float().mean().item()) / float(int(timesteps) - 1)) + return _clamp01(1.0 - noise_fraction) + + +def normalize_eval_ablation_branch(branch: str | None) -> str: + if branch is None: + return "A_plus_D_plus_R_normal" + branch = str(branch) + if branch not in EVAL_ABLATION_BRANCH_TO_ID: + raise ValueError(f"unknown DeMemWM eval ablation branch: {branch}") + return branch + + +def normalize_noise_bucket(noise_bucket: str | None) -> str: + if noise_bucket in NOISE_BUCKET_TO_ID: + return str(noise_bucket) + return "mid" + + +_STAGE_ENABLES = { + 'stage_1': (True, True, True), + 'stage_2': (True, True, True), +} + + + +@dataclass(frozen=True) +class CurriculumState: + """Step-resolved DeMemWM curriculum/freezing state for one continuous run.""" + + global_step: int + enabled: bool + stage: str + anchor_enabled: bool + dynamic_enabled: bool + revisit_enabled: bool + dit_train_state: str + freeze_vae: bool + dememwm_lr: float + memory_adapter_lr: float + full_dit_lr: float + + + @property + def dit_full_trainable(self) -> bool: + return self.dit_train_state == "full" + + def diagnostics(self) -> dict[str, Any]: + return { + "dememwm_global_step": self.global_step, + "dememwm_curriculum_enabled": self.enabled, + "dememwm_stage": self.stage, + "curriculum_anchor_enabled": self.anchor_enabled, + "curriculum_dynamic_enabled": self.dynamic_enabled, + "curriculum_revisit_enabled": self.revisit_enabled, + "dit_train_state": self.dit_train_state, + "dit_full_trainable": self.dit_full_trainable, + "freeze_vae": self.freeze_vae, + "lr_dememwm_modules": self.dememwm_lr, + "lr_memory_adapters": self.memory_adapter_lr, + "lr_full_dit": self.full_dit_lr, + } + + +def _cfg_get(obj: Any, name: str, default: Any) -> Any: + return getattr(obj, name, default) if obj is not None else default + + +def _stage_for_step(curriculum_cfg: Any, step: int) -> str: + full_start = int(_cfg_get(curriculum_cfg, 'full_stage_start_step', 60000)) + return 'stage_2' if step >= full_start else 'stage_1' + + +def _dit_train_state(curriculum_cfg: Any, step: int) -> str: + freeze_cfg = _cfg_get(curriculum_cfg, 'dit_freeze', None) + freeze_enabled = bool(_cfg_get(freeze_cfg, 'enabled', True)) + full_step = int(_cfg_get(curriculum_cfg, 'full_stage_start_step', 60000)) + if freeze_enabled and step < full_step: + return 'frozen' + return 'full' + + +def resolve_curriculum(memory_cfg: Any, global_step: int | None = None) -> CurriculumState: + """Resolve internal DeMemWM curriculum phase from Lightning global_step. + + This intentionally supports one continuous training run; stage names are internal + gates only and do not imply separate jobs/checkpoints. + """ + + step = max(0, int(global_step or 0)) + curriculum_cfg = _cfg_get(memory_cfg, "curriculum", None) + lr_cfg = _cfg_get(curriculum_cfg, "lr", None) + enabled = bool(_cfg_get(curriculum_cfg, "enabled", False)) + + if enabled: + stage = _stage_for_step(curriculum_cfg, step) + dit_state = _dit_train_state(curriculum_cfg, step) + freeze_vae = bool(_cfg_get(curriculum_cfg, "freeze_vae", True)) + else: + stage = str(_cfg_get(memory_cfg, "training_stage", "stage_1")) + dit_state = "full" + freeze_vae = True + + if stage not in _STAGE_ENABLES: + raise ValueError(f"unknown DeMemWM stage: {stage}") + anchor_on, dynamic_on, revisit_on = _STAGE_ENABLES[stage] + return CurriculumState( + global_step=step, + enabled=enabled, + stage=stage, + anchor_enabled=anchor_on, + dynamic_enabled=dynamic_on, + revisit_enabled=revisit_on, + dit_train_state=dit_state, + freeze_vae=freeze_vae, + dememwm_lr=float(_cfg_get(lr_cfg, "dememwm_modules", 1.0e-4)), + memory_adapter_lr=float(_cfg_get(lr_cfg, "memory_adapters", 1.0e-4)), + full_dit_lr=float(_cfg_get(lr_cfg, "full_dit", 1.0e-5)), + ) + + +DeMemWMCurriculumState = CurriculumState +resolve_dememwm_curriculum = resolve_curriculum + + +def compute_stream_gates(stage: str, denoising_fraction: float | None = None, debug_force_all_streams: bool = False, anchor_gate: float = 1.0, dynamic_gate: float = 1.0, revisit_gate: float = 1.0) -> StreamGateState: + if debug_force_all_streams: + return StreamGateState(True, True, True, float(anchor_gate), float(dynamic_gate), float(revisit_gate), stage, "debug_force_all_streams") + if stage not in _STAGE_ENABLES: + raise ValueError(f"unknown DeMemWM stage: {stage}") + a_on, d_on, r_on = _STAGE_ENABLES[stage] + if denoising_fraction is not None: + denoising_fraction = max(0.0, min(1.0, float(denoising_fraction))) + stage_scale = 0.25 + 0.75 * denoising_fraction + else: + stage_scale = 1.0 + return StreamGateState(a_on, d_on, r_on, float(anchor_gate) if a_on else 0.0, float(dynamic_gate) * stage_scale if d_on else 0.0, float(revisit_gate) * stage_scale if r_on else 0.0, stage, "stage_schedule") diff --git a/algorithms/worldmem/dememwm/types.py b/algorithms/worldmem/dememwm/types.py new file mode 100644 index 0000000000000000000000000000000000000000..62beb66831d169c5c7a263f67c457d49015b9e03 --- /dev/null +++ b/algorithms/worldmem/dememwm/types.py @@ -0,0 +1,98 @@ + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional + +import torch + + +class MemorySourceType(str, Enum): + PREFIX_GT = "prefix_gt" + GENERATED = "generated" + DYNAMIC = "dynamic" + REVISIT = "revisit" + + +@dataclass +class MemoryRecord: + """One causal DeMemWM memory item. + + Frame ranges use an inclusive start and exclusive end: [source_start, source_end). + All source frame indices represented by this record must be strictly smaller + than a queried target frame unless the caller is explicitly querying an + already-committed prefix frame. + """ + + tokens: torch.Tensor + mask: torch.Tensor + source_start: int + source_end: int + frame_indices: torch.Tensor + pose: Optional[torch.Tensor] + source_type: MemorySourceType + is_generated: bool + score: Optional[float | torch.Tensor] = None + chunk_id: Optional[str] = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if self.source_end <= self.source_start: + raise ValueError("source_end must be greater than source_start") + if self.tokens.ndim < 2: + raise ValueError("tokens must include slot and channel dimensions") + self.mask = self.mask.bool() + if self.mask.ndim != 1: + raise ValueError("mask must have shape (M,)") + if self.mask.shape[0] != self.tokens.shape[0]: + raise ValueError("mask length must match token slots") + if self.source_type == MemorySourceType.PREFIX_GT and self.is_generated: + raise ValueError("generated records cannot be PREFIX_GT anchors") + if self.frame_indices.numel() == 0: + raise ValueError("frame_indices cannot be empty") + + @property + def max_source_frame(self) -> int: + return int(self.frame_indices.detach().max().item()) + + @property + def valid_slots(self) -> int: + return int(self.mask.detach().sum().item()) + + +@dataclass +class MemoryStreamTensors: + anchor_tokens: torch.Tensor + anchor_mask: torch.Tensor + dynamic_tokens: torch.Tensor + dynamic_mask: torch.Tensor + revisit_tokens: torch.Tensor + revisit_mask: torch.Tensor + anchor_gate: torch.Tensor | float + dynamic_gate: torch.Tensor | float + revisit_gate: torch.Tensor | float + revisit_gate_raw: torch.Tensor | None = None + valid_revisit_mask: torch.Tensor | None = None + no_valid_revisit_mask: torch.Tensor | None = None + diagnostics: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class StreamGateState: + anchor_enabled: bool + dynamic_enabled: bool + revisit_enabled: bool + anchor_gate: float + dynamic_gate: float + revisit_gate: float + stage: str + reason: str = "" + + +@dataclass +class RevisitRetrievalResult: + records: list[MemoryRecord] + scores: torch.Tensor + selected_frame_ids: list[int] + diagnostics: dict[str, Any] diff --git a/algorithms/worldmem/dememwm_memory_dit.py b/algorithms/worldmem/dememwm_memory_dit.py new file mode 100644 index 0000000000000000000000000000000000000000..e8849eda7034bf87158393a1a169d38ba7c2d462 --- /dev/null +++ b/algorithms/worldmem/dememwm_memory_dit.py @@ -0,0 +1,18 @@ + +from __future__ import annotations + +from .dememwm.algorithm import MemoryDiTMixin +from .df_video import BaseVideoDiTMinecraft + + +class DeMemWMMinecraft(MemoryDiTMixin, BaseVideoDiTMinecraft): + """Standalone DeMemWM / Memory-DiT algorithm. + + Reuses the base video-DiT VAE/diffusion/training infrastructure, + but owns memory construction/injection. Does not route through the legacy memory method. + """ + + pass + + +DeMemWMMemoryDiTMinecraft = DeMemWMMinecraft diff --git a/algorithms/worldmem/df_base.py b/algorithms/worldmem/df_base.py new file mode 100644 index 0000000000000000000000000000000000000000..96607e3f99a1f3a7c0a5f4233b4cfce829e308a4 --- /dev/null +++ b/algorithms/worldmem/df_base.py @@ -0,0 +1,307 @@ +""" +This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research +template [repo](https://github.com/buoyancy99/research-template). +By its MIT license, you must keep the above sentence in `README.md` +and the `LICENSE` file to credit the author. +""" + +from typing import Optional +from tqdm import tqdm +from omegaconf import DictConfig +import numpy as np +import torch +import torch.nn.functional as F +from typing import Any +from einops import rearrange + +from lightning.pytorch.utilities.types import STEP_OUTPUT + +from algorithms.common.base_pytorch_algo import BasePytorchAlgo +from .models.diffusion import Diffusion + + +class DiffusionForcingBase(BasePytorchAlgo): + def __init__(self, cfg: DictConfig): + self.cfg = cfg + self.x_shape = cfg.x_shape + self.frame_stack = cfg.frame_stack + self.x_stacked_shape = list(self.x_shape) + self.x_stacked_shape[0] *= cfg.frame_stack + self.guidance_scale = cfg.guidance_scale + self.context_frames = cfg.context_frames + self.chunk_size = cfg.chunk_size + self.action_cond_dim = cfg.action_cond_dim + self.causal = cfg.causal + + self.uncertainty_scale = cfg.uncertainty_scale + self.timesteps = cfg.diffusion.timesteps + self.sampling_timesteps = cfg.diffusion.sampling_timesteps + self.clip_noise = cfg.diffusion.clip_noise + + self.cfg.diffusion.cum_snr_decay = self.cfg.diffusion.cum_snr_decay ** (self.frame_stack * cfg.frame_skip) + + self.validation_step_outputs = [] + super().__init__(cfg) + + def _build_model(self): + self.diffusion_model = Diffusion( + x_shape=self.x_stacked_shape, + action_cond_dim=self.action_cond_dim, + is_causal=self.causal, + cfg=self.cfg.diffusion, + ) + self.register_data_mean_std(self.cfg.data_mean, self.cfg.data_std) + + def configure_optimizers(self): + params = tuple(self.diffusion_model.parameters()) + optimizer_dynamics = torch.optim.AdamW( + params, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay, betas=self.cfg.optimizer_beta + ) + return optimizer_dynamics + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): + # update params + optimizer.step(closure=optimizer_closure) + + # manually warm up lr without a scheduler + if self.trainer.global_step < self.cfg.warmup_steps: + lr_scale = min(1.0, float(self.trainer.global_step + 1) / self.cfg.warmup_steps) + for pg in optimizer.param_groups: + pg["lr"] = lr_scale * self.cfg.lr + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + xs, conditions, masks = self._preprocess_batch(batch) + + rand_length = torch.randint(3,xs.shape[0]-2, (1,))[0].item() + xs = torch.cat([xs[:rand_length], xs[rand_length-3:rand_length-1]]) + conditions = torch.cat([conditions[:rand_length], conditions[rand_length-3:rand_length-1]]) + masks = torch.cat([masks[:rand_length], masks[rand_length-3:rand_length-1]]) + noise_levels=self._generate_noise_levels(xs) + noise_levels[:rand_length] = 15 # stable_noise_levels + noise_levels[rand_length+1:] = 15 # stable_noise_levels + + xs_pred, loss = self.diffusion_model(xs, conditions, noise_levels=noise_levels) + loss = self.reweight_loss(loss, masks) + + # log the loss + if batch_idx % 20 == 0: + self.log("training/loss", loss) + + xs = self._unstack_and_unnormalize(xs) + xs_pred = self._unstack_and_unnormalize(xs_pred) + + output_dict = { + "loss": loss, + "xs_pred": xs_pred, + "xs": xs, + } + + return output_dict + + @torch.no_grad() + def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT: + xs, conditions, masks = self._preprocess_batch(batch) + n_frames, batch_size, *_ = xs.shape + xs_pred = [] + curr_frame = 0 + + # context + n_context_frames = self.context_frames // self.frame_stack + xs_pred = xs[:n_context_frames].clone() + curr_frame += n_context_frames + + if self.condtion_similar_length: + n_frames -= self.condtion_similar_length + + pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") + while curr_frame < n_frames: + if self.chunk_size > 0: + horizon = min(n_frames - curr_frame, self.chunk_size) + else: + horizon = n_frames - curr_frame + assert horizon <= self.n_tokens, "horizon exceeds the number of tokens." + scheduling_matrix = self._generate_scheduling_matrix(horizon) + + chunk = torch.randn((horizon, batch_size, *self.x_stacked_shape), device=self.device) + chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise) + xs_pred = torch.cat([xs_pred, chunk], 0) + + # sliding window: only input the last n_tokens frames + start_frame = max(0, curr_frame + horizon - self.n_tokens) + + pbar.set_postfix( + { + "start": start_frame, + "end": curr_frame + horizon, + } + ) + + if self.condtion_similar_length: + xs_pred = torch.cat([xs_pred, xs[curr_frame-self.condtion_similar_length:curr_frame].clone()], 0) + + for m in range(scheduling_matrix.shape[0] - 1): + + from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[ + :, None + ].repeat(batch_size, axis=1) + to_noise_levels = np.concatenate( + ( + np.zeros((curr_frame,), dtype=np.int64), + scheduling_matrix[m + 1], + ) + )[ + :, None + ].repeat(batch_size, axis=1) + + if self.condtion_similar_length: + from_noise_levels = np.concatenate([from_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0) + to_noise_levels = np.concatenate([to_noise_levels, np.array([[0,0,0,0]*self.condtion_similar_length])], axis=0) + + from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device) + to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device) + + # update xs_pred by DDIM or DDPM sampling + # input frames within the sliding window + + try: + input_condition = conditions[start_frame : curr_frame + horizon].clone() + except: + import pdb;pdb.set_trace() + if self.condtion_similar_length: + input_condition = torch.cat([conditions[start_frame : curr_frame + horizon], conditions[-self.condtion_similar_length:]], dim=0) + xs_pred[start_frame:] = self.diffusion_model.sample_step( + xs_pred[start_frame:], + input_condition, + from_noise_levels[start_frame:], + to_noise_levels[start_frame:], + ) + + if self.condtion_similar_length: + xs_pred = xs_pred[:-self.condtion_similar_length] + + curr_frame += horizon + pbar.update(horizon) + + if self.condtion_similar_length: + xs = xs[:-self.condtion_similar_length] + # FIXME: loss + loss = F.mse_loss(xs_pred, xs, reduction="none") + loss = self.reweight_loss(loss, masks) + self.validation_step_outputs.append((xs_pred.detach().cpu(), xs.detach().cpu())) + + return loss + + def test_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + return self.validation_step(*args, **kwargs, namespace="test") + + def on_test_epoch_end(self) -> None: + self.on_validation_epoch_end(namespace="test") + + def _generate_noise_levels(self, xs: torch.Tensor, masks: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Generate noise levels for training. + """ + num_frames, batch_size, *_ = xs.shape + match self.cfg.noise_level: + case "random_all": # entirely random noise levels + noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device) + case "same": + noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device) + noise_levels[1:] = noise_levels[0] + + if masks is not None: + # for frames that are not available, treat as full noise + discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1) + noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels) + + return noise_levels + + def _generate_scheduling_matrix(self, horizon: int): + match self.cfg.scheduling_matrix: + case "pyramid": + return self._generate_pyramid_scheduling_matrix(horizon, self.uncertainty_scale) + case "full_sequence": + return np.arange(self.sampling_timesteps, -1, -1)[:, None].repeat(horizon, axis=1) + case "autoregressive": + return self._generate_pyramid_scheduling_matrix(horizon, self.sampling_timesteps) + case "trapezoid": + return self._generate_trapezoid_scheduling_matrix(horizon, self.uncertainty_scale) + + def _generate_pyramid_scheduling_matrix(self, horizon: int, uncertainty_scale: float): + height = self.sampling_timesteps + int((horizon - 1) * uncertainty_scale) + 1 + scheduling_matrix = np.zeros((height, horizon), dtype=np.int64) + for m in range(height): + for t in range(horizon): + scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m + + return np.clip(scheduling_matrix, 0, self.sampling_timesteps) + + def _generate_trapezoid_scheduling_matrix(self, horizon: int, uncertainty_scale: float): + height = self.sampling_timesteps + int((horizon + 1) // 2 * uncertainty_scale) + scheduling_matrix = np.zeros((height, horizon), dtype=np.int64) + for m in range(height): + for t in range((horizon + 1) // 2): + scheduling_matrix[m, t] = self.sampling_timesteps + int(t * uncertainty_scale) - m + scheduling_matrix[m, -t] = self.sampling_timesteps + int(t * uncertainty_scale) - m + + return np.clip(scheduling_matrix, 0, self.sampling_timesteps) + + def reweight_loss(self, loss, weight=None): + # Note there is another part of loss reweighting (fused_snr) inside the Diffusion class! + loss = rearrange(loss, "t b (fs c) ... -> t b fs c ...", fs=self.frame_stack) + if weight is not None: + expand_dim = len(loss.shape) - len(weight.shape) - 1 + weight = rearrange( + weight, + "(t fs) b ... -> t b fs ..." + " 1" * expand_dim, + fs=self.frame_stack, + ) + loss = loss * weight + + return loss.mean() + + def _preprocess_batch(self, batch): + xs = batch[0] + batch_size, n_frames = xs.shape[:2] + + if n_frames % self.frame_stack != 0: + raise ValueError("Number of frames must be divisible by frame stack size") + if self.context_frames % self.frame_stack != 0: + raise ValueError("Number of context frames must be divisible by frame stack size") + + masks = torch.ones(n_frames, batch_size).to(xs.device) + n_frames = n_frames // self.frame_stack + + if self.action_cond_dim: + conditions = batch[1] + conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1) + conditions = rearrange(conditions, "b (t fs) d -> t b (fs d)", fs=self.frame_stack).contiguous() + + # f, _, _ = conditions.shape + # predefined_1 = torch.tensor([0,0,0,1]).to(conditions.device) + # predefined_2 = torch.tensor([0,0,1,0]).to(conditions.device) + # conditions[:f//2] = predefined_1 + # conditions[f//2:] = predefined_2 + else: + conditions = [None for _ in range(n_frames)] + + xs = self._normalize_x(xs) + xs = rearrange(xs, "b (t fs) c ... -> t b (fs c) ...", fs=self.frame_stack).contiguous() + + return xs, conditions, masks + + def _normalize_x(self, xs): + shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape) + mean = self.data_mean.reshape(shape) + std = self.data_std.reshape(shape) + return (xs - mean) / std + + def _unnormalize_x(self, xs): + shape = [1] * (xs.ndim - self.data_mean.ndim) + list(self.data_mean.shape) + mean = self.data_mean.reshape(shape) + std = self.data_std.reshape(shape) + return xs * std + mean + + def _unstack_and_unnormalize(self, xs): + xs = rearrange(xs, "t b (fs c) ... -> (t fs) b c ...", fs=self.frame_stack) + return self._unnormalize_x(xs) diff --git a/algorithms/worldmem/df_video.py b/algorithms/worldmem/df_video.py new file mode 100644 index 0000000000000000000000000000000000000000..51f73fbb36d9dee0680e62983ca8d3c14e97d5dc --- /dev/null +++ b/algorithms/worldmem/df_video.py @@ -0,0 +1,926 @@ +import os +import random +import math +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from torchvision.transforms import InterpolationMode +from PIL import Image +from packaging import version as pver +from einops import rearrange +from tqdm import tqdm +from omegaconf import DictConfig +from lightning.pytorch.utilities.types import STEP_OUTPUT +from algorithms.common.metrics import ( + LearnedPerceptualImagePatchSimilarity, +) +from utils.logging_utils import log_video, get_validation_metrics_for_videos +from .df_base import DiffusionForcingBase +from .models.vae import VAE_models +from .models.diffusion import Diffusion +from .models.pose_prediction import PosePredictionNet +import glob + +# Utility Functions +def euler_to_rotation_matrix(pitch, yaw): + """ + Convert pitch and yaw angles (in radians) to a 3x3 rotation matrix. + Supports batch input. + + Args: + pitch (torch.Tensor): Pitch angles in radians. + yaw (torch.Tensor): Yaw angles in radians. + + Returns: + torch.Tensor: Rotation matrix of shape (batch_size, 3, 3). + """ + cos_pitch, sin_pitch = torch.cos(pitch), torch.sin(pitch) + cos_yaw, sin_yaw = torch.cos(yaw), torch.sin(yaw) + + R_pitch = torch.stack([ + torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), + torch.zeros_like(pitch), cos_pitch, -sin_pitch, + torch.zeros_like(pitch), sin_pitch, cos_pitch + ], dim=-1).reshape(-1, 3, 3) + + R_yaw = torch.stack([ + cos_yaw, torch.zeros_like(yaw), sin_yaw, + torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw), + -sin_yaw, torch.zeros_like(yaw), cos_yaw + ], dim=-1).reshape(-1, 3, 3) + + return torch.matmul(R_yaw, R_pitch) + + +def euler_to_camera_to_world_matrix(pose): + """ + Convert (x, y, z, pitch, yaw) to a 4x4 camera-to-world transformation matrix using torch. + Supports both (5,) and (f, b, 5) shaped inputs. + + Args: + pose (torch.Tensor): Pose tensor of shape (5,) or (f, b, 5). + + Returns: + torch.Tensor: Camera-to-world transformation matrix of shape (4, 4). + """ + + origin_dim = pose.ndim + if origin_dim == 1: + pose = pose.unsqueeze(0).unsqueeze(0) # Convert (5,) -> (1, 1, 5) + elif origin_dim == 2: + pose = pose.unsqueeze(0) + + x, y, z, pitch, yaw = pose[..., 0], pose[..., 1], pose[..., 2], pose[..., 3], pose[..., 4] + pitch, yaw = torch.deg2rad(pitch), torch.deg2rad(yaw) + + # Compute rotation matrix (batch mode) + R = euler_to_rotation_matrix(pitch, yaw) # Shape (f*b, 3, 3) + + # Create the 4x4 transformation matrix + eye = torch.eye(4, dtype=torch.float32, device=pose.device) + camera_to_world = eye.repeat(R.shape[0], 1, 1) # Shape (f*b, 4, 4) + + # Assign rotation + camera_to_world[:, :3, :3] = R + + # Assign translation + camera_to_world[:, :3, 3] = torch.stack([x.reshape(-1), y.reshape(-1), z.reshape(-1)], dim=-1) + + # Reshape back to (f, b, 4, 4) if needed + if origin_dim == 3: + return camera_to_world.view(pose.shape[0], pose.shape[1], 4, 4) + elif origin_dim == 2: + return camera_to_world.view(pose.shape[0], 4, 4) + else: + return camera_to_world.squeeze(0).squeeze(0) # Convert (1,1,4,4) -> (4,4) + +def is_inside_fov_3d_hv(points, center, center_pitch, center_yaw, fov_half_h, fov_half_v): + """ + Check whether points are within a given 3D field of view (FOV) + with separately defined horizontal and vertical ranges. + + The center view direction is specified by pitch and yaw (in degrees). + + :param points: (N, B, 3) Sample point coordinates + :param center: (3,) Center coordinates of the FOV + :param center_pitch: Pitch angle of the center view (in degrees) + :param center_yaw: Yaw angle of the center view (in degrees) + :param fov_half_h: Horizontal half-FOV angle (in degrees) + :param fov_half_v: Vertical half-FOV angle (in degrees) + :return: Boolean tensor (N, B), indicating whether each point is inside the FOV + """ + # Compute vectors relative to the center + vectors = points - center # shape (N, B, 3) + x = vectors[..., 0] + y = vectors[..., 1] + z = vectors[..., 2] + + # Compute horizontal angle (yaw): measured with respect to the z-axis as the forward direction, + # and the x-axis as left-right, resulting in a range of -180 to 180 degrees. + azimuth = torch.atan2(x, z) * (180 / math.pi) + + # Compute vertical angle (pitch): measured with respect to the horizontal plane, + # resulting in a range of -90 to 90 degrees. + elevation = torch.atan2(y, torch.sqrt(x**2 + z**2)) * (180 / math.pi) + + # Compute the angular difference from the center view (handling circular angle wrap-around) + diff_azimuth = (azimuth - center_yaw).abs() % 360 + diff_elevation = (elevation - center_pitch).abs() % 360 + + # Adjust values greater than 180 degrees to the shorter angular difference + diff_azimuth = torch.where(diff_azimuth > 180, 360 - diff_azimuth, diff_azimuth) + diff_elevation = torch.where(diff_elevation > 180, 360 - diff_elevation, diff_elevation) + + # Check if both horizontal and vertical angles are within their respective FOV limits + return (diff_azimuth < fov_half_h) & (diff_elevation < fov_half_v) + +def generate_points_in_sphere(n_points, radius): + # Sample three independent uniform distributions + samples_r = torch.rand(n_points) # For radius distribution + samples_phi = torch.rand(n_points) # For azimuthal angle phi + samples_u = torch.rand(n_points) # For polar angle theta + + # Apply cube root to ensure uniform volumetric distribution + r = radius * torch.pow(samples_r, 1/3) + # Azimuthal angle phi uniformly distributed in [0, 2π] + phi = 2 * math.pi * samples_phi + # Convert u to theta to ensure cos(theta) is uniformly distributed + theta = torch.acos(1 - 2 * samples_u) + + # Convert spherical coordinates to Cartesian coordinates + x = r * torch.sin(theta) * torch.cos(phi) + y = r * torch.sin(theta) * torch.sin(phi) + z = r * torch.cos(theta) + + points = torch.stack((x, y, z), dim=1) + return points + +def tensor_max_with_number(tensor, number): + number_tensor = torch.tensor(number, dtype=tensor.dtype, device=tensor.device) + result = torch.max(tensor, number_tensor) + return result + +def custom_meshgrid(*args): + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + +def camera_to_world_to_world_to_camera(camera_to_world: torch.Tensor) -> torch.Tensor: + """ + Convert Camera-to-World matrices to World-to-Camera matrices for a tensor with shape (f, b, 4, 4). + + Args: + camera_to_world (torch.Tensor): A tensor of shape (f, b, 4, 4), where: + f = number of frames, + b = batch size. + + Returns: + torch.Tensor: A tensor of shape (f, b, 4, 4) representing the World-to-Camera matrices. + """ + # Ensure input is a 4D tensor + assert camera_to_world.ndim == 4 and camera_to_world.shape[2:] == (4, 4), \ + "Input must be of shape (f, b, 4, 4)" + + # Extract the rotation (R) and translation (T) parts + R = camera_to_world[:, :, :3, :3] # Shape: (f, b, 3, 3) + T = camera_to_world[:, :, :3, 3] # Shape: (f, b, 3) + + # Initialize an identity matrix for the output + world_to_camera = torch.eye(4, device=camera_to_world.device).unsqueeze(0).unsqueeze(0) + world_to_camera = world_to_camera.repeat(camera_to_world.size(0), camera_to_world.size(1), 1, 1) # Shape: (f, b, 4, 4) + + # Compute the rotation (transpose of R) + world_to_camera[:, :, :3, :3] = R.transpose(2, 3) + + # Compute the translation (-R^T * T) + world_to_camera[:, :, :3, 3] = -torch.matmul(R.transpose(2, 3), T.unsqueeze(-1)).squeeze(-1) + + return world_to_camera.to(camera_to_world.dtype) + +def convert_to_plucker(poses, curr_frame, focal_length, image_width, image_height): + + intrinsic = np.asarray([focal_length * image_width, + focal_length * image_height, + 0.5 * image_width, + 0.5 * image_height], dtype=np.float32) + + c2ws = get_relative_pose(poses, zero_first_frame_scale=curr_frame) + c2ws = rearrange(c2ws, "t b m n -> b t m n") + + K = torch.as_tensor(intrinsic, device=poses.device, dtype=poses.dtype).repeat(c2ws.shape[0],c2ws.shape[1],1) # [B, F, 4] + plucker_embedding = ray_condition(K, c2ws, image_height, image_width, device=c2ws.device) + plucker_embedding = rearrange(plucker_embedding, "b t h w d -> t b h w d").contiguous() + + return plucker_embedding + + +def get_relative_pose(abs_c2ws, zero_first_frame_scale): + abs_w2cs = camera_to_world_to_world_to_camera(abs_c2ws) + target_cam_c2w = torch.tensor([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]).to(abs_c2ws.device).to(abs_c2ws.dtype) + abs2rel = target_cam_c2w @ abs_w2cs[zero_first_frame_scale] + ret_poses = [abs2rel @ abs_c2w for abs_c2w in abs_c2ws] + ret_poses = torch.stack(ret_poses) + return ret_poses + +def ray_condition(K, c2w, H, W, device): + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = custom_meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i, device=device, dtype=c2w.dtype) # [B, HxW] + xs = -(i - cx) / fx * zs + ys = -(j - cy) / fy * zs + + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.linalg.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + + return plucker + +def random_transform(tensor): + """ + Apply the same random translation, rotation, and scaling to all frames in the batch. + + Args: + tensor (torch.Tensor): Input tensor of shape (F, B, 3, H, W). + + Returns: + torch.Tensor: Transformed tensor of shape (F, B, 3, H, W). + """ + if tensor.ndim != 5: + raise ValueError("Input tensor must have shape (F, B, 3, H, W)") + + F, B, C, H, W = tensor.shape + + # Generate random transformation parameters + max_translate = 0.2 # Translate up to 20% of width/height + max_rotate = 30 # Rotate up to 30 degrees + max_scale = 0.2 # Scale change by up to +/- 20% + + translate_x = random.uniform(-max_translate, max_translate) * W + translate_y = random.uniform(-max_translate, max_translate) * H + rotate_angle = random.uniform(-max_rotate, max_rotate) + scale_factor = 1 + random.uniform(-max_scale, max_scale) + + # Apply the same transformation to all frames and batches + + tensor = tensor.reshape(F*B, C, H, W) + transformed_tensor = TF.affine( + tensor, + angle=rotate_angle, + translate=(translate_x, translate_y), + scale=scale_factor, + shear=(0, 0), + interpolation=InterpolationMode.BILINEAR, + fill=0 + ) + + transformed_tensor = transformed_tensor.reshape(F, B, C, H, W) + return transformed_tensor + +def save_tensor_as_png(tensor, file_path): + """ + Save a 3*H*W tensor as a PNG image. + + Args: + tensor (torch.Tensor): Input tensor of shape (3, H, W). + file_path (str): Path to save the PNG file. + """ + if tensor.ndim != 3 or tensor.shape[0] != 3: + raise ValueError("Input tensor must have shape (3, H, W)") + + # Convert tensor to PIL Image + image = TF.to_pil_image(tensor) + + # Save image + image.save(file_path) + +class BaseVideoDiTMinecraft(DiffusionForcingBase): + """ + Video generation for MineCraft with memory. + """ + + def __init__(self, cfg: DictConfig): + """ + Initialize the base video-DiT Minecraft class with the given configuration. + + Args: + cfg (DictConfig): Configuration object. + """ + self.n_tokens = cfg.n_frames // cfg.frame_stack # number of max tokens for the model + self.n_frames = cfg.n_frames + if hasattr(cfg, "n_tokens"): + self.n_tokens = cfg.n_tokens // cfg.frame_stack + self.memory_condition_length = cfg.memory_condition_length + self.pose_cond_dim = getattr(cfg, "pose_cond_dim", 5) + + self.use_plucker = getattr(cfg, "use_plucker", True) + self.relative_embedding = getattr(cfg, "relative_embedding", True) + self.state_embed_only_on_qk = getattr(cfg, "state_embed_only_on_qk", True) + self.use_memory_attention = getattr(cfg, "use_memory_attention", True) + self.add_timestamp_embedding = getattr(cfg, "add_timestamp_embedding", True) + self.ref_mode = getattr(cfg, "ref_mode", 'sequential') + self.log_curve = getattr(cfg, "log_curve", False) + self.focal_length = getattr(cfg, "focal_length", 0.35) + self.log_video = cfg.log_video + self.save_local = getattr(cfg, "save_local", True) + self.local_save_dir = getattr(cfg, "local_save_dir", None) + self.lpips_batch_size = getattr(cfg, "lpips_batch_size", 16) + self.next_frame_length = getattr(cfg, "next_frame_length", 1) + self.require_pose_prediction = getattr(cfg, "require_pose_prediction", False) + + super().__init__(cfg) + + def _build_model(self): + + self.diffusion_model = Diffusion( + reference_length=self.memory_condition_length, + x_shape=self.x_stacked_shape, + action_cond_dim=self.action_cond_dim, + pose_cond_dim=self.pose_cond_dim, + is_causal=self.causal, + cfg=self.cfg.diffusion, + is_dit=True, + use_plucker=self.use_plucker, + relative_embedding=self.relative_embedding, + state_embed_only_on_qk=self.state_embed_only_on_qk, + use_memory_attention=self.use_memory_attention, + add_timestamp_embedding=self.add_timestamp_embedding, + ref_mode=self.ref_mode + ) + + self.validation_lpips_model = LearnedPerceptualImagePatchSimilarity() + vae = VAE_models["vit-l-20-shallow-encoder"]() + self.vae = vae.eval() + + if self.require_pose_prediction: + self.pose_prediction_model = PosePredictionNet() + + def _generate_noise_levels(self, xs: torch.Tensor, masks = None) -> torch.Tensor: + """ + Generate noise levels for training. + """ + num_frames, batch_size, *_ = xs.shape + match self.cfg.noise_level: + case "random_all": # entirely random noise levels + noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device) + case "same": + noise_levels = torch.randint(0, self.timesteps, (num_frames, batch_size), device=xs.device) + noise_levels[1:] = noise_levels[0] + + if masks is not None: + # for frames that are not available, treat as full noise + discard = torch.all(~rearrange(masks.bool(), "(t fs) b -> t b fs", fs=self.frame_stack), -1) + noise_levels = torch.where(discard, torch.full_like(noise_levels, self.timesteps - 1), noise_levels) + + return noise_levels + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + """ + Perform a single training step. + + This function processes the input batch, + encodes the input frames, generates noise levels, and computes the loss using the diffusion model. + + Args: + batch: Input batch of data containing frames, conditions, poses, etc. + batch_idx: Index of the current batch. + + Returns: + dict: A dictionary containing the training loss. + """ + xs, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) + + if self.use_plucker: + if self.relative_embedding: + input_pose_condition = [] + frame_idx_list = [] + for i in range(self.n_frames): + input_pose_condition.append( + convert_to_plucker( + torch.cat([c2w_mat[i:i + 1], c2w_mat[-self.memory_condition_length:]]).clone(), + 0, + focal_length=self.focal_length, + image_height=xs.shape[-2],image_width=xs.shape[-1] + ).to(xs.dtype) + ) + frame_idx_list.append( + torch.cat([ + frame_idx[i:i + 1] - frame_idx[i:i + 1], + frame_idx[-self.memory_condition_length:] - frame_idx[i:i + 1] + ]).clone() + ) + input_pose_condition = torch.cat(input_pose_condition) + frame_idx_list = torch.cat(frame_idx_list) + else: + input_pose_condition = convert_to_plucker( + c2w_mat, 0, focal_length=self.focal_length + ).to(xs.dtype) + frame_idx_list = frame_idx + else: + input_pose_condition = pose_conditions.to(xs.dtype) + frame_idx_list = None + + xs = self.encode(xs) + + noise_levels = self._generate_noise_levels(xs) + + if self.memory_condition_length: + noise_levels[-self.memory_condition_length:] = self.diffusion_model.stabilization_level + conditions[-self.memory_condition_length:] *= 0 + + _, loss = self.diffusion_model( + xs, + conditions, + input_pose_condition, + noise_levels=noise_levels, + reference_length=self.memory_condition_length, + frame_idx=frame_idx_list + ) + + if self.memory_condition_length: + loss = loss[:-self.memory_condition_length] + + loss = self.reweight_loss(loss, None) + + if batch_idx % 20 == 0: + self.log("training/loss", loss.cpu()) + + return {"loss": loss} + + def on_validation_epoch_end(self, namespace="validation") -> None: + if not self.validation_step_outputs: + return + + xs_pred = [] + xs = [] + for pred, gt in self.validation_step_outputs: + xs_pred.append(pred) + xs.append(gt) + + xs_pred = torch.cat(xs_pred, 1) + if gt is not None: + xs = torch.cat(xs, 1) + else: + xs = None + + if self.logger and self.log_video: + log_video( + xs_pred, + xs, + step=None if namespace == "test" else self.global_step, + namespace=namespace + "_vis", + context_frames=self.context_frames, + logger=self.logger.experiment, + save_local=self.save_local, + local_save_dir=self.local_save_dir, + ) + + if xs is not None: + # Move data to the same device as LPIPS model for metric calculation + device = next(self.validation_lpips_model.parameters()).device + xs_pred_device = xs_pred.to(device) + xs_device = xs.to(device) + + metric_dict = get_validation_metrics_for_videos( + xs_pred_device, xs_device, + lpips_model=self.validation_lpips_model, + lpips_batch_size=self.lpips_batch_size) + + self.log_dict( + {"mse": metric_dict['mse'], + "psnr": metric_dict['psnr'], + "lpips": metric_dict['lpips']}, + sync_dist=True + ) + + if self.log_curve: + psnr_values = metric_dict['frame_wise_psnr'].cpu().tolist() + frames = list(range(len(psnr_values))) + line_plot = wandb.plot.line_series( + xs = frames, + ys = [psnr_values], + keys = ["PSNR"], + title = "Frame-wise PSNR", + xname = "Frame index" + ) + + self.logger.experiment.log({"frame_wise_psnr_plot": line_plot}) + + self.validation_step_outputs.clear() + + def _preprocess_batch(self, batch): + + xs, conditions, pose_conditions, frame_index = batch + + if self.action_cond_dim: + conditions = torch.cat([torch.zeros_like(conditions[:, :1]), conditions[:, 1:]], 1) + conditions = rearrange(conditions, "b t d -> t b d").contiguous() + else: + raise NotImplementedError("Only support external cond.") + + pose_conditions = rearrange(pose_conditions, "b t d -> t b d").contiguous() + c2w_mat = euler_to_camera_to_world_matrix(pose_conditions) + xs = rearrange(xs, "b t c ... -> t b c ...").contiguous() + frame_index = rearrange(frame_index, "b t -> t b").contiguous() + + return xs, conditions, pose_conditions, c2w_mat, frame_index + + def encode(self, x): + # vae encoding + T = x.shape[0] + H, W = x.shape[-2:] + scaling_factor = 0.07843137255 + + x = rearrange(x, "t b c h w -> (t b) c h w") + with torch.no_grad(): + x = self.vae.encode(x * 2 - 1).mean * scaling_factor + x = rearrange(x, "(t b) (h w) c -> t b c h w", t=T, h=H // self.vae.patch_size, w=W // self.vae.patch_size) + return x + + def decode(self, x): + total_frames = x.shape[0] + scaling_factor = 0.07843137255 + x = rearrange(x, "t b c h w -> (t b) (h w) c") + with torch.no_grad(): + x = (self.vae.decode(x / scaling_factor) + 1) / 2 + x = rearrange(x, "(t b) c h w-> t b c h w", t=total_frames) + return x + + def _generate_condition_indices(self, curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon): + """ + Generate indices for condition similarity based on the current frame and pose conditions. + """ + if curr_frame < memory_condition_length: + random_idx = [i for i in range(curr_frame)] + [0] * (memory_condition_length - curr_frame) + random_idx = np.repeat(np.array(random_idx)[:, None], xs_pred.shape[1], -1) + else: + # Generate points in a sphere and filter based on field of view + num_samples = 10000 + radius = 30 + points = generate_points_in_sphere(num_samples, radius).to(pose_conditions.device) + points = points[:, None].repeat(1, pose_conditions.shape[1], 1) + points += pose_conditions[curr_frame, :, :3][None] + fov_half_h = torch.tensor(105 / 2, device=pose_conditions.device) + fov_half_v = torch.tensor(75 / 2, device=pose_conditions.device) + + # in_fov1 = is_inside_fov_3d_hv( + # points, pose_conditions[curr_frame, :, :3], + # pose_conditions[curr_frame, :, -2], pose_conditions[curr_frame, :, -1], + # fov_half_h, fov_half_v + # ) + + in_fov1 = torch.stack([ + is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v) + for pc in pose_conditions[curr_frame:curr_frame+horizon] + ]) + + in_fov1 = torch.sum(in_fov1, 0) > 0 + + # Compute overlap ratios and select indices + in_fov_list = torch.stack([ + is_inside_fov_3d_hv(points, pc[:, :3], pc[:, -2], pc[:, -1], fov_half_h, fov_half_v) + for pc in pose_conditions[:curr_frame] + ]) + + random_idx = [] + for _ in range(memory_condition_length): + overlap_ratio = ((in_fov1.bool() & in_fov_list).sum(1)) / in_fov1.sum() + + confidence = overlap_ratio + (curr_frame - frame_idx[:curr_frame]) / curr_frame * (-0.2) + + if len(random_idx) > 0: + confidence[torch.cat(random_idx)] = -1e10 + _, r_idx = torch.topk(confidence, k=1, dim=0) + random_idx.append(r_idx[0]) + + # choice 1: directly remove overlapping region + occupied_mask = in_fov_list[r_idx[0, range(in_fov1.shape[-1])], :, range(in_fov1.shape[-1])].permute(1,0) + in_fov1 = in_fov1 & ~occupied_mask + + # choice 2: apply similarity filter + # cos_sim = F.cosine_similarity(xs_pred.to(r_idx.device)[r_idx[:, range(in_fov1.shape[1])], + # range(in_fov1.shape[1])], xs_pred.to(r_idx.device)[:curr_frame], dim=2) + # cos_sim = cos_sim.mean((-2,-1)) + + # mask_sim = cos_sim>0.9 + # in_fov_list = in_fov_list & ~mask_sim[:,None].to(in_fov_list.device) + + random_idx = torch.stack(random_idx).cpu() + + return random_idx + + def _prepare_conditions(self, + start_frame, curr_frame, horizon, conditions, + pose_conditions, c2w_mat, frame_idx, random_idx, + image_width, image_height): + """ + Prepare input conditions and pose conditions for sampling. + """ + + padding = torch.zeros((len(random_idx),) + conditions.shape[1:], device=conditions.device, dtype=conditions.dtype) + input_condition = torch.cat([conditions[start_frame:curr_frame + horizon], padding], dim=0) + + batch_size = conditions.shape[1] + + if self.use_plucker: + if self.relative_embedding: + frame_idx_list = [] + input_pose_condition = [] + for i in range(start_frame, curr_frame + horizon): + input_pose_condition.append(convert_to_plucker(torch.cat([c2w_mat[i:i+1],c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]]).clone(), 0, focal_length=self.focal_length, + image_width=image_width, image_height=image_height).to(conditions.dtype)) + frame_idx_list.append(torch.cat([frame_idx[i:i+1]-frame_idx[i:i+1], frame_idx[random_idx[:,range(batch_size)], range(batch_size)]-frame_idx[i:i+1]])) + input_pose_condition = torch.cat(input_pose_condition) + frame_idx_list = torch.cat(frame_idx_list) + + else: + input_pose_condition = torch.cat([c2w_mat[start_frame : curr_frame + horizon], c2w_mat[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone() + input_pose_condition = convert_to_plucker(input_pose_condition, 0, focal_length=self.focal_length) + frame_idx_list = None + else: + input_pose_condition = torch.cat([pose_conditions[start_frame : curr_frame + horizon], pose_conditions[random_idx[:,range(batch_size)], range(batch_size)]], dim=0).clone() + frame_idx_list = None + + return input_condition, input_pose_condition, frame_idx_list + + def _prepare_noise_levels(self, scheduling_matrix, m, curr_frame, batch_size, memory_condition_length): + """ + Prepare noise levels for the current sampling step. + """ + from_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m]))[:, None].repeat(batch_size, axis=1) + to_noise_levels = np.concatenate((np.zeros((curr_frame,), dtype=np.int64), scheduling_matrix[m + 1]))[:, None].repeat(batch_size, axis=1) + if memory_condition_length: + from_noise_levels = np.concatenate([from_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0) + to_noise_levels = np.concatenate([to_noise_levels, np.zeros((memory_condition_length, from_noise_levels.shape[-1]), dtype=np.int32)], axis=0) + from_noise_levels = torch.from_numpy(from_noise_levels).to(self.device) + to_noise_levels = torch.from_numpy(to_noise_levels).to(self.device) + return from_noise_levels, to_noise_levels + + def validation_step(self, batch, batch_idx, namespace="validation") -> STEP_OUTPUT: + """ + Perform a single validation step. + + This function processes the input batch, encodes frames, generates predictions using a sliding window approach, + and handles condition similarity logic for sampling. The results are decoded and stored for evaluation. + + Args: + batch: Input batch of data containing frames, conditions, poses, etc. + batch_idx: Index of the current batch. + namespace: Namespace for logging (default: "validation"). + + Returns: + None: Appends the predicted and ground truth frames to `self.validation_step_outputs`. + """ + # Preprocess the input batch + memory_condition_length = self.memory_condition_length + xs_raw, conditions, pose_conditions, c2w_mat, frame_idx = self._preprocess_batch(batch) + + + # Encode frames in chunks if necessary + total_frame = xs_raw.shape[0] + if total_frame > 10: + xs = torch.cat([ + self.encode(xs_raw[int(total_frame * i / 10):int(total_frame * (i + 1) / 10)]).cpu() + for i in range(10) + ]) + else: + xs = self.encode(xs_raw).cpu() + + n_frames, batch_size, *_ = xs.shape + curr_frame = 0 + + # Initialize context frames + n_context_frames = self.context_frames // self.frame_stack + xs_pred = xs[:n_context_frames].clone() + curr_frame += n_context_frames + + # Progress bar for sampling + pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") + + while curr_frame < n_frames: + # Determine the horizon for the current chunk + horizon = min(n_frames - curr_frame, self.chunk_size) if self.chunk_size > 0 else n_frames - curr_frame + assert horizon <= self.n_tokens, "Horizon exceeds the number of tokens." + + # Generate scheduling matrix and initialize noise + scheduling_matrix = self._generate_scheduling_matrix(horizon) + chunk = torch.randn((horizon, batch_size, *xs_pred.shape[2:])) + chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise).to(xs_pred.device) + xs_pred = torch.cat([xs_pred, chunk], 0) + + # Sliding window: only input the last `n_tokens` frames + start_frame = max(0, curr_frame + horizon - self.n_tokens) + pbar.set_postfix({"start": start_frame, "end": curr_frame + horizon}) + + # Handle condition similarity logic + if memory_condition_length: + random_idx = self._generate_condition_indices( + curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, horizon + ) + + xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0) + + # Prepare input conditions and pose conditions + input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions( + start_frame, curr_frame, horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx, + image_width=xs_raw.shape[-1], image_height=xs_raw.shape[-2] + ) + + # Perform sampling for each step in the scheduling matrix + for m in range(scheduling_matrix.shape[0] - 1): + from_noise_levels, to_noise_levels = self._prepare_noise_levels( + scheduling_matrix, m, curr_frame, batch_size, memory_condition_length + ) + + xs_pred[start_frame:] = self.diffusion_model.sample_step( + xs_pred[start_frame:].to(input_condition.device), + input_condition, + input_pose_condition, + from_noise_levels[start_frame:], + to_noise_levels[start_frame:], + current_frame=curr_frame, + mode="validation", + reference_length=memory_condition_length, + frame_idx=frame_idx_list + ).cpu() + + # Remove condition similarity frames if applicable + if memory_condition_length: + xs_pred = xs_pred[:-memory_condition_length] + + curr_frame += horizon + pbar.update(horizon) + + # Decode predictions and ground truth + xs_pred = self.decode(xs_pred[n_context_frames:].to(conditions.device)) + xs_decode = self.decode(xs[n_context_frames:].to(conditions.device)) + + # Store results for evaluation (move to CPU to save GPU memory) + self.validation_step_outputs.append((xs_pred.detach().cpu(), xs_decode.detach().cpu())) + return + + @torch.no_grad() + def interactive(self, first_frame, new_actions, first_pose, device, + memory_latent_frames, memory_actions, memory_poses, memory_c2w, memory_frame_idx): + + memory_condition_length = self.memory_condition_length + + if memory_latent_frames is None: + first_frame = torch.from_numpy(first_frame) + new_actions = torch.from_numpy(new_actions) + first_pose = torch.from_numpy(first_pose) + first_frame_encode = self.encode(first_frame[None, None].to(device)) + memory_latent_frames = first_frame_encode.cpu() + memory_actions = new_actions[None, None].to(device) + memory_poses = first_pose[None, None].to(device) + new_c2w_mat = euler_to_camera_to_world_matrix(first_pose) + memory_c2w = new_c2w_mat[None, None].to(device) + memory_frame_idx = torch.tensor([[0]]).to(device) + return first_frame.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy() + else: + memory_latent_frames = torch.from_numpy(memory_latent_frames) + memory_actions = torch.from_numpy(memory_actions).to(device) + memory_poses = torch.from_numpy(memory_poses).to(device) + memory_c2w = torch.from_numpy(memory_c2w).to(device) + memory_frame_idx = torch.from_numpy(memory_frame_idx).to(device) + new_actions = new_actions.to(device) + + curr_frame = 0 + batch_size = 1 + horizon = self.next_frame_length + n_frames = curr_frame + horizon + # context + n_context_frames = len(memory_latent_frames) + xs_pred = memory_latent_frames[:n_context_frames].clone() + curr_frame += n_context_frames + + pbar = tqdm(total=n_frames, initial=curr_frame, desc="Sampling") + + new_pose_condition_list = [] + last_frame = xs_pred[-1].clone() + last_pose_condition = memory_poses[-1].clone() + curr_actions = new_actions.clone() + for hi in range(len(new_actions)): + last_pose_condition[:,3:] = last_pose_condition[:,3:] // 15 + new_pose_condition_offset = self.pose_prediction_model(last_frame.to(device), curr_actions[None, hi], last_pose_condition) + new_pose_condition_offset[:,3:] = torch.round(new_pose_condition_offset[:,3:]) + new_pose_condition = last_pose_condition + new_pose_condition_offset + new_pose_condition[:,3:] = new_pose_condition[:,3:] * 15 + new_pose_condition[:,3:] %= 360 + last_pose_condition = new_pose_condition.clone() + new_pose_condition_list.append(new_pose_condition[None]) + new_pose_condition_list = torch.cat(new_pose_condition_list, 0) + + ai = 0 + while ai < len(new_actions): + next_horizon = min(horizon, len(new_actions) - ai) + last_frame = xs_pred[-1].clone() + curr_actions = new_actions[ai:ai+next_horizon].clone() + + new_pose_condition = new_pose_condition_list[ai:ai+next_horizon].clone() + + new_c2w_mat = euler_to_camera_to_world_matrix(new_pose_condition) + memory_poses = torch.cat([memory_poses, new_pose_condition]) + memory_actions = torch.cat([memory_actions, curr_actions[:, None]]) + memory_c2w = torch.cat([memory_c2w, new_c2w_mat]) + new_indices = memory_frame_idx[-1,0] + torch.arange(next_horizon, device=memory_frame_idx.device) + 1 + + memory_frame_idx = torch.cat([memory_frame_idx, new_indices[:, None]]) + + conditions = memory_actions.clone() + pose_conditions = memory_poses.clone() + c2w_mat = memory_c2w .clone() + frame_idx = memory_frame_idx.clone() + + # generation on frame + scheduling_matrix = self._generate_scheduling_matrix(next_horizon) + chunk = torch.randn((next_horizon, batch_size, *xs_pred.shape[2:])).to(xs_pred.device) + chunk = torch.clamp(chunk, -self.clip_noise, self.clip_noise) + + xs_pred = torch.cat([xs_pred, chunk], 0) + + # sliding window: only input the last n_tokens frames + start_frame = max(0, curr_frame - self.n_tokens) + + pbar.set_postfix( + { + "start": start_frame, + "end": curr_frame + next_horizon, + } + ) + + # Handle condition similarity logic + if memory_condition_length: + random_idx = self._generate_condition_indices( + curr_frame, memory_condition_length, xs_pred, pose_conditions, frame_idx, next_horizon + ) + + # random_idx = np.unique(random_idx)[:, None] + # memory_condition_length = len(random_idx) + xs_pred = torch.cat([xs_pred, xs_pred[random_idx[:, range(xs_pred.shape[1])], range(xs_pred.shape[1])].clone()], 0) + + # Prepare input conditions and pose conditions + input_condition, input_pose_condition, frame_idx_list = self._prepare_conditions( + start_frame, curr_frame, next_horizon, conditions, pose_conditions, c2w_mat, frame_idx, random_idx, + image_width=first_frame.shape[-1], image_height=first_frame.shape[-2] + ) + + # Perform sampling for each step in the scheduling matrix + for m in range(scheduling_matrix.shape[0] - 1): + from_noise_levels, to_noise_levels = self._prepare_noise_levels( + scheduling_matrix, m, curr_frame, batch_size, memory_condition_length + ) + + xs_pred[start_frame:] = self.diffusion_model.sample_step( + xs_pred[start_frame:].to(input_condition.device), + input_condition, + input_pose_condition, + from_noise_levels[start_frame:], + to_noise_levels[start_frame:], + current_frame=curr_frame, + mode="validation", + reference_length=memory_condition_length, + frame_idx=frame_idx_list + ).cpu() + + + if memory_condition_length: + xs_pred = xs_pred[:-memory_condition_length] + + curr_frame += next_horizon + pbar.update(next_horizon) + ai += next_horizon + + memory_latent_frames = torch.cat([memory_latent_frames, xs_pred[n_context_frames:]]) + xs_pred = self.decode(xs_pred[n_context_frames:].to(device)).cpu() + + return xs_pred.cpu().numpy(), memory_latent_frames.cpu().numpy(), memory_actions.cpu().numpy(), \ + memory_poses.cpu().numpy(), memory_c2w.cpu().numpy(), memory_frame_idx.cpu().numpy() diff --git a/algorithms/worldmem/models/attention.py b/algorithms/worldmem/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..e5e7be25535060d266342448c3aece4b53018204 --- /dev/null +++ b/algorithms/worldmem/models/attention.py @@ -0,0 +1,342 @@ +""" +Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py +""" + +from typing import Optional +from collections import namedtuple +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from .rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb +import numpy as np + +class TemporalAxialAttention(nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + reference_length: int, + rotary_emb: RotaryEmbedding, + is_causal: bool = True, + is_temporal_independent: bool = False, + use_domain_adapter = False + ): + super().__init__() + self.inner_dim = dim_head * heads + self.heads = heads + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) + + self.use_domain_adapter = use_domain_adapter + if self.use_domain_adapter: + lora_rank = 8 + self.lora_A = nn.Linear(dim, lora_rank, bias=False) + self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False) + + self.to_out = nn.Linear(self.inner_dim, dim) + + self.rotary_emb = rotary_emb + self.is_causal = is_causal + self.is_temporal_independent = is_temporal_independent + + self.reference_length = reference_length + + def forward(self, x: torch.Tensor): + B, T, H, W, D = x.shape + + # if T>=9: + # try: + # # x = torch.cat([x[:,:-1],x[:,16-T:17-T],x[:,-1:]], dim=1) + # x = torch.cat([x[:,16-T:17-T],x], dim=1) + # except: + # import pdb;pdb.set_trace() + # print("="*50) + # print(x.shape) + + B, T, H, W, D = x.shape + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + if self.use_domain_adapter: + q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1) + q = q+q_lora + k = k+k_lora + v = v+v_lora + + q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads) + k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads) + v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads) + + q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs) + k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs) + + q, k, v = map(lambda t: t.contiguous(), (q, k, v)) + + if self.is_temporal_independent: + attn_bias = torch.ones((T, T), dtype=q.dtype, device=q.device) + attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf')) + attn_bias[range(T), range(T)] = 0 + elif self.is_causal: + attn_bias = torch.triu(torch.ones((T, T), dtype=q.dtype, device=q.device), diagonal=1) + attn_bias = attn_bias.masked_fill(attn_bias == 1, float('-inf')) + attn_bias[(T-self.reference_length):] = float('-inf') + attn_bias[range(T), range(T)] = 0 + else: + attn_bias = None + + try: + x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias) + except: + import pdb;pdb.set_trace() + + x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W) + x = x.to(q.dtype) + + # linear proj + x = self.to_out(x) + + # if T>=10: + # try: + # # x = torch.cat([x[:,:-2],x[:,-1:]], dim=1) + # x = x[:,1:] + # except: + # import pdb;pdb.set_trace() + # print(x.shape) + return x + +class SpatialAxialAttention(nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + rotary_emb: RotaryEmbedding, + use_domain_adapter = False + ): + super().__init__() + self.inner_dim = dim_head * heads + self.heads = heads + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) + self.use_domain_adapter = use_domain_adapter + if self.use_domain_adapter: + lora_rank = 8 + self.lora_A = nn.Linear(dim, lora_rank, bias=False) + self.lora_B = nn.Linear(lora_rank, self.inner_dim * 3, bias=False) + + self.to_out = nn.Linear(self.inner_dim, dim) + + self.rotary_emb = rotary_emb + + def forward(self, x: torch.Tensor): + B, T, H, W, D = x.shape + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + if self.use_domain_adapter: + q_lora, k_lora, v_lora = self.lora_B(self.lora_A(x)).chunk(3, dim=-1) + q = q+q_lora + k = k+k_lora + v = v+v_lora + + q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads) + k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads) + v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads) + + freqs = self.rotary_emb.get_axial_freqs(H, W) + q = apply_rotary_emb(freqs, q) + k = apply_rotary_emb(freqs, k) + + # prepare for attn + q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) + k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) + v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) + + x = F.scaled_dot_product_attention(query=q, key=k, value=v, is_causal=False) + + x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W) + x = x.to(q.dtype) + + # linear proj + x = self.to_out(x) + return x + +class MemTemporalAxialAttention(nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + rotary_emb: RotaryEmbedding, + is_causal: bool = True, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.heads = heads + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) + self.to_out = nn.Linear(self.inner_dim, dim) + + self.rotary_emb = rotary_emb + self.is_causal = is_causal + + self.reference_length = 3 + + def forward(self, x: torch.Tensor): + B, T, H, W, D = x.shape + + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + + q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads) + k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads) + v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads) + + + + # q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs) + # k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs) + + q, k, v = map(lambda t: t.contiguous(), (q, k, v)) + + # if T == 21000: + # # 手动计算缩放点积分数 + # _, _, _, d_k = q.shape + # scores = torch.einsum("b h n d, b h m d -> b h n m", q, k) / (d_k ** 0.5) # Shape: (B, T_q, T_k) + + # # 计算注意力图 (Attention Map) + # attention_map = F.softmax(scores, dim=-1) # Shape: (B, T_q, T_k) + # b_, h_, n_, m_ = attention_map.shape + # attention_map = attention_map.reshape(1, int(np.sqrt(b_/1)), int(np.sqrt(b_/1)), h_, n_, m_) + # attention_map = attention_map.mean(3) + + # attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device) + # T_origin = T - self.reference_length + # attn_bias[:T_origin, T_origin:] = 1 + # attn_bias[range(T), range(T)] = 1 + + # attention_map = attention_map * attn_bias + + # # print 注意力图 + # import matplotlib.pyplot as plt + # fig, axes = plt.subplots(21000, 21000, figsize=(9, 9)) # 调整figsize以适配图像大小 + + # # 遍历3*3维度 + # for i in range(21000): + # for j in range(21000): + # # 取出第(i, j)个子图像 + # img = attention_map[0, :, :, i, j].cpu().numpy() + # axes[i, j].imshow(img, cmap='viridis') # 可以自定义cmap + # axes[i, j].axis('off') # 隐藏坐标轴 + + # # 调整子图间距 + # plt.tight_layout() + # plt.savefig('attention_map.png') + # import pdb; pdb.set_trace() + # plt.close() + + attn_bias = torch.zeros((T, T), dtype=q.dtype, device=q.device) + attn_bias = attn_bias.masked_fill(attn_bias == 0, float('-inf')) + T_origin = T - self.reference_length + attn_bias[:T_origin, T_origin:] = 0 + attn_bias[range(T), range(T)] = 0 + + # if T==121000: + # import pdb;pdb.set_trace() + + try: + x = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_bias) + except: + import pdb;pdb.set_trace() + + x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W) + x = x.to(q.dtype) + + # linear proj + x = self.to_out(x) + return x + +class MemFullAttention(nn.Module): + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + reference_length: int, + rotary_emb: RotaryEmbedding, + is_causal: bool = True + ): + super().__init__() + self.inner_dim = dim_head * heads + self.heads = heads + self.head_dim = dim_head + self.inner_dim = dim_head * heads + self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) + self.to_out = nn.Linear(self.inner_dim, dim) + + self.rotary_emb = rotary_emb + self.is_causal = is_causal + + self.reference_length = reference_length + + self.store = None + + def forward(self, x: torch.Tensor, relative_embedding=False, + extra_condition=None, + state_embed_only_on_qk=False, + reference_length=None): + + B, T, H, W, D = x.shape + + if state_embed_only_on_qk: + q, k, _ = self.to_qkv(x+extra_condition).chunk(3, dim=-1) + _, _, v = self.to_qkv(x).chunk(3, dim=-1) + else: + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + + if relative_embedding: + length = reference_length+1 + n_frames = T // length + x = x.reshape(B, n_frames, length, H, W, D) + + x_list = [] + + for i in range(n_frames): + if i == n_frames-1: + q_i = rearrange(q[:, i*length:], "B T H W (h d) -> B h (T H W) d", h=self.heads) + k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads) + v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads) + else: + q_i = rearrange(q[:, i*length:i*length+1], "B T H W (h d) -> B h (T H W) d", h=self.heads) + k_i = rearrange(k[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads) + v_i = rearrange(v[:, i*length+1:(i+1)*length], "B T H W (h d) -> B h (T H W) d", h=self.heads) + + q_i, k_i, v_i = map(lambda t: t.contiguous(), (q_i, k_i, v_i)) + x_i = F.scaled_dot_product_attention(query=q_i, key=k_i, value=v_i) + x_i = rearrange(x_i, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W) + x_i = x_i.to(q.dtype) + x_list.append(x_i) + + x = torch.cat(x_list, dim=1) + + + else: + T_ = T - reference_length + q = rearrange(q, "B T H W (h d) -> B h (T H W) d", h=self.heads) + k = rearrange(k[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads) + v = rearrange(v[:, T_:], "B T H W (h d) -> B h (T H W) d", h=self.heads) + + q, k, v = map(lambda t: t.contiguous(), (q, k, v)) + x = F.scaled_dot_product_attention(query=q, key=k, value=v) + x = rearrange(x, "B h (T H W) d -> B T H W (h d)", B=B, H=H, W=W) + x = x.to(q.dtype) + + # linear proj + x = self.to_out(x) + + return x diff --git a/algorithms/worldmem/models/diffusion.py b/algorithms/worldmem/models/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..f48003e39aff3614db1a285a8339c0e6bea11bfb --- /dev/null +++ b/algorithms/worldmem/models/diffusion.py @@ -0,0 +1,594 @@ +from typing import Optional, Callable +from collections import namedtuple +from omegaconf import DictConfig +import torch +from torch import nn +from torch.nn import functional as F +from einops import rearrange +from .utils import linear_beta_schedule, cosine_beta_schedule, sigmoid_beta_schedule, extract +from .dit import DiT_models + +ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start", "model_out"]) + + +class Diffusion(nn.Module): + # Special thanks to lucidrains for the implementation of the base Diffusion model + # https://github.com/lucidrains/denoising-diffusion-pytorch + + def __init__( + self, + x_shape: torch.Size, + reference_length: int, + action_cond_dim: int, + pose_cond_dim, + is_causal: bool, + cfg: DictConfig, + is_dit: bool=False, + use_plucker=False, + relative_embedding=False, + state_embed_only_on_qk=False, + use_memory_attention=False, + add_timestamp_embedding=False, + memory_token_cross_attention=False, + memory_cross_attn_layers=None, + ref_mode='sequential' + ): + super().__init__() + self.cfg = cfg + + self.x_shape = x_shape + self.action_cond_dim = action_cond_dim + self.timesteps = cfg.timesteps + self.sampling_timesteps = cfg.sampling_timesteps + self.beta_schedule = cfg.beta_schedule + self.schedule_fn_kwargs = cfg.schedule_fn_kwargs + self.objective = cfg.objective + self.use_fused_snr = cfg.use_fused_snr + self.snr_clip = cfg.snr_clip + self.cum_snr_decay = cfg.cum_snr_decay + self.ddim_sampling_eta = cfg.ddim_sampling_eta + self.clip_noise = cfg.clip_noise + self.arch = cfg.architecture + self.stabilization_level = cfg.stabilization_level + self.is_causal = is_causal + self.is_dit = is_dit + self.reference_length = reference_length + self.pose_cond_dim = pose_cond_dim + self.use_plucker = use_plucker + self.relative_embedding = relative_embedding + self.state_embed_only_on_qk = state_embed_only_on_qk + self.use_memory_attention = use_memory_attention + self.add_timestamp_embedding = add_timestamp_embedding + self.memory_token_cross_attention = memory_token_cross_attention + self.memory_cross_attn_layers = memory_cross_attn_layers + self.ref_mode = ref_mode + if self.use_memory_attention: + raise ValueError( + "WorldMem reference-frame use_memory_attention has been removed from DiT. " + "Use memory_token_cross_attention=True for compact memory tokens." + ) + + self._build_model() + self._build_buffer() + + def _build_model(self): + x_channel = self.x_shape[0] + if self.is_dit: + self.model = DiT_models["DiT-S/2"](action_cond_dim=self.action_cond_dim, + reference_length=self.reference_length, + memory_token_cross_attention=self.memory_token_cross_attention, + memory_cross_attn_layers=self.memory_cross_attn_layers, + ref_mode=self.ref_mode) + else: + raise NotImplementedError + + def _build_buffer(self): + if self.beta_schedule == "linear": + beta_schedule_fn = linear_beta_schedule + elif self.beta_schedule == "cosine": + beta_schedule_fn = cosine_beta_schedule + elif self.beta_schedule == "sigmoid": + beta_schedule_fn = sigmoid_beta_schedule + else: + raise ValueError(f"unknown beta schedule {self.beta_schedule}") + + betas = beta_schedule_fn(self.timesteps, **self.schedule_fn_kwargs) + + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + + # sampling related parameters + assert self.sampling_timesteps <= self.timesteps + self.is_ddim_sampling = self.sampling_timesteps < self.timesteps + + # helper function to register buffer from float64 to float32 + register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) + + register_buffer("betas", betas) + register_buffer("alphas_cumprod", alphas_cumprod) + register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) + register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)) + register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) + register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) + register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer("posterior_variance", posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer( + "posterior_log_variance_clipped", + torch.log(posterior_variance.clamp(min=1e-20)), + ) + register_buffer( + "posterior_mean_coef1", + betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), + ) + register_buffer( + "posterior_mean_coef2", + (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod), + ) + + # calculate p2 reweighting + + # register_buffer( + # "p2_loss_weight", + # (self.p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) + # ** -self.p2_loss_weight_gamma, + # ) + + # derive loss weight + # https://arxiv.org/abs/2303.09556 + # snr: signal noise ratio + snr = alphas_cumprod / (1 - alphas_cumprod) + clipped_snr = snr.clone() + clipped_snr.clamp_(max=self.snr_clip) + + register_buffer("clipped_snr", clipped_snr) + register_buffer("snr", snr) + + def add_shape_channels(self, x): + return rearrange(x, f"... -> ...{' 1' * len(self.x_shape)}") + + def model_predictions(self, x, t, action_cond=None, current_frame=None, + pose_cond=None, mode="training", reference_length=None, frame_idx=None, + memory_tokens=None, memory_token_mask=None, memory_retrieval_tokens=None, memory_retrieval_mask=None, + **memory_kwargs): + x = x.permute(1,0,2,3,4) + action_cond = action_cond.permute(1,0,2) + if pose_cond is not None and pose_cond[0] is not None: + try: + pose_cond = pose_cond.permute(1,0,2) + except: + pass + t = t.permute(1,0) + model_output = self.model(x, t, action_cond, current_frame=current_frame, pose_cond=pose_cond, + mode=mode, reference_length=reference_length, frame_idx=frame_idx, + memory_tokens=memory_tokens, memory_token_mask=memory_token_mask, + memory_retrieval_tokens=memory_retrieval_tokens, + memory_retrieval_mask=memory_retrieval_mask, **memory_kwargs) + model_output = model_output.permute(1,0,2,3,4) + x = x.permute(1,0,2,3,4) + t = t.permute(1,0) + + if self.objective == "pred_noise": + pred_noise = torch.clamp(model_output, -self.clip_noise, self.clip_noise) + x_start = self.predict_start_from_noise(x, t, pred_noise) + + elif self.objective == "pred_x0": + x_start = model_output + pred_noise = self.predict_noise_from_start(x, t, x_start) + + elif self.objective == "pred_v": + v = model_output + x_start = self.predict_start_from_v(x, t, v) + pred_noise = self.predict_noise_from_start(x, t, x_start) + + + return ModelPrediction(pred_noise, x_start, model_output) + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_noise_from_start(self, x_t, t, x0): + return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / extract( + self.sqrt_recipm1_alphas_cumprod, t, x_t.shape + ) + + def predict_v(self, x_start, t, noise): + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise + - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start + ) + + def predict_start_from_v(self, x_t, t, v): + return ( + extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t + - extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def q_sample(self, x_start, t, noise=None): + if noise is None: + noise = torch.randn_like(x_start) + noise = torch.clamp(noise, -self.clip_noise, self.clip_noise) + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def p_mean_variance( + self, + x, + t, + action_cond=None, + pose_cond=None, + reference_length=None, + frame_idx=None, + memory_tokens=None, + memory_token_mask=None, + memory_retrieval_tokens=None, + memory_retrieval_mask=None, + **memory_kwargs, + ): + model_pred = self.model_predictions(x=x, t=t, action_cond=action_cond, + pose_cond=pose_cond, reference_length=reference_length, + frame_idx=frame_idx, memory_tokens=memory_tokens, memory_token_mask=memory_token_mask, + memory_retrieval_tokens=memory_retrieval_tokens, + memory_retrieval_mask=memory_retrieval_mask, **memory_kwargs) + x_start = model_pred.pred_x_start + return self.q_posterior(x_start=x_start, x_t=x, t=t) + + def compute_loss_weights(self, noise_levels: torch.Tensor): + + snr = self.snr[noise_levels] + clipped_snr = self.clipped_snr[noise_levels] + normalized_clipped_snr = clipped_snr / self.snr_clip + normalized_snr = snr / self.snr_clip + + if not self.use_fused_snr: + # min SNR reweighting + match self.objective: + case "pred_noise": + return clipped_snr / snr + case "pred_x0": + return clipped_snr + case "pred_v": + return clipped_snr / (snr + 1) + + cum_snr = torch.zeros_like(normalized_snr) + for t in range(0, noise_levels.shape[0]): + if t == 0: + cum_snr[t] = normalized_clipped_snr[t] + else: + cum_snr[t] = self.cum_snr_decay * cum_snr[t - 1] + (1 - self.cum_snr_decay) * normalized_clipped_snr[t] + + cum_snr = F.pad(cum_snr[:-1], (0, 0, 1, 0), value=0.0) + clipped_fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_clipped_snr) + fused_snr = 1 - (1 - cum_snr * self.cum_snr_decay) * (1 - normalized_snr) + + match self.objective: + case "pred_noise": + return clipped_fused_snr / fused_snr + case "pred_x0": + return clipped_fused_snr * self.snr_clip + case "pred_v": + return clipped_fused_snr * self.snr_clip / (fused_snr * self.snr_clip + 1) + case _: + raise ValueError(f"unknown objective {self.objective}") + + def forward( + self, + x: torch.Tensor, + action_cond: Optional[torch.Tensor], + pose_cond, + noise_levels: torch.Tensor, + reference_length, + frame_idx=None, + memory_tokens=None, + memory_token_mask=None, + memory_retrieval_tokens=None, + memory_retrieval_mask=None, + **memory_kwargs, + ): + noise = torch.randn_like(x) + noise = torch.clamp(noise, -self.clip_noise, self.clip_noise) + + noised_x = self.q_sample(x_start=x, t=noise_levels, noise=noise) + + model_pred = self.model_predictions(x=noised_x, t=noise_levels, action_cond=action_cond, + pose_cond=pose_cond,reference_length=reference_length, frame_idx=frame_idx, + memory_tokens=memory_tokens, memory_token_mask=memory_token_mask, + memory_retrieval_tokens=memory_retrieval_tokens, + memory_retrieval_mask=memory_retrieval_mask, **memory_kwargs) + + pred = model_pred.model_out + x_pred = model_pred.pred_x_start + + if self.objective == "pred_noise": + target = noise + elif self.objective == "pred_x0": + target = x + elif self.objective == "pred_v": + target = self.predict_v(x, noise_levels, noise) + else: + raise ValueError(f"unknown objective {self.objective}") + + # 训练的时候每个frame随便给噪声 + loss = F.mse_loss(pred, target.detach(), reduction="none") + loss_weight = self.compute_loss_weights(noise_levels) + + loss_weight = loss_weight.view(*loss_weight.shape, *((1,) * (loss.ndim - 2))) + + loss = loss * loss_weight + + return x_pred, loss + + def sample_step( + self, + x: torch.Tensor, + action_cond: Optional[torch.Tensor], + pose_cond, + curr_noise_level: torch.Tensor, + next_noise_level: torch.Tensor, + guidance_fn: Optional[Callable] = None, + current_frame=None, + mode="training", + reference_length=None, + frame_idx=None, + memory_tokens=None, + memory_token_mask=None, + memory_retrieval_tokens=None, + memory_retrieval_mask=None, + **memory_kwargs, + ): + real_steps = torch.linspace(-1, self.timesteps - 1, steps=self.sampling_timesteps + 1, device=x.device).long() + + # convert noise levels (0 ~ sampling_timesteps) to real noise levels (-1 ~ timesteps - 1) + curr_noise_level = real_steps[curr_noise_level] + next_noise_level = real_steps[next_noise_level] + + if self.is_ddim_sampling: + return self.ddim_sample_step( + x=x, + action_cond=action_cond, + pose_cond=pose_cond, + curr_noise_level=curr_noise_level, + next_noise_level=next_noise_level, + guidance_fn=guidance_fn, + current_frame=current_frame, + mode=mode, + reference_length=reference_length, + frame_idx=frame_idx, + memory_tokens=memory_tokens, + memory_token_mask=memory_token_mask, + memory_retrieval_tokens=memory_retrieval_tokens, + memory_retrieval_mask=memory_retrieval_mask, + **memory_kwargs, + ) + + # FIXME: temporary code for checking ddpm sampling + assert torch.all( + (curr_noise_level - 1 == next_noise_level) | ((curr_noise_level == -1) & (next_noise_level == -1)) + ), "Wrong noise level given for ddpm sampling." + + assert ( + self.sampling_timesteps == self.timesteps + ), "sampling_timesteps should be equal to timesteps for ddpm sampling." + + return self.ddpm_sample_step( + x=x, + action_cond=action_cond, + pose_cond=pose_cond, + curr_noise_level=curr_noise_level, + guidance_fn=guidance_fn, + reference_length=reference_length, + frame_idx=frame_idx, + memory_tokens=memory_tokens, + memory_token_mask=memory_token_mask, + memory_retrieval_tokens=memory_retrieval_tokens, + memory_retrieval_mask=memory_retrieval_mask, + **memory_kwargs, + ) + + def ddpm_sample_step( + self, + x: torch.Tensor, + action_cond: Optional[torch.Tensor], + pose_cond, + curr_noise_level: torch.Tensor, + guidance_fn: Optional[Callable] = None, + reference_length=None, + frame_idx=None, + memory_tokens=None, + memory_token_mask=None, + memory_retrieval_tokens=None, + memory_retrieval_mask=None, + **memory_kwargs, + ): + clipped_curr_noise_level = torch.where( + curr_noise_level < 0, + torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long), + curr_noise_level, + ) + + # treating as stabilization would require us to scale with sqrt of alpha_cum + orig_x = x.clone().detach() + scaled_context = self.q_sample( + x, + clipped_curr_noise_level, + noise=torch.zeros_like(x), + ) + x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x) + + if guidance_fn is not None: + raise NotImplementedError("Guidance function is not implemented for ddpm sampling yet.") + + else: + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, + t=clipped_curr_noise_level, + action_cond=action_cond, + pose_cond=pose_cond, + reference_length=reference_length, + frame_idx=frame_idx, + memory_tokens=memory_tokens, + memory_token_mask=memory_token_mask, + memory_retrieval_tokens=memory_retrieval_tokens, + memory_retrieval_mask=memory_retrieval_mask, + **memory_kwargs, + ) + + noise = torch.where( + self.add_shape_channels(clipped_curr_noise_level > 0), + torch.randn_like(x), + 0, + ) + noise = torch.clamp(noise, -self.clip_noise, self.clip_noise) + x_pred = model_mean + torch.exp(0.5 * model_log_variance) * noise + + # only update frames where the noise level decreases + return torch.where(self.add_shape_channels(curr_noise_level == -1), orig_x, x_pred) + + def ddim_sample_step( + self, + x: torch.Tensor, + action_cond: Optional[torch.Tensor], + pose_cond, + curr_noise_level: torch.Tensor, + next_noise_level: torch.Tensor, + guidance_fn: Optional[Callable] = None, + current_frame=None, + mode="training", + reference_length=None, + frame_idx=None, + memory_tokens=None, + memory_token_mask=None, + memory_retrieval_tokens=None, + memory_retrieval_mask=None, + **memory_kwargs, + ): + # convert noise level -1 to self.stabilization_level - 1 + clipped_curr_noise_level = torch.where( + curr_noise_level < 0, + torch.full_like(curr_noise_level, self.stabilization_level - 1, dtype=torch.long), + curr_noise_level, + ) + + # treating as stabilization would require us to scale with sqrt of alpha_cum + orig_x = x.clone().detach() + scaled_context = self.q_sample( + x, + clipped_curr_noise_level, + noise=torch.zeros_like(x), + ) + x = torch.where(self.add_shape_channels(curr_noise_level < 0), scaled_context, orig_x) + + alpha = self.alphas_cumprod[clipped_curr_noise_level] + alpha_next = torch.where( + next_noise_level < 0, + torch.ones_like(next_noise_level), + self.alphas_cumprod[next_noise_level], + ) + sigma = torch.where( + next_noise_level < 0, + torch.zeros_like(next_noise_level), + self.ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt(), + ) + c = (1 - alpha_next - sigma**2).sqrt() + + alpha_next = self.add_shape_channels(alpha_next) + c = self.add_shape_channels(c) + sigma = self.add_shape_channels(sigma) + + if guidance_fn is not None: + with torch.enable_grad(): + x = x.detach().requires_grad_() + + model_pred = self.model_predictions( + x=x, + t=clipped_curr_noise_level, + action_cond=action_cond, + pose_cond=pose_cond, + current_frame=current_frame, + mode=mode, + reference_length=reference_length, + frame_idx=frame_idx, + memory_tokens=memory_tokens, + memory_token_mask=memory_token_mask, + memory_retrieval_tokens=memory_retrieval_tokens, + memory_retrieval_mask=memory_retrieval_mask, + **memory_kwargs, + ) + + guidance_loss = guidance_fn(model_pred.pred_x_start) + grad = -torch.autograd.grad( + guidance_loss, + x, + )[0] + + pred_noise = model_pred.pred_noise + (1 - alpha_next).sqrt() * grad + x_start = self.predict_start_from_noise(x, clipped_curr_noise_level, pred_noise) + + else: + # print(clipped_curr_noise_level) + model_pred = self.model_predictions( + x=x, + t=clipped_curr_noise_level, + action_cond=action_cond, + pose_cond=pose_cond, + current_frame=current_frame, + mode=mode, + reference_length=reference_length, + frame_idx=frame_idx, + memory_tokens=memory_tokens, + memory_token_mask=memory_token_mask, + memory_retrieval_tokens=memory_retrieval_tokens, + memory_retrieval_mask=memory_retrieval_mask, + **memory_kwargs, + ) + x_start = model_pred.pred_x_start + pred_noise = model_pred.pred_noise + + noise = torch.randn_like(x) + noise = torch.clamp(noise, -self.clip_noise, self.clip_noise) + + x_pred = x_start * alpha_next.sqrt() + pred_noise * c + sigma * noise + + # only update frames where the noise level decreases + mask = curr_noise_level == next_noise_level + x_pred = torch.where( + self.add_shape_channels(mask), + orig_x, + x_pred, + ) + + return x_pred diff --git a/algorithms/worldmem/models/dit.py b/algorithms/worldmem/models/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..0a1a0b6a780cead46a2ee3074fa77a6f6db0e60d --- /dev/null +++ b/algorithms/worldmem/models/dit.py @@ -0,0 +1,899 @@ +""" +References: + - DiT: https://github.com/facebookresearch/DiT/blob/main/models.py + - Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py + - Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py +""" + +from typing import Optional, Literal +import torch +from torch import nn +from .rotary_embedding_torch import RotaryEmbedding +from einops import rearrange +from .attention import SpatialAxialAttention, TemporalAxialAttention +from timm.models.vision_transformer import Mlp +from timm.layers.helpers import to_2tuple +import math +from collections import namedtuple +from typing import Optional, Callable + +def modulate(x, shift, scale): + fixed_dims = [1] * len(shift.shape[1:]) + shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims) + scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims) + while shift.dim() < x.dim(): + shift = shift.unsqueeze(-2) + scale = scale.unsqueeze(-2) + return x * (1 + scale) + shift + +def gate(x, g): + fixed_dims = [1] * len(g.shape[1:]) + g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims) + while g.dim() < x.dim(): + g = g.unsqueeze(-2) + return g * x + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_height=256, + img_width=256, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + ): + super().__init__() + img_size = (img_height, img_width) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x, random_sample=False): + B, C, H, W = x.shape + assert random_sample or (H == self.img_size[0] and W == self.img_size[1]), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + + x = self.proj(x) + if self.flatten: + x = rearrange(x, "B C H W -> B (H W) C") + else: + x = rearrange(x, "B C H W -> B H W C") + x = self.norm(x) + return x + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256, freq_type='time_step'): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + self.freq_type = freq_type + + @staticmethod + def timestep_embedding(t, dim, max_period=10000, freq_type='time_step'): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + + if freq_type == 'time_step': + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + elif freq_type == 'spatial': # ~(-5 5) + freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi + elif freq_type == 'angle': # 0-360 + freqs = torch.linspace(1.0, half, half).to(device=t.device) * torch.pi / 180 + + + args = t[:, None].float() * freqs[None] + + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size, freq_type=self.freq_type) + t_emb = self.mlp(t_freq) + return t_emb + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +MEMORY_TYPE_NAMES = ("anchor", "dynamic", "revisit") +MEMORY_TYPE_ANCHOR = 0 +MEMORY_TYPE_DYNAMIC = 1 +MEMORY_TYPE_REVISIT = 2 + + +class MemoryTokenCrossAttention(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, num_memory_types=3): + super().__init__() + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.num_heads = num_heads + self.num_memory_types = num_memory_types + self.norm_q = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.norm_mem = nn.LayerNorm(hidden_size, eps=1e-6) + self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True) + self.norm_mlp = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + self.memory_type_embed = nn.Embedding(num_memory_types, hidden_size) + self.memory_type_scale = nn.Parameter(torch.ones(num_memory_types, hidden_size)) + self.memory_type_gate = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, num_memory_types, bias=True)) + self.last_gate_mean = None + self.last_delta_ratio = None + self.last_valid_fraction = None + self.last_type_gate_mean = None + for type_name in MEMORY_TYPE_NAMES[:num_memory_types]: + setattr(self, f"last_type_gate_{type_name}_mean", None) + nn.init.normal_(self.memory_type_embed.weight, std=0.02) + self.reset_identity_init() + + def reset_identity_init(self): + nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.memory_type_gate[-1].weight, 0) + nn.init.constant_(self.memory_type_gate[-1].bias, 0) + + def _attend(self, query, memory_tokens, memory_token_mask=None, memory_token_gate=None): + if memory_token_mask is None and memory_token_gate is None: + out, _ = self.attn(query, memory_tokens, memory_tokens, need_weights=False) + return out, None + + if memory_token_mask is None: + memory_token_mask = torch.ones( + memory_tokens.shape[:2], + device=memory_tokens.device, + dtype=torch.bool, + ) + else: + memory_token_mask = memory_token_mask.bool() + gate_tensor = None + if memory_token_gate is not None: + if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:2]): + raise ValueError( + f"memory_token_gate must have shape {tuple(memory_tokens.shape[:2])}, " + f"got {tuple(memory_token_gate.shape)}" + ) + gate_tensor = memory_token_gate.to(device=memory_tokens.device, dtype=query.dtype) + memory_token_mask = memory_token_mask & (gate_tensor > 0) + valid_rows = memory_token_mask.any(dim=1) + out = torch.zeros_like(query) + if valid_rows.any(): + attn_mask = None + key_padding_mask = ~memory_token_mask[valid_rows] + if gate_tensor is not None: + gate_bias = torch.log(gate_tensor[valid_rows].clamp_min(1.0e-6)) + gate_bias = gate_bias[:, None, :].expand(-1, query.shape[1], -1) + attn_mask = gate_bias.repeat_interleave(self.num_heads, dim=0) + float_padding_mask = torch.zeros_like(gate_tensor[valid_rows], dtype=query.dtype) + key_padding_mask = float_padding_mask.masked_fill(key_padding_mask, float("-inf")) + attended, _ = self.attn( + query[valid_rows], + memory_tokens[valid_rows], + memory_tokens[valid_rows], + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + need_weights=False, + ) + out[valid_rows] = attended.to(out.dtype) + return out, valid_rows + + def _apply_memory_type(self, memory_tokens, memory_type_ids): + if memory_type_ids is None: + return memory_tokens + memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long) + type_embed = self.memory_type_embed(memory_type_ids).to(memory_tokens.dtype) + type_scale = self.memory_type_scale[memory_type_ids].to(memory_tokens.dtype) + while type_embed.dim() < memory_tokens.dim(): + type_embed = type_embed.unsqueeze(0) + type_scale = type_scale.unsqueeze(0) + return memory_tokens * type_scale + type_embed + + def _store_type_gate_diagnostics(self, stage_gate): + with torch.no_grad(): + detached = stage_gate.detach().float() + self.last_type_gate_mean = detached.mean() + for type_idx, type_name in enumerate(MEMORY_TYPE_NAMES[: self.num_memory_types]): + setattr(self, f"last_type_gate_{type_name}_mean", detached[..., type_idx].mean()) + + def _type_stage_gate(self, c, memory_tokens, memory_type_ids): + if memory_type_ids is None: + return None + memory_type_ids = memory_type_ids.to(device=memory_tokens.device, dtype=torch.long) + stage_gate = torch.sigmoid(self.memory_type_gate(c)).to(memory_tokens.dtype) + self._store_type_gate_diagnostics(stage_gate) + if memory_tokens.dim() == 4: + batch_size, num_frames, num_tokens = memory_tokens.shape[:3] + if memory_type_ids.dim() == 1: + gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, num_frames, num_tokens) + elif tuple(memory_type_ids.shape) == (batch_size, num_frames, num_tokens): + gather_ids = memory_type_ids + else: + raise ValueError( + "rank-4 memory_type_ids must have shape (M,) or (B,T,M), " + f"got {tuple(memory_type_ids.shape)}" + ) + return torch.gather(stage_gate, dim=-1, index=gather_ids) + if memory_tokens.dim() == 3: + batch_size, num_tokens = memory_tokens.shape[:2] + if memory_type_ids.dim() != 1: + raise ValueError("rank-3 memory_type_ids must have shape (M,)") + gather_ids = memory_type_ids.view(1, 1, num_tokens).expand(batch_size, stage_gate.shape[1], num_tokens) + return torch.gather(stage_gate, dim=-1, index=gather_ids).mean(dim=1) + raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}") + + def _combine_memory_gate(self, memory_tokens, memory_token_gate, type_stage_gate): + combined_gate = type_stage_gate + if memory_token_gate is not None: + if tuple(memory_token_gate.shape) != tuple(memory_tokens.shape[:-1]): + raise ValueError( + f"memory_token_gate must have shape {tuple(memory_tokens.shape[:-1])}, " + f"got {tuple(memory_token_gate.shape)}" + ) + stream_gate = memory_token_gate.to(device=memory_tokens.device, dtype=memory_tokens.dtype) + combined_gate = stream_gate if combined_gate is None else combined_gate * stream_gate + return combined_gate + + def _valid_mask(self, valid_rows, batch_size, num_frames, dtype, device): + if valid_rows is None: + return None + valid_rows = valid_rows.to(device=device, dtype=dtype) + if valid_rows.numel() == batch_size: + return valid_rows.view(batch_size, 1, 1, 1, 1) + if valid_rows.numel() == batch_size * num_frames: + return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None] + raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}") + + def _gate_valid_mask(self, valid_rows, batch_size, num_frames, dtype, device): + if valid_rows is None: + return None + valid_rows = valid_rows.to(device=device, dtype=dtype) + if valid_rows.numel() == batch_size: + return valid_rows.view(batch_size, 1, 1) + if valid_rows.numel() == batch_size * num_frames: + return rearrange(valid_rows, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None] + raise ValueError(f"valid_rows has incompatible shape: {tuple(valid_rows.shape)}") + + def _residual_gate(self, residual_gate, batch_size, num_frames, dtype, device): + if residual_gate is None: + return None + if not torch.is_tensor(residual_gate): + return torch.tensor(float(residual_gate), dtype=dtype, device=device).view(1, 1, 1, 1, 1) + gate_tensor = residual_gate.to(device=device, dtype=dtype) + if gate_tensor.dim() == 0: + gate_tensor = gate_tensor.view(1, 1, 1, 1, 1) + elif gate_tensor.dim() == 1: + if gate_tensor.numel() == batch_size: + gate_tensor = gate_tensor.view(batch_size, 1, 1, 1, 1) + elif gate_tensor.numel() == batch_size * num_frames: + gate_tensor = rearrange(gate_tensor, "(b t) -> b t", b=batch_size, t=num_frames)[:, :, None, None, None] + else: + raise ValueError(f"residual_gate has incompatible shape: {tuple(gate_tensor.shape)}") + elif gate_tensor.dim() == 2: + if tuple(gate_tensor.shape) != (batch_size, num_frames): + raise ValueError(f"residual_gate must have shape (B,T), got {tuple(gate_tensor.shape)}") + gate_tensor = gate_tensor[:, :, None, None, None] + elif gate_tensor.dim() == 3: + if tuple(gate_tensor.shape[:2]) != (batch_size, num_frames): + raise ValueError(f"residual_gate must start with (B,T), got {tuple(gate_tensor.shape)}") + gate_tensor = gate_tensor[:, :, :, None, None] + else: + while gate_tensor.dim() < 5: + gate_tensor = gate_tensor.unsqueeze(-1) + return gate_tensor + + def _store_diagnostics(self, output, base, gate_msa, gate_mlp, valid_rows): + with torch.no_grad(): + batch_size, num_frames = base.shape[:2] + gate_values = torch.cat( + [gate_msa.detach().float().abs(), gate_mlp.detach().float().abs()], + dim=-1, + ) + gate_mask = self._gate_valid_mask( + valid_rows, + batch_size, + num_frames, + dtype=gate_values.dtype, + device=gate_values.device, + ) + if gate_mask is not None: + gate_values = gate_values * gate_mask + self.last_valid_fraction = valid_rows.detach().float().mean() + valid_count = (gate_mask.sum() * gate_values.shape[-1]).clamp_min(1.0) + self.last_gate_mean = gate_values.sum() / valid_count + else: + self.last_valid_fraction = base.detach().new_tensor(1.0, dtype=torch.float32) + self.last_gate_mean = gate_values.mean() + + delta_norm = (output.detach().float() - base.detach().float()).norm() + base_norm = base.detach().float().norm() + self.last_delta_ratio = delta_norm / (base_norm + 1e-6) + + def forward( + self, + x, + c, + memory_tokens, + memory_token_mask=None, + residual_base=None, + return_delta=False, + residual_gate=None, + memory_type_ids=None, + memory_token_gate=None, + ): + B, T, H, W, D = x.shape + if residual_base is None: + residual_base = x + m_shift_msa, m_scale_msa, m_gate_msa, m_shift_mlp, m_scale_mlp, m_gate_mlp = ( + self.adaLN_modulation(c).chunk(6, dim=-1) + ) + query_source = modulate(self.norm_q(x), m_shift_msa, m_scale_msa) + type_stage_gate = self._type_stage_gate(c, memory_tokens, memory_type_ids) + effective_token_gate = self._combine_memory_gate(memory_tokens, memory_token_gate, type_stage_gate) + if memory_tokens.dim() == 3: + query = rearrange(query_source, "b t h w d -> b (t h w) d") + memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids) + valid_rows = None + if memory_token_mask is not None: + if tuple(memory_token_mask.shape) != tuple(memory_tokens.shape[:2]): + raise ValueError( + f"legacy memory mask must have shape {tuple(memory_tokens.shape[:2])}, " + f"got {tuple(memory_token_mask.shape)}" + ) + out, valid_rows = self._attend( + query, + memory_tokens, + memory_token_mask=memory_token_mask, + memory_token_gate=effective_token_gate, + ) + out = rearrange(out, "b (t h w) d -> b t h w d", t=T, h=H, w=W) + elif memory_tokens.dim() == 4: + assert memory_tokens.shape[:2] == (B, T), ( + f"per-frame memory tokens must have shape (B, T, M, D), got {tuple(memory_tokens.shape)}" + ) + query = rearrange(query_source, "b t h w d -> (b t) (h w) d") + memory_tokens = self._apply_memory_type(self.norm_mem(memory_tokens), memory_type_ids) + memory_tokens = rearrange(memory_tokens, "b t m d -> (b t) m d") + if effective_token_gate is not None: + effective_token_gate = rearrange(effective_token_gate, "b t m -> (b t) m") + valid_rows = None + if memory_token_mask is not None: + expected_mask_shape = (B, T, memory_tokens.shape[1]) + if tuple(memory_token_mask.shape) != expected_mask_shape: + raise ValueError( + f"per-frame memory mask must have shape {expected_mask_shape}, " + f"got {tuple(memory_token_mask.shape)}" + ) + memory_token_mask = rearrange(memory_token_mask.bool(), "b t m -> (b t) m") + out, valid_rows = self._attend( + query, + memory_tokens, + memory_token_mask=memory_token_mask, + memory_token_gate=effective_token_gate, + ) + out = rearrange(out, "(b t) (h w) d -> b t h w d", b=B, t=T, h=H, w=W) + else: + raise ValueError(f"memory_tokens must be rank 3 or 4, got rank {memory_tokens.dim()}") + + valid_mask = self._valid_mask(valid_rows, B, T, dtype=out.dtype, device=out.device) + residual_gate_tensor = self._residual_gate(residual_gate, B, T, dtype=out.dtype, device=out.device) + attn_delta = gate(out, m_gate_msa) + if valid_mask is not None: + attn_delta = attn_delta * valid_mask + if residual_gate_tensor is not None: + attn_delta = attn_delta * residual_gate_tensor + output = residual_base + attn_delta + + mlp_delta = gate(self.mlp(modulate(self.norm_mlp(output), m_shift_mlp, m_scale_mlp)), m_gate_mlp) + if valid_mask is not None: + mlp_delta = mlp_delta * valid_mask + if residual_gate_tensor is not None: + mlp_delta = mlp_delta * residual_gate_tensor + output = output + mlp_delta + self._store_diagnostics(output, residual_base, m_gate_msa, m_gate_mlp, valid_rows) + if return_delta: + return attn_delta + mlp_delta + return output + +class SpatioTemporalDiTBlock(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + reference_length, + mlp_ratio=4.0, + is_causal=True, + spatial_rotary_emb: Optional[RotaryEmbedding] = None, + temporal_rotary_emb: Optional[RotaryEmbedding] = None, + use_memory_token_cross_attention=False, + ref_mode='sequential' + ): + super().__init__() + self.is_causal = is_causal + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + + self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.s_attn = SpatialAxialAttention( + hidden_size, + heads=num_heads, + dim_head=hidden_size // num_heads, + rotary_emb=spatial_rotary_emb + ) + self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.s_mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + self.s_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.t_attn = TemporalAxialAttention( + hidden_size, + heads=num_heads, + dim_head=hidden_size // num_heads, + is_causal=is_causal, + rotary_emb=temporal_rotary_emb, + reference_length=reference_length + ) + self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.t_mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + self.t_adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True)) + + self.reference_length = reference_length + self.use_memory_token_cross_attention = use_memory_token_cross_attention + if self.use_memory_token_cross_attention: + self.memory_token_cross_attn = MemoryTokenCrossAttention(hidden_size, num_heads, mlp_ratio=mlp_ratio) + + self.ref_mode = ref_mode + + if self.ref_mode == 'parallel': + self.parallel_map = nn.Linear(hidden_size, hidden_size) + + def _expand_memory_stream(self, tokens, mask, stream_gate, type_idx, batch_size, num_frames): + if tokens is None or tokens.shape[-2] == 0: + return None + if tokens.dim() == 3: + if tokens.shape[0] != batch_size: + raise ValueError(f"rank-3 memory tokens must start with B={batch_size}, got {tuple(tokens.shape)}") + tokens = tokens[:, None].expand(-1, num_frames, -1, -1) + if mask is None: + mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool) + elif mask.dim() == 2: + mask = mask[:, None].expand(-1, num_frames, -1) + elif mask.dim() != 3: + raise ValueError(f"rank-3 stream mask must have rank 2 or 3, got {tuple(mask.shape)}") + elif tokens.dim() == 4: + if tuple(tokens.shape[:2]) != (batch_size, num_frames): + raise ValueError( + f"rank-4 memory tokens must start with (B,T)={(batch_size, num_frames)}, " + f"got {tuple(tokens.shape)}" + ) + if mask is None: + mask = torch.ones(tokens.shape[:3], device=tokens.device, dtype=torch.bool) + elif mask.dim() != 3: + raise ValueError(f"rank-4 stream mask must have rank 3, got {tuple(mask.shape)}") + else: + raise ValueError(f"memory stream tokens must be rank 3 or 4, got rank {tokens.dim()}") + if tuple(mask.shape) != tuple(tokens.shape[:3]): + raise ValueError(f"memory stream mask must have shape {tuple(tokens.shape[:3])}, got {tuple(mask.shape)}") + gate_tensor = self._expand_memory_stream_gate(stream_gate, tokens) + type_ids = torch.full((tokens.shape[2],), int(type_idx), device=tokens.device, dtype=torch.long) + return tokens, mask.to(device=tokens.device, dtype=torch.bool), gate_tensor, type_ids + + def _expand_memory_stream_gate(self, stream_gate, tokens): + batch_size, num_frames, num_tokens = tokens.shape[:3] + if stream_gate is None: + return torch.ones(tokens.shape[:3], device=tokens.device, dtype=tokens.dtype) + if not torch.is_tensor(stream_gate): + return torch.full(tokens.shape[:3], float(stream_gate), device=tokens.device, dtype=tokens.dtype) + gate_tensor = stream_gate.to(device=tokens.device, dtype=tokens.dtype) + if gate_tensor.dim() == 0: + return gate_tensor.view(1, 1, 1).expand(batch_size, num_frames, num_tokens) + if gate_tensor.dim() == 1: + if gate_tensor.numel() != batch_size: + raise ValueError(f"rank-1 memory gate must have B={batch_size} values, got {tuple(gate_tensor.shape)}") + return gate_tensor.view(batch_size, 1, 1).expand(batch_size, num_frames, num_tokens) + if gate_tensor.dim() == 2: + if tuple(gate_tensor.shape) == (batch_size, num_frames): + return gate_tensor[:, :, None].expand(batch_size, num_frames, num_tokens) + if tuple(gate_tensor.shape) == (batch_size, num_tokens): + return gate_tensor[:, None, :].expand(batch_size, num_frames, num_tokens) + raise ValueError( + f"rank-2 memory gate must have shape (B,T) or (B,M), got {tuple(gate_tensor.shape)}" + ) + if gate_tensor.dim() == 3: + if tuple(gate_tensor.shape) == (batch_size, num_frames, 1): + return gate_tensor.expand(batch_size, num_frames, num_tokens) + if tuple(gate_tensor.shape) == (batch_size, num_frames, num_tokens): + return gate_tensor + raise ValueError( + f"rank-3 memory gate must have shape (B,T,1) or (B,T,M), got {tuple(gate_tensor.shape)}" + ) + raise ValueError(f"memory gate rank must be <=3, got rank {gate_tensor.dim()}") + + def _pack_typed_memory_streams( + self, + batch_size, + num_frames, + memory_tokens=None, + memory_token_mask=None, + memory_dynamic_tokens=None, + memory_dynamic_mask=None, + memory_retrieval_tokens=None, + memory_retrieval_mask=None, + memory_anchor_gate=None, + memory_dynamic_gate=None, + memory_retrieval_gate=None, + ): + streams = [] + for tokens, mask, stream_gate, type_idx in ( + (memory_tokens, memory_token_mask, memory_anchor_gate, MEMORY_TYPE_ANCHOR), + (memory_dynamic_tokens, memory_dynamic_mask, memory_dynamic_gate, MEMORY_TYPE_DYNAMIC), + (memory_retrieval_tokens, memory_retrieval_mask, memory_retrieval_gate, MEMORY_TYPE_REVISIT), + ): + expanded = self._expand_memory_stream(tokens, mask, stream_gate, type_idx, batch_size, num_frames) + if expanded is not None: + streams.append(expanded) + if not streams: + return None + packed_tokens = torch.cat([item[0] for item in streams], dim=2) + packed_mask = torch.cat([item[1] for item in streams], dim=2) + packed_gate = torch.cat([item[2] for item in streams], dim=2) + packed_type_ids = torch.cat([item[3] for item in streams], dim=0) + valid_gate = packed_gate.masked_fill(~packed_mask, 0) + residual_gate = valid_gate.max(dim=2).values + return packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate + + def forward(self, x, c, current_frame=None, timestep=None, is_last_block=False, + pose_cond=None, mode="training", c_action_cond=None, reference_length=None, + memory_tokens=None, memory_token_mask=None, memory_dynamic_tokens=None, memory_dynamic_mask=None, + memory_retrieval_tokens=None, memory_retrieval_mask=None, memory_anchor_gate=None, + memory_dynamic_gate=None, memory_retrieval_gate=None): + B, T, H, W, D = x.shape + + # spatial block + + s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1) + x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa) + x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp) + + # temporal block + if c_action_cond is not None: + t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c_action_cond).chunk(6, dim=-1) + else: + t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1) + + x_t = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa) + x_t = x_t + gate(self.t_mlp(modulate(self.t_norm2(x_t), t_shift_mlp, t_scale_mlp)), t_gate_mlp) + + if self.ref_mode == 'sequential': + x = x_t + + if self.use_memory_token_cross_attention: + memory_base = x + packed_memory = self._pack_typed_memory_streams( + B, + T, + memory_tokens=memory_tokens, + memory_token_mask=memory_token_mask, + memory_dynamic_tokens=memory_dynamic_tokens, + memory_dynamic_mask=memory_dynamic_mask, + memory_retrieval_tokens=memory_retrieval_tokens, + memory_retrieval_mask=memory_retrieval_mask, + memory_anchor_gate=memory_anchor_gate, + memory_dynamic_gate=memory_dynamic_gate, + memory_retrieval_gate=memory_retrieval_gate, + ) + if packed_memory is not None: + packed_tokens, packed_mask, packed_gate, packed_type_ids, residual_gate = packed_memory + x = self.memory_token_cross_attn( + memory_base, + c, + packed_tokens, + packed_mask, + residual_gate=residual_gate, + memory_type_ids=packed_type_ids, + memory_token_gate=packed_gate, + ) + + if self.ref_mode == 'parallel': + x = x_t + self.parallel_map(x) + + return x + + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_h=18, + input_w=32, + patch_size=2, + in_channels=16, + hidden_size=1024, + depth=12, + num_heads=16, + mlp_ratio=4.0, + action_cond_dim=25, + max_frames=32, + reference_length=8, + memory_token_cross_attention=False, + memory_cross_attn_layers=None, + ref_mode='sequential' + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.max_frames = max_frames + + self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False) + self.t_embedder = TimestepEmbedder(hidden_size) + + self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256) + self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads) + + self.external_cond = nn.Linear(action_cond_dim, hidden_size) if action_cond_dim > 0 else nn.Identity() + if memory_cross_attn_layers is None: + memory_cross_attn_layer_set = None + else: + memory_cross_attn_layer_set = {int(layer_idx) for layer_idx in memory_cross_attn_layers} + invalid_layers = sorted( + layer_idx for layer_idx in memory_cross_attn_layer_set if layer_idx < 0 or layer_idx >= depth + ) + if invalid_layers: + raise ValueError( + f"memory_cross_attn_layers contains invalid indices {invalid_layers} for depth={depth}" + ) + + self.blocks = nn.ModuleList( + [ + SpatioTemporalDiTBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + is_causal=True, + reference_length=reference_length, + spatial_rotary_emb=self.spatial_rotary_emb, + temporal_rotary_emb=self.temporal_rotary_emb, + use_memory_token_cross_attention=memory_token_cross_attention + and (memory_cross_attn_layer_set is None or block_idx in memory_cross_attn_layer_set), + ref_mode=ref_mode + ) + for block_idx in range(depth) + ] + ) + self.memory_token_cross_attention = memory_token_cross_attention + self.memory_cross_attn_layers = ( + None if memory_cross_attn_layer_set is None else tuple(sorted(memory_cross_attn_layer_set)) + ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0) + nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + if self.memory_token_cross_attention: + for block in self.blocks: + memory_adapter = getattr(block, "memory_token_cross_attn", None) + if memory_adapter is not None: + memory_adapter.reset_identity_init() + + def memory_adapter_delta_diagnostics(self): + diagnostics = {} + ratios = [] + type_gate_values = {type_name: [] for type_name in MEMORY_TYPE_NAMES} + shared_type_gate_values = [] + for block in self.blocks: + adapter = getattr(block, "memory_token_cross_attn", None) + if adapter is None: + continue + ratio = getattr(adapter, "last_delta_ratio", None) + if ratio is not None: + ratios.append(torch.as_tensor(ratio).detach().float()) + type_gate = getattr(adapter, "last_type_gate_mean", None) + if type_gate is not None: + shared_type_gate_values.append(torch.as_tensor(type_gate).detach().float()) + for type_name in MEMORY_TYPE_NAMES: + value = getattr(adapter, f"last_type_gate_{type_name}_mean", None) + if value is not None: + type_gate_values[type_name].append(torch.as_tensor(value).detach().float()) + if ratios: + values = torch.stack(ratios) + diagnostics["memory_adapter_delta_ratio_max"] = float(values.max().item()) + diagnostics["memory_adapter_delta_ratio_mean"] = float(values.mean().item()) + if shared_type_gate_values: + values = torch.stack(shared_type_gate_values) + diagnostics["memory_adapter_type_gate_mean"] = float(values.mean().item()) + for type_name, values_list in type_gate_values.items(): + if values_list: + values = torch.stack(values_list) + diagnostics[f"memory_adapter_type_gate_{type_name}_mean"] = float(values.mean().item()) + return diagnostics + + def unpatchify(self, x): + """ + x: (N, H, W, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = x.shape[1] + w = x.shape[2] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + def forward( + self, + x, + t, + action_cond=None, + pose_cond=None, + current_frame=None, + mode=None, + reference_length=None, + frame_idx=None, + memory_tokens=None, + memory_token_mask=None, + memory_dynamic_tokens=None, + memory_dynamic_mask=None, + memory_retrieval_tokens=None, + memory_retrieval_mask=None, + memory_anchor_gate=None, + memory_dynamic_gate=None, + memory_retrieval_gate=None, + ): + """ + Forward pass of DiT. + x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (B, T,) tensor of diffusion timesteps + """ + + B, T, C, H, W = x.shape + + # add spatial embeddings + x = rearrange(x, "b t c h w -> (b t) c h w") + + x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model + # restore shape + x = rearrange(x, "(b t) h w d -> b t h w d", t=T) + # embed noise steps + t = rearrange(t, "b t -> (b t)") + + c_t = self.t_embedder(t) # (N, D) + c = c_t.clone() + c = rearrange(c, "(b t) d -> b t d", t=T) + + if torch.is_tensor(action_cond): + c_action_cond = c + self.external_cond(action_cond) + else: + c_action_cond = None + + for i, block in enumerate(self.blocks): + x = block(x, c, current_frame=current_frame, timestep=t, is_last_block= (i+1 == len(self.blocks)), + mode=mode, c_action_cond=c_action_cond, reference_length=reference_length, + memory_tokens=memory_tokens, memory_token_mask=memory_token_mask, + memory_dynamic_tokens=memory_dynamic_tokens, memory_dynamic_mask=memory_dynamic_mask, + memory_retrieval_tokens=memory_retrieval_tokens, memory_retrieval_mask=memory_retrieval_mask, + memory_anchor_gate=memory_anchor_gate, memory_dynamic_gate=memory_dynamic_gate, + memory_retrieval_gate=memory_retrieval_gate) # (N, T, H, W, D) + x = self.final_layer(x, c) # (N, T, H, W, patch_size ** 2 * out_channels) + # unpatchify + x = rearrange(x, "b t h w d -> (b t) h w d") + x = self.unpatchify(x) # (N, out_channels, H, W) + x = rearrange(x, "(b t) c h w -> b t c h w", t=T) + return x + + +def DiT_S_2( + action_cond_dim, + reference_length, + ref_mode, + memory_token_cross_attention=False, + memory_cross_attn_layers=None, +): + return DiT( + patch_size=2, + hidden_size=1024, + depth=16, + num_heads=16, + action_cond_dim=action_cond_dim, + reference_length=reference_length, + memory_token_cross_attention=memory_token_cross_attention, + memory_cross_attn_layers=memory_cross_attn_layers, + ref_mode=ref_mode + ) + + +DiT_models = {"DiT-S/2": DiT_S_2} diff --git a/algorithms/worldmem/models/pose_prediction.py b/algorithms/worldmem/models/pose_prediction.py new file mode 100644 index 0000000000000000000000000000000000000000..0a390b45c5d076b8ca4c376859b8ce9e08348438 --- /dev/null +++ b/algorithms/worldmem/models/pose_prediction.py @@ -0,0 +1,42 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class PosePredictionNet(nn.Module): + def __init__(self, img_channels=16, img_feat_dim=256, pose_dim=5, action_dim=25, hidden_dim=128): + super(PosePredictionNet, self).__init__() + + self.cnn = nn.Sequential( + nn.Conv2d(img_channels, 32, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.AdaptiveAvgPool2d((1, 1)) + ) + + self.fc_img = nn.Linear(128, img_feat_dim) + + self.mlp_motion = nn.Sequential( + nn.Linear(pose_dim + action_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU() + ) + + self.fc_out = nn.Sequential( + nn.Linear(img_feat_dim + hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, pose_dim) + ) + + def forward(self, img, action, pose): + img_feat = self.cnn(img).view(img.size(0), -1) + img_feat = self.fc_img(img_feat) + + motion_feat = self.mlp_motion(torch.cat([pose, action], dim=1)) + fused_feat = torch.cat([img_feat, motion_feat], dim=1) + pose_next_pred = self.fc_out(fused_feat) + + return pose_next_pred \ No newline at end of file diff --git a/algorithms/worldmem/models/rotary_embedding_torch.py b/algorithms/worldmem/models/rotary_embedding_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..d9af591e49e8cc367e2789e939e889b210e48c7c --- /dev/null +++ b/algorithms/worldmem/models/rotary_embedding_torch.py @@ -0,0 +1,302 @@ +""" +Adapted from https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py +""" + +from __future__ import annotations +from math import pi, log + +import torch +from torch.nn import Module, ModuleList +from torch.amp import autocast +from torch import nn, einsum, broadcast_tensors, Tensor + +from einops import rearrange, repeat + +from typing import Literal + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# broadcat, as tortoise-tts was using it + + +def broadcat(tensors, dim=-1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim=dim) + + +# rotary embedding helper functions + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +@autocast("cuda", enabled=False) +def apply_rotary_emb(freqs, t, start_index=0, scale=1.0, seq_dim=-2): + dtype = t.dtype + + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:] + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + + # Split t into three parts: left, middle (to be transformed), and right + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + # Apply rotary embeddings without modifying t in place + t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) + + out = torch.cat((t_left, t_transformed, t_right), dim=-1) + + return out.type(dtype) + + +# learned rotation helpers + + +def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): + if exists(freq_ranges): + rotations = einsum("..., f -> ... f", rotations, freq_ranges) + rotations = rearrange(rotations, "... r f -> ... (r f)") + + rotations = repeat(rotations, "... n -> ... (n r)", r=2) + return apply_rotary_emb(rotations, t, start_index=start_index) + + +# classes + + +class RotaryEmbedding(Module): + def __init__( + self, + dim, + custom_freqs: Tensor | None = None, + freqs_for: Literal["lang", "pixel", "constant"] = "lang", + theta=10000, + max_freq=10, + num_freqs=1, + learned_freq=False, + use_xpos=False, + xpos_scale_base=512, + interpolate_factor=1.0, + theta_rescale_factor=1.0, + seq_before_head_dim=False, + cache_if_possible=True, + cache_max_seq_len=8192, + ): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "spacetime": + time_freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + + if freqs_for == "spacetime": + self.time_freqs = nn.Parameter(time_freqs, requires_grad=learned_freq) + self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) + + self.cache_if_possible = cache_if_possible + self.cache_max_seq_len = cache_max_seq_len + + self.register_buffer("cached_freqs", torch.zeros(cache_max_seq_len, dim), persistent=False) + self.register_buffer("cached_freqs_seq_len", torch.tensor(0), persistent=False) + + self.learned_freq = learned_freq + + # dummy for device + + self.register_buffer("dummy", torch.tensor(0), persistent=False) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1.0 + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + + if not use_xpos: + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + + self.register_buffer("scale", scale, persistent=False) + self.register_buffer("cached_scales", torch.zeros(cache_max_seq_len, dim), persistent=False) + self.register_buffer("cached_scales_seq_len", torch.tensor(0), persistent=False) + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def get_seq_pos(self, seq_len, device, dtype, offset=0): + return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, freqs, seq_dim=None, offset=0, scale=None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos or exists(scale), "you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings" + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset) + + seq_freqs = self.forward(seq, freqs, seq_len=seq_len, offset=offset) + + if seq_dim == -3: + seq_freqs = rearrange(seq_freqs, "n d -> n 1 d") + + return apply_rotary_emb(seq_freqs, t, scale=default(scale, 1.0), seq_dim=seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0): + dtype, device, seq_dim = ( + q.dtype, + q.device, + default(seq_dim, self.default_seq_dim), + ) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + + q_scale = k_scale = 1.0 + + if self.use_xpos: + seq = self.get_seq_pos(k_len, dtype=dtype, device=device) + + q_scale = self.get_scale(seq[-q_len:]).type(dtype) + k_scale = self.get_scale(seq).type(dtype) + + rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, scale=q_scale, offset=k_len - q_len + offset) + rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, scale=k_scale**-1) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, freqs, seq_dim=None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype=dtype, device=device) + + seq_freqs = self.forward(seq, freqs, seq_len=seq_len) + scale = self.get_scale(seq, seq_len=seq_len).to(dtype) + + if seq_dim == -3: + seq_freqs = rearrange(seq_freqs, "n d -> n 1 d") + scale = rearrange(scale, "n d -> n 1 d") + + rotated_q = apply_rotary_emb(seq_freqs, q, scale=scale, seq_dim=seq_dim) + rotated_k = apply_rotary_emb(seq_freqs, k, scale=scale**-1, seq_dim=seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale(self, t: Tensor, seq_len: int | None = None, offset=0): + assert self.use_xpos + + should_cache = self.cache_if_possible and exists(seq_len) and (offset + seq_len) <= self.cache_max_seq_len + + if should_cache and exists(self.cached_scales) and (seq_len + offset) <= self.cached_scales_seq_len.item(): + return self.cached_scales[offset : (offset + seq_len)] + + scale = 1.0 + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale ** rearrange(power, "n -> n 1") + scale = repeat(scale, "n d -> n (d r)", r=2) + + if should_cache and offset == 0: + self.cached_scales[:seq_len] = scale.detach() + self.cached_scales_seq_len.copy_(seq_len) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + # only allow pixel freqs for last two dimensions + use_pixel = (self.freqs_for == "pixel" or self.freqs_for == "spacetime") and ind >= len(dims) - 2 + if use_pixel: + pos = torch.linspace(-1, 1, steps=dim, device=self.device) + else: + pos = torch.arange(dim, device=self.device) + + if self.freqs_for == "spacetime" and not use_pixel: + seq_freqs = self.forward(pos, self.time_freqs, seq_len=dim) + else: + seq_freqs = self.forward(pos, self.freqs, seq_len=dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(seq_freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim=-1) + + @autocast("cuda", enabled=False) + def forward(self, t: Tensor, freqs: Tensor, seq_len=None, offset=0): + should_cache = self.cache_if_possible and not self.learned_freq and exists(seq_len) and self.freqs_for != "pixel" and (offset + seq_len) <= self.cache_max_seq_len + + if should_cache and exists(self.cached_freqs) and (offset + seq_len) <= self.cached_freqs_seq_len.item(): + return self.cached_freqs[offset : (offset + seq_len)].detach() + + freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + + if should_cache and offset == 0: + self.cached_freqs[:seq_len] = freqs.detach() + self.cached_freqs_seq_len.copy_(seq_len) + + return freqs diff --git a/algorithms/worldmem/models/utils.py b/algorithms/worldmem/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..41e6f8a9801649421fe5be46879dc74f5db09a12 --- /dev/null +++ b/algorithms/worldmem/models/utils.py @@ -0,0 +1,163 @@ +""" +Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py +Action format derived from VPT https://github.com/openai/Video-Pre-Training +Adapted from https://github.com/etched-ai/open-oasis/blob/master/utils.py +""" + +import math +import torch +from torch import nn +from torchvision.io import read_image, read_video +from torchvision.transforms.functional import resize +from einops import rearrange +from typing import Mapping, Sequence +from einops import rearrange, parse_shape + + +def exists(val): + return val is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def extract(a, t, x_shape): + f, b = t.shape + out = a[t] + return out.reshape(f, b, *((1,) * (len(x_shape) - 2))) + + +def linear_beta_schedule(timesteps): + """ + linear schedule, proposed in original ddpm paper + """ + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps + alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + + + +def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5): + """ + sigmoid schedule + proposed in https://arxiv.org/abs/2212.11972 - Figure 8 + better for images > 64x64, when used during training + """ + steps = timesteps + 1 + t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps + v_start = torch.tensor(start / tau).sigmoid() + v_end = torch.tensor(end / tau).sigmoid() + alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + + +ACTION_KEYS = [ + "inventory", + "ESC", + "hotbar.1", + "hotbar.2", + "hotbar.3", + "hotbar.4", + "hotbar.5", + "hotbar.6", + "hotbar.7", + "hotbar.8", + "hotbar.9", + "forward", + "back", + "left", + "right", + "cameraX", + "cameraY", + "jump", + "sneak", + "sprint", + "swapHands", + "attack", + "use", + "pickItem", + "drop", +] + + +def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor: + actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS)) + for i, current_actions in enumerate(actions): + for j, action_key in enumerate(ACTION_KEYS): + if action_key.startswith("camera"): + if action_key == "cameraX": + value = current_actions["camera"][0] + elif action_key == "cameraY": + value = current_actions["camera"][1] + else: + raise ValueError(f"Unknown camera action key: {action_key}") + max_val = 20 + bin_size = 0.5 + num_buckets = int(max_val / bin_size) + value = (value - num_buckets) / num_buckets + assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}" + else: + value = current_actions[action_key] + assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}" + actions_one_hot[i, j] = value + + return actions_one_hot + + +IMAGE_EXTENSIONS = {"png", "jpg", "jpeg"} +VIDEO_EXTENSIONS = {"mp4"} + + +def load_prompt(path, video_offset=None, n_prompt_frames=1): + if path.lower().split(".")[-1] in IMAGE_EXTENSIONS: + print("prompt is image; ignoring video_offset and n_prompt_frames") + prompt = read_image(path) + # add frame dimension + prompt = rearrange(prompt, "c h w -> 1 c h w") + elif path.lower().split(".")[-1] in VIDEO_EXTENSIONS: + prompt = read_video(path, pts_unit="sec")[0] + if video_offset is not None: + prompt = prompt[video_offset:] + prompt = prompt[:n_prompt_frames] + else: + raise ValueError(f"unrecognized prompt file extension; expected one in {IMAGE_EXTENSIONS} or {VIDEO_EXTENSIONS}") + assert prompt.shape[0] == n_prompt_frames, f"input prompt {path} had less than n_prompt_frames={n_prompt_frames} frames" + prompt = resize(prompt, (360, 640)) + # add batch dimension + prompt = rearrange(prompt, "t c h w -> 1 t c h w") + prompt = prompt.float() / 255.0 + return prompt + + +def load_actions(path, action_offset=None): + if path.endswith(".actions.pt"): + actions = one_hot_actions(torch.load(path)) + elif path.endswith(".one_hot_actions.pt"): + actions = torch.load(path, weights_only=True) + else: + raise ValueError("unrecognized action file extension; expected '*.actions.pt' or '*.one_hot_actions.pt'") + if action_offset is not None: + actions = actions[action_offset:] + actions = torch.cat([torch.zeros_like(actions[:1]), actions], dim=0) + # add batch dimension + actions = rearrange(actions, "t d -> 1 t d") + return actions diff --git a/algorithms/worldmem/models/vae.py b/algorithms/worldmem/models/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..6cad52b41fd533d4ecdcc964a106c03b170cea64 --- /dev/null +++ b/algorithms/worldmem/models/vae.py @@ -0,0 +1,359 @@ +""" +References: + - VQGAN: https://github.com/CompVis/taming-transformers + - MAE: https://github.com/facebookresearch/mae +""" + +import numpy as np +import math +import functools +from collections import namedtuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from timm.models.vision_transformer import Mlp +from timm.layers.helpers import to_2tuple +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb +from .dit import PatchEmbed + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False, dim=1): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim) + if dim == 1: + self.dims = [1, 2, 3] + elif dim == 2: + self.dims = [1, 2] + else: + raise NotImplementedError + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def mode(self): + return self.mean + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads, + frame_height, + frame_width, + qkv_bias=False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.frame_height = frame_height + self.frame_width = frame_width + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + rotary_freqs = RotaryEmbedding( + dim=head_dim // 4, + freqs_for="pixel", + max_freq=frame_height * frame_width, + ).get_axial_freqs(frame_height, frame_width) + self.register_buffer("rotary_freqs", rotary_freqs, persistent=False) + + def forward(self, x): + B, N, C = x.shape + assert N == self.frame_height * self.frame_width + + q, k, v = self.qkv(x).chunk(3, dim=-1) + + q = rearrange( + q, + "b (H W) (h d) -> b h H W d", + H=self.frame_height, + W=self.frame_width, + h=self.num_heads, + ) + k = rearrange( + k, + "b (H W) (h d) -> b h H W d", + H=self.frame_height, + W=self.frame_width, + h=self.num_heads, + ) + v = rearrange( + v, + "b (H W) (h d) -> b h H W d", + H=self.frame_height, + W=self.frame_width, + h=self.num_heads, + ) + + q = apply_rotary_emb(self.rotary_freqs, q) + k = apply_rotary_emb(self.rotary_freqs, k) + + q = rearrange(q, "b h H W d -> b h (H W) d") + k = rearrange(k, "b h H W d -> b h (H W) d") + v = rearrange(v, "b h H W d -> b h (H W) d") + + x = F.scaled_dot_product_attention(q, k, v) + x = rearrange(x, "b h N d -> b N (h d)") + + x = self.proj(x) + return x + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + frame_height, + frame_width, + mlp_ratio=4.0, + qkv_bias=False, + attn_causal=False, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads, + frame_height, + frame_width, + qkv_bias=qkv_bias, + ) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +class AutoencoderKL(nn.Module): + def __init__( + self, + latent_dim, + input_height=256, + input_width=256, + patch_size=16, + enc_dim=768, + enc_depth=6, + enc_heads=12, + dec_dim=768, + dec_depth=6, + dec_heads=12, + mlp_ratio=4.0, + norm_layer=functools.partial(nn.LayerNorm, eps=1e-6), + use_variational=True, + **kwargs, + ): + super().__init__() + self.input_height = input_height + self.input_width = input_width + self.patch_size = patch_size + self.seq_h = input_height // patch_size + self.seq_w = input_width // patch_size + self.seq_len = self.seq_h * self.seq_w + self.patch_dim = 3 * patch_size**2 + + self.latent_dim = latent_dim + self.enc_dim = enc_dim + self.dec_dim = dec_dim + + # patch + self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim) + + # encoder + self.encoder = nn.ModuleList( + [ + AttentionBlock( + enc_dim, + enc_heads, + self.seq_h, + self.seq_w, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) + for i in range(enc_depth) + ] + ) + self.enc_norm = norm_layer(enc_dim) + + # bottleneck + self.use_variational = use_variational + mult = 2 if self.use_variational else 1 + self.quant_conv = nn.Linear(enc_dim, mult * latent_dim) + self.post_quant_conv = nn.Linear(latent_dim, dec_dim) + + # decoder + self.decoder = nn.ModuleList( + [ + AttentionBlock( + dec_dim, + dec_heads, + self.seq_h, + self.seq_w, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + ) + for i in range(dec_depth) + ] + ) + self.dec_norm = norm_layer(dec_dim) + self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch + + # initialize this weight first + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0.0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, x): + # patchify + bsz, _, h, w = x.shape + x = x.reshape( + bsz, + 3, + self.seq_h, + self.patch_size, + self.seq_w, + self.patch_size, + ).permute([0, 1, 3, 5, 2, 4]) # [b, c, h, p, w, p] --> [b, c, p, p, h, w] + x = x.reshape(bsz, self.patch_dim, self.seq_h, self.seq_w) # --> [b, cxpxp, h, w] + x = x.permute([0, 2, 3, 1]).reshape(bsz, self.seq_len, self.patch_dim) # --> [b, hxw, cxpxp] + return x + + def unpatchify(self, x): + bsz = x.shape[0] + # unpatchify + x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute([0, 3, 1, 2]) # [b, h, w, cxpxp] --> [b, cxpxp, h, w] + x = x.reshape( + bsz, + 3, + self.patch_size, + self.patch_size, + self.seq_h, + self.seq_w, + ).permute([0, 1, 4, 2, 5, 3]) # [b, c, p, p, h, w] --> [b, c, h, p, w, p] + x = x.reshape( + bsz, + 3, + self.input_height, + self.input_width, + ) # [b, c, hxp, wxp] + return x + + def encode(self, x): + # patchify + x = self.patch_embed(x) + + # encoder + for blk in self.encoder: + x = blk(x) + x = self.enc_norm(x) + + # bottleneck + moments = self.quant_conv(x) + if not self.use_variational: + moments = torch.cat((moments, torch.zeros_like(moments)), 2) + posterior = DiagonalGaussianDistribution(moments, deterministic=(not self.use_variational), dim=2) + return posterior + + def decode(self, z): + # bottleneck + z = self.post_quant_conv(z) + + # decoder + for blk in self.decoder: + z = blk(z) + z = self.dec_norm(z) + + # predictor + z = self.predictor(z) + + # unpatchify + dec = self.unpatchify(z) + return dec + + def autoencode(self, input, sample_posterior=True): + posterior = self.encode(input) + if self.use_variational and sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior, z + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def forward(self, inputs, labels, split="train"): + rec, post, latent = self.autoencode(inputs) + return rec, post, latent + + def get_last_layer(self): + return self.predictor.weight + + +def ViT_L_20_Shallow_Encoder(**kwargs): + if "latent_dim" in kwargs: + latent_dim = kwargs.pop("latent_dim") + else: + latent_dim = 16 + return AutoencoderKL( + latent_dim=latent_dim, + patch_size=20, + enc_dim=1024, + enc_depth=6, + enc_heads=16, + dec_dim=1024, + dec_depth=12, + dec_heads=16, + input_height=360, + input_width=640, + **kwargs, + ) + + +VAE_models = { + "vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder, +} diff --git a/configurations/algorithm/base_algo.yaml b/configurations/algorithm/base_algo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a116a5d5147fb8aede0a857ff057677186b8a54 --- /dev/null +++ b/configurations/algorithm/base_algo.yaml @@ -0,0 +1,3 @@ +# This will be passed as the cfg to Algo.__init__(cfg) of your algorithm class + +debug: ${debug} # inherited from configurations/config.yaml diff --git a/configurations/algorithm/base_pytorch_algo.yaml b/configurations/algorithm/base_pytorch_algo.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d8870c21cd59ff4e5e88e7fc5d97fdfe120f93f5 --- /dev/null +++ b/configurations/algorithm/base_pytorch_algo.yaml @@ -0,0 +1,4 @@ +defaults: + - base_algo # inherits from configurations/algorithm/base_algo.yaml + +lr: ${experiment.training.lr} diff --git a/configurations/algorithm/base_video_dit.yaml b/configurations/algorithm/base_video_dit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0119d67457c84b5d5a35ef76801473f7d1505fab --- /dev/null +++ b/configurations/algorithm/base_video_dit.yaml @@ -0,0 +1,36 @@ +defaults: + - df_base + +n_frames: ${dataset.n_frames} +frame_skip: ${dataset.frame_skip} +metadata: ${dataset.metadata} + +# training hyperparameters +weight_decay: 2e-3 +warmup_steps: 1000 +optimizer_beta: [0.9, 0.99] +action_cond_dim: 25 +use_plucker: true + +diffusion: + # training + beta_schedule: sigmoid + objective: pred_v + use_fused_snr: True + cum_snr_decay: 0.96 + clip_noise: 20. + # sampling + sampling_timesteps: 20 + ddim_sampling_eta: 0.0 + stabilization_level: 15 + # architecture + architecture: + network_size: 64 + attn_heads: 4 + attn_dim_head: 64 + dim_mults: [1, 2, 4, 8] + resolution: ${dataset.resolution} + attn_resolutions: [16, 32, 64, 128] + use_init_temporal_attn: True + use_linear_attn: True + time_emb_type: rotary diff --git a/configurations/algorithm/dememwm_memory_dit.yaml b/configurations/algorithm/dememwm_memory_dit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fea9bd751cdec2e2be263eef335a4b96db3c33cb --- /dev/null +++ b/configurations/algorithm/dememwm_memory_dit.yaml @@ -0,0 +1,103 @@ + +defaults: + - base_video_dit + - _self_ + +_name: dememwm_memory_dit + +# Standalone Memory-DiT path. Do not route through old SSM-memory config. +memory_token_cross_attention: true +memory_cross_attn_layers: null +memory_condition_length: 0 +pose_cond_dim: 5 +log_video: false + +dememwm: + enabled: true + training_stage: stage_1 # fallback only when curriculum.enabled=false + debug_force_all_streams: false + curriculum: + enabled: true + full_stage_start_step: 60000 + freeze_vae: true + dit_freeze: + enabled: true + lr: + dememwm_modules: 1.0e-4 + memory_adapters: 1.0e-4 + full_dit: 1.0e-5 + # Current Conv2D memory projectors preserve latent H,W=(18,32). + # Pool sizes are resolved from projected spatial grid size and downsample ratios. + token_patch_size: 2 + anchor: + enabled: true + anchor_indices: [0, 1, 2, 3] + allow_generated_as_anchor: false + diverse_selection: true + compress: + downsample_ratio: 4 + dynamic: + enabled: true + exclude_latest_local_frames: 4 + recent_frames: 8 + conv_kernel_t: 3 + conv_stride_t: 2 + revisit: + enabled: true + deterministic_pose_retrieval: true + fov_overlap_threshold: 0.30 + high_quality_fov_threshold: 0.70 + plucker_weight: 0.10 + max_frames: 2 + # FoV geometry for coverage-based retrieval scoring. + # fov_half_h/v: half-angles (degrees) of the horizontal/vertical field of view. + # fov_radius: world-space radius of the sample sphere. + # fov_{yaw,pitch,depth}_samples: grid resolution for FoV point sampling. + fov_half_h: 52.5 # 105 deg total horizontal FoV + fov_half_v: 37.5 # 75 deg total vertical FoV + fov_radius: 30.0 + fov_yaw_samples: 25 + fov_pitch_samples: 20 + fov_depth_samples: 20 + pose_preselect_topk: 64 + # Plucker descriptor grid for secondary pose-similarity scoring. + plucker_grid_h: 4 + plucker_grid_w: 4 + plucker_focal_length: 0.35 + compress: + downsample_ratio: 4 + stage_policy: + noise_bucket_logging: true + eval_ablation: + enabled: false + branch: A_plus_D_plus_R_normal + generated_history_proxy: + enabled: false + start_step: 0 + ramp_steps: 1 + max_prob: 0.0 + noise_std: 0.25 + dropout_prob: 0.0 + injection: + dit_hidden_size: 1024 + anchor_gate: 1.0 + dynamic_gate: 1.0 + revisit_gate: 1.0 + cache: + enabled: true + device: cpu + keep_raw_latents: all + keep_compressed_records: true + keep_prefix_anchors: true + eviction_policy: none + no_evict: true + clear_between_videos: true + max_records: null + max_slots: null + on_capacity_exceeded: warn + checkpoint: + strict_dememwm_eval_load: true + +diffusion: + architecture: + network_size: 64 diff --git a/configurations/algorithm/df_base.yaml b/configurations/algorithm/df_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..65be51cfe5837e3941a967955abb373cb40aaf76 --- /dev/null +++ b/configurations/algorithm/df_base.yaml @@ -0,0 +1,42 @@ +defaults: + - base_pytorch_algo + +# dataset-dependent configurations +x_shape: ${dataset.observation_shape} +frame_stack: 1 +frame_skip: 1 +data_mean: ${dataset.data_mean} +data_std: ${dataset.data_std} +external_cond_dim: 0 #${dataset.action_dim} +context_frames: ${dataset.context_length} +# training hyperparameters +weight_decay: 1e-4 +warmup_steps: 10000 +optimizer_beta: [0.9, 0.999] +# diffusion-related +uncertainty_scale: 1 +guidance_scale: 0.0 +chunk_size: 1 # -1 for full trajectory diffusion, number to specify diffusion chunk size +scheduling_matrix: autoregressive +noise_level: random_all +causal: True + +diffusion: + # training + objective: pred_x0 + beta_schedule: cosine + schedule_fn_kwargs: {} + clip_noise: 20.0 + use_snr: False + use_cum_snr: False + use_fused_snr: False + snr_clip: 5.0 + cum_snr_decay: 0.98 + timesteps: 1000 + # sampling + sampling_timesteps: 50 # fixme, numer of diffusion steps, should be increased + ddim_sampling_eta: 1.0 + stabilization_level: 10 + # architecture + architecture: + network_size: 64 diff --git a/configurations/dataset/base_dataset.yaml b/configurations/dataset/base_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e4fc0a9d01954d5968ce3438c5af00e97a8308ec --- /dev/null +++ b/configurations/dataset/base_dataset.yaml @@ -0,0 +1,3 @@ +# This will be passed as the cfg to Dataset.__init__(cfg) of your dataset class + +debug: ${debug} # inherited from configurations/config.yaml diff --git a/configurations/dataset/base_video.yaml b/configurations/dataset/base_video.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b364d9b62dea195afc4f36dc23b842b9a039bc80 --- /dev/null +++ b/configurations/dataset/base_video.yaml @@ -0,0 +1,14 @@ +defaults: + - base_dataset + +metadata: "data/${dataset.name}/metadata.json" +data_mean: "data/${dataset.name}/data_mean.npy" +data_std: "data/${dataset.name}/data_std.npy" +save_dir: ??? +n_frames: 32 +context_length: 4 +resolution: 128 +observation_shape: [3, "${dataset.resolution}", "${dataset.resolution}"] +external_cond_dim: 0 +validation_multiplier: 1 +frame_skip: 1 \ No newline at end of file diff --git a/configurations/dataset/video_minecraft.yaml b/configurations/dataset/video_minecraft.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b79b93dfa7baa70a66c8a29a3ac79181a13d2e25 --- /dev/null +++ b/configurations/dataset/video_minecraft.yaml @@ -0,0 +1,14 @@ +defaults: + - base_video + +save_dir: data/minecraft_simple_backforward +n_frames: 16 # TODO: increase later +resolution: 128 +data_mean: 0.5 +data_std: 0.5 +action_cond_dim: 25 +context_length: 1 +frame_skip: 1 +validation_multiplier: 1 + +_name: video_minecraft_oasis \ No newline at end of file diff --git a/configurations/dataset/video_minecraft_latent.yaml b/configurations/dataset/video_minecraft_latent.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a5475733121ddc4e5ca51c87fe5da0d11854ce5c --- /dev/null +++ b/configurations/dataset/video_minecraft_latent.yaml @@ -0,0 +1,6 @@ +defaults: + - video_minecraft + +precomputed_feature_dir: /share_1/users/bonan_ding/worldmem_data/minecraft/vae_features + +_name: video_minecraft_latent diff --git a/configurations/experiment/base_experiment.yaml b/configurations/experiment/base_experiment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6884b5e8d8c937671fa88d2a37369234e2cdf30 --- /dev/null +++ b/configurations/experiment/base_experiment.yaml @@ -0,0 +1,2 @@ +debug: ${debug} # inherited from configurations/config.yaml +tasks: [main] # tasks to run sequantially, such as [training, test], useful when your project has multiple stages and you want to run only a subset of them. diff --git a/configurations/experiment/base_pytorch.yaml b/configurations/experiment/base_pytorch.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e789e66ff8eedb594047e10d73b2c2030121bdbe --- /dev/null +++ b/configurations/experiment/base_pytorch.yaml @@ -0,0 +1,51 @@ +# inherites from base_experiment.yaml +# most of the options have docs at https://lightning.ai/docs/pytorch/stable/common/trainer.html + +defaults: + - base_experiment + +tasks: [training] # tasks to run sequantially, change when your project has multiple stages and you want to run only a subset of them. +num_nodes: 1 # number of gpu servers used in large scale distributed training + +training: + precision: 16-mixed # set float precision, 16-mixed is faster while 32 is more stable + compile: False # whether to compile the model with torch.compile + lr: 0.001 # learning rate + batch_size: 16 # training batch size; effective batch size is this number * gpu * nodes iff using distributed training + max_epochs: 1000 # set to -1 to train forever + max_steps: -1 # set to -1 to train forever, will override max_epochs + max_time: null # set to something like "00:12:00:00" to enable + data: + num_workers: 4 # number of CPU threads for data preprocessing. + shuffle: True # whether training data will be shuffled + optim: + accumulate_grad_batches: 1 # accumulate gradients for n batches before backprop + gradient_clip_val: 1.0 # clip gradients with norm above this value, set to 0 to disable + checkpointing: + # these are arguments to pytorch lightning's callback, `ModelCheckpoint` class + every_n_train_steps: 5000 # save a checkpoint every n train steps + every_n_epochs: null # mutually exclusive with ``every_n_train_steps`` and ``train_time_interval`` + train_time_interval: null # in format of "00:12:00:00", mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. + save_last: True # keep last.ckpt for automatic interrupted-run resume + enable_version_counter: False # If this is ``False``, later checkpoint will be overwrite previous ones. + +validation: + precision: 16-mixed + compile: False # whether to compile the model with torch.compile + batch_size: 16 # validation batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training + val_every_n_step: 2000 # controls how frequent do we run validation, can be float (fraction of epoches) or int (steps) or null (if val_every_n_epoch is set) + val_every_n_epoch: null # if you want to do validation every n epoches, requires val_every_n_step to be null. + limit_batch: null # if null, run through validation set. Otherwise limit the number of batches to use for validation. + inference_mode: True # whether to run validation in inference mode (enable_grad won't work!) + data: + num_workers: 4 # number of CPU threads for data preprocessing, for validation. + shuffle: False # whether validation data will be shuffled + +test: + precision: 16-mixed + compile: False # whether to compile the model with torch.compile + batch_size: 4 # test batch size per GPU; effective batch size is this number * gpu * nodes iff using distributed training + limit_batch: null # if null, run through test set. Otherwise limit the number of batches to use for test. + data: + num_workers: 4 # number of CPU threads for data preprocessing, for test. + shuffle: False # whether test data will be shuffled diff --git a/configurations/experiment/exp_video.yaml b/configurations/experiment/exp_video.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d62a447e85947415424041504713c38e1debc9e --- /dev/null +++ b/configurations/experiment/exp_video.yaml @@ -0,0 +1,31 @@ +defaults: + - base_pytorch + +tasks: [training] + +training: + lr: 2e-5 + precision: 16-mixed + batch_size: 4 + max_epochs: -1 + max_steps: 2000005 + checkpointing: + every_n_train_steps: 2500 + optim: + gradient_clip_val: 1.0 + +validation: + val_every_n_step: 2500 + val_every_n_epoch: null + batch_size: 4 + limit_batch: 1 + +test: + limit_batch: 1 + batch_size: 1 + +logging: + metrics: + # - fvd + # - fid + # - lpips diff --git a/configurations/training.yaml b/configurations/training.yaml new file mode 100644 index 0000000000000000000000000000000000000000..acb48dac7ce0b937db5ed657c9b9236ba616da70 --- /dev/null +++ b/configurations/training.yaml @@ -0,0 +1,18 @@ +# configuration parsing starts here +defaults: + - experiment: exp_video # experiment yaml file name in configurations/experiments folder [fixme] + - dataset: video_minecraft # dataset yaml file name in configurations/dataset folder [fixme] + - algorithm: dememwm_memory_dit # algorithm yaml file name in configurations/algorithm folder [fixme] + - cluster: null # optional, cluster yaml file name in configurations/cluster folder. Leave null for local compute + +debug: false # global debug flag will be passed into configuration of experiment, dataset and algorithm + +wandb: + entity: turlin # wandb account name / organization name [fixme] + project: DeMemWM # wandb project name; if not provided, defaults to root folder name [fixme] + mode: offline # set wandb logging to online, offline or dryrun + +resume: null # wandb run id to resume logging and loading checkpoint from +load: null # wandb run id containing checkpoint or a path to a checkpoint file +auto_resume: true # automatically resume training from output_dir/checkpoints when available +resume_ckpt_path: null # explicit full Lightning checkpoint path for deterministic training resume diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7454c379131207e12ae7cf528b0e2c67ecdf78cc --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1 @@ +from .video import MinecraftVideoDataset, MinecraftVideoLatentDataset diff --git a/datasets/video/__init__.py b/datasets/video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6395c45af9c56ecdd83c1c1afd018c81bd80e653 --- /dev/null +++ b/datasets/video/__init__.py @@ -0,0 +1,2 @@ +from .minecraft_video_dataset import MinecraftVideoDataset +from .minecraft_video_latent_dataset import MinecraftVideoLatentDataset diff --git a/datasets/video/base_video_dataset.py b/datasets/video/base_video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dc8068cc50709957f333672d410f6610ccdde09d --- /dev/null +++ b/datasets/video/base_video_dataset.py @@ -0,0 +1,155 @@ +from typing import Sequence +import torch +import random +import os +import numpy as np +import cv2 +from omegaconf import DictConfig +from torchvision import transforms +from pathlib import Path +from abc import abstractmethod, ABC +import json + + +class BaseVideoDataset(torch.utils.data.Dataset, ABC): + """ + Base class for video datasets. Videos may be of variable length. + + Folder structure of each dataset: + - [save_dir] (specified in config, e.g., data/phys101) + - /[split] (one per split) + - /data_folder_name (e.g., videos) + metadata.json + """ + + def __init__(self, cfg: DictConfig, split: str = "training"): + super().__init__() + self.cfg = cfg + self.split = split + self.resolution = cfg.resolution + self.external_cond_dim = cfg.external_cond_dim + self.n_frames = ( + cfg.n_frames * cfg.frame_skip + if split == "training" + else cfg.n_frames * cfg.frame_skip * cfg.validation_multiplier + ) + self.frame_skip = cfg.frame_skip + self.save_dir = Path(cfg.save_dir) + self.save_dir.mkdir(exist_ok=True, parents=True) + self.split_dir = self.save_dir / f"{split}" + self.data_paths = self.get_data_paths(self.split) + + if self.split == 'training': + self.metadata = [1300] * len(self.data_paths) # total 1500 f, randomly sampled ~1300 clips + else: + self.metadata = [1] * len(self.data_paths) # total 1500 f + + self.clips_per_video = np.clip(np.array(self.metadata) - self.n_frames + 1, a_min=1, a_max=None).astype( + np.int32 + ) + self.cum_clips_per_video = np.cumsum(self.clips_per_video) + self.transform = transforms.Resize((self.resolution, self.resolution), antialias=True) + + # shuffle but keep the same order for each epoch, so validation sample is diverse yet deterministic + # Use seed from config if provided, otherwise default to 0 + shuffle_seed = cfg.get("seed", 0) + random.seed(shuffle_seed) + self.idx_remap = list(range(self.__len__())) + random.shuffle(self.idx_remap) + + @abstractmethod + def download_dataset(self) -> Sequence[int]: + """ + Download dataset from the internet and build it in save_dir + + Returns a list of video lengths + """ + raise NotImplementedError + + @abstractmethod + def get_data_paths(self, split): + """Return a list of data paths (e.g. xxx.mp4) for a given split""" + raise NotImplementedError + + def get_data_lengths(self, split): + """Return a list of num_frames for each data path (e.g. xxx.mp4) for a given split""" + lengths = [] + for path in self.get_data_paths(split): + length = cv2.VideoCapture(str(path)).get(cv2.CAP_PROP_FRAME_COUNT) + lengths.append(length) + return lengths + + def split_idx(self, idx): + video_idx = np.argmax(self.cum_clips_per_video > idx) + frame_idx = idx - np.pad(self.cum_clips_per_video, (1, 0))[video_idx] + return video_idx, frame_idx + + @staticmethod + def load_video(path: Path): + """ + Load video from a path + :param filename: path to the video + :return: video as a numpy array + """ + + cap = cv2.VideoCapture(str(path)) + + frames = [] + while cap.isOpened(): + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frames.append(frame) + else: + break + + cap.release() + frames = np.stack(frames, dtype=np.uint8) + return np.transpose(frames, (0, 3, 1, 2)) # (T, C, H, W) + + @staticmethod + def load_image(filename: Path): + """ + Load image from a path + :param filename: path to the image + :return: image as a numpy array + """ + image = cv2.imread(str(filename)) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + return np.transpose(image, (2, 0, 1)) + + def __len__(self): + return self.clips_per_video.sum() + + def __getitem__(self, idx): + idx = self.idx_remap[idx] + video_idx, frame_idx = self.split_idx(idx) + video_path = self.data_paths[video_idx] + video = self.load_video(video_path)[frame_idx : frame_idx + self.n_frames] + + pad_len = self.n_frames - len(video) + + nonterminal = np.ones(self.n_frames) + if len(video) < self.n_frames: + video = np.pad(video, ((0, pad_len), (0, 0), (0, 0), (0, 0))) + nonterminal[-pad_len:] = 0 + + video = torch.from_numpy(video / 256.0).float() + video = self.transform(video) + + if self.external_cond_dim: + external_cond = np.load( + # pylint: disable=no-member + self.condition_dir + / f"{video_path.name.replace('.mp4', '.npy')}" + ) + if len(external_cond) < self.n_frames: + external_cond = np.pad(external_cond, ((0, pad_len),)) + external_cond = torch.from_numpy(external_cond).float() + return ( + video[:: self.frame_skip], + external_cond[:: self.frame_skip], + nonterminal[:: self.frame_skip], + ) + else: + return video[:: self.frame_skip], nonterminal[:: self.frame_skip] diff --git a/datasets/video/minecraft_video_dataset.py b/datasets/video/minecraft_video_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..524632c7f6a3ee338ff6137211649e0195a7d560 --- /dev/null +++ b/datasets/video/minecraft_video_dataset.py @@ -0,0 +1,248 @@ +import os +import io +import tarfile +import numpy as np +import torch +from typing import Sequence, Mapping +from omegaconf import DictConfig +from pytorchvideo.data.encoded_video import EncodedVideo +import random + +from .base_video_dataset import BaseVideoDataset + + + + +ACTION_KEYS = [ + "inventory", + "ESC", + "hotbar.1", + "hotbar.2", + "hotbar.3", + "hotbar.4", + "hotbar.5", + "hotbar.6", + "hotbar.7", + "hotbar.8", + "hotbar.9", + "forward", + "back", + "left", + "right", + "cameraY", + "cameraX", + "jump", + "sneak", + "sprint", + "swapHands", + "attack", + "use", + "pickItem", + "drop", +] + +def convert_action_space(actions): + vec_25 = torch.zeros(len(actions), len(ACTION_KEYS)) + vec_25[actions[:,0]==1, 11] = 1 + vec_25[actions[:,0]==2, 12] = 1 + vec_25[actions[:,4]==11, 16] = -1 + vec_25[actions[:,4]==13, 16] = 1 + vec_25[actions[:,3]==11, 15] = -1 + vec_25[actions[:,3]==13, 15] = 1 + vec_25[actions[:,5]==6, 24] = 1 + vec_25[actions[:,5]==1, 24] = 1 + vec_25[actions[:,1]==1, 13] = 1 + vec_25[actions[:,1]==2, 14] = 1 + vec_25[actions[:,7]==1, 2] = 1 + return vec_25 + +# Dataset class +class MinecraftVideoDataset(BaseVideoDataset): + """ + Minecraft video dataset for training and validation. + + Args: + cfg (DictConfig): Configuration object. + split (str): Dataset split ("training" or "validation"). + """ + def __init__(self, cfg: DictConfig, split: str = "training"): + self.wo_updown = getattr(cfg, "wo_updown", False) + super().__init__(cfg, split) + self.n_frames = cfg.n_frames_valid if split == "validation" or split == "test" and hasattr(cfg, "n_frames_valid") else cfg.n_frames + self.memory_condition_length = getattr(cfg, "memory_condition_length", 8) + self.customized_validation = cfg.customized_validation + if split == "training": + self.angle_range = cfg.angle_range + self.pos_range = cfg.pos_range + self.add_timestamp_embedding = getattr(cfg, "add_timestamp_embedding", True) + self.training_dropout = 0.1 + self.sample_more_event = getattr(cfg, "sample_more_event", False) + self.causal_frame = getattr(cfg, "causal_frame", False) + + def get_data_paths(self, split: str): + """ + Retrieve all video file paths for the given split. + + Args: + split (str): Dataset split ("training" or "validation"). + + Returns: + List[Path]: List of video file paths. + """ + data_dir = self.save_dir / split + paths = sorted(list(data_dir.glob("**/*.mp4")), key=lambda x: x.name) + + if self.wo_updown: + # Filter out paths containing "w_updown" + paths = [p for p in paths if "w_updown" not in str(p)] + + if (split == "validation" or split == "test") and self.wo_updown: + paths = [p for p in paths if "w_updown" not in str(p)] + elif split == "validation" or split == "test": + paths = [p for p in paths if "w_updown" in str(p)] + + if not paths: + sub_dirs = os.listdir(data_dir) + for sub_dir in sub_dirs: + sub_path = data_dir / sub_dir + paths += sorted(list(sub_path.glob("**/*.mp4")), key=lambda x: x.name) + return paths + + def download_dataset(self): + pass + + def __getitem__(self, idx: int): + """ + Retrieve a single data sample by index. + + Args: + idx (int): Index of the data sample. + + Returns: + Tuple[torch.Tensor, torch.Tensor, np.ndarray, np.ndarray]: Video, actions, poses, and timestamps. + """ + max_retries = 1000 + for _ in range(max_retries): + try: + return self.load_data(idx) + except Exception as e: + # print(f"Retrying due to error: {e}") + idx = (idx + 1) % len(self) + + def load_data(self, idx): + # === 1. Remap index and skip first few frames === + idx = self.idx_remap[idx] + file_idx, frame_idx = self.split_idx(idx) + frame_idx += 100 # initial few frames are low quality + + # === 2. Load paths and data arrays === + video_path = self.data_paths[file_idx] + action_path = video_path.with_suffix(".npz") + data = np.load(action_path) + actions_pool = convert_action_space(data["actions"]) + poses_pool = data["poses"] + + # Fix corrupted height (maybe) in the first frame + poses_pool[0, 1] = poses_pool[1, 1] + # assert poses_pool[:, 1].ptp() < 2, f"Height variation too large: {poses_pool[:, 1].ptp()} - {video_path}" + assert poses_pool[:, 1].ptp() < 2 + + # Pad poses if shorter than actions + if len(poses_pool) < len(actions_pool): + poses_pool = np.pad(poses_pool, ((1, 0), (0, 0))) + + # === 3. Load video clip === + video_raw = EncodedVideo.from_path(video_path, decode_audio=False) + fps = 10 + clip = video_raw.get_clip( + start_sec=frame_idx / fps, + end_sec=(frame_idx + self.n_frames) / fps + )["video"] + video = clip.permute(1, 2, 3, 0).numpy() + + actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) + poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames]) + + # === 4. Normalize poses relative to current segment === + def normalize_pose(pose, ref_pose): + pose[:, :3] -= ref_pose[:1, :3] + pose[:, -1] = -pose[:, -1] + pose[:, 3:] %= 360 + return pose + + poses_pool = normalize_pose(poses_pool, poses) + poses = normalize_pose(poses, poses) + + assert len(video) >= self.n_frames, f"{video_path}" + + # === 5. Sample memory frames for training === + if self.split == "training" and self.memory_condition_length > 0: + use_memory = random.uniform(0, 1) > self.training_dropout + + if use_memory: + # Compute pose distance between current and candidate frames + dis = np.abs(poses[:, None] - poses_pool[None, :]) + dis[..., 3:][dis[..., 3:] > 180] = 360 - dis[..., 3:][dis[..., 3:] > 180] + + spatial_match = (dis[..., :3] <= self.pos_range).sum(-1) >= 3 + angular_match = (dis[..., 3:] <= self.angle_range).sum(-1) >= 2 + not_exact_match = ((dis[..., :3] > 0).sum(-1) >= 1) | ((dis[..., 3:] > 0).sum(-1) >= 1) + + valid_index = (spatial_match & angular_match & not_exact_match).sum(0) + valid_index[:100] = 0 # skip unstable early frames + + # Exclude future if causality and timestamp are enabled + if self.add_timestamp_embedding and self.causal_frame and (actions_pool[:frame_idx, 24] == 1).sum() > 0: + valid_index[frame_idx:] = 0 + + # Select indices satisfying condition + mask = valid_index >= 1 + mask[0] = False + candidate_indices = np.argwhere(mask) + + # Backup candidates with weaker conditions + mask2 = valid_index >= 0 + mask2[0] = False + + count = min(self.memory_condition_length, candidate_indices.shape[0]) + selected = candidate_indices[np.random.choice(candidate_indices.shape[0], count, replace=True)][:, 0] + + if count < self.memory_condition_length: + extra = np.argwhere(mask2) + extra = extra[np.random.choice(extra.shape[0], self.memory_condition_length - count, replace=True)][:, 0] + selected = np.concatenate([selected, extra]) + + # Prioritize event-trigger frames if applicable + if self.sample_more_event and random.uniform(0, 1) > 0.3: + event_idx = torch.nonzero(actions_pool[:frame_idx, 24] == 1)[:, 0] + if len(event_idx) > self.memory_condition_length // 2: + event_idx = event_idx[-self.memory_condition_length // 2:] + if len(event_idx) > 0: + selected[-len(event_idx):] = event_idx + 4 + + else: + selected = np.full(self.memory_condition_length, random.randint(0, frame_idx)) + + # === 6. Retrieve video frames for selected memory indices === + video_pool = [] + for si in selected: + frame = video_raw.get_clip(start_sec=si / fps, end_sec=(si + 1) / fps)["video"][:, 0].permute(1, 2, 0) + video_pool.append(frame) + + video = np.concatenate([video, np.stack(video_pool)], axis=0) + actions = np.concatenate([actions, actions_pool[selected]], axis=0) + poses = np.concatenate([poses, poses_pool[selected]], axis=0) + timestamp = np.concatenate([np.arange(frame_idx, frame_idx + self.n_frames), selected]) + else: + timestamp = np.arange(self.n_frames) + + # === 7. Convert video to torch format === + video = torch.from_numpy(video / 255.0).float().permute(0, 3, 1, 2).contiguous() + + # === 9. Return all items === + return ( + video[:: self.frame_skip], + actions[:: self.frame_skip], + poses[:: self.frame_skip], + timestamp + ) diff --git a/datasets/video/minecraft_video_latent_dataset.py b/datasets/video/minecraft_video_latent_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..723e6672473144e475dc71c446540e9e9e503252 --- /dev/null +++ b/datasets/video/minecraft_video_latent_dataset.py @@ -0,0 +1,94 @@ +import json +from pathlib import Path + +import numpy as np +import torch +from omegaconf import DictConfig + +from .minecraft_video_dataset import MinecraftVideoDataset, convert_action_space + + +class MinecraftVideoLatentDataset(MinecraftVideoDataset): + """ + Minecraft dataset variant backed by precomputed VAE latent grids. + + It preserves the same action, pose, timestamp, and split indexing contract as + MinecraftVideoDataset, but returns latent tensors instead of RGB frames. + """ + + def __init__(self, cfg: DictConfig, split: str = "training"): + super().__init__(cfg, split) + self.feature_root = Path(cfg.precomputed_feature_dir) + self._cached_feature_path = None + self._cached_features = None + self._cached_image_hw = None + + def _feature_cache_paths(self, video_path: Path) -> tuple[Path, Path]: + relative_video_path = video_path.relative_to(self.save_dir) + feature_dir = self.feature_root / relative_video_path.parent + video_index = relative_video_path.stem + return ( + feature_dir / f"{video_index}_vae_feature.npy", + feature_dir / f"{video_index}_vae_feature_meta.json", + ) + + def _load_precomputed_features(self, video_path: Path) -> tuple[np.ndarray, tuple[int, int]]: + feature_path, meta_path = self._feature_cache_paths(video_path) + if not feature_path.exists(): + raise FileNotFoundError(f"Missing precomputed VAE features: {feature_path}") + + if self._cached_feature_path != str(feature_path): + self._cached_features = np.load(feature_path, mmap_mode="r") + if meta_path.exists(): + with open(meta_path, "r", encoding="utf-8") as handle: + meta = json.load(handle) + image_hw = (int(meta["image_height"]), int(meta["image_width"])) + else: + image_hw = (int(self.resolution), int(self.resolution)) + self._cached_feature_path = str(feature_path) + self._cached_image_hw = image_hw + + return self._cached_features, self._cached_image_hw + + def load_data(self, idx): + idx = self.idx_remap[idx] + file_idx, frame_idx = self.split_idx(idx) + frame_idx += 100 + + video_path = self.data_paths[file_idx] + action_path = video_path.with_suffix(".npz") + data = np.load(action_path) + actions_pool = convert_action_space(data["actions"]) + poses_pool = data["poses"] + + poses_pool[0, 1] = poses_pool[1, 1] + assert poses_pool[:, 1].ptp() < 2 + if len(poses_pool) < len(actions_pool): + poses_pool = np.pad(poses_pool, ((1, 0), (0, 0))) + + feature_pool, image_hw = self._load_precomputed_features(video_path) + latents = np.array( + feature_pool[frame_idx : frame_idx + self.n_frames], + dtype=np.float32, + copy=True, + ) + actions = np.copy(actions_pool[frame_idx : frame_idx + self.n_frames]) + poses = np.copy(poses_pool[frame_idx : frame_idx + self.n_frames]) + + def normalize_pose(pose, ref_pose): + pose[:, :3] -= ref_pose[:1, :3] + pose[:, -1] = -pose[:, -1] + pose[:, 3:] %= 360 + return pose + + poses = normalize_pose(poses, poses) + assert len(latents) >= self.n_frames, f"{video_path}" + + timestamp = np.arange(self.n_frames) + return ( + torch.from_numpy(latents).float().contiguous()[:: self.frame_skip], + actions[:: self.frame_skip], + poses[:: self.frame_skip], + timestamp[:: self.frame_skip], + torch.tensor(image_hw, dtype=torch.int32), + ) diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb267849646a72ef4dc18dc10cced74a55ee108d --- /dev/null +++ b/experiments/__init__.py @@ -0,0 +1,33 @@ +from typing import Optional, Union +from omegaconf import DictConfig +import pathlib +from lightning.pytorch.loggers.wandb import WandbLogger + +from .exp_base import BaseExperiment +from .exp_video import VideoPredictionExperiment + +# each key has to be a yaml file under '[project_root]/configurations/experiment' without .yaml suffix +exp_registry = dict( + exp_video=VideoPredictionExperiment +) + + +def build_experiment( + cfg: DictConfig, + logger: Optional[WandbLogger] = None, + ckpt_path: Optional[Union[str, pathlib.Path]] = None, +) -> BaseExperiment: + """ + Build an experiment instance based on registry + :param cfg: configuration file + :param logger: optional logger for the experiment + :param ckpt_path: optional checkpoint path for saving and loading + :return: + """ + if cfg.experiment._name not in exp_registry: + raise ValueError( + f"Experiment {cfg.experiment._name} not found in registry {list(exp_registry.keys())}. " + "Make sure you register it correctly in 'experiments/__init__.py' under the same name as yaml file." + ) + + return exp_registry[cfg.experiment._name](cfg, logger, ckpt_path) diff --git a/experiments/exp_base.py b/experiments/exp_base.py new file mode 100644 index 0000000000000000000000000000000000000000..80853eef2a8588460acea7e1ff1d01147f59ec9d --- /dev/null +++ b/experiments/exp_base.py @@ -0,0 +1,706 @@ +""" +This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research +template [repo](https://github.com/buoyancy99/research-template). +By its MIT license, you must keep the above sentence in `README.md` +and the `LICENSE` file to credit the author. +""" + +from abc import ABC, abstractmethod +from typing import Optional, Union, Literal, List, Dict +import pathlib +import os + +import hydra +import torch +from lightning.pytorch.strategies.ddp import DDPStrategy + +import lightning.pytorch as pl +from lightning.pytorch.loggers.wandb import WandbLogger +from lightning.pytorch.utilities.types import TRAIN_DATALOADERS +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_info + +from omegaconf import DictConfig + +from utils.print_utils import cyan +from utils.distributed_utils import is_rank_zero +from safetensors.torch import load_file, load_model +from pathlib import Path +from huggingface_hub import hf_hub_download +from huggingface_hub import model_info + +torch.set_float32_matmul_precision("high") + +def is_huggingface_model(path: str) -> bool: + hf_ckpt = str(path).split('/') + repo_id = '/'.join(hf_ckpt[:2]) + try: + model_info(repo_id) + return True + except: + return False + +def _extract_state_dict(checkpoint): + if isinstance(checkpoint, dict) and isinstance(checkpoint.get("state_dict"), dict): + return checkpoint["state_dict"] + return checkpoint + + +def _compatible_state_dict(algo, state_dict): + model_state = algo.state_dict() + prefixes = ( + "", + "model.", + "module.", + "algo.", + "diffusion_model.", + "diffusion_model.model.", + "vae.", + ) + best = ( + {}, + { + "matched": 0, + "missing": 0, + "unexpected": 0, + "skipped_prefix": 0, + "shape_mismatch": 0, + "missing_model": len(model_state), + "total": len(state_dict), + "prefix": "", + "unexpected_checkpoint_keys": [], + "shape_mismatch_keys": [], + "missing_model_keys": list(model_state.keys()), + }, + ) + for prefix in prefixes: + compatible = {} + unexpected_checkpoint_keys = [] + shape_mismatch_keys = [] + skipped_prefix = 0 + for key, value in state_dict.items(): + if key in ["data_mean", "data_std"]: + continue + if prefix and not key.startswith(prefix): + skipped_prefix += 1 + continue + stripped = key.removeprefix(prefix) + if stripped not in model_state: + unexpected_checkpoint_keys.append(key) + continue + if hasattr(value, "shape") and value.shape != model_state[stripped].shape: + shape_mismatch_keys.append( + ( + key, + stripped, + tuple(value.shape), + tuple(model_state[stripped].shape), + ) + ) + continue + compatible[stripped] = value + missing_model_keys = [key for key in model_state.keys() if key not in compatible] + if len(compatible) > best[1]["matched"]: + best = ( + compatible, + { + "matched": len(compatible), + "missing": skipped_prefix + len(unexpected_checkpoint_keys), + "unexpected": len(unexpected_checkpoint_keys), + "skipped_prefix": skipped_prefix, + "shape_mismatch": len(shape_mismatch_keys), + "missing_model": len(missing_model_keys), + "total": len(state_dict), + "prefix": prefix, + "unexpected_checkpoint_keys": unexpected_checkpoint_keys, + "shape_mismatch_keys": shape_mismatch_keys, + "missing_model_keys": missing_model_keys, + }, + ) + return best + + +def _key_matches_any(key: str, markers: tuple[str, ...]) -> bool: + return any(marker in key for marker in markers) + + +def _format_key_samples(keys, limit: int = 25, indent: str = " ") -> str: + sample = list(keys[:limit]) + lines = [f"{indent}{key}" for key in sample] + if len(keys) > limit: + lines.append(f"{indent}... {len(keys) - limit} more") + return "\n".join(lines) + + +def _format_shape_mismatch_samples(shape_mismatch_keys, limit: int = 25, indent: str = " ") -> str: + sample = list(shape_mismatch_keys[:limit]) + lines = [ + f"{indent}{checkpoint_key} -> {model_key}: checkpoint{checkpoint_shape} != model{model_shape}" + for checkpoint_key, model_key, checkpoint_shape, model_shape in sample + ] + if len(shape_mismatch_keys) > limit: + lines.append(f"{indent}... {len(shape_mismatch_keys) - limit} more") + return "\n".join(lines) + + +def _log_checkpoint_mismatch_report( + stats, + checkpoint_path, + label: str | None = None, + dememwm_key_check: bool = False, +) -> None: + mismatch_count = stats["missing_model"] + stats["unexpected"] + stats["shape_mismatch"] + if mismatch_count == 0: + return + + title = label or str(checkpoint_path) + lines = [ + f"Checkpoint mismatch report for {title}:", + f" checkpoint={checkpoint_path}", + " selected_prefix={!r}".format(stats["prefix"]), + ( + " counts: " + "matched={matched} " + "model_not_loaded={model_not_loaded} " + "checkpoint_not_used={checkpoint_not_used} " + "shape_mismatch={shape_mismatch} " + "skipped_by_prefix={skipped_by_prefix}" + ).format( + matched=stats["matched"], + model_not_loaded=stats["missing_model"], + checkpoint_not_used=stats["unexpected"], + shape_mismatch=stats["shape_mismatch"], + skipped_by_prefix=stats["skipped_prefix"], + ), + ] + if stats["missing_model_keys"]: + lines.append(" Model keys not loaded from checkpoint:") + lines.append(_format_key_samples(stats["missing_model_keys"])) + if stats["unexpected_checkpoint_keys"]: + lines.append(" Checkpoint keys not used by current model:") + lines.append(_format_key_samples(stats["unexpected_checkpoint_keys"])) + if stats["shape_mismatch_keys"]: + lines.append(" Shape mismatches:") + lines.append(_format_shape_mismatch_samples(stats["shape_mismatch_keys"])) + + if dememwm_key_check: + markers = ("dememwm_", ".dememwm_", ".memory_token_cross_attn.") + missing_dememwm = [key for key in stats["missing_model_keys"] if _key_matches_any(key, markers)] + unexpected_dememwm = [ + key for key in stats["unexpected_checkpoint_keys"] if _key_matches_any(key, markers) + ] + shape_dememwm = [ + item + for item in stats["shape_mismatch_keys"] + if _key_matches_any(item[0], markers) or _key_matches_any(item[1], markers) + ] + if missing_dememwm or unexpected_dememwm or shape_dememwm: + lines.append(" DeMemWM mismatch subset:") + if missing_dememwm: + lines.append(" DeMemWM model keys not loaded:") + lines.append(_format_key_samples(missing_dememwm, indent=" ")) + if unexpected_dememwm: + lines.append(" DeMemWM checkpoint keys not used:") + lines.append(_format_key_samples(unexpected_dememwm, indent=" ")) + if shape_dememwm: + lines.append(" DeMemWM shape mismatches:") + lines.append(_format_shape_mismatch_samples(shape_dememwm, indent=" ")) + + rank_zero_info("\n".join(lines)) + + +def load_custom_checkpoint( + algo, + checkpoint_path, + require_match: bool = False, + label: str | None = None, + dememwm_key_check: bool = False, + report_key_mismatch: bool = False, +): + if not checkpoint_path: + if require_match: + target = label or "model" + raise FileNotFoundError(f"Expected checkpoint for {target}, but no path was provided.") + rank_zero_info("No checkpoint path provided, skipping checkpoint loading.") + return None + + if not isinstance(checkpoint_path, Path): + checkpoint_path = Path(checkpoint_path) + + if is_huggingface_model(str(checkpoint_path)): + hf_ckpt = str(checkpoint_path).split("/") + repo_id = "/".join(hf_ckpt[:2]) + file_name = "/".join(hf_ckpt[2:]) + model_path = hf_hub_download(repo_id=repo_id, filename=file_name) + ckpt = torch.load(model_path, map_location=torch.device("cpu")) + state_dict = _extract_state_dict(ckpt) + + elif checkpoint_path.suffix == ".pt": + ckpt = torch.load(checkpoint_path, weights_only=True) + state_dict = _extract_state_dict(ckpt) + + elif checkpoint_path.suffix == ".ckpt": + ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu")) + state_dict = _extract_state_dict(ckpt) + + elif checkpoint_path.suffix == ".safetensors": + state_dict = load_file(checkpoint_path) + + elif os.path.isdir(checkpoint_path): + ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".ckpt")] + if not ckpt_files: + raise FileNotFoundError("No .ckpt files found in the specified directory!") + selected_ckpt = max(ckpt_files) + selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt) + print(f"Checkpoint file selected for loading: {selected_ckpt_path}") + + ckpt = torch.load(selected_ckpt_path, map_location=torch.device("cpu")) + state_dict = _extract_state_dict(ckpt) + + else: + raise ValueError(f"Unsupported checkpoint: {checkpoint_path}") + if dememwm_key_check and hasattr(algo, "strict_dememwm_checkpoint_key_check"): + algo.strict_dememwm_checkpoint_key_check(state_dict) + compatible, stats = _compatible_state_dict(algo, state_dict) + if require_match and stats["matched"] == 0: + raise RuntimeError(f"Expected checkpoint for {label or checkpoint_path} matched zero model weights.") + if compatible: + algo.load_state_dict(compatible, strict=False) + elif checkpoint_path.suffix == ".safetensors": + load_model(algo, checkpoint_path, strict=False) + rank_zero_info( + "Model weights loaded from {}: matched={} missing={} shape_mismatch={} prefix={!r}".format( + checkpoint_path, + stats["matched"], + stats["missing"], + stats["shape_mismatch"], + stats["prefix"], + ) + ) + if report_key_mismatch: + _log_checkpoint_mismatch_report(stats, checkpoint_path, label=label, dememwm_key_check=dememwm_key_check) + return stats + + +class BaseExperiment(ABC): + """ + Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more + flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks. + """ + + # each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix + compatible_algorithms: Dict = NotImplementedError + + def __init__( + self, + root_cfg: DictConfig, + logger: Optional[WandbLogger] = None, + ckpt_path: Optional[Union[str, pathlib.Path]] = None, + ) -> None: + """ + Constructor + + Args: + cfg: configuration file that contains everything about the experiment + logger: a pytorch-lightning WandbLogger instance + ckpt_path: an optional path to saved checkpoint + """ + super().__init__() + self.root_cfg = root_cfg + self.cfg = root_cfg.experiment + self.debug = root_cfg.debug + self.logger = logger + self.ckpt_path = ckpt_path + self.algo = None + self.customized_load = getattr(root_cfg, "customized_load", False) + self.seperate_load = getattr(root_cfg, "seperate_load", False) + self.zero_init_gate= getattr(root_cfg, "zero_init_gate", False) + self.only_tune_memory = getattr(root_cfg, "only_tune_memory", False) + self.diffusion_model_path = getattr(root_cfg, "diffusion_model_path", None) + self.vae_path = getattr(root_cfg, "vae_path", None) + self.pose_predictor_path = getattr(root_cfg, "pose_predictor_path", None) + self.auto_resuming = getattr(root_cfg, "_auto_resuming", False) + + def _build_algo(self): + """ + Build the lightning module + :return: a pytorch-lightning module to be launched + """ + algo_name = self.root_cfg.algorithm._name + if algo_name not in self.compatible_algorithms: + raise ValueError( + f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. " + "Make sure you define compatible_algorithms correctly and make sure that each key has " + "same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix" + ) + return self.compatible_algorithms[algo_name](self.root_cfg.algorithm) + + def exec_task(self, task: str) -> None: + """ + Executing a certain task specified by string. Each task should be a stage of experiment. + In most computer vision / nlp applications, tasks should be just train and test. + In reinforcement learning, you might have more stages such as collecting dataset etc + + Args: + task: a string specifying a task implemented for this experiment + """ + if hasattr(self, task) and callable(getattr(self, task)): + if is_rank_zero: + print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}") + getattr(self, task)() + else: + raise ValueError( + f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable." + ) + + def exec_interactive(self, task: str) -> None: + """ + Executing a certain task specified by string. Each task should be a stage of experiment. + In most computer vision / nlp applications, tasks should be just train and test. + In reinforcement learning, you might have more stages such as collecting dataset etc + + Args: + task: a string specifying a task implemented for this experiment + """ + if hasattr(self, task) and callable(getattr(self, task)): + if is_rank_zero: + print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}") + return getattr(self, task)() + else: + raise ValueError( + f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable." + ) + +class BaseLightningExperiment(BaseExperiment): + """ + Abstract class for pytorch lightning experiments. Useful for computer vision & nlp where main components are + simply models, datasets and train loop. + """ + + # each key has to be a yaml file under '[project_root]/configurations/algorithm' without .yaml suffix + compatible_algorithms: Dict = NotImplementedError + + # each key has to be a yaml file under '[project_root]/configurations/dataset' without .yaml suffix + compatible_datasets: Dict = NotImplementedError + + def _build_trainer_callbacks(self): + callbacks = [] + if self.logger: + callbacks.append(LearningRateMonitor("step", True)) + + def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]: + train_dataset = self._build_dataset("training") + shuffle = ( + False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle + ) + if train_dataset: + return torch.utils.data.DataLoader( + train_dataset, + batch_size=self.cfg.training.batch_size, + num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers), + shuffle=shuffle, + persistent_workers=True, + pin_memory=torch.cuda.is_available(), + ) + else: + return None + + def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]: + validation_dataset = self._build_dataset("validation") + shuffle = ( + False + if isinstance(validation_dataset, torch.utils.data.IterableDataset) + else self.cfg.validation.data.shuffle + ) + if validation_dataset: + return torch.utils.data.DataLoader( + validation_dataset, + batch_size=self.cfg.validation.batch_size, + num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers), + shuffle=shuffle, + persistent_workers=True, + pin_memory=torch.cuda.is_available(), + ) + else: + return None + + def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]: + test_dataset = self._build_dataset("test") + shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle + if test_dataset: + return torch.utils.data.DataLoader( + test_dataset, + batch_size=self.cfg.test.batch_size, + num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers), + shuffle=shuffle, + persistent_workers=True, + pin_memory=torch.cuda.is_available(), + ) + else: + return None + + def training(self) -> None: + """ + All training happens here + """ + if not self.algo: + self.algo = self._build_algo() + if self.cfg.training.compile: + self.algo = torch.compile(self.algo) + + callbacks = [] + if self.logger: + callbacks.append(LearningRateMonitor("step", True)) + if "checkpointing" in self.cfg.training: + callbacks.append( + ModelCheckpoint( + pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints", + filename="epoch{epoch}_step{step}", + auto_insert_metric_name=False, + **self.cfg.training.checkpointing, + ) + ) + + trainer = pl.Trainer( + accelerator="auto", + devices="auto", + strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto", + logger=self.logger or False, + callbacks=callbacks, + gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, + val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None, + limit_val_batches=self.cfg.validation.limit_batch, + check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None, + accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, + precision=self.cfg.training.precision or 32, + detect_anomaly=False, + num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0, + max_epochs=self.cfg.training.max_epochs, + max_steps=self.cfg.training.max_steps, + max_time=self.cfg.training.max_time + ) + + + if self.auto_resuming: + self.algo._strict_resume_state = True + trainer.fit( + self.algo, + train_dataloaders=self._build_training_loader(), + val_dataloaders=self._build_validation_loader(), + ckpt_path=self.ckpt_path, + ) + elif self.customized_load: + if self.seperate_load: + if 'oasis500m' in self.diffusion_model_path: + load_custom_checkpoint( + algo=self.algo.diffusion_model.model, + checkpoint_path=self.diffusion_model_path, + require_match=True, + label="diffusion_model.model", + ) + else: + load_custom_checkpoint( + algo=self.algo.diffusion_model, + checkpoint_path=self.diffusion_model_path, + require_match=True, + label="diffusion_model", + ) + load_custom_checkpoint(algo=self.algo.vae, checkpoint_path=self.vae_path, require_match=True, label="vae") + else: + load_custom_checkpoint(algo=self.algo, checkpoint_path=self.ckpt_path, dememwm_key_check=True) + + if self.zero_init_gate: + for name, para in self.algo.diffusion_model.named_parameters(): + if 'r_adaLN_modulation' in name: + para.requires_grad_(False) + para[2*1024:3*1024] = 0 + para[5*1024:6*1024] = 0 + para.requires_grad_(True) + + if self.only_tune_memory: + for name, para in self.algo.diffusion_model.named_parameters(): + para.requires_grad_(False) + if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name: + para.requires_grad_(True) + + trainer.fit( + self.algo, + train_dataloaders=self._build_training_loader(), + val_dataloaders=self._build_validation_loader(), + ckpt_path=None, + ) + else: + + if self.only_tune_memory: + for name, para in self.algo.diffusion_model.named_parameters(): + para.requires_grad_(False) + if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name: + para.requires_grad_(True) + + trainer.fit( + self.algo, + train_dataloaders=self._build_training_loader(), + val_dataloaders=self._build_validation_loader(), + ckpt_path=self.ckpt_path, + ) + + def validation(self) -> None: + """ + All validation happens here + """ + if not self.algo: + self.algo = self._build_algo() + if self.cfg.validation.compile: + self.algo = torch.compile(self.algo) + + callbacks = [] + + trainer = pl.Trainer( + accelerator="auto", + logger=self.logger, + devices="auto", + num_nodes=self.cfg.num_nodes, + strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto", + callbacks=callbacks, + limit_val_batches=self.cfg.validation.limit_batch, + precision=self.cfg.validation.precision, + detect_anomaly=False, # self.cfg.debug, + inference_mode=self.cfg.validation.inference_mode, + ) + + if self.customized_load: + if self.seperate_load: + if 'oasis500m' in self.diffusion_model_path: + load_custom_checkpoint( + algo=self.algo.diffusion_model.model, + checkpoint_path=self.diffusion_model_path, + require_match=True, + label="diffusion_model.model", + report_key_mismatch=True, + ) + else: + load_custom_checkpoint( + algo=self.algo.diffusion_model, + checkpoint_path=self.diffusion_model_path, + require_match=True, + label="diffusion_model", + report_key_mismatch=True, + ) + load_custom_checkpoint( + algo=self.algo.vae, + checkpoint_path=self.vae_path, + require_match=True, + label="vae", + report_key_mismatch=True, + ) + else: + load_custom_checkpoint( + algo=self.algo, + checkpoint_path=self.ckpt_path, + label="algo", + dememwm_key_check=True, + report_key_mismatch=True, + ) + + if self.zero_init_gate: + for name, para in self.algo.diffusion_model.named_parameters(): + if 'r_adaLN_modulation' in name: + para.requires_grad_(False) + para[2*1024:3*1024] = 0 + para[5*1024:6*1024] = 0 + para.requires_grad_(True) + + trainer.validate( + self.algo, + dataloaders=self._build_validation_loader(), + ckpt_path=None, + ) + else: + trainer.validate( + self.algo, + dataloaders=self._build_validation_loader(), + ckpt_path=self.ckpt_path, + ) + + def test(self) -> None: + """ + All testing happens here + """ + if not self.algo: + self.algo = self._build_algo() + if self.cfg.test.compile: + self.algo = torch.compile(self.algo) + + callbacks = [] + + trainer = pl.Trainer( + accelerator="auto", + logger=self.logger, + devices="auto", + num_nodes=self.cfg.num_nodes, + strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto", + callbacks=callbacks, + limit_test_batches=self.cfg.test.limit_batch, + precision=self.cfg.test.precision, + detect_anomaly=False, # self.cfg.debug, + ) + + if self.customized_load: + if self.seperate_load: + if 'oasis500m' in self.diffusion_model_path: + load_custom_checkpoint( + algo=self.algo.diffusion_model.model, + checkpoint_path=self.diffusion_model_path, + require_match=True, + label="diffusion_model.model", + report_key_mismatch=True, + ) + else: + load_custom_checkpoint( + algo=self.algo.diffusion_model, + checkpoint_path=self.diffusion_model_path, + require_match=True, + label="diffusion_model", + report_key_mismatch=True, + ) + load_custom_checkpoint( + algo=self.algo.vae, + checkpoint_path=self.vae_path, + require_match=True, + label="vae", + report_key_mismatch=True, + ) + else: + load_custom_checkpoint( + algo=self.algo, + checkpoint_path=self.ckpt_path, + label="algo", + dememwm_key_check=True, + report_key_mismatch=True, + ) + + if self.zero_init_gate: + for name, para in self.algo.diffusion_model.named_parameters(): + if 'r_adaLN_modulation' in name: + para.requires_grad_(False) + para[2*1024:3*1024] = 0 + para[5*1024:6*1024] = 0 + para.requires_grad_(True) + + trainer.test( + self.algo, + dataloaders=self._build_test_loader(), + ckpt_path=None, + ) + else: + trainer.test( + self.algo, + dataloaders=self._build_test_loader(), + ckpt_path=self.ckpt_path, + ) + + def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]: + if split in ["training", "test", "validation"]: + return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split) + else: + raise NotImplementedError(f"split '{split}' is not implemented") diff --git a/experiments/exp_video.py b/experiments/exp_video.py new file mode 100644 index 0000000000000000000000000000000000000000..28fa6ba546677fd05cf756f84ff4244e6e18c14b --- /dev/null +++ b/experiments/exp_video.py @@ -0,0 +1,23 @@ +from datasets.video import ( + MinecraftVideoDataset, + MinecraftVideoLatentDataset, +) + +from algorithms.worldmem import DeMemWMMinecraft +from .exp_base import BaseLightningExperiment + + +class VideoPredictionExperiment(BaseLightningExperiment): + """ + A video prediction experiment + """ + + compatible_algorithms = dict( + dememwm_memory_dit=DeMemWMMinecraft, + ) + + compatible_datasets = dict( + # video datasets + video_minecraft=MinecraftVideoDataset, + video_minecraft_latent=MinecraftVideoLatentDataset, + ) diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3300edfce4fea5277aeae19ad92491f0beffba --- /dev/null +++ b/main.py @@ -0,0 +1,318 @@ +""" +This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research +template [repo](https://github.com/buoyancy99/research-template). +By its MIT license, you must keep the above sentence in `README.md` +and the `LICENSE` file to credit the author. + +Main file for the project. This will create and run new experiments and load checkpoints from wandb. +Borrowed part of the code from David Charatan and wandb. +""" + +import os +import sys +import subprocess +import time +import re +from pathlib import Path + +import hydra +from omegaconf import DictConfig, OmegaConf +from omegaconf.omegaconf import open_dict + +from utils.print_utils import cyan +from utils.ckpt_utils import download_latest_checkpoint, is_run_id +from utils.cluster_utils import submit_slurm_job +from utils.distributed_utils import is_rank_zero + +WANDB_RUN_ID_FILE = ".wandb_run_id" + + +def get_latest_checkpoint(checkpoint_folder: Path, pattern: str = "*.ckpt"): + if not checkpoint_folder.exists(): + return None + + checkpoint_files = [path for path in checkpoint_folder.glob(pattern) if path.is_file() and path.stat().st_size > 0] + if not checkpoint_files: + return None + + last_checkpoint = checkpoint_folder / "last.ckpt" + if last_checkpoint in checkpoint_files: + return last_checkpoint + + def checkpoint_key(path: Path): + step_match = re.search(r"step[=_-]?(\d+)", path.stem) + step = int(step_match.group(1)) if step_match else -1 + return step, path.stat().st_mtime + + return max(checkpoint_files, key=checkpoint_key) + + +def validate_resume_checkpoint(checkpoint_path: Path) -> Path: + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Resume checkpoint does not exist: {checkpoint_path}") + if not checkpoint_path.is_file(): + raise ValueError(f"Resume checkpoint is not a file: {checkpoint_path}") + if checkpoint_path.suffix != ".ckpt": + raise ValueError(f"Resume checkpoint must be a .ckpt file: {checkpoint_path}") + if checkpoint_path.stat().st_size == 0: + raise ValueError(f"Resume checkpoint is empty: {checkpoint_path}") + return checkpoint_path + + +def discover_wandb_run_id(output_dir: Path): + run_id_file = output_dir / WANDB_RUN_ID_FILE + if run_id_file.exists(): + run_id = run_id_file.read_text().strip() + if not run_id: + raise ValueError(f"W&B run id file is empty: {run_id_file}") + return run_id + + wandb_dir = output_dir / "wandb" + if wandb_dir.exists(): + run_dirs = [path for path in wandb_dir.iterdir() if path.is_dir()] + run_dirs = [path for path in run_dirs if re.match(r"(offline-run|run)-.+-[A-Za-z0-9]+$", path.name)] + if run_dirs: + run_dir = max(run_dirs, key=lambda path: path.stat().st_mtime) + return run_dir.name.rsplit("-", 1)[-1] + return None + + +def get_process_rank() -> int: + for env_name in ("RANK", "SLURM_PROCID", "LOCAL_RANK"): + value = os.environ.get(env_name) + if value is not None: + return int(value) + return 0 + + +def wait_for_wandb_run_id(output_dir: Path, timeout_s: float = 300.0): + run_id_file = output_dir / WANDB_RUN_ID_FILE + deadline = time.time() + timeout_s + while time.time() < deadline: + if run_id_file.exists(): + run_id = run_id_file.read_text().strip() + if run_id: + return run_id + time.sleep(0.5) + raise TimeoutError(f"Timed out waiting for rank 0 to create W&B run id file: {run_id_file}") + + +def create_wandb_run_id_on_rank_zero(output_dir: Path, requested_run_id=None): + if requested_run_id: + run_id = requested_run_id + else: + run_id = discover_wandb_run_id(output_dir) + if run_id is None: + import wandb + run_id = wandb.util.generate_id() + + output_dir.mkdir(parents=True, exist_ok=True) + run_id_file = output_dir / WANDB_RUN_ID_FILE + if run_id_file.exists(): + existing_run_id = run_id_file.read_text().strip() + if existing_run_id and existing_run_id != run_id: + raise ValueError( + f"Output directory already belongs to W&B run id {existing_run_id}, " + f"but {run_id} was requested. Use a different output_dir or resume id." + ) + run_id_file.write_text(f"{run_id}\n") + return run_id + + +def get_or_create_wandb_run_id(output_dir: Path, requested_run_id=None): + if get_process_rank() == 0: + return create_wandb_run_id_on_rank_zero(output_dir, requested_run_id=requested_run_id) + return wait_for_wandb_run_id(output_dir) + + +def run_local(cfg: DictConfig): + # delay some imports in case they are not needed in non-local envs for submission + from experiments import build_experiment + from utils.wandb_utils import OfflineWandbLogger, SpaceEfficientWandbLogger + import lightning.pytorch as pl + + # Set global seed for reproducibility + if cfg.get("seed", None) is not None: + pl.seed_everything(cfg.seed, workers=True) + + # Get yaml names + hydra_cfg = hydra.core.hydra_config.HydraConfig.get() + cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices) + + with open_dict(cfg): + if cfg_choice["experiment"] is not None: + cfg.experiment._name = cfg_choice["experiment"] + if cfg_choice["dataset"] is not None: + cfg.dataset._name = cfg_choice["dataset"] + if cfg_choice["algorithm"] is not None: + cfg.algorithm._name = cfg_choice["algorithm"] + + # Set up the output directory. + output_dir = getattr(cfg, "output_dir", None) + if output_dir is not None: + OmegaConf.set_readonly(hydra_cfg, False) + hydra_cfg.runtime.output_dir = output_dir + OmegaConf.set_readonly(hydra_cfg, True) + + output_dir = Path(hydra_cfg.runtime.output_dir) + + if is_rank_zero: + print(cyan(f"Outputs will be saved to:"), output_dir) + (output_dir.parents[1] / "latest-run").unlink(missing_ok=True) + (output_dir.parents[1] / "latest-run").symlink_to(output_dir, target_is_directory=True) + + training_requested = "training" in cfg.experiment.tasks + checkpoint_dir = output_dir / "checkpoints" + auto_resume = bool(getattr(cfg, "auto_resume", True)) + explicit_resume_ckpt = getattr(cfg, "resume_ckpt_path", None) + auto_resume_checkpoint_path = None + if training_requested: + if explicit_resume_ckpt: + auto_resume_checkpoint_path = validate_resume_checkpoint(Path(explicit_resume_ckpt)) + elif auto_resume: + auto_resume_checkpoint_path = get_latest_checkpoint(checkpoint_dir) + if auto_resume_checkpoint_path is not None: + auto_resume_checkpoint_path = validate_resume_checkpoint(auto_resume_checkpoint_path) + + if auto_resume_checkpoint_path and is_rank_zero: + print(cyan("Auto-resuming training from:"), auto_resume_checkpoint_path) + + with open_dict(cfg): + cfg._auto_resuming = auto_resume_checkpoint_path is not None + cfg._resume_checkpoint_path = str(auto_resume_checkpoint_path) if auto_resume_checkpoint_path else None + + # Set up logging with wandb. + if cfg.wandb.mode != "disabled": + # If resuming, merge into the existing run on wandb. + resume = cfg.get("resume", None) + wandb_run_id = get_or_create_wandb_run_id(output_dir, requested_run_id=resume) + name = None if auto_resume_checkpoint_path else f"{cfg.name} ({output_dir.parent.name}/{output_dir.name})" + + if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline: + logger_cls = OfflineWandbLogger + else: + logger_cls = SpaceEfficientWandbLogger + + offline = cfg.wandb.mode != "online" + logger = logger_cls( + name=name, + save_dir=str(output_dir), + offline=offline, + entity=cfg.wandb.entity, + project=cfg.wandb.project, + log_model=False, + config=OmegaConf.to_container(cfg), + id=wandb_run_id, + resume="auto" + ) + + else: + logger = None + + # Load ckpt + resume = cfg.get("resume", None) + load = cfg.get("load", None) + checkpoint_path = auto_resume_checkpoint_path + load_id = None + if checkpoint_path is None and load and not is_run_id(load): + checkpoint_path = load + if checkpoint_path is None and resume: + load_id = resume + elif checkpoint_path is None and load and is_run_id(load): + load_id = load + else: + load_id = None + + if load_id: + checkpoint_path = get_latest_checkpoint(output_dir / "checkpoints") + if checkpoint_path is None: + raise FileNotFoundError(f"No checkpoint found under {output_dir / 'checkpoints'} for run id {load_id}") + checkpoint_path = validate_resume_checkpoint(checkpoint_path) + + if checkpoint_path and is_rank_zero: + print(f"Will load checkpoint from {checkpoint_path}") + + # launch experiment + experiment = build_experiment(cfg, logger, checkpoint_path) + for task in cfg.experiment.tasks: + experiment.exec_task(task) + + +def run_slurm(cfg: DictConfig): + python_args = " ".join(sys.argv[1:]) + " +_on_compute_node=True" + project_root = Path.cwd() + while not (project_root / ".git").exists(): + project_root = project_root.parent + if project_root == Path("/"): + raise Exception("Could not find repo directory!") + + slurm_log_dir = submit_slurm_job( + cfg, + python_args, + project_root, + ) + + if "cluster" in cfg and cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online": + print("Job submitted to a compute node without internet. This requires manual syncing on login node.") + osh_command_dir = project_root / ".wandb_osh_command_dir" + + osh_proc = None + # if click.confirm("Do you want us to run the sync loop for you?", default=True): + osh_proc = subprocess.Popen(["wandb-osh", "--command-dir", osh_command_dir]) + print(f"Running wandb-osh in background... PID: {osh_proc.pid}") + print(f"To kill the sync process, run 'kill {osh_proc.pid}' in the terminal.") + print( + f"You can manually start a sync loop later by running the following:", + cyan(f"wandb-osh --command-dir {osh_command_dir}"), + ) + + print( + "Once the job gets allocated and starts running, we will print a command below " + "for you to trace the errors and outputs: (Ctrl + C to exit without waiting)" + ) + msg = f"tail -f {slurm_log_dir}/* \n" + try: + while not list(slurm_log_dir.glob("*.out")) and not list(slurm_log_dir.glob("*.err")): + time.sleep(1) + print(cyan("To trace the outputs and errors, run the following command:"), msg) + except KeyboardInterrupt: + print("Keyboard interrupt detected. Exiting...") + print( + cyan("To trace the outputs and errors, manually wait for the job to start and run the following command:"), + msg, + ) + + +@hydra.main( + version_base=None, + config_path="configurations", + config_name="training", +) +def run(cfg: DictConfig): + if "_on_compute_node" in cfg and cfg.cluster.is_compute_node_offline: + with open_dict(cfg): + if cfg.cluster.is_compute_node_offline and cfg.wandb.mode == "online": + cfg.wandb.mode = "offline" + + if "name" not in cfg: + raise ValueError("must specify a name for the run with command line argument '+name=[name]'") + + if not cfg.wandb.get("entity", None): + raise ValueError( + "must specify wandb entity in 'configurations/config.yaml' or with command line" + " argument 'wandb.entity=[entity]' \n An entity is your wandb user name or group" + " name. This is used for logging. If you don't have an wandb account, please signup at https://wandb.ai/" + ) + + if cfg.wandb.project is None: + cfg.wandb.project = str(Path(__file__).parent.name) + + if "cluster" in cfg and not "_on_compute_node" in cfg: + print(cyan("Slurm detected, submitting to compute node instead of running locally...")) + run_slurm(cfg) + else: + run_local(cfg) + + +if __name__ == "__main__": + run() # pylint: disable=no-value-for-parameter diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f5f85c586efbe82f97fe81724739febdde77798b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +torch~=2.4.0 +torchvision~=0.19.1 +pytest +lightning~=2.1.2 +wandb~=0.17.0 +hydra-core~=1.3.2 +omegaconf~=2.3.0 +torchmetrics[image]==0.11.4 +wandb-osh==1.2.1 +gluonts[torch]==0.13.1 +pytorchvideo~=0.1.5 +colorama +tqdm +opencv-python +matplotlib +click +moviepy==1.0.3 +imageio +einops +pandas +pyzmq +pyrealsense2 +internetarchive +h5py +rotary_embedding_torch +diffusers +timm +gradio +spaces \ No newline at end of file diff --git a/scripts/dememwm_full_eval.slurm b/scripts/dememwm_full_eval.slurm new file mode 100644 index 0000000000000000000000000000000000000000..e276fef6f1de0a94a952908c12acbc92f332367e --- /dev/null +++ b/scripts/dememwm_full_eval.slurm @@ -0,0 +1,179 @@ +#!/usr/bin/env bash +#SBATCH --job-name=dememwm_full_eval +#SBATCH --partition=gpu +#SBATCH --time=1-00:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=256G +#SBATCH --gres=gpu:1 +#SBATCH --chdir=/share_1/users/bonan_ding/DeMemWM +#SBATCH --output=/share_1/users/bonan_ding/DeMemWM/slurm_logs/%x_%j.out +#SBATCH --error=/share_1/users/bonan_ding/DeMemWM/slurm_logs/%x_%j.err + +# Full DeMemWM evaluation script for DeMemWM/H200. +# Submit from the remote repo after training has produced a checkpoint: +# sbatch --export=ALL,CHECKPOINT=/share_1/users/bonan_ding/DeMemWM/outputs//train/checkpoints/last.ckpt scripts/dememwm_full_eval.slurm +# or: +# CHECKPOINT=/path/to/last.ckpt sbatch --export=ALL scripts/dememwm_full_eval.slurm + +set -euo pipefail + +CHECKPOINT=${CHECKPOINT:-${1:-}} +if [[ -z "${CHECKPOINT}" ]]; then + echo "ERROR: set CHECKPOINT=/path/to/dememwm.ckpt, e.g." >&2 + echo " sbatch --export=ALL,CHECKPOINT=/share_1/users/bonan_ding/DeMemWM/outputs//train/checkpoints/last.ckpt scripts/dememwm_full_eval.slurm" >&2 + exit 2 +fi +if [[ ! -s "${CHECKPOINT}" ]]; then + echo "ERROR: checkpoint does not exist or is empty: ${CHECKPOINT}" >&2 + exit 2 +fi + +REPO=${REPO:-/share_1/users/bonan_ding/DeMemWM} +DATA_DIR=${DATA_DIR:-/share_1/users/bonan_ding/worldmem_data/minecraft} +FEATURE_DIR=${FEATURE_DIR:-/share_1/users/bonan_ding/worldmem_data/minecraft/vae_features} + +RUN_TAG=${RUN_TAG:-dememwm_full_eval_${SLURM_JOB_ID:-manual_$(date +%Y%m%d_%H%M%S)}} +RUN_ROOT=${RUN_ROOT:-${REPO}/outputs/${RUN_TAG}} +EVAL_OUT=${EVAL_OUT:-${RUN_ROOT}/eval} +LOG_DIR=${LOG_DIR:-${REPO}/slurm_logs/${RUN_TAG}} +mkdir -p "${EVAL_OUT}" "${LOG_DIR}" "${REPO}/slurm_logs" + +DATASET_N_FRAMES=${DATASET_N_FRAMES:-300} +N_FRAMES_VALID=${N_FRAMES_VALID:-216} +CONTEXT_FRAMES=${CONTEXT_FRAMES:-116} +N_TOKENS=${N_TOKENS:-8} +SAMPLING_TIMESTEPS=${SAMPLING_TIMESTEPS:-20} +VAL_BATCH_SIZE=${VAL_BATCH_SIZE:-1} +VAL_LIMIT=${VAL_LIMIT:-16} +LOG_VIDEO=${LOG_VIDEO:-true} +SEED=${SEED:-42} +ABLATION_BRANCH=${ABLATION_BRANCH:-A_plus_D_plus_R_normal} + +# Consumed DeMemWM memory-shape knobs for current latent setup. +# Anchor: ratio 6 over 18x32 -> 4 prefixes * 3x6 pooled slots = 72 tokens. +# Revisit: ratio 3 over 18x32 -> 2 frames * 6x11 pooled slots = 132 tokens. +ANCHOR_DOWNSAMPLE_RATIO=${ANCHOR_DOWNSAMPLE_RATIO:-6} +REVISIT_MAX_FRAMES=${REVISIT_MAX_FRAMES:-2} +REVISIT_DOWNSAMPLE_RATIO=${REVISIT_DOWNSAMPLE_RATIO:-3} + +cd "${REPO}" +source ~/.bashrc >/dev/null 2>&1 || true +if command -v conda >/dev/null 2>&1; then + eval "$(conda shell.bash hook)" +elif [[ -f "${HOME}/.conda/etc/profile.d/conda.sh" ]]; then + source "${HOME}/.conda/etc/profile.d/conda.sh" +elif [[ -f /share_0/conda/etc/profile.d/conda.sh ]]; then + source /share_0/conda/etc/profile.d/conda.sh +fi +conda activate worldmem +PY=$(which python) + +export PYTHONPATH="./:${PYTHONPATH:-}" +export HYDRA_FULL_ERROR=1 +export PYTHONWARNINGS=ignore +export OMP_NUM_THREADS="${SLURM_CPUS_PER_TASK:-16}" +export WANDB_MODE=offline +export NCCL_P2P_DISABLE=1 +wandb offline >/dev/null 2>&1 || true + +echo "JOB_ID=${SLURM_JOB_ID:-manual}" +echo "RUN_TAG=${RUN_TAG}" +echo "RUN_ROOT=${RUN_ROOT}" +echo "CHECKPOINT=${CHECKPOINT}" +echo "ABLATION_BRANCH=${ABLATION_BRANCH}" +echo "HOST=$(hostname)" +echo "START=$(date --iso-8601=seconds)" +echo "PWD=$PWD" +echo "PY=${PY}" +"${PY}" --version +nvidia-smi || true +nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits > "${LOG_DIR}/gpu_memory_before_mb.txt" || true +git branch --show-current || true +git rev-parse HEAD || true + +EVAL_ARGS=( + "+name=eval_${RUN_TAG}" + "+output_dir=${EVAL_OUT}/" + "experiment.tasks=[validation]" + "wandb.mode=offline" + "dataset.validation_multiplier=1" + "+dataset.seed=${SEED}" + "+customized_load=true" + "+seperate_load=false" + "algorithm=dememwm_memory_dit" + "load=${CHECKPOINT}" + "dataset=video_minecraft_latent" + "dataset.save_dir=${DATA_DIR}" + "dataset.precomputed_feature_dir=${FEATURE_DIR}" + "dataset.n_frames=${DATASET_N_FRAMES}" + "+dataset.n_frames_valid=${N_FRAMES_VALID}" + "+dataset.customized_validation=true" + "+dataset.memory_condition_length=0" + "++dataset.angle_range=180" + "++dataset.pos_range=1000000000" + "++algorithm.n_tokens=${N_TOKENS}" + "algorithm.x_shape=[16,18,32]" + "++algorithm.context_frames=${CONTEXT_FRAMES}" + "++algorithm.log_video=${LOG_VIDEO}" + "++algorithm.diffusion.sampling_timesteps=${SAMPLING_TIMESTEPS}" + "++algorithm.dememwm.debug_force_all_streams=false" + "++algorithm.dememwm.training_stage=stage_2" + "++algorithm.dememwm.anchor.enabled=true" + "++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3]" + "++algorithm.dememwm.anchor.diverse_selection=true" + "++algorithm.dememwm.anchor.compress.downsample_ratio=${ANCHOR_DOWNSAMPLE_RATIO}" + "++algorithm.dememwm.anchor.allow_generated_as_anchor=false" + "++algorithm.dememwm.dynamic.enabled=true" + "++algorithm.dememwm.dynamic.exclude_latest_local_frames=4" + "++algorithm.dememwm.dynamic.recent_frames=8" + "++algorithm.dememwm.revisit.enabled=true" + "++algorithm.dememwm.revisit.deterministic_pose_retrieval=true" + "++algorithm.dememwm.revisit.fov_overlap_threshold=0.30" + "++algorithm.dememwm.revisit.high_quality_fov_threshold=0.70" + "++algorithm.dememwm.revisit.pose_preselect_topk=64" + "++algorithm.dememwm.revisit.fov_yaw_samples=25" + "++algorithm.dememwm.revisit.fov_pitch_samples=20" + "++algorithm.dememwm.revisit.fov_depth_samples=20" + "++algorithm.dememwm.revisit.plucker_weight=0.10" + "++algorithm.dememwm.revisit.max_frames=${REVISIT_MAX_FRAMES}" + "++algorithm.dememwm.revisit.compress.downsample_ratio=${REVISIT_DOWNSAMPLE_RATIO}" + "++algorithm.dememwm.stage_policy.noise_bucket_logging=true" + "++algorithm.dememwm.eval_ablation.enabled=true" + "++algorithm.dememwm.eval_ablation.branch=${ABLATION_BRANCH}" + "++algorithm.dememwm.cache.enabled=true" + "++algorithm.dememwm.cache.device=cpu" + "++algorithm.dememwm.cache.keep_raw_latents=all" + "++algorithm.dememwm.cache.keep_compressed_records=true" + "++algorithm.dememwm.cache.eviction_policy=none" + "++algorithm.dememwm.cache.no_evict=true" + "++algorithm.dememwm.cache.clear_between_videos=true" + "++algorithm.dememwm.cache.max_records=null" + "++algorithm.dememwm.cache.max_slots=null" + "++algorithm.dememwm.cache.on_capacity_exceeded=warn" + "experiment.validation.batch_size=${VAL_BATCH_SIZE}" + "experiment.validation.limit_batch=${VAL_LIMIT}" +) + +printf '%s\n' "${EVAL_ARGS[@]}" > "${LOG_DIR}/eval_args.txt" +echo "Launching evaluation..." +SECONDS=0 +srun "${PY}" -m main "${EVAL_ARGS[@]}" +EVAL_DURATION_SECONDS=${SECONDS} +nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits > "${LOG_DIR}/gpu_memory_after_mb.txt" || true + +cat > "${RUN_ROOT}/eval_manifest.txt" </dev/null 2>&1 || true + +srun python -m main \ + +name=train_dememwm_full_h200_2gpu_bs32_350k \ + +output_dir=/share_1/users/bonan_ding/worldmem_ckpt/dememwm_full_h200_2gpu_bs32_350k/ \ + wandb.mode=online \ + auto_resume=true \ + "experiment.tasks=[training]" \ + algorithm=dememwm_memory_dit \ + +customized_load=true \ + +seperate_load=true \ + +diffusion_model_path=/share_1/users/bonan_ding/WorldMem/open-oasis/checkpoints/oasis500m.safetensors \ + +vae_path=/share_1/users/bonan_ding/WorldMem/open-oasis/checkpoints/vit-l-20.safetensors \ + +only_tune_memory=false \ + dataset=video_minecraft_latent \ + dataset.save_dir=/share_1/users/bonan_ding/worldmem_data/minecraft \ + dataset.precomputed_feature_dir=/share_1/users/bonan_ding/worldmem_data/minecraft/vae_features \ + dataset.n_frames=1000 \ + +dataset.n_frames_valid=1100 \ + +dataset.customized_validation=true \ + +dataset.memory_condition_length=0 \ + +dataset.wo_updown=false \ + +dataset.angle_range=180 \ + +dataset.pos_range=8 \ + ++algorithm.n_tokens=4 \ + "algorithm.x_shape=[16,18,32]" \ + ++algorithm.context_frames=100 \ + ++algorithm.log_video=true \ + ++algorithm.diffusion.sampling_timesteps=20 \ + ++algorithm.dememwm.debug_force_all_streams=false \ + ++algorithm.dememwm.generated_history_proxy.enabled=true \ + ++algorithm.dememwm.generated_history_proxy.start_step=40000 \ + ++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \ + ++algorithm.dememwm.generated_history_proxy.max_prob=0.25 \ + ++algorithm.dememwm.generated_history_proxy.noise_std=0.25 \ + ++algorithm.dememwm.generated_history_proxy.dropout_prob=0.0 \ + ++algorithm.dememwm.anchor.enabled=true \ + ++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3] \ + ++algorithm.dememwm.anchor.diverse_selection=true \ + ++algorithm.dememwm.anchor.compress.downsample_ratio=3 \ + ++algorithm.dememwm.anchor.allow_generated_as_anchor=false \ + ++algorithm.dememwm.dynamic.enabled=true \ + ++algorithm.dememwm.dynamic.exclude_latest_local_frames=4 \ + ++algorithm.dememwm.dynamic.recent_frames=4 \ + ++algorithm.dememwm.revisit.enabled=true \ + ++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \ + ++algorithm.dememwm.revisit.fov_overlap_threshold=0.30 \ + ++algorithm.dememwm.revisit.high_quality_fov_threshold=0.70 \ + ++algorithm.dememwm.revisit.pose_preselect_topk=64 \ + ++algorithm.dememwm.revisit.fov_yaw_samples=25 \ + ++algorithm.dememwm.revisit.fov_pitch_samples=20 \ + ++algorithm.dememwm.revisit.fov_depth_samples=20 \ + ++algorithm.dememwm.revisit.plucker_weight=0.10 \ + ++algorithm.dememwm.revisit.max_frames=2 \ + ++algorithm.dememwm.revisit.compress.downsample_ratio=3 \ + ++algorithm.dememwm.stage_policy.noise_bucket_logging=true \ + ++algorithm.dememwm.cache.enabled=true \ + ++algorithm.dememwm.cache.device=cpu \ + ++algorithm.dememwm.cache.keep_raw_latents=all \ + ++algorithm.dememwm.cache.keep_compressed_records=true \ + ++algorithm.dememwm.cache.eviction_policy=none \ + ++algorithm.dememwm.cache.no_evict=true \ + ++algorithm.dememwm.cache.clear_between_videos=true \ + ++algorithm.dememwm.cache.max_records=null \ + ++algorithm.dememwm.cache.max_slots=null \ + ++algorithm.dememwm.cache.on_capacity_exceeded=warn \ + ++algorithm.dememwm.curriculum.enabled=true \ + ++algorithm.dememwm.curriculum.full_stage_start_step=20000 \ + ++algorithm.dememwm.curriculum.freeze_vae=true \ + ++algorithm.dememwm.curriculum.dit_freeze.enabled=true \ + ++algorithm.dememwm.curriculum.lr.dememwm_modules=4.0e-5 \ + ++algorithm.dememwm.curriculum.lr.memory_adapters=4.0e-5 \ + ++algorithm.dememwm.curriculum.lr.full_dit=1.0e-5 \ + experiment.training.batch_size=32 \ + experiment.training.optim.accumulate_grad_batches=1 \ + experiment.validation.batch_size=1 \ + experiment.validation.limit_batch=16 \ + experiment.training.checkpointing.every_n_train_steps=2000 \ + experiment.validation.val_every_n_step=2000 \ + experiment.training.max_steps=350000 diff --git a/tests/dememwm_import_helper.py b/tests/dememwm_import_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..693a389f39688264e0cffeb90dcf26b9c9500cc8 --- /dev/null +++ b/tests/dememwm_import_helper.py @@ -0,0 +1,11 @@ +import sys +import types +from pathlib import Path + + +def install_dememwm_namespace(): + root = Path(__file__).resolve().parents[1] + algorithms = sys.modules.setdefault("algorithms", types.ModuleType("algorithms")) + algorithms.__path__ = [str(root / "algorithms")] + worldmem = sys.modules.setdefault("algorithms.worldmem", types.ModuleType("algorithms.worldmem")) + worldmem.__path__ = [str(root / "algorithms" / "worldmem")] diff --git a/tests/test_dememwm_algorithm_static.py b/tests/test_dememwm_algorithm_static.py new file mode 100644 index 0000000000000000000000000000000000000000..8eebd2b47d316f55c8aaa78ff351e0ba56bc736d --- /dev/null +++ b/tests/test_dememwm_algorithm_static.py @@ -0,0 +1,81 @@ + +import ast +from pathlib import Path + + +def test_entrypoint_class_uses_standalone_mixin_and_base_video_dit(): + src = Path("algorithms/worldmem/dememwm_memory_dit.py").read_text() + tree = ast.parse(src) + classes = {node.name: [base.id if isinstance(base, ast.Name) else getattr(base, "attr", "") for base in node.bases] for node in tree.body if isinstance(node, ast.ClassDef)} + assert classes["DeMemWMMinecraft"] == ["MemoryDiTMixin", "BaseVideoDiTMinecraft"] + assert "DeMemWMMemoryDiTMinecraft = DeMemWMMinecraft" in src + assert "SSM" not in src + + +def test_algorithm_mixin_has_strict_checkpoint_helper_and_no_old_imports(): + src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() + assert "strict_checkpoint_key_check" in src + assert "strict_dememwm_checkpoint_key_check = strict_checkpoint_key_check" in src + assert "spatial_ssm_memory" not in src + assert "df_video_ssm_memory" not in src + assert "ssm_memory" not in src + + +def test_algorithm_mixin_wires_standalone_memory_retrieval_surface(): + src = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() + tree = ast.parse(src) + + required_imports = { + "CausalMemoryBank", + "MemoryBankQuery", + "stack_record_tokens", + "deterministic_revisit_retrieval", + "MemorySourceType", + } + for name in required_imports: + assert name in src + + imported_names = { + alias.asname or alias.name.rsplit(".", 1)[-1] + for node in ast.walk(tree) + if isinstance(node, (ast.Import, ast.ImportFrom)) + for alias in node.names + } + assert required_imports <= imported_names + + mixin = next( + node + for node in tree.body + if isinstance(node, ast.ClassDef) and node.name == "MemoryDiTMixin" + ) + method_names = { + node.name + for node in mixin.body + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + } + assert { + "_build_causal_memory_banks", + "_build_preselected_causal_memory_banks", + "_records_to_stream", + } <= method_names + + build_method = next( + node + for node in mixin.body + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == "build_memory_streams" + ) + call_names = set() + for node in ast.walk(build_method): + if isinstance(node, ast.Call): + func = node.func + if isinstance(func, ast.Name): + call_names.add(func.id) + elif isinstance(func, ast.Attribute): + call_names.add(func.attr) + assert { + "_build_preselected_causal_memory_banks", + "deterministic_revisit_retrieval", + "dememwm_dynamic_compressor", + } <= call_names + diff --git a/tests/test_dememwm_compression.py b/tests/test_dememwm_compression.py new file mode 100644 index 0000000000000000000000000000000000000000..6b878c3f2113212ad422edeb798bae4751987da5 --- /dev/null +++ b/tests/test_dememwm_compression.py @@ -0,0 +1,112 @@ +import torch +from dememwm_import_helper import install_dememwm_namespace + +install_dememwm_namespace() +from algorithms.worldmem.dememwm.cache import StreamingCache +from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor + + +def small_compressor(**kwargs): + return CausalConv3DDynamicCompressor( + latent_channels=3, + dit_hidden_size=8, + patch_size=2, + conv_kernel_t=3, + conv_stride_t=2, + max_source_frames=4, + **kwargs, + ) + + +def test_dynamic_compressor_shapes_and_budget(): + comp = small_compressor(exclude_latest_local_frames=0) + latents = torch.randn(4, 2, 3, 2, 2) + frame_indices = torch.arange(4)[:, None].repeat(1, 2) + target = torch.tensor([[1, 2], [4, 4]]) + tokens, mask, diag = comp(latents, frame_indices, None, target) + assert tokens.shape == (2, 2, 2, 8) + assert mask.shape == (2, 2, 2) + assert mask[0, 0].any() + assert diag["selected_source_count"].max().item() <= 4 + + +def test_dynamic_compressor_abstains_without_old_enough_sources(): + comp = small_compressor(exclude_latest_local_frames=4) + latents = torch.randn(2, 1, 3, 2, 2) + frame_indices = torch.tensor([[5], [6]]) + target = torch.tensor([[8]]) + tokens, mask, diag = comp(latents, frame_indices, None, target) + assert tokens.shape == (1, 1, 2, 8) + assert not mask.any() + assert diag["max_source_frame"].item() == -1 + assert diag["dynamic_min_gap_to_target_per_target"].item() == -1 + + +def test_dynamic_compressor_reports_generated_fraction_and_no_future(): + comp = small_compressor(exclude_latest_local_frames=0) + latents = torch.randn(3, 1, 3, 2, 2) + frame_indices = torch.tensor([[0], [2], [5]]) + generated = torch.tensor([[False], [True], [True]]) + target = torch.tensor([[3]]) + _, mask, diag = comp(latents, frame_indices, None, target, generated) + assert mask.any() + assert diag["max_source_frame"].item() == 2 + assert 0.0 < diag["generated_source_fraction"].item() < 1.0 + + +def test_dynamic_compressor_excludes_c_short_overlap_and_keeps_shape(): + comp = small_compressor(exclude_latest_local_frames=2) + latents = torch.randn(5, 1, 3, 2, 2) + frame_indices = torch.tensor([[0], [1], [2], [3], [4]]) + target = torch.tensor([[5]]) + tokens, mask, diag = comp(latents, frame_indices, None, target) + assert tokens.shape == (1, 1, 2, 8) + assert mask.any() + assert diag["max_source_frame"].item() == 2 + assert diag["dynamic_min_gap_to_target_per_target"].item() == 3 + assert diag["dynamic_max_gap_to_target_per_target"].item() == 5 + assert diag["dynamic_exclude_latest_local_frames"] == 2 + + +def test_cache_materialize_raw_latents_excludes_c_short_overlap(): + cache = StreamingCache(enabled=True, keep_raw_latents="all", keep_compressed_records=False) + latents = torch.randn(6, 1, 3, 2, 2) + frames = torch.arange(6).view(6, 1) + cache.add_raw_latents(latents, frames) + raw_latents, raw_frames, raw_generated, raw_pose = cache.materialize_raw_latents( + device=torch.device("cpu"), + dtype=latents.dtype, + max_recent_frames=8, + target_frame_indices=torch.tensor([[6]]), + exclude_latest_local_frames=4, + ) + assert raw_pose is None + assert raw_latents.shape[0] == 2 + assert raw_generated.shape == raw_frames.shape + assert raw_frames.flatten().tolist() == [0, 1] + + +def test_dynamic_compressor_preserves_grad_to_trainable_parts(): + comp = small_compressor(exclude_latest_local_frames=0) + latents = torch.randn(4, 1, 3, 2, 2) + frame_indices = torch.arange(4)[:, None] + target = torch.tensor([[4]]) + tokens, mask, _ = comp(latents, frame_indices, None, target) + assert mask.any() + tokens.square().sum().backward() + grads = [ + comp.conv3d.weight.grad, + comp.out_norm.weight.grad, + ] + assert all(grad is not None for grad in grads) + assert all(grad.abs().sum().item() > 0 for grad in grads) + + +def test_dynamic_compressor_selects_only_recent_valid_sources(): + comp = small_compressor(exclude_latest_local_frames=2) + latents = torch.randn(20, 1, 3, 2, 2) + frame_indices = torch.arange(20)[:, None] + target = torch.tensor([[10]]) + _, mask, diag = comp(latents, frame_indices, None, target) + assert mask.any() + assert diag["selected_source_count"].item() == 4 diff --git a/tests/test_dememwm_config_static.py b/tests/test_dememwm_config_static.py new file mode 100644 index 0000000000000000000000000000000000000000..51ae857f8f9a37f7a5f01e7d36f535fdf104ef56 --- /dev/null +++ b/tests/test_dememwm_config_static.py @@ -0,0 +1,212 @@ +from pathlib import Path + + +def test_config_is_distinct_standalone_memory_dit_path(): + text = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text() + assert "_name: dememwm_memory_dit" in text + assert "base_video_dit" in text + assert "memory_token_cross_attention: true" in text + assert "dememwm:" in text + assert "debug_force_all_streams" in text + assert "ssm_memory" not in text + assert "ssm_memory_ckpt_path" not in text + + +def test_registry_text_contains_new_algorithm_key(): + exp = Path("experiments/exp_video.py").read_text() + init = Path("algorithms/worldmem/__init__.py").read_text() + assert "DeMemWMMinecraft" in init + assert "DeMemWMMemoryDiTMinecraft" in init + assert "dememwm_memory_dit=DeMemWMMinecraft" in exp + assert "StateSpaceSpatialMemoryMinecraft" not in init + exp + assert "WorldMemMinecraft" not in init + exp + + +def test_current_config_contract_is_explicit_and_has_no_stale_sections(): + text = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text() + required = [ + "token_patch_size: 2", + "exclude_latest_local_frames: 4", + "deterministic_pose_retrieval: true", + "max_frames: 2", + "fov_overlap_threshold: 0.30", + "high_quality_fov_threshold: 0.70", + "plucker_weight: 0.10", + "fov_half_h: 52.5", + "fov_half_v: 37.5", + "fov_radius: 30.0", + "fov_yaw_samples: 25", + "fov_pitch_samples: 20", + "fov_depth_samples: 20", + "pose_preselect_topk: 64", + "plucker_grid_h: 4", + "plucker_grid_w: 4", + "plucker_focal_length: 0.35", + "noise_bucket_logging: true", + "eval_ablation:", + "branch: A_plus_D_plus_R_normal", + "generated_history_proxy:", + ] + for token in required: + assert token in text + for forbidden in ( + "anchor_ratio", + "dynamic_ratio", + "revisit_ratio", + "lambda_abstain", + "abstention:", + "force_gate_zero_when_invalid", + "use_residual_bound_loss", + "use_utility_loss", + "use_revisit_classifier_loss", + "min_score", + "generated_penalty", + "min_gap_frames", + "max_chunks", + "chunk_frames", + ): + assert forbidden not in text + + +def test_full_scripts_use_consumed_contract_overrides(): + required = [ + "algorithm.dememwm.dynamic.exclude_latest_local_frames=4", + "algorithm.dememwm.revisit.deterministic_pose_retrieval=true", + "algorithm.dememwm.revisit.fov_overlap_threshold=0.30", + "algorithm.dememwm.revisit.high_quality_fov_threshold=0.70", + "algorithm.dememwm.revisit.pose_preselect_topk=64", + "algorithm.dememwm.revisit.fov_yaw_samples=25", + "algorithm.dememwm.revisit.fov_pitch_samples=20", + "algorithm.dememwm.revisit.fov_depth_samples=20", + "algorithm.dememwm.revisit.plucker_weight=0.10", + "algorithm.dememwm.stage_policy.noise_bucket_logging=true", + "algorithm.dememwm.cache.keep_compressed_records=true", + ] + stale = [ + "algorithm.dememwm.loss.", + "algorithm.dememwm.abstention.", + "algorithm.dememwm.anchor.topk", + "algorithm.dememwm.anchor.pin_prefix", + "algorithm.dememwm.dynamic.include_generated_recent", + "algorithm.dememwm.revisit.deterministic_only", + "algorithm.dememwm.revisit.min_age_frames", + "algorithm.dememwm.revisit.topk", + "algorithm.dememwm.revisit.min_gap_frames", + "algorithm.dememwm.revisit.max_chunks", + "algorithm.dememwm.revisit.chunk_frames", + "algorithm.dememwm.revisit.min_score", + "algorithm.dememwm.revisit.generated_penalty", + "algorithm.dememwm.rollout.", + ] + for rel in ("scripts/dememwm_full_train.slurm", "scripts/dememwm_full_eval.slurm"): + text = Path(rel).read_text() + for token in required: + assert token in text, f"{token} missing from {rel}" + for token in stale: + assert token not in text, f"stale {token} override remains in {rel}" + + +def test_algorithm_consumes_final_contract_guards_and_revisit_geometry_args(): + text = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() + assert "_dememwm_validate_config_contract = _validate_config_contract" in text + for token in [ + "_validate_config_contract", + "deterministic_pose_retrieval", + "exclude_latest_local_frames", + "noise_bucket_logging", + "anchor_effective_enabled", + "dynamic_effective_enabled", + "revisit_effective_enabled", + "stale DeMemWM config fields", + "revisit_retrieval_kwargs", + "fov_half_h", + "fov_yaw_samples", + "plucker_grid_h", + "plucker_focal_length", + "pose_preselect_topk", + ]: + assert token in text + assert '_cfg_get(revisit_cfg, "topk"' not in text + assert "lambda_abstain" not in text + + +def test_revisit_retrieval_is_deterministic_fov_plucker_contract(): + retrieval = Path("algorithms/worldmem/dememwm/retrieval.py").read_text() + labels = Path("algorithms/worldmem/dememwm/labels.py").read_text() + algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() + diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text() + for token in [ + "exclude_local_context_frames", + "fov_overlap_threshold", + "plucker_weight", + "high_quality_fov_threshold", + "best_selected_frame_fov_overlap", + "deterministic_fov_coverage_plucker", + "valid_revisit_mask", + "revisit_candidate_frame_count", + "valid_candidate_label_count", + "valid_revisit_target_count", + "valid_revisit_frame_count", + "no_valid_revisit_count", + "revisit_selected_frame_count", + "revisit_frame_fov_overlap", + "revisit_abstained_count", + ]: + assert token in retrieval + labels + algorithm + diagnostics + assert "same_video" not in retrieval + labels + assert "wrong_video" not in retrieval + labels + for stale in ["time_weight", "pose_weight", "latent_weight", "generated_penalty", "min_score"]: + assert f'self._cfg_get(revisit_cfg, "{stale}"' not in algorithm + + +def test_dynamic_compressor_excludes_c_short_contract(): + compression = Path("algorithms/worldmem/dememwm/compression.py").read_text() + cache = Path("algorithms/worldmem/dememwm/cache.py").read_text() + algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() + for token in [ + "exclude_latest_local_frames", + "src_frames_b < target - exclude_latest_local_frames", + "dynamic_min_gap_to_target_per_target", + "dynamic_max_gap_to_target_per_target", + "dynamic_exclude_latest_local_frames", + "_local_context_exclusion_frames", + ]: + assert token in compression + cache + algorithm + assert "src_frames_b < target, as_tuple=False" not in compression + assert "src < int(target), as_tuple=False" not in cache + + +def test_eval_ablation_and_noise_bucket_logging_contracts(): + schedules = Path("algorithms/worldmem/dememwm/schedules.py").read_text() + diagnostics = Path("algorithms/worldmem/dememwm/diagnostics.py").read_text() + algorithm = Path("algorithms/worldmem/dememwm/algorithm.py").read_text() + matrix = Path("scripts/dememwm_eval_ablation_matrix.sh").read_text() + for branch in [ + "memory_off", + "A_only", + "D_only", + "A_plus_D", + "A_plus_D_plus_R_normal", + "R_forced_off", + "R_forced_on", + "wrong_pose", + "time_shuffle", + "source_matched_random", + "pose_shuffle", + "wrong_video", + "local_context_overlap_fake_revisit", + ]: + assert branch in schedules + assert branch in matrix + for token in [ + "noise_bucket_from_denoising_fraction", + "noise_bucket_from_noise_levels", + "summarize_noise_bucket_diagnostics", + "noise_bucket_id", + "summarize_eval_ablation_diagnostics", + "eval_bucket_true_revisit_count", + "eval_bucket_no_valid_revisit_count", + "eval_bucket_corrupted_memory_count", + "apply_revisit_eval_corruption", + ]: + assert token in schedules + diagnostics + algorithm diff --git a/tests/test_dememwm_dit_extension_static.py b/tests/test_dememwm_dit_extension_static.py new file mode 100644 index 0000000000000000000000000000000000000000..b42499770bb5b1305d0b5f62c6e0c74e750b1627 --- /dev/null +++ b/tests/test_dememwm_dit_extension_static.py @@ -0,0 +1,157 @@ + +import torch +from dememwm_import_helper import install_dememwm_namespace +install_dememwm_namespace() + +from algorithms.worldmem.models.dit import DiT, MemoryTokenCrossAttention, SpatioTemporalDiTBlock +import ast +from pathlib import Path + + +def test_block_uses_one_shared_typed_adapter_when_memory_enabled(): + block = SpatioTemporalDiTBlock(hidden_size=16, num_heads=4, reference_length=1, use_memory_token_cross_attention=True) + assert hasattr(block, "memory_token_cross_attn") + assert not hasattr(block, "dynamic_memory_token_cross_attn") + assert not hasattr(block, "retrieval_memory_token_cross_attn") + adapters = [module for module in block.modules() if isinstance(module, MemoryTokenCrossAttention)] + assert len(adapters) == 1 + assert hasattr(block.memory_token_cross_attn, "memory_type_embed") + assert hasattr(block.memory_token_cross_attn, "memory_type_gate") + + +def test_memory_cross_attention_zero_residual_gate_zeroes_delta(): + attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2) + with torch.no_grad(): + attn.adaLN_modulation[-1].bias.fill_(1.0) + x = torch.randn(1, 2, 1, 1, 8) + c = torch.randn(1, 2, 8) + mem = torch.randn(1, 2, 3, 8) + mask = torch.ones(1, 2, 3, dtype=torch.bool) + delta = attn(x, c, mem, mask, return_delta=True, residual_gate=torch.zeros(1, 2, 1)) + assert torch.allclose(delta, torch.zeros_like(delta), atol=1e-6) + + +def test_memory_cross_attention_zero_token_gate_masks_valid_tokens(): + attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2) + with torch.no_grad(): + attn.adaLN_modulation[-1].bias.fill_(1.0) + x = torch.randn(1, 2, 1, 1, 8) + c = torch.randn(1, 2, 8) + mem = torch.randn(1, 2, 3, 8) + mask = torch.ones(1, 2, 3, dtype=torch.bool) + delta = attn(x, c, mem, mask, return_delta=True, memory_token_gate=torch.zeros(1, 2, 3)) + assert torch.allclose(delta, torch.zeros_like(delta), atol=1e-6) + + +def test_shared_memory_attention_zero_revisit_gate_matches_anchor_only(): + torch.manual_seed(0) + attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2) + with torch.no_grad(): + attn.adaLN_modulation[-1].bias.fill_(1.0) + x = torch.randn(1, 2, 1, 1, 8) + c = torch.randn(1, 2, 8) + anchor = torch.randn(1, 2, 2, 8) + revisit = torch.randn(1, 2, 3, 8) * 100.0 + packed = torch.cat([anchor, revisit], dim=2) + out_anchor = attn( + x, + c, + anchor, + torch.ones(1, 2, 2, dtype=torch.bool), + memory_type_ids=torch.tensor([0, 0]), + memory_token_gate=torch.ones(1, 2, 2), + ) + out_packed = attn( + x, + c, + packed, + torch.ones(1, 2, 5, dtype=torch.bool), + memory_type_ids=torch.tensor([0, 0, 2, 2, 2]), + memory_token_gate=torch.tensor([[[1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]], dtype=torch.float32), + residual_gate=torch.ones(1, 2), + ) + assert torch.allclose(out_anchor, out_packed, atol=1e-5) + + +def test_memory_cross_attention_type_ids_record_stage_gates(): + attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2) + x = torch.randn(1, 2, 1, 1, 8) + c = torch.randn(1, 2, 8) + mem = torch.randn(1, 2, 3, 8) + mask = torch.ones(1, 2, 3, dtype=torch.bool) + type_ids = torch.tensor([0, 1, 2]) + _ = attn(x, c, mem, mask, memory_type_ids=type_ids, memory_token_gate=torch.ones(1, 2, 3)) + assert torch.is_tensor(attn.last_type_gate_anchor_mean) + assert torch.is_tensor(attn.last_type_gate_dynamic_mean) + assert torch.is_tensor(attn.last_type_gate_revisit_mean) + + +def test_dit_accepts_dynamic_rank4_tokens_and_all_false_masks_without_nan(): + model = DiT(input_h=4, input_w=4, patch_size=2, in_channels=2, hidden_size=32, depth=1, num_heads=4, action_cond_dim=0, max_frames=2, reference_length=1, memory_token_cross_attention=True) + x = torch.randn(1, 2, 2, 4, 4) + t = torch.zeros(1, 2, dtype=torch.long) + mem = torch.randn(1, 2, 3, 32) + mask = torch.zeros(1, 2, 3, dtype=torch.bool) + out = model(x, t, action_cond=None, memory_dynamic_tokens=mem, memory_dynamic_mask=mask, memory_dynamic_gate=torch.ones(1, 2, 1)) + assert out.shape == x.shape + assert torch.isfinite(out).all() + + +def test_dit_old_two_stream_call_still_works(): + model = DiT(input_h=4, input_w=4, patch_size=2, in_channels=2, hidden_size=32, depth=1, num_heads=4, action_cond_dim=0, max_frames=2, reference_length=1, memory_token_cross_attention=True) + x = torch.randn(1, 2, 2, 4, 4) + t = torch.zeros(1, 2, dtype=torch.long) + mem = torch.randn(1, 2, 3, 32) + mask = torch.ones(1, 2, 3, dtype=torch.bool) + out = model(x, t, action_cond=None, memory_tokens=mem, memory_token_mask=mask, memory_retrieval_tokens=mem, memory_retrieval_mask=mask) + assert out.shape == x.shape + + +def test_diffusion_methods_accept_option_c_kwargs_by_signature(): + # Static/API check avoids importing omegaconf under base Python. + tree = ast.parse(Path("algorithms/worldmem/models/diffusion.py").read_text()) + wanted = {"model_predictions", "p_mean_variance", "forward", "sample_step", "ddpm_sample_step", "ddim_sample_step"} + found = {} + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef) and node.name in wanted: + found[node.name] = node.args.kwarg.arg if node.args.kwarg else None + assert found.keys() >= wanted + assert all(found[name] == "memory_kwargs" for name in wanted) + + + +def test_fresh_memory_cross_attention_identity_init_delta_ratio_is_zero(): + attn = MemoryTokenCrossAttention(hidden_size=8, num_heads=2) + x = torch.randn(1, 2, 1, 1, 8) + c = torch.randn(1, 2, 8) + mem = torch.randn(1, 2, 3, 8) + mask = torch.ones(1, 2, 3, dtype=torch.bool) + delta = attn(x, c, mem, mask, return_delta=True, residual_gate=torch.ones(1, 2, 1)) + assert torch.allclose(delta, torch.zeros_like(delta), atol=1e-6) + assert float(attn.last_delta_ratio.item()) <= 1e-7 + + +def test_fresh_dit_memory_on_matches_memory_off_and_reports_delta_ratio(): + model = DiT(input_h=4, input_w=4, patch_size=2, in_channels=2, hidden_size=32, depth=1, num_heads=4, action_cond_dim=0, max_frames=2, reference_length=1, memory_token_cross_attention=True) + x = torch.randn(1, 2, 2, 4, 4) + t = torch.zeros(1, 2, dtype=torch.long) + mem = torch.randn(1, 2, 3, 32) + mask = torch.ones(1, 2, 3, dtype=torch.bool) + out_off = model(x, t, action_cond=None) + out_on = model( + x, + t, + action_cond=None, + memory_tokens=mem, + memory_token_mask=mask, + memory_dynamic_tokens=mem, + memory_dynamic_mask=mask, + memory_retrieval_tokens=mem, + memory_retrieval_mask=mask, + memory_anchor_gate=torch.ones(1, 2, 1), + memory_dynamic_gate=torch.ones(1, 2, 1), + memory_retrieval_gate=torch.ones(1, 2, 1), + ) + diagnostics = model.memory_adapter_delta_diagnostics() + assert torch.allclose(out_on, out_off, atol=1e-6) + assert diagnostics["memory_adapter_delta_ratio_max"] <= 1e-7 diff --git a/tests/test_dememwm_eval_ablation.py b/tests/test_dememwm_eval_ablation.py new file mode 100644 index 0000000000000000000000000000000000000000..a59fa233915c4710a4c4e4dc279106e36019520a --- /dev/null +++ b/tests/test_dememwm_eval_ablation.py @@ -0,0 +1,201 @@ +import types + +import pytest +import torch +from dememwm_import_helper import install_dememwm_namespace + +install_dememwm_namespace() +from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin +from algorithms.worldmem.dememwm.compression import CausalConv3DDynamicCompressor +from algorithms.worldmem.dememwm.diagnostics import summarize_eval_ablation_diagnostics +from algorithms.worldmem.dememwm.schedules import ( + EVAL_ABLATION_BRANCHES, + EVAL_ABLATION_BRANCH_TO_ID, + EVAL_CORRUPTION_BRANCHES, + normalize_eval_ablation_branch, +) + + +WAVE9_BRANCHES = ( + "memory_off", + "A_only", + "D_only", + "A_plus_D", + "A_plus_D_plus_R_normal", + "R_forced_off", + "R_forced_on", + "wrong_pose", + "time_shuffle", + "source_matched_random", + "pose_shuffle", + "wrong_video", + "local_context_overlap_fake_revisit", +) + + +def _device(): + return torch.device("cpu") + + +def test_wave9_branch_registry_is_exact_and_validated(): + assert EVAL_ABLATION_BRANCHES == WAVE9_BRANCHES + assert EVAL_ABLATION_BRANCH_TO_ID["memory_off"] == 0 + assert EVAL_ABLATION_BRANCH_TO_ID["local_context_overlap_fake_revisit"] == len(WAVE9_BRANCHES) - 1 + assert EVAL_CORRUPTION_BRANCHES == WAVE9_BRANCHES[7:] + assert normalize_eval_ablation_branch(None) == "A_plus_D_plus_R_normal" + assert normalize_eval_ablation_branch("wrong_pose") == "wrong_pose" + with pytest.raises(ValueError): + normalize_eval_ablation_branch("ratio_sweep") + + +def test_eval_ablation_diagnostics_bucket_counts(): + diag = summarize_eval_ablation_diagnostics( + enabled=True, + branch="wrong_pose", + valid_revisit_mask=torch.tensor([[True, True, True, False]]), + no_valid_revisit_mask=torch.tensor([[False, False, False, True]]), + eval_corrupted_revisit_mask=torch.tensor([[False, True, True, False]]), + ) + assert diag["eval_ablation_enabled"] is True + assert diag["eval_ablation_branch"] == "wrong_pose" + assert diag["eval_ablation_branch_id"] == EVAL_ABLATION_BRANCH_TO_ID["wrong_pose"] + assert diag["eval_bucket_true_revisit_count"] == 1 + assert diag["eval_bucket_no_valid_revisit_count"] == 1 + assert diag["eval_bucket_corrupted_memory_count"] == 2 + assert diag["eval_bucket_true_revisit_fraction"] == pytest.approx(0.25) + assert diag["eval_bucket_no_valid_revisit_fraction"] == pytest.approx(0.25) + assert diag["eval_bucket_corrupted_memory_fraction"] == pytest.approx(0.5) + + +class ConstantGate(torch.nn.Module): + def __init__(self, value: float): + super().__init__() + self.value = float(value) + + def forward(self, *, valid_revisit_mask, best_selected_fov_overlap, best_selected_plucker_overlap, selected_gap_frames): + del valid_revisit_mask, best_selected_plucker_overlap, selected_gap_frames + return torch.full_like(best_selected_fov_overlap, self.value, dtype=torch.float32) + + +class DummyDeMemWM(MemoryDiTMixin): + def __init__(self, branch: str, device: torch.device): + self.cfg = types.SimpleNamespace( + dememwm=types.SimpleNamespace( + enabled=True, + training_stage="stage_2", + debug_force_all_streams=False, + token_patch_size=2, + curriculum=types.SimpleNamespace(enabled=False), + anchor=types.SimpleNamespace( + enabled=True, + anchor_indices=[0, 1], + allow_generated_as_anchor=False, + diverse_selection=False, + compress=types.SimpleNamespace(pool_h=1, pool_w=1), + ), + dynamic=types.SimpleNamespace( + enabled=True, + exclude_latest_local_frames=2, + recent_frames=4, + conv_kernel_t=3, + conv_stride_t=2, + ), + revisit=types.SimpleNamespace( + enabled=True, + deterministic_pose_retrieval=True, + fov_overlap_threshold=0.0, + plucker_weight=0.1, + max_frames=2, + compress=types.SimpleNamespace(pool_h=1, pool_w=1), + ), + stage_policy=types.SimpleNamespace(noise_bucket_logging=True), + eval_ablation=types.SimpleNamespace(enabled=True, branch=branch), + generated_history_proxy=types.SimpleNamespace(enabled=False), + injection=types.SimpleNamespace(dit_hidden_size=8, anchor_gate=1.0, dynamic_gate=1.0, revisit_gate=1.0), + cache=types.SimpleNamespace(enabled=False), + checkpoint=types.SimpleNamespace(strict_dememwm_eval_load=True), + ), + weight_decay=0.0, + optimizer_beta=(0.9, 0.999), + ) + self.global_step = 0 + self.x_stacked_shape = (1, 4, 4) + self.dememwm_anchor_proj = torch.nn.Linear(4, 8, bias=False).to(device) + self.dememwm_revisit_proj = torch.nn.Linear(4, 8, bias=False).to(device) + self.dememwm_dynamic_compressor = CausalConv3DDynamicCompressor( + latent_channels=1, + dit_hidden_size=8, + patch_size=2, + conv_kernel_t=3, + conv_stride_t=2, + max_source_frames=4, + exclude_latest_local_frames=2, + ).to(device) + self.dememwm_revisit_gate = ConstantGate(0.25).to(device) + + +def _streams(branch: str): + device = _device() + model = DummyDeMemWM(branch, device) + latents = torch.arange(12 * 1 * 1 * 4 * 4, device=device, dtype=torch.float32).reshape(12, 1, 1, 4, 4) / 100.0 + source_frames = torch.arange(12, device=device).reshape(12, 1) + target_frames = torch.tensor([[8], [12]], device=device) + pose = torch.zeros((12, 1, 5), device=device, dtype=torch.float32) + target_pose = torch.zeros((2, 1, 5), device=device, dtype=torch.float32) + return model.build_memory_streams( + latents, + source_frames, + target_frame_indices=target_frames, + pose=pose, + target_pose=target_pose, + action=None, + target_action=None, + ) + + +def test_eval_ablation_stream_enable_branches_control_masks_and_gates(): + memory_off = _streams("memory_off") + assert memory_off.anchor_gate == 0.0 + assert memory_off.dynamic_gate == 0.0 + assert torch.count_nonzero(memory_off.revisit_gate).item() == 0 + assert not memory_off.anchor_mask.any() + assert not memory_off.dynamic_mask.any() + assert not memory_off.revisit_mask.any() + + a_only = _streams("A_only") + assert a_only.anchor_gate == 1.0 + assert a_only.anchor_mask.any() + assert not a_only.dynamic_mask.any() + assert not a_only.revisit_mask.any() + + d_only = _streams("D_only") + assert d_only.dynamic_gate == 1.0 + assert d_only.dynamic_mask.any() + assert d_only.anchor_gate == 0.0 + assert not d_only.anchor_mask.any() + assert not d_only.revisit_mask.any() + + a_plus_d = _streams("A_plus_D") + assert a_plus_d.anchor_mask.any() + assert a_plus_d.dynamic_mask.any() + assert not a_plus_d.revisit_mask.any() + assert torch.count_nonzero(a_plus_d.revisit_gate).item() == 0 + + +def test_eval_ablation_forced_revisit_controls_are_isolated_to_eval_branch(): + normal = _streams("A_plus_D_plus_R_normal") + forced_off = _streams("R_forced_off") + forced_on = _streams("R_forced_on") + assert normal.valid_revisit_mask.all() + assert torch.allclose(normal.revisit_gate, torch.full_like(normal.revisit_gate, 0.25)) + assert torch.count_nonzero(forced_off.revisit_gate).item() == 0 + assert torch.equal(forced_on.revisit_gate, forced_on.valid_revisit_mask.to(dtype=forced_on.revisit_gate.dtype)) + assert forced_on.diagnostics["eval_ablation_branch"] == "R_forced_on" + + +def test_eval_ablation_corruption_branch_marks_corrupted_revisit_without_zeroing_gate(): + wrong_pose = _streams("wrong_pose") + assert wrong_pose.valid_revisit_mask.all() + assert torch.allclose(wrong_pose.revisit_gate, torch.full_like(wrong_pose.revisit_gate, 0.25)) + assert wrong_pose.diagnostics["eval_bucket_corrupted_memory_count"] == int(wrong_pose.valid_revisit_mask.numel()) + assert wrong_pose.diagnostics["eval_bucket_true_revisit_count"] == 0 diff --git a/tests/test_dememwm_freeze_policy.py b/tests/test_dememwm_freeze_policy.py new file mode 100644 index 0000000000000000000000000000000000000000..ce5306770769f77dd81dd46e66e649d0ca8c5bf1 --- /dev/null +++ b/tests/test_dememwm_freeze_policy.py @@ -0,0 +1,96 @@ +from types import SimpleNamespace + +import torch + +from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin + + +class _DummyBlock(torch.nn.Module): + def __init__(self): + super().__init__() + self.base = torch.nn.Linear(2, 2) + self.memory_token_cross_attn = torch.nn.Linear(2, 2) + + +class _DummyDiffusionModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.model = torch.nn.Module() + self.model.blocks = torch.nn.ModuleList([_DummyBlock()]) + + +class _DummyDeMemWMModel(MemoryDiTMixin, torch.nn.Module): + def __init__(self): + super().__init__() + self.global_step = 0 + self.cfg = SimpleNamespace( + weight_decay=0.01, + optimizer_beta=(0.9, 0.999), + dememwm=SimpleNamespace( + curriculum=SimpleNamespace( + enabled=True, + full_stage_start_step=10, + freeze_vae=True, + dit_freeze=SimpleNamespace(enabled=True), + lr=SimpleNamespace( + dememwm_modules=1.0e-4, + memory_adapters=2.0e-4, + full_dit=1.0e-5, + ), + ) + ), + ) + self.dememwm_anchor_proj = torch.nn.Linear(2, 2) + self.diffusion_model = _DummyDiffusionModel() + self.vae = torch.nn.Linear(2, 2) + + +def _params_for_group(model, group_name, state): + return [ + param + for name, param in model.named_parameters() + if model._param_group_name(name, state) == group_name + ] + + +def test_dit_freeze_keeps_requires_grad_stable_and_zeroes_optimizer_lr(): + model = _DummyDeMemWMModel() + + frozen_state = model._apply_freeze_policy(step=0) + full_dit_params = _params_for_group(model, "full_dit", frozen_state) + + assert frozen_state.dit_train_state == "frozen" + assert full_dit_params + assert all(param.requires_grad for param in full_dit_params) + assert model._last_dememwm_freeze_diagnostics["trainable_tensors_full_dit"] == 0 + assert model._last_dememwm_freeze_diagnostics["requires_grad_tensors_full_dit"] == len(full_dit_params) + assert all(not param.requires_grad for param in model.vae.parameters()) + + for param in full_dit_params: + param.grad = torch.ones_like(param) + memory_adapter_params = _params_for_group(model, "memory_adapters", frozen_state) + for param in memory_adapter_params: + param.grad = torch.ones_like(param) + model.on_after_backward() + assert all(param.grad is None for param in full_dit_params) + assert all(param.grad is not None for param in memory_adapter_params) + + optimizer = model.configure_optimizers() + lr_by_name = {group["name"]: group["lr"] for group in optimizer.param_groups} + assert lr_by_name["full_dit"] == 0.0 + assert lr_by_name["memory_adapters"] == 2.0e-4 + + full_state = model._apply_freeze_policy(optimizer, step=10) + lr_by_name = {group["name"]: group["lr"] for group in optimizer.param_groups} + + assert full_state.dit_train_state == "full" + assert all(param.requires_grad for param in full_dit_params) + assert model._last_dememwm_freeze_diagnostics["trainable_tensors_full_dit"] == len(full_dit_params) + assert lr_by_name["full_dit"] == 1.0e-5 + assert all(not param.requires_grad for param in model.vae.parameters()) + + model.global_step = 10 + for param in full_dit_params: + param.grad = torch.ones_like(param) + model.on_after_backward() + assert all(param.grad is not None for param in full_dit_params) diff --git a/tests/test_dememwm_gates_losses.py b/tests/test_dememwm_gates_losses.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3b18ea17ddb09ed1ecf42d79c275d2eaae4a7a --- /dev/null +++ b/tests/test_dememwm_gates_losses.py @@ -0,0 +1,31 @@ +import torch +from dememwm_import_helper import install_dememwm_namespace + +install_dememwm_namespace() +from algorithms.worldmem.dememwm.gates import RevisitRawGate + + +def test_revisit_raw_gate_is_learned_and_effective_gate_masks_validity(): + gate = RevisitRawGate(init_logit=0.0) + valid = torch.tensor([[False, True]]) + raw = gate( + valid_revisit_mask=valid, + best_selected_fov_overlap=torch.tensor([[0.0, 0.8]]), + best_selected_plucker_overlap=torch.tensor([[0.0, 0.6]]), + selected_gap_frames=torch.tensor([[-1.0, 8.0]]), + ) + eff = valid.float() * raw + assert gate.net.in_features == 3 + assert raw.shape == (1, 2) + assert raw.requires_grad + assert eff[0, 0].item() == 0.0 + assert eff[0, 1].item() == raw[0, 1].item() + + +def test_revisit_raw_gate_returns_pre_stage_learned_output(): + gate = RevisitRawGate(init_logit=0.0) + valid = torch.ones((1, 2), dtype=torch.bool) + learned = gate(valid_revisit_mask=valid) + effective = learned * 0.25 + assert torch.allclose(learned, torch.full_like(learned, 0.5)) + assert torch.allclose(effective, torch.full_like(effective, 0.125)) diff --git a/tests/test_dememwm_generated_history_proxy.py b/tests/test_dememwm_generated_history_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..2313f45a73f5b4e0481a4635374a634df9ef5f5f --- /dev/null +++ b/tests/test_dememwm_generated_history_proxy.py @@ -0,0 +1,147 @@ +from pathlib import Path +from types import SimpleNamespace + +import torch +from dememwm_import_helper import install_dememwm_namespace + +install_dememwm_namespace() + +from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin +from algorithms.worldmem.dememwm.types import MemorySourceType + + +class DummyDeMemWM(MemoryDiTMixin): + def __init__(self, proxy_cfg=None, step=0): + self.cfg = SimpleNamespace(dememwm=SimpleNamespace(generated_history_proxy=proxy_cfg)) + self.global_step = step + + +def _proxy_cfg(**overrides): + values = { + "enabled": True, + "start_step": 0, + "ramp_steps": 0, + "max_prob": 1.0, + "noise_std": 0.5, + "dropout_prob": 0.0, + } + values.update(overrides) + return SimpleNamespace(**values) + + +def test_generated_history_proxy_config_defaults_disabled_and_train_script_enables_explicit_values(): + config = Path("configurations/algorithm/dememwm_memory_dit.yaml").read_text() + train_script = Path("scripts/dememwm_full_train.slurm").read_text() + for token in [ + "generated_history_proxy:", + "enabled: false", + "start_step: 0", + "ramp_steps: 1", + "max_prob: 0.0", + "noise_std: 0.25", + "dropout_prob: 0.0", + ]: + assert token in config + for token in [ + "algorithm.dememwm.generated_history_proxy.enabled=true", + "algorithm.dememwm.generated_history_proxy.start_step=40000", + "algorithm.dememwm.generated_history_proxy.ramp_steps=40000", + "algorithm.dememwm.generated_history_proxy.max_prob=0.25", + "algorithm.dememwm.generated_history_proxy.noise_std=0.25", + "algorithm.dememwm.generated_history_proxy.dropout_prob=0.0", + ]: + assert token in train_script + + +def test_generated_history_proxy_probability_ramps_after_start_step(): + model = DummyDeMemWM(_proxy_cfg(start_step=10, ramp_steps=10, max_prob=0.5)) + assert model._generated_history_proxy_prob(step=9) == 0.0 + assert model._generated_history_proxy_prob(step=10) == 0.0 + assert model._generated_history_proxy_prob(step=15) == 0.25 + assert model._generated_history_proxy_prob(step=20) == 0.5 + assert model._generated_history_proxy_prob(step=30) == 0.5 + + +def test_generated_history_proxy_corrupts_only_returned_memory_source_and_marks_frames(): + model = DummyDeMemWM(_proxy_cfg(max_prob=1.0, noise_std=0.5, dropout_prob=0.0), step=0) + source_latents = torch.zeros(4, 1, 1, 2, 2) + source_is_generated = torch.zeros(4, 1, dtype=torch.bool) + + torch.manual_seed(123) + corrupted, generated, diagnostics = model._apply_generated_history_proxy( + source_latents, + source_is_generated, + ) + + assert torch.equal(source_latents, torch.zeros_like(source_latents)) + assert not torch.equal(corrupted, source_latents) + assert generated.all() + assert not source_is_generated.any() + assert diagnostics["generated_history_proxy_frame_count"] == 4 + assert diagnostics["generated_history_proxy_frame_fraction"] == 1.0 + + + +def test_generated_history_proxy_respects_context_prefix_and_target_window_bounds(): + model = DummyDeMemWM(_proxy_cfg(max_prob=1.0, noise_std=0.5, dropout_prob=0.0), step=0) + source_latents = torch.zeros(8, 1, 1, 2, 2) + source_is_generated = torch.zeros(8, 1, dtype=torch.bool) + + torch.manual_seed(123) + corrupted, generated, diagnostics = model._apply_generated_history_proxy( + source_latents, + source_is_generated, + context_frame_count=3, + target_start_frame=6, + ) + + expected_generated = torch.tensor( + [[False], [False], [False], [True], [True], [True], [False], [False]], + dtype=torch.bool, + ) + assert torch.equal(source_latents, torch.zeros_like(source_latents)) + assert torch.equal(generated, expected_generated) + assert torch.equal(corrupted[:3], source_latents[:3]) + assert not torch.equal(corrupted[3:6], source_latents[3:6]) + assert torch.equal(corrupted[6:], source_latents[6:]) + assert diagnostics["generated_history_proxy_frame_count"] == 3 + assert diagnostics["generated_history_proxy_frame_fraction"] == 3 / 8 + + +def test_generated_proxy_frames_skip_prefix_anchors_but_remain_revisit_sources(): + model = DummyDeMemWM(_proxy_cfg(enabled=False)) + model.dememwm_anchor_proj = torch.nn.Linear(1, 2, bias=False) + model.dememwm_revisit_proj = torch.nn.Linear(1, 2, bias=False) + with torch.no_grad(): + model.dememwm_anchor_proj.weight.fill_(1.0) + model.dememwm_revisit_proj.weight.fill_(1.0) + + latents = torch.arange(4, dtype=torch.float32).reshape(4, 1, 1, 1, 1) + frame_indices = torch.arange(4).reshape(4, 1) + source_is_generated = torch.tensor([[True], [False], [False], [False]]) + anchor_projected = model._project_latent_patch_tokens(latents, model.dememwm_anchor_proj, patch_size=1) + revisit_projected = model._project_latent_patch_tokens(latents, model.dememwm_revisit_proj, patch_size=1) + + anchor_banks, revisit_banks = model._build_causal_memory_banks( + anchor_projected, + revisit_projected, + frame_indices, + source_is_generated, + pose=None, + action=None, + allow_generated_anchor=False, + anchor_indices=[0, 1], + anchor_pool_h=1, + anchor_pool_w=1, + revisit_pool_h=1, + revisit_pool_w=1, + src_h=1, + src_w=1, + ) + + anchor_frames = [int(record.frame_indices.item()) for record in anchor_banks[0].records] + assert anchor_frames == [1, 2] + generated_revisit = [record for record in revisit_banks[0].records if record.is_generated] + assert len(generated_revisit) == 1 + assert generated_revisit[0].source_type == MemorySourceType.GENERATED + assert generated_revisit[0].frame_indices.tolist() == [0] diff --git a/tests/test_dememwm_injection_static.py b/tests/test_dememwm_injection_static.py new file mode 100644 index 0000000000000000000000000000000000000000..a9eccda2c3e727784b0a71a674b97fc4bb90cd76 --- /dev/null +++ b/tests/test_dememwm_injection_static.py @@ -0,0 +1,51 @@ + +import torch +from dememwm_import_helper import install_dememwm_namespace +install_dememwm_namespace() + +from algorithms.worldmem.dememwm.injection import InjectionAdapter +from algorithms.worldmem.dememwm.types import MemoryStreamTensors + + +def _streams(dtype=torch.float32): + return MemoryStreamTensors( + anchor_tokens=torch.randn(2, 3, 1, 4, dtype=dtype), + anchor_mask=torch.ones(2, 3, 1), + dynamic_tokens=torch.randn(2, 3, 2, 4, dtype=dtype), + dynamic_mask=torch.tensor([[[1, 0], [1, 1], [0, 0]], [[1, 1], [0, 0], [1, 0]]]), + revisit_tokens=torch.randn(2, 3, 1, 4, dtype=dtype), + revisit_mask=torch.zeros(2, 3, 1), + anchor_gate=1.0, + dynamic_gate=torch.ones(2, 3, 1) * 0.5, + revisit_gate=0.0, + diagnostics={"selected_revisit_frame_record_ids": ["c1"], "dynamic_max_source_frame": torch.tensor(2)}, + ) + + +def test_injection_kwarg_names_masks_dtype_and_diagnostics(): + kwargs, diag = InjectionAdapter()(_streams(), dtype=torch.float64) + assert set(kwargs) == {"memory_tokens", "memory_token_mask", "memory_dynamic_tokens", "memory_dynamic_mask", "memory_retrieval_tokens", "memory_retrieval_mask", "memory_anchor_gate", "memory_dynamic_gate", "memory_retrieval_gate"} + assert kwargs["memory_tokens"].dtype == torch.float64 + assert kwargs["memory_dynamic_mask"].dtype == torch.bool + assert diag["anchor_valid_tokens"] == 6 + assert diag["dynamic_valid_fraction"] > 0.0 + assert diag["selected_revisit_frame_record_ids"] == ["c1"] + assert diag["max_source_frame"] == 2 + + +def test_injection_omit_disabled_streams(): + kwargs, _ = InjectionAdapter(omit_disabled=True)(_streams()) + assert kwargs["memory_retrieval_tokens"] is None + assert kwargs["memory_retrieval_mask"] is None + assert kwargs["memory_dynamic_tokens"] is not None + + +def test_injection_rejects_bad_mask_shape(): + streams = _streams() + streams.dynamic_mask = torch.ones(2, 3, 3) + try: + InjectionAdapter()(streams) + except ValueError as exc: + assert "dynamic mask" in str(exc) + else: + raise AssertionError("expected bad mask shape to fail") diff --git a/tests/test_dememwm_memory.py b/tests/test_dememwm_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..829f1a3f2f3824e1108205faf3055562b17126e3 --- /dev/null +++ b/tests/test_dememwm_memory.py @@ -0,0 +1,55 @@ + +import pytest +import torch +from dememwm_import_helper import install_dememwm_namespace +install_dememwm_namespace() +from algorithms.worldmem.dememwm.memory import CausalMemoryBank, MemoryBankQuery, stack_record_tokens +from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType + + +def _record(frame, source_type=MemorySourceType.PREFIX_GT, generated=False, slots=2): + return MemoryRecord(tokens=torch.ones(slots, 4) * frame, mask=torch.ones(slots, dtype=torch.bool), source_start=frame, source_end=frame + 1, frame_indices=torch.tensor([frame]), pose=None, source_type=source_type, is_generated=generated, chunk_id=f"r{frame}") + + +def test_prefix_anchors_are_prefix_gt_records(): + bank = CausalMemoryBank() + bank.add_prefix_anchors(torch.randn(2, 3, 4), torch.ones(2, 3, dtype=torch.bool), torch.tensor([0, 4])) + assert [r.source_type for r in bank.records] == [MemorySourceType.PREFIX_GT, MemorySourceType.PREFIX_GT] + assert not any(r.is_generated for r in bank.records) + + +def test_generated_records_are_not_prefix_gt_by_default(): + bank = CausalMemoryBank() + bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([3])) + assert bank.records[0].source_type == MemorySourceType.GENERATED + assert bank.records[0].is_generated + with pytest.raises(ValueError): + bank.add_generated_records(torch.randn(1, 2, 4), torch.ones(1, 2, dtype=torch.bool), torch.tensor([4]), source_type=MemorySourceType.PREFIX_GT) + + +def test_query_never_returns_future_sources(): + bank = CausalMemoryBank() + for f in [0, 2, 5, 7]: + bank.add_record(_record(f)) + records = bank.query(5) + assert [r.max_source_frame for r in records] == [0, 2] + bank.assert_causal(5, records) + + +def test_all_false_masks_are_valid_abstention_outputs(): + rec = MemoryRecord(tokens=torch.zeros(3, 4), mask=torch.zeros(3, dtype=torch.bool), source_start=0, source_end=1, frame_indices=torch.tensor([0]), pose=None, source_type=MemorySourceType.REVISIT, is_generated=False) + assert rec.valid_slots == 0 + tokens, mask = stack_record_tokens([rec]) + assert tokens.shape == (3, 4) + assert mask.sum().item() == 0 + + +def test_budgets_cap_records_and_slots(): + bank = CausalMemoryBank(max_records=10) + for f in range(6): + bank.add_record(_record(f, slots=2)) + records = bank.query(MemoryBankQuery(target_frame=10, max_records=2, max_slots=3)) + assert len(records) == 2 + tokens, mask = stack_record_tokens(records, max_slots=3) + assert tokens.shape[0] == 3 + assert mask.shape[0] == 3 diff --git a/tests/test_dememwm_negatives.py b/tests/test_dememwm_negatives.py new file mode 100644 index 0000000000000000000000000000000000000000..929967d56b33f2765559c9be2049a390c772c71b --- /dev/null +++ b/tests/test_dememwm_negatives.py @@ -0,0 +1,44 @@ +import torch +from dememwm_import_helper import install_dememwm_namespace + +install_dememwm_namespace() +from algorithms.worldmem.dememwm.negatives import apply_revisit_eval_corruption +from algorithms.worldmem.dememwm.schedules import EVAL_CORRUPTION_BRANCHES + + +def test_eval_corruption_branches_modify_valid_revisit_tokens(): + tokens = torch.arange(8, dtype=torch.float32).reshape(2, 4) + mask = torch.ones(2, dtype=torch.bool) + for branch in EVAL_CORRUPTION_BRANCHES: + corrupted, was_corrupted = apply_revisit_eval_corruption( + tokens=tokens, + mask=mask, + branch=branch, + target_frame=10, + ) + assert was_corrupted is True, branch + assert not torch.equal(corrupted, tokens), branch + + +def test_eval_corruption_is_disabled_for_unknown_branch_or_empty_mask(): + tokens = torch.arange(8, dtype=torch.float32).reshape(2, 4) + full_mask = torch.ones(2, dtype=torch.bool) + empty_mask = torch.zeros(2, dtype=torch.bool) + + unchanged, was_corrupted = apply_revisit_eval_corruption( + tokens=tokens, + mask=full_mask, + branch="A_plus_D_plus_R_normal", + target_frame=10, + ) + assert was_corrupted is False + assert torch.equal(unchanged, tokens) + + unchanged, was_corrupted = apply_revisit_eval_corruption( + tokens=tokens, + mask=empty_mask, + branch="wrong_pose", + target_frame=10, + ) + assert was_corrupted is False + assert torch.equal(unchanged, tokens) diff --git a/tests/test_dememwm_no_old_path_mutation_static.py b/tests/test_dememwm_no_old_path_mutation_static.py new file mode 100644 index 0000000000000000000000000000000000000000..ea0f3f1d5fb5859bf0db9eae6a0877b4cb023a2a --- /dev/null +++ b/tests/test_dememwm_no_old_path_mutation_static.py @@ -0,0 +1,22 @@ + +from pathlib import Path + + +def test_new_standalone_files_do_not_reference_old_ssm_method_paths(): + forbidden = [ + "spatial_ssm_memory", + "df_video_ssm_memory", + "StateSpaceSpatialMemoryMinecraft", + "ssm_memory_ckpt_path", + "dememwm_compression", + ] + for rel in ["algorithms/worldmem/dememwm_memory_dit.py", "algorithms/worldmem/dememwm/algorithm.py", "configurations/algorithm/dememwm_memory_dit.yaml"]: + text = Path(rel).read_text() + for token in forbidden: + assert token not in text, f"{token} leaked into {rel}" + + +def test_legacy_ssm_bridge_scaffold_is_not_present(): + assert not Path("algorithms/worldmem/models/dememwm_compression.py").exists() + assert not Path("algorithms/worldmem/df_video_ssm_memory.py").exists() + assert not Path("algorithms/worldmem/models/spatial_ssm_memory.py").exists() diff --git a/tests/test_dememwm_noise_bucket.py b/tests/test_dememwm_noise_bucket.py new file mode 100644 index 0000000000000000000000000000000000000000..441afdf966b17136c91fbcd7500811a92bee98ad --- /dev/null +++ b/tests/test_dememwm_noise_bucket.py @@ -0,0 +1,103 @@ +import torch +from dememwm_import_helper import install_dememwm_namespace + +install_dememwm_namespace() +from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin +from algorithms.worldmem.dememwm.diagnostics import summarize_noise_bucket_diagnostics, summarize_revisit_diagnostics + + +def test_revisit_diagnostics_report_mean_counts_per_target(): + diag = summarize_revisit_diagnostics( + [ + {"valid_revisit_frame_count": 6, "revisit_candidate_frame_count": 8, "revisit_selected_frame_count": 2}, + {"valid_revisit_frame_count": 3, "revisit_candidate_frame_count": 5, "revisit_selected_frame_count": 1}, + {"valid_revisit_frame_count": 0, "revisit_candidate_frame_count": 2, "revisit_selected_frame_count": 0, "no_valid_revisit_count": 1}, + ], + valid_revisit_mask=torch.tensor([[True, True, False]]), + ) + assert diag["revisit_candidate_frame_count"] == 5.0 + assert diag["revisit_candidate_count"] == 5.0 + assert diag["valid_revisit_frame_count"] == 3.0 + assert diag["valid_revisit_count"] == 3.0 + assert diag["revisit_selected_frame_count"] == 3 + assert diag["no_valid_revisit_count"] == 1 + + +def test_noise_bucket_diagnostics_include_valid_and_no_valid_counts(): + diag = summarize_noise_bucket_diagnostics( + noise_bucket="high", + valid_revisit_mask=torch.tensor([[True, True, False]]), + no_valid_revisit_mask=torch.tensor([[False, False, True]]), + ) + assert diag["noise_bucket"] == "high" + assert diag["noise_bucket_id"] == 0 + assert diag["noise_bucket_is_high"] == 1 + assert diag["noise_bucket_is_mid"] == 0 + assert diag["noise_bucket_high_target_count"] == 3 + assert diag["noise_bucket_mid_target_count"] == 0 + assert diag["valid_revisit_noise_bucket_high_count"] == 2 + assert diag["no_valid_revisit_noise_bucket_high_count"] == 1 + + +def test_noise_bucket_diagnostics_count_per_target_bucket_ids(): + diag = summarize_noise_bucket_diagnostics( + noise_bucket="mid", + noise_bucket_ids=torch.tensor([[0, 1, 2]]), + valid_revisit_mask=torch.tensor([[True, True, False]]), + no_valid_revisit_mask=torch.tensor([[False, False, True]]), + ) + assert diag["noise_bucket"] == "mid" + assert diag["noise_bucket_id"] == 1 + assert diag["noise_bucket_is_mid"] == 1 + assert diag["noise_bucket_high_target_count"] == 1 + assert diag["noise_bucket_mid_target_count"] == 1 + assert diag["noise_bucket_low_target_count"] == 1 + assert diag["valid_revisit_noise_bucket_high_count"] == 1 + assert diag["valid_revisit_noise_bucket_mid_count"] == 1 + assert diag["valid_revisit_noise_bucket_low_count"] == 0 + assert diag["no_valid_revisit_noise_bucket_low_count"] == 1 + + +def test_noise_bucket_log_allowlist_keeps_target_counts_only(): + keys = MemoryDiTMixin._TRAIN_DIAGNOSTIC_LOG_KEYS + for key in ( + "anchor_valid_fraction", + "dynamic_valid_fraction", + "revisit_valid_fraction", + "valid_revisit_mask_fraction", + "revisit_candidate_count", + "valid_revisit_count", + "revisit_selected_count", + "revisit_fov_overlap_mean", + "revisit_incremental_fov_overlap_mean", + "revisit_plucker_overlap_mean", + "causal_violation_count", + "noise_bucket_id", + "noise_bucket_is_high", + "noise_bucket_is_mid", + "noise_bucket_is_low", + "revisit_raw_gate_mean", + "valid_revisit_noise_bucket_high_count", + "valid_revisit_noise_bucket_mid_count", + "valid_revisit_noise_bucket_low_count", + "no_valid_revisit_noise_bucket_high_count", + "no_valid_revisit_noise_bucket_mid_count", + "no_valid_revisit_noise_bucket_low_count", + ): + assert key not in keys + for key in [ + "noise_bucket_target_count", + "noise_bucket_high_target_count", + "noise_bucket_mid_target_count", + "noise_bucket_low_target_count", + "revisit_candidate_frame_count", + "valid_revisit_frame_count", + "valid_revisit_target_count", + "revisit_selected_frame_count", + "revisit_frame_fov_overlap_mean", + "revisit_best_selected_frame_fov_overlap_mean", + "revisit_best_selected_plucker_overlap_mean", + "revisit_best_selected_gap_frames_mean", + "revisit_learned_gate_mean", + ]: + assert key in keys diff --git a/tests/test_dememwm_preselection.py b/tests/test_dememwm_preselection.py new file mode 100644 index 0000000000000000000000000000000000000000..3e9501ba173c64caba713ff52e83f0c8f1d87a41 --- /dev/null +++ b/tests/test_dememwm_preselection.py @@ -0,0 +1,220 @@ +import torch +from dememwm_import_helper import install_dememwm_namespace + +install_dememwm_namespace() +from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin +from algorithms.worldmem.dememwm.cache import StreamingCache +from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType + + +class Harness(MemoryDiTMixin): + def __init__(self): + self.n_tokens = 8 + self.context_frames = 0 + self.frame_stack = 1 + self.dememwm_anchor_proj = torch.nn.Linear(12, 8) + self.dememwm_revisit_proj = torch.nn.Linear(12, 8) + self.project_call_lengths = [] + self.project_call_values = [] + + def _project_latent_patch_tokens(self, latents, projection, patch_size): + self.project_call_lengths.append(int(latents.shape[0])) + self.project_call_values.append(latents[:, 0, 0, 0, 0].detach().cpu().tolist()) + return MemoryDiTMixin._project_latent_patch_tokens(self, latents, projection, patch_size) + + +def test_training_window_bounds_samples_inside_long_clip(): + harness = Harness() + torch.manual_seed(0) + starts = [] + for _ in range(20): + start, end = harness._training_window_bounds(128, torch.device("cpu")) + starts.append(start) + assert end - start == 8 + assert 0 <= start <= 120 + assert any(start != 120 for start in starts) + + +def test_training_window_bounds_respects_context_frames(): + harness = Harness() + harness.context_frames = 100 + torch.manual_seed(0) + starts = [] + for _ in range(20): + start, end = harness._training_window_bounds(128, torch.device("cpu")) + starts.append(start) + assert end - start == 8 + assert 100 <= start <= 120 + assert any(start != 120 for start in starts) + + +def test_revisit_local_context_exclusion_uses_n_tokens_times_frame_stack(): + harness = Harness() + harness.n_tokens = 4 + harness.frame_stack = 2 + harness.context_frames = 100 + assert harness._local_context_exclusion_frames() == 8 + + +def test_diverse_anchor_selection_uses_context_frames_not_literal_limit(): + harness = Harness() + harness.context_frames = 2 + latents = torch.randn(8, 1, 3, 2, 2) + frame_indices = torch.arange(8)[:, None] + poses = torch.zeros((8, 1, 5), dtype=torch.float32) + target_pose = torch.zeros((1, 1, 5), dtype=torch.float32) + anchor_banks, _, _, diag = harness._build_preselected_causal_memory_banks( + committed_latents=latents, + source_frame_indices=frame_indices, + source_is_generated=None, + pose=poses, + action=None, + target_frame_indices=torch.tensor([[6]]), + target_pose=target_pose, + target_action=None, + target_video_ids=None, + allow_generated_anchor=False, + anchor_indices=[0, 1, 2, 3], + anchor_pool_h=1, + anchor_pool_w=1, + anchor_diverse=True, + revisit_pool_h=1, + revisit_pool_w=1, + revisit_max_frames=0, + exclude_local_context_frames=4, + fov_overlap_threshold=0.0, + plucker_weight=0.1, + revisit_retrieval_kwargs=None, + token_patch_size=2, + ) + + assert [int(record.frame_indices.item()) for record in anchor_banks[0].records] == [0, 1] + assert diag["preselected_anchor_projected_frame_count"] == 2 + + +def test_preselected_memory_banks_project_only_selected_frames(): + harness = Harness() + latents = torch.randn(20, 1, 3, 2, 2) + frame_indices = torch.arange(20)[:, None] + target_frame_indices = torch.tensor([[10], [11]]) + poses = torch.zeros((20, 1, 5), dtype=torch.float32) + target_pose = torch.zeros((2, 1, 5), dtype=torch.float32) + anchor_banks, revisit_banks, tokens_per_frame, diag = harness._build_preselected_causal_memory_banks( + committed_latents=latents, + source_frame_indices=frame_indices, + source_is_generated=None, + pose=poses, + action=None, + target_frame_indices=target_frame_indices, + target_pose=target_pose, + target_action=None, + target_video_ids=None, + allow_generated_anchor=False, + anchor_indices=[0, 1, 2, 3], + anchor_pool_h=1, + anchor_pool_w=1, + anchor_diverse=False, + revisit_pool_h=1, + revisit_pool_w=1, + revisit_max_frames=2, + exclude_local_context_frames=4, + fov_overlap_threshold=0.0, + plucker_weight=0.1, + revisit_retrieval_kwargs=None, + token_patch_size=2, + ) + assert tokens_per_frame == 1 + assert len(anchor_banks[0].records) == 4 + assert len(revisit_banks[0].records) == 3 + assert diag["preselected_anchor_projected_frame_count"] == 4 + assert diag["preselected_revisit_projected_frame_count"] == 3 + assert diag["preselected_revisit_projected_frame_record_count"] == 3 + assert harness.project_call_lengths == [4, 1, 1, 1] + assert 20 not in harness.project_call_lengths + + +def test_preselected_revisit_projects_best_fov_frame_not_latest(): + harness = Harness() + latents = torch.arange(8, dtype=torch.float32).view(8, 1, 1, 1, 1).expand(8, 1, 3, 2, 2).clone() + frame_indices = torch.arange(8)[:, None] + pose_rows = torch.tensor( + [ + [0.0, 0.0, 0.0, 0.0, 180.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 180.0], + [0.0, 0.0, 0.0, 0.0, 180.0], + [0.0, 0.0, 0.0, 0.0, 180.0], + [0.0, 0.0, 0.0, 0.0, 180.0], + [0.0, 0.0, 0.0, 0.0, 180.0], + [0.0, 0.0, 0.0, 0.0, 180.0], + ], + dtype=torch.float32, + ) + poses = pose_rows[:, None, :] + + _, revisit_banks, _, _ = harness._build_preselected_causal_memory_banks( + committed_latents=latents, + source_frame_indices=frame_indices, + source_is_generated=None, + pose=poses, + action=None, + target_frame_indices=torch.tensor([[8]]), + target_pose=torch.tensor([[[0.0, 0.0, 0.0, 0.0, 0.0]]]), + target_action=None, + target_video_ids=None, + allow_generated_anchor=False, + anchor_indices=[], + anchor_pool_h=1, + anchor_pool_w=1, + anchor_diverse=False, + revisit_pool_h=1, + revisit_pool_w=1, + revisit_max_frames=1, + exclude_local_context_frames=4, + fov_overlap_threshold=0.30, + plucker_weight=0.1, + revisit_retrieval_kwargs={"high_quality_fov_threshold": 0.70}, + token_patch_size=2, + ) + + assert len(revisit_banks[0].records) == 1 + assert revisit_banks[0].records[0].metadata["dememwm_selected_frame_index"] == 1 + assert harness.project_call_values == [[1.0]] + + +def test_streaming_revisit_projection_uses_selected_frame_metadata(): + harness = Harness() + cache = StreamingCache(enabled=True, keep_raw_latents="all", keep_compressed_records=True) + latents = torch.arange(4, dtype=torch.float32).view(4, 1, 1, 1, 1).expand(4, 1, 3, 2, 2).clone() + cache.add_raw_latents(latents, torch.arange(4)[:, None]) + record = MemoryRecord( + tokens=torch.zeros((1, 8)), + mask=torch.ones(1, dtype=torch.bool), + source_start=0, + source_end=4, + frame_indices=torch.tensor([0, 1, 2, 3]), + pose=None, + source_type=MemorySourceType.PREFIX_GT, + is_generated=False, + chunk_id="frame", + metadata={ + "dememwm_revisit_metadata_only": True, + "dememwm_selected_frame_index": 1, + }, + ) + + projected = harness._project_streaming_revisit_records( + cache=cache, + batch_idx=0, + records=[record], + device=torch.device("cpu"), + dtype=torch.float32, + token_patch_size=2, + revisit_pool_h=1, + revisit_pool_w=1, + projection_cache={}, + ) + + assert len(projected) == 1 + assert projected[0].metadata["dememwm_selected_frame_index"] == 1 + assert harness.project_call_values == [[1.0]] diff --git a/tests/test_dememwm_retrieval.py b/tests/test_dememwm_retrieval.py new file mode 100644 index 0000000000000000000000000000000000000000..336dc15bb5a7dba2cbcc914693cd672a967e9bc6 --- /dev/null +++ b/tests/test_dememwm_retrieval.py @@ -0,0 +1,252 @@ +import pytest +import torch +from dememwm_import_helper import install_dememwm_namespace + +install_dememwm_namespace() +from algorithms.worldmem.dememwm.labels import RevisitCandidateLabel, plucker_overlap +from algorithms.worldmem.dememwm.retrieval import _select_greedy_coverage, deterministic_revisit_retrieval +from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType + + +def rec(frame, value, generated=False, pose=None, action=None, video_id="v0", chunk_id=None): + metadata = {"video_id": video_id} + if action is not None: + metadata["action"] = torch.tensor(action, dtype=torch.float32) + return MemoryRecord( + tokens=torch.full((2, 4), float(value)), + mask=torch.ones(2, dtype=torch.bool), + source_start=frame, + source_end=frame + 1, + frame_indices=torch.tensor([frame]), + pose=None if pose is None else torch.tensor(pose, dtype=torch.float32), + source_type=MemorySourceType.REVISIT, + is_generated=generated, + chunk_id=chunk_id or f"c{frame}", + metadata=metadata, + ) + + +def candidate_label(chunk_id, frame, fov, plucker, coverage_mask): + return RevisitCandidateLabel( + record=rec(frame, frame, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id=chunk_id), + valid=True, + gap_valid=True, + gap_to_target=10 - int(frame), + fov_overlap=float(fov), + plucker_overlap=float(plucker), + primary_overlap=float(fov), + coverage_mask=torch.tensor(coverage_mask, dtype=torch.bool), + reject_reasons=(), + ) + + +def test_plucker_cannot_outrank_higher_incremental_fov_gain(): + low_fov_high_plucker = candidate_label( + "low_fov_high_plucker", 0, 0.11, 1.0, [True] * 11 + [False] * 89 + ) + high_fov_low_plucker = candidate_label( + "high_fov_low_plucker", 1, 0.20, 0.0, [True] * 20 + [False] * 80 + ) + + selected, scores, gains = _select_greedy_coverage( + [low_fov_high_plucker, high_fov_low_plucker], topk=1, plucker_weight=0.10 + ) + + assert selected[0].record.chunk_id == "high_fov_low_plucker" + assert abs(scores[0] - gains[0]) < 1e-6 + assert abs(gains[0] - 0.20) < 1e-6 + + +def test_plucker_breaks_ties_after_fov_gain_and_overlap(): + low_plucker = candidate_label("low_plucker", 0, 0.20, 0.1, [True] * 20 + [False] * 80) + high_plucker = candidate_label("high_plucker", 0, 0.20, 0.9, [True] * 20 + [False] * 80) + + selected, _, _ = _select_greedy_coverage([low_plucker, high_plucker], topk=1, plucker_weight=0.10) + + assert selected[0].record.chunk_id == "high_plucker" + + +def test_plucker_overlap_handles_cuda_autocast_mixed_precision(): + if not torch.cuda.is_available(): + return + source_pose = torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0]], device="cuda", dtype=torch.float32) + target_pose = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], device="cuda", dtype=torch.float32) + with torch.autocast(device_type="cuda", dtype=torch.float16): + overlap = plucker_overlap(source_pose, target_pose) + assert overlap is not None + assert overlap > 0.0 + + +def test_revisit_candidates_require_causal_c_short_gap(): + records = [ + rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), + rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), + rec(9, 9, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), + ] + result = deterministic_revisit_retrieval( + records, + target_frame=6, + target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), + topk=5, + exclude_local_context_frames=4, + ) + assert [r.max_source_frame for r in result.records] == [1] + assert result.diagnostics["revisit_candidate_frame_count"] == 2 + assert result.diagnostics["revisit_candidate_count"] == 2 + assert result.diagnostics["valid_revisit_frame_count"] == 1 + assert result.diagnostics["valid_revisit_count"] == 1 + assert result.diagnostics["valid_candidate_label_count"] == 1 + assert result.diagnostics["valid_revisit_target_count"] == 1 + assert result.diagnostics["revisit_min_gap_to_target"] == 5 + assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1 + + +def test_revisit_abstains_when_no_valid_candidate(): + result = deterministic_revisit_retrieval([rec(2, 2), rec(3, 3)], target_frame=6, topk=2, exclude_local_context_frames=4) + assert result.records == [] + assert result.diagnostics["abstained"] is True + assert result.diagnostics["valid_revisit_mask"] == 0 + assert result.diagnostics["valid_revisit_target_count"] == 0 + assert result.diagnostics["no_valid_revisit_count"] == 1 + + +def test_revisit_retrieval_rejects_non_vectorized_inputs(): + with pytest.raises(ValueError, match="target_pose"): + deterministic_revisit_retrieval( + [rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])], + target_frame=10, + exclude_local_context_frames=4, + ) + + chunk_record = MemoryRecord( + tokens=torch.zeros((2, 4)), + mask=torch.ones(2, dtype=torch.bool), + source_start=0, + source_end=2, + frame_indices=torch.tensor([0, 1]), + pose=torch.tensor([[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]), + source_type=MemorySourceType.REVISIT, + is_generated=False, + chunk_id="chunk", + ) + with pytest.raises(ValueError, match="frame-level records"): + deterministic_revisit_retrieval( + [chunk_record], + target_frame=10, + target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), + exclude_local_context_frames=4, + ) + + +def test_fov_threshold_filters_candidates_without_action(): + records = [ + rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0]), + rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 180.0]), + rec(2, 2, pose=[100.0, 0.0, 0.0, 0.0, 0.0]), + ] + result = deterministic_revisit_retrieval( + records, + target_frame=10, + target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), + fov_overlap_threshold=0.5, + exclude_local_context_frames=4, + topk=4, + ) + assert result.diagnostics["selected_frame_record_ids"] == ["c0"] + assert result.diagnostics["valid_revisit_frame_count"] == 1 + assert result.diagnostics["valid_revisit_count"] == 1 + assert result.diagnostics["valid_revisit_target_count"] == 1 + assert result.diagnostics["best_selected_fov_overlap"] == 1.0 + assert result.diagnostics["revisit_best_selected_fov_overlap_max"] == 1.0 + assert result.diagnostics["best_selected_gap_frames"] == 10 + assert result.diagnostics["revisit_fov_overlap_max"] == 1.0 + assert result.diagnostics["revisit_plucker_overlap_max"] > 0.0 + + +def test_pose_preselect_uses_local_position_and_view_direction_before_fov(): + records = [ + rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 180.0], chunk_id="opposite_same_position"), + rec(1, 1, pose=[90.0, 0.0, 0.0, 0.0, 0.0], chunk_id="far_same_direction"), + rec(2, 2, pose=[1.0, 0.0, 0.0, 0.0, 0.0], chunk_id="near_same_direction"), + ] + result = deterministic_revisit_retrieval( + records, + target_frame=10, + target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), + fov_overlap_threshold=0.0, + fov_radius=30.0, + exclude_local_context_frames=4, + topk=1, + pose_preselect_topk=1, + ) + + assert result.diagnostics["selected_frame_record_ids"] == ["near_same_direction"] + assert result.diagnostics["revisit_pose_preselect_input_count"] == 3 + assert result.diagnostics["revisit_pose_preselect_scored_count"] == 3 + assert result.diagnostics["revisit_pose_preselect_selected_count"] == 1 + assert result.diagnostics["revisit_exact_fov_candidate_count"] == 1 + assert result.diagnostics["revisit_vectorized_frame_scorer_used"] == 1 + assert abs(result.diagnostics["revisit_pose_preselect_min_distance"] - (1.0 / 30.0)) < 1e-6 + + +def test_selected_frame_carries_frame_metadata_for_projection(): + result = deterministic_revisit_retrieval( + [rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="frame_1")], + target_frame=8, + target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), + fov_overlap_threshold=0.30, + high_quality_fov_threshold=0.70, + exclude_local_context_frames=4, + topk=1, + ) + + assert result.diagnostics["selected_frame_record_ids"] == ["frame_1"] + assert result.selected_frame_ids == [1] + assert result.records[0].metadata["dememwm_selected_frame_index"] == 1 + assert result.records[0].metadata["dememwm_selected_frame_passes_high_quality"] is True + assert result.diagnostics["best_selected_frame_index"] == 1 + assert result.diagnostics["best_selected_frame_fov_overlap"] == 1.0 + assert result.diagnostics["valid_revisit_target_count"] == 1 + + +def test_high_quality_threshold_is_selected_target_diagnostic_only(): + result = deterministic_revisit_retrieval( + [rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0])], + target_frame=10, + target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 60.0]), + fov_overlap_threshold=0.30, + high_quality_fov_threshold=0.70, + exclude_local_context_frames=4, + topk=1, + ) + assert result.diagnostics["selected_frame_record_ids"] == ["c0"] + assert result.diagnostics["valid_revisit_count"] == 1 + assert result.diagnostics["valid_revisit_target_count"] == 0 + assert 0.30 <= result.diagnostics["best_selected_fov_overlap"] < 0.70 + + +def test_video_metadata_does_not_filter_revisit_candidates(): + records = [ + rec(0, 0, video_id="v0", pose=[0.0, 0.0, 0.0, 0.0, 0.0]), + rec(1, 1, video_id="other", pose=[0.0, 0.0, 0.0, 0.0, 0.0]), + ] + result = deterministic_revisit_retrieval( + records, + target_frame=10, + target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), + target_video_id="v0", + exclude_local_context_frames=4, + topk=4, + ) + assert result.diagnostics["selected_frame_record_ids"] == ["c1", "c0"] + assert result.diagnostics["valid_revisit_count"] == 2 + + +def test_tie_breaking_is_overlap_then_age_then_source_then_record_id(): + records = [ + rec(0, 0, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="b"), + rec(1, 1, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="a"), + rec(2, 2, pose=[0.0, 0.0, 0.0, 0.0, 0.0], chunk_id="c"), + ] + result = deterministic_revisit_retrieval(records, target_frame=10, target_pose=torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0]), exclude_local_context_frames=4, topk=3) + assert result.diagnostics["selected_frame_record_ids"] == ["c", "a", "b"] diff --git a/tests/test_dememwm_schedules.py b/tests/test_dememwm_schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f63fad1243209589cddce2d534c0fe2779f68c --- /dev/null +++ b/tests/test_dememwm_schedules.py @@ -0,0 +1,79 @@ + +import pytest +from types import SimpleNamespace + +from algorithms.worldmem.dememwm.schedules import ( + compute_stream_gates, + noise_bucket_from_denoising_fraction, + noise_bucket_from_noise_levels, + noise_bucket_ids_from_noise_levels, + resolve_curriculum, +) + + +def test_stage_1_uses_all_streams(): + gates = compute_stream_gates("stage_1") + assert gates.anchor_enabled + assert gates.dynamic_enabled + assert gates.revisit_enabled + + +def test_stage_2_uses_all_streams(): + gates = compute_stream_gates("stage_2") + assert gates.anchor_enabled + assert gates.dynamic_enabled + assert gates.revisit_enabled + + +def test_two_stage_curriculum_switches_at_full_stage_start(): + cfg = SimpleNamespace( + curriculum=SimpleNamespace( + enabled=True, + full_stage_start_step=10, + freeze_vae=True, + dit_freeze=SimpleNamespace(enabled=True), + lr=SimpleNamespace(dememwm_modules=1.0e-4, memory_adapters=1.0e-4, full_dit=1.0e-5), + ) + ) + + stage_1 = resolve_curriculum(cfg, 9) + stage_2 = resolve_curriculum(cfg, 10) + + assert stage_1.stage == "stage_1" + assert stage_1.anchor_enabled and stage_1.dynamic_enabled and stage_1.revisit_enabled + assert stage_1.dit_train_state == "frozen" + assert not hasattr(stage_1, "dit_late_blocks_trainable") + assert all("late" not in key for key in stage_1.diagnostics()) + assert stage_2.stage == "stage_2" + assert stage_2.anchor_enabled and stage_2.dynamic_enabled and stage_2.revisit_enabled + assert stage_2.dit_train_state == "full" + + +def test_debug_force_all_streams_overrides_stage(): + gates = compute_stream_gates("stage_1", debug_force_all_streams=True) + assert gates.anchor_enabled and gates.dynamic_enabled and gates.revisit_enabled + assert gates.reason == "debug_force_all_streams" + + +def test_unknown_stage_fails(): + with pytest.raises(ValueError): + compute_stream_gates("unknown") + + +def test_noise_bucket_from_denoising_fraction(): + assert noise_bucket_from_denoising_fraction(0.0) == "high" + assert noise_bucket_from_denoising_fraction(0.5) == "mid" + assert noise_bucket_from_denoising_fraction(1.0) == "low" + + +def test_noise_bucket_from_training_noise_levels(): + import torch + assert noise_bucket_from_noise_levels(torch.tensor([9, 8]), 10) == "high" + assert noise_bucket_from_noise_levels(torch.tensor([5, 4]), 10) == "mid" + assert noise_bucket_from_noise_levels(torch.tensor([1, 0]), 10) == "low" + + +def test_noise_bucket_ids_from_training_noise_levels(): + import torch + bucket_ids = noise_bucket_ids_from_noise_levels(torch.tensor([[9, 4, 0]]), 10) + assert bucket_ids.tolist() == [[0, 1, 2]] diff --git a/tests/test_dememwm_stream_grad.py b/tests/test_dememwm_stream_grad.py new file mode 100644 index 0000000000000000000000000000000000000000..fe4a6d2efb10bb04486ffe54fccd86fe004b46fa --- /dev/null +++ b/tests/test_dememwm_stream_grad.py @@ -0,0 +1,35 @@ +import torch +from dememwm_import_helper import install_dememwm_namespace + +install_dememwm_namespace() +from algorithms.worldmem.dememwm.algorithm import MemoryDiTMixin +from algorithms.worldmem.dememwm.types import MemoryRecord, MemorySourceType + + +def test_records_to_stream_preserves_grad_to_record_tokens(): + record_tokens = torch.full((2, 4), 3.0) + record_tokens.requires_grad_() + record = MemoryRecord( + tokens=record_tokens, + mask=torch.ones(2, dtype=torch.bool), + source_start=0, + source_end=1, + frame_indices=torch.tensor([0]), + pose=None, + source_type=MemorySourceType.REVISIT, + is_generated=False, + chunk_id="grad", + ) + tokens, mask, max_source = MemoryDiTMixin._records_to_stream( + object(), + [record], + max_tokens=4, + hidden_size=4, + device=torch.device("cpu"), + dtype=torch.float32, + ) + assert mask.tolist() == [True, True, False, False] + assert max_source == 0 + tokens.sum().backward() + assert record_tokens.grad is not None + assert record_tokens.grad.abs().sum().item() > 0 diff --git a/train_dememwm_full_berzelius.sh b/train_dememwm_full_berzelius.sh new file mode 100755 index 0000000000000000000000000000000000000000..b982d33733c9af9b208eaf4b2eb798bd735458a1 --- /dev/null +++ b/train_dememwm_full_berzelius.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash +#SBATCH --job-name=dememwm-full +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=8 +#SBATCH --cpus-per-task=8 +#SBATCH --time=72:00:00 +#SBATCH --output=slurm_logs/dememwm-full-%j.out +#SBATCH --error=slurm_logs/dememwm-full-%j.out +#SBATCH --account=berzelius-2025-436 +#SBATCH --gres=gpu:A100-SXM4-80GB:8 + +module load buildenv-gcccuda/12.1.1-gcc12.3.0 +source $(conda info --base)/etc/profile.d/conda.sh + +export PYTHONPATH="./:$PYTHONPATH" +export HF_HOME=/proj/cvl/users/x_fahkh2/caches +export TORCH_HOME=/proj/cvl/users/x_fahkh2/caches +export PIP_CACHE_DIR=/proj/cvl/users/x_fahkh2/caches +export TMPDIR=/proj/cvl/users/x_fahkh2/caches +export TRITON_CACHE_DIR=/proj/cvl/users/x_fahkh2/caches +export CUDA_HOME=$CUDA_ROOT +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export WANDB_DISABLED=true +export HYDRA_FULL_ERROR=1 + +OUTPUT_DIR=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/dememwm_full_berzelius_8a100_bs8_global64_350k + +srun python -m main \ + +name=train_dememwm_full_berzelius_8a100_bs8_global64_350k \ + +output_dir="${OUTPUT_DIR}/" \ + auto_resume=true \ + "experiment.tasks=[training]" \ + algorithm=dememwm_memory_dit \ + +customized_load=true \ + +seperate_load=true \ + +diffusion_model_path=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/oasis500m.safetensors \ + +vae_path=/proj/cvl/users/x_fahkh2/WorldMem_Repro/checkpoints/vit-l-20.safetensors \ + +only_tune_memory=false \ + dataset=video_minecraft_latent \ + dataset.save_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft \ + dataset.precomputed_feature_dir=/proj/cvl/users/x_fahkh2/WorldMem_Repro/datasets/minecraft/vae_features \ + dataset.n_frames=1000 \ + +dataset.n_frames_valid=1100 \ + +dataset.customized_validation=true \ + +dataset.memory_condition_length=0 \ + +dataset.wo_updown=false \ + +dataset.angle_range=180 \ + +dataset.pos_range=8 \ + ++algorithm.n_tokens=4 \ + "algorithm.x_shape=[16,18,32]" \ + ++algorithm.context_frames=100 \ + ++algorithm.log_video=true \ + ++algorithm.diffusion.sampling_timesteps=20 \ + ++algorithm.dememwm.debug_force_all_streams=false \ + ++algorithm.dememwm.generated_history_proxy.enabled=true \ + ++algorithm.dememwm.generated_history_proxy.start_step=40000 \ + ++algorithm.dememwm.generated_history_proxy.ramp_steps=40000 \ + ++algorithm.dememwm.generated_history_proxy.max_prob=0.25 \ + ++algorithm.dememwm.generated_history_proxy.noise_std=0.25 \ + ++algorithm.dememwm.generated_history_proxy.dropout_prob=0.0 \ + ++algorithm.dememwm.anchor.enabled=true \ + ++algorithm.dememwm.anchor.anchor_indices=[0,1,2,3] \ + ++algorithm.dememwm.anchor.diverse_selection=true \ + ++algorithm.dememwm.anchor.compress.downsample_ratio=3 \ + ++algorithm.dememwm.anchor.allow_generated_as_anchor=false \ + ++algorithm.dememwm.dynamic.enabled=true \ + ++algorithm.dememwm.dynamic.exclude_latest_local_frames=4 \ + ++algorithm.dememwm.dynamic.recent_frames=4 \ + ++algorithm.dememwm.revisit.enabled=true \ + ++algorithm.dememwm.revisit.deterministic_pose_retrieval=true \ + ++algorithm.dememwm.revisit.fov_overlap_threshold=0.30 \ + ++algorithm.dememwm.revisit.high_quality_fov_threshold=0.70 \ + ++algorithm.dememwm.revisit.pose_preselect_topk=64 \ + ++algorithm.dememwm.revisit.fov_yaw_samples=25 \ + ++algorithm.dememwm.revisit.fov_pitch_samples=20 \ + ++algorithm.dememwm.revisit.fov_depth_samples=20 \ + ++algorithm.dememwm.revisit.plucker_weight=0.10 \ + ++algorithm.dememwm.revisit.max_frames=2 \ + ++algorithm.dememwm.revisit.compress.downsample_ratio=3 \ + ++algorithm.dememwm.stage_policy.noise_bucket_logging=true \ + ++algorithm.dememwm.cache.enabled=true \ + ++algorithm.dememwm.cache.device=cpu \ + ++algorithm.dememwm.cache.keep_raw_latents=all \ + ++algorithm.dememwm.cache.keep_compressed_records=true \ + ++algorithm.dememwm.cache.eviction_policy=none \ + ++algorithm.dememwm.cache.no_evict=true \ + ++algorithm.dememwm.cache.clear_between_videos=true \ + ++algorithm.dememwm.cache.max_records=null \ + ++algorithm.dememwm.cache.max_slots=null \ + ++algorithm.dememwm.cache.on_capacity_exceeded=warn \ + ++algorithm.dememwm.curriculum.enabled=true \ + ++algorithm.dememwm.curriculum.full_stage_start_step=20000 \ + ++algorithm.dememwm.curriculum.freeze_vae=true \ + ++algorithm.dememwm.curriculum.dit_freeze.enabled=true \ + ++algorithm.dememwm.curriculum.lr.dememwm_modules=4.0e-5 \ + ++algorithm.dememwm.curriculum.lr.memory_adapters=4.0e-5 \ + ++algorithm.dememwm.curriculum.lr.full_dit=1.0e-5 \ + experiment.training.batch_size=8 \ + experiment.training.optim.accumulate_grad_batches=1 \ + experiment.validation.batch_size=1 \ + experiment.validation.limit_batch=8 \ + experiment.training.checkpointing.every_n_train_steps=2000 \ + experiment.validation.val_every_n_step=2000 \ + experiment.training.max_steps=350000 diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/ckpt_utils.py b/utils/ckpt_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9aa27f102f50b160b0c7ff5ed5b713f7bfe77b8b --- /dev/null +++ b/utils/ckpt_utils.py @@ -0,0 +1,32 @@ +from pathlib import Path +import wandb + + +def is_run_id(run_id: str) -> bool: + """Check if a string is a run ID.""" + return len(run_id) == 8 and run_id.isalnum() + + +def version_to_int(artifact) -> int: + """Convert versions of the form vX to X. For example, v12 to 12.""" + return int(artifact.version[1:]) + + +def download_latest_checkpoint(run_path: str, download_dir: Path) -> Path: + api = wandb.Api() + run = api.run(run_path) + + # Find the latest saved model checkpoint. + latest = None + for artifact in run.logged_artifacts(): + if artifact.type != "model" or artifact.state != "COMMITTED": + continue + + if latest is None or version_to_int(artifact) > version_to_int(latest): + latest = artifact + + # Download the checkpoint. + download_dir.mkdir(exist_ok=True, parents=True) + root = download_dir / run_path + latest.download(root=root) + return root / "model.ckpt" diff --git a/utils/cluster_utils.py b/utils/cluster_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae29488a44a59888ca7a97feb55e39447b2b393 --- /dev/null +++ b/utils/cluster_utils.py @@ -0,0 +1,40 @@ +""" +utils for submitting to clusters, such as slurm +""" + +import os +from omegaconf import DictConfig, OmegaConf +from datetime import datetime +from pathlib import Path + +from utils.print_utils import cyan + +# This is set below. +REPO_DIR = None + + +def submit_slurm_job( + cfg: DictConfig, + python_args: str, + project_root: Path, +): + log_dir = project_root / "slurm_logs" / f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-{cfg.name}" + log_dir.mkdir(exist_ok=True, parents=True) + (project_root / "slurm_logs" / "latest").unlink(missing_ok=True) + (project_root / "slurm_logs" / "latest").symlink_to(log_dir, target_is_directory=True) + + params = dict(name=cfg.name, log_dir=log_dir, project_root=project_root, python_args=python_args) + params.update(cfg.cluster.params) + + slurm_script = cfg.cluster.launch_template.format(**params) + + slurm_script_path = log_dir / "job.slurm" + with slurm_script_path.open("w") as f: + f.write(slurm_script) + + os.system(f"chmod +x {slurm_script_path}") + os.system(f"sbatch {slurm_script_path}") + + print(f"\n{cyan('script:')} {slurm_script_path}\n{cyan('slurm errors and logs:')} {log_dir}\n") + + return log_dir diff --git a/utils/distributed_utils.py b/utils/distributed_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dde3e98aef29b3fe6b7eb1c58f589b1d0d6c99ce --- /dev/null +++ b/utils/distributed_utils.py @@ -0,0 +1,3 @@ +import wandb + +is_rank_zero = wandb.run is not None diff --git a/utils/logging_utils.py b/utils/logging_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec77ee1678f0c61ff9047b60d9fab328ad026cea --- /dev/null +++ b/utils/logging_utils.py @@ -0,0 +1,560 @@ +from typing import Optional +import wandb +import numpy as np +import torch +import os + +import matplotlib.pyplot as plt +import cv2 +import matplotlib.pyplot as plt +from tqdm import trange, tqdm +import matplotlib.animation as animation +from pathlib import Path +import imageio + +plt.set_loglevel("warning") + +from torchmetrics.functional import mean_squared_error, peak_signal_noise_ratio +from torchmetrics.functional import ( + structural_similarity_index_measure, + universal_image_quality_index, +) +from algorithms.common.metrics import ( + FrechetVideoDistance, + LearnedPerceptualImagePatchSimilarity, + FrechetInceptionDistance, +) + + +# FIXME: clean up & check this util +def log_video( + observation_hat, + observation_gt=None, + step=0, + namespace="train", + prefix="video", + context_frames=0, + color=(255, 0, 0), + logger=None, + fps=15, + format="mp4", + save_local=True, + local_save_dir=None, +): + """ + take in video tensors in range [-1, 1] and log into wandb + + :param observation_hat: predicted observation tensor of shape (frame, batch, channel, height, width) + :param observation_gt: ground-truth observation tensor of shape (frame, batch, channel, height, width) + :param step: an int indicating the step number + :param namespace: a string specify a name space this video logging falls under, e.g. train, val + :param prefix: a string specify a prefix for the video name + :param context_frames: an int indicating how many frames in observation_hat are ground truth given as context + :param color: a tuple of 3 numbers specifying the color of the border for ground truth frames + :param logger: optional logger to use. use global wandb if not specified + :param fps: frames per second for the video (default: 15) + :param format: video format, either "mp4" or "gif" (default: "mp4") + :param save_local: whether to save videos to local disk (default: True) + :param local_save_dir: directory to save local videos. If None, uses hydra output dir + """ + import cv2 + import hydra + from pathlib import Path + + # Get local rank for distributed training + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + + if not logger: + logger = wandb + + # Prepare video tensors + observation_hat_np = observation_hat.detach().cpu().numpy() + if observation_gt is not None: + observation_gt_np = observation_gt.detach().cpu().numpy() + else: + observation_gt_np = None + + # Normalize to 0-255 + observation_hat_np = np.transpose(np.clip(observation_hat_np, a_min=0.0, a_max=1.0) * 255, (1, 0, 2, 3, 4)).astype(np.uint8) + if observation_gt_np is not None: + observation_gt_np = np.transpose(np.clip(observation_gt_np, a_min=0.0, a_max=1.0) * 255, (1, 0, 2, 3, 4)).astype(np.uint8) + + n_samples = len(observation_hat_np) + + # Setup local save directory + if save_local: + if local_save_dir is None: + try: + hydra_cfg = hydra.core.hydra_config.HydraConfig.get() + output_dir = Path(hydra_cfg.runtime.output_dir) + except: + output_dir = Path.cwd() / "outputs" + local_save_dir = output_dir / "videos" / namespace + else: + local_save_dir = Path(local_save_dir) + + local_save_dir.mkdir(parents=True, exist_ok=True) + + # Save pred videos locally + pred_dir = local_save_dir / "pred" + pred_dir.mkdir(parents=True, exist_ok=True) + + # Save gt videos locally if available + if observation_gt_np is not None: + gt_dir = local_save_dir / "gt" + gt_dir.mkdir(parents=True, exist_ok=True) + + # Save videos + for i in range(n_samples): + video_pred = observation_hat_np[i] # (T, C, H, W) + + if save_local: + # Save prediction video + if step is not None: + video_filename_pred = f"{prefix}_{i}_rank{local_rank}_step{step}.{format}" + else: + video_filename_pred = f"{prefix}_{i}_rank{local_rank}.{format}" + + video_path_pred = pred_dir / video_filename_pred + _save_video_to_file(video_pred, str(video_path_pred), fps) + + # Save ground truth video if available + if observation_gt_np is not None: + video_gt = observation_gt_np[i] + if step is not None: + video_filename_gt = f"{prefix}_{i}_rank{local_rank}_step{step}.{format}" + else: + video_filename_gt = f"{prefix}_{i}_rank{local_rank}.{format}" + + video_path_gt = gt_dir / video_filename_gt + _save_video_to_file(video_gt, str(video_path_gt), fps) + + # Log to wandb (only rank 0 to avoid duplicate logging) + if local_rank == 0 and logger: + # Concatenate pred and gt side by side for visualization + if observation_gt_np is not None: + video_combined = torch.cat([ + torch.from_numpy(observation_hat_np), + torch.from_numpy(observation_gt_np) + ], -2).numpy() # Concatenate along width + logger.log( + { + f"{namespace}/{prefix}_{i}": wandb.Video(video_combined[i], fps=fps, format=format), + f"trainer/global_step": step, + } + ) + else: + logger.log( + { + f"{namespace}/{prefix}_{i}": wandb.Video(video_pred, fps=fps, format=format), + f"trainer/global_step": step, + } + ) + + +def _save_video_to_file(video_tensor, output_path, fps=15): + """ + Save a video tensor to file using imageio (better compatibility than cv2). + + :param video_tensor: numpy array of shape (T, C, H, W) with values in [0, 255] + :param output_path: path to save the video + :param fps: frames per second + """ + + T, C, H, W = video_tensor.shape + + # Convert from (T, C, H, W) to (T, H, W, C) + video_tensor = np.transpose(video_tensor, (0, 2, 3, 1)) + + # Ensure uint8 + video_tensor = video_tensor.astype(np.uint8) + + # Save using imageio with H.264 codec (best compatibility) + writer = imageio.get_writer( + output_path, + fps=fps, + codec='libx264', # H.264 codec - widely supported + quality=8, # Good quality (scale 0-10, 10 is best) + pixelformat='yuv420p', # Standard pixel format for compatibility + macro_block_size=1 # Better quality + ) + + for frame in video_tensor: + writer.append_data(frame) + + writer.close() + + + + +def get_validation_metrics_for_videos( + observation_hat, + observation_gt, + lpips_model: Optional[LearnedPerceptualImagePatchSimilarity] = None, + fid_model: Optional[FrechetInceptionDistance] = None, + fvd_model: Optional[FrechetVideoDistance] = None, + lpips_batch_size: int = 100, +): + """ + :param observation_hat: predicted observation tensor of shape (frame, batch, channel, height, width) + :param observation_gt: ground-truth observation tensor of shape (frame, batch, channel, height, width) + :param lpips_model: a LearnedPerceptualImagePatchSimilarity object from algorithm.common.metrics + :param fid_model: a FrechetInceptionDistance object from algorithm.common.metrics + :param fvd_model: a FrechetVideoDistance object from algorithm.common.metrics + :param lpips_batch_size: batch size for LPIPS calculation to avoid OOM (default: 100) + :return: a tuple of metrics + """ + frame, batch, channel, height, width = observation_hat.shape + output_dict = {} + observation_gt = observation_gt.type_as(observation_hat) # some metrics don't fully support fp16 + + if frame < 9: + fvd_model = None # FVD requires at least 9 frames + + observation_hat = observation_hat.float() + observation_gt = observation_gt.float() + + # Clip to [0, 1] range before computing metrics (matching video saving behavior) + observation_hat_clipped = torch.clamp(observation_hat, 0.0, 1.0) + observation_gt_clipped = torch.clamp(observation_gt, 0.0, 1.0) + + # Compute video-wise PSNR: frame-wise average per video, then average across videos + video_psnr_list = [] + for b in range(batch): + frame_psnr_for_video = [] + for f in range(frame): + frame_psnr = peak_signal_noise_ratio(observation_hat_clipped[f, b], observation_gt_clipped[f, b], data_range=1.0) + frame_psnr_for_video.append(frame_psnr) + video_psnr = torch.stack(frame_psnr_for_video).mean() + video_psnr_list.append(video_psnr) + output_dict["psnr"] = torch.stack(video_psnr_list).mean() + + observation_hat_clipped = observation_hat_clipped.view(-1, channel, height, width) + observation_gt_clipped = observation_gt_clipped.view(-1, channel, height, width) + + # Compute MSE on clipped data + output_dict["mse"] = mean_squared_error(observation_hat_clipped, observation_gt_clipped) + # output_dict["ssim"] = structural_similarity_index_measure(observation_hat_clipped, observation_gt_clipped, data_range=1.0) + # output_dict["uiqi"] = universal_image_quality_index(observation_hat_clipped, observation_gt_clipped) + + # LPIPS computation + if lpips_model is not None: + # Process LPIPS in batches to avoid OOM + num_frames = observation_hat_clipped.shape[0] + + for i in range(0, num_frames, lpips_batch_size): + batch_end = min(i + lpips_batch_size, num_frames) + observation_hat_batch = observation_hat_clipped[i:batch_end] + observation_gt_batch = observation_gt_clipped[i:batch_end] + + lpips_model.update(observation_hat_batch, observation_gt_batch) + + # Free GPU memory after each batch + del observation_hat_batch, observation_gt_batch + torch.cuda.empty_cache() + + lpips = lpips_model.compute().item() + # Reset the states of non-functional metrics + output_dict["lpips"] = lpips + lpips_model.reset() + + # FID computation + if fid_model is not None: + observation_hat_uint8 = (observation_hat_clipped * 255).type(torch.uint8) + observation_gt_uint8 = (observation_gt_clipped * 255).type(torch.uint8) + fid_model.update(observation_gt_uint8, real=True) + fid_model.update(observation_hat_uint8, real=False) + fid = fid_model.compute() + output_dict["fid"] = fid + # Reset the states of non-functional metrics + fid_model.reset() + + return output_dict + + +def is_grid_env(env_id): + return "maze2d" in env_id or "diagonal2d" in env_id + + +def get_maze_grid(env_id): + # import gym + # maze_string = gym.make(env_id).str_maze_spec + if "large" in env_id: + maze_string = "############\\#OOOO#OOOOO#\\#O##O#O#O#O#\\#OOOOOO#OOO#\\#O####O###O#\\#OO#O#OOOOO#\\##O#O#O#O###\\#OO#OOO#OGO#\\############" + if "medium" in env_id: + maze_string = "########\\#OO##OO#\\#OO#OOO#\\##OOO###\\#OO#OOO#\\#O#OO#O#\\#OOO#OG#\\########" + if "umaze" in env_id: + maze_string = "#####\\#GOO#\\###O#\\#OOO#\\#####" + lines = maze_string.split("\\") + grid = [line[1:-1] for line in lines] + return grid[1:-1] + + +def get_random_start_goal(env_id, batch_size): + maze_grid = get_maze_grid(env_id) + s2i = {"O": 0, "#": 1, "G": 2} + maze_grid = [[s2i[s] for s in r] for r in maze_grid] + maze_grid = np.array(maze_grid) + x, y = np.nonzero(maze_grid == 0) + indices = np.random.randint(len(x), size=batch_size) + start = np.stack([x[indices], y[indices]], -1) + 1 + x, y = np.nonzero(maze_grid == 2) + goal = np.concatenate([x, y], -1) + goal = np.tile(goal[None, :], (batch_size, 1)) + 1 + return start, goal + + +def plot_maze_layout(ax, maze_grid): + ax.clear() + + if maze_grid is not None: + for i, row in enumerate(maze_grid): + for j, cell in enumerate(row): + if cell == "#": + square = plt.Rectangle((i + 0.5, j + 0.5), 1, 1, edgecolor="black", facecolor="black") + ax.add_patch(square) + + ax.set_aspect("equal") + ax.grid(True, color="white", linewidth=4) + ax.set_axisbelow(True) + ax.spines["top"].set_linewidth(4) + ax.spines["right"].set_linewidth(4) + ax.spines["bottom"].set_linewidth(4) + ax.spines["left"].set_linewidth(4) + ax.set_facecolor("lightgray") + ax.tick_params( + axis="both", + which="both", + bottom=False, + top=False, + left=False, + right=False, + labelbottom=False, + labelleft=False, + ) + ax.set_xticks(np.arange(0.5, len(maze_grid) + 0.5)) + ax.set_yticks(np.arange(0.5, len(maze_grid[0]) + 0.5)) + ax.set_xlim(0.5, len(maze_grid) + 0.5) + ax.set_ylim(0.5, len(maze_grid[0]) + 0.5) + ax.grid(True, color="white", which="minor", linewidth=4) + + +def plot_start_goal(ax, start_goal: None): + def draw_star(center, radius, num_points=5, color="black"): + angles = np.linspace(0.0, 2 * np.pi, num_points, endpoint=False) + 5 * np.pi / (2 * num_points) + inner_radius = radius / 2.0 + + points = [] + for angle in angles: + points.extend( + [ + center[0] + radius * np.cos(angle), + center[1] + radius * np.sin(angle), + center[0] + inner_radius * np.cos(angle + np.pi / num_points), + center[1] + inner_radius * np.sin(angle + np.pi / num_points), + ] + ) + + star = plt.Polygon(np.array(points).reshape(-1, 2), color=color) + ax.add_patch(star) + + start_x, start_y = start_goal[0] + start_outer_circle = plt.Circle((start_x, start_y), 0.16, facecolor="white", edgecolor="black") + ax.add_patch(start_outer_circle) + start_inner_circle = plt.Circle((start_x, start_y), 0.08, color="black") + ax.add_patch(start_inner_circle) + + goal_x, goal_y = start_goal[1] + goal_outer_circle = plt.Circle((goal_x, goal_y), 0.16, facecolor="white", edgecolor="black") + ax.add_patch(goal_outer_circle) + draw_star((goal_x, goal_y), radius=0.08) + + +def make_trajectory_images(env_id, trajectory, batch_size, start, goal, plot_end_points=True): + images = [] + for batch_idx in range(batch_size): + fig, ax = plt.subplots() + if is_grid_env(env_id): + maze_grid = get_maze_grid(env_id) + else: + maze_grid = None + plot_maze_layout(ax, maze_grid) + ax.scatter(trajectory[:, batch_idx, 0], trajectory[:, batch_idx, 1], c=np.arange(len(trajectory)), cmap="Reds"), + if plot_end_points: + start_goal = (start[batch_idx], goal[batch_idx]) + plot_start_goal(ax, start_goal) + # plt.title(f"sample_{batch_idx}") + fig.tight_layout() + fig.canvas.draw() + img_shape = fig.canvas.get_width_height()[::-1] + (4,) + img = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8).copy().reshape(img_shape) + images.append(img) + + plt.close() + return images + + +def make_convergence_animation( + env_id, + plan_history, + trajectory, + start, + goal, + open_loop_horizon, + namespace, + interval=100, + plot_end_points=True, + batch_idx=0, +): + # - plan_history: contains for each time step all the MPC predicted plans for each pyramid noise level. + # Structured as a list of length (episode_len // open_loop_horizon), where each + # element corresponds to a control_time_step and stores a list of length pyramid_height, + # where each element is a plan at a different pyramid noise level and stored as a tensor of + # shape (episode_len // open_loop_horizon - control_time_step, + # batch_size, x_stacked_shape) + + # select index and prune history + start, goal = start[batch_idx], goal[batch_idx] + trajectory = trajectory[:, batch_idx] + plan_history = [[pm[:, batch_idx] for pm in pt] for pt in plan_history] + trajectory, plan_history = prune_history(plan_history, trajectory, goal, open_loop_horizon) + + # animate the convergence of the first plan + fig, ax = plt.subplots() + if "large" in env_id: + fig.set_size_inches(3.5, 5) + else: + fig.set_size_inches(3, 3) + ax.set_axis_off() + fig.subplots_adjust(left=0, bottom=0, right=1, top=1) + + if is_grid_env(env_id): + maze_grid = get_maze_grid(env_id) + else: + maze_grid = None + + def update(frame): + plot_maze_layout(ax, maze_grid) + + plan_history_m = plan_history[0][frame] + plan_history_m = plan_history_m.numpy() + ax.scatter( + plan_history_m[:, 0], + plan_history_m[:, 1], + c=np.arange(len(plan_history_m))[::-1], + cmap="Reds", + ) + + if plot_end_points: + plot_start_goal(ax, (start, goal)) + + frames = tqdm(range(len(plan_history[0])), desc="Making convergence animation") + ani = animation.FuncAnimation(fig, update, frames=frames, interval=interval) + prefix = wandb.run.id if wandb.run is not None else env_id + filename = f"/tmp/{prefix}_{namespace}_convergence.mp4" + ani.save(filename, writer="ffmpeg", fps=5) + return filename + + +def prune_history(plan_history, trajectory, goal, open_loop_horizon): + dist = np.linalg.norm( + trajectory[:, :2] - np.array(goal)[None], + axis=-1, + ) + reached = dist < 0.2 + if reached.any(): + cap_idx = np.argmax(reached) + trajectory = trajectory[: cap_idx + open_loop_horizon + 1] + plan_history = plan_history[: cap_idx // open_loop_horizon + 2] + + pruned_plan_history = [] + for plans in plan_history: + pruned_plan_history.append([]) + for m in range(len(plans)): + plan = plans[m] + pruned_plan_history[-1].append(plan) + plan = pruned_plan_history[-1][-1] + dist = np.linalg.norm(plan.numpy()[:, :2] - np.array(goal)[None], axis=-1) + reached = dist < 0.2 + if reached.any(): + cap_idx = np.argmax(reached) + 1 + pruned_plan_history[-1] = [p[:cap_idx] for p in pruned_plan_history[-1]] + return trajectory, pruned_plan_history + + +def make_mpc_animation( + env_id, + plan_history, + trajectory, + start, + goal, + open_loop_horizon, + namespace, + interval=100, + plot_end_points=True, + batch_idx=0, +): + # - plan_history: contains for each time step all the MPC predicted plans for each pyramid noise level. + # Structured as a list of length (episode_len // open_loop_horizon), where each + # element corresponds to a control_time_step and stores a list of length pyramid_height, + # where each element is a plan at a different pyramid noise level and stored as a tensor of + # shape (episode_len // open_loop_horizon - control_time_step, + # batch_size, x_stacked_shape) + + # select index and prune history + start, goal = start[batch_idx], goal[batch_idx] + trajectory = trajectory[:, batch_idx] + plan_history = [[pm[:, batch_idx] for pm in pt] for pt in plan_history] + trajectory, plan_history = prune_history(plan_history, trajectory, goal, open_loop_horizon) + + # animate the convergence of the plans + fig, ax = plt.subplots() + if "large" in env_id: + fig.set_size_inches(3.5, 5) + else: + fig.set_size_inches(3, 3) + ax.set_axis_off() + fig.subplots_adjust(left=0, bottom=0, right=1, top=1) + trajectory_colors = np.linspace(0, 1, len(trajectory)) + + if is_grid_env(env_id): + maze_grid = get_maze_grid(env_id) + else: + maze_grid = None + + def update(frame): + control_time_step = 0 + while frame >= 0: + frame -= len(plan_history[control_time_step]) + control_time_step += 1 + control_time_step -= 1 + m = frame + len(plan_history[control_time_step]) + num_steps_taken = 1 + open_loop_horizon * control_time_step + plot_maze_layout(ax, maze_grid) + + plan_history_m = plan_history[control_time_step][m] + plan_history_m = plan_history_m.numpy() + ax.scatter( + trajectory[:num_steps_taken, 0], + trajectory[:num_steps_taken, 1], + c=trajectory_colors[:num_steps_taken], + cmap="Blues", + ) + ax.scatter( + plan_history_m[:, 0], + plan_history_m[:, 1], + c=np.arange(len(plan_history_m))[::-1], + cmap="Reds", + ) + + if plot_end_points: + plot_start_goal(ax, (start, goal)) + + num_frames = sum([len(p) for p in plan_history]) + frames = tqdm(range(num_frames), desc="Making MPC animation") + ani = animation.FuncAnimation(fig, update, frames=frames, interval=interval) + prefix = wandb.run.id if wandb.run is not None else env_id + filename = f"/tmp/{prefix}_{namespace}_mpc.mp4" + ani.save(filename, writer="ffmpeg", fps=5) + + return filename diff --git a/utils/print_utils.py b/utils/print_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c9052267f0390c1e0068be6f3bac453d3c4d23 --- /dev/null +++ b/utils/print_utils.py @@ -0,0 +1,5 @@ +from colorama import Fore + + +def cyan(x: str) -> str: + return f"{Fore.CYAN}{x}{Fore.RESET}" diff --git a/utils/wandb_utils.py b/utils/wandb_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4df932c82e2b4f2dd3510e39faa7ecfee19279 --- /dev/null +++ b/utils/wandb_utils.py @@ -0,0 +1,175 @@ +from pathlib import Path +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Union +from typing_extensions import override +from functools import wraps +import os +from wandb_osh.hooks import TriggerWandbSyncHook +import time +from lightning.pytorch.loggers.wandb import WandbLogger, _scan_checkpoints, ModelCheckpoint, Tensor +from lightning.pytorch.utilities.rank_zero import rank_zero_only +from lightning.fabric.utilities.types import _PATH + + +if TYPE_CHECKING: + from wandb.sdk.lib import RunDisabled + from wandb.wandb_run import Run + + +class SpaceEfficientWandbLogger(WandbLogger): + """ + A wandb logger that by default overrides artifacts to save space, instead of creating new version. + A variable expiration_days can be set to control how long older versions of artifacts are kept. + By default, the latest version is kept indefinitely, while older versions are kept for 5 days. + """ + + def __init__( + self, + name: Optional[str] = None, + save_dir: _PATH = ".", + version: Optional[str] = None, + offline: bool = False, + dir: Optional[_PATH] = None, + id: Optional[str] = None, + anonymous: Optional[bool] = None, + project: Optional[str] = None, + log_model: Union[Literal["all"], bool] = False, + experiment: Union["Run", "RunDisabled", None] = None, + prefix: str = "", + checkpoint_name: Optional[str] = None, + expiration_days: Optional[int] = 5, + **kwargs: Any, + ) -> None: + super().__init__( + name=name, + save_dir=save_dir, + version=version, + offline=False, + dir=dir, + id=id, + anonymous=anonymous, + project=project, + log_model=log_model, + experiment=experiment, + prefix=prefix, + checkpoint_name=checkpoint_name, + **kwargs, + ) + + super().__init__( + name=name, + save_dir=save_dir, + version=version, + offline=offline, + dir=dir, + id=id, + anonymous=anonymous, + project=project, + log_model=log_model, + experiment=experiment, + prefix=prefix, + checkpoint_name=checkpoint_name, + **kwargs, + ) + self.expiration_days = expiration_days + self._last_artifacts = [] + + def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: + import wandb + + # get checkpoints to be saved with associated score + checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time) + + # log iteratively all new checkpoints + artifacts = [] + for t, p, s, tag in checkpoints: + metadata = { + "score": s.item() if isinstance(s, Tensor) else s, + "original_filename": Path(p).name, + checkpoint_callback.__class__.__name__: { + k: getattr(checkpoint_callback, k) + for k in [ + "monitor", + "mode", + "save_last", + "save_top_k", + "save_weights_only", + "_every_n_train_steps", + ] + # ensure it does not break if `ModelCheckpoint` args change + if hasattr(checkpoint_callback, k) + }, + } + if not self._checkpoint_name: + self._checkpoint_name = f"model-{self.experiment.id}" + + artifact = wandb.Artifact(name=self._checkpoint_name, type="model", metadata=metadata) + artifact.add_file(p, name="model.ckpt") + aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] + self.experiment.log_artifact(artifact, aliases=aliases) + # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) + self._logged_model_time[p] = t + artifacts.append(artifact) + + for artifact in self._last_artifacts: + if not self._offline: + artifact.wait() + artifact.ttl = timedelta(days=self.expiration_days) + artifact.save() + self._last_artifacts = artifacts + + +class OfflineWandbLogger(SpaceEfficientWandbLogger): + """ + Wraps WandbLogger to trigger offline sync hook occasionally. + This is useful when running on slurm clusters, many of which + only has internet on login nodes, not compute nodes. + """ + + def __init__( + self, + name: Optional[str] = None, + save_dir: _PATH = ".", + version: Optional[str] = None, + offline: bool = False, + dir: Optional[_PATH] = None, + id: Optional[str] = None, + anonymous: Optional[bool] = None, + project: Optional[str] = None, + log_model: Union[Literal["all"], bool] = False, + experiment: Union["Run", "RunDisabled", None] = None, + prefix: str = "", + checkpoint_name: Optional[str] = None, + **kwargs: Any, + ) -> None: + super().__init__( + name=name, + save_dir=save_dir, + version=version, + offline=False, + dir=dir, + id=id, + anonymous=anonymous, + project=project, + log_model=log_model, + experiment=experiment, + prefix=prefix, + checkpoint_name=checkpoint_name, + **kwargs, + ) + self._offline = offline + communication_dir = Path(".wandb_osh_command_dir") + communication_dir.mkdir(parents=True, exist_ok=True) + self.trigger_sync = TriggerWandbSyncHook(communication_dir) + self.last_sync_time = 0.0 + self.min_sync_interval = 60 + self.wandb_dir = os.path.join(self._save_dir, "wandb/latest-run") + + @override + @rank_zero_only + def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: + out = super().log_metrics(metrics, step) + if time.time() - self.last_sync_time > self.min_sync_interval: + self.trigger_sync(self.wandb_dir) + self.last_sync_time = time.time() + return out