Image-to-Video
Transformers
psi
feature-extraction
world-model
video-generation
multimodal
physical-world-model
controllable-generation
custom_code
Instructions to use StanfordNeuroAILab/psi0_5 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use StanfordNeuroAILab/psi0_5 with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("StanfordNeuroAILab/psi0_5", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| from typing import Tuple, Union, List, Dict, Optional, Callable | |
| import tqdm | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def validate_psi2_config(config) -> None: | |
| """ | |
| Validate PSI2 config to catch misconfigurations early. | |
| Raises: | |
| AssertionError: If any validation fails with a descriptive message. | |
| """ | |
| if hasattr(config, 'pointer_token'): | |
| assert config.pointer_token < config.vocab_size, \ | |
| f"pointer_token ({config.pointer_token}) must be < vocab_size ({config.vocab_size})" | |
| token_range_attrs = ['rgb_range', 'flow_range', 'depth_range', 'campose_range', 'imagenet_cls_range'] | |
| for attr in token_range_attrs: | |
| if hasattr(config, attr): | |
| lo, hi = getattr(config, attr) | |
| assert 0 <= lo <= hi <= config.vocab_size, \ | |
| f"{attr} ({lo}, {hi}) out of bounds for vocab_size={config.vocab_size}" | |
| channel_range_attrs = [ | |
| 'pointer_channel_range', 'rgb_channel_range', 'campose_channel_range', | |
| 'flow_channel_range', 'depth_channel_range', 'cls_channel_range' | |
| ] | |
| for attr in channel_range_attrs: | |
| if hasattr(config, attr): | |
| lo, hi = getattr(config, attr) | |
| assert 0 <= lo < hi <= config.channel_size, \ | |
| f"{attr} ({lo}, {hi}) out of bounds for channel_size={config.channel_size}" | |
| assert 0.0 <= config.dropout <= 1.0, \ | |
| f"dropout ({config.dropout}) must be in [0.0, 1.0]" | |
| if hasattr(config, 'drop_path_rate'): | |
| assert 0.0 <= config.drop_path_rate <= 1.0, \ | |
| f"drop_path_rate ({config.drop_path_rate}) must be in [0.0, 1.0]" | |
| if hasattr(config, 'residual_scale'): | |
| assert config.residual_scale > 0, \ | |
| f"residual_scale ({config.residual_scale}) must be positive" | |
| if hasattr(config, 'n_kv_head') and config.n_kv_head is not None: | |
| assert config.n_kv_head > 0, \ | |
| f"n_kv_head ({config.n_kv_head}) must be positive (or None to disable GQA)" | |
| assert config.n_head % config.n_kv_head == 0, \ | |
| f"n_head ({config.n_head}) must be divisible by n_kv_head ({config.n_kv_head})" | |
| assert config.n_embd % config.n_head == 0, \ | |
| f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})" | |
| if hasattr(config, 'mlp_activation'): | |
| valid_activations = {'swiglu', 'gelu'} | |
| assert config.mlp_activation.lower() in valid_activations, \ | |
| f"mlp_activation ({config.mlp_activation}) must be one of {valid_activations}" | |
| if hasattr(config, 'attention_mask'): | |
| valid_masks = {'causal', 'none', 'full'} | |
| assert config.attention_mask.lower() in valid_masks, \ | |
| f"attention_mask ({config.attention_mask}) must be one of {valid_masks}" | |
| if hasattr(config, 'context_parallel_size'): | |
| assert int(config.context_parallel_size) >= 1, \ | |
| f"context_parallel_size ({config.context_parallel_size}) must be >= 1" | |
| if hasattr(config, 'load_balance_cp') and config.load_balance_cp: | |
| assert getattr(config, 'context_parallel_size', 1) > 1, \ | |
| "load_balance_cp requires context_parallel_size > 1" | |
| def _to_additive_mask( | |
| mask: Optional[torch.Tensor], | |
| *, | |
| L: int, | |
| S: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> Optional[torch.Tensor]: | |
| """ | |
| Convert a user-provided mask to an additive mask with 0 for allowed and -inf for blocked. | |
| Accepts: | |
| - None | |
| - boolean mask (True = blocked) | |
| - float mask already in additive form | |
| """ | |
| if mask is None: | |
| return None | |
| if mask.dtype == torch.bool: | |
| add = torch.zeros_like(mask, dtype=dtype, device=mask.device) | |
| add = add.masked_fill(mask, float('-inf')) | |
| return add | |
| return mask.to(dtype) if mask.dtype != dtype else mask | |
| def _build_causal_additive_mask( | |
| L: int, | |
| S: int, | |
| *, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| past_k: int = 0, | |
| ) -> torch.Tensor: | |
| """ | |
| Additive causal mask for rectangular attention where query length is L and key length is S. | |
| A query at row i may attend up to index (past_k + i). Positions j > past_k + i are masked. | |
| Returns [1,1,L,S]. | |
| """ | |
| i = torch.arange(L, device=device).unsqueeze(-1) | |
| j = torch.arange(S, device=device).unsqueeze(0) | |
| blocked = j > (past_k + i) | |
| add = torch.zeros((L, S), device=device, dtype=dtype) | |
| add = add.masked_fill(blocked, float('-inf')) | |
| return add.view(1, 1, L, S) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| rms = x.float().pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt().to(x.dtype) | |
| return self.weight * (x * rms) | |
| class Rotary3D(nn.Module): | |
| """ | |
| 3D Rotary embedding for per-head dims. Splits head_dim across (x,y,t) and applies rotations. | |
| """ | |
| def __init__(self, head_dim: int, base: int = 100): | |
| super().__init__() | |
| assert head_dim % 2 == 0, "head_dim must be even." | |
| x_dim = (head_dim * 6) // 16 | |
| y_dim = (head_dim * 6) // 16 | |
| t_dim = head_dim - x_dim - y_dim | |
| for d in (x_dim, y_dim, t_dim): | |
| if d % 2 != 0: | |
| raise AssertionError("Each axis dim must be even for RoPE.") | |
| self.base = base | |
| self.x_dim = x_dim | |
| self.y_dim = y_dim | |
| self.t_dim = t_dim | |
| self.register_buffer("inv_freq_x", self._inv_freqs(self.x_dim), persistent=False) | |
| self.register_buffer("inv_freq_y", self._inv_freqs(self.y_dim), persistent=False) | |
| self.register_buffer("inv_freq_t", self._inv_freqs(self.t_dim), persistent=False) | |
| def initialize_buffers(self): | |
| '''Called in Checkpointing.instantiate_model when weight init is skipped''' | |
| self.inv_freq_x = self._inv_freqs(self.x_dim) | |
| self.inv_freq_y = self._inv_freqs(self.y_dim) | |
| self.inv_freq_t = self._inv_freqs(self.t_dim) | |
| def _inv_freqs(self, dim: int): | |
| return 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim)) | |
| def _axis_rot(self, x_half1, x_half2, cos, sin): | |
| return torch.cat([x_half1 * cos - x_half2 * sin, x_half1 * sin + x_half2 * cos], dim=-1) | |
| def forward(self, x: torch.Tensor, pos_xyz: torch.Tensor) -> torch.Tensor: | |
| """ | |
| x: [B, nH, T, H] | |
| pos_xyz: [B, T, 3] int/float (x, y, t) | |
| """ | |
| B, nH, T, H = x.shape | |
| assert pos_xyz.shape == (B, T, 3), f"pos must be [B, T, 3], got {pos_xyz.shape}" | |
| # Move inv_freq buffers to input device if needed (FSDP doesn't move buffers) | |
| if self.inv_freq_x.device != x.device: | |
| self.inv_freq_x = self.inv_freq_x.to(x.device) | |
| self.inv_freq_y = self.inv_freq_y.to(x.device) | |
| self.inv_freq_t = self.inv_freq_t.to(x.device) | |
| px = pos_xyz[..., 0].to(self.inv_freq_x.dtype) | |
| py = pos_xyz[..., 1].to(self.inv_freq_y.dtype) | |
| pt = pos_xyz[..., 2].to(self.inv_freq_t.dtype) | |
| fx = torch.einsum("bt,f->btf", px, self.inv_freq_x) | |
| fy = torch.einsum("bt,f->btf", py, self.inv_freq_y) | |
| ft = torch.einsum("bt,f->btf", pt, self.inv_freq_t) | |
| cx, sx = fx.cos().unsqueeze(1), fx.sin().unsqueeze(1) | |
| cy, sy = fy.cos().unsqueeze(1), fy.sin().unsqueeze(1) | |
| ct, st = ft.cos().unsqueeze(1), ft.sin().unsqueeze(1) | |
| half = H // 2 | |
| x1, x2 = x[..., :half], x[..., half:] | |
| def split_axis(tensor, dims): | |
| a, b, c = dims | |
| a1 = a // 2; b1 = b // 2; c1 = c // 2 | |
| s0 = tensor[..., :a1] | |
| s1 = tensor[..., a1:a1 + b1] | |
| s2 = tensor[..., a1 + b1:a1 + b1 + c1] | |
| return s0, s1, s2 | |
| x1x, x1y, x1t = split_axis(x1, (self.x_dim, self.y_dim, self.t_dim)) | |
| x2x, x2y, x2t = split_axis(x2, (self.x_dim, self.y_dim, self.t_dim)) | |
| rx = self._axis_rot(x1x, x2x, cx, sx) | |
| ry = self._axis_rot(x1y, x2y, cy, sy) | |
| rt = self._axis_rot(x1t, x2t, ct, st) | |
| rotated = torch.cat([rx, ry, rt], dim=-1) | |
| return rotated | |
| class DropPath(nn.Module): | |
| """Per-sample DropPath, broadcast across non-batch dims.""" | |
| def __init__(self, drop_prob: float = 0.0): | |
| super().__init__() | |
| self.drop_prob = float(drop_prob) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.drop_prob == 0.0 or not self.training: | |
| return x | |
| keep = 1.0 - self.drop_prob | |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
| mask = x.new_empty(shape).bernoulli_(keep) | |
| return x * (mask / keep) | |
| class MLP(nn.Module): | |
| def __init__(self, | |
| n_embd: int, | |
| mlp_hidden_size: Optional[int] = None, | |
| *, | |
| bias: bool = False, | |
| dropout: float = 0.0, | |
| activation: str = "swiglu"): | |
| super().__init__() | |
| inner = mlp_hidden_size if mlp_hidden_size is not None else 4 * n_embd | |
| self.activation = activation.lower() | |
| if self.activation == "swiglu": | |
| self.c_fc = nn.Linear(n_embd, 2 * inner, bias=bias) | |
| self.c_proj = nn.Linear(inner, n_embd, bias=bias) | |
| elif self.activation == "gelu": | |
| self.c_fc = nn.Linear(n_embd, inner, bias=bias) | |
| self.c_proj = nn.Linear(inner, n_embd, bias=bias) | |
| else: | |
| raise ValueError(f"Unsupported mlp_activation: {activation}") | |
| self.drop = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.activation == "swiglu": | |
| u, v = self.c_fc(x).chunk(2, dim=-1) | |
| x = F.silu(v) * u | |
| else: | |
| x = F.gelu(self.c_fc(x)) | |
| x = self.c_proj(x) | |
| x = self.drop(x) | |
| return x | |
| class PSIAttentionLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head" | |
| self.n_head = config.n_head | |
| self.n_kv_head = getattr(config, "n_kv_head", None) or self.n_head | |
| assert self.n_head % self.n_kv_head == 0, "n_head must be a multiple of n_kv_head" | |
| self.head_dim = config.n_embd // self.n_head | |
| self.dropout = config.dropout | |
| self.bias = getattr(config, "bias", False) | |
| self.is_causal = (getattr(config, "attention_mask", "causal") == "causal") | |
| self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=self.bias) | |
| self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=self.bias) | |
| self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=self.bias) | |
| self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=self.bias) | |
| self.attn_dropout = nn.Dropout(self.dropout) | |
| self.resid_dropout = nn.Dropout(self.dropout) | |
| self.rope = Rotary3D(self.head_dim) | |
| self._use_sdpa = hasattr(F, "scaled_dot_product_attention") | |
| def emplace_kv(self, T, k_cache, v_cache, k, v): | |
| # torch.compile doesn't play well with this op (5x slowdown) | |
| # so we insert a graph break and copy eagerly | |
| k_cache[:,:,-T:].copy_(k) | |
| v_cache[:,:,-T:].copy_(v) | |
| return k_cache, v_cache | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| pos_xyz: torch.Tensor, | |
| *, | |
| k_cache: Optional[torch.Tensor] = None, | |
| v_cache: Optional[torch.Tensor] = None, | |
| return_kv: bool = False, | |
| inplace_kv: bool = False, | |
| mask: Optional[torch.Tensor] = None | |
| ): | |
| B, T, C = x.shape | |
| H, HKV, D = self.n_head, self.n_kv_head, self.head_dim | |
| groups = H // HKV | |
| q = self.q_proj(x).view(B, T, H, D).transpose(1, 2) | |
| k = self.k_proj(x).view(B, T, HKV, D).transpose(1, 2) | |
| v = self.v_proj(x).view(B, T, HKV, D).transpose(1, 2) | |
| q = self.rope(q, pos_xyz) | |
| k = self.rope(k, pos_xyz) | |
| past_k = k_cache.shape[2] if (k_cache is not None) else 0 | |
| if inplace_kv and (k_cache is not None) and (v_cache is not None): | |
| k, v = self.emplace_kv(T, k_cache, v_cache, k, v) | |
| else: | |
| if k_cache is not None: | |
| k = torch.cat([k_cache, k], dim=2) | |
| if v_cache is not None: | |
| v = torch.cat([v_cache, v], dim=2) | |
| S = k.shape[2] | |
| # Check if we're on XLA device at runtime - more reliable than init-time check | |
| # The init-time _tpu_available() can return False if called before XLA is initialized | |
| is_xla = 'xla' in str(x.device) | |
| expand_kv_heads = HKV != H | |
| if expand_kv_heads: | |
| # Note that despite using expand + reshape, this still technically copies the | |
| # tensor because it is impossible to create a strided view of repeated data. | |
| # This is equivalent to repeat_interleave. | |
| # We use this even with F.scaled_dot_product_attention because the enable_gqa arg | |
| # is slower and uses more memory. This may change with later torch versions. | |
| k = k.unsqueeze(2).expand(B, HKV, groups, S, D).reshape(B, H, S, D) | |
| v = v.unsqueeze(2).expand(B, HKV, groups, S, D).reshape(B, H, S, D) | |
| add_mask = _to_additive_mask(mask, L=T, S=S, device=x.device, dtype=q.dtype) | |
| if add_mask is None and self.is_causal: | |
| add_mask = _build_causal_additive_mask(T, S, device=x.device, dtype=q.dtype, past_k=past_k) | |
| if self._use_sdpa and not is_xla: | |
| y = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=add_mask, | |
| dropout_p=self.dropout if self.training else 0.0, | |
| ) | |
| else: | |
| att = torch.einsum("bhtd,bhsd->bhts", q, k) * (1.0 / math.sqrt(D)) | |
| if add_mask is not None: | |
| att = att + add_mask | |
| att = F.softmax(att, dim=-1, dtype=torch.float32).to(q.dtype) | |
| att = self.attn_dropout(att) | |
| y = torch.einsum("bhts,bhsd->bhtd", att, v) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| y = self.resid_dropout(self.out_proj(y)) | |
| if return_kv: | |
| if expand_kv_heads: | |
| return y, k[:, ::groups], v[:, ::groups] | |
| return y, k, v | |
| return y | |
| class Block(nn.Module): | |
| def __init__(self, config, layer_idx: int, total_layers: int): | |
| super().__init__() | |
| self.residual_scale = float(getattr(config, "residual_scale", 1.0)) | |
| drop_path_rate = float(getattr(config, "drop_path_rate", 0.0)) | |
| layer_dp = drop_path_rate * (layer_idx / max(1, total_layers - 1)) | |
| self.ln_1 = RMSNorm(config.n_embd) | |
| self.attn = PSIAttentionLayer(config) | |
| self.drop_path_1 = DropPath(layer_dp) | |
| self.ln_2 = RMSNorm(config.n_embd) | |
| self.mlp = MLP( | |
| config.n_embd, | |
| getattr(config, "mlp_hidden_size", None), | |
| bias=getattr(config, "bias", False), | |
| dropout=config.dropout, | |
| activation=getattr(config, "mlp_activation", "swiglu"), | |
| ) | |
| self.drop_path_2 = DropPath(layer_dp) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| pos: torch.Tensor, | |
| *, | |
| k_cache: Optional[torch.Tensor] = None, | |
| v_cache: Optional[torch.Tensor] = None, | |
| return_kv: bool = False, | |
| inplace_kv: bool = False, | |
| mask: Optional[torch.Tensor] = None, | |
| ): | |
| if return_kv: | |
| a, k, v = self.attn(self.ln_1(x), pos, | |
| k_cache=k_cache, v_cache=v_cache, | |
| return_kv=True, inplace_kv=inplace_kv, mask=mask) | |
| x = x + self.drop_path_1(a) * self.residual_scale | |
| m = self.mlp(self.ln_2(x)) | |
| x = x + self.drop_path_2(m) * self.residual_scale | |
| return x, k, v | |
| else: | |
| a = self.attn(self.ln_1(x), pos, | |
| k_cache=k_cache, v_cache=v_cache, | |
| return_kv=False, inplace_kv=inplace_kv, mask=mask) | |
| x = x + self.drop_path_1(a) * self.residual_scale | |
| m = self.mlp(self.ln_2(x)) | |
| x = x + self.drop_path_2(m) * self.residual_scale | |
| return x | |
| class KVCache: | |
| def __init__(self, | |
| k: torch.Tensor, | |
| v: torch.Tensor, | |
| mask: Optional[torch.Tensor], | |
| size: int | |
| ): | |
| self.k = k | |
| self.v = v | |
| self.mask = mask | |
| self.size = size | |
| def capacity(self): | |
| return self.k.shape[-2] | |
| def shallow_copy(self, | |
| mask: Optional[torch.Tensor], | |
| size: int | |
| ) -> 'KVCache': | |
| return KVCache(k=self.k, v=self.v, mask=mask, size=size) | |
| def allocate( | |
| model: 'PSI2', | |
| capacity: int, | |
| batch_size: int = 1, | |
| dtype=None, device=None | |
| ) -> 'KVCache': | |
| ''' | |
| Allocates a standard KVCache. | |
| You'll usually want to use `PSI2.allocate_kvcache()` instead. | |
| ''' | |
| n_kv = getattr(model.config, "n_kv_head", None) or model.config.n_head | |
| head_dim = model.config.n_embd // model.config.n_head | |
| kv_shape = ( | |
| model.config.n_layer, batch_size, n_kv, | |
| capacity, head_dim | |
| ) | |
| return KVCache( | |
| k=torch.empty(kv_shape, dtype=dtype, device=device), | |
| v=torch.empty(kv_shape, dtype=dtype, device=device), | |
| mask=None, size=0 | |
| ) | |
| def validate(self, batch_size: int): | |
| assert self.k.shape == self.v.shape, \ | |
| f'KV shape mismatch: k={self.k.shape}, v={self.v.shape}' | |
| assert self.k.ndim == 5 | |
| assert self.k.shape[1] == batch_size | |
| assert self.k.device == self.v.device | |
| assert self.k.dtype == self.v.dtype | |
| if self.mask is not None: | |
| assert self.mask.ndim == 4 and self.mask.shape[1:3] == (1,1) | |
| assert self.mask.shape[-1] == self.size | |
| assert self.size <= self.k.shape[-2] | |
| def get_token_slice(self, | |
| start: int, | |
| stop: Optional[int] | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| return ( | |
| self.k[...,start:stop,:], | |
| self.v[...,start:stop,:] | |
| ) | |
| def copy_from(self, kvcache: 'KVCache'): | |
| assert self.capacity >= kvcache.size | |
| self.k[...,:kvcache.size,:].copy_(kvcache.k[...,:kvcache.size,:], non_blocking=True) | |
| self.v[...,:kvcache.size,:].copy_(kvcache.v[...,:kvcache.size,:], non_blocking=True) | |
| self.mask = kvcache.mask | |
| self.size = kvcache.size | |
| def transpose_last(self, end: int, count: int, tpp: int): | |
| # npt=tokens, npp=patches | |
| # [... npt E] -> [... tpp npp E] -> [... npp tpp E] -> [... npt E] | |
| edims = self.k.shape[:-2] | |
| par_k = self.k[...,end-count:end,:].reshape(*edims, tpp, count//tpp, -1) | |
| par_v = self.v[...,end-count:end,:].reshape(*edims, tpp, count//tpp, -1) | |
| par_k = par_k.transpose(-2, -3).reshape(*edims, count, -1) | |
| par_v = par_v.transpose(-2, -3).reshape(*edims, count, -1) | |
| self.k[...,end-count:end,:].copy_(par_k.clone()) | |
| self.v[...,end-count:end,:].copy_(par_v.clone()) | |
| class PSI2(nn.Module): | |
| """ | |
| PSI2 with training-time features: | |
| - GQA (n_kv_head < n_head) | |
| - SwiGLU MLP | |
| - Stochastic depth (DropPath) | |
| - Residual scaling | |
| - Robust weight tying (works with XLA/FSDP, saves memory) | |
| """ | |
| def __init__(self, config, verbose=True): | |
| super().__init__() | |
| self.config = config | |
| if verbose: | |
| print("[PSI2] Using config:", config, flush=True) | |
| validate_psi2_config(config) | |
| self.context_parallel_size = int(getattr(config, "context_parallel_size", 1)) | |
| self.tie_weights = getattr(config, "tie_weights", False) | |
| self.n_lm_vocab = getattr(config, "n_lm_vocab", None) or config.vocab_size | |
| if self.tie_weights and self.n_lm_vocab != config.vocab_size: | |
| raise ValueError( | |
| f"Cannot tie weights when n_lm_vocab ({self.n_lm_vocab}) != vocab_size ({config.vocab_size})" | |
| ) | |
| token_embedding = nn.Embedding(config.vocab_size, config.n_embd) | |
| self.transformer = nn.ModuleDict(dict( | |
| token_embedding=token_embedding, | |
| channel_embedding=nn.Embedding(config.channel_size, config.n_embd), | |
| drop=nn.Dropout(config.dropout), | |
| h=nn.ModuleList([Block(config, i, config.n_layer) for i in range(config.n_layer)]), | |
| ln_f=RMSNorm(config.n_embd), | |
| )) | |
| if self.tie_weights: | |
| self.lm_head = None | |
| if getattr(config, "lm_head_bias", False): | |
| self.lm_head_bias = nn.Parameter(torch.zeros(self.n_lm_vocab)) | |
| else: | |
| self.register_parameter('lm_head_bias', None) | |
| if verbose: | |
| print(f"[PSI2] Weight tying enabled: using token_embedding weights for lm_head " | |
| f"(saving {config.vocab_size * config.n_embd:,} parameters)") | |
| else: | |
| self.lm_head = nn.Linear(config.n_embd, self.n_lm_vocab, bias=False) | |
| self.lm_head_bias = None | |
| self.apply(self._init_weights) | |
| for pn, p in self.named_parameters(): | |
| if pn.endswith("out_proj.weight"): | |
| torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) | |
| self.unsharded_param_count = self.get_num_params() | |
| # Store the "effective" param count for compute purposes | |
| # When weights are tied, the lm_head matmul still happens, so we count it for FLOPs | |
| self.compute_param_count = self._get_compute_params() | |
| def device(self) -> torch.device: | |
| return self.transformer.token_embedding.weight.device | |
| def dtype(self) -> torch.device: | |
| return self.transformer.token_embedding.weight.dtype | |
| def _init_weights(self, module: nn.Module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def _get_compute_params(self) -> int: | |
| """ | |
| Get effective parameter count for compute/FLOPs estimation. | |
| When weights are tied, the embedding is used twice (embedding + lm_head), | |
| so we count it twice for compute purposes even though it's stored once. | |
| """ | |
| base_params = sum(p.numel() for p in self.parameters()) | |
| if self.tie_weights: | |
| # Add the embedding size again since it's used for lm_head computation | |
| tied_weight_size = self.config.vocab_size * self.config.n_embd | |
| return base_params + tied_weight_size | |
| return base_params | |
| def get_num_params(self, non_embedding: bool = True) -> int: | |
| """Return the number of stored model parameters. | |
| The ``non_embedding`` argument is accepted for compatibility with the | |
| training implementation; the compact inference runtime reports the full | |
| stored parameter count. | |
| """ | |
| return sum(p.numel() for p in self.parameters()) | |
| def _lm_head_forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute lm_head output, using embedding weights if tie_weights is enabled. | |
| """ | |
| if self.tie_weights: | |
| logits = F.linear(x, self.transformer.token_embedding.weight, self.lm_head_bias) | |
| else: | |
| logits = self.lm_head(x) | |
| return logits | |
| def forward( | |
| self, | |
| seq: torch.Tensor, | |
| pos: torch.Tensor, | |
| tgt: Optional[torch.Tensor] = None, | |
| mask: Optional[torch.Tensor] = None, | |
| k_cache: Optional[torch.Tensor] = None, | |
| v_cache: Optional[torch.Tensor] = None, | |
| return_kv: bool = False, | |
| inplace_kv: bool = False, | |
| output_hidden_states: bool = False, | |
| ): | |
| """ | |
| Returns: | |
| if tgt is None: | |
| (logits, None) or (logits, (k_cache, v_cache)...) depending on flags | |
| else: | |
| (logits, loss) | |
| """ | |
| st_pos = pos[:, :, :-1] | |
| channel_pos = pos[:, :, -1] | |
| tok_emb = self.transformer.token_embedding(seq) | |
| ch_emb = self.transformer.channel_embedding(channel_pos) | |
| x = self.transformer.drop(tok_emb + ch_emb) | |
| if output_hidden_states: | |
| hidden_states = [x] | |
| k_list, v_list = [], [] | |
| for i, block in enumerate(self.transformer.h): | |
| x = block( | |
| x, | |
| pos=st_pos, | |
| k_cache=None if k_cache is None else k_cache[i], | |
| v_cache=None if v_cache is None else v_cache[i], | |
| return_kv=return_kv, | |
| inplace_kv=inplace_kv, | |
| mask=mask, | |
| ) | |
| if return_kv: | |
| x, k, v = x | |
| k_list.append(k) | |
| v_list.append(v) | |
| if output_hidden_states: | |
| hidden_states.append(x) | |
| x = self.transformer.ln_f(x) | |
| if output_hidden_states: | |
| hidden_states.append(x) | |
| if tgt is None: | |
| logits = self._lm_head_forward(x) | |
| if output_hidden_states: | |
| logits = {"logits": logits, "hidden_states": hidden_states} | |
| if return_kv: | |
| if inplace_kv: | |
| return logits, k_cache, v_cache | |
| else: | |
| return logits, torch.stack(k_list), torch.stack(v_list) | |
| return logits, None | |
| logits = self._lm_head_forward(x[:, -tgt.size(1):]) | |
| flat_logits = logits.reshape(-1, logits.size(-1)) | |
| flat_tgt = tgt.reshape(-1) | |
| # Context parallelism should not change the training objective itself: | |
| # we still optimize the standard token-mean cross-entropy on the logits | |
| # produced by the current forward pass. Any CP-specific correctness work | |
| # belongs in the attention plumbing, not in ad hoc loss rescaling here. | |
| loss = F.cross_entropy( | |
| flat_logits, | |
| flat_tgt, | |
| ignore_index=-1 | |
| ) | |
| if output_hidden_states: | |
| logits = {"logits": logits, "hidden_states": hidden_states} | |
| return logits, loss | |
| def unpack_and_sort_img_seq(self, | |
| img_seq: torch.Tensor, | |
| num_revealed_patches: int = 0, | |
| mark_revealed_patches: Optional[int] = None, | |
| logits: Optional[torch.Tensor] = None, | |
| unpatchify_seq: bool = True): | |
| """ | |
| Unpacks a sequence of patch-index + RGB tokens into an image. | |
| Assumes first token per patch is the index. | |
| """ | |
| n_tokens_per_patch = self.config.patch_size ** 2 + 1 | |
| B = img_seq.size(0) | |
| img_seq = img_seq.view(B, -1, n_tokens_per_patch) | |
| img_idxs = (img_seq[:, :, 0].long() - (65536 + 256)) | |
| reconstruct_indxs = torch.argsort(img_idxs, dim=1) | |
| rgb_seq = img_seq[:, :, 1:] | |
| if mark_revealed_patches is not None: | |
| rgb_seq[:, :num_revealed_patches] = mark_revealed_patches | |
| rgb_seq = rgb_seq[torch.arange(B).unsqueeze(1).expand_as(img_idxs), reconstruct_indxs] | |
| if unpatchify_seq: | |
| img = _unpatchify(rgb_seq) | |
| else: | |
| h = w = int(math.sqrt(rgb_seq.size(1))) | |
| img = rgb_seq.reshape(B, h, w, -1) | |
| if logits is not None: | |
| logits = logits.view(B, -1, n_tokens_per_patch, logits.size(-1)) | |
| logits = logits[:, :, 1:] | |
| logits = logits[torch.arange(B).unsqueeze(1).expand_as(img_idxs), reconstruct_indxs] | |
| logits = _unpatchify_logits(logits) | |
| return img, logits | |
| return img | |
| def sample_logits(self, | |
| logits: torch.FloatTensor, | |
| temp: Optional[float] = None, | |
| post_temp: Optional[float] = None, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| min_p: Optional[float] = None, | |
| sample_range: Optional[Tuple[int,int]] = None, | |
| blacklist: Optional[Union[List[int], torch.LongTensor]] = None, | |
| content_whitelist: Optional[torch.LongTensor] = None | |
| ) -> torch.LongTensor: | |
| """ | |
| Samples an integer from the distribution of logits | |
| Parameters: | |
| logits (torch.FloatTensor): The logits of the distribution | |
| temp (float): The temperature of the sampling, if 0.0, then argmax is used | |
| top_k (int): The number of top k tokens to consider during sampling | |
| top_p (float): The cumulative probability threshold for nucleus (top-p) sampling | |
| min_p (float): The minimum probability threshold factor for min-p sampling | |
| blacklist (Union[List[int], torch.LongTensor]): The list of tokens to blacklist during sampling | |
| content_whitelist (torch.LongTensor): The list of tokens to whitelist during sampling (mutually exclusive with blacklist) | |
| Returns: | |
| torch.LongTensor: The sampled integers | |
| """ | |
| if isinstance(temp, list): | |
| temp = temp[0] | |
| if isinstance(post_temp, list): | |
| post_temp = post_temp[0] | |
| if isinstance(top_k, list): | |
| top_k = top_k[0] | |
| if isinstance(top_p, list): | |
| top_p = top_p[0] | |
| assert temp is None or temp >= 0.0 | |
| assert post_temp is None or post_temp >= 0.0 | |
| assert top_k is None or top_k > 0 | |
| assert top_p is None or top_p >= 0.0 | |
| assert min_p is None or 0.0 <= min_p <= 1.0 | |
| assert sample_range is None or ( | |
| sample_range[0] < sample_range[1] and | |
| sample_range[0] >= 0 and | |
| sample_range[1] <= logits.shape[-1] | |
| ) | |
| assert blacklist is None or content_whitelist is None, "blacklist and content_whitelist cannot be used together" | |
| # Apply blacklist & sample range | |
| if blacklist is not None: | |
| logits[...,blacklist] = float('-inf') | |
| if content_whitelist is not None: | |
| # Create a mask that allows only whitelisted tokens | |
| whitelist_mask = torch.full_like(logits, float('-inf')) | |
| whitelist_mask[..., content_whitelist] = 0.0 | |
| logits = logits + whitelist_mask | |
| if sample_range is not None: | |
| logits = logits[...,sample_range[0]:sample_range[1]] | |
| token_offset = sample_range[0] | |
| else: | |
| token_offset = 0 | |
| # Apply temperature, or use argmax if 0.0 | |
| if (temp is not None and temp == 0.0) or (post_temp is not None and post_temp == 0.0): | |
| return token_offset + torch.argmax(logits, dim=-1) | |
| if temp is not None and temp != 1.0: | |
| logits.div_(temp) | |
| # Sort the logits once. More efficient when using top-k and top-p together (min-p doesn't require sorting). | |
| # We sample in sorted order then re-order before returning. | |
| if top_k is not None or top_p is not None: | |
| logits, order = torch.sort(logits, dim=-1, descending=True) | |
| else: | |
| order = None # Don't sort | |
| # Apply top-k filtering if specified | |
| if top_k is not None: | |
| logits = logits[...,:top_k] | |
| # Apply top-p (nucleus) filtering if specified | |
| if top_p is not None: | |
| probs = F.softmax(logits, dim=-1) | |
| cumulative_probs = probs.cumsum_(dim=-1) | |
| idxs_to_remove = cumulative_probs > top_p | |
| # Shift the mask right to keep at least one token | |
| logits[...,1:][idxs_to_remove[...,:-1]] = float('-inf') | |
| del probs, cumulative_probs, idxs_to_remove | |
| # Apply min-p filtering if specified | |
| if min_p is not None: | |
| probs = F.softmax(logits, dim=-1) | |
| maxprob = probs[...,[0]] if order is not None else torch.max(probs, dim=-1, keepdim=True).values | |
| logits[probs < maxprob * min_p] = float('-inf') | |
| del probs, maxprob | |
| # Apply optional post-temperature | |
| if post_temp is not None and post_temp != 1.0: | |
| logits.div_(post_temp) | |
| # Compute softmax probabilities | |
| orig_shape = logits.shape | |
| probs = torch.softmax(logits, dim=-1, out=logits) | |
| # Flatten probabilities to (batch_size * sequence_length, vocab_size) for sampling | |
| flat_probs = probs.view(-1, probs.size(-1)) | |
| sampled = torch.multinomial(flat_probs, num_samples=1) | |
| sampled = sampled.view(*orig_shape[:-1]) | |
| # If we sorted, unsort to collect the actual token values | |
| if order is not None: | |
| sampled = torch.gather(order, dim=-1, index=sampled.unsqueeze(-1)).squeeze(-1) | |
| return token_offset + sampled | |
| def allocate_kvcache(self, n_tokens: int, batch_size: int = 1, dtype=None, device=None): | |
| assert n_tokens > 0 | |
| assert batch_size > 0 | |
| dtype = dtype or self.transformer.token_embedding.weight.dtype | |
| device = device or self.device | |
| return KVCache.allocate(self, n_tokens, batch_size, dtype=dtype, device=device) | |
| def rollout_patches(self, | |
| seq: Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]], | |
| pos: Union[torch.LongTensor, List[torch.LongTensor]], | |
| idx: Union[torch.LongTensor, List[torch.LongTensor]], | |
| n_tokens_per_patch: int = 5, | |
| n_seq_patches: int = -1, | |
| weights: Optional[Union[List[float], torch.Tensor]] = None, | |
| scatter_indices: Optional[Union[List[int], torch.LongTensor]] = None, | |
| kvcache: Optional[KVCache] = None, | |
| policy: Callable[..., torch.LongTensor] = None, | |
| *, | |
| unmask_parallel: bool = False, | |
| return_logits: bool = False, | |
| return_idx_logits: bool = True, | |
| return_kv: bool = False, | |
| temp: Optional[Union[float, List[float]]] = None, | |
| post_temp: Optional[Union[float, List[float]]] = None, | |
| top_k: Optional[Union[int, List[int]]] = None, | |
| top_p: Optional[Union[float, List[float]]] = None, | |
| min_p: Optional[Union[float, List[float]]] = None, | |
| sample_range: Optional[Tuple[int, int]] = None, | |
| blacklist: Optional[Union[List[int], torch.LongTensor]] = None, | |
| tqdm_kwargs: dict = {}, | |
| print_and_visualize_tokens: bool = False, | |
| out_dir: Optional[str] = "debug", | |
| ) -> Union[ | |
| torch.LongTensor, # seq | |
| Tuple[torch.LongTensor, torch.Tensor], # seq, logits | |
| Tuple[torch.LongTensor, KVCache], # seq, kvcache | |
| Tuple[torch.LongTensor, torch.Tensor, KVCache], # seq, logits, kvcache | |
| ]: | |
| """ | |
| K = number of given sequences (1 if seq is not a list) | |
| T = length of a conditioning sequence (per sequence) | |
| N = total length of conditioning + generated tokens (per sequence) | |
| I = number of index tokens to roll out | |
| B = number of batched rollouts to make (0 if not using batched rollouts) | |
| max(T) = maximum of T across sequences | |
| num_new_tokens = len(idx) * n_tokens_per_patch | |
| ***Logit Mixing***: | |
| To use logit mixing, supply multiple sequences/positions and a `weights` list/tensor. Before sampling tokens, the logits | |
| from each sequence will be mixed using the provided weights. | |
| ***Batched Rollouts***: | |
| To run batched rollouts *without* logit mixing, supply multiple (K) sequences/positions and leave `weights=None`. | |
| The returned sequences/logits will include an extra batch dimension of size K. | |
| ***Batched Rollouts + Logit Mixing***: | |
| To run batched rollouts and logit mixing at the same time, supply multiple (K) sequences/positions, a `weights` list, | |
| and `scatter_indices`. `scatter_indices` indicates which output sequence each input sequence corresponds to. For example, | |
| to run 3 cfg rollouts, you may supply K=6 sequences, with `weights=[cfg, 1-cfg, cfg, 1-cfg, cfg, 1-cfg]` and `scatter_indices=[0,0,1,1,2,2]`. | |
| **WARNING:** With batched rollouts and logit mixing, and logits may not be numerically deterministic on CUDA (see torch.Tensor.scatter_add_ docs.) | |
| ***Tips for long rollouts***: | |
| 1. Use `gc.collect()` and then `torch.cuda.empty_cache()` (in that order) | |
| 2. Avoid fragmentation wherever possible. This method needs to allocate the entire KV cache | |
| contiguously in memory. If you create tensors between rollouts, consider moving them to | |
| CPU or cloning them to defragment VRAM. | |
| 3. Even if a single N-token rollout fits in VRAM, running two consecutive rollouts (e.g. running | |
| n1 then giving its KV cache to n2, with n1 + n2 = N) may not fit, because reallocating the | |
| KV cache will duplicate memory. To avoid this, try moving the cache to CPU first: | |
| ``` | |
| seq1, kvcache = predictor.rollout_patches(..., return_kv=True) | |
| kvcache = { k: v.cpu() for k, v in kvcache.items() } | |
| gc.collect(); torch.cuda.empty_cache() | |
| seq2 = predictor.rollout_patches(..., **kvcache) | |
| ``` | |
| With this, the new KV cache will be allocated on GPU, and the CPU cache will be copied into it. | |
| TODO: Support temp/top_k/top_p/min_p scheduling with parallel. Currently uses index -1 for all parallel tokens | |
| Parameters: | |
| seq (Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]]): | |
| [T], [K T], or list of [T] / sequence(s) to condition the generation on. None means empty sequence | |
| pos (Union[torch.LongTensor, List[torch.LongTensor]]): | |
| [N 4], [K N 4], or list of [N 4] / 4D position(s) corresponding to each sequence. If multiple, must have length K=len(seq) | |
| idx (Union[torch.LongTensor. List[torch.LongTensor]]): | |
| [I], [B I], or list of [I] / the patch indices to use in the rollout. 1D by default, used for all sequences. If using batched rollouts, idx may be 1D or 2D, | |
| and the second dimension must be equal to the number of batched rollouts (if scatter_indices is not None, this may be less than K). | |
| n_tokens_per_patch (int): | |
| number of tokens per patch, including patch index | |
| n_seq_patches (int): | |
| number of patches to roll out sequentially (-1 for all). The remaining patches will be parallel | |
| weights (Optional[Union[List[float], torch.Tensor]]): | |
| float weights for the logits produced by each sequence, used for logit mixing. If None and multiple sequences are given, batched rollouts will be | |
| used instead of logit mixing. If a list or 1D tensor, must have size K. If 2D, must have shape [K num_new_tokens]. | |
| If 2D, the first n_seq_patches patches will use the weights in order as expected, but the remaining (parallel) patches may be arbitrarily reordered. | |
| When using parallel and a 2D weight schedule, it is recommended to make the weights for parallel patches uniform for consistency | |
| scatter_indices (Optional[Union[List[int], torch.LongTensor]]): | |
| integer scatter indices for logit mixing, indicating which generated sequence each input sequence is used for. If provided, multiple sequences | |
| will be generated and returned; see above for details. Must have exactly K elements, one for each input sequence. Indices must be set-contiguous in the range [0,K) | |
| and must include 0, e.g. for K=5, [1,0,3,2,2] are valid indices but [1,3,4,4,3] are not. | |
| kvcache (Optional[KVCache]): | |
| optional KV cache for this rollout. If kvcache.size > 0, the tokens in the cache will be used as conditioning for this rollout. If the cache has | |
| enough empty space to cache this rollout, the KV tensors will be modified in-place. Note that only the tensor data is modified, not the other entries | |
| of the KVCache (e.g. size, mask). To receive an update cache object, set return_kv=True and pass the returned cache into downstream rollouts. If no cache is | |
| provided, or if the cache is too small for this rollout, a new cache will be allocated and the tensor data from this cache will be copied. | |
| policy (Callable[..., torch.LongTensor]): | |
| optional callback defining the policy for rollout order. Must accept an argument `idx` (torch.LongTensor of shape [I]) with the candidate indices to generate next. | |
| Must return the *index* into the `idx` tensor to generate next, either an int or 0-dimensional torch.LongTensor. For example, given candidate indices [4,12,1023], | |
| return 2 to generate the patch with index 1023 next. Only used for the sequential part of the generation. The following kwargs are given | |
| - `idx` (torch.LongTensor of shape [I] or [B I]) the candidate patch indices | |
| - `pos` (torch.LongTensor of shape [K N 4]): the remaining poses for all yet-ungenerated tokens in the same order as `idx` | |
| - `weights` (torch.Tensor of shape [K N]): the weights for all yet-ungenerated tokens, or None | |
| - `scatter_indices` (torch.LongTensor of shape [K]) | |
| - `kvcache` (KVCache): the cache for this rollout, where kvcache.size includes the tokens generated up to this point | |
| - `n_tokens_per_patch` (int) | |
| - `sample_range` (Optional[Tuple[int,int]]) | |
| - `idx_pos` (torch.LongTensor of shape [K I 4]): same as `pos`, but only for the candidate index tokens | |
| - `idx_weights` (torch.Tensor of shape [K I]): same as `weights`, but only for the candidate index tokens | |
| - `predictor` (PSIPredictor) self | |
| - `device` (torch.device)\n | |
| The callback must return the following | |
| - (Union[int, torch.LongTensor]): the *index* into the `idx` tensor with the patch token to generate next (*not* the value of the patch index itself) | |
| unmask_parallel (bool): | |
| if True, all parallel patches can attend to each other. If False (default), parallel patches can only attend to themselves | |
| return_logits (bool): | |
| return the logits of the sequence | |
| return_idx_logits (bool): | |
| if True (default), returns logits that would predict index tokens, so there is one set of logits for every returned token. If False, only returns | |
| logits used to sample content tokens, e.g. returns (n_tokens_per_patch - 1) sets of logits per patch. The latter may not need to compute logits | |
| for all tokens, so it may be more efficient for some computations (such as patchwise entropy). Ignored if return_logits=False | |
| return_kv (bool): | |
| return the KV cache(s) as a dict with keys 'k_cache', 'v_cache', 'cache_mask', and 'cache_size', useful for downstream operations. If True and return_logits=False, | |
| returns (new_tokens, kvcache). If True and return_logits=True, returns (new_tokens, logits, kvcache). **All KVs are returned in patch-major order,** even if | |
| the rollout is partially or fully parallel. Note that KVs from parallel prediction are not computed causally | |
| Returns: | |
| torch.LongTensor: | |
| [num_new_tokens] or [B num_new_tokens] | |
| the generated tokens only (the input sequence is not prepended). 1D by default, or 2D if using batched rollouts | |
| torch.Tensor: | |
| (optional) [n_tokens vocab_size] or [B n_tokens vocab_size] | |
| the logits of the sequence, where n_tokens depends on return_idx_logits. 2D by default, or 3D if using batched rollouts | |
| Dict[str, torch.Tensor]: | |
| (optional) the KV cache, with the following key/value pairs | |
| - `k_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head] | |
| - `v_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head] | |
| - `cache_mask` (torch.Tensor) [K 1 1 n_tok] | |
| - `cache_size` (int) | |
| """ | |
| if print_and_visualize_tokens: | |
| raise NotImplementedError("Token visualization is not included in the Hugging Face inference build.") | |
| ######################### | |
| # === Preprocessing === # | |
| ######################### | |
| if not isinstance(seq, list): | |
| seq = [seq] if seq is None or seq.ndim == 1 else list(seq) | |
| if not isinstance(pos, list): | |
| pos = [pos] if pos.ndim == 2 else list(pos) | |
| if isinstance(idx, list): | |
| idx = torch.stack(idx) | |
| nnt = idx.shape[-1] * n_tokens_per_patch | |
| idtype = pos[0].dtype | |
| dtype = self.dtype | |
| device = self.device | |
| if weights is not None: | |
| if isinstance(weights, list): | |
| weights = torch.tensor(weights, dtype=dtype, device=device) | |
| if weights.ndim != 2: | |
| weights = weights.unsqueeze(-1).expand(-1, nnt) | |
| weights = weights.to(dtype).to(device) | |
| if scatter_indices is not None: | |
| if isinstance(scatter_indices, list): | |
| scatter_indices = torch.tensor(scatter_indices, dtype=idtype, device=device) | |
| if n_seq_patches < 0: | |
| n_seq_patches = idx.shape[-1] | |
| if temp is not None and not isinstance(temp, list): | |
| temp = [temp] * nnt | |
| if post_temp is not None and not isinstance(post_temp, list): | |
| post_temp = [post_temp] * nnt | |
| if top_k is not None and not isinstance(top_k, list): | |
| top_k = [top_k] * nnt | |
| if top_p is not None and not isinstance(top_p, list): | |
| top_p = [top_p] * nnt | |
| if min_p is not None and not isinstance(min_p, list): | |
| min_p = [min_p] * nnt | |
| K = len(seq) | |
| I = idx.shape[-1] | |
| T = [0 if s is None else s.shape[0] for s in seq] | |
| maxT = max(T) | |
| tpp = n_tokens_per_patch | |
| # Number of tokens we need to use from the input cache | |
| in_cache_size = 0 if kvcache is None else kvcache.size | |
| n_rollout_tokens = tpp * n_seq_patches | |
| n_par_patches = I - n_seq_patches | |
| return_idx_logits = return_logits and return_idx_logits | |
| run_last_parallel_tokens = return_idx_logits or return_kv | |
| # Number of batched rollouts to make. If 0, we return a single 1D sequence, else return 2D | |
| if scatter_indices is not None: | |
| B = int(torch.max(scatter_indices).cpu().item()) + 1 | |
| else: | |
| B = K if (weights is None and K > 1) else 0 | |
| if idx.ndim == 2: | |
| assert B > 0, f'2D idx tensors are only supported for batched rollouts (got idx.shape={tuple(idx.shape)})' | |
| assert idx.shape == (B, I), f'Expected idx.shape={(B,I)}, but got {tuple(idx.shape)}.' | |
| if scatter_indices is not None: | |
| idx = torch.gather(idx, 0, scatter_indices.unsqueeze(-1).repeat(1, idx.shape[-1])) # [B I] -> [K I] | |
| # else B=K, so idx is already [K I] | |
| elif idx.ndim == 1 and B > 0: | |
| idx = idx.expand(K, idx.shape[-1]) | |
| if idx.ndim == 2: | |
| # TODO: Batched rollouts + policies are currently not supported | |
| if policy is not None: | |
| raise NotImplementedError('Batched rollouts with a policy callback is not currently supported') | |
| # Later, we may need to reduce idx from [K I] to [B I], such as before invoking | |
| # the policy callback or running parallel prediction. Here we precompute the rows | |
| # to use so we can efficiently perform the reduction on-demand. | |
| # Note we can't use torch.unique here because it doesn't maintain order. | |
| def precompute_idx_batch_rows(scatter: List[int]): | |
| seen = set() | |
| out = [] | |
| for i, s in enumerate(scatter): | |
| if s not in seen: | |
| seen.add(s) | |
| out.append(i) | |
| return out | |
| if scatter_indices is None: | |
| idx_batch_rows = list(range(K)) | |
| else: | |
| idx_batch_rows = precompute_idx_batch_rows(scatter_indices.tolist()) | |
| # Validate inputs as best we can | |
| assert len(pos) == K, f'Expected seq and pos lists to have the same length, but got {K} and {len(pos)}' | |
| assert (idx.shape == (K, I)) or (idx.ndim == 1), f'idx tensor must be 1D (or 2D for batched rollouts)' | |
| if weights is not None: | |
| assert weights.ndim == 2 and weights.shape == (K, nnt) | |
| if scatter_indices is not None: | |
| assert scatter_indices.ndim == 1 and len(scatter_indices) == K | |
| assert torch.min(scatter_indices) >= 0 and torch.max(scatter_indices) < K | |
| _usi = torch.unique(scatter_indices, sorted=True) | |
| assert torch.equal(_usi, torch.arange(len(_usi), device=_usi.device)) | |
| del _usi | |
| assert I > 0, f'No idx tokens were provided' | |
| assert tpp > 0, f'Must have a positive number of tokens per patch' | |
| assert I * tpp == nnt | |
| assert 0 <= n_seq_patches <= I | |
| if kvcache is not None: | |
| kvcache.validate(K) | |
| for i, (s, p) in enumerate(zip(seq, pos)): | |
| if s is not None: | |
| assert s.ndim == 1 | |
| assert p.ndim == 2 | |
| assert p.shape[1] == 4 | |
| assert p.shape[0] == T[i] + nnt | |
| # General notes: | |
| # B>0 means batched rollout mode, even if B==1. Batched means returned sequences are [B nnt], but B==0 means they're [nnt]. | |
| # 0 <= B <= K. If B==K, then all sequences are independent (no logit mixing). | |
| # If B>0, then idx is [K I] with B unique rows (some may be repeated). If B==0, then idx is [I]. | |
| # Use idx[idx_batch_rows] to select idx with shape [B I]. | |
| # Use Tensor.scatter_add + scatter_indices to sum-reduce from [K *] to [B *]. | |
| ######################### | |
| # === Preallocation === # | |
| ######################### | |
| # Preallocate the KV cache (if we need to) so we don't need to constantly resize it | |
| # If we won't run the last parallel pass, we don't need to cache those toks | |
| n_kvcache = in_cache_size + maxT + nnt - (0 if run_last_parallel_tokens else n_par_patches) | |
| if kvcache is None or kvcache.capacity < n_kvcache: | |
| # Allocate a new cache and copy the existing one into it | |
| new_kvcache = self.allocate_kvcache(n_kvcache, batch_size=K, dtype=dtype, device=device) | |
| if kvcache is not None: | |
| new_kvcache.copy_from(kvcache) | |
| kvcache = new_kvcache | |
| del new_kvcache | |
| # else we can just use the provided cache in-place | |
| # Also preallocate the output logits tensor, if requested | |
| if return_logits: | |
| n_logits = n_rollout_tokens + (I - n_seq_patches) * (tpp if return_idx_logits else (tpp - 1)) | |
| if B > 0: | |
| # Easiest to store/append this shape then transpose first two axes on return | |
| all_logits = torch.empty((n_logits, B, self.config.vocab_size), dtype=dtype, device=device) | |
| else: | |
| all_logits = torch.empty((n_logits, self.config.vocab_size), dtype=dtype, device=device) | |
| ################################ | |
| # === Initial Forward Pass === # | |
| ################################ | |
| # Stack seq/pos into a batch, left-padded | |
| # [K maxT] | |
| if maxT > 0: | |
| seq = torch.stack([( | |
| torch.zeros(maxT, dtype=idtype, device=device) if s is None else | |
| F.pad(s, (maxT - s.shape[0], 0)) | |
| ) for s in seq]) | |
| # [K maxN 4] | |
| pos = torch.stack([F.pad(p, (0, 0, maxT - t, 0)) for t, p in zip(T, pos)]) | |
| if maxT > 0: | |
| # Build attention mask for initial forward pass [K 1 maxT maxT] | |
| # Batch size K, each mask in the batch is fully causal except for the first (maxT - T) tokens, which are masked | |
| mask = torch.zeros(K, 1, maxT, maxT, device=device) | |
| mask.masked_fill_(torch.ones_like(mask, dtype=torch.bool).triu(1), float('-inf')) | |
| for i, t in enumerate(T): | |
| mask[i, ..., :maxT-t] = float('-inf') | |
| # Unmask the diagonal so pad tokens can self-attend | |
| # This doesn't matter with torch sdpa, but prevents NaNs with manual attention | |
| # NOTE: If t==0, the diagonal is *only* pad tokens, so this will unmask the last pad token | |
| # in the last row (which we use for rollouts). We re-mask this pad token in the rollout mask below | |
| mask[i, 0].fill_diagonal_(0.0) | |
| if kvcache is not None: | |
| if kvcache.mask is not None: | |
| mask = torch.cat([kvcache.mask.to(mask.device).expand((K, 1, maxT, -1)), mask], dim=-1) | |
| else: | |
| mask = F.pad(mask, (in_cache_size, 0, 0, 0, 0, 0, 0, 0)) | |
| # The above mask[:,0,-1,:] might look something like this: | |
| # kv cache | sequences | |
| # T[0]==3 [ T T T T T T | F F F F T T T ] | |
| # T[1]==6 [ T T T T T T | F T T T T T T ] | |
| # T[2]==7 [ T T T T T T | T T T T T T T ] | |
| # T[3]==4 [ T T T T T T | F F F T T T T ] | |
| # For one element in the batch, mask[0,0,:,:] with T[0]==3 might look like: | |
| # kv cache | sequences | |
| # [ T T T T T T | T F F F F F F ] | |
| # [ T T T T T T | F T F F F F F ] | |
| # [ T T T T T T | F F T F F F F ] | |
| # [ T T T T T T | F F F T F F F ] | |
| # [ T T T T T T | F F F F T F F ] | |
| # [ T T T T T T | F F F F T T F ] | |
| # [ T T T T T T | F F F F T T T ] | |
| # If a custom kvcache.mask is given, the kv cache part above may be different | |
| # Initial forward pass (conditioning sequences only) | |
| k_cache, v_cache = kvcache.get_token_slice(0, in_cache_size+maxT) | |
| self.forward( | |
| seq=seq, pos=pos[:,:maxT], mask=mask, | |
| k_cache=k_cache, v_cache=v_cache, inplace_kv=True | |
| ) | |
| pos = pos[:,maxT:] | |
| ############################## | |
| # === Sequential Rollout === # | |
| ############################## | |
| # Build attention mask for rollout [K 1 1 in_cache_size+maxT+n_rollout_tokens] | |
| if maxT == 0: | |
| if kvcache is None or kvcache.mask is None: | |
| # We have no context and no input cache mask, so we only need rollout tokens | |
| # Note we may still have an input cache without a mask, so add in_cache_size tokens | |
| mask = torch.zeros((K, 1, 1, in_cache_size+n_rollout_tokens), dtype=dtype, device=device) | |
| else: | |
| # We have no context and an input cache mask | |
| # Pad [K 1 1 in_cache_size] -> [K 1 1 in_cache_size+n_rollout_tokens] | |
| mask = F.pad(kvcache.mask, (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0)) | |
| else: | |
| # We created a mask while computing context above. Take the last entry and clone to free memory | |
| mask = F.pad(mask[...,[-1],:].clone(), (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0)) | |
| for i, t in enumerate(T): | |
| # If t==0, the fill_diagonal call above unmasked the last pad token, | |
| # so we need to re-mask it before we start rolling out | |
| if t == 0: | |
| mask[i, ..., in_cache_size+maxT-1] = float('-inf') | |
| # The above mask[:,0,0,:] might look something like this: | |
| # kv cache | sequences | rollout | |
| # T[0]==3 [ T T T T T T | F F F F T T T | T T T T ... T T T T ] | |
| # T[1]==6 [ T T T T T T | F T T T T T T | T T T T ... T T T T ] | |
| # T[2]==7 [ T T T T T T | T T T T T T T | T T T T ... T T T T ] | |
| # T[3]==4 [ T T T T T T | F F F T T T T | T T T T ... T T T T ] | |
| # We construct this mask once, then slice off part of the right side at each rollout step | |
| # List of 0D (if B==0) or 1D (if B>0) tensors | |
| rollout_seq = [] | |
| # Rollout | |
| tqdm_kwargs = dict(desc='Rollout', unit='tok') | tqdm_kwargs | |
| if n_rollout_tokens == 0: | |
| tqdm_kwargs |= dict(disable=True) | |
| for i in tqdm.tqdm(range(n_rollout_tokens), **tqdm_kwargs): | |
| if (i % tpp) == 0: | |
| patch_number = i // tpp | |
| if policy is None: | |
| # Use provided order | |
| next_token = idx[...,patch_number] # 0D or [K] | |
| else: | |
| # Use callback to select the next patch | |
| policy_cache_mask = mask[..., :in_cache_size+maxT+i] | |
| idx_of_next_idx = policy( | |
| idx=idx[idx_batch_rows,patch_number:] if B > 0 else idx[patch_number:], | |
| pos=pos[:, i:], | |
| weights=None if weights is None else weights[:, i:], | |
| scatter_indices=scatter_indices, | |
| kvcache=kvcache.shallow_copy(mask=policy_cache_mask, size=in_cache_size+maxT+i), | |
| n_tokens_per_patch=n_tokens_per_patch, | |
| sample_range=sample_range, | |
| idx_pos=pos[:, i::tpp], | |
| idx_weights=None if weights is None else weights[:, i::tpp], | |
| predictor=self, | |
| device=idx.device, | |
| ) | |
| del policy_cache_mask | |
| # Move the patch patch_number+idx_of_next_idx to the next position by swapping | |
| # TODO: Policies are currently not supported with batched rollouts | |
| if idx_of_next_idx != 0: | |
| i1, i2 = patch_number, patch_number + int(idx_of_next_idx) | |
| idx[[i1,i2]] = idx[[i2,i1]] | |
| # [K N 4] -> [K I tpp 4] -swap-idxs-> [K I tpp 4] -> [K N 4] | |
| pos = pos.reshape(K, I, tpp, 4) | |
| pos[:,[i1,i2]] = pos[:,[i2,i1]] | |
| pos = pos.reshape(K, -1, 4) | |
| next_token = idx[patch_number] | |
| rollout_seq.append(next_token[idx_batch_rows] if B > 0 else next_token) | |
| # Forward call | |
| k_cache, v_cache = kvcache.get_token_slice(0, in_cache_size+maxT+i+1) | |
| logits, _ = self.forward( | |
| seq=next_token.unsqueeze(-1).expand(K, 1), # 0D or [K] -> [K 1] | |
| pos=pos[:, [i]], # [K 1 4] | |
| mask=mask[..., :in_cache_size+maxT+i+1], # [K 1 1 n_prev+1] | |
| k_cache=k_cache, | |
| v_cache=v_cache, | |
| inplace_kv=True | |
| ) | |
| logits = logits.squeeze(1) # [K 1 V] -> [K V] | |
| # If next is an index token, we just needed to cache the previous token | |
| # so we can skip logits if we're not returning them | |
| if ((i + 1) % tpp) == 0 and not return_idx_logits: | |
| continue | |
| # Logit mixing: [K V] -> [1 V] or [B V] | |
| if weights is not None: | |
| logits.mul_(weights[:,[i]]) # weights: [K nnt] -> [K 1], logits: [K V] | |
| if scatter_indices is not None: | |
| new_logits = torch.zeros((B, logits.shape[-1]), dtype=dtype, device=device) | |
| si = scatter_indices.unsqueeze(-1).repeat(1, logits.shape[-1]) # [B V] | |
| new_logits.scatter_add_(0, si, logits) # [B V] | |
| logits = new_logits | |
| del new_logits | |
| elif B == 0: | |
| logits = logits.sum(0, keepdim=True) # [1 V] | |
| # else logits: [K V] = [B V] | |
| # If next is an index token, we just needed to cache the previous token | |
| if ((i + 1) % tpp) == 0: | |
| if return_idx_logits: | |
| all_logits[i].copy_(logits if B > 0 else logits.squeeze()) | |
| continue | |
| if return_logits: | |
| all_logits[i].copy_(logits if B > 0 else logits.squeeze()) | |
| # Sample from logits | |
| next_token = self.sample_logits( | |
| logits, | |
| temp=temp[i] if temp else None, | |
| post_temp=post_temp[i] if post_temp else None, | |
| top_k=top_k[i] if top_k else None, | |
| top_p=top_p[i] if top_p else None, | |
| min_p=min_p[i] if min_p else None, | |
| sample_range=sample_range, | |
| blacklist=blacklist | |
| ) | |
| rollout_seq.append(next_token if B > 0 else next_token.squeeze()) | |
| # Batched rollouts: We need to remap next_token from shape [B] to [K 1] | |
| # Note that if B>0 and scatter_indices is None, then B=K, as enforced during preprocessing | |
| if B > 0 and scatter_indices is not None: | |
| next_token = torch.gather(next_token, 0, scatter_indices) # [B] -> [K] | |
| if n_rollout_tokens > 0: | |
| rollout_seq = torch.stack(rollout_seq) # [nnt] or [nnt B] | |
| if B > 0: | |
| rollout_seq = rollout_seq.mT # [nnt B] -> [B nnt] | |
| if n_rollout_tokens == nnt: | |
| ret = (rollout_seq,) | |
| if return_logits: | |
| # [nnt B V] -> [B nnt V], or [nnt V] | |
| ret = (*ret, all_logits.transpose(0, 1)) if B > 0 else (*ret, all_logits) | |
| if return_kv: | |
| ret = (*ret, kvcache.shallow_copy(mask=mask, size=n_kvcache)) | |
| return ret if len(ret) > 1 else ret[0] | |
| ############################ | |
| # === Parallel Rollout === # | |
| ############################ | |
| npp = n_par_patches | |
| npt = npp * tpp # num parallel tokens | |
| idx = idx[-npp:] | |
| # Build attention mask for parallel part [K 1 npp n_past+npp] | |
| if unmask_parallel: | |
| par_mask = torch.zeros(1, dtype=mask.dtype, device=device).expand(K, 1, npp, npp) | |
| else: | |
| par_mask = torch.full((npp, npp), float('-inf'), dtype=mask.dtype, device=device) | |
| par_mask.fill_diagonal_(0.0) | |
| par_mask = par_mask.expand(K, 1, npp, npp) | |
| mask = mask.expand(K, 1, npp, -1) | |
| # Initial shape is [K 1 npp n_past]. Before each iter, we append par_mask [K 1 npp npp] along the last dim | |
| # Reshape positions for parallel passes | |
| # [K nnt 4] -trim-> [K npt 4] -> [K npp tpp 4] -> [K tpp npp 4] -> [K npt 4] | |
| pos = pos[:,-npt:].reshape(K, npp, tpp, 4).transpose(1, 2).reshape(K, npt, 4) | |
| # Reshape scheduled properties (transpose similarly to positions) | |
| # [nnt] -trim-> [npt] -> [npp tpp] -> [tpp npp] -> [npt] | |
| if temp is not None: | |
| temp = np.array(temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() | |
| if post_temp is not None: | |
| post_temp = np.array(post_temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() | |
| if top_k is not None: | |
| top_k = np.array(top_k[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() | |
| if top_p is not None: | |
| top_p = np.array(top_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() | |
| if min_p is not None: | |
| min_p = np.array(min_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() | |
| if weights is not None: | |
| # [K nnt] -trim-> [K npt] -> [K npp tpp] -> [K tpp npp] -> [K npt] | |
| weights = weights[:,-npt:].reshape(K, npp, tpp).transpose(1, 2).reshape(K, npt) | |
| next_tokens = idx.expand(K, npp) # [K npp] or [npp] -> [K npp] | |
| parallel_seq = [idx[idx_batch_rows]] if B > 0 else [idx] # [B npp] or [npp] | |
| # Run parallel passes | |
| for i in range(tpp if run_last_parallel_tokens else (tpp - 1)): | |
| mask = torch.cat([mask, par_mask], dim=-1) | |
| parallel_slice = slice(i*npp, (i+1)*npp) | |
| k_cache, v_cache = kvcache.get_token_slice(0, in_cache_size+maxT+n_rollout_tokens+(i+1)*npp) | |
| logits, _ = self.forward( | |
| seq=next_tokens, # [K npp] | |
| pos=pos[:,parallel_slice], # [K npp 4] | |
| mask=mask, # [K 1 npp n_past+npp] | |
| k_cache=k_cache, | |
| v_cache=v_cache, | |
| inplace_kv=True | |
| ) | |
| # Weighted sum of logits [K npp V] -> [1 npp V] or [B npp V] | |
| if weights is not None: | |
| w = weights[:,parallel_slice] # [K npt] -> [K npp] | |
| logits.mul_(w.unsqueeze(-1)) # [K npp V] | |
| if scatter_indices is not None: | |
| new_logits = torch.zeros((B, npp, logits.shape[-1]), dtype=dtype, device=device) | |
| new_logits.scatter_add_(0, scatter_indices.reshape(-1,1,1).repeat(1, npp, logits.shape[-1]), logits) | |
| logits = new_logits # [B npp V] | |
| del new_logits | |
| elif B == 0: | |
| logits = logits.sum(0, keepdim=True) # [K npp V] -> [1 npp V] | |
| # else B==K, as enforced in preprocessing | |
| if return_logits and (return_idx_logits or i < tpp - 1): | |
| # Store the logits with a stride so we don't need to transpose later | |
| # NOTE: If return_idx_logits=False, we have tpp-1 instead of tpp | |
| stride = tpp if return_idx_logits else (tpp - 1) | |
| all_logits[n_rollout_tokens+i::stride].copy_( | |
| logits.transpose(0,1) if B > 0 else logits.squeeze(0) # [npp B V] or [npp V] | |
| ) | |
| if i == (tpp - 1): | |
| # We just needed to compute logits and/or KV to return them; no need to predict | |
| break | |
| # Sample from logits | |
| # TODO: Index using parallel_slice instead of -1 to support scheduled parameters | |
| next_tokens = self.sample_logits( | |
| logits, | |
| temp=temp[-1] if temp else None, | |
| post_temp=post_temp[-1] if post_temp else None, | |
| top_k=top_k[-1] if top_k else None, | |
| top_p=top_p[-1] if top_p else None, | |
| min_p=min_p[-1] if min_p else None, | |
| sample_range=sample_range, | |
| blacklist=blacklist | |
| ) | |
| parallel_seq.append(next_tokens if B > 0 else next_tokens.squeeze(0)) | |
| # Batched rollouts: We need to remap next_tokens from shape [B npp] to [K npp] | |
| # Note that if B>0 and scatter_indices is None, then B=K, as enforced during preprocessing | |
| if B > 0 and scatter_indices is not None: | |
| next_tokens = torch.gather(next_tokens, 0, scatter_indices.unsqueeze(-1).repeat(1,npp)) | |
| # [tpp npp] -> [npp tpp] -> [npt] | |
| parallel_seq = torch.stack(parallel_seq).transpose(0, 1).flatten() | |
| if return_kv: | |
| # Transpose the parallel part of the cache so it's returned the same order as sequential rollouts | |
| kvcache.transpose_last(n_kvcache, npt, tpp) | |
| # [K 1 npp N] -> [K 1 1 N], clone to free memory, doesn't matter which row we take (only different in last npt cols) | |
| # Unmask the last npt cols (if necessary) to make the parallel part fully unmasked | |
| mask = mask[...,[-1],:].clone() | |
| if not unmask_parallel: | |
| mask[...,-npt:] = 0.0 | |
| ret = (torch.cat([rollout_seq, parallel_seq], -1),) if n_rollout_tokens > 0 else (parallel_seq,) | |
| if return_logits: | |
| # [nnt B V] -> [B nnt V], or [nnt V] | |
| ret = (*ret, all_logits.transpose(0, 1)) if B > 0 else (*ret, all_logits) | |
| if return_kv: | |
| ret = (*ret, KVCache(k=kvcache.k, v=kvcache.v, mask=mask, size=n_kvcache)) | |
| return ret if len(ret) > 1 else ret[0] | |
| def _unpatchify(labels: torch.Tensor) -> torch.Tensor: | |
| """ | |
| labels: [B, N, P*3], where P = patch_area | |
| Returns [B, 3, H, W] assuming square grid and patches. | |
| """ | |
| B, N, threeP = labels.shape | |
| C = 3 | |
| P = threeP // C | |
| p = int(math.sqrt(P)) | |
| side = int(math.sqrt(N)) | |
| assert p * p == P and side * side == N | |
| rec = labels.view(B, side, side, p, p, C).permute(0, 5, 1, 3, 2, 4).contiguous() | |
| img = rec.view(B, C, side * p, side * p) | |
| return img | |
| def _unpatchify_logits(logits: torch.Tensor) -> torch.Tensor: | |
| """ | |
| logits: [B, N, P, V] -> [B, V, H, W] aligned to the unpatchify layout. | |
| """ | |
| B, N, P, V = logits.shape | |
| p = int(math.sqrt(P)) | |
| side = int(math.sqrt(N)) | |
| assert p * p == P and side * side == N | |
| out = logits.view(B, side, side, p, p, V).permute(0, 5, 1, 3, 2, 4).contiguous() | |
| out = out.view(B, V, side * p, side * p) | |
| return out | |