import math from typing import Tuple, Union, List, Dict, Optional, Callable import tqdm import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def validate_psi2_config(config) -> None: """ Validate PSI2 config to catch misconfigurations early. Raises: AssertionError: If any validation fails with a descriptive message. """ if hasattr(config, 'pointer_token'): assert config.pointer_token < config.vocab_size, \ f"pointer_token ({config.pointer_token}) must be < vocab_size ({config.vocab_size})" token_range_attrs = ['rgb_range', 'flow_range', 'depth_range', 'campose_range', 'imagenet_cls_range'] for attr in token_range_attrs: if hasattr(config, attr): lo, hi = getattr(config, attr) assert 0 <= lo <= hi <= config.vocab_size, \ f"{attr} ({lo}, {hi}) out of bounds for vocab_size={config.vocab_size}" channel_range_attrs = [ 'pointer_channel_range', 'rgb_channel_range', 'campose_channel_range', 'flow_channel_range', 'depth_channel_range', 'cls_channel_range' ] for attr in channel_range_attrs: if hasattr(config, attr): lo, hi = getattr(config, attr) assert 0 <= lo < hi <= config.channel_size, \ f"{attr} ({lo}, {hi}) out of bounds for channel_size={config.channel_size}" assert 0.0 <= config.dropout <= 1.0, \ f"dropout ({config.dropout}) must be in [0.0, 1.0]" if hasattr(config, 'drop_path_rate'): assert 0.0 <= config.drop_path_rate <= 1.0, \ f"drop_path_rate ({config.drop_path_rate}) must be in [0.0, 1.0]" if hasattr(config, 'residual_scale'): assert config.residual_scale > 0, \ f"residual_scale ({config.residual_scale}) must be positive" if hasattr(config, 'n_kv_head') and config.n_kv_head is not None: assert config.n_kv_head > 0, \ f"n_kv_head ({config.n_kv_head}) must be positive (or None to disable GQA)" assert config.n_head % config.n_kv_head == 0, \ f"n_head ({config.n_head}) must be divisible by n_kv_head ({config.n_kv_head})" assert config.n_embd % config.n_head == 0, \ f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})" if hasattr(config, 'mlp_activation'): valid_activations = {'swiglu', 'gelu'} assert config.mlp_activation.lower() in valid_activations, \ f"mlp_activation ({config.mlp_activation}) must be one of {valid_activations}" if hasattr(config, 'attention_mask'): valid_masks = {'causal', 'none', 'full'} assert config.attention_mask.lower() in valid_masks, \ f"attention_mask ({config.attention_mask}) must be one of {valid_masks}" if hasattr(config, 'context_parallel_size'): assert int(config.context_parallel_size) >= 1, \ f"context_parallel_size ({config.context_parallel_size}) must be >= 1" if hasattr(config, 'load_balance_cp') and config.load_balance_cp: assert getattr(config, 'context_parallel_size', 1) > 1, \ "load_balance_cp requires context_parallel_size > 1" def _to_additive_mask( mask: Optional[torch.Tensor], *, L: int, S: int, device: torch.device, dtype: torch.dtype, ) -> Optional[torch.Tensor]: """ Convert a user-provided mask to an additive mask with 0 for allowed and -inf for blocked. Accepts: - None - boolean mask (True = blocked) - float mask already in additive form """ if mask is None: return None if mask.dtype == torch.bool: add = torch.zeros_like(mask, dtype=dtype, device=mask.device) add = add.masked_fill(mask, float('-inf')) return add return mask.to(dtype) if mask.dtype != dtype else mask def _build_causal_additive_mask( L: int, S: int, *, device: torch.device, dtype: torch.dtype, past_k: int = 0, ) -> torch.Tensor: """ Additive causal mask for rectangular attention where query length is L and key length is S. A query at row i may attend up to index (past_k + i). Positions j > past_k + i are masked. Returns [1,1,L,S]. """ i = torch.arange(L, device=device).unsqueeze(-1) j = torch.arange(S, device=device).unsqueeze(0) blocked = j > (past_k + i) add = torch.zeros((L, S), device=device, dtype=dtype) add = add.masked_fill(blocked, float('-inf')) return add.view(1, 1, L, S) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: rms = x.float().pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt().to(x.dtype) return self.weight * (x * rms) class Rotary3D(nn.Module): """ 3D Rotary embedding for per-head dims. Splits head_dim across (x,y,t) and applies rotations. """ def __init__(self, head_dim: int, base: int = 100): super().__init__() assert head_dim % 2 == 0, "head_dim must be even." x_dim = (head_dim * 6) // 16 y_dim = (head_dim * 6) // 16 t_dim = head_dim - x_dim - y_dim for d in (x_dim, y_dim, t_dim): if d % 2 != 0: raise AssertionError("Each axis dim must be even for RoPE.") self.base = base self.x_dim = x_dim self.y_dim = y_dim self.t_dim = t_dim self.register_buffer("inv_freq_x", self._inv_freqs(self.x_dim), persistent=False) self.register_buffer("inv_freq_y", self._inv_freqs(self.y_dim), persistent=False) self.register_buffer("inv_freq_t", self._inv_freqs(self.t_dim), persistent=False) def initialize_buffers(self): '''Called in Checkpointing.instantiate_model when weight init is skipped''' self.inv_freq_x = self._inv_freqs(self.x_dim) self.inv_freq_y = self._inv_freqs(self.y_dim) self.inv_freq_t = self._inv_freqs(self.t_dim) def _inv_freqs(self, dim: int): return 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim)) def _axis_rot(self, x_half1, x_half2, cos, sin): return torch.cat([x_half1 * cos - x_half2 * sin, x_half1 * sin + x_half2 * cos], dim=-1) def forward(self, x: torch.Tensor, pos_xyz: torch.Tensor) -> torch.Tensor: """ x: [B, nH, T, H] pos_xyz: [B, T, 3] int/float (x, y, t) """ B, nH, T, H = x.shape assert pos_xyz.shape == (B, T, 3), f"pos must be [B, T, 3], got {pos_xyz.shape}" # Move inv_freq buffers to input device if needed (FSDP doesn't move buffers) if self.inv_freq_x.device != x.device: self.inv_freq_x = self.inv_freq_x.to(x.device) self.inv_freq_y = self.inv_freq_y.to(x.device) self.inv_freq_t = self.inv_freq_t.to(x.device) px = pos_xyz[..., 0].to(self.inv_freq_x.dtype) py = pos_xyz[..., 1].to(self.inv_freq_y.dtype) pt = pos_xyz[..., 2].to(self.inv_freq_t.dtype) fx = torch.einsum("bt,f->btf", px, self.inv_freq_x) fy = torch.einsum("bt,f->btf", py, self.inv_freq_y) ft = torch.einsum("bt,f->btf", pt, self.inv_freq_t) cx, sx = fx.cos().unsqueeze(1), fx.sin().unsqueeze(1) cy, sy = fy.cos().unsqueeze(1), fy.sin().unsqueeze(1) ct, st = ft.cos().unsqueeze(1), ft.sin().unsqueeze(1) half = H // 2 x1, x2 = x[..., :half], x[..., half:] def split_axis(tensor, dims): a, b, c = dims a1 = a // 2; b1 = b // 2; c1 = c // 2 s0 = tensor[..., :a1] s1 = tensor[..., a1:a1 + b1] s2 = tensor[..., a1 + b1:a1 + b1 + c1] return s0, s1, s2 x1x, x1y, x1t = split_axis(x1, (self.x_dim, self.y_dim, self.t_dim)) x2x, x2y, x2t = split_axis(x2, (self.x_dim, self.y_dim, self.t_dim)) rx = self._axis_rot(x1x, x2x, cx, sx) ry = self._axis_rot(x1y, x2y, cy, sy) rt = self._axis_rot(x1t, x2t, ct, st) rotated = torch.cat([rx, ry, rt], dim=-1) return rotated class DropPath(nn.Module): """Per-sample DropPath, broadcast across non-batch dims.""" def __init__(self, drop_prob: float = 0.0): super().__init__() self.drop_prob = float(drop_prob) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.drop_prob == 0.0 or not self.training: return x keep = 1.0 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) mask = x.new_empty(shape).bernoulli_(keep) return x * (mask / keep) class MLP(nn.Module): def __init__(self, n_embd: int, mlp_hidden_size: Optional[int] = None, *, bias: bool = False, dropout: float = 0.0, activation: str = "swiglu"): super().__init__() inner = mlp_hidden_size if mlp_hidden_size is not None else 4 * n_embd self.activation = activation.lower() if self.activation == "swiglu": self.c_fc = nn.Linear(n_embd, 2 * inner, bias=bias) self.c_proj = nn.Linear(inner, n_embd, bias=bias) elif self.activation == "gelu": self.c_fc = nn.Linear(n_embd, inner, bias=bias) self.c_proj = nn.Linear(inner, n_embd, bias=bias) else: raise ValueError(f"Unsupported mlp_activation: {activation}") self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.activation == "swiglu": u, v = self.c_fc(x).chunk(2, dim=-1) x = F.silu(v) * u else: x = F.gelu(self.c_fc(x)) x = self.c_proj(x) x = self.drop(x) return x class PSIAttentionLayer(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head" self.n_head = config.n_head self.n_kv_head = getattr(config, "n_kv_head", None) or self.n_head assert self.n_head % self.n_kv_head == 0, "n_head must be a multiple of n_kv_head" self.head_dim = config.n_embd // self.n_head self.dropout = config.dropout self.bias = getattr(config, "bias", False) self.is_causal = (getattr(config, "attention_mask", "causal") == "causal") self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=self.bias) self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=self.bias) self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=self.bias) self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=self.bias) self.attn_dropout = nn.Dropout(self.dropout) self.resid_dropout = nn.Dropout(self.dropout) self.rope = Rotary3D(self.head_dim) self._use_sdpa = hasattr(F, "scaled_dot_product_attention") @torch.compiler.disable def emplace_kv(self, T, k_cache, v_cache, k, v): # torch.compile doesn't play well with this op (5x slowdown) # so we insert a graph break and copy eagerly k_cache[:,:,-T:].copy_(k) v_cache[:,:,-T:].copy_(v) return k_cache, v_cache def forward( self, x: torch.Tensor, pos_xyz: torch.Tensor, *, k_cache: Optional[torch.Tensor] = None, v_cache: Optional[torch.Tensor] = None, return_kv: bool = False, inplace_kv: bool = False, mask: Optional[torch.Tensor] = None ): B, T, C = x.shape H, HKV, D = self.n_head, self.n_kv_head, self.head_dim groups = H // HKV q = self.q_proj(x).view(B, T, H, D).transpose(1, 2) k = self.k_proj(x).view(B, T, HKV, D).transpose(1, 2) v = self.v_proj(x).view(B, T, HKV, D).transpose(1, 2) q = self.rope(q, pos_xyz) k = self.rope(k, pos_xyz) past_k = k_cache.shape[2] if (k_cache is not None) else 0 if inplace_kv and (k_cache is not None) and (v_cache is not None): k, v = self.emplace_kv(T, k_cache, v_cache, k, v) else: if k_cache is not None: k = torch.cat([k_cache, k], dim=2) if v_cache is not None: v = torch.cat([v_cache, v], dim=2) S = k.shape[2] # Check if we're on XLA device at runtime - more reliable than init-time check # The init-time _tpu_available() can return False if called before XLA is initialized is_xla = 'xla' in str(x.device) expand_kv_heads = HKV != H if expand_kv_heads: # Note that despite using expand + reshape, this still technically copies the # tensor because it is impossible to create a strided view of repeated data. # This is equivalent to repeat_interleave. # We use this even with F.scaled_dot_product_attention because the enable_gqa arg # is slower and uses more memory. This may change with later torch versions. k = k.unsqueeze(2).expand(B, HKV, groups, S, D).reshape(B, H, S, D) v = v.unsqueeze(2).expand(B, HKV, groups, S, D).reshape(B, H, S, D) add_mask = _to_additive_mask(mask, L=T, S=S, device=x.device, dtype=q.dtype) if add_mask is None and self.is_causal: add_mask = _build_causal_additive_mask(T, S, device=x.device, dtype=q.dtype, past_k=past_k) if self._use_sdpa and not is_xla: y = F.scaled_dot_product_attention( q, k, v, attn_mask=add_mask, dropout_p=self.dropout if self.training else 0.0, ) else: att = torch.einsum("bhtd,bhsd->bhts", q, k) * (1.0 / math.sqrt(D)) if add_mask is not None: att = att + add_mask att = F.softmax(att, dim=-1, dtype=torch.float32).to(q.dtype) att = self.attn_dropout(att) y = torch.einsum("bhts,bhsd->bhtd", att, v) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.resid_dropout(self.out_proj(y)) if return_kv: if expand_kv_heads: return y, k[:, ::groups], v[:, ::groups] return y, k, v return y class Block(nn.Module): def __init__(self, config, layer_idx: int, total_layers: int): super().__init__() self.residual_scale = float(getattr(config, "residual_scale", 1.0)) drop_path_rate = float(getattr(config, "drop_path_rate", 0.0)) layer_dp = drop_path_rate * (layer_idx / max(1, total_layers - 1)) self.ln_1 = RMSNorm(config.n_embd) self.attn = PSIAttentionLayer(config) self.drop_path_1 = DropPath(layer_dp) self.ln_2 = RMSNorm(config.n_embd) self.mlp = MLP( config.n_embd, getattr(config, "mlp_hidden_size", None), bias=getattr(config, "bias", False), dropout=config.dropout, activation=getattr(config, "mlp_activation", "swiglu"), ) self.drop_path_2 = DropPath(layer_dp) def forward( self, x: torch.Tensor, pos: torch.Tensor, *, k_cache: Optional[torch.Tensor] = None, v_cache: Optional[torch.Tensor] = None, return_kv: bool = False, inplace_kv: bool = False, mask: Optional[torch.Tensor] = None, ): if return_kv: a, k, v = self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache, return_kv=True, inplace_kv=inplace_kv, mask=mask) x = x + self.drop_path_1(a) * self.residual_scale m = self.mlp(self.ln_2(x)) x = x + self.drop_path_2(m) * self.residual_scale return x, k, v else: a = self.attn(self.ln_1(x), pos, k_cache=k_cache, v_cache=v_cache, return_kv=False, inplace_kv=inplace_kv, mask=mask) x = x + self.drop_path_1(a) * self.residual_scale m = self.mlp(self.ln_2(x)) x = x + self.drop_path_2(m) * self.residual_scale return x class KVCache: @torch.no_grad() def __init__(self, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor], size: int ): self.k = k self.v = v self.mask = mask self.size = size @property def capacity(self): return self.k.shape[-2] def shallow_copy(self, mask: Optional[torch.Tensor], size: int ) -> 'KVCache': return KVCache(k=self.k, v=self.v, mask=mask, size=size) @torch.no_grad() @staticmethod def allocate( model: 'PSI2', capacity: int, batch_size: int = 1, dtype=None, device=None ) -> 'KVCache': ''' Allocates a standard KVCache. You'll usually want to use `PSI2.allocate_kvcache()` instead. ''' n_kv = getattr(model.config, "n_kv_head", None) or model.config.n_head head_dim = model.config.n_embd // model.config.n_head kv_shape = ( model.config.n_layer, batch_size, n_kv, capacity, head_dim ) return KVCache( k=torch.empty(kv_shape, dtype=dtype, device=device), v=torch.empty(kv_shape, dtype=dtype, device=device), mask=None, size=0 ) def validate(self, batch_size: int): assert self.k.shape == self.v.shape, \ f'KV shape mismatch: k={self.k.shape}, v={self.v.shape}' assert self.k.ndim == 5 assert self.k.shape[1] == batch_size assert self.k.device == self.v.device assert self.k.dtype == self.v.dtype if self.mask is not None: assert self.mask.ndim == 4 and self.mask.shape[1:3] == (1,1) assert self.mask.shape[-1] == self.size assert self.size <= self.k.shape[-2] def get_token_slice(self, start: int, stop: Optional[int] ) -> Tuple[torch.Tensor, torch.Tensor]: return ( self.k[...,start:stop,:], self.v[...,start:stop,:] ) def copy_from(self, kvcache: 'KVCache'): assert self.capacity >= kvcache.size self.k[...,:kvcache.size,:].copy_(kvcache.k[...,:kvcache.size,:], non_blocking=True) self.v[...,:kvcache.size,:].copy_(kvcache.v[...,:kvcache.size,:], non_blocking=True) self.mask = kvcache.mask self.size = kvcache.size def transpose_last(self, end: int, count: int, tpp: int): # npt=tokens, npp=patches # [... npt E] -> [... tpp npp E] -> [... npp tpp E] -> [... npt E] edims = self.k.shape[:-2] par_k = self.k[...,end-count:end,:].reshape(*edims, tpp, count//tpp, -1) par_v = self.v[...,end-count:end,:].reshape(*edims, tpp, count//tpp, -1) par_k = par_k.transpose(-2, -3).reshape(*edims, count, -1) par_v = par_v.transpose(-2, -3).reshape(*edims, count, -1) self.k[...,end-count:end,:].copy_(par_k.clone()) self.v[...,end-count:end,:].copy_(par_v.clone()) class PSI2(nn.Module): """ PSI2 with training-time features: - GQA (n_kv_head < n_head) - SwiGLU MLP - Stochastic depth (DropPath) - Residual scaling - Robust weight tying (works with XLA/FSDP, saves memory) """ def __init__(self, config, verbose=True): super().__init__() self.config = config if verbose: print("[PSI2] Using config:", config, flush=True) validate_psi2_config(config) self.context_parallel_size = int(getattr(config, "context_parallel_size", 1)) self.tie_weights = getattr(config, "tie_weights", False) self.n_lm_vocab = getattr(config, "n_lm_vocab", None) or config.vocab_size if self.tie_weights and self.n_lm_vocab != config.vocab_size: raise ValueError( f"Cannot tie weights when n_lm_vocab ({self.n_lm_vocab}) != vocab_size ({config.vocab_size})" ) token_embedding = nn.Embedding(config.vocab_size, config.n_embd) self.transformer = nn.ModuleDict(dict( token_embedding=token_embedding, channel_embedding=nn.Embedding(config.channel_size, config.n_embd), drop=nn.Dropout(config.dropout), h=nn.ModuleList([Block(config, i, config.n_layer) for i in range(config.n_layer)]), ln_f=RMSNorm(config.n_embd), )) if self.tie_weights: self.lm_head = None if getattr(config, "lm_head_bias", False): self.lm_head_bias = nn.Parameter(torch.zeros(self.n_lm_vocab)) else: self.register_parameter('lm_head_bias', None) if verbose: print(f"[PSI2] Weight tying enabled: using token_embedding weights for lm_head " f"(saving {config.vocab_size * config.n_embd:,} parameters)") else: self.lm_head = nn.Linear(config.n_embd, self.n_lm_vocab, bias=False) self.lm_head_bias = None self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith("out_proj.weight"): torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) self.unsharded_param_count = self.get_num_params() # Store the "effective" param count for compute purposes # When weights are tied, the lm_head matmul still happens, so we count it for FLOPs self.compute_param_count = self._get_compute_params() @property def device(self) -> torch.device: return self.transformer.token_embedding.weight.device @property def dtype(self) -> torch.device: return self.transformer.token_embedding.weight.dtype def _init_weights(self, module: nn.Module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def _get_compute_params(self) -> int: """ Get effective parameter count for compute/FLOPs estimation. When weights are tied, the embedding is used twice (embedding + lm_head), so we count it twice for compute purposes even though it's stored once. """ base_params = sum(p.numel() for p in self.parameters()) if self.tie_weights: # Add the embedding size again since it's used for lm_head computation tied_weight_size = self.config.vocab_size * self.config.n_embd return base_params + tied_weight_size return base_params def get_num_params(self, non_embedding: bool = True) -> int: """Return the number of stored model parameters. The ``non_embedding`` argument is accepted for compatibility with the training implementation; the compact inference runtime reports the full stored parameter count. """ return sum(p.numel() for p in self.parameters()) def _lm_head_forward(self, x: torch.Tensor) -> torch.Tensor: """ Compute lm_head output, using embedding weights if tie_weights is enabled. """ if self.tie_weights: logits = F.linear(x, self.transformer.token_embedding.weight, self.lm_head_bias) else: logits = self.lm_head(x) return logits def forward( self, seq: torch.Tensor, pos: torch.Tensor, tgt: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None, k_cache: Optional[torch.Tensor] = None, v_cache: Optional[torch.Tensor] = None, return_kv: bool = False, inplace_kv: bool = False, output_hidden_states: bool = False, ): """ Returns: if tgt is None: (logits, None) or (logits, (k_cache, v_cache)...) depending on flags else: (logits, loss) """ st_pos = pos[:, :, :-1] channel_pos = pos[:, :, -1] tok_emb = self.transformer.token_embedding(seq) ch_emb = self.transformer.channel_embedding(channel_pos) x = self.transformer.drop(tok_emb + ch_emb) if output_hidden_states: hidden_states = [x] k_list, v_list = [], [] for i, block in enumerate(self.transformer.h): x = block( x, pos=st_pos, k_cache=None if k_cache is None else k_cache[i], v_cache=None if v_cache is None else v_cache[i], return_kv=return_kv, inplace_kv=inplace_kv, mask=mask, ) if return_kv: x, k, v = x k_list.append(k) v_list.append(v) if output_hidden_states: hidden_states.append(x) x = self.transformer.ln_f(x) if output_hidden_states: hidden_states.append(x) if tgt is None: logits = self._lm_head_forward(x) if output_hidden_states: logits = {"logits": logits, "hidden_states": hidden_states} if return_kv: if inplace_kv: return logits, k_cache, v_cache else: return logits, torch.stack(k_list), torch.stack(v_list) return logits, None logits = self._lm_head_forward(x[:, -tgt.size(1):]) flat_logits = logits.reshape(-1, logits.size(-1)) flat_tgt = tgt.reshape(-1) # Context parallelism should not change the training objective itself: # we still optimize the standard token-mean cross-entropy on the logits # produced by the current forward pass. Any CP-specific correctness work # belongs in the attention plumbing, not in ad hoc loss rescaling here. loss = F.cross_entropy( flat_logits, flat_tgt, ignore_index=-1 ) if output_hidden_states: logits = {"logits": logits, "hidden_states": hidden_states} return logits, loss def unpack_and_sort_img_seq(self, img_seq: torch.Tensor, num_revealed_patches: int = 0, mark_revealed_patches: Optional[int] = None, logits: Optional[torch.Tensor] = None, unpatchify_seq: bool = True): """ Unpacks a sequence of patch-index + RGB tokens into an image. Assumes first token per patch is the index. """ n_tokens_per_patch = self.config.patch_size ** 2 + 1 B = img_seq.size(0) img_seq = img_seq.view(B, -1, n_tokens_per_patch) img_idxs = (img_seq[:, :, 0].long() - (65536 + 256)) reconstruct_indxs = torch.argsort(img_idxs, dim=1) rgb_seq = img_seq[:, :, 1:] if mark_revealed_patches is not None: rgb_seq[:, :num_revealed_patches] = mark_revealed_patches rgb_seq = rgb_seq[torch.arange(B).unsqueeze(1).expand_as(img_idxs), reconstruct_indxs] if unpatchify_seq: img = _unpatchify(rgb_seq) else: h = w = int(math.sqrt(rgb_seq.size(1))) img = rgb_seq.reshape(B, h, w, -1) if logits is not None: logits = logits.view(B, -1, n_tokens_per_patch, logits.size(-1)) logits = logits[:, :, 1:] logits = logits[torch.arange(B).unsqueeze(1).expand_as(img_idxs), reconstruct_indxs] logits = _unpatchify_logits(logits) return img, logits return img @torch.no_grad() def sample_logits(self, logits: torch.FloatTensor, temp: Optional[float] = None, post_temp: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, min_p: Optional[float] = None, sample_range: Optional[Tuple[int,int]] = None, blacklist: Optional[Union[List[int], torch.LongTensor]] = None, content_whitelist: Optional[torch.LongTensor] = None ) -> torch.LongTensor: """ Samples an integer from the distribution of logits Parameters: logits (torch.FloatTensor): The logits of the distribution temp (float): The temperature of the sampling, if 0.0, then argmax is used top_k (int): The number of top k tokens to consider during sampling top_p (float): The cumulative probability threshold for nucleus (top-p) sampling min_p (float): The minimum probability threshold factor for min-p sampling blacklist (Union[List[int], torch.LongTensor]): The list of tokens to blacklist during sampling content_whitelist (torch.LongTensor): The list of tokens to whitelist during sampling (mutually exclusive with blacklist) Returns: torch.LongTensor: The sampled integers """ if isinstance(temp, list): temp = temp[0] if isinstance(post_temp, list): post_temp = post_temp[0] if isinstance(top_k, list): top_k = top_k[0] if isinstance(top_p, list): top_p = top_p[0] assert temp is None or temp >= 0.0 assert post_temp is None or post_temp >= 0.0 assert top_k is None or top_k > 0 assert top_p is None or top_p >= 0.0 assert min_p is None or 0.0 <= min_p <= 1.0 assert sample_range is None or ( sample_range[0] < sample_range[1] and sample_range[0] >= 0 and sample_range[1] <= logits.shape[-1] ) assert blacklist is None or content_whitelist is None, "blacklist and content_whitelist cannot be used together" # Apply blacklist & sample range if blacklist is not None: logits[...,blacklist] = float('-inf') if content_whitelist is not None: # Create a mask that allows only whitelisted tokens whitelist_mask = torch.full_like(logits, float('-inf')) whitelist_mask[..., content_whitelist] = 0.0 logits = logits + whitelist_mask if sample_range is not None: logits = logits[...,sample_range[0]:sample_range[1]] token_offset = sample_range[0] else: token_offset = 0 # Apply temperature, or use argmax if 0.0 if (temp is not None and temp == 0.0) or (post_temp is not None and post_temp == 0.0): return token_offset + torch.argmax(logits, dim=-1) if temp is not None and temp != 1.0: logits.div_(temp) # Sort the logits once. More efficient when using top-k and top-p together (min-p doesn't require sorting). # We sample in sorted order then re-order before returning. if top_k is not None or top_p is not None: logits, order = torch.sort(logits, dim=-1, descending=True) else: order = None # Don't sort # Apply top-k filtering if specified if top_k is not None: logits = logits[...,:top_k] # Apply top-p (nucleus) filtering if specified if top_p is not None: probs = F.softmax(logits, dim=-1) cumulative_probs = probs.cumsum_(dim=-1) idxs_to_remove = cumulative_probs > top_p # Shift the mask right to keep at least one token logits[...,1:][idxs_to_remove[...,:-1]] = float('-inf') del probs, cumulative_probs, idxs_to_remove # Apply min-p filtering if specified if min_p is not None: probs = F.softmax(logits, dim=-1) maxprob = probs[...,[0]] if order is not None else torch.max(probs, dim=-1, keepdim=True).values logits[probs < maxprob * min_p] = float('-inf') del probs, maxprob # Apply optional post-temperature if post_temp is not None and post_temp != 1.0: logits.div_(post_temp) # Compute softmax probabilities orig_shape = logits.shape probs = torch.softmax(logits, dim=-1, out=logits) # Flatten probabilities to (batch_size * sequence_length, vocab_size) for sampling flat_probs = probs.view(-1, probs.size(-1)) sampled = torch.multinomial(flat_probs, num_samples=1) sampled = sampled.view(*orig_shape[:-1]) # If we sorted, unsort to collect the actual token values if order is not None: sampled = torch.gather(order, dim=-1, index=sampled.unsqueeze(-1)).squeeze(-1) return token_offset + sampled @torch.no_grad() def allocate_kvcache(self, n_tokens: int, batch_size: int = 1, dtype=None, device=None): assert n_tokens > 0 assert batch_size > 0 dtype = dtype or self.transformer.token_embedding.weight.dtype device = device or self.device return KVCache.allocate(self, n_tokens, batch_size, dtype=dtype, device=device) @torch.no_grad() def rollout_patches(self, seq: Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]], pos: Union[torch.LongTensor, List[torch.LongTensor]], idx: Union[torch.LongTensor, List[torch.LongTensor]], n_tokens_per_patch: int = 5, n_seq_patches: int = -1, weights: Optional[Union[List[float], torch.Tensor]] = None, scatter_indices: Optional[Union[List[int], torch.LongTensor]] = None, kvcache: Optional[KVCache] = None, policy: Callable[..., torch.LongTensor] = None, *, unmask_parallel: bool = False, return_logits: bool = False, return_idx_logits: bool = True, return_kv: bool = False, temp: Optional[Union[float, List[float]]] = None, post_temp: Optional[Union[float, List[float]]] = None, top_k: Optional[Union[int, List[int]]] = None, top_p: Optional[Union[float, List[float]]] = None, min_p: Optional[Union[float, List[float]]] = None, sample_range: Optional[Tuple[int, int]] = None, blacklist: Optional[Union[List[int], torch.LongTensor]] = None, tqdm_kwargs: dict = {}, print_and_visualize_tokens: bool = False, out_dir: Optional[str] = "debug", ) -> Union[ torch.LongTensor, # seq Tuple[torch.LongTensor, torch.Tensor], # seq, logits Tuple[torch.LongTensor, KVCache], # seq, kvcache Tuple[torch.LongTensor, torch.Tensor, KVCache], # seq, logits, kvcache ]: """ K = number of given sequences (1 if seq is not a list) T = length of a conditioning sequence (per sequence) N = total length of conditioning + generated tokens (per sequence) I = number of index tokens to roll out B = number of batched rollouts to make (0 if not using batched rollouts) max(T) = maximum of T across sequences num_new_tokens = len(idx) * n_tokens_per_patch ***Logit Mixing***: To use logit mixing, supply multiple sequences/positions and a `weights` list/tensor. Before sampling tokens, the logits from each sequence will be mixed using the provided weights. ***Batched Rollouts***: To run batched rollouts *without* logit mixing, supply multiple (K) sequences/positions and leave `weights=None`. The returned sequences/logits will include an extra batch dimension of size K. ***Batched Rollouts + Logit Mixing***: To run batched rollouts and logit mixing at the same time, supply multiple (K) sequences/positions, a `weights` list, and `scatter_indices`. `scatter_indices` indicates which output sequence each input sequence corresponds to. For example, to run 3 cfg rollouts, you may supply K=6 sequences, with `weights=[cfg, 1-cfg, cfg, 1-cfg, cfg, 1-cfg]` and `scatter_indices=[0,0,1,1,2,2]`. **WARNING:** With batched rollouts and logit mixing, and logits may not be numerically deterministic on CUDA (see torch.Tensor.scatter_add_ docs.) ***Tips for long rollouts***: 1. Use `gc.collect()` and then `torch.cuda.empty_cache()` (in that order) 2. Avoid fragmentation wherever possible. This method needs to allocate the entire KV cache contiguously in memory. If you create tensors between rollouts, consider moving them to CPU or cloning them to defragment VRAM. 3. Even if a single N-token rollout fits in VRAM, running two consecutive rollouts (e.g. running n1 then giving its KV cache to n2, with n1 + n2 = N) may not fit, because reallocating the KV cache will duplicate memory. To avoid this, try moving the cache to CPU first: ``` seq1, kvcache = predictor.rollout_patches(..., return_kv=True) kvcache = { k: v.cpu() for k, v in kvcache.items() } gc.collect(); torch.cuda.empty_cache() seq2 = predictor.rollout_patches(..., **kvcache) ``` With this, the new KV cache will be allocated on GPU, and the CPU cache will be copied into it. TODO: Support temp/top_k/top_p/min_p scheduling with parallel. Currently uses index -1 for all parallel tokens Parameters: seq (Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]]): [T], [K T], or list of [T] / sequence(s) to condition the generation on. None means empty sequence pos (Union[torch.LongTensor, List[torch.LongTensor]]): [N 4], [K N 4], or list of [N 4] / 4D position(s) corresponding to each sequence. If multiple, must have length K=len(seq) idx (Union[torch.LongTensor. List[torch.LongTensor]]): [I], [B I], or list of [I] / the patch indices to use in the rollout. 1D by default, used for all sequences. If using batched rollouts, idx may be 1D or 2D, and the second dimension must be equal to the number of batched rollouts (if scatter_indices is not None, this may be less than K). n_tokens_per_patch (int): number of tokens per patch, including patch index n_seq_patches (int): number of patches to roll out sequentially (-1 for all). The remaining patches will be parallel weights (Optional[Union[List[float], torch.Tensor]]): float weights for the logits produced by each sequence, used for logit mixing. If None and multiple sequences are given, batched rollouts will be used instead of logit mixing. If a list or 1D tensor, must have size K. If 2D, must have shape [K num_new_tokens]. If 2D, the first n_seq_patches patches will use the weights in order as expected, but the remaining (parallel) patches may be arbitrarily reordered. When using parallel and a 2D weight schedule, it is recommended to make the weights for parallel patches uniform for consistency scatter_indices (Optional[Union[List[int], torch.LongTensor]]): integer scatter indices for logit mixing, indicating which generated sequence each input sequence is used for. If provided, multiple sequences will be generated and returned; see above for details. Must have exactly K elements, one for each input sequence. Indices must be set-contiguous in the range [0,K) and must include 0, e.g. for K=5, [1,0,3,2,2] are valid indices but [1,3,4,4,3] are not. kvcache (Optional[KVCache]): optional KV cache for this rollout. If kvcache.size > 0, the tokens in the cache will be used as conditioning for this rollout. If the cache has enough empty space to cache this rollout, the KV tensors will be modified in-place. Note that only the tensor data is modified, not the other entries of the KVCache (e.g. size, mask). To receive an update cache object, set return_kv=True and pass the returned cache into downstream rollouts. If no cache is provided, or if the cache is too small for this rollout, a new cache will be allocated and the tensor data from this cache will be copied. policy (Callable[..., torch.LongTensor]): optional callback defining the policy for rollout order. Must accept an argument `idx` (torch.LongTensor of shape [I]) with the candidate indices to generate next. Must return the *index* into the `idx` tensor to generate next, either an int or 0-dimensional torch.LongTensor. For example, given candidate indices [4,12,1023], return 2 to generate the patch with index 1023 next. Only used for the sequential part of the generation. The following kwargs are given - `idx` (torch.LongTensor of shape [I] or [B I]) the candidate patch indices - `pos` (torch.LongTensor of shape [K N 4]): the remaining poses for all yet-ungenerated tokens in the same order as `idx` - `weights` (torch.Tensor of shape [K N]): the weights for all yet-ungenerated tokens, or None - `scatter_indices` (torch.LongTensor of shape [K]) - `kvcache` (KVCache): the cache for this rollout, where kvcache.size includes the tokens generated up to this point - `n_tokens_per_patch` (int) - `sample_range` (Optional[Tuple[int,int]]) - `idx_pos` (torch.LongTensor of shape [K I 4]): same as `pos`, but only for the candidate index tokens - `idx_weights` (torch.Tensor of shape [K I]): same as `weights`, but only for the candidate index tokens - `predictor` (PSIPredictor) self - `device` (torch.device)\n The callback must return the following - (Union[int, torch.LongTensor]): the *index* into the `idx` tensor with the patch token to generate next (*not* the value of the patch index itself) unmask_parallel (bool): if True, all parallel patches can attend to each other. If False (default), parallel patches can only attend to themselves return_logits (bool): return the logits of the sequence return_idx_logits (bool): if True (default), returns logits that would predict index tokens, so there is one set of logits for every returned token. If False, only returns logits used to sample content tokens, e.g. returns (n_tokens_per_patch - 1) sets of logits per patch. The latter may not need to compute logits for all tokens, so it may be more efficient for some computations (such as patchwise entropy). Ignored if return_logits=False return_kv (bool): return the KV cache(s) as a dict with keys 'k_cache', 'v_cache', 'cache_mask', and 'cache_size', useful for downstream operations. If True and return_logits=False, returns (new_tokens, kvcache). If True and return_logits=True, returns (new_tokens, logits, kvcache). **All KVs are returned in patch-major order,** even if the rollout is partially or fully parallel. Note that KVs from parallel prediction are not computed causally Returns: torch.LongTensor: [num_new_tokens] or [B num_new_tokens] the generated tokens only (the input sequence is not prepended). 1D by default, or 2D if using batched rollouts torch.Tensor: (optional) [n_tokens vocab_size] or [B n_tokens vocab_size] the logits of the sequence, where n_tokens depends on return_idx_logits. 2D by default, or 3D if using batched rollouts Dict[str, torch.Tensor]: (optional) the KV cache, with the following key/value pairs - `k_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head] - `v_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head] - `cache_mask` (torch.Tensor) [K 1 1 n_tok] - `cache_size` (int) """ if print_and_visualize_tokens: raise NotImplementedError("Token visualization is not included in the Hugging Face inference build.") ######################### # === Preprocessing === # ######################### if not isinstance(seq, list): seq = [seq] if seq is None or seq.ndim == 1 else list(seq) if not isinstance(pos, list): pos = [pos] if pos.ndim == 2 else list(pos) if isinstance(idx, list): idx = torch.stack(idx) nnt = idx.shape[-1] * n_tokens_per_patch idtype = pos[0].dtype dtype = self.dtype device = self.device if weights is not None: if isinstance(weights, list): weights = torch.tensor(weights, dtype=dtype, device=device) if weights.ndim != 2: weights = weights.unsqueeze(-1).expand(-1, nnt) weights = weights.to(dtype).to(device) if scatter_indices is not None: if isinstance(scatter_indices, list): scatter_indices = torch.tensor(scatter_indices, dtype=idtype, device=device) if n_seq_patches < 0: n_seq_patches = idx.shape[-1] if temp is not None and not isinstance(temp, list): temp = [temp] * nnt if post_temp is not None and not isinstance(post_temp, list): post_temp = [post_temp] * nnt if top_k is not None and not isinstance(top_k, list): top_k = [top_k] * nnt if top_p is not None and not isinstance(top_p, list): top_p = [top_p] * nnt if min_p is not None and not isinstance(min_p, list): min_p = [min_p] * nnt K = len(seq) I = idx.shape[-1] T = [0 if s is None else s.shape[0] for s in seq] maxT = max(T) tpp = n_tokens_per_patch # Number of tokens we need to use from the input cache in_cache_size = 0 if kvcache is None else kvcache.size n_rollout_tokens = tpp * n_seq_patches n_par_patches = I - n_seq_patches return_idx_logits = return_logits and return_idx_logits run_last_parallel_tokens = return_idx_logits or return_kv # Number of batched rollouts to make. If 0, we return a single 1D sequence, else return 2D if scatter_indices is not None: B = int(torch.max(scatter_indices).cpu().item()) + 1 else: B = K if (weights is None and K > 1) else 0 if idx.ndim == 2: assert B > 0, f'2D idx tensors are only supported for batched rollouts (got idx.shape={tuple(idx.shape)})' assert idx.shape == (B, I), f'Expected idx.shape={(B,I)}, but got {tuple(idx.shape)}.' if scatter_indices is not None: idx = torch.gather(idx, 0, scatter_indices.unsqueeze(-1).repeat(1, idx.shape[-1])) # [B I] -> [K I] # else B=K, so idx is already [K I] elif idx.ndim == 1 and B > 0: idx = idx.expand(K, idx.shape[-1]) if idx.ndim == 2: # TODO: Batched rollouts + policies are currently not supported if policy is not None: raise NotImplementedError('Batched rollouts with a policy callback is not currently supported') # Later, we may need to reduce idx from [K I] to [B I], such as before invoking # the policy callback or running parallel prediction. Here we precompute the rows # to use so we can efficiently perform the reduction on-demand. # Note we can't use torch.unique here because it doesn't maintain order. def precompute_idx_batch_rows(scatter: List[int]): seen = set() out = [] for i, s in enumerate(scatter): if s not in seen: seen.add(s) out.append(i) return out if scatter_indices is None: idx_batch_rows = list(range(K)) else: idx_batch_rows = precompute_idx_batch_rows(scatter_indices.tolist()) # Validate inputs as best we can assert len(pos) == K, f'Expected seq and pos lists to have the same length, but got {K} and {len(pos)}' assert (idx.shape == (K, I)) or (idx.ndim == 1), f'idx tensor must be 1D (or 2D for batched rollouts)' if weights is not None: assert weights.ndim == 2 and weights.shape == (K, nnt) if scatter_indices is not None: assert scatter_indices.ndim == 1 and len(scatter_indices) == K assert torch.min(scatter_indices) >= 0 and torch.max(scatter_indices) < K _usi = torch.unique(scatter_indices, sorted=True) assert torch.equal(_usi, torch.arange(len(_usi), device=_usi.device)) del _usi assert I > 0, f'No idx tokens were provided' assert tpp > 0, f'Must have a positive number of tokens per patch' assert I * tpp == nnt assert 0 <= n_seq_patches <= I if kvcache is not None: kvcache.validate(K) for i, (s, p) in enumerate(zip(seq, pos)): if s is not None: assert s.ndim == 1 assert p.ndim == 2 assert p.shape[1] == 4 assert p.shape[0] == T[i] + nnt # General notes: # B>0 means batched rollout mode, even if B==1. Batched means returned sequences are [B nnt], but B==0 means they're [nnt]. # 0 <= B <= K. If B==K, then all sequences are independent (no logit mixing). # If B>0, then idx is [K I] with B unique rows (some may be repeated). If B==0, then idx is [I]. # Use idx[idx_batch_rows] to select idx with shape [B I]. # Use Tensor.scatter_add + scatter_indices to sum-reduce from [K *] to [B *]. ######################### # === Preallocation === # ######################### # Preallocate the KV cache (if we need to) so we don't need to constantly resize it # If we won't run the last parallel pass, we don't need to cache those toks n_kvcache = in_cache_size + maxT + nnt - (0 if run_last_parallel_tokens else n_par_patches) if kvcache is None or kvcache.capacity < n_kvcache: # Allocate a new cache and copy the existing one into it new_kvcache = self.allocate_kvcache(n_kvcache, batch_size=K, dtype=dtype, device=device) if kvcache is not None: new_kvcache.copy_from(kvcache) kvcache = new_kvcache del new_kvcache # else we can just use the provided cache in-place # Also preallocate the output logits tensor, if requested if return_logits: n_logits = n_rollout_tokens + (I - n_seq_patches) * (tpp if return_idx_logits else (tpp - 1)) if B > 0: # Easiest to store/append this shape then transpose first two axes on return all_logits = torch.empty((n_logits, B, self.config.vocab_size), dtype=dtype, device=device) else: all_logits = torch.empty((n_logits, self.config.vocab_size), dtype=dtype, device=device) ################################ # === Initial Forward Pass === # ################################ # Stack seq/pos into a batch, left-padded # [K maxT] if maxT > 0: seq = torch.stack([( torch.zeros(maxT, dtype=idtype, device=device) if s is None else F.pad(s, (maxT - s.shape[0], 0)) ) for s in seq]) # [K maxN 4] pos = torch.stack([F.pad(p, (0, 0, maxT - t, 0)) for t, p in zip(T, pos)]) if maxT > 0: # Build attention mask for initial forward pass [K 1 maxT maxT] # Batch size K, each mask in the batch is fully causal except for the first (maxT - T) tokens, which are masked mask = torch.zeros(K, 1, maxT, maxT, device=device) mask.masked_fill_(torch.ones_like(mask, dtype=torch.bool).triu(1), float('-inf')) for i, t in enumerate(T): mask[i, ..., :maxT-t] = float('-inf') # Unmask the diagonal so pad tokens can self-attend # This doesn't matter with torch sdpa, but prevents NaNs with manual attention # NOTE: If t==0, the diagonal is *only* pad tokens, so this will unmask the last pad token # in the last row (which we use for rollouts). We re-mask this pad token in the rollout mask below mask[i, 0].fill_diagonal_(0.0) if kvcache is not None: if kvcache.mask is not None: mask = torch.cat([kvcache.mask.to(mask.device).expand((K, 1, maxT, -1)), mask], dim=-1) else: mask = F.pad(mask, (in_cache_size, 0, 0, 0, 0, 0, 0, 0)) # The above mask[:,0,-1,:] might look something like this: # kv cache | sequences # T[0]==3 [ T T T T T T | F F F F T T T ] # T[1]==6 [ T T T T T T | F T T T T T T ] # T[2]==7 [ T T T T T T | T T T T T T T ] # T[3]==4 [ T T T T T T | F F F T T T T ] # For one element in the batch, mask[0,0,:,:] with T[0]==3 might look like: # kv cache | sequences # [ T T T T T T | T F F F F F F ] # [ T T T T T T | F T F F F F F ] # [ T T T T T T | F F T F F F F ] # [ T T T T T T | F F F T F F F ] # [ T T T T T T | F F F F T F F ] # [ T T T T T T | F F F F T T F ] # [ T T T T T T | F F F F T T T ] # If a custom kvcache.mask is given, the kv cache part above may be different # Initial forward pass (conditioning sequences only) k_cache, v_cache = kvcache.get_token_slice(0, in_cache_size+maxT) self.forward( seq=seq, pos=pos[:,:maxT], mask=mask, k_cache=k_cache, v_cache=v_cache, inplace_kv=True ) pos = pos[:,maxT:] ############################## # === Sequential Rollout === # ############################## # Build attention mask for rollout [K 1 1 in_cache_size+maxT+n_rollout_tokens] if maxT == 0: if kvcache is None or kvcache.mask is None: # We have no context and no input cache mask, so we only need rollout tokens # Note we may still have an input cache without a mask, so add in_cache_size tokens mask = torch.zeros((K, 1, 1, in_cache_size+n_rollout_tokens), dtype=dtype, device=device) else: # We have no context and an input cache mask # Pad [K 1 1 in_cache_size] -> [K 1 1 in_cache_size+n_rollout_tokens] mask = F.pad(kvcache.mask, (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0)) else: # We created a mask while computing context above. Take the last entry and clone to free memory mask = F.pad(mask[...,[-1],:].clone(), (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0)) for i, t in enumerate(T): # If t==0, the fill_diagonal call above unmasked the last pad token, # so we need to re-mask it before we start rolling out if t == 0: mask[i, ..., in_cache_size+maxT-1] = float('-inf') # The above mask[:,0,0,:] might look something like this: # kv cache | sequences | rollout # T[0]==3 [ T T T T T T | F F F F T T T | T T T T ... T T T T ] # T[1]==6 [ T T T T T T | F T T T T T T | T T T T ... T T T T ] # T[2]==7 [ T T T T T T | T T T T T T T | T T T T ... T T T T ] # T[3]==4 [ T T T T T T | F F F T T T T | T T T T ... T T T T ] # We construct this mask once, then slice off part of the right side at each rollout step # List of 0D (if B==0) or 1D (if B>0) tensors rollout_seq = [] # Rollout tqdm_kwargs = dict(desc='Rollout', unit='tok') | tqdm_kwargs if n_rollout_tokens == 0: tqdm_kwargs |= dict(disable=True) for i in tqdm.tqdm(range(n_rollout_tokens), **tqdm_kwargs): if (i % tpp) == 0: patch_number = i // tpp if policy is None: # Use provided order next_token = idx[...,patch_number] # 0D or [K] else: # Use callback to select the next patch policy_cache_mask = mask[..., :in_cache_size+maxT+i] idx_of_next_idx = policy( idx=idx[idx_batch_rows,patch_number:] if B > 0 else idx[patch_number:], pos=pos[:, i:], weights=None if weights is None else weights[:, i:], scatter_indices=scatter_indices, kvcache=kvcache.shallow_copy(mask=policy_cache_mask, size=in_cache_size+maxT+i), n_tokens_per_patch=n_tokens_per_patch, sample_range=sample_range, idx_pos=pos[:, i::tpp], idx_weights=None if weights is None else weights[:, i::tpp], predictor=self, device=idx.device, ) del policy_cache_mask # Move the patch patch_number+idx_of_next_idx to the next position by swapping # TODO: Policies are currently not supported with batched rollouts if idx_of_next_idx != 0: i1, i2 = patch_number, patch_number + int(idx_of_next_idx) idx[[i1,i2]] = idx[[i2,i1]] # [K N 4] -> [K I tpp 4] -swap-idxs-> [K I tpp 4] -> [K N 4] pos = pos.reshape(K, I, tpp, 4) pos[:,[i1,i2]] = pos[:,[i2,i1]] pos = pos.reshape(K, -1, 4) next_token = idx[patch_number] rollout_seq.append(next_token[idx_batch_rows] if B > 0 else next_token) # Forward call k_cache, v_cache = kvcache.get_token_slice(0, in_cache_size+maxT+i+1) logits, _ = self.forward( seq=next_token.unsqueeze(-1).expand(K, 1), # 0D or [K] -> [K 1] pos=pos[:, [i]], # [K 1 4] mask=mask[..., :in_cache_size+maxT+i+1], # [K 1 1 n_prev+1] k_cache=k_cache, v_cache=v_cache, inplace_kv=True ) logits = logits.squeeze(1) # [K 1 V] -> [K V] # If next is an index token, we just needed to cache the previous token # so we can skip logits if we're not returning them if ((i + 1) % tpp) == 0 and not return_idx_logits: continue # Logit mixing: [K V] -> [1 V] or [B V] if weights is not None: logits.mul_(weights[:,[i]]) # weights: [K nnt] -> [K 1], logits: [K V] if scatter_indices is not None: new_logits = torch.zeros((B, logits.shape[-1]), dtype=dtype, device=device) si = scatter_indices.unsqueeze(-1).repeat(1, logits.shape[-1]) # [B V] new_logits.scatter_add_(0, si, logits) # [B V] logits = new_logits del new_logits elif B == 0: logits = logits.sum(0, keepdim=True) # [1 V] # else logits: [K V] = [B V] # If next is an index token, we just needed to cache the previous token if ((i + 1) % tpp) == 0: if return_idx_logits: all_logits[i].copy_(logits if B > 0 else logits.squeeze()) continue if return_logits: all_logits[i].copy_(logits if B > 0 else logits.squeeze()) # Sample from logits next_token = self.sample_logits( logits, temp=temp[i] if temp else None, post_temp=post_temp[i] if post_temp else None, top_k=top_k[i] if top_k else None, top_p=top_p[i] if top_p else None, min_p=min_p[i] if min_p else None, sample_range=sample_range, blacklist=blacklist ) rollout_seq.append(next_token if B > 0 else next_token.squeeze()) # Batched rollouts: We need to remap next_token from shape [B] to [K 1] # Note that if B>0 and scatter_indices is None, then B=K, as enforced during preprocessing if B > 0 and scatter_indices is not None: next_token = torch.gather(next_token, 0, scatter_indices) # [B] -> [K] if n_rollout_tokens > 0: rollout_seq = torch.stack(rollout_seq) # [nnt] or [nnt B] if B > 0: rollout_seq = rollout_seq.mT # [nnt B] -> [B nnt] if n_rollout_tokens == nnt: ret = (rollout_seq,) if return_logits: # [nnt B V] -> [B nnt V], or [nnt V] ret = (*ret, all_logits.transpose(0, 1)) if B > 0 else (*ret, all_logits) if return_kv: ret = (*ret, kvcache.shallow_copy(mask=mask, size=n_kvcache)) return ret if len(ret) > 1 else ret[0] ############################ # === Parallel Rollout === # ############################ npp = n_par_patches npt = npp * tpp # num parallel tokens idx = idx[-npp:] # Build attention mask for parallel part [K 1 npp n_past+npp] if unmask_parallel: par_mask = torch.zeros(1, dtype=mask.dtype, device=device).expand(K, 1, npp, npp) else: par_mask = torch.full((npp, npp), float('-inf'), dtype=mask.dtype, device=device) par_mask.fill_diagonal_(0.0) par_mask = par_mask.expand(K, 1, npp, npp) mask = mask.expand(K, 1, npp, -1) # Initial shape is [K 1 npp n_past]. Before each iter, we append par_mask [K 1 npp npp] along the last dim # Reshape positions for parallel passes # [K nnt 4] -trim-> [K npt 4] -> [K npp tpp 4] -> [K tpp npp 4] -> [K npt 4] pos = pos[:,-npt:].reshape(K, npp, tpp, 4).transpose(1, 2).reshape(K, npt, 4) # Reshape scheduled properties (transpose similarly to positions) # [nnt] -trim-> [npt] -> [npp tpp] -> [tpp npp] -> [npt] if temp is not None: temp = np.array(temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() if post_temp is not None: post_temp = np.array(post_temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() if top_k is not None: top_k = np.array(top_k[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() if top_p is not None: top_p = np.array(top_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() if min_p is not None: min_p = np.array(min_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() if weights is not None: # [K nnt] -trim-> [K npt] -> [K npp tpp] -> [K tpp npp] -> [K npt] weights = weights[:,-npt:].reshape(K, npp, tpp).transpose(1, 2).reshape(K, npt) next_tokens = idx.expand(K, npp) # [K npp] or [npp] -> [K npp] parallel_seq = [idx[idx_batch_rows]] if B > 0 else [idx] # [B npp] or [npp] # Run parallel passes for i in range(tpp if run_last_parallel_tokens else (tpp - 1)): mask = torch.cat([mask, par_mask], dim=-1) parallel_slice = slice(i*npp, (i+1)*npp) k_cache, v_cache = kvcache.get_token_slice(0, in_cache_size+maxT+n_rollout_tokens+(i+1)*npp) logits, _ = self.forward( seq=next_tokens, # [K npp] pos=pos[:,parallel_slice], # [K npp 4] mask=mask, # [K 1 npp n_past+npp] k_cache=k_cache, v_cache=v_cache, inplace_kv=True ) # Weighted sum of logits [K npp V] -> [1 npp V] or [B npp V] if weights is not None: w = weights[:,parallel_slice] # [K npt] -> [K npp] logits.mul_(w.unsqueeze(-1)) # [K npp V] if scatter_indices is not None: new_logits = torch.zeros((B, npp, logits.shape[-1]), dtype=dtype, device=device) new_logits.scatter_add_(0, scatter_indices.reshape(-1,1,1).repeat(1, npp, logits.shape[-1]), logits) logits = new_logits # [B npp V] del new_logits elif B == 0: logits = logits.sum(0, keepdim=True) # [K npp V] -> [1 npp V] # else B==K, as enforced in preprocessing if return_logits and (return_idx_logits or i < tpp - 1): # Store the logits with a stride so we don't need to transpose later # NOTE: If return_idx_logits=False, we have tpp-1 instead of tpp stride = tpp if return_idx_logits else (tpp - 1) all_logits[n_rollout_tokens+i::stride].copy_( logits.transpose(0,1) if B > 0 else logits.squeeze(0) # [npp B V] or [npp V] ) if i == (tpp - 1): # We just needed to compute logits and/or KV to return them; no need to predict break # Sample from logits # TODO: Index using parallel_slice instead of -1 to support scheduled parameters next_tokens = self.sample_logits( logits, temp=temp[-1] if temp else None, post_temp=post_temp[-1] if post_temp else None, top_k=top_k[-1] if top_k else None, top_p=top_p[-1] if top_p else None, min_p=min_p[-1] if min_p else None, sample_range=sample_range, blacklist=blacklist ) parallel_seq.append(next_tokens if B > 0 else next_tokens.squeeze(0)) # Batched rollouts: We need to remap next_tokens from shape [B npp] to [K npp] # Note that if B>0 and scatter_indices is None, then B=K, as enforced during preprocessing if B > 0 and scatter_indices is not None: next_tokens = torch.gather(next_tokens, 0, scatter_indices.unsqueeze(-1).repeat(1,npp)) # [tpp npp] -> [npp tpp] -> [npt] parallel_seq = torch.stack(parallel_seq).transpose(0, 1).flatten() if return_kv: # Transpose the parallel part of the cache so it's returned the same order as sequential rollouts kvcache.transpose_last(n_kvcache, npt, tpp) # [K 1 npp N] -> [K 1 1 N], clone to free memory, doesn't matter which row we take (only different in last npt cols) # Unmask the last npt cols (if necessary) to make the parallel part fully unmasked mask = mask[...,[-1],:].clone() if not unmask_parallel: mask[...,-npt:] = 0.0 ret = (torch.cat([rollout_seq, parallel_seq], -1),) if n_rollout_tokens > 0 else (parallel_seq,) if return_logits: # [nnt B V] -> [B nnt V], or [nnt V] ret = (*ret, all_logits.transpose(0, 1)) if B > 0 else (*ret, all_logits) if return_kv: ret = (*ret, KVCache(k=kvcache.k, v=kvcache.v, mask=mask, size=n_kvcache)) return ret if len(ret) > 1 else ret[0] def _unpatchify(labels: torch.Tensor) -> torch.Tensor: """ labels: [B, N, P*3], where P = patch_area Returns [B, 3, H, W] assuming square grid and patches. """ B, N, threeP = labels.shape C = 3 P = threeP // C p = int(math.sqrt(P)) side = int(math.sqrt(N)) assert p * p == P and side * side == N rec = labels.view(B, side, side, p, p, C).permute(0, 5, 1, 3, 2, 4).contiguous() img = rec.view(B, C, side * p, side * p) return img def _unpatchify_logits(logits: torch.Tensor) -> torch.Tensor: """ logits: [B, N, P, V] -> [B, V, H, W] aligned to the unpatchify layout. """ B, N, P, V = logits.shape p = int(math.sqrt(P)) side = int(math.sqrt(N)) assert p * p == P and side * side == N out = logits.view(B, side, side, p, p, V).permute(0, 5, 1, 3, 2, 4).contiguous() out = out.view(B, V, side * p, side * p) return out