psi0_5 / model.py
klemenk's picture
Upload model.py with huggingface_hub
d05a6b4 verified
import math
from typing import Tuple, Union, List, Dict, Optional, Callable
import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def validate_psi2_config(config) -> None:
"""
Validate PSI2 config to catch misconfigurations early.
Raises:
AssertionError: If any validation fails with a descriptive message.
"""
if hasattr(config, 'pointer_token'):
assert config.pointer_token < config.vocab_size, \
f"pointer_token ({config.pointer_token}) must be < vocab_size ({config.vocab_size})"
token_range_attrs = ['rgb_range', 'flow_range', 'depth_range', 'campose_range', 'imagenet_cls_range']
for attr in token_range_attrs:
if hasattr(config, attr):
lo, hi = getattr(config, attr)
assert 0 <= lo <= hi <= config.vocab_size, \
f"{attr} ({lo}, {hi}) out of bounds for vocab_size={config.vocab_size}"
channel_range_attrs = [
'pointer_channel_range', 'rgb_channel_range', 'campose_channel_range',
'flow_channel_range', 'depth_channel_range', 'cls_channel_range'
]
for attr in channel_range_attrs:
if hasattr(config, attr):
lo, hi = getattr(config, attr)
assert 0 <= lo < hi <= config.channel_size, \
f"{attr} ({lo}, {hi}) out of bounds for channel_size={config.channel_size}"
assert 0.0 <= config.dropout <= 1.0, \
f"dropout ({config.dropout}) must be in [0.0, 1.0]"
if hasattr(config, 'drop_path_rate'):
assert 0.0 <= config.drop_path_rate <= 1.0, \
f"drop_path_rate ({config.drop_path_rate}) must be in [0.0, 1.0]"
if hasattr(config, 'residual_scale'):
assert config.residual_scale > 0, \
f"residual_scale ({config.residual_scale}) must be positive"
if hasattr(config, 'n_kv_head') and config.n_kv_head is not None:
assert config.n_kv_head > 0, \
f"n_kv_head ({config.n_kv_head}) must be positive (or None to disable GQA)"
assert config.n_head % config.n_kv_head == 0, \
f"n_head ({config.n_head}) must be divisible by n_kv_head ({config.n_kv_head})"
assert config.n_embd % config.n_head == 0, \
f"n_embd ({config.n_embd}) must be divisible by n_head ({config.n_head})"
if hasattr(config, 'mlp_activation'):
valid_activations = {'swiglu', 'gelu'}
assert config.mlp_activation.lower() in valid_activations, \
f"mlp_activation ({config.mlp_activation}) must be one of {valid_activations}"
if hasattr(config, 'attention_mask'):
valid_masks = {'causal', 'none', 'full'}
assert config.attention_mask.lower() in valid_masks, \
f"attention_mask ({config.attention_mask}) must be one of {valid_masks}"
if hasattr(config, 'context_parallel_size'):
assert int(config.context_parallel_size) >= 1, \
f"context_parallel_size ({config.context_parallel_size}) must be >= 1"
if hasattr(config, 'load_balance_cp') and config.load_balance_cp:
assert getattr(config, 'context_parallel_size', 1) > 1, \
"load_balance_cp requires context_parallel_size > 1"
def _to_additive_mask(
mask: Optional[torch.Tensor],
*,
L: int,
S: int,
device: torch.device,
dtype: torch.dtype,
) -> Optional[torch.Tensor]:
"""
Convert a user-provided mask to an additive mask with 0 for allowed and -inf for blocked.
Accepts:
- None
- boolean mask (True = blocked)
- float mask already in additive form
"""
if mask is None:
return None
if mask.dtype == torch.bool:
add = torch.zeros_like(mask, dtype=dtype, device=mask.device)
add = add.masked_fill(mask, float('-inf'))
return add
return mask.to(dtype) if mask.dtype != dtype else mask
def _build_causal_additive_mask(
L: int,
S: int,
*,
device: torch.device,
dtype: torch.dtype,
past_k: int = 0,
) -> torch.Tensor:
"""
Additive causal mask for rectangular attention where query length is L and key length is S.
A query at row i may attend up to index (past_k + i). Positions j > past_k + i are masked.
Returns [1,1,L,S].
"""
i = torch.arange(L, device=device).unsqueeze(-1)
j = torch.arange(S, device=device).unsqueeze(0)
blocked = j > (past_k + i)
add = torch.zeros((L, S), device=device, dtype=dtype)
add = add.masked_fill(blocked, float('-inf'))
return add.view(1, 1, L, S)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
rms = x.float().pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt().to(x.dtype)
return self.weight * (x * rms)
class Rotary3D(nn.Module):
"""
3D Rotary embedding for per-head dims. Splits head_dim across (x,y,t) and applies rotations.
"""
def __init__(self, head_dim: int, base: int = 100):
super().__init__()
assert head_dim % 2 == 0, "head_dim must be even."
x_dim = (head_dim * 6) // 16
y_dim = (head_dim * 6) // 16
t_dim = head_dim - x_dim - y_dim
for d in (x_dim, y_dim, t_dim):
if d % 2 != 0:
raise AssertionError("Each axis dim must be even for RoPE.")
self.base = base
self.x_dim = x_dim
self.y_dim = y_dim
self.t_dim = t_dim
self.register_buffer("inv_freq_x", self._inv_freqs(self.x_dim), persistent=False)
self.register_buffer("inv_freq_y", self._inv_freqs(self.y_dim), persistent=False)
self.register_buffer("inv_freq_t", self._inv_freqs(self.t_dim), persistent=False)
def initialize_buffers(self):
'''Called in Checkpointing.instantiate_model when weight init is skipped'''
self.inv_freq_x = self._inv_freqs(self.x_dim)
self.inv_freq_y = self._inv_freqs(self.y_dim)
self.inv_freq_t = self._inv_freqs(self.t_dim)
def _inv_freqs(self, dim: int):
return 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
def _axis_rot(self, x_half1, x_half2, cos, sin):
return torch.cat([x_half1 * cos - x_half2 * sin, x_half1 * sin + x_half2 * cos], dim=-1)
def forward(self, x: torch.Tensor, pos_xyz: torch.Tensor) -> torch.Tensor:
"""
x: [B, nH, T, H]
pos_xyz: [B, T, 3] int/float (x, y, t)
"""
B, nH, T, H = x.shape
assert pos_xyz.shape == (B, T, 3), f"pos must be [B, T, 3], got {pos_xyz.shape}"
# Move inv_freq buffers to input device if needed (FSDP doesn't move buffers)
if self.inv_freq_x.device != x.device:
self.inv_freq_x = self.inv_freq_x.to(x.device)
self.inv_freq_y = self.inv_freq_y.to(x.device)
self.inv_freq_t = self.inv_freq_t.to(x.device)
px = pos_xyz[..., 0].to(self.inv_freq_x.dtype)
py = pos_xyz[..., 1].to(self.inv_freq_y.dtype)
pt = pos_xyz[..., 2].to(self.inv_freq_t.dtype)
fx = torch.einsum("bt,f->btf", px, self.inv_freq_x)
fy = torch.einsum("bt,f->btf", py, self.inv_freq_y)
ft = torch.einsum("bt,f->btf", pt, self.inv_freq_t)
cx, sx = fx.cos().unsqueeze(1), fx.sin().unsqueeze(1)
cy, sy = fy.cos().unsqueeze(1), fy.sin().unsqueeze(1)
ct, st = ft.cos().unsqueeze(1), ft.sin().unsqueeze(1)
half = H // 2
x1, x2 = x[..., :half], x[..., half:]
def split_axis(tensor, dims):
a, b, c = dims
a1 = a // 2; b1 = b // 2; c1 = c // 2
s0 = tensor[..., :a1]
s1 = tensor[..., a1:a1 + b1]
s2 = tensor[..., a1 + b1:a1 + b1 + c1]
return s0, s1, s2
x1x, x1y, x1t = split_axis(x1, (self.x_dim, self.y_dim, self.t_dim))
x2x, x2y, x2t = split_axis(x2, (self.x_dim, self.y_dim, self.t_dim))
rx = self._axis_rot(x1x, x2x, cx, sx)
ry = self._axis_rot(x1y, x2y, cy, sy)
rt = self._axis_rot(x1t, x2t, ct, st)
rotated = torch.cat([rx, ry, rt], dim=-1)
return rotated
class DropPath(nn.Module):
"""Per-sample DropPath, broadcast across non-batch dims."""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = float(drop_prob)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.drop_prob == 0.0 or not self.training:
return x
keep = 1.0 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
mask = x.new_empty(shape).bernoulli_(keep)
return x * (mask / keep)
class MLP(nn.Module):
def __init__(self,
n_embd: int,
mlp_hidden_size: Optional[int] = None,
*,
bias: bool = False,
dropout: float = 0.0,
activation: str = "swiglu"):
super().__init__()
inner = mlp_hidden_size if mlp_hidden_size is not None else 4 * n_embd
self.activation = activation.lower()
if self.activation == "swiglu":
self.c_fc = nn.Linear(n_embd, 2 * inner, bias=bias)
self.c_proj = nn.Linear(inner, n_embd, bias=bias)
elif self.activation == "gelu":
self.c_fc = nn.Linear(n_embd, inner, bias=bias)
self.c_proj = nn.Linear(inner, n_embd, bias=bias)
else:
raise ValueError(f"Unsupported mlp_activation: {activation}")
self.drop = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.activation == "swiglu":
u, v = self.c_fc(x).chunk(2, dim=-1)
x = F.silu(v) * u
else:
x = F.gelu(self.c_fc(x))
x = self.c_proj(x)
x = self.drop(x)
return x
class PSIAttentionLayer(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0, "n_embd must be divisible by n_head"
self.n_head = config.n_head
self.n_kv_head = getattr(config, "n_kv_head", None) or self.n_head
assert self.n_head % self.n_kv_head == 0, "n_head must be a multiple of n_kv_head"
self.head_dim = config.n_embd // self.n_head
self.dropout = config.dropout
self.bias = getattr(config, "bias", False)
self.is_causal = (getattr(config, "attention_mask", "causal") == "causal")
self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=self.bias)
self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=self.bias)
self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=self.bias)
self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=self.bias)
self.attn_dropout = nn.Dropout(self.dropout)
self.resid_dropout = nn.Dropout(self.dropout)
self.rope = Rotary3D(self.head_dim)
self._use_sdpa = hasattr(F, "scaled_dot_product_attention")
@torch.compiler.disable
def emplace_kv(self, T, k_cache, v_cache, k, v):
# torch.compile doesn't play well with this op (5x slowdown)
# so we insert a graph break and copy eagerly
k_cache[:,:,-T:].copy_(k)
v_cache[:,:,-T:].copy_(v)
return k_cache, v_cache
def forward(
self,
x: torch.Tensor,
pos_xyz: torch.Tensor,
*,
k_cache: Optional[torch.Tensor] = None,
v_cache: Optional[torch.Tensor] = None,
return_kv: bool = False,
inplace_kv: bool = False,
mask: Optional[torch.Tensor] = None
):
B, T, C = x.shape
H, HKV, D = self.n_head, self.n_kv_head, self.head_dim
groups = H // HKV
q = self.q_proj(x).view(B, T, H, D).transpose(1, 2)
k = self.k_proj(x).view(B, T, HKV, D).transpose(1, 2)
v = self.v_proj(x).view(B, T, HKV, D).transpose(1, 2)
q = self.rope(q, pos_xyz)
k = self.rope(k, pos_xyz)
past_k = k_cache.shape[2] if (k_cache is not None) else 0
if inplace_kv and (k_cache is not None) and (v_cache is not None):
k, v = self.emplace_kv(T, k_cache, v_cache, k, v)
else:
if k_cache is not None:
k = torch.cat([k_cache, k], dim=2)
if v_cache is not None:
v = torch.cat([v_cache, v], dim=2)
S = k.shape[2]
# Check if we're on XLA device at runtime - more reliable than init-time check
# The init-time _tpu_available() can return False if called before XLA is initialized
is_xla = 'xla' in str(x.device)
expand_kv_heads = HKV != H
if expand_kv_heads:
# Note that despite using expand + reshape, this still technically copies the
# tensor because it is impossible to create a strided view of repeated data.
# This is equivalent to repeat_interleave.
# We use this even with F.scaled_dot_product_attention because the enable_gqa arg
# is slower and uses more memory. This may change with later torch versions.
k = k.unsqueeze(2).expand(B, HKV, groups, S, D).reshape(B, H, S, D)
v = v.unsqueeze(2).expand(B, HKV, groups, S, D).reshape(B, H, S, D)
add_mask = _to_additive_mask(mask, L=T, S=S, device=x.device, dtype=q.dtype)
if add_mask is None and self.is_causal:
add_mask = _build_causal_additive_mask(T, S, device=x.device, dtype=q.dtype, past_k=past_k)
if self._use_sdpa and not is_xla:
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=add_mask,
dropout_p=self.dropout if self.training else 0.0,
)
else:
att = torch.einsum("bhtd,bhsd->bhts", q, k) * (1.0 / math.sqrt(D))
if add_mask is not None:
att = att + add_mask
att = F.softmax(att, dim=-1, dtype=torch.float32).to(q.dtype)
att = self.attn_dropout(att)
y = torch.einsum("bhts,bhsd->bhtd", att, v)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_dropout(self.out_proj(y))
if return_kv:
if expand_kv_heads:
return y, k[:, ::groups], v[:, ::groups]
return y, k, v
return y
class Block(nn.Module):
def __init__(self, config, layer_idx: int, total_layers: int):
super().__init__()
self.residual_scale = float(getattr(config, "residual_scale", 1.0))
drop_path_rate = float(getattr(config, "drop_path_rate", 0.0))
layer_dp = drop_path_rate * (layer_idx / max(1, total_layers - 1))
self.ln_1 = RMSNorm(config.n_embd)
self.attn = PSIAttentionLayer(config)
self.drop_path_1 = DropPath(layer_dp)
self.ln_2 = RMSNorm(config.n_embd)
self.mlp = MLP(
config.n_embd,
getattr(config, "mlp_hidden_size", None),
bias=getattr(config, "bias", False),
dropout=config.dropout,
activation=getattr(config, "mlp_activation", "swiglu"),
)
self.drop_path_2 = DropPath(layer_dp)
def forward(
self,
x: torch.Tensor,
pos: torch.Tensor,
*,
k_cache: Optional[torch.Tensor] = None,
v_cache: Optional[torch.Tensor] = None,
return_kv: bool = False,
inplace_kv: bool = False,
mask: Optional[torch.Tensor] = None,
):
if return_kv:
a, k, v = self.attn(self.ln_1(x), pos,
k_cache=k_cache, v_cache=v_cache,
return_kv=True, inplace_kv=inplace_kv, mask=mask)
x = x + self.drop_path_1(a) * self.residual_scale
m = self.mlp(self.ln_2(x))
x = x + self.drop_path_2(m) * self.residual_scale
return x, k, v
else:
a = self.attn(self.ln_1(x), pos,
k_cache=k_cache, v_cache=v_cache,
return_kv=False, inplace_kv=inplace_kv, mask=mask)
x = x + self.drop_path_1(a) * self.residual_scale
m = self.mlp(self.ln_2(x))
x = x + self.drop_path_2(m) * self.residual_scale
return x
class KVCache:
@torch.no_grad()
def __init__(self,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor],
size: int
):
self.k = k
self.v = v
self.mask = mask
self.size = size
@property
def capacity(self):
return self.k.shape[-2]
def shallow_copy(self,
mask: Optional[torch.Tensor],
size: int
) -> 'KVCache':
return KVCache(k=self.k, v=self.v, mask=mask, size=size)
@torch.no_grad()
@staticmethod
def allocate(
model: 'PSI2',
capacity: int,
batch_size: int = 1,
dtype=None, device=None
) -> 'KVCache':
'''
Allocates a standard KVCache.
You'll usually want to use `PSI2.allocate_kvcache()` instead.
'''
n_kv = getattr(model.config, "n_kv_head", None) or model.config.n_head
head_dim = model.config.n_embd // model.config.n_head
kv_shape = (
model.config.n_layer, batch_size, n_kv,
capacity, head_dim
)
return KVCache(
k=torch.empty(kv_shape, dtype=dtype, device=device),
v=torch.empty(kv_shape, dtype=dtype, device=device),
mask=None, size=0
)
def validate(self, batch_size: int):
assert self.k.shape == self.v.shape, \
f'KV shape mismatch: k={self.k.shape}, v={self.v.shape}'
assert self.k.ndim == 5
assert self.k.shape[1] == batch_size
assert self.k.device == self.v.device
assert self.k.dtype == self.v.dtype
if self.mask is not None:
assert self.mask.ndim == 4 and self.mask.shape[1:3] == (1,1)
assert self.mask.shape[-1] == self.size
assert self.size <= self.k.shape[-2]
def get_token_slice(self,
start: int,
stop: Optional[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
return (
self.k[...,start:stop,:],
self.v[...,start:stop,:]
)
def copy_from(self, kvcache: 'KVCache'):
assert self.capacity >= kvcache.size
self.k[...,:kvcache.size,:].copy_(kvcache.k[...,:kvcache.size,:], non_blocking=True)
self.v[...,:kvcache.size,:].copy_(kvcache.v[...,:kvcache.size,:], non_blocking=True)
self.mask = kvcache.mask
self.size = kvcache.size
def transpose_last(self, end: int, count: int, tpp: int):
# npt=tokens, npp=patches
# [... npt E] -> [... tpp npp E] -> [... npp tpp E] -> [... npt E]
edims = self.k.shape[:-2]
par_k = self.k[...,end-count:end,:].reshape(*edims, tpp, count//tpp, -1)
par_v = self.v[...,end-count:end,:].reshape(*edims, tpp, count//tpp, -1)
par_k = par_k.transpose(-2, -3).reshape(*edims, count, -1)
par_v = par_v.transpose(-2, -3).reshape(*edims, count, -1)
self.k[...,end-count:end,:].copy_(par_k.clone())
self.v[...,end-count:end,:].copy_(par_v.clone())
class PSI2(nn.Module):
"""
PSI2 with training-time features:
- GQA (n_kv_head < n_head)
- SwiGLU MLP
- Stochastic depth (DropPath)
- Residual scaling
- Robust weight tying (works with XLA/FSDP, saves memory)
"""
def __init__(self, config, verbose=True):
super().__init__()
self.config = config
if verbose:
print("[PSI2] Using config:", config, flush=True)
validate_psi2_config(config)
self.context_parallel_size = int(getattr(config, "context_parallel_size", 1))
self.tie_weights = getattr(config, "tie_weights", False)
self.n_lm_vocab = getattr(config, "n_lm_vocab", None) or config.vocab_size
if self.tie_weights and self.n_lm_vocab != config.vocab_size:
raise ValueError(
f"Cannot tie weights when n_lm_vocab ({self.n_lm_vocab}) != vocab_size ({config.vocab_size})"
)
token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
self.transformer = nn.ModuleDict(dict(
token_embedding=token_embedding,
channel_embedding=nn.Embedding(config.channel_size, config.n_embd),
drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config, i, config.n_layer) for i in range(config.n_layer)]),
ln_f=RMSNorm(config.n_embd),
))
if self.tie_weights:
self.lm_head = None
if getattr(config, "lm_head_bias", False):
self.lm_head_bias = nn.Parameter(torch.zeros(self.n_lm_vocab))
else:
self.register_parameter('lm_head_bias', None)
if verbose:
print(f"[PSI2] Weight tying enabled: using token_embedding weights for lm_head "
f"(saving {config.vocab_size * config.n_embd:,} parameters)")
else:
self.lm_head = nn.Linear(config.n_embd, self.n_lm_vocab, bias=False)
self.lm_head_bias = None
self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith("out_proj.weight"):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
self.unsharded_param_count = self.get_num_params()
# Store the "effective" param count for compute purposes
# When weights are tied, the lm_head matmul still happens, so we count it for FLOPs
self.compute_param_count = self._get_compute_params()
@property
def device(self) -> torch.device:
return self.transformer.token_embedding.weight.device
@property
def dtype(self) -> torch.device:
return self.transformer.token_embedding.weight.dtype
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _get_compute_params(self) -> int:
"""
Get effective parameter count for compute/FLOPs estimation.
When weights are tied, the embedding is used twice (embedding + lm_head),
so we count it twice for compute purposes even though it's stored once.
"""
base_params = sum(p.numel() for p in self.parameters())
if self.tie_weights:
# Add the embedding size again since it's used for lm_head computation
tied_weight_size = self.config.vocab_size * self.config.n_embd
return base_params + tied_weight_size
return base_params
def get_num_params(self, non_embedding: bool = True) -> int:
"""Return the number of stored model parameters.
The ``non_embedding`` argument is accepted for compatibility with the
training implementation; the compact inference runtime reports the full
stored parameter count.
"""
return sum(p.numel() for p in self.parameters())
def _lm_head_forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute lm_head output, using embedding weights if tie_weights is enabled.
"""
if self.tie_weights:
logits = F.linear(x, self.transformer.token_embedding.weight, self.lm_head_bias)
else:
logits = self.lm_head(x)
return logits
def forward(
self,
seq: torch.Tensor,
pos: torch.Tensor,
tgt: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
k_cache: Optional[torch.Tensor] = None,
v_cache: Optional[torch.Tensor] = None,
return_kv: bool = False,
inplace_kv: bool = False,
output_hidden_states: bool = False,
):
"""
Returns:
if tgt is None:
(logits, None) or (logits, (k_cache, v_cache)...) depending on flags
else:
(logits, loss)
"""
st_pos = pos[:, :, :-1]
channel_pos = pos[:, :, -1]
tok_emb = self.transformer.token_embedding(seq)
ch_emb = self.transformer.channel_embedding(channel_pos)
x = self.transformer.drop(tok_emb + ch_emb)
if output_hidden_states:
hidden_states = [x]
k_list, v_list = [], []
for i, block in enumerate(self.transformer.h):
x = block(
x,
pos=st_pos,
k_cache=None if k_cache is None else k_cache[i],
v_cache=None if v_cache is None else v_cache[i],
return_kv=return_kv,
inplace_kv=inplace_kv,
mask=mask,
)
if return_kv:
x, k, v = x
k_list.append(k)
v_list.append(v)
if output_hidden_states:
hidden_states.append(x)
x = self.transformer.ln_f(x)
if output_hidden_states:
hidden_states.append(x)
if tgt is None:
logits = self._lm_head_forward(x)
if output_hidden_states:
logits = {"logits": logits, "hidden_states": hidden_states}
if return_kv:
if inplace_kv:
return logits, k_cache, v_cache
else:
return logits, torch.stack(k_list), torch.stack(v_list)
return logits, None
logits = self._lm_head_forward(x[:, -tgt.size(1):])
flat_logits = logits.reshape(-1, logits.size(-1))
flat_tgt = tgt.reshape(-1)
# Context parallelism should not change the training objective itself:
# we still optimize the standard token-mean cross-entropy on the logits
# produced by the current forward pass. Any CP-specific correctness work
# belongs in the attention plumbing, not in ad hoc loss rescaling here.
loss = F.cross_entropy(
flat_logits,
flat_tgt,
ignore_index=-1
)
if output_hidden_states:
logits = {"logits": logits, "hidden_states": hidden_states}
return logits, loss
def unpack_and_sort_img_seq(self,
img_seq: torch.Tensor,
num_revealed_patches: int = 0,
mark_revealed_patches: Optional[int] = None,
logits: Optional[torch.Tensor] = None,
unpatchify_seq: bool = True):
"""
Unpacks a sequence of patch-index + RGB tokens into an image.
Assumes first token per patch is the index.
"""
n_tokens_per_patch = self.config.patch_size ** 2 + 1
B = img_seq.size(0)
img_seq = img_seq.view(B, -1, n_tokens_per_patch)
img_idxs = (img_seq[:, :, 0].long() - (65536 + 256))
reconstruct_indxs = torch.argsort(img_idxs, dim=1)
rgb_seq = img_seq[:, :, 1:]
if mark_revealed_patches is not None:
rgb_seq[:, :num_revealed_patches] = mark_revealed_patches
rgb_seq = rgb_seq[torch.arange(B).unsqueeze(1).expand_as(img_idxs), reconstruct_indxs]
if unpatchify_seq:
img = _unpatchify(rgb_seq)
else:
h = w = int(math.sqrt(rgb_seq.size(1)))
img = rgb_seq.reshape(B, h, w, -1)
if logits is not None:
logits = logits.view(B, -1, n_tokens_per_patch, logits.size(-1))
logits = logits[:, :, 1:]
logits = logits[torch.arange(B).unsqueeze(1).expand_as(img_idxs), reconstruct_indxs]
logits = _unpatchify_logits(logits)
return img, logits
return img
@torch.no_grad()
def sample_logits(self,
logits: torch.FloatTensor,
temp: Optional[float] = None,
post_temp: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
sample_range: Optional[Tuple[int,int]] = None,
blacklist: Optional[Union[List[int], torch.LongTensor]] = None,
content_whitelist: Optional[torch.LongTensor] = None
) -> torch.LongTensor:
"""
Samples an integer from the distribution of logits
Parameters:
logits (torch.FloatTensor): The logits of the distribution
temp (float): The temperature of the sampling, if 0.0, then argmax is used
top_k (int): The number of top k tokens to consider during sampling
top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
min_p (float): The minimum probability threshold factor for min-p sampling
blacklist (Union[List[int], torch.LongTensor]): The list of tokens to blacklist during sampling
content_whitelist (torch.LongTensor): The list of tokens to whitelist during sampling (mutually exclusive with blacklist)
Returns:
torch.LongTensor: The sampled integers
"""
if isinstance(temp, list):
temp = temp[0]
if isinstance(post_temp, list):
post_temp = post_temp[0]
if isinstance(top_k, list):
top_k = top_k[0]
if isinstance(top_p, list):
top_p = top_p[0]
assert temp is None or temp >= 0.0
assert post_temp is None or post_temp >= 0.0
assert top_k is None or top_k > 0
assert top_p is None or top_p >= 0.0
assert min_p is None or 0.0 <= min_p <= 1.0
assert sample_range is None or (
sample_range[0] < sample_range[1] and
sample_range[0] >= 0 and
sample_range[1] <= logits.shape[-1]
)
assert blacklist is None or content_whitelist is None, "blacklist and content_whitelist cannot be used together"
# Apply blacklist & sample range
if blacklist is not None:
logits[...,blacklist] = float('-inf')
if content_whitelist is not None:
# Create a mask that allows only whitelisted tokens
whitelist_mask = torch.full_like(logits, float('-inf'))
whitelist_mask[..., content_whitelist] = 0.0
logits = logits + whitelist_mask
if sample_range is not None:
logits = logits[...,sample_range[0]:sample_range[1]]
token_offset = sample_range[0]
else:
token_offset = 0
# Apply temperature, or use argmax if 0.0
if (temp is not None and temp == 0.0) or (post_temp is not None and post_temp == 0.0):
return token_offset + torch.argmax(logits, dim=-1)
if temp is not None and temp != 1.0:
logits.div_(temp)
# Sort the logits once. More efficient when using top-k and top-p together (min-p doesn't require sorting).
# We sample in sorted order then re-order before returning.
if top_k is not None or top_p is not None:
logits, order = torch.sort(logits, dim=-1, descending=True)
else:
order = None # Don't sort
# Apply top-k filtering if specified
if top_k is not None:
logits = logits[...,:top_k]
# Apply top-p (nucleus) filtering if specified
if top_p is not None:
probs = F.softmax(logits, dim=-1)
cumulative_probs = probs.cumsum_(dim=-1)
idxs_to_remove = cumulative_probs > top_p
# Shift the mask right to keep at least one token
logits[...,1:][idxs_to_remove[...,:-1]] = float('-inf')
del probs, cumulative_probs, idxs_to_remove
# Apply min-p filtering if specified
if min_p is not None:
probs = F.softmax(logits, dim=-1)
maxprob = probs[...,[0]] if order is not None else torch.max(probs, dim=-1, keepdim=True).values
logits[probs < maxprob * min_p] = float('-inf')
del probs, maxprob
# Apply optional post-temperature
if post_temp is not None and post_temp != 1.0:
logits.div_(post_temp)
# Compute softmax probabilities
orig_shape = logits.shape
probs = torch.softmax(logits, dim=-1, out=logits)
# Flatten probabilities to (batch_size * sequence_length, vocab_size) for sampling
flat_probs = probs.view(-1, probs.size(-1))
sampled = torch.multinomial(flat_probs, num_samples=1)
sampled = sampled.view(*orig_shape[:-1])
# If we sorted, unsort to collect the actual token values
if order is not None:
sampled = torch.gather(order, dim=-1, index=sampled.unsqueeze(-1)).squeeze(-1)
return token_offset + sampled
@torch.no_grad()
def allocate_kvcache(self, n_tokens: int, batch_size: int = 1, dtype=None, device=None):
assert n_tokens > 0
assert batch_size > 0
dtype = dtype or self.transformer.token_embedding.weight.dtype
device = device or self.device
return KVCache.allocate(self, n_tokens, batch_size, dtype=dtype, device=device)
@torch.no_grad()
def rollout_patches(self,
seq: Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]],
pos: Union[torch.LongTensor, List[torch.LongTensor]],
idx: Union[torch.LongTensor, List[torch.LongTensor]],
n_tokens_per_patch: int = 5,
n_seq_patches: int = -1,
weights: Optional[Union[List[float], torch.Tensor]] = None,
scatter_indices: Optional[Union[List[int], torch.LongTensor]] = None,
kvcache: Optional[KVCache] = None,
policy: Callable[..., torch.LongTensor] = None,
*,
unmask_parallel: bool = False,
return_logits: bool = False,
return_idx_logits: bool = True,
return_kv: bool = False,
temp: Optional[Union[float, List[float]]] = None,
post_temp: Optional[Union[float, List[float]]] = None,
top_k: Optional[Union[int, List[int]]] = None,
top_p: Optional[Union[float, List[float]]] = None,
min_p: Optional[Union[float, List[float]]] = None,
sample_range: Optional[Tuple[int, int]] = None,
blacklist: Optional[Union[List[int], torch.LongTensor]] = None,
tqdm_kwargs: dict = {},
print_and_visualize_tokens: bool = False,
out_dir: Optional[str] = "debug",
) -> Union[
torch.LongTensor, # seq
Tuple[torch.LongTensor, torch.Tensor], # seq, logits
Tuple[torch.LongTensor, KVCache], # seq, kvcache
Tuple[torch.LongTensor, torch.Tensor, KVCache], # seq, logits, kvcache
]:
"""
K = number of given sequences (1 if seq is not a list)
T = length of a conditioning sequence (per sequence)
N = total length of conditioning + generated tokens (per sequence)
I = number of index tokens to roll out
B = number of batched rollouts to make (0 if not using batched rollouts)
max(T) = maximum of T across sequences
num_new_tokens = len(idx) * n_tokens_per_patch
***Logit Mixing***:
To use logit mixing, supply multiple sequences/positions and a `weights` list/tensor. Before sampling tokens, the logits
from each sequence will be mixed using the provided weights.
***Batched Rollouts***:
To run batched rollouts *without* logit mixing, supply multiple (K) sequences/positions and leave `weights=None`.
The returned sequences/logits will include an extra batch dimension of size K.
***Batched Rollouts + Logit Mixing***:
To run batched rollouts and logit mixing at the same time, supply multiple (K) sequences/positions, a `weights` list,
and `scatter_indices`. `scatter_indices` indicates which output sequence each input sequence corresponds to. For example,
to run 3 cfg rollouts, you may supply K=6 sequences, with `weights=[cfg, 1-cfg, cfg, 1-cfg, cfg, 1-cfg]` and `scatter_indices=[0,0,1,1,2,2]`.
**WARNING:** With batched rollouts and logit mixing, and logits may not be numerically deterministic on CUDA (see torch.Tensor.scatter_add_ docs.)
***Tips for long rollouts***:
1. Use `gc.collect()` and then `torch.cuda.empty_cache()` (in that order)
2. Avoid fragmentation wherever possible. This method needs to allocate the entire KV cache
contiguously in memory. If you create tensors between rollouts, consider moving them to
CPU or cloning them to defragment VRAM.
3. Even if a single N-token rollout fits in VRAM, running two consecutive rollouts (e.g. running
n1 then giving its KV cache to n2, with n1 + n2 = N) may not fit, because reallocating the
KV cache will duplicate memory. To avoid this, try moving the cache to CPU first:
```
seq1, kvcache = predictor.rollout_patches(..., return_kv=True)
kvcache = { k: v.cpu() for k, v in kvcache.items() }
gc.collect(); torch.cuda.empty_cache()
seq2 = predictor.rollout_patches(..., **kvcache)
```
With this, the new KV cache will be allocated on GPU, and the CPU cache will be copied into it.
TODO: Support temp/top_k/top_p/min_p scheduling with parallel. Currently uses index -1 for all parallel tokens
Parameters:
seq (Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]]):
[T], [K T], or list of [T] / sequence(s) to condition the generation on. None means empty sequence
pos (Union[torch.LongTensor, List[torch.LongTensor]]):
[N 4], [K N 4], or list of [N 4] / 4D position(s) corresponding to each sequence. If multiple, must have length K=len(seq)
idx (Union[torch.LongTensor. List[torch.LongTensor]]):
[I], [B I], or list of [I] / the patch indices to use in the rollout. 1D by default, used for all sequences. If using batched rollouts, idx may be 1D or 2D,
and the second dimension must be equal to the number of batched rollouts (if scatter_indices is not None, this may be less than K).
n_tokens_per_patch (int):
number of tokens per patch, including patch index
n_seq_patches (int):
number of patches to roll out sequentially (-1 for all). The remaining patches will be parallel
weights (Optional[Union[List[float], torch.Tensor]]):
float weights for the logits produced by each sequence, used for logit mixing. If None and multiple sequences are given, batched rollouts will be
used instead of logit mixing. If a list or 1D tensor, must have size K. If 2D, must have shape [K num_new_tokens].
If 2D, the first n_seq_patches patches will use the weights in order as expected, but the remaining (parallel) patches may be arbitrarily reordered.
When using parallel and a 2D weight schedule, it is recommended to make the weights for parallel patches uniform for consistency
scatter_indices (Optional[Union[List[int], torch.LongTensor]]):
integer scatter indices for logit mixing, indicating which generated sequence each input sequence is used for. If provided, multiple sequences
will be generated and returned; see above for details. Must have exactly K elements, one for each input sequence. Indices must be set-contiguous in the range [0,K)
and must include 0, e.g. for K=5, [1,0,3,2,2] are valid indices but [1,3,4,4,3] are not.
kvcache (Optional[KVCache]):
optional KV cache for this rollout. If kvcache.size > 0, the tokens in the cache will be used as conditioning for this rollout. If the cache has
enough empty space to cache this rollout, the KV tensors will be modified in-place. Note that only the tensor data is modified, not the other entries
of the KVCache (e.g. size, mask). To receive an update cache object, set return_kv=True and pass the returned cache into downstream rollouts. If no cache is
provided, or if the cache is too small for this rollout, a new cache will be allocated and the tensor data from this cache will be copied.
policy (Callable[..., torch.LongTensor]):
optional callback defining the policy for rollout order. Must accept an argument `idx` (torch.LongTensor of shape [I]) with the candidate indices to generate next.
Must return the *index* into the `idx` tensor to generate next, either an int or 0-dimensional torch.LongTensor. For example, given candidate indices [4,12,1023],
return 2 to generate the patch with index 1023 next. Only used for the sequential part of the generation. The following kwargs are given
- `idx` (torch.LongTensor of shape [I] or [B I]) the candidate patch indices
- `pos` (torch.LongTensor of shape [K N 4]): the remaining poses for all yet-ungenerated tokens in the same order as `idx`
- `weights` (torch.Tensor of shape [K N]): the weights for all yet-ungenerated tokens, or None
- `scatter_indices` (torch.LongTensor of shape [K])
- `kvcache` (KVCache): the cache for this rollout, where kvcache.size includes the tokens generated up to this point
- `n_tokens_per_patch` (int)
- `sample_range` (Optional[Tuple[int,int]])
- `idx_pos` (torch.LongTensor of shape [K I 4]): same as `pos`, but only for the candidate index tokens
- `idx_weights` (torch.Tensor of shape [K I]): same as `weights`, but only for the candidate index tokens
- `predictor` (PSIPredictor) self
- `device` (torch.device)\n
The callback must return the following
- (Union[int, torch.LongTensor]): the *index* into the `idx` tensor with the patch token to generate next (*not* the value of the patch index itself)
unmask_parallel (bool):
if True, all parallel patches can attend to each other. If False (default), parallel patches can only attend to themselves
return_logits (bool):
return the logits of the sequence
return_idx_logits (bool):
if True (default), returns logits that would predict index tokens, so there is one set of logits for every returned token. If False, only returns
logits used to sample content tokens, e.g. returns (n_tokens_per_patch - 1) sets of logits per patch. The latter may not need to compute logits
for all tokens, so it may be more efficient for some computations (such as patchwise entropy). Ignored if return_logits=False
return_kv (bool):
return the KV cache(s) as a dict with keys 'k_cache', 'v_cache', 'cache_mask', and 'cache_size', useful for downstream operations. If True and return_logits=False,
returns (new_tokens, kvcache). If True and return_logits=True, returns (new_tokens, logits, kvcache). **All KVs are returned in patch-major order,** even if
the rollout is partially or fully parallel. Note that KVs from parallel prediction are not computed causally
Returns:
torch.LongTensor:
[num_new_tokens] or [B num_new_tokens]
the generated tokens only (the input sequence is not prepended). 1D by default, or 2D if using batched rollouts
torch.Tensor:
(optional) [n_tokens vocab_size] or [B n_tokens vocab_size]
the logits of the sequence, where n_tokens depends on return_idx_logits. 2D by default, or 3D if using batched rollouts
Dict[str, torch.Tensor]:
(optional) the KV cache, with the following key/value pairs
- `k_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head]
- `v_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head]
- `cache_mask` (torch.Tensor) [K 1 1 n_tok]
- `cache_size` (int)
"""
if print_and_visualize_tokens:
raise NotImplementedError("Token visualization is not included in the Hugging Face inference build.")
#########################
# === Preprocessing === #
#########################
if not isinstance(seq, list):
seq = [seq] if seq is None or seq.ndim == 1 else list(seq)
if not isinstance(pos, list):
pos = [pos] if pos.ndim == 2 else list(pos)
if isinstance(idx, list):
idx = torch.stack(idx)
nnt = idx.shape[-1] * n_tokens_per_patch
idtype = pos[0].dtype
dtype = self.dtype
device = self.device
if weights is not None:
if isinstance(weights, list):
weights = torch.tensor(weights, dtype=dtype, device=device)
if weights.ndim != 2:
weights = weights.unsqueeze(-1).expand(-1, nnt)
weights = weights.to(dtype).to(device)
if scatter_indices is not None:
if isinstance(scatter_indices, list):
scatter_indices = torch.tensor(scatter_indices, dtype=idtype, device=device)
if n_seq_patches < 0:
n_seq_patches = idx.shape[-1]
if temp is not None and not isinstance(temp, list):
temp = [temp] * nnt
if post_temp is not None and not isinstance(post_temp, list):
post_temp = [post_temp] * nnt
if top_k is not None and not isinstance(top_k, list):
top_k = [top_k] * nnt
if top_p is not None and not isinstance(top_p, list):
top_p = [top_p] * nnt
if min_p is not None and not isinstance(min_p, list):
min_p = [min_p] * nnt
K = len(seq)
I = idx.shape[-1]
T = [0 if s is None else s.shape[0] for s in seq]
maxT = max(T)
tpp = n_tokens_per_patch
# Number of tokens we need to use from the input cache
in_cache_size = 0 if kvcache is None else kvcache.size
n_rollout_tokens = tpp * n_seq_patches
n_par_patches = I - n_seq_patches
return_idx_logits = return_logits and return_idx_logits
run_last_parallel_tokens = return_idx_logits or return_kv
# Number of batched rollouts to make. If 0, we return a single 1D sequence, else return 2D
if scatter_indices is not None:
B = int(torch.max(scatter_indices).cpu().item()) + 1
else:
B = K if (weights is None and K > 1) else 0
if idx.ndim == 2:
assert B > 0, f'2D idx tensors are only supported for batched rollouts (got idx.shape={tuple(idx.shape)})'
assert idx.shape == (B, I), f'Expected idx.shape={(B,I)}, but got {tuple(idx.shape)}.'
if scatter_indices is not None:
idx = torch.gather(idx, 0, scatter_indices.unsqueeze(-1).repeat(1, idx.shape[-1])) # [B I] -> [K I]
# else B=K, so idx is already [K I]
elif idx.ndim == 1 and B > 0:
idx = idx.expand(K, idx.shape[-1])
if idx.ndim == 2:
# TODO: Batched rollouts + policies are currently not supported
if policy is not None:
raise NotImplementedError('Batched rollouts with a policy callback is not currently supported')
# Later, we may need to reduce idx from [K I] to [B I], such as before invoking
# the policy callback or running parallel prediction. Here we precompute the rows
# to use so we can efficiently perform the reduction on-demand.
# Note we can't use torch.unique here because it doesn't maintain order.
def precompute_idx_batch_rows(scatter: List[int]):
seen = set()
out = []
for i, s in enumerate(scatter):
if s not in seen:
seen.add(s)
out.append(i)
return out
if scatter_indices is None:
idx_batch_rows = list(range(K))
else:
idx_batch_rows = precompute_idx_batch_rows(scatter_indices.tolist())
# Validate inputs as best we can
assert len(pos) == K, f'Expected seq and pos lists to have the same length, but got {K} and {len(pos)}'
assert (idx.shape == (K, I)) or (idx.ndim == 1), f'idx tensor must be 1D (or 2D for batched rollouts)'
if weights is not None:
assert weights.ndim == 2 and weights.shape == (K, nnt)
if scatter_indices is not None:
assert scatter_indices.ndim == 1 and len(scatter_indices) == K
assert torch.min(scatter_indices) >= 0 and torch.max(scatter_indices) < K
_usi = torch.unique(scatter_indices, sorted=True)
assert torch.equal(_usi, torch.arange(len(_usi), device=_usi.device))
del _usi
assert I > 0, f'No idx tokens were provided'
assert tpp > 0, f'Must have a positive number of tokens per patch'
assert I * tpp == nnt
assert 0 <= n_seq_patches <= I
if kvcache is not None:
kvcache.validate(K)
for i, (s, p) in enumerate(zip(seq, pos)):
if s is not None:
assert s.ndim == 1
assert p.ndim == 2
assert p.shape[1] == 4
assert p.shape[0] == T[i] + nnt
# General notes:
# B>0 means batched rollout mode, even if B==1. Batched means returned sequences are [B nnt], but B==0 means they're [nnt].
# 0 <= B <= K. If B==K, then all sequences are independent (no logit mixing).
# If B>0, then idx is [K I] with B unique rows (some may be repeated). If B==0, then idx is [I].
# Use idx[idx_batch_rows] to select idx with shape [B I].
# Use Tensor.scatter_add + scatter_indices to sum-reduce from [K *] to [B *].
#########################
# === Preallocation === #
#########################
# Preallocate the KV cache (if we need to) so we don't need to constantly resize it
# If we won't run the last parallel pass, we don't need to cache those toks
n_kvcache = in_cache_size + maxT + nnt - (0 if run_last_parallel_tokens else n_par_patches)
if kvcache is None or kvcache.capacity < n_kvcache:
# Allocate a new cache and copy the existing one into it
new_kvcache = self.allocate_kvcache(n_kvcache, batch_size=K, dtype=dtype, device=device)
if kvcache is not None:
new_kvcache.copy_from(kvcache)
kvcache = new_kvcache
del new_kvcache
# else we can just use the provided cache in-place
# Also preallocate the output logits tensor, if requested
if return_logits:
n_logits = n_rollout_tokens + (I - n_seq_patches) * (tpp if return_idx_logits else (tpp - 1))
if B > 0:
# Easiest to store/append this shape then transpose first two axes on return
all_logits = torch.empty((n_logits, B, self.config.vocab_size), dtype=dtype, device=device)
else:
all_logits = torch.empty((n_logits, self.config.vocab_size), dtype=dtype, device=device)
################################
# === Initial Forward Pass === #
################################
# Stack seq/pos into a batch, left-padded
# [K maxT]
if maxT > 0:
seq = torch.stack([(
torch.zeros(maxT, dtype=idtype, device=device) if s is None else
F.pad(s, (maxT - s.shape[0], 0))
) for s in seq])
# [K maxN 4]
pos = torch.stack([F.pad(p, (0, 0, maxT - t, 0)) for t, p in zip(T, pos)])
if maxT > 0:
# Build attention mask for initial forward pass [K 1 maxT maxT]
# Batch size K, each mask in the batch is fully causal except for the first (maxT - T) tokens, which are masked
mask = torch.zeros(K, 1, maxT, maxT, device=device)
mask.masked_fill_(torch.ones_like(mask, dtype=torch.bool).triu(1), float('-inf'))
for i, t in enumerate(T):
mask[i, ..., :maxT-t] = float('-inf')
# Unmask the diagonal so pad tokens can self-attend
# This doesn't matter with torch sdpa, but prevents NaNs with manual attention
# NOTE: If t==0, the diagonal is *only* pad tokens, so this will unmask the last pad token
# in the last row (which we use for rollouts). We re-mask this pad token in the rollout mask below
mask[i, 0].fill_diagonal_(0.0)
if kvcache is not None:
if kvcache.mask is not None:
mask = torch.cat([kvcache.mask.to(mask.device).expand((K, 1, maxT, -1)), mask], dim=-1)
else:
mask = F.pad(mask, (in_cache_size, 0, 0, 0, 0, 0, 0, 0))
# The above mask[:,0,-1,:] might look something like this:
# kv cache | sequences
# T[0]==3 [ T T T T T T | F F F F T T T ]
# T[1]==6 [ T T T T T T | F T T T T T T ]
# T[2]==7 [ T T T T T T | T T T T T T T ]
# T[3]==4 [ T T T T T T | F F F T T T T ]
# For one element in the batch, mask[0,0,:,:] with T[0]==3 might look like:
# kv cache | sequences
# [ T T T T T T | T F F F F F F ]
# [ T T T T T T | F T F F F F F ]
# [ T T T T T T | F F T F F F F ]
# [ T T T T T T | F F F T F F F ]
# [ T T T T T T | F F F F T F F ]
# [ T T T T T T | F F F F T T F ]
# [ T T T T T T | F F F F T T T ]
# If a custom kvcache.mask is given, the kv cache part above may be different
# Initial forward pass (conditioning sequences only)
k_cache, v_cache = kvcache.get_token_slice(0, in_cache_size+maxT)
self.forward(
seq=seq, pos=pos[:,:maxT], mask=mask,
k_cache=k_cache, v_cache=v_cache, inplace_kv=True
)
pos = pos[:,maxT:]
##############################
# === Sequential Rollout === #
##############################
# Build attention mask for rollout [K 1 1 in_cache_size+maxT+n_rollout_tokens]
if maxT == 0:
if kvcache is None or kvcache.mask is None:
# We have no context and no input cache mask, so we only need rollout tokens
# Note we may still have an input cache without a mask, so add in_cache_size tokens
mask = torch.zeros((K, 1, 1, in_cache_size+n_rollout_tokens), dtype=dtype, device=device)
else:
# We have no context and an input cache mask
# Pad [K 1 1 in_cache_size] -> [K 1 1 in_cache_size+n_rollout_tokens]
mask = F.pad(kvcache.mask, (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0))
else:
# We created a mask while computing context above. Take the last entry and clone to free memory
mask = F.pad(mask[...,[-1],:].clone(), (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0))
for i, t in enumerate(T):
# If t==0, the fill_diagonal call above unmasked the last pad token,
# so we need to re-mask it before we start rolling out
if t == 0:
mask[i, ..., in_cache_size+maxT-1] = float('-inf')
# The above mask[:,0,0,:] might look something like this:
# kv cache | sequences | rollout
# T[0]==3 [ T T T T T T | F F F F T T T | T T T T ... T T T T ]
# T[1]==6 [ T T T T T T | F T T T T T T | T T T T ... T T T T ]
# T[2]==7 [ T T T T T T | T T T T T T T | T T T T ... T T T T ]
# T[3]==4 [ T T T T T T | F F F T T T T | T T T T ... T T T T ]
# We construct this mask once, then slice off part of the right side at each rollout step
# List of 0D (if B==0) or 1D (if B>0) tensors
rollout_seq = []
# Rollout
tqdm_kwargs = dict(desc='Rollout', unit='tok') | tqdm_kwargs
if n_rollout_tokens == 0:
tqdm_kwargs |= dict(disable=True)
for i in tqdm.tqdm(range(n_rollout_tokens), **tqdm_kwargs):
if (i % tpp) == 0:
patch_number = i // tpp
if policy is None:
# Use provided order
next_token = idx[...,patch_number] # 0D or [K]
else:
# Use callback to select the next patch
policy_cache_mask = mask[..., :in_cache_size+maxT+i]
idx_of_next_idx = policy(
idx=idx[idx_batch_rows,patch_number:] if B > 0 else idx[patch_number:],
pos=pos[:, i:],
weights=None if weights is None else weights[:, i:],
scatter_indices=scatter_indices,
kvcache=kvcache.shallow_copy(mask=policy_cache_mask, size=in_cache_size+maxT+i),
n_tokens_per_patch=n_tokens_per_patch,
sample_range=sample_range,
idx_pos=pos[:, i::tpp],
idx_weights=None if weights is None else weights[:, i::tpp],
predictor=self,
device=idx.device,
)
del policy_cache_mask
# Move the patch patch_number+idx_of_next_idx to the next position by swapping
# TODO: Policies are currently not supported with batched rollouts
if idx_of_next_idx != 0:
i1, i2 = patch_number, patch_number + int(idx_of_next_idx)
idx[[i1,i2]] = idx[[i2,i1]]
# [K N 4] -> [K I tpp 4] -swap-idxs-> [K I tpp 4] -> [K N 4]
pos = pos.reshape(K, I, tpp, 4)
pos[:,[i1,i2]] = pos[:,[i2,i1]]
pos = pos.reshape(K, -1, 4)
next_token = idx[patch_number]
rollout_seq.append(next_token[idx_batch_rows] if B > 0 else next_token)
# Forward call
k_cache, v_cache = kvcache.get_token_slice(0, in_cache_size+maxT+i+1)
logits, _ = self.forward(
seq=next_token.unsqueeze(-1).expand(K, 1), # 0D or [K] -> [K 1]
pos=pos[:, [i]], # [K 1 4]
mask=mask[..., :in_cache_size+maxT+i+1], # [K 1 1 n_prev+1]
k_cache=k_cache,
v_cache=v_cache,
inplace_kv=True
)
logits = logits.squeeze(1) # [K 1 V] -> [K V]
# If next is an index token, we just needed to cache the previous token
# so we can skip logits if we're not returning them
if ((i + 1) % tpp) == 0 and not return_idx_logits:
continue
# Logit mixing: [K V] -> [1 V] or [B V]
if weights is not None:
logits.mul_(weights[:,[i]]) # weights: [K nnt] -> [K 1], logits: [K V]
if scatter_indices is not None:
new_logits = torch.zeros((B, logits.shape[-1]), dtype=dtype, device=device)
si = scatter_indices.unsqueeze(-1).repeat(1, logits.shape[-1]) # [B V]
new_logits.scatter_add_(0, si, logits) # [B V]
logits = new_logits
del new_logits
elif B == 0:
logits = logits.sum(0, keepdim=True) # [1 V]
# else logits: [K V] = [B V]
# If next is an index token, we just needed to cache the previous token
if ((i + 1) % tpp) == 0:
if return_idx_logits:
all_logits[i].copy_(logits if B > 0 else logits.squeeze())
continue
if return_logits:
all_logits[i].copy_(logits if B > 0 else logits.squeeze())
# Sample from logits
next_token = self.sample_logits(
logits,
temp=temp[i] if temp else None,
post_temp=post_temp[i] if post_temp else None,
top_k=top_k[i] if top_k else None,
top_p=top_p[i] if top_p else None,
min_p=min_p[i] if min_p else None,
sample_range=sample_range,
blacklist=blacklist
)
rollout_seq.append(next_token if B > 0 else next_token.squeeze())
# Batched rollouts: We need to remap next_token from shape [B] to [K 1]
# Note that if B>0 and scatter_indices is None, then B=K, as enforced during preprocessing
if B > 0 and scatter_indices is not None:
next_token = torch.gather(next_token, 0, scatter_indices) # [B] -> [K]
if n_rollout_tokens > 0:
rollout_seq = torch.stack(rollout_seq) # [nnt] or [nnt B]
if B > 0:
rollout_seq = rollout_seq.mT # [nnt B] -> [B nnt]
if n_rollout_tokens == nnt:
ret = (rollout_seq,)
if return_logits:
# [nnt B V] -> [B nnt V], or [nnt V]
ret = (*ret, all_logits.transpose(0, 1)) if B > 0 else (*ret, all_logits)
if return_kv:
ret = (*ret, kvcache.shallow_copy(mask=mask, size=n_kvcache))
return ret if len(ret) > 1 else ret[0]
############################
# === Parallel Rollout === #
############################
npp = n_par_patches
npt = npp * tpp # num parallel tokens
idx = idx[-npp:]
# Build attention mask for parallel part [K 1 npp n_past+npp]
if unmask_parallel:
par_mask = torch.zeros(1, dtype=mask.dtype, device=device).expand(K, 1, npp, npp)
else:
par_mask = torch.full((npp, npp), float('-inf'), dtype=mask.dtype, device=device)
par_mask.fill_diagonal_(0.0)
par_mask = par_mask.expand(K, 1, npp, npp)
mask = mask.expand(K, 1, npp, -1)
# Initial shape is [K 1 npp n_past]. Before each iter, we append par_mask [K 1 npp npp] along the last dim
# Reshape positions for parallel passes
# [K nnt 4] -trim-> [K npt 4] -> [K npp tpp 4] -> [K tpp npp 4] -> [K npt 4]
pos = pos[:,-npt:].reshape(K, npp, tpp, 4).transpose(1, 2).reshape(K, npt, 4)
# Reshape scheduled properties (transpose similarly to positions)
# [nnt] -trim-> [npt] -> [npp tpp] -> [tpp npp] -> [npt]
if temp is not None:
temp = np.array(temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if post_temp is not None:
post_temp = np.array(post_temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if top_k is not None:
top_k = np.array(top_k[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if top_p is not None:
top_p = np.array(top_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if min_p is not None:
min_p = np.array(min_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if weights is not None:
# [K nnt] -trim-> [K npt] -> [K npp tpp] -> [K tpp npp] -> [K npt]
weights = weights[:,-npt:].reshape(K, npp, tpp).transpose(1, 2).reshape(K, npt)
next_tokens = idx.expand(K, npp) # [K npp] or [npp] -> [K npp]
parallel_seq = [idx[idx_batch_rows]] if B > 0 else [idx] # [B npp] or [npp]
# Run parallel passes
for i in range(tpp if run_last_parallel_tokens else (tpp - 1)):
mask = torch.cat([mask, par_mask], dim=-1)
parallel_slice = slice(i*npp, (i+1)*npp)
k_cache, v_cache = kvcache.get_token_slice(0, in_cache_size+maxT+n_rollout_tokens+(i+1)*npp)
logits, _ = self.forward(
seq=next_tokens, # [K npp]
pos=pos[:,parallel_slice], # [K npp 4]
mask=mask, # [K 1 npp n_past+npp]
k_cache=k_cache,
v_cache=v_cache,
inplace_kv=True
)
# Weighted sum of logits [K npp V] -> [1 npp V] or [B npp V]
if weights is not None:
w = weights[:,parallel_slice] # [K npt] -> [K npp]
logits.mul_(w.unsqueeze(-1)) # [K npp V]
if scatter_indices is not None:
new_logits = torch.zeros((B, npp, logits.shape[-1]), dtype=dtype, device=device)
new_logits.scatter_add_(0, scatter_indices.reshape(-1,1,1).repeat(1, npp, logits.shape[-1]), logits)
logits = new_logits # [B npp V]
del new_logits
elif B == 0:
logits = logits.sum(0, keepdim=True) # [K npp V] -> [1 npp V]
# else B==K, as enforced in preprocessing
if return_logits and (return_idx_logits or i < tpp - 1):
# Store the logits with a stride so we don't need to transpose later
# NOTE: If return_idx_logits=False, we have tpp-1 instead of tpp
stride = tpp if return_idx_logits else (tpp - 1)
all_logits[n_rollout_tokens+i::stride].copy_(
logits.transpose(0,1) if B > 0 else logits.squeeze(0) # [npp B V] or [npp V]
)
if i == (tpp - 1):
# We just needed to compute logits and/or KV to return them; no need to predict
break
# Sample from logits
# TODO: Index using parallel_slice instead of -1 to support scheduled parameters
next_tokens = self.sample_logits(
logits,
temp=temp[-1] if temp else None,
post_temp=post_temp[-1] if post_temp else None,
top_k=top_k[-1] if top_k else None,
top_p=top_p[-1] if top_p else None,
min_p=min_p[-1] if min_p else None,
sample_range=sample_range,
blacklist=blacklist
)
parallel_seq.append(next_tokens if B > 0 else next_tokens.squeeze(0))
# Batched rollouts: We need to remap next_tokens from shape [B npp] to [K npp]
# Note that if B>0 and scatter_indices is None, then B=K, as enforced during preprocessing
if B > 0 and scatter_indices is not None:
next_tokens = torch.gather(next_tokens, 0, scatter_indices.unsqueeze(-1).repeat(1,npp))
# [tpp npp] -> [npp tpp] -> [npt]
parallel_seq = torch.stack(parallel_seq).transpose(0, 1).flatten()
if return_kv:
# Transpose the parallel part of the cache so it's returned the same order as sequential rollouts
kvcache.transpose_last(n_kvcache, npt, tpp)
# [K 1 npp N] -> [K 1 1 N], clone to free memory, doesn't matter which row we take (only different in last npt cols)
# Unmask the last npt cols (if necessary) to make the parallel part fully unmasked
mask = mask[...,[-1],:].clone()
if not unmask_parallel:
mask[...,-npt:] = 0.0
ret = (torch.cat([rollout_seq, parallel_seq], -1),) if n_rollout_tokens > 0 else (parallel_seq,)
if return_logits:
# [nnt B V] -> [B nnt V], or [nnt V]
ret = (*ret, all_logits.transpose(0, 1)) if B > 0 else (*ret, all_logits)
if return_kv:
ret = (*ret, KVCache(k=kvcache.k, v=kvcache.v, mask=mask, size=n_kvcache))
return ret if len(ret) > 1 else ret[0]
def _unpatchify(labels: torch.Tensor) -> torch.Tensor:
"""
labels: [B, N, P*3], where P = patch_area
Returns [B, 3, H, W] assuming square grid and patches.
"""
B, N, threeP = labels.shape
C = 3
P = threeP // C
p = int(math.sqrt(P))
side = int(math.sqrt(N))
assert p * p == P and side * side == N
rec = labels.view(B, side, side, p, p, C).permute(0, 5, 1, 3, 2, 4).contiguous()
img = rec.view(B, C, side * p, side * p)
return img
def _unpatchify_logits(logits: torch.Tensor) -> torch.Tensor:
"""
logits: [B, N, P, V] -> [B, V, H, W] aligned to the unpatchify layout.
"""
B, N, P, V = logits.shape
p = int(math.sqrt(P))
side = int(math.sqrt(N))
assert p * p == P and side * side == N
out = logits.view(B, side, side, p, p, V).permute(0, 5, 1, 3, 2, 4).contiguous()
out = out.view(B, V, side * p, side * p)
return out