PSI / psi.py
TheTrueJard's picture
Upload folder using huggingface_hub
e3d287d verified
"""
PSI Model Definition
"""
import math
import importlib
from typing import Tuple, Union, List, Optional, Callable, Dict
from transformers import PreTrainedModel
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import tqdm
from .config import PSIConfig
from .modeling import (
RMSNorm, PSIBlock, PartitionedEmbedding, PartitionedLinear
)
try:
xm = importlib.import_module('torch_xla.core.xla_model')
xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding')
except ImportError:
xm = None
xs = None
class PSI(PreTrainedModel):
config_class = PSIConfig
### Initialization Functions ###
def __init__(self, config):
super().__init__(config)
self.config = config
if hasattr(config, "partition_embedding") and config.partition_embedding:
token_embedding = PartitionedEmbedding(config.vocab_size, config.n_embd)
lm_head = PartitionedLinear(config.n_embd, config.vocab_size, bias=False)
else:
token_embedding = nn.Embedding(config.vocab_size, config.n_embd)
if hasattr(config, "n_lm_vocab") and config.n_lm_vocab is not None:
n_lm_vocab = config.n_lm_vocab
else:
n_lm_vocab = config.vocab_size
lm_head = nn.Linear(config.n_embd, n_lm_vocab, bias=False)
self.transformer = nn.ModuleDict(dict(
token_embedding = token_embedding,
channel_embedding = nn.Embedding(config.channel_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([PSIBlock(config) for _ in range(config.n_layer)]),
ln_f = RMSNorm(config.n_embd, bias=config.bias),
))
self.lm_head = lm_head
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
if hasattr(config, "tie_weights") and config.tie_weights:
if hasattr(config, "partition_embedding") and config.partition_embedding:
for i in range(len(self.transformer.token_embedding.embedding_layers)):
self.lm_head.linear_layers[i].weight = self.transformer.token_embedding.embedding_layers[i].weight
else:
self.lm_head.weight = self.transformer.token_embedding.weight
# Apply XLA sharding for model parallelism if on XLA and model axis > 1
xla_device_available = False
if xm is not None:
try:
device_kind = xm.xla_device_kind()
if device_kind is not None:
xla_device_available = True
except RuntimeError:
pass
self.unsharded_param_count = self.get_num_params()
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def get_num_params(self):
"""Return the number of parameters in the model."""
return sum(p.numel() for p in self.parameters())
### Training Functions ###
def forward(
self,
seq: torch.Tensor,
pos: torch.Tensor,
tgt: torch.Tensor = None,
mask: torch.Tensor = None,
k_cache: torch.Tensor = None,
v_cache: torch.Tensor = None,
return_kv: bool = False,
inplace_kv: bool = False,
output_hidden_states: bool = False,
) -> torch.Tensor:
"""
Forward pass of the model
Parameters:
seq (torch.Tensor) of size b, t: The input sequence
pos (torch.Tensor) of size b, t, d: The positional indices of the sequence of shape (batch, tokens, dimensions)
They consist of x, y, t and c coordinates, where x, y are the spatial coordinates of the patch,
t is the time index and c is the channel index
tgt (torch.Tensor) of size b, t_tgt: The target sequence
mask (torch.Tensor) of size b, t, t: The mask of the sequence
k_cache (torch.Tensor) of size n_layer, b, n_head, n, n_embd//n_head: A k-cache to prepend
v_cache (torch.Tensor) of size n_layer, b, n_head, n, n_embd//n_head: A v-cache to prepend
return_kv (bool): If True, returns (logits, k, v). Ignored if tgt != None
inplace_kv (bool): If True, k_cache/v_cache are modified in-place. They must be sufficiently large to store
the new tokens, and the last N tokens will be overwritten. If False (default), the input kv will not be
modified, and a concat operation will be used instead. No effect if k_cache/v_cache are None.
Returns:
torch.Tensor: The logits of the model. Size b, t if tgt is None, else b, t_tgt
if tgt != None:
torch.Tensor: The cross entropy loss between the logits and tgt
elif return_k:
torch.Tensor: the k-cache
torch.Tensor: the v-cache
"""
st_pos = pos[:, :, :-1]
channel_pos = pos[:, :, -1]
# forward the GPT model itself
tok_emb = self.transformer.token_embedding(seq) # token embeddings of shape (b, t, n_embd)
channel_emb = self.transformer.channel_embedding(channel_pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(tok_emb + channel_emb)
if output_hidden_states:
hidden_states = [x]
k_list, v_list = [], []
for i, block in enumerate(self.transformer.h):
x = block(x, pos=st_pos, mask=mask,
k_cache=None if k_cache is None else k_cache[i],
v_cache=None if v_cache is None else v_cache[i],
return_kv=return_kv, inplace_kv=inplace_kv)
if return_kv:
x, k, v = x
k_list.append(k)
v_list.append(v)
if output_hidden_states:
hidden_states.append(x)
x = self.transformer.ln_f(x)
if output_hidden_states:
hidden_states.append(x)
# if tgt is not none, compute the logits for the entire sequence
if tgt is None:
logits = self.lm_head(x)
if output_hidden_states:
logits = {"logits": logits, "hidden_states": hidden_states}
if return_kv:
if inplace_kv:
# We modified in-place; avoid allocating a new tensor with torch.stack
return logits, k_cache, v_cache
else:
return logits, torch.stack(k_list), torch.stack(v_list)
return logits, None
# if tgt is not none, compute the logits and the loss for the target sequence
logits = self.lm_head(x[:, -tgt.size(1):])
if output_hidden_states:
logits = {"logits": logits, "hidden_states": hidden_states}
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1), ignore_index=-1)
return logits, loss
### Rollout Functions ###
@torch.no_grad()
def sample_logits(self,
logits: torch.FloatTensor,
temp: Optional[float] = None,
post_temp: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
min_p: Optional[float] = None,
sample_range: Optional[Tuple[int,int]] = None,
blacklist: Optional[Union[List[int], torch.LongTensor]] = None
) -> torch.LongTensor:
"""
Samples an integer from the distribution of logits
Parameters:
logits (torch.FloatTensor): The logits of the distribution
temp (float): The temperature of the sampling, if 0.0, then argmax is used
top_k (int): The number of top k tokens to consider during sampling
top_p (float): The cumulative probability threshold for nucleus (top-p) sampling
min_p (float): The minimum probability threshold factor for min-p sampling
blacklist (Union[List[int], torch.LongTensor]): The list of tokens to blacklist during sampling
Returns:
torch.LongTensor: The sampled integers
"""
if isinstance(temp, list):
temp = temp[0]
if isinstance(post_temp, list):
post_temp = post_temp[0]
if isinstance(top_k, list):
top_k = top_k[0]
if isinstance(top_p, list):
top_p = top_p[0]
assert temp is None or temp >= 0.0
assert post_temp is None or post_temp >= 0.0
assert top_k is None or top_k > 0
assert top_p is None or top_p >= 0.0
assert min_p is None or 0.0 <= min_p <= 1.0
assert sample_range is None or (
sample_range[0] < sample_range[1] and
sample_range[0] >= 0 and
sample_range[1] <= logits.shape[-1]
)
# Apply blacklist & sample range
if blacklist is not None:
logits[...,blacklist] = float('-inf')
if sample_range is not None:
logits = logits[...,sample_range[0]:sample_range[1]]
token_offset = sample_range[0]
else:
token_offset = 0
# Apply temperature, or use argmax if 0.0
if (temp is not None and temp == 0.0) or (post_temp is not None and post_temp == 0.0):
return token_offset + torch.argmax(logits, dim=-1)
if temp is not None and temp != 1.0:
logits.div_(temp)
# Sort the logits once. More efficient when using top-k and top-p together (min-p doesn't require sorting).
# We sample in sorted order then re-order before returning.
if top_k is not None or top_p is not None:
logits, order = torch.sort(logits, dim=-1, descending=True)
else:
order = None # Don't sort
# Apply top-k filtering if specified
if top_k is not None:
logits = logits[...,:top_k]
# Apply top-p (nucleus) filtering if specified
if top_p is not None:
probs = F.softmax(logits, dim=-1) # Already sorted
cumulative_probs = probs.cumsum_(dim=-1)
idxs_to_remove = cumulative_probs > top_p
# Shift the mask right to keep at least one token
logits[...,1:][idxs_to_remove[...,:-1]] = float('-inf')
del probs, cumulative_probs, idxs_to_remove
# Apply min-p filtering if specified
if min_p is not None:
probs = F.softmax(logits, dim=-1)
maxprob = probs[...,[0]] if order is not None else torch.max(probs, dim=-1, keepdim=True).values
logits[probs < maxprob * min_p] = float('-inf')
del probs, maxprob
# Apply optional post-temperature
if post_temp is not None and post_temp != 1.0:
logits.div_(post_temp)
# Compute softmax probabilities
orig_shape = logits.shape
probs = torch.softmax(logits, dim=-1, out=logits)
# Flatten probabilities to (batch_size * sequence_length, vocab_size)
flat_probs = probs.view(-1, probs.size(-1))
# Sample from the distribution
sampled = torch.multinomial(flat_probs, num_samples=1)
# Reshape to original shape except for the last dimension
sampled = sampled.view(*orig_shape[:-1])
# If we sorted, unsort to collect the actual token values
if order is not None:
sampled = torch.gather(order, dim=-1, index=sampled.unsqueeze(-1)).squeeze(-1)
return token_offset + sampled
@torch.no_grad()
def rollout_patches(self,
seq: Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]],
pos: Union[torch.LongTensor, List[torch.LongTensor]],
idx: torch.LongTensor,
n_tokens_per_patch: int = 5,
n_seq_patches: int = -1,
weights: Optional[Union[List[float], torch.Tensor]] = None,
k_cache: Optional[torch.Tensor] = None,
v_cache: Optional[torch.Tensor] = None,
cache_mask: Optional[torch.Tensor] = None,
policy: Callable[..., torch.LongTensor] = None,
*,
unmask_parallel: bool = False,
return_logits: bool = False,
return_idx_logits: bool = True,
return_kv: bool = False,
verbose: bool = True,
temp: Optional[Union[float, List[float]]] = None,
post_temp: Optional[Union[float, List[float]]] = None,
top_k: Optional[Union[int, List[int]]] = None,
top_p: Optional[Union[float, List[float]]] = None,
min_p: Optional[Union[float, List[float]]] = None,
sample_range: Optional[Tuple[int, int]] = None,
blacklist: Optional[Union[List[int], torch.LongTensor]] = None
) -> Union[
torch.LongTensor, # seq
Tuple[torch.LongTensor, torch.Tensor], # seq, logits
Tuple[torch.LongTensor, Dict[str, torch.Tensor]], # seq, kvcache
Tuple[torch.LongTensor, torch.Tensor, Dict[str, torch.Tensor]], # seq, logits, kvcache
]:
"""
K = number of given sequences (1 if seq is not a list)
T = length of a conditioning sequence (per sequence)
N = total length of conditioning + generated tokens (per sequence)
I = number of index tokens to roll out
max(T) = maximum of T across sequences
num_new_tokens = len(idx) * n_tokens_per_patch
***Tips for long rollouts***:
1. Use `gc.collect()` and then `torch.cuda.empty_cache()` (in that order)
2. Avoid fragmentation wherever possible. This method needs to allocate the entire KV cache
contiguously in memory. If you create tensors between rollouts, consider moving them to
CPU or cloning them to defragment VRAM.
3. Even if a single N-token rollout fits in VRAM, running two consecutive rollouts (e.g. running
n1 then giving its KV cache to n2, with n1 + n2 = N) may not fit, because reallocating the
KV cache will duplicate memory. To avoid this, try moving the cache to CPU first:
```
seq1, kvcache = predictor.rollout_patches(..., return_kv=True)
kvcache = { k: v.cpu() for k, v in kvcache.items() }
gc.collect(); torch.cuda.empty_cache()
seq2 = predictor.rollout_patches(..., **kvcache)
```
With this, the new KV cache will be allocated on GPU, and the CPU cache will be copied into it.
TODO: Support temp/top_k/top_p/min_p scheduling with parallel. Currently uses index -1 for all parallel tokens
Parameters:
seq (Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]]):
[T], [K T], or list of [T] / sequence(s) to condition the generation on. None means empty sequence
pos (Union[torch.LongTensor, List[torch.LongTensor]]):
[N 4], [K N 4], or list of [N 4] / 4D position(s) corresponding to each sequence. If multiple, must have length K=len(seq)
idx (torch.LongTensor):
[I] / the patch indices to use in the rollout (same for all sequences)
n_tokens_per_patch (int):
number of tokens per patch, including patch index
n_seq_patches (int):
number of patches to roll out sequentially (-1 for all). The remaining patches will be parallel
weights (Optional[Union[List[float], torch.Tensor]]):
float weights for the logits produced by each sequence. If None and multiple sequences are given, defaults to all ones.
If a list or 1D tensor, must have size K. If 2D, must have shape [K num_new_tokens].
If 2D, the first n_seq_patches patches will use the weights in order as expected, but the remaining (parallel) patches may be arbitrarily reordered.
When using parallel and a 2D weight schedule, it is recommended to make the weights for parallel patches uniform for consistency
k_cache (Optional[torch.Tensor]):
optional k_cache to prepend to all seqs, broadcastable to shape [n_layer K n_head n_tok n_embd//n_head]. May be on a different device
v_cache (Optional[torch.Tensor]):
optional v_cache to prepend to all seqs, broadcastable to shape [n_layer K n_head n_tok n_embd//n_head]. May be on a different device
cache_mask (Optional[torch.Tensor]):
optional mask to be applied to the provided KV cache with shape [K 1 1 n_tok], where n_tok matches k_cache/v_cache. Useful when a KV cache is supplied
for multiple conditioning sequences of different lengths, where the cache_mask indicates which elements of the cache should be attended to for each sequence.
If k_cache/v_cache are given and cache_mask is None, the cache will be fully unmasked. May be on a different device
policy (Callable[..., torch.LongTensor]):
optional callback defining the policy for rollout order. Must accept an argument `idx` (torch.LongTensor of shape [I]) with the candidate indices to generate next.
Must return the *index* into the `idx` tensor to generate next, either an int or 0-dimensional torch.LongTensor. For example, given candidate indices [4,12,1023],
return 2 to generate the patch with index 1023 next. Only used for the sequential part of the generation. The following kwargs are given
- `idx` (torch.LongTensor of shape [I]) the candidate patch indices
- `pos` (torch.LongTensor of shape [K N 4]): the remaining poses for all yet-ungenerated tokens in the same order as `idx`
- `weights` (torch.Tensor of shape [K N]): the weights for all yet-ungenerated tokens, or None
- `k_cache` (torch.Tensor)
- `v_cache` (torch.Tensor)
- `cache_mask` (torch.Tensor of shape [K 1 1 n_tok])
- `kvcache` (Dict[str, torch.Tensor]): the kvcache dict with keys 'k_cache', 'v_cache', and 'cache_mask'
- `all_k_cache` (torch.Tensor): the entire preallocated k-cache, including uninitialized tokens
- `all_v_cache` (torch.Tensor): the entire preallocated v-cache, including uninitialized tokens
- `n_tokens_per_patch` (int)
- `sample_range` (Optional[Tuple[int,int]])
- `idx_pos` (torch.LongTensor of shape [K I 4]): same as `pos`, but only for the candidate index tokens
- `idx_weights` (torch.Tensor of shape [K I]): same as `weights`, but only for the candidate index tokens
- `device` (torch.device)\n
The callback must return the following
- (Union[int, torch.LongTensor]): the *index* into the `idx` tensor with the patch token to generate next (*not* the value of the patch index itself)
unmask_parallel (bool):
if True, all parallel patches can attend to each other. If False (default), parallel patches can only attend to themselves
return_logits (bool):
return the logits of the sequence
return_idx_logits (bool):
if True (default), returns logits that would predict index tokens, so there is one set of logits for every returned token. If False, only returns
logits used to sample content tokens, e.g. returns (n_tokens_per_patch - 1) sets of logits per patch. The latter may not need to compute logits
for all tokens, so it may be more efficient for some computations (such as patchwise entropy). Ignored if return_logits=False
return_kv (bool):
return the KV cache(s) as a dict with keys 'k_cache', 'v_cache', and 'cache_mask', useful for downstream operations. If True and return_logits=False,
returns (new_tokens, kvcache). If True and return_logits=True, returns (new_tokens, logits, kvcache). **All KVs are returned in patch-major order,** even if
the rollout is partially or fully parallel. Note that KVs from parallel prediction are not computed causally
Returns:
torch.LongTensor:
[num_new_tokens] the generated tokens only (the input sequence is not prepended)
torch.Tensor:
(optional) [n_tokens vocab_size] the logits of the sequence, where n_tokens depends on return_idx_logits
Dict[str, torch.Tensor]:
(optional) the KV cache, with the following key/value pairs
- `k_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head]
- `v_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head]
- `cache_mask` (torch.Tensor) [K 1 1 n_tok]
"""
#########################
# === Preprocessing === #
#########################
if not isinstance(seq, list):
seq = [seq] if seq is None or seq.ndim == 1 else list(seq)
if not isinstance(pos, list):
pos = [pos] if pos.ndim == 2 else list(pos)
nnt = idx.numel() * n_tokens_per_patch # num new tokens
device = pos[0].device
idtype = pos[0].dtype
dtype = self.lm_head.weight.dtype
if weights is not None:
if isinstance(weights, list):
weights = torch.tensor(weights, dtype=dtype, device=device)
if weights.ndim != 2:
weights = weights.unsqueeze(-1).expand(-1, nnt)
weights = weights.to(dtype).to(device)
elif len(seq) > 1:
weights = torch.ones(len(seq), nnt, dtype=dtype, device=device)
if n_seq_patches < 0:
n_seq_patches = idx.shape[0]
if temp is not None and not isinstance(temp, list):
temp = [temp] * nnt
if post_temp is not None and not isinstance(post_temp, list):
post_temp = [post_temp] * nnt
if top_k is not None and not isinstance(top_k, list):
top_k = [top_k] * nnt
if top_p is not None and not isinstance(top_p, list):
top_p = [top_p] * nnt
if min_p is not None and not isinstance(min_p, list):
min_p = [min_p] * nnt
K = len(seq)
I = idx.shape[0]
T = [0 if s is None else s.shape[0] for s in seq]
maxT = max(1, max(T))
tpp = n_tokens_per_patch
in_cache_size = 0 if k_cache is None else k_cache.shape[3]
n_rollout_tokens = tpp * n_seq_patches
n_par_patches = I - n_seq_patches
return_idx_logits = return_logits and return_idx_logits
run_last_parallel_tokens = return_idx_logits or return_kv
# Validate inputs as best we can
assert len(pos) == K, f'Expected seq and pos lists to have the same length, but got {K} and {len(pos)}'
assert idx.ndim == 1
if weights is not None:
assert weights.ndim == 2 and weights.shape == (K, nnt)
assert I * tpp == nnt, f'Requested {nnt} new tokens, but ({I} idx tokens) * ({tpp} tok per patch) = {I*tpp} != {nnt}'
assert 0 <= n_seq_patches <= I
assert k_cache is None or k_cache.ndim == 5
assert v_cache is None or v_cache.ndim == 5
assert (k_cache is None) == (v_cache is None)
assert k_cache is None or k_cache.shape[3] == v_cache.shape[3]
if cache_mask is not None:
assert cache_mask.ndim == 4 and cache_mask.shape[1] == 1 and cache_mask.shape[2] == 1
assert cache_mask.shape[-1] == in_cache_size, f'cache_mask ({cache_mask.shape[-1]} tokens) does not match the size of k_cache/v_cache ({in_cache_size} tokens)'
for i, (s, p) in enumerate(zip(seq, pos)):
if s is not None:
assert s.ndim == 1, f'Expected all sequence tensors to be 1D, but got seq[{i}].ndim={s.ndim}'
assert p.ndim == 2, f'Expected all position tensors to be 2D, but got pos[{i}].ndim={p.ndim}'
assert p.shape[1] == 4, f'Expected all position tensors have shape (*,4), but got pos[{i}].shape[1]={p.shape[1]}'
assert p.shape[0] == T[i] + nnt, f'Sequence {i}: With {T[i]} conditioning and {nnt} new tokens, expected pos[{i}].shape[0]={T[i]+nnt}, but got {p.shape[0]}'
#########################
# === Preallocation === #
#########################
# Preallocate the KV cache so we don't need to constantly resize it
# If we won't run the last parallel pass, we don't need to cache those toks
n_kvcache = in_cache_size + maxT + nnt - (0 if run_last_parallel_tokens else n_par_patches)
# [n_layer K n_head n_tok n_embd//n_head]
kv_shape = (
self.config.n_layer, K, self.config.n_head,
n_kvcache, self.config.n_embd // self.config.n_head
)
all_v_cache = torch.empty(kv_shape, dtype=dtype, device=device)
all_k_cache = torch.empty(kv_shape, dtype=dtype, device=device)
if in_cache_size > 0:
all_k_cache[...,:in_cache_size,:].copy_(k_cache, non_blocking=True)
all_v_cache[...,:in_cache_size,:].copy_(v_cache, non_blocking=True)
# Also preallocate the output logits tensor, if requested
if return_logits:
n_logits = n_rollout_tokens + (I - n_seq_patches) * (tpp if return_idx_logits else (tpp - 1))
all_logits = torch.empty((n_logits, self.config.vocab_size), dtype=dtype, device=device)
################################
# === Initial Forward Pass === #
################################
# Stack seq/pos into a batch, left-padded
# [K maxT]
seq = torch.stack([(
torch.zeros(maxT, dtype=idtype, device=device) if s is None else
F.pad(s, (maxT - s.shape[0], 0))
) for s in seq])
# [K maxN 4]
pos = torch.stack([F.pad(p, (0, 0, maxT - t, 0)) for t, p in zip(T, pos)])
# Build attention mask for initial forward pass [K 1 maxT maxT]
# Batch size K, each mask in the batch is fully causal except for the first (maxT - T) tokens, which are masked
mask = torch.zeros(K, 1, maxT, maxT, device=device)
mask.masked_fill_(torch.ones_like(mask, dtype=torch.bool).triu(1), float('-inf'))
for i, t in enumerate(T):
mask[i, ..., :maxT-t] = float('-inf')
# Unmask the diagonal so pad tokens can self-attend
# This doesn't matter with torch sdpa, but prevents NaNs with manual attention
# NOTE: If t==0, the diagonal is *only* pad tokens, so this will unmask the last pad token
# in the last row (which we use for rollouts). We re-mask this pad token in the rollout mask below
mask[i, 0].fill_diagonal_(0.0)
if k_cache is not None:
if cache_mask is not None:
mask = torch.cat([cache_mask.to(mask.device).expand((K, 1, maxT, -1)), mask], dim=-1)
else:
mask = F.pad(mask, (in_cache_size, 0, 0, 0, 0, 0, 0, 0))
# The above mask[:,0,-1,:] might look something like this:
# kv cache | sequences
# T[0]==3 [ T T T T T T | F F F F T T T ]
# T[1]==6 [ T T T T T T | F T T T T T T ]
# T[2]==7 [ T T T T T T | T T T T T T T ]
# T[3]==4 [ T T T T T T | F F F T T T T ]
# For one element in the batch, mask[0,0,:,:] with T[0]==3 might look like:
# kv cache | sequences
# [ T T T T T T | T F F F F F F ]
# [ T T T T T T | F T F F F F F ]
# [ T T T T T T | F F T F F F F ]
# [ T T T T T T | F F F T F F F ]
# [ T T T T T T | F F F F T F F ]
# [ T T T T T T | F F F F T T F ]
# [ T T T T T T | F F F F T T T ]
# If a custom cache_mask is given, the kv cache part above may be different
# Initial forward pass (conditioning sequences only)
k_cache = all_k_cache[...,:in_cache_size+maxT,:]
v_cache = all_v_cache[...,:in_cache_size+maxT,:]
self.forward(
seq=seq, pos=pos[:,:maxT], mask=mask,
k_cache=k_cache, v_cache=v_cache, inplace_kv=True
)
pos = pos[:,maxT:]
##############################
# === Sequential Rollout === #
##############################
# Build attention mask for rollout [K 1 1 maxT+n_rollout_tokens], clone to free memory
mask = F.pad(mask[...,[-1],:].clone(), (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0))
for i, t in enumerate(T):
# If t==0, the fill_diagonal call above unmasked the last pad token,
# so we need to re-mask it before we start rolling out
if t == 0:
mask[i, ..., in_cache_size+maxT-1] = float('-inf')
# The above mask[:,0,0,:] might look something like this:
# kv cache | sequences | rollout
# T[0]==3 [ T T T T T T | F F F F T T T | T T T T ... T T T T ]
# T[1]==6 [ T T T T T T | F T T T T T T | T T T T ... T T T T ]
# T[2]==7 [ T T T T T T | T T T T T T T | T T T T ... T T T T ]
# T[3]==4 [ T T T T T T | F F F T T T T | T T T T ... T T T T ]
# We construct this mask once, then slice off part of the right side at each rollout step
rollout_seq = []
# Rollout
for i in tqdm.tqdm(range(n_rollout_tokens), desc='Rollout', unit='tok', disable=(not verbose or n_rollout_tokens==0)):
if (i % tpp) == 0:
patch_number = i // tpp
if policy is None:
# Use provided order
next_token = idx[patch_number]
else:
# Use callback to select the next patch
policy_cache_mask = mask[..., :in_cache_size+maxT+i]
idx_of_next_idx = policy(
idx=idx[patch_number:], # [N]
pos=pos[:, i:], # [K N 4]
weights=None if weights is None else weights[:, i:], # [K N]
k_cache=k_cache,
v_cache=v_cache,
cache_mask=policy_cache_mask, # [K 1 1 n_tok]
kvcache=dict(k_cache=k_cache, v_cache=v_cache, cache_mask=policy_cache_mask),
all_k_cache=all_k_cache,
all_v_cache=all_v_cache,
n_tokens_per_patch=n_tokens_per_patch,
sample_range=sample_range,
idx_pos=pos[:, i::tpp], # [K I 4]
idx_weights=None if weights is None else weights[:, i::tpp], # [K I]
device=idx.device,
)
del policy_cache_mask
# Move the patch patch_number+idx_of_next_idx to the next position by swapping
if idx_of_next_idx != 0: # Don't bother if it's already next
i1, i2 = patch_number, patch_number + int(idx_of_next_idx)
idx[[i1,i2]] = idx[[i2,i1]]
# [K N 4] -> [K I tpp 4] -swap-idxs-> [K I tpp 4] -> [K N 4]
pos = pos.reshape(K, I, tpp, 4)
pos[:,[i1,i2]] = pos[:,[i2,i1]]
pos = pos.reshape(K, -1, 4)
next_token = idx[patch_number]
rollout_seq.append(next_token)
# Forward call
k_cache = all_k_cache[...,:in_cache_size+maxT+i+1,:]
v_cache = all_v_cache[...,:in_cache_size+maxT+i+1,:]
logits, _ = self.forward(
seq=next_token.expand(K, 1), # [K 1]
pos=pos[:, [i]], # [K 1 4]
mask=mask[..., :in_cache_size+maxT+i+1], # [K 1 1 n_prev+1]
k_cache=k_cache,
v_cache=v_cache,
inplace_kv=True
)
# Weighted sum of logits [K 1 V] -> [1 V]
if weights is not None:
w = weights[:,[i]] # [K nnt] -> [K 1]
logits = w.T @ logits.squeeze(1) # [1 V]
else:
logits = logits.squeeze(0) # [1 1 V] -> [1 V]
# If next is an index token, we just needed to cache the previous token
if ((i + 1) % tpp) == 0:
if return_idx_logits:
all_logits[i].copy_(logits.squeeze())
continue
if return_logits:
all_logits[i].copy_(logits.squeeze())
# Sample from logits
next_token = self.sample_logits(
logits,
temp=temp[i] if temp else None,
post_temp=post_temp[i] if post_temp else None,
top_k=top_k[i] if top_k else None,
top_p=top_p[i] if top_p else None,
min_p=min_p[i] if min_p else None,
sample_range=sample_range,
blacklist=blacklist
).squeeze(0)
rollout_seq.append(next_token)
if n_rollout_tokens > 0:
rollout_seq = torch.stack(rollout_seq)
if n_rollout_tokens == nnt:
ret = (rollout_seq,)
if return_logits:
ret = (*ret, all_logits)
if return_kv:
ret = (*ret, dict(k_cache=all_k_cache, v_cache=all_v_cache, cache_mask=mask))
return ret if len(ret) > 1 else ret[0]
############################
# === Parallel Rollout === #
############################
npp = n_par_patches
npt = npp * tpp # num parallel tokens
idx = idx[-npp:]
# Build attention mask for parallel part [K 1 npp n_past+npp]
if unmask_parallel:
par_mask = torch.zeros(1, dtype=mask.dtype, device=device).expand(K, 1, npp, npp)
else:
par_mask = torch.full((npp, npp), float('-inf'), dtype=mask.dtype, device=device)
par_mask.fill_diagonal_(0.0)
par_mask = par_mask.expand(K, 1, npp, npp)
mask = mask.expand(K, 1, npp, -1)
# Initial shape is [K 1 npp n_past]. Before each iter, we append par_mask [K 1 npp npp] along the last dim
# Reshape positions for parallel passes
# [K nnt 4] -trim-> [K npt 4] -> [K npp tpp 4] -> [K tpp npp 4] -> [K npt 4]
pos = pos[:,-npt:].reshape(K, npp, tpp, 4).transpose(1, 2).reshape(K, npt, 4)
# Reshape scheduled properties (transpose similarly to positions)
# [nnt] -trim-> [npt] -> [npp tpp] -> [tpp npp] -> [npt]
if temp is not None:
temp = np.array(temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if post_temp is not None:
post_temp = np.array(post_temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if top_k is not None:
top_k = np.array(top_k[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if top_p is not None:
top_p = np.array(top_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if min_p is not None:
min_p = np.array(min_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist()
if weights is not None:
# [K nnt] -trim-> [K npt] -> [K npp tpp] -> [K tpp npp] -> [K npt]
weights = weights[:,-npt:].reshape(K, npp, tpp).transpose(1, 2).reshape(K, npt)
next_tokens = idx
parallel_seq = [next_tokens]
# Run parallel passes
for i in range(tpp if run_last_parallel_tokens else (tpp - 1)):
mask = torch.cat([mask, par_mask], dim=-1)
parallel_slice = slice(i*npp, (i+1)*npp)
k_cache = all_k_cache[...,:in_cache_size+maxT+n_rollout_tokens+(i+1)*npp,:]
v_cache = all_v_cache[...,:in_cache_size+maxT+n_rollout_tokens+(i+1)*npp,:]
logits, _ = self.forward(
seq=next_tokens.expand(K, npp), # [K npp]
pos=pos[:,parallel_slice], # [K npp 4]
mask=mask, # [K 1 npp n_past+npp]
k_cache=k_cache,
v_cache=v_cache,
inplace_kv=True
)
# Weighted sum of logits [K npp V] -> [npp V]
if weights is not None:
w = weights[:,parallel_slice] # [K npt] -> [K npp]
logits = (logits * w.unsqueeze(-1)).sum(0) # [K npp V] -> [npp V]
else:
logits = logits.squeeze(0) # [1 npp V] -> [npp V]
if return_logits and (i < tpp - 1 or return_idx_logits):
# Store the logits with a stride so we don't need to transpose later
# NOTE: If return_idx_logits=False, we have tpp-1 instead of tpp
stride = tpp if return_idx_logits else (tpp - 1)
all_logits[n_rollout_tokens+i::stride].copy_(logits)
if i == (tpp - 1):
# We just needed to compute logits and/or KV to return them; no need to predict
break
# Sample from logits
# TODO: Index using parallel_slice instead of -1 to support scheduled parameters
next_tokens = self.sample_logits(
logits,
temp=temp[-1] if temp else None,
post_temp=post_temp[-1] if post_temp else None,
top_k=top_k[-1] if top_k else None,
top_p=top_p[-1] if top_p else None,
min_p=min_p[-1] if min_p else None,
sample_range=sample_range,
blacklist=blacklist
)
parallel_seq.append(next_tokens)
# [tpp npp] -> [npp tpp] -> [npt]
parallel_seq = torch.stack(parallel_seq).transpose(0, 1).flatten()
if return_kv:
# Transpose only the last npp (num parallel patches) patches
# [... npt E] -> [... tpp npp E] -> [... npp tpp E] -> [... npt E]
edims = all_k_cache.shape[:-2]
par_k = all_k_cache[...,-npt:,:].reshape(*edims, tpp, npp, -1).transpose(-2, -3).reshape(*edims, npt, -1)
par_v = all_v_cache[...,-npt:,:].reshape(*edims, tpp, npp, -1).transpose(-2, -3).reshape(*edims, npt, -1)
all_k_cache[...,-npt:,:].copy_(par_k.clone())
all_v_cache[...,-npt:,:].copy_(par_v.clone())
del par_k, par_v
# [K 1 npp N] -> [K 1 1 N], clone to free memory, doesn't matter which row we take (only different in last npt cols)
# Unmask the last npt cols (if necessary) to make the parallel part fully unmasked
mask = mask[...,[-1],:].clone()
if not unmask_parallel:
mask[...,-npt:] = 0.0
ret = (torch.cat([rollout_seq, parallel_seq]),) if n_rollout_tokens > 0 else (parallel_seq,)
if return_logits:
ret = (*ret, all_logits)
if return_kv:
ret = (*ret, dict(k_cache=all_k_cache, v_cache=all_v_cache, cache_mask=mask))
return ret if len(ret) > 1 else ret[0]