Image-to-Video
Transformers
psi
feature-extraction
world-model
video-generation
multimodal
physical-world-model
controllable-generation
custom_code
Instructions to use StanfordNeuroAILab/psi0_5 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use StanfordNeuroAILab/psi0_5 with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("StanfordNeuroAILab/psi0_5", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Inference-only PLPQ/HLQ quantizers used by PSI2. | |
| The RGB, flow, and depth tokenizers released with PSI-0.5 all use the same | |
| pyramidal local patch quantizer architecture with different channel counts and | |
| numbers of residual scalar-quantizer codebooks. This file keeps only the pieces | |
| needed for encode/decode at inference time: | |
| - Haar patchwise wavelet projection. | |
| - Local residual convolution blocks. | |
| - Pyramidal finite scalar quantization (PFSQ). | |
| - The PLPQ wrapper with ``quantize()``, ``decode()``, and | |
| ``decode_coarse_tokens()``. | |
| The module names mirror the training implementation so existing checkpoints load | |
| without key surgery. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from contextlib import nullcontext | |
| from functools import wraps | |
| from types import SimpleNamespace | |
| from typing import Any, Dict, Iterable, List, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import pack, rearrange, unpack | |
| from torch import Tensor, int32 | |
| def _autocast_disabled(): | |
| """Return an autocast-disabled context/decorator without deprecation noise.""" | |
| return torch.amp.autocast("cuda", enabled=False) | |
| def _exists(value: Any) -> bool: | |
| return value is not None | |
| def _default(*values: Any) -> Any: | |
| for value in values: | |
| if _exists(value): | |
| return value | |
| return None | |
| def _maybe(fn): | |
| def inner(x, *args, **kwargs): | |
| if not _exists(x): | |
| return x | |
| return fn(x, *args, **kwargs) | |
| return inner | |
| def _pack_one(tensor: torch.Tensor, pattern: str): | |
| return pack([tensor], pattern) | |
| def _unpack_one(tensor: torch.Tensor, packed_shape, pattern: str): | |
| return unpack(tensor, packed_shape, pattern)[0] | |
| def _round_ste(z: Tensor) -> Tensor: | |
| """Round with a straight-through estimator.""" | |
| rounded = z.round() | |
| return z + (rounded - z).detach() | |
| class LayerNorm(nn.Module): | |
| """LayerNorm with optional bias, matching the training implementation.""" | |
| def __init__(self, ndim: int, bias: bool): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(ndim)) | |
| self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) | |
| class PatchResidualConvBlock(nn.Module): | |
| """Local residual MLP implemented as two convolutions over patch grids.""" | |
| def __init__( | |
| self, | |
| in_dim: int, | |
| out_dim: int, | |
| hidden_dim: int, | |
| kernel_size: int, | |
| stride: int, | |
| padding: int, | |
| dorpout: float = 0.1, | |
| ) -> None: | |
| super().__init__() | |
| self.nonlinearity = nn.SiLU() | |
| self.ln1 = LayerNorm(in_dim, bias=True) | |
| self.dropout = nn.Dropout(dorpout) | |
| self.conv1 = nn.Conv2d(in_dim, hidden_dim, kernel_size=kernel_size, stride=stride, padding=padding) | |
| self.conv2 = nn.Conv2d(hidden_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| b, c, h, w = x.shape | |
| z = self.ln1(x.permute(0, 2, 3, 1).reshape(b * h * w, c)) | |
| z = z.reshape(b, h, w, c).permute(0, 3, 1, 2).contiguous() | |
| z = self.nonlinearity(self.conv1(z)) | |
| z = self.dropout(z) | |
| z = self.nonlinearity(self.conv2(z)) | |
| return z + x | |
| class Upsample(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.conv(F.interpolate(x, scale_factor=2.0, mode="nearest")) | |
| class Downsample(nn.Module): | |
| def __init__(self, in_channels: int, out_channels: int): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.conv(F.pad(x, (0, 1, 0, 1), mode="constant", value=0)) | |
| class WaveletTransform(nn.Module): | |
| """Patchwise Haar transform used by released PLPQ/HLQ checkpoints.""" | |
| def __init__(self, patch_size: int, inverse: bool = False): | |
| super().__init__() | |
| self.patch_size = int(patch_size) | |
| self.inverse = bool(inverse) | |
| self.haar = torch.tensor([0.7071067811865476, 0.7071067811865476]) | |
| self.arange = torch.arange(len(self.haar)) | |
| self.steps = int(math.log2(self.patch_size)) | |
| def num_transformed_channels(self, in_channels: int = 3) -> int: | |
| return int(in_channels) * (4 ** self.steps) | |
| def forward(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor: | |
| if self.inverse: | |
| return self.invert(x, patchwise=patchwise, from_reshaped=reshape) | |
| return self.transform(x, patchwise=patchwise, reshape=reshape) | |
| def transform(self, x: torch.Tensor, patchwise: bool = True, reshape: bool = False) -> torch.Tensor: | |
| patch = self.patch_size | |
| if patchwise: | |
| b, c, h, w = x.shape | |
| init_b = b | |
| x = x.reshape(b, c, h // patch, patch, w // patch, patch).moveaxis(4, 3) | |
| x = x.moveaxis(1, 3).reshape(-1, c, patch, patch) | |
| for _ in range(self.steps): | |
| x = self.dwt(x) | |
| if patchwise: | |
| x = x.reshape(init_b, h // patch, w // patch, -1).moveaxis(3, 1) | |
| if reshape: | |
| b, cp2, hdp, wdp = x.shape | |
| c, h, w = cp2 // (patch**2), hdp * patch, wdp * patch | |
| x = x.reshape(b, patch, patch, c, hdp, wdp) | |
| x = x.moveaxis(3, 1).moveaxis(3, 4).reshape(b, c, h, w).contiguous() | |
| return x | |
| def invert(self, x: torch.Tensor, patchwise: bool = True, from_reshaped: bool = False) -> torch.Tensor: | |
| patch = self.patch_size | |
| if from_reshaped: | |
| b, c, h, w = x.shape | |
| cp2, hdp, wdp = c * patch**2, h // patch, w // patch | |
| x = x.reshape(b, c, patch, hdp, patch, wdp) | |
| x = x.moveaxis(4, 3).moveaxis(1, 3).reshape(b, cp2, hdp, wdp) | |
| if patchwise: | |
| init_b, lh, lw = x.shape[0], x.shape[2], x.shape[3] | |
| x = x.moveaxis(1, 3).reshape(-1, x.shape[1], 1, 1) | |
| for _ in range(self.steps): | |
| x = self.idwt(x) | |
| if patchwise: | |
| x = x.reshape(init_b, lh, lw, *x.shape[1:]).moveaxis(3, 1) | |
| x = x.moveaxis(3, 4).reshape(*x.shape[:2], lh * patch, lw * patch) | |
| return x | |
| def dwt(self, x: torch.Tensor) -> torch.Tensor: | |
| dtype = x.dtype | |
| h = self.haar | |
| n = h.shape[0] | |
| groups = x.shape[1] | |
| hl = h.flip(0).reshape(1, 1, -1).repeat(groups, 1, 1) | |
| hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(groups, 1, 1) | |
| hl = hl.to(device=x.device, dtype=dtype) | |
| hh = hh.to(device=x.device, dtype=dtype) | |
| x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode="reflect").to(dtype) | |
| xl = F.conv2d(x, hl.unsqueeze(2), groups=groups, stride=(1, 2)) | |
| xh = F.conv2d(x, hh.unsqueeze(2), groups=groups, stride=(1, 2)) | |
| xll = F.conv2d(xl, hl.unsqueeze(3), groups=groups, stride=(2, 1)) | |
| xlh = F.conv2d(xl, hh.unsqueeze(3), groups=groups, stride=(2, 1)) | |
| xhl = F.conv2d(xh, hl.unsqueeze(3), groups=groups, stride=(2, 1)) | |
| xhh = F.conv2d(xh, hh.unsqueeze(3), groups=groups, stride=(2, 1)) | |
| return 0.5 * torch.cat([xll, xlh, xhl, xhh], dim=1) | |
| def idwt(self, x: torch.Tensor) -> torch.Tensor: | |
| dtype = x.dtype | |
| h = self.haar | |
| n = h.shape[0] | |
| groups = x.shape[1] // 4 | |
| hl = h.flip([0]).reshape(1, 1, -1).repeat([groups, 1, 1]) | |
| hh = (h * ((-1) ** self.arange)).reshape(1, 1, -1).repeat(groups, 1, 1) | |
| hl = hl.to(device=x.device, dtype=dtype) | |
| hh = hh.to(device=x.device, dtype=dtype) | |
| xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1) | |
| yl = F.conv_transpose2d(xll, hl.unsqueeze(3), groups=groups, stride=(2, 1), padding=(n - 2, 0)) | |
| yl += F.conv_transpose2d(xlh, hh.unsqueeze(3), groups=groups, stride=(2, 1), padding=(n - 2, 0)) | |
| yh = F.conv_transpose2d(xhl, hl.unsqueeze(3), groups=groups, stride=(2, 1), padding=(n - 2, 0)) | |
| yh += F.conv_transpose2d(xhh, hh.unsqueeze(3), groups=groups, stride=(2, 1), padding=(n - 2, 0)) | |
| y = F.conv_transpose2d(yl, hl.unsqueeze(2), groups=groups, stride=(1, 2), padding=(0, n - 2)) | |
| y += F.conv_transpose2d(yh, hh.unsqueeze(2), groups=groups, stride=(1, 2), padding=(0, n - 2)) | |
| return 2.0 * y | |
| class PFSQ(nn.Module): | |
| """Pyramidal finite scalar quantizer used inside PLPQ.""" | |
| def __init__( | |
| self, | |
| levels: List[int], | |
| dim: int | None = None, | |
| num_codebooks: int = 1, | |
| keep_num_codebooks_dim: bool | None = None, | |
| scale: float | None = None, | |
| allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), | |
| channel_first: bool = False, | |
| projection_has_bias: bool = True, | |
| return_indices: bool = True, | |
| force_quantization_f32: bool = True, | |
| ): | |
| super().__init__() | |
| self.register_buffer("_levels", torch.tensor(levels, dtype=int32), persistent=False) | |
| self.register_buffer("_basis", torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32), persistent=False) | |
| self.scale = scale | |
| self.codebook_dim = len(levels) | |
| self.num_codebooks = int(num_codebooks) | |
| self.effective_codebook_dim = self.codebook_dim * self.num_codebooks | |
| self.keep_num_codebooks_dim = _default(keep_num_codebooks_dim, self.num_codebooks > 1) | |
| if self.num_codebooks > 1 and not self.keep_num_codebooks_dim: | |
| raise ValueError("PFSQ with multiple codebooks must keep the codebook dimension.") | |
| self.dim = _default(dim, self.effective_codebook_dim) | |
| self.channel_first = bool(channel_first) | |
| has_projections = self.dim != self.effective_codebook_dim | |
| self.project_in = nn.Linear(self.dim, self.effective_codebook_dim, bias=projection_has_bias) if has_projections else nn.Identity() | |
| self.project_out = nn.Linear(self.effective_codebook_dim, self.dim, bias=projection_has_bias) if has_projections else nn.Identity() | |
| self.has_projections = has_projections | |
| self.return_indices = bool(return_indices) | |
| if self.return_indices: | |
| self.codebook_size = self._levels.prod().item() | |
| self.register_buffer( | |
| "implicit_codebook", | |
| self._indices_to_codes(torch.arange(self.codebook_size)), | |
| persistent=False, | |
| ) | |
| self.allowed_dtypes = allowed_dtypes | |
| self.force_quantization_f32 = bool(force_quantization_f32) | |
| def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor: | |
| half_l = (self._levels - 1) * (1 + eps) / 2 | |
| offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) | |
| shift = (offset / half_l).atanh() | |
| return (z + shift).tanh() * half_l - offset | |
| def quantize(self, z: torch.Tensor) -> torch.Tensor: | |
| half_width = self._levels // 2 | |
| return _round_ste(self.bound(z)) / half_width | |
| def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor: | |
| half_width = self._levels // 2 | |
| return (zhat_normalized * half_width) + half_width | |
| def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor: | |
| half_width = self._levels // 2 | |
| return (zhat - half_width) / half_width | |
| def indices_to_level_indices(self, indices: torch.Tensor) -> torch.Tensor: | |
| indices = rearrange(indices, "... -> ... 1") | |
| return (indices // self._basis) % self._levels | |
| def _indices_to_codes(self, indices: torch.Tensor) -> torch.Tensor: | |
| return self._scale_and_shift_inverse(self.indices_to_level_indices(indices)) | |
| def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor: | |
| if zhat.shape[-1] != self.codebook_dim: | |
| raise ValueError(f"Expected last dim {self.codebook_dim}, got {zhat.shape[-1]}.") | |
| return (self._scale_and_shift(zhat) * self._basis).sum(dim=-1).to(int32) | |
| def indices_to_codes(self, indices: torch.Tensor, return_first: bool = False) -> torch.Tensor: | |
| if not _exists(indices): | |
| raise ValueError("indices must not be None.") | |
| n_codes = indices.shape[-1] | |
| is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) | |
| codes = self._indices_to_codes(indices) | |
| if self.keep_num_codebooks_dim: | |
| codes = rearrange(codes, "... c d -> ... (c d)") | |
| if n_codes == 1: | |
| return codes | |
| codes = self.project_out(codes) | |
| if is_img_or_video or self.channel_first: | |
| codes = rearrange(codes, "b ... d -> b d ...") | |
| return codes | |
| def forward(self, z: torch.Tensor): | |
| is_img_or_video = z.ndim >= 4 | |
| need_move_channel_last = is_img_or_video or self.channel_first | |
| if need_move_channel_last: | |
| z = rearrange(z, "b d ... -> b ... d") | |
| z, packed_shape = _pack_one(z, "b * d") | |
| if z.shape[-1] != self.dim: | |
| raise ValueError(f"Expected dimension {self.dim}, found {z.shape[-1]}.") | |
| z = self.project_in(z) | |
| z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) | |
| quantization_context = _autocast_disabled if self.force_quantization_f32 else nullcontext | |
| with quantization_context(): | |
| orig_dtype = z.dtype | |
| if self.force_quantization_f32 and orig_dtype not in self.allowed_dtypes: | |
| z = z.float() | |
| codes = self.quantize(z) | |
| indices = self.codes_to_indices(codes) if self.return_indices else None | |
| first_codes = codes[:, :, 0, :].type(orig_dtype) | |
| codes = rearrange(codes, "b n c d -> b n (c d)").type(orig_dtype) | |
| out = self.project_out(codes) | |
| if need_move_channel_last: | |
| out = _unpack_one(out, packed_shape, "b * d") | |
| out = rearrange(out, "b ... d -> b d ...") | |
| indices = _maybe(_unpack_one)(indices, packed_shape, "b * c") | |
| if not self.keep_num_codebooks_dim and self.return_indices: | |
| indices = _maybe(rearrange)(indices, "... 1 -> ...") | |
| return out, first_codes, indices | |
| class PLPQ(nn.Module): | |
| """Pyramidal Local Patch Quantizer inference wrapper.""" | |
| def __init__(self, config: SimpleNamespace): | |
| super().__init__() | |
| self.config = config | |
| if getattr(config, "use_wavelets", False): | |
| wavelets = WaveletTransform(patch_size=config.patch_size) | |
| wavelet_channels = wavelets.num_transformed_channels(config.num_in_channels) | |
| in_proj = nn.Sequential( | |
| wavelets, | |
| nn.Conv2d(wavelet_channels, config.encoder_blocks[0][1], kernel_size=1, stride=1), | |
| ) | |
| out_proj = nn.Sequential( | |
| nn.Conv2d(config.decoder_blocks[-1][2], wavelet_channels, kernel_size=3, stride=1, padding=1), | |
| WaveletTransform(patch_size=config.patch_size, inverse=True), | |
| ) | |
| else: | |
| in_proj = nn.Conv2d( | |
| config.num_in_channels, | |
| config.encoder_blocks[0][1], | |
| kernel_size=config.patch_size, | |
| stride=config.patch_size, | |
| ) | |
| out_proj = nn.Conv2d(config.decoder_blocks[-1][2], config.num_out_channels, kernel_size=3, stride=1, padding=1) | |
| self.encoder = nn.Sequential( | |
| in_proj, | |
| nn.SiLU(), | |
| *[ | |
| PatchResidualConvBlock(*params[1:]) if params[0] == "ResBlock" else Downsample(*params[1:]) | |
| for params in config.encoder_blocks | |
| ], | |
| ) | |
| self.quantizer = PFSQ( | |
| levels=config.levels, | |
| num_codebooks=config.num_quantizers, | |
| dim=config.encoder_blocks[-1][2], | |
| ) | |
| self.coarse_decoder = nn.Conv2d(len(config.levels), config.num_out_channels, kernel_size=1, stride=1) | |
| self.decoder = nn.Sequential( | |
| *[ | |
| PatchResidualConvBlock(*params[1:]) if params[0] == "ResBlock" else Upsample(*params[1:]) | |
| for params in config.decoder_blocks | |
| ], | |
| out_proj, | |
| ) | |
| def quantize(self, x: torch.Tensor, flatten: bool = True) -> torch.Tensor: | |
| z = self.encoder(x).permute(0, 2, 3, 1).contiguous() | |
| b, h, w, _c = z.shape | |
| z = z.view(b, h * w, -1) | |
| _quantized, _coarse_quantized, all_codes = self.quantizer(z) | |
| if not flatten: | |
| all_codes = all_codes.view(b, h, w, -1) | |
| return all_codes | |
| def decode(self, indices: torch.Tensor, shape: Tuple[int, int] | None = None) -> torch.Tensor: | |
| n_codes = indices.shape[-1] | |
| emb = self.quantizer.indices_to_codes(indices).squeeze(-1) | |
| if len(emb.shape) == 4: | |
| emb = emb.permute(0, 1, 2, 3).contiguous() | |
| else: | |
| if shape is not None: | |
| b = emb.size(0) | |
| h = shape[0] // self.config.patch_size | |
| w = shape[1] // self.config.patch_size | |
| else: | |
| b = emb.size(0) | |
| h = w = int(math.sqrt(emb.size(1))) | |
| emb = emb.permute(0, 2, 1).reshape(b, -1, h, w).contiguous() | |
| if n_codes == 1: | |
| return self.coarse_decoder(emb) | |
| return self.decoder(emb) | |
| def decode_coarse_tokens(self, indices: torch.Tensor) -> torch.Tensor: | |
| emb = self.quantizer.indices_to_codes(indices).squeeze(-1) | |
| emb = emb.transpose(1, 2).unsqueeze(-1).contiguous() | |
| return self.coarse_decoder(emb) | |
| def quantizer_config_from_dict(config_dict: Dict[str, Any]) -> SimpleNamespace: | |
| """Return a config namespace compatible with released PLPQ checkpoints.""" | |
| return SimpleNamespace(**dict(config_dict)) | |
| def quantizer_from_checkpoint_dict(ckpt: Dict[str, Any]) -> PLPQ: | |
| """Instantiate a PLPQ quantizer from a loaded checkpoint dictionary.""" | |
| cfg = quantizer_config_from_dict(ckpt["cfg"]) | |
| return PLPQ(cfg) | |