| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | from dataclasses import dataclass
|
| | from collections.abc import Iterable
|
| | from typing import Any, Dict, Optional, Tuple, Union
|
| |
|
| | import torch
|
| | try:
|
| | import torch_npu
|
| | except ImportError as e:
|
| | pass
|
| | import torch.distributions as dists
|
| | from torch.nn import functional as F
|
| |
|
| | from transformers.cache_utils import Cache, DynamicCache
|
| |
|
| |
|
| | def top_p_logits(logits, top_p=None):
|
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| | sorted_indices_to_remove = cumulative_probs > top_p
|
| |
|
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| | sorted_indices_to_remove[..., 0] = 0
|
| |
|
| | mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
|
| | mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
|
| | logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
|
| | return logits
|
| |
|
| |
|
| | def top_k_logits(logits, top_k=None):
|
| | top_k = min(top_k, logits.size(-1))
|
| |
|
| | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| | logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
|
| | return logits
|
| |
|
| |
|
| | def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
|
| |
|
| | if temperature > 0:
|
| | logits = logits / temperature
|
| | if top_p is not None and top_p < 1:
|
| | logits = top_p_logits(logits, top_p)
|
| | if top_k is not None:
|
| | logits = top_k_logits(logits, top_k)
|
| | probs = torch.softmax(logits, dim=-1)
|
| |
|
| | if temperature > 0:
|
| | try:
|
| | x0 = dists.Categorical(probs=probs).sample()
|
| | confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
|
| | except:
|
| | confidence, x0 = probs.max(dim=-1)
|
| | else:
|
| | confidence, x0 = probs.max(dim=-1)
|
| |
|
| | if margin_confidence:
|
| | sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
|
| |
|
| | top1_probs = sorted_probs[:, 0]
|
| | top2_probs = sorted_probs[:, 1]
|
| |
|
| | confidence = top1_probs - top2_probs
|
| |
|
| | if neg_entropy:
|
| | epsilon = 1e-10
|
| | log_probs = torch.log(probs + epsilon)
|
| | confidence = torch.sum(probs * log_probs, dim=-1)
|
| |
|
| | return confidence, x0
|
| |
|
| |
|
| | class BlockDynamicCache(DynamicCache):
|
| | """
|
| | When `skip_cache_update` is True, this class does NOT update the cached key and value states.
|
| | Instead, it concatenates the current states with the original cached states along the sequence dimension
|
| | and returns the result.
|
| |
|
| | Example:
|
| |
|
| | ```python
|
| | >>> past_key_values = BlockDynamicCache()
|
| | >>> past_key_values.skip_cache_update = True
|
| | >>> outputs.past_key_values
|
| | ```
|
| | """
|
| | def __init__(self, _distributed_cache_data: Optional[Iterable] = None) -> None:
|
| | """
|
| | Initialize a BlockDynamicCache instance.
|
| |
|
| | skip_cache_update is False by default.
|
| | """
|
| | super().__init__(_distributed_cache_data)
|
| | self.skip_cache_update = False
|
| |
|
| | def update(
|
| | self,
|
| | key_states: torch.Tensor,
|
| | value_states: torch.Tensor,
|
| | layer_idx: int,
|
| | cache_kwargs: Optional[dict[str, Any]] = None,
|
| | ) -> tuple[torch.Tensor, torch.Tensor]:
|
| | """
|
| | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
| |
|
| | Behavior depends on the `skip_cache_update` flag:
|
| | - If `skip_cache_update` is True:
|
| | * Does NOT update the stored cache.
|
| | * Concatenates the current `key_states` and `value_states`
|
| | with the original cached states along the sequence dimension.
|
| | * Returns the concatenated result.
|
| | - If `skip_cache_update` is False:
|
| | * Uses the parent class update logic to update the cache.
|
| |
|
| | Parameters:
|
| | key_states (`torch.Tensor`):
|
| | The new key states to cache.
|
| | value_states (`torch.Tensor`):
|
| | The new value states to cache.
|
| | layer_idx (`int`):
|
| | The index of the layer to cache the states for.
|
| | cache_kwargs (`dict[str, Any]`, `optional`):
|
| | Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
| |
|
| | Returns:
|
| | Tuple[torch.Tensor, torch.Tensor]:
|
| | The updated key and value states after concatenation or update.
|
| | When `skip_cache_update=True`, returns the concatenated tensor without modifying cache.
|
| | When `skip_cache_update=False`, returns the result from the parent class.
|
| | """
|
| | if self.skip_cache_update:
|
| | key_cache = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
| | value_cache = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
| | return key_cache, value_cache
|
| | return super().update(key_states, value_states, layer_idx, cache_kwargs)
|
| |
|
| |
|
| |
|
| | @torch.no_grad()
|
| | def diffusion_generate(
|
| | model,
|
| | inputs: Optional[torch.Tensor] = None,
|
| | top_p: Optional[int] = None,
|
| | top_k: Optional[int] = None,
|
| | threshold: Optional[float] = 0.9,
|
| | num_small_blocks: Optional[int] = 1,
|
| | **kwargs,
|
| | ):
|
| | block_length=kwargs.pop("block_length", 32)
|
| | attention_mask = kwargs.pop("attention_mask", None)
|
| | alg = kwargs.get("alg", 'origin')
|
| | temperature = kwargs.get("temperature", 0.0)
|
| | mask_token_id = kwargs.get("mask_token_id", None)
|
| | eos_token_id = kwargs.get("eos_token_id", None)
|
| |
|
| | if mask_token_id is None:
|
| | raise ValueError("mask_token_id must be provided")
|
| |
|
| | if eos_token_id is None:
|
| | raise ValueError("eos_token_id must be provided")
|
| |
|
| | if inputs is None:
|
| | raise ValueError("inputs must be provided")
|
| |
|
| | if attention_mask is None:
|
| | raise ValueError("attention_mask must be provided")
|
| |
|
| |
|
| | input_ids = inputs
|
| |
|
| | if type(kwargs.get('max_new_tokens', None)) is int:
|
| | max_length = kwargs.get('max_new_tokens') + input_ids.shape[-1]
|
| | elif kwargs.get('max_length', None) is None:
|
| | raise ValueError("Pass max_new_tokens or max_length")
|
| |
|
| | prompt_length = input_ids.shape[1]
|
| | if (max_length - prompt_length) % block_length != 0:
|
| | raise ValueError(
|
| | f"The token length ({max_length - prompt_length}) "
|
| | f"cannot be evenly divided by the block length ({block_length})."
|
| | )
|
| |
|
| | num_blocks = (max_length - prompt_length) // block_length
|
| | device = model.device
|
| | position_ids = torch.arange(max_length, device=device).unsqueeze(0)
|
| |
|
| | x = F.pad(input_ids, (0, max_length - prompt_length), value=mask_token_id)
|
| |
|
| |
|
| | past_key_values = BlockDynamicCache()
|
| |
|
| | causal_mask = torch.tril(torch.ones(max_length, max_length, device=device, dtype=torch.bool))[None, None, :, :]
|
| |
|
| | padding_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
|
| | position_ids = padding_mask.long().cumsum(-1) - 1
|
| | position_ids.masked_fill_(padding_mask == 0, 1)
|
| |
|
| | padding_mask = torch.logical_and(
|
| | padding_mask.unsqueeze(1).unsqueeze(-2),
|
| | padding_mask.unsqueeze(1).unsqueeze(-1),
|
| | )
|
| | attention_mask = padding_mask & causal_mask
|
| |
|
| |
|
| |
|
| | if prompt_length > 0:
|
| | cur_x = x[:, :prompt_length]
|
| | cur_attn_mask = attention_mask[:, :, :prompt_length, :prompt_length]
|
| | cur_position_ids = position_ids[:, :prompt_length]
|
| | output = model(cur_x,
|
| | attention_mask=cur_attn_mask,
|
| | position_ids=cur_position_ids,
|
| | past_key_values=past_key_values,
|
| | use_cache=True
|
| | )
|
| | past_key_values = output.past_key_values
|
| |
|
| | logits = output.logits[:, -1:]
|
| | confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| | x[:, prompt_length:prompt_length + 1] = x0
|
| |
|
| |
|
| | for num_block in range(num_blocks):
|
| | block_start = prompt_length + num_block * block_length
|
| | block_end = prompt_length + (num_block + 1) * block_length
|
| | cur_x = x[:, block_start:block_end]
|
| | cur_attn_mask = attention_mask[:, :, block_start:block_end, :block_end]
|
| | cur_padding_mask = padding_mask[:, :, block_start:block_end, :block_end]
|
| | cur_position_ids = position_ids[:, block_start:block_end]
|
| |
|
| | small_block_length = block_length // num_small_blocks
|
| |
|
| | if block_length % num_small_blocks != 0:
|
| | raise ValueError(
|
| | f"block_length ({block_length}) must be divisible by num_small_blocks ({num_small_blocks})."
|
| | )
|
| |
|
| |
|
| | past_key_values.skip_cache_update = True
|
| | for small_block_idx in range(num_small_blocks):
|
| | small_block_start = small_block_idx * small_block_length
|
| | small_block_end = small_block_start + small_block_length
|
| |
|
| | while True:
|
| | sub_mask_index = (cur_x[:, small_block_start:small_block_end] == mask_token_id)
|
| | if sub_mask_index.sum() == 0:
|
| | break
|
| |
|
| | output = model(cur_x,
|
| | attention_mask=cur_padding_mask,
|
| | position_ids=cur_position_ids,
|
| | past_key_values=past_key_values,
|
| | use_cache=True)
|
| | logits = output.logits
|
| | logits = torch.cat([logits[:, :1], logits[:, :-1]], dim=1)
|
| | logits = logits[:, small_block_start:small_block_end]
|
| |
|
| | confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k,
|
| | neg_entropy=(alg == 'entropy'), margin_confidence=(alg == 'topk_margin'))
|
| | confidence = torch.where(sub_mask_index, confidence, -torch.inf)
|
| | transfer_index = (F.one_hot(torch.max(confidence, dim=1)[1], num_classes=small_block_length) == 1)
|
| | if alg == 'confidence_threshold':
|
| | transfer_index |= (confidence > threshold)
|
| | cur_x[:, small_block_start:small_block_end][transfer_index] = x0[transfer_index]
|
| |
|
| | if eos_token_id and (x[:, prompt_length:] == eos_token_id).any(dim=1).all():
|
| | return x
|
| |
|
| |
|
| | past_key_values.skip_cache_update = False
|
| | output = model(cur_x,
|
| | attention_mask=cur_attn_mask,
|
| | position_ids=cur_position_ids,
|
| | past_key_values=past_key_values,
|
| | use_cache=True,
|
| | )
|
| | past_key_values = output.past_key_values
|
| | if num_block < num_blocks - 1:
|
| | logits = output.logits[:, -1:]
|
| | confidence, x0 = sample_tokens(logits, temperature=temperature, top_p=top_p, top_k=top_k)
|
| | x[:, block_end:block_end + 1] = x0
|
| |
|
| | return x
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|