| | """ |
| | PSI Model Definition |
| | """ |
| |
|
| |
|
| | import math |
| | import importlib |
| | from typing import Tuple, Union, List, Optional, Callable, Dict |
| | from transformers import PreTrainedModel |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import numpy as np |
| | import tqdm |
| |
|
| | from .config import PSIConfig |
| | from .modeling import ( |
| | RMSNorm, PSIBlock, PartitionedEmbedding, PartitionedLinear |
| | ) |
| |
|
| | try: |
| | xm = importlib.import_module('torch_xla.core.xla_model') |
| | xs = importlib.import_module('torch_xla.distributed.spmd.xla_sharding') |
| | except ImportError: |
| | xm = None |
| | xs = None |
| |
|
| |
|
| |
|
| | class PSI(PreTrainedModel): |
| | config_class = PSIConfig |
| |
|
| | |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | if hasattr(config, "partition_embedding") and config.partition_embedding: |
| | token_embedding = PartitionedEmbedding(config.vocab_size, config.n_embd) |
| | lm_head = PartitionedLinear(config.n_embd, config.vocab_size, bias=False) |
| | else: |
| | token_embedding = nn.Embedding(config.vocab_size, config.n_embd) |
| | if hasattr(config, "n_lm_vocab") and config.n_lm_vocab is not None: |
| | n_lm_vocab = config.n_lm_vocab |
| | else: |
| | n_lm_vocab = config.vocab_size |
| | lm_head = nn.Linear(config.n_embd, n_lm_vocab, bias=False) |
| |
|
| | self.transformer = nn.ModuleDict(dict( |
| | token_embedding = token_embedding, |
| | channel_embedding = nn.Embedding(config.channel_size, config.n_embd), |
| | drop = nn.Dropout(config.dropout), |
| | h = nn.ModuleList([PSIBlock(config) for _ in range(config.n_layer)]), |
| | ln_f = RMSNorm(config.n_embd, bias=config.bias), |
| | )) |
| | self.lm_head = lm_head |
| |
|
| | |
| | self.apply(self._init_weights) |
| | |
| | for pn, p in self.named_parameters(): |
| | if pn.endswith('c_proj.weight'): |
| | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) |
| | |
| | if hasattr(config, "tie_weights") and config.tie_weights: |
| | if hasattr(config, "partition_embedding") and config.partition_embedding: |
| | for i in range(len(self.transformer.token_embedding.embedding_layers)): |
| | self.lm_head.linear_layers[i].weight = self.transformer.token_embedding.embedding_layers[i].weight |
| | else: |
| | self.lm_head.weight = self.transformer.token_embedding.weight |
| | |
| | |
| | xla_device_available = False |
| | if xm is not None: |
| | try: |
| | device_kind = xm.xla_device_kind() |
| | if device_kind is not None: |
| | xla_device_available = True |
| | except RuntimeError: |
| | pass |
| | |
| | self.unsharded_param_count = self.get_num_params() |
| |
|
| | def _init_weights(self, module): |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
|
| | def get_num_params(self): |
| | """Return the number of parameters in the model.""" |
| | return sum(p.numel() for p in self.parameters()) |
| |
|
| | |
| | |
| |
|
| | def forward( |
| | self, |
| | seq: torch.Tensor, |
| | pos: torch.Tensor, |
| | tgt: torch.Tensor = None, |
| | mask: torch.Tensor = None, |
| | k_cache: torch.Tensor = None, |
| | v_cache: torch.Tensor = None, |
| | return_kv: bool = False, |
| | inplace_kv: bool = False, |
| | output_hidden_states: bool = False, |
| | ) -> torch.Tensor: |
| | """ |
| | Forward pass of the model |
| | |
| | Parameters: |
| | seq (torch.Tensor) of size b, t: The input sequence |
| | pos (torch.Tensor) of size b, t, d: The positional indices of the sequence of shape (batch, tokens, dimensions) |
| | They consist of x, y, t and c coordinates, where x, y are the spatial coordinates of the patch, |
| | t is the time index and c is the channel index |
| | tgt (torch.Tensor) of size b, t_tgt: The target sequence |
| | mask (torch.Tensor) of size b, t, t: The mask of the sequence |
| | k_cache (torch.Tensor) of size n_layer, b, n_head, n, n_embd//n_head: A k-cache to prepend |
| | v_cache (torch.Tensor) of size n_layer, b, n_head, n, n_embd//n_head: A v-cache to prepend |
| | return_kv (bool): If True, returns (logits, k, v). Ignored if tgt != None |
| | inplace_kv (bool): If True, k_cache/v_cache are modified in-place. They must be sufficiently large to store |
| | the new tokens, and the last N tokens will be overwritten. If False (default), the input kv will not be |
| | modified, and a concat operation will be used instead. No effect if k_cache/v_cache are None. |
| | |
| | Returns: |
| | torch.Tensor: The logits of the model. Size b, t if tgt is None, else b, t_tgt |
| | if tgt != None: |
| | torch.Tensor: The cross entropy loss between the logits and tgt |
| | elif return_k: |
| | torch.Tensor: the k-cache |
| | torch.Tensor: the v-cache |
| | """ |
| |
|
| | st_pos = pos[:, :, :-1] |
| | channel_pos = pos[:, :, -1] |
| |
|
| | |
| | tok_emb = self.transformer.token_embedding(seq) |
| | channel_emb = self.transformer.channel_embedding(channel_pos) |
| | x = self.transformer.drop(tok_emb + channel_emb) |
| |
|
| | if output_hidden_states: |
| | hidden_states = [x] |
| |
|
| | k_list, v_list = [], [] |
| | for i, block in enumerate(self.transformer.h): |
| | x = block(x, pos=st_pos, mask=mask, |
| | k_cache=None if k_cache is None else k_cache[i], |
| | v_cache=None if v_cache is None else v_cache[i], |
| | return_kv=return_kv, inplace_kv=inplace_kv) |
| | if return_kv: |
| | x, k, v = x |
| | k_list.append(k) |
| | v_list.append(v) |
| | if output_hidden_states: |
| | hidden_states.append(x) |
| |
|
| | x = self.transformer.ln_f(x) |
| | if output_hidden_states: |
| | hidden_states.append(x) |
| |
|
| | |
| | if tgt is None: |
| | logits = self.lm_head(x) |
| | if output_hidden_states: |
| | logits = {"logits": logits, "hidden_states": hidden_states} |
| | if return_kv: |
| | if inplace_kv: |
| | |
| | return logits, k_cache, v_cache |
| | else: |
| | return logits, torch.stack(k_list), torch.stack(v_list) |
| | return logits, None |
| | |
| | |
| | logits = self.lm_head(x[:, -tgt.size(1):]) |
| | if output_hidden_states: |
| | logits = {"logits": logits, "hidden_states": hidden_states} |
| |
|
| | loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), tgt.reshape(-1), ignore_index=-1) |
| | return logits, loss |
| |
|
| |
|
| | |
| |
|
| | @torch.no_grad() |
| | def sample_logits(self, |
| | logits: torch.FloatTensor, |
| | temp: Optional[float] = None, |
| | post_temp: Optional[float] = None, |
| | top_k: Optional[int] = None, |
| | top_p: Optional[float] = None, |
| | min_p: Optional[float] = None, |
| | sample_range: Optional[Tuple[int,int]] = None, |
| | blacklist: Optional[Union[List[int], torch.LongTensor]] = None |
| | ) -> torch.LongTensor: |
| | """ |
| | Samples an integer from the distribution of logits |
| | |
| | Parameters: |
| | logits (torch.FloatTensor): The logits of the distribution |
| | temp (float): The temperature of the sampling, if 0.0, then argmax is used |
| | top_k (int): The number of top k tokens to consider during sampling |
| | top_p (float): The cumulative probability threshold for nucleus (top-p) sampling |
| | min_p (float): The minimum probability threshold factor for min-p sampling |
| | blacklist (Union[List[int], torch.LongTensor]): The list of tokens to blacklist during sampling |
| | Returns: |
| | torch.LongTensor: The sampled integers |
| | """ |
| | if isinstance(temp, list): |
| | temp = temp[0] |
| | if isinstance(post_temp, list): |
| | post_temp = post_temp[0] |
| | if isinstance(top_k, list): |
| | top_k = top_k[0] |
| | if isinstance(top_p, list): |
| | top_p = top_p[0] |
| | assert temp is None or temp >= 0.0 |
| | assert post_temp is None or post_temp >= 0.0 |
| | assert top_k is None or top_k > 0 |
| | assert top_p is None or top_p >= 0.0 |
| | assert min_p is None or 0.0 <= min_p <= 1.0 |
| | assert sample_range is None or ( |
| | sample_range[0] < sample_range[1] and |
| | sample_range[0] >= 0 and |
| | sample_range[1] <= logits.shape[-1] |
| | ) |
| |
|
| | |
| | if blacklist is not None: |
| | logits[...,blacklist] = float('-inf') |
| | if sample_range is not None: |
| | logits = logits[...,sample_range[0]:sample_range[1]] |
| | token_offset = sample_range[0] |
| | else: |
| | token_offset = 0 |
| |
|
| | |
| | if (temp is not None and temp == 0.0) or (post_temp is not None and post_temp == 0.0): |
| | return token_offset + torch.argmax(logits, dim=-1) |
| | if temp is not None and temp != 1.0: |
| | logits.div_(temp) |
| |
|
| | |
| | |
| | if top_k is not None or top_p is not None: |
| | logits, order = torch.sort(logits, dim=-1, descending=True) |
| | else: |
| | order = None |
| |
|
| | |
| | if top_k is not None: |
| | logits = logits[...,:top_k] |
| |
|
| | |
| | if top_p is not None: |
| | probs = F.softmax(logits, dim=-1) |
| | cumulative_probs = probs.cumsum_(dim=-1) |
| | idxs_to_remove = cumulative_probs > top_p |
| | |
| | logits[...,1:][idxs_to_remove[...,:-1]] = float('-inf') |
| | del probs, cumulative_probs, idxs_to_remove |
| |
|
| | |
| | if min_p is not None: |
| | probs = F.softmax(logits, dim=-1) |
| | maxprob = probs[...,[0]] if order is not None else torch.max(probs, dim=-1, keepdim=True).values |
| | logits[probs < maxprob * min_p] = float('-inf') |
| | del probs, maxprob |
| |
|
| | |
| | if post_temp is not None and post_temp != 1.0: |
| | logits.div_(post_temp) |
| |
|
| | |
| | orig_shape = logits.shape |
| | probs = torch.softmax(logits, dim=-1, out=logits) |
| | |
| | flat_probs = probs.view(-1, probs.size(-1)) |
| | |
| | sampled = torch.multinomial(flat_probs, num_samples=1) |
| | |
| | sampled = sampled.view(*orig_shape[:-1]) |
| |
|
| | |
| | if order is not None: |
| | sampled = torch.gather(order, dim=-1, index=sampled.unsqueeze(-1)).squeeze(-1) |
| | return token_offset + sampled |
| |
|
| |
|
| | @torch.no_grad() |
| | def rollout_patches(self, |
| | seq: Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]], |
| | pos: Union[torch.LongTensor, List[torch.LongTensor]], |
| | idx: torch.LongTensor, |
| | n_tokens_per_patch: int = 5, |
| | n_seq_patches: int = -1, |
| | weights: Optional[Union[List[float], torch.Tensor]] = None, |
| | k_cache: Optional[torch.Tensor] = None, |
| | v_cache: Optional[torch.Tensor] = None, |
| | cache_mask: Optional[torch.Tensor] = None, |
| | policy: Callable[..., torch.LongTensor] = None, |
| | *, |
| | unmask_parallel: bool = False, |
| | return_logits: bool = False, |
| | return_idx_logits: bool = True, |
| | return_kv: bool = False, |
| | verbose: bool = True, |
| | temp: Optional[Union[float, List[float]]] = None, |
| | post_temp: Optional[Union[float, List[float]]] = None, |
| | top_k: Optional[Union[int, List[int]]] = None, |
| | top_p: Optional[Union[float, List[float]]] = None, |
| | min_p: Optional[Union[float, List[float]]] = None, |
| | sample_range: Optional[Tuple[int, int]] = None, |
| | blacklist: Optional[Union[List[int], torch.LongTensor]] = None |
| | ) -> Union[ |
| | torch.LongTensor, |
| | Tuple[torch.LongTensor, torch.Tensor], |
| | Tuple[torch.LongTensor, Dict[str, torch.Tensor]], |
| | Tuple[torch.LongTensor, torch.Tensor, Dict[str, torch.Tensor]], |
| | ]: |
| | """ |
| | K = number of given sequences (1 if seq is not a list) |
| | T = length of a conditioning sequence (per sequence) |
| | N = total length of conditioning + generated tokens (per sequence) |
| | I = number of index tokens to roll out |
| | max(T) = maximum of T across sequences |
| | num_new_tokens = len(idx) * n_tokens_per_patch |
| | |
| | ***Tips for long rollouts***: |
| | 1. Use `gc.collect()` and then `torch.cuda.empty_cache()` (in that order) |
| | 2. Avoid fragmentation wherever possible. This method needs to allocate the entire KV cache |
| | contiguously in memory. If you create tensors between rollouts, consider moving them to |
| | CPU or cloning them to defragment VRAM. |
| | 3. Even if a single N-token rollout fits in VRAM, running two consecutive rollouts (e.g. running |
| | n1 then giving its KV cache to n2, with n1 + n2 = N) may not fit, because reallocating the |
| | KV cache will duplicate memory. To avoid this, try moving the cache to CPU first: |
| | ``` |
| | seq1, kvcache = predictor.rollout_patches(..., return_kv=True) |
| | kvcache = { k: v.cpu() for k, v in kvcache.items() } |
| | gc.collect(); torch.cuda.empty_cache() |
| | seq2 = predictor.rollout_patches(..., **kvcache) |
| | ``` |
| | With this, the new KV cache will be allocated on GPU, and the CPU cache will be copied into it. |
| | |
| | TODO: Support temp/top_k/top_p/min_p scheduling with parallel. Currently uses index -1 for all parallel tokens |
| | |
| | Parameters: |
| | seq (Union[Optional[torch.LongTensor], List[Optional[torch.LongTensor]]]): |
| | [T], [K T], or list of [T] / sequence(s) to condition the generation on. None means empty sequence |
| | pos (Union[torch.LongTensor, List[torch.LongTensor]]): |
| | [N 4], [K N 4], or list of [N 4] / 4D position(s) corresponding to each sequence. If multiple, must have length K=len(seq) |
| | idx (torch.LongTensor): |
| | [I] / the patch indices to use in the rollout (same for all sequences) |
| | n_tokens_per_patch (int): |
| | number of tokens per patch, including patch index |
| | n_seq_patches (int): |
| | number of patches to roll out sequentially (-1 for all). The remaining patches will be parallel |
| | weights (Optional[Union[List[float], torch.Tensor]]): |
| | float weights for the logits produced by each sequence. If None and multiple sequences are given, defaults to all ones. |
| | If a list or 1D tensor, must have size K. If 2D, must have shape [K num_new_tokens]. |
| | If 2D, the first n_seq_patches patches will use the weights in order as expected, but the remaining (parallel) patches may be arbitrarily reordered. |
| | When using parallel and a 2D weight schedule, it is recommended to make the weights for parallel patches uniform for consistency |
| | k_cache (Optional[torch.Tensor]): |
| | optional k_cache to prepend to all seqs, broadcastable to shape [n_layer K n_head n_tok n_embd//n_head]. May be on a different device |
| | v_cache (Optional[torch.Tensor]): |
| | optional v_cache to prepend to all seqs, broadcastable to shape [n_layer K n_head n_tok n_embd//n_head]. May be on a different device |
| | cache_mask (Optional[torch.Tensor]): |
| | optional mask to be applied to the provided KV cache with shape [K 1 1 n_tok], where n_tok matches k_cache/v_cache. Useful when a KV cache is supplied |
| | for multiple conditioning sequences of different lengths, where the cache_mask indicates which elements of the cache should be attended to for each sequence. |
| | If k_cache/v_cache are given and cache_mask is None, the cache will be fully unmasked. May be on a different device |
| | policy (Callable[..., torch.LongTensor]): |
| | optional callback defining the policy for rollout order. Must accept an argument `idx` (torch.LongTensor of shape [I]) with the candidate indices to generate next. |
| | Must return the *index* into the `idx` tensor to generate next, either an int or 0-dimensional torch.LongTensor. For example, given candidate indices [4,12,1023], |
| | return 2 to generate the patch with index 1023 next. Only used for the sequential part of the generation. The following kwargs are given |
| | - `idx` (torch.LongTensor of shape [I]) the candidate patch indices |
| | - `pos` (torch.LongTensor of shape [K N 4]): the remaining poses for all yet-ungenerated tokens in the same order as `idx` |
| | - `weights` (torch.Tensor of shape [K N]): the weights for all yet-ungenerated tokens, or None |
| | - `k_cache` (torch.Tensor) |
| | - `v_cache` (torch.Tensor) |
| | - `cache_mask` (torch.Tensor of shape [K 1 1 n_tok]) |
| | - `kvcache` (Dict[str, torch.Tensor]): the kvcache dict with keys 'k_cache', 'v_cache', and 'cache_mask' |
| | - `all_k_cache` (torch.Tensor): the entire preallocated k-cache, including uninitialized tokens |
| | - `all_v_cache` (torch.Tensor): the entire preallocated v-cache, including uninitialized tokens |
| | - `n_tokens_per_patch` (int) |
| | - `sample_range` (Optional[Tuple[int,int]]) |
| | - `idx_pos` (torch.LongTensor of shape [K I 4]): same as `pos`, but only for the candidate index tokens |
| | - `idx_weights` (torch.Tensor of shape [K I]): same as `weights`, but only for the candidate index tokens |
| | - `device` (torch.device)\n |
| | The callback must return the following |
| | - (Union[int, torch.LongTensor]): the *index* into the `idx` tensor with the patch token to generate next (*not* the value of the patch index itself) |
| | unmask_parallel (bool): |
| | if True, all parallel patches can attend to each other. If False (default), parallel patches can only attend to themselves |
| | return_logits (bool): |
| | return the logits of the sequence |
| | return_idx_logits (bool): |
| | if True (default), returns logits that would predict index tokens, so there is one set of logits for every returned token. If False, only returns |
| | logits used to sample content tokens, e.g. returns (n_tokens_per_patch - 1) sets of logits per patch. The latter may not need to compute logits |
| | for all tokens, so it may be more efficient for some computations (such as patchwise entropy). Ignored if return_logits=False |
| | return_kv (bool): |
| | return the KV cache(s) as a dict with keys 'k_cache', 'v_cache', and 'cache_mask', useful for downstream operations. If True and return_logits=False, |
| | returns (new_tokens, kvcache). If True and return_logits=True, returns (new_tokens, logits, kvcache). **All KVs are returned in patch-major order,** even if |
| | the rollout is partially or fully parallel. Note that KVs from parallel prediction are not computed causally |
| | |
| | Returns: |
| | torch.LongTensor: |
| | [num_new_tokens] the generated tokens only (the input sequence is not prepended) |
| | torch.Tensor: |
| | (optional) [n_tokens vocab_size] the logits of the sequence, where n_tokens depends on return_idx_logits |
| | Dict[str, torch.Tensor]: |
| | (optional) the KV cache, with the following key/value pairs |
| | - `k_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head] |
| | - `v_cache` (torch.Tensor) [n_layer K n_head n_tok n_embd//n_head] |
| | - `cache_mask` (torch.Tensor) [K 1 1 n_tok] |
| | """ |
| |
|
| | |
| | |
| | |
| |
|
| | if not isinstance(seq, list): |
| | seq = [seq] if seq is None or seq.ndim == 1 else list(seq) |
| | if not isinstance(pos, list): |
| | pos = [pos] if pos.ndim == 2 else list(pos) |
| |
|
| | nnt = idx.numel() * n_tokens_per_patch |
| | device = pos[0].device |
| | idtype = pos[0].dtype |
| | dtype = self.lm_head.weight.dtype |
| |
|
| | if weights is not None: |
| | if isinstance(weights, list): |
| | weights = torch.tensor(weights, dtype=dtype, device=device) |
| | if weights.ndim != 2: |
| | weights = weights.unsqueeze(-1).expand(-1, nnt) |
| | weights = weights.to(dtype).to(device) |
| | elif len(seq) > 1: |
| | weights = torch.ones(len(seq), nnt, dtype=dtype, device=device) |
| |
|
| | if n_seq_patches < 0: |
| | n_seq_patches = idx.shape[0] |
| | if temp is not None and not isinstance(temp, list): |
| | temp = [temp] * nnt |
| | if post_temp is not None and not isinstance(post_temp, list): |
| | post_temp = [post_temp] * nnt |
| | if top_k is not None and not isinstance(top_k, list): |
| | top_k = [top_k] * nnt |
| | if top_p is not None and not isinstance(top_p, list): |
| | top_p = [top_p] * nnt |
| | if min_p is not None and not isinstance(min_p, list): |
| | min_p = [min_p] * nnt |
| |
|
| | K = len(seq) |
| | I = idx.shape[0] |
| | T = [0 if s is None else s.shape[0] for s in seq] |
| | maxT = max(1, max(T)) |
| |
|
| | tpp = n_tokens_per_patch |
| | in_cache_size = 0 if k_cache is None else k_cache.shape[3] |
| | n_rollout_tokens = tpp * n_seq_patches |
| | n_par_patches = I - n_seq_patches |
| | return_idx_logits = return_logits and return_idx_logits |
| | run_last_parallel_tokens = return_idx_logits or return_kv |
| |
|
| | |
| | assert len(pos) == K, f'Expected seq and pos lists to have the same length, but got {K} and {len(pos)}' |
| | assert idx.ndim == 1 |
| | if weights is not None: |
| | assert weights.ndim == 2 and weights.shape == (K, nnt) |
| | assert I * tpp == nnt, f'Requested {nnt} new tokens, but ({I} idx tokens) * ({tpp} tok per patch) = {I*tpp} != {nnt}' |
| | assert 0 <= n_seq_patches <= I |
| | assert k_cache is None or k_cache.ndim == 5 |
| | assert v_cache is None or v_cache.ndim == 5 |
| | assert (k_cache is None) == (v_cache is None) |
| | assert k_cache is None or k_cache.shape[3] == v_cache.shape[3] |
| | if cache_mask is not None: |
| | assert cache_mask.ndim == 4 and cache_mask.shape[1] == 1 and cache_mask.shape[2] == 1 |
| | assert cache_mask.shape[-1] == in_cache_size, f'cache_mask ({cache_mask.shape[-1]} tokens) does not match the size of k_cache/v_cache ({in_cache_size} tokens)' |
| | for i, (s, p) in enumerate(zip(seq, pos)): |
| | if s is not None: |
| | assert s.ndim == 1, f'Expected all sequence tensors to be 1D, but got seq[{i}].ndim={s.ndim}' |
| | assert p.ndim == 2, f'Expected all position tensors to be 2D, but got pos[{i}].ndim={p.ndim}' |
| | assert p.shape[1] == 4, f'Expected all position tensors have shape (*,4), but got pos[{i}].shape[1]={p.shape[1]}' |
| | assert p.shape[0] == T[i] + nnt, f'Sequence {i}: With {T[i]} conditioning and {nnt} new tokens, expected pos[{i}].shape[0]={T[i]+nnt}, but got {p.shape[0]}' |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | n_kvcache = in_cache_size + maxT + nnt - (0 if run_last_parallel_tokens else n_par_patches) |
| | |
| | kv_shape = ( |
| | self.config.n_layer, K, self.config.n_head, |
| | n_kvcache, self.config.n_embd // self.config.n_head |
| | ) |
| | all_v_cache = torch.empty(kv_shape, dtype=dtype, device=device) |
| | all_k_cache = torch.empty(kv_shape, dtype=dtype, device=device) |
| | if in_cache_size > 0: |
| | all_k_cache[...,:in_cache_size,:].copy_(k_cache, non_blocking=True) |
| | all_v_cache[...,:in_cache_size,:].copy_(v_cache, non_blocking=True) |
| |
|
| | |
| | if return_logits: |
| | n_logits = n_rollout_tokens + (I - n_seq_patches) * (tpp if return_idx_logits else (tpp - 1)) |
| | all_logits = torch.empty((n_logits, self.config.vocab_size), dtype=dtype, device=device) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | seq = torch.stack([( |
| | torch.zeros(maxT, dtype=idtype, device=device) if s is None else |
| | F.pad(s, (maxT - s.shape[0], 0)) |
| | ) for s in seq]) |
| | |
| | pos = torch.stack([F.pad(p, (0, 0, maxT - t, 0)) for t, p in zip(T, pos)]) |
| |
|
| | |
| | |
| | mask = torch.zeros(K, 1, maxT, maxT, device=device) |
| | mask.masked_fill_(torch.ones_like(mask, dtype=torch.bool).triu(1), float('-inf')) |
| | for i, t in enumerate(T): |
| | mask[i, ..., :maxT-t] = float('-inf') |
| | |
| | |
| | |
| | |
| | mask[i, 0].fill_diagonal_(0.0) |
| | if k_cache is not None: |
| | if cache_mask is not None: |
| | mask = torch.cat([cache_mask.to(mask.device).expand((K, 1, maxT, -1)), mask], dim=-1) |
| | else: |
| | mask = F.pad(mask, (in_cache_size, 0, 0, 0, 0, 0, 0, 0)) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | k_cache = all_k_cache[...,:in_cache_size+maxT,:] |
| | v_cache = all_v_cache[...,:in_cache_size+maxT,:] |
| | self.forward( |
| | seq=seq, pos=pos[:,:maxT], mask=mask, |
| | k_cache=k_cache, v_cache=v_cache, inplace_kv=True |
| | ) |
| | pos = pos[:,maxT:] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | mask = F.pad(mask[...,[-1],:].clone(), (0, n_rollout_tokens, 0, 0, 0, 0, 0, 0)) |
| | for i, t in enumerate(T): |
| | |
| | |
| | if t == 0: |
| | mask[i, ..., in_cache_size+maxT-1] = float('-inf') |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | rollout_seq = [] |
| |
|
| | |
| | for i in tqdm.tqdm(range(n_rollout_tokens), desc='Rollout', unit='tok', disable=(not verbose or n_rollout_tokens==0)): |
| | if (i % tpp) == 0: |
| | patch_number = i // tpp |
| | if policy is None: |
| | |
| | next_token = idx[patch_number] |
| | else: |
| | |
| | policy_cache_mask = mask[..., :in_cache_size+maxT+i] |
| | idx_of_next_idx = policy( |
| | idx=idx[patch_number:], |
| | pos=pos[:, i:], |
| | weights=None if weights is None else weights[:, i:], |
| | k_cache=k_cache, |
| | v_cache=v_cache, |
| | cache_mask=policy_cache_mask, |
| | kvcache=dict(k_cache=k_cache, v_cache=v_cache, cache_mask=policy_cache_mask), |
| | all_k_cache=all_k_cache, |
| | all_v_cache=all_v_cache, |
| | n_tokens_per_patch=n_tokens_per_patch, |
| | sample_range=sample_range, |
| | idx_pos=pos[:, i::tpp], |
| | idx_weights=None if weights is None else weights[:, i::tpp], |
| | device=idx.device, |
| | ) |
| | del policy_cache_mask |
| | |
| | if idx_of_next_idx != 0: |
| | i1, i2 = patch_number, patch_number + int(idx_of_next_idx) |
| | idx[[i1,i2]] = idx[[i2,i1]] |
| | |
| | pos = pos.reshape(K, I, tpp, 4) |
| | pos[:,[i1,i2]] = pos[:,[i2,i1]] |
| | pos = pos.reshape(K, -1, 4) |
| | next_token = idx[patch_number] |
| | rollout_seq.append(next_token) |
| |
|
| | |
| | k_cache = all_k_cache[...,:in_cache_size+maxT+i+1,:] |
| | v_cache = all_v_cache[...,:in_cache_size+maxT+i+1,:] |
| | logits, _ = self.forward( |
| | seq=next_token.expand(K, 1), |
| | pos=pos[:, [i]], |
| | mask=mask[..., :in_cache_size+maxT+i+1], |
| | k_cache=k_cache, |
| | v_cache=v_cache, |
| | inplace_kv=True |
| | ) |
| |
|
| | |
| | if weights is not None: |
| | w = weights[:,[i]] |
| | logits = w.T @ logits.squeeze(1) |
| | else: |
| | logits = logits.squeeze(0) |
| |
|
| | |
| | if ((i + 1) % tpp) == 0: |
| | if return_idx_logits: |
| | all_logits[i].copy_(logits.squeeze()) |
| | continue |
| | if return_logits: |
| | all_logits[i].copy_(logits.squeeze()) |
| |
|
| | |
| | next_token = self.sample_logits( |
| | logits, |
| | temp=temp[i] if temp else None, |
| | post_temp=post_temp[i] if post_temp else None, |
| | top_k=top_k[i] if top_k else None, |
| | top_p=top_p[i] if top_p else None, |
| | min_p=min_p[i] if min_p else None, |
| | sample_range=sample_range, |
| | blacklist=blacklist |
| | ).squeeze(0) |
| | rollout_seq.append(next_token) |
| |
|
| | if n_rollout_tokens > 0: |
| | rollout_seq = torch.stack(rollout_seq) |
| | if n_rollout_tokens == nnt: |
| | ret = (rollout_seq,) |
| | if return_logits: |
| | ret = (*ret, all_logits) |
| | if return_kv: |
| | ret = (*ret, dict(k_cache=all_k_cache, v_cache=all_v_cache, cache_mask=mask)) |
| | return ret if len(ret) > 1 else ret[0] |
| | |
| | |
| | |
| | |
| |
|
| | npp = n_par_patches |
| | npt = npp * tpp |
| | idx = idx[-npp:] |
| | |
| | |
| | if unmask_parallel: |
| | par_mask = torch.zeros(1, dtype=mask.dtype, device=device).expand(K, 1, npp, npp) |
| | else: |
| | par_mask = torch.full((npp, npp), float('-inf'), dtype=mask.dtype, device=device) |
| | par_mask.fill_diagonal_(0.0) |
| | par_mask = par_mask.expand(K, 1, npp, npp) |
| | mask = mask.expand(K, 1, npp, -1) |
| | |
| |
|
| | |
| | |
| | pos = pos[:,-npt:].reshape(K, npp, tpp, 4).transpose(1, 2).reshape(K, npt, 4) |
| |
|
| | |
| | |
| | if temp is not None: |
| | temp = np.array(temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() |
| | if post_temp is not None: |
| | post_temp = np.array(post_temp[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() |
| | if top_k is not None: |
| | top_k = np.array(top_k[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() |
| | if top_p is not None: |
| | top_p = np.array(top_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() |
| | if min_p is not None: |
| | min_p = np.array(min_p[-npt:]).reshape(npp, tpp).transpose(0, 1).flatten().tolist() |
| | if weights is not None: |
| | |
| | weights = weights[:,-npt:].reshape(K, npp, tpp).transpose(1, 2).reshape(K, npt) |
| |
|
| | next_tokens = idx |
| | parallel_seq = [next_tokens] |
| |
|
| | |
| | for i in range(tpp if run_last_parallel_tokens else (tpp - 1)): |
| | mask = torch.cat([mask, par_mask], dim=-1) |
| | parallel_slice = slice(i*npp, (i+1)*npp) |
| | k_cache = all_k_cache[...,:in_cache_size+maxT+n_rollout_tokens+(i+1)*npp,:] |
| | v_cache = all_v_cache[...,:in_cache_size+maxT+n_rollout_tokens+(i+1)*npp,:] |
| | logits, _ = self.forward( |
| | seq=next_tokens.expand(K, npp), |
| | pos=pos[:,parallel_slice], |
| | mask=mask, |
| | k_cache=k_cache, |
| | v_cache=v_cache, |
| | inplace_kv=True |
| | ) |
| |
|
| | |
| | if weights is not None: |
| | w = weights[:,parallel_slice] |
| | logits = (logits * w.unsqueeze(-1)).sum(0) |
| | else: |
| | logits = logits.squeeze(0) |
| |
|
| | if return_logits and (i < tpp - 1 or return_idx_logits): |
| | |
| | |
| | stride = tpp if return_idx_logits else (tpp - 1) |
| | all_logits[n_rollout_tokens+i::stride].copy_(logits) |
| | if i == (tpp - 1): |
| | |
| | break |
| |
|
| | |
| | |
| | next_tokens = self.sample_logits( |
| | logits, |
| | temp=temp[-1] if temp else None, |
| | post_temp=post_temp[-1] if post_temp else None, |
| | top_k=top_k[-1] if top_k else None, |
| | top_p=top_p[-1] if top_p else None, |
| | min_p=min_p[-1] if min_p else None, |
| | sample_range=sample_range, |
| | blacklist=blacklist |
| | ) |
| | parallel_seq.append(next_tokens) |
| |
|
| | |
| | parallel_seq = torch.stack(parallel_seq).transpose(0, 1).flatten() |
| | if return_kv: |
| | |
| | |
| | edims = all_k_cache.shape[:-2] |
| | par_k = all_k_cache[...,-npt:,:].reshape(*edims, tpp, npp, -1).transpose(-2, -3).reshape(*edims, npt, -1) |
| | par_v = all_v_cache[...,-npt:,:].reshape(*edims, tpp, npp, -1).transpose(-2, -3).reshape(*edims, npt, -1) |
| | all_k_cache[...,-npt:,:].copy_(par_k.clone()) |
| | all_v_cache[...,-npt:,:].copy_(par_v.clone()) |
| | del par_k, par_v |
| | |
| | |
| | mask = mask[...,[-1],:].clone() |
| | if not unmask_parallel: |
| | mask[...,-npt:] = 0.0 |
| |
|
| | ret = (torch.cat([rollout_seq, parallel_seq]),) if n_rollout_tokens > 0 else (parallel_seq,) |
| | if return_logits: |
| | ret = (*ret, all_logits) |
| | if return_kv: |
| | ret = (*ret, dict(k_cache=all_k_cache, v_cache=all_v_cache, cache_mask=mask)) |
| | return ret if len(ret) > 1 else ret[0] |
| |
|
| |
|