| """ |
| This module implements the TPTT model with linear attention (LiZA) and LoRA support. |
| Author : Fabien FURFARO |
| """ |
|
|
| import logging |
| import os |
| import re |
| import shutil |
| from typing import Dict, List, Optional |
|
|
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange |
| from huggingface_hub import hf_hub_download, list_repo_files |
| from peft import LoraConfig, get_peft_model |
| from safetensors import safe_open |
| from torch import nn |
| from transformers import AutoModelForCausalLM, DynamicCache, PreTrainedModel |
| from transformers.configuration_utils import PretrainedConfig |
|
|
| from .configuration_tptt import TpttConfig |
|
|
|
|
| def import_fla_ops(): |
| """flash linear attention""" |
| if torch.cuda.is_available(): |
| try: |
| from fla.ops.gla import fused_chunk_gla, fused_recurrent_gla |
|
|
| return fused_chunk_gla, fused_recurrent_gla |
| except ImportError: |
| return None, None |
| return None, None |
|
|
|
|
| fused_chunk_gla, fused_recurrent_gla = import_fla_ops() |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class LCache: |
| """ |
| Cache for storing intermediate states of linear attention layers. |
| Supports a sliding window if max_length is set. |
| """ |
|
|
| def __init__(self): |
| """ |
| Initialize the cache. |
| |
| Args: |
| max_length (Optional[int]): Maximum number of tokens to keep per layer (if set). |
| """ |
| self.states: List[Dict[str, torch.Tensor]] = [] |
| self.seen_tokens = 0 |
|
|
| def __getitem__(self, layer_idx: int) -> Optional[Dict[str, torch.Tensor]]: |
| """ |
| Retrieve the state for the given layer index, if it exists. |
| """ |
| if layer_idx < len(self.states): |
| return self.states[layer_idx] |
| return None |
|
|
| def update(self, layer_idx: int, **kwargs): |
| """ |
| Update the cache for a given layer. |
| If max_length is set, keep only the last max_length tokens in any sequence state. |
| """ |
| detached_kwargs = {} |
| for key, value in kwargs.items(): |
| if isinstance(value, torch.Tensor): |
| value = value.detach() |
| detached_kwargs[key] = value |
|
|
| if len(self.states) <= layer_idx: |
| self.states.append(detached_kwargs) |
| else: |
| self.states[layer_idx].update(detached_kwargs) |
|
|
| def reset(self): |
| """ |
| Reset the cache and token counter. |
| """ |
| self.states.clear() |
| self.seen_tokens = 0 |
|
|
|
|
| class LiZAttention(nn.Module): |
| """LiZA Linear Attention module, mixing linear and vanilla attention.""" |
|
|
| def __init__( |
| self, |
| base_attn: nn.Module, |
| layer_idx: int, |
| base_config, |
| linear_cache: Optional[LCache] = None, |
| operator_mode: str = "delta_rule", |
| max_self_attn_length: int = 2048, |
| mag_weight: float = 0.5, |
| max_chunk_size: int = 64, |
| ): |
| super().__init__() |
| self.base_attn = base_attn |
| self.base_config = base_config |
| self.layer_idx = layer_idx |
| self.max_self_attn_length = max_self_attn_length |
| self.mag_weight = mag_weight |
| self.max_chunk_size = max_chunk_size |
| self.linear_cache = linear_cache or LCache() |
| ( |
| self.num_heads, |
| self.head_dim, |
| self.num_key_value_heads, |
| self.num_key_value_groups, |
| ) = self._get_attention_parameters(base_attn, base_config) |
| self.operator = get_attention_operator(operator_mode) |
| self.pool_g = nn.AdaptiveAvgPool1d( |
| output_size=self.head_dim * self.num_key_value_heads |
| ) |
|
|
| def _get_attention_parameters(self, base_attn, base_config): |
| """Retrieve the attention parameters from the base attention module.""" |
| |
| num_heads = ( |
| getattr(base_attn, "num_heads", None) |
| or getattr(base_attn, "num_q_heads", None) |
| or getattr(base_config, "num_heads", None) |
| or getattr(base_config, "num_attention_heads", None) |
| ) |
| head_dim = getattr(base_attn, "head_dim", None) or getattr( |
| base_config, "head_dim", None |
| ) |
| num_key_value_heads = ( |
| getattr(base_attn, "num_kv_heads", None) |
| or getattr(base_attn, "num_k_heads", None) |
| or getattr(base_config, "num_key_value_heads", None) |
| or num_heads |
| ) |
| num_key_value_groups = getattr(base_attn, "num_key_value_groups", None) or ( |
| num_heads // num_key_value_heads if num_heads and num_key_value_heads else 1 |
| ) |
| return ( |
| num_heads, |
| head_dim, |
| num_key_value_heads, |
| num_key_value_groups, |
| ) |
|
|
| def _apply_projections(self, hidden_states): |
| base_attn = self.base_attn |
| if hasattr(base_attn, "q_proj"): |
| |
| q = base_attn.q_proj(hidden_states) |
| k = base_attn.k_proj(hidden_states) |
| v = base_attn.v_proj(hidden_states) |
| out_proj = base_attn.o_proj |
| elif hasattr(base_attn, "qkv_proj"): |
| |
| qkv = base_attn.qkv_proj(hidden_states) |
| q, k, v = split_qkv(base_attn, qkv) |
| out_proj = base_attn.out_proj |
| elif hasattr(base_attn, "c_attn") and hasattr(base_attn, "c_proj"): |
| |
| qkv = base_attn.c_attn(hidden_states) |
| q, k, v = qkv.chunk(3, dim=-1) |
| out_proj = base_attn.c_proj |
| else: |
| raise ValueError("Unsupported attention module: cannot find projections.") |
| |
| q = torch.clamp(q, min=-1e4, max=1e4) |
| k = torch.clamp(k, min=-1e4, max=1e4) |
| v = torch.clamp(v, min=-1e4, max=1e4) |
| return q, k, v, out_proj |
|
|
| def _prepare_attn_input(self, q, k, v, gate_norm): |
| |
| g = self.pool_g(k) |
|
|
| |
| q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) |
| k = rearrange(k, "b n (h d) -> b h n d", h=self.num_key_value_heads) |
| v = rearrange(v, "b n (h d) -> b h n d", h=self.num_key_value_heads) |
| g = rearrange(g, "b n (h m) -> b h n m", h=self.num_key_value_heads) |
|
|
| |
| k = repeat_kv(k, self.num_key_value_groups) |
| v = repeat_kv(v, self.num_key_value_groups) |
| g = repeat_kv(g, self.num_key_value_groups) |
|
|
| |
| q = torch.clamp(F.softmax(q, dim=-1), min=1e-6, max=1 - 1e-6) |
| k = torch.clamp(F.softmax(k, dim=-1), min=1e-6, max=1 - 1e-6) |
|
|
| g = F.logsigmoid(g) / gate_norm |
| g = torch.clamp(g, min=-gate_norm, max=gate_norm) |
|
|
| |
| q, k, v, g = (x.to(torch.float32).contiguous() for x in (q, k, v, g)) |
|
|
| return q, k, v, g |
|
|
| def _process_linear_attn(self, q, k, v, g, out_proj, tensor_dtype, kwargs): |
| |
| if kwargs["use_cache"]: |
| last_state = self.linear_cache[self.layer_idx] |
| recurrent_state = ( |
| last_state["recurrent_state"] |
| if last_state is not None and "recurrent_state" in last_state |
| else None |
| ) |
| else: |
| recurrent_state = None |
|
|
| |
| o_lin, recurrent_state = self.operator( |
| q, |
| k, |
| v, |
| beta=g, |
| chunk_size=self.max_chunk_size, |
| recurrent_state=recurrent_state, |
| ) |
| o_lin = rearrange(o_lin, "b h n d -> b n (h d)").to(tensor_dtype) |
| o_lin = out_proj(o_lin) |
| |
| o_lin = torch.clamp(o_lin, min=-1e4, max=1e4) |
|
|
| |
| if kwargs["use_cache"]: |
| self.linear_cache.update(self.layer_idx, recurrent_state=recurrent_state) |
| return o_lin |
|
|
| def _process_self_attn(self, hidden_states, attention_mask, kwargs): |
| |
| hidden_states, attention_mask = truncate_attention_mask( |
| hidden_states, attention_mask, self.max_self_attn_length |
| ) |
|
|
| if kwargs.get("position_embeddings", None) is not None: |
| cos, sin = kwargs["position_embeddings"] |
| cos = cos[:, -self.max_self_attn_length :] |
| sin = sin[:, -self.max_self_attn_length :] |
| kwargs["position_embeddings"] = (cos, sin) |
|
|
| if isinstance(kwargs.get("past_key_value", None), DynamicCache): |
| |
| if len(kwargs["past_key_value"]) > self.layer_idx and self.layer_idx == 0: |
| kwargs["past_key_value"].crop(self.max_self_attn_length - 1) |
|
|
| |
| base_attn_outputs = self.base_attn( |
| hidden_states, |
| attention_mask=attention_mask, |
| **kwargs, |
| ) |
|
|
| if isinstance(base_attn_outputs, tuple): |
| if len(base_attn_outputs) == 3: |
| o_base, attn_weights, present_key_value = base_attn_outputs |
| expected_attn_mode = 3 |
| elif len(base_attn_outputs) == 2: |
| o_base, attn_weights = base_attn_outputs |
| present_key_value, expected_attn_mode = None, 2 |
| else: |
| raise ValueError( |
| f"Unexpected number of outputs from base_attn: {len(base_attn_outputs)}" |
| ) |
| else: |
| o_base = base_attn_outputs |
| attn_weights, present_key_value, expected_attn_mode = None, None, 1 |
| |
| o_base = torch.clamp(o_base, min=-1e4, max=1e4) |
| return o_base, attn_weights, present_key_value, expected_attn_mode |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ): |
| device = hidden_states.device |
| tensor_dtype = hidden_states.dtype |
| self.base_attn.to(device) |
|
|
| if self.training: |
| kwargs.pop("past_key_value", None) |
| kwargs["use_cache"] = False |
| else: |
| |
| kwargs["use_cache"] = True |
|
|
| kwargs.pop("position_ids", None) |
|
|
| |
| q, k, v, out_proj = self._apply_projections(hidden_states) |
|
|
| |
| if attention_mask is not None: |
| |
| v = apply_linear_attention_mask(attention_mask, v) |
|
|
| |
| gate_norm = kwargs.get("gate_logit_normalizer", 16) |
| q, k, v, g = self._prepare_attn_input(q, k, v, gate_norm) |
|
|
| |
| o_lin = self._process_linear_attn(q, k, v, g, out_proj, tensor_dtype, kwargs) |
|
|
| |
| o_base, attn_weights, present_key_value, expected_attn_mode = ( |
| self._process_self_attn(hidden_states, attention_mask, kwargs) |
| ) |
|
|
| |
| o_lin = o_lin.to(tensor_dtype) |
| o_base = o_base.to(tensor_dtype) |
|
|
| |
| if o_lin.shape[1] > o_base.shape[1]: |
| o_padding = torch.zeros_like(o_lin).to(tensor_dtype) |
| o_padding[:, -o_base.shape[1] :] = o_base |
| o_base = o_padding |
| elif o_lin.shape[1] != o_base.shape[1]: |
| left_trunc = min(o_lin.shape[1], o_base.shape[1]) |
| o_lin, o_base = o_lin[:, -left_trunc:], o_base[:, -left_trunc:] |
| out = self.mag_weight * o_lin + (1 - self.mag_weight) * o_base |
| |
| out = torch.clamp(out, min=-1e4, max=1e4) |
|
|
| |
| if expected_attn_mode == 3: |
| return out, attn_weights, present_key_value |
| elif expected_attn_mode == 2: |
| return out, attn_weights |
| else: |
| return out |
|
|
|
|
| def get_tptt_model( |
| model: nn.Module, |
| base_config: PretrainedConfig, |
| liza_attention: LiZAttention, |
| target_modules: list, |
| linear_cache: Optional[LCache] = None, |
| operator_mode: str = "delta_rule", |
| mag_weight: float = 0.5, |
| max_chunk_size: int = 64, |
| max_self_attn_length: int = 2048, |
| ): |
| """Replace target modules in a model with LiZAttention.""" |
| linear_cache = linear_cache or LCache() |
| |
| for name, _ in model.named_modules(): |
| if name in target_modules: |
| parent = model |
| *path, last = name.split(".") |
| for p in path: |
| parent = getattr(parent, p) |
| layer_idx = extract_layer_idx(name) |
| setattr( |
| parent, |
| last, |
| liza_attention( |
| getattr(parent, last), |
| layer_idx=layer_idx, |
| base_config=base_config, |
| linear_cache=linear_cache, |
| operator_mode=operator_mode, |
| max_self_attn_length=max_self_attn_length, |
| mag_weight=mag_weight, |
| max_chunk_size=max_chunk_size, |
| ), |
| ) |
| return model, linear_cache |
|
|
|
|
| class TpttModel(PreTrainedModel): |
| """ |
| TPTT model wrapper with linear attention (LiZA) and LoRA support. |
| Handles only architecture and weights. |
| """ |
|
|
| config_class = TpttConfig |
|
|
| def __init__( |
| self, |
| config: TpttConfig, |
| **kwargs, |
| ): |
| """ |
| Initialize TpttModel with a given config and backbone. |
| Injects LiZA attention modules into the backbone. |
| """ |
| super().__init__(config, **kwargs) |
| repo_or_path = getattr(config, "_base_path", None) or config._name_or_path |
|
|
| |
| self.backbone = AutoModelForCausalLM.from_pretrained( |
| config.base_model_name, **kwargs |
| ) |
| self._retie_lm_after_load(**kwargs) |
|
|
| |
| self.linear_cache = LCache() |
| self.backbone, self.linear_cache = self.inject_liza_attention( |
| self.backbone, config, self.linear_cache |
| ) |
| |
| if config.lora_config is not None: |
| lora_config_obj = LoraConfig(**config.lora_config) |
| self.backbone = get_peft_model(self.backbone, lora_config_obj) |
| if repo_or_path: |
| self.load_peft_safetensors( |
| repo_or_path, token=kwargs.get("token", None) |
| ) |
|
|
| def load_peft_safetensors(self, src, token=None): |
| |
| fname = "adapter_model.safetensors" |
| if os.path.isdir(src): |
| path = os.path.join(src, fname) |
| if not os.path.exists(path): |
| return |
| else: |
| if fname not in list_repo_files(src, token=token): |
| return |
| path = hf_hub_download(src, fname, token=token) |
| with safe_open(path, framework="pt") as f: |
| self.backbone.load_state_dict( |
| {k: f.get_tensor(k) for k in f.keys()}, strict=False |
| ) |
|
|
| @staticmethod |
| def inject_liza_attention( |
| backbone, |
| config, |
| linear_cache, |
| ): |
| """ |
| Inject LiZAttention into the specified target modules of the base model. |
| """ |
| |
| target_modules = [ |
| name |
| for name, _ in backbone.named_modules() |
| if any(name.endswith(suffix) for suffix in config.target_modules_names) |
| ] |
| if not target_modules: |
| raise ValueError( |
| f"Target modules '{config.target_modules_names}' not found in the model." |
| ) |
| |
| return get_tptt_model( |
| backbone, |
| base_config=backbone.config, |
| liza_attention=LiZAttention, |
| target_modules=target_modules, |
| linear_cache=linear_cache, |
| operator_mode=config.operator_mode, |
| max_self_attn_length=config.max_self_attn_length, |
| mag_weight=config.mag_weight, |
| max_chunk_size=config.max_chunk_size, |
| ) |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| """ |
| Forward pass. All arguments are passed to the underlying base model. |
| """ |
| if self.training: |
| kwargs["use_cache"] = False |
| kwargs.pop("num_items_in_batch", None) |
| else: |
| kwargs["use_cache"] = True |
| return self.backbone( |
| input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs |
| ) |
|
|
| def generate(self, *args, **kwargs): |
| |
| return self.backbone.generate(*args, **kwargs) |
|
|
| def save_pretrained(self, path: str, **kwargs): |
| """Save model weights, config, and source code to the given path.""" |
| super().save_pretrained(path, **kwargs) |
|
|
| |
| self._save_peft_weights(path, **kwargs) |
| |
| self._copy_source_files(path) |
|
|
| def _save_peft_weights(self, path: str, **kwargs): |
| """Save PEFT weights and remove redundant adapter config.""" |
| self.backbone.save_pretrained(path, **kwargs) |
| adapter_config_path = os.path.join(path, "adapter_config.json") |
| if os.path.exists(adapter_config_path): |
| os.remove(adapter_config_path) |
|
|
| def _copy_source_files(self, path: str): |
| """Copy all .py files from package directory for trust_remote_code.""" |
| src_dir = os.path.dirname(os.path.abspath(__file__)) |
| for fname in os.listdir(src_dir): |
| if fname.endswith(".py"): |
| src = os.path.join(src_dir, fname) |
| dst = os.path.join(path, fname) |
| shutil.copy2(src, dst) |
|
|
| def _retie_lm_after_load(self, **kwargs): |
| """Re-link lm_head after loading external weights.""" |
| embed_lm = find_embedding_lm(self.backbone) |
| if embed_lm is not None and hasattr(self.backbone, "lm_head"): |
| if self.backbone.lm_head is None: |
| self.backbone.lm_head = nn.Linear( |
| embed_lm.weight.shape[1], embed_lm.weight.shape[0], bias=False |
| ) |
| if kwargs.get("tie_word_embeddings", True): |
| self.backbone.lm_head.weight = embed_lm.weight |
| logger.info("Weights of lm_head have been shared with embedding.") |
| else: |
| self.backbone.lm_head.weight = nn.Parameter(embed_lm.weight.clone()) |
| logger.info("Weights of lm_head have been cloned from the embedding.") |
|
|
| @classmethod |
| def from_pretrained(cls, *args, **kwargs): |
| model = super().from_pretrained(*args, **kwargs) |
| model._retie_lm_after_load(**kwargs) |
| return model |
|
|
|
|
| TpttModel.register_for_auto_class("AutoModelForCausalLM") |
|
|
|
|
| class AttentionOperator(nn.Module): |
| """Base class for linear attention operators.""" |
|
|
| def __init__(self, mode="delta_rule"): |
| super().__init__() |
| self.mode = mode |
|
|
| def forward(self, q, k, v, **options): |
| """Forward pass for the attention operator.""" |
| beta = options.get("beta", None) |
| chunk_size = options.get("chunk_size", 64) |
| scale = options.get("scale", 1) |
| recurrent_state = options.get("recurrent_state", None) |
|
|
| if self.mode == "delta_rule": |
| return self.chunk_delta_rule_forward( |
| q, k, v, beta, chunk_size, initial_state=recurrent_state |
| ) |
| if self.mode == "gla": |
| return self.gla_forward(q, k, v, beta, scale, initial_state=recurrent_state) |
| raise ValueError(f"Unknown operator mode: {self.mode}") |
|
|
| @staticmethod |
| def chunk_delta_rule_forward( |
| query, key, value, beta, chunk_size, initial_state=None |
| ): |
| """ |
| Implementation of https://arxiv.org/abs/2406.06484 |
| query, key, value, beta: [batch, num_heads, seq_len, head_dim] |
| chunk_size: int |
| initial_state: [batch, num_heads, head_dim, head_dim] or None |
| """ |
| batch_size, num_heads, seq_len, head_dim = query.shape |
| chunk_size = get_valid_chunk_size(seq_len, chunk_size) |
| num_chunks = seq_len // chunk_size |
|
|
| |
| q_chunks = query.reshape( |
| batch_size, num_heads, num_chunks, chunk_size, head_dim |
| ) |
| k_chunks = key.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim) |
| v_chunks = value.reshape( |
| batch_size, num_heads, num_chunks, chunk_size, head_dim |
| ) |
| beta_chunks = beta.reshape( |
| batch_size, num_heads, num_chunks, chunk_size, head_dim |
| ) |
|
|
| |
| output = torch.empty_like(q_chunks) |
| |
| expect_state_shape = (batch_size, num_heads, head_dim, head_dim) |
| if initial_state is not None and initial_state.shape == expect_state_shape: |
| |
| state = initial_state.to(device=query.device, dtype=query.dtype) |
| else: |
| state = torch.zeros( |
| batch_size, |
| num_heads, |
| head_dim, |
| head_dim, |
| device=query.device, |
| dtype=query.dtype, |
| ) |
|
|
| def process_chunk(q, k, v, b, state): |
| """ |
| q, k, v, b: [batch, num_heads, chunk_size, head_dim] |
| state: [batch, num_heads, head_dim, head_dim] |
| Returns: (output_chunk, new_state) |
| """ |
| |
| k = torch.clamp(k, min=-1e4, max=1e4) |
| v = torch.clamp(v, min=-1e4, max=1e4) |
| b = torch.clamp(b, min=1e-6, max=1e4) |
| q = torch.clamp(q, min=-1e4, max=1e4) |
|
|
| |
| k_beta = k * b |
| v_beta = v * b |
|
|
| |
| |
| t_matrix = -(k_beta @ k.transpose(-2, -1)).tril(-1) |
| t_matrix = torch.clamp(t_matrix, min=-1e4, max=1e4) |
| t_matrix = t_matrix + torch.eye( |
| q.shape[-2], device=q.device, dtype=q.dtype |
| ).unsqueeze(0).unsqueeze(0) |
|
|
| |
| w_matrix = t_matrix @ k_beta |
| w_matrix = torch.clamp(w_matrix, min=-1e4, max=1e4) |
|
|
| u_matrix = t_matrix @ v_beta |
| u_matrix = torch.clamp(u_matrix, min=-1e4, max=1e4) |
|
|
| |
| u_i = u_matrix - torch.matmul(w_matrix, state) |
|
|
| |
| o_inter = torch.matmul(q, state) |
|
|
| |
| a_i = (q @ k.transpose(-2, -1)).tril() |
|
|
| |
| o_intra = torch.matmul(a_i, u_i) |
|
|
| |
| new_state = state + torch.matmul(k.transpose(-2, -1), u_i) |
| new_state = torch.clamp(new_state, min=-1e4, max=1e4) |
|
|
| |
| return o_intra + o_inter, new_state |
|
|
| for chunk_idx in range(num_chunks): |
| q = q_chunks[:, :, chunk_idx] |
| k = k_chunks[:, :, chunk_idx] |
| v = v_chunks[:, :, chunk_idx] |
| b = beta_chunks[:, :, chunk_idx] |
|
|
| chunk_out, state = process_chunk(q, k, v, b, state) |
| output[:, :, chunk_idx] = chunk_out |
|
|
| |
| output = output.reshape(batch_size, num_heads, seq_len, head_dim) |
| return output, state |
|
|
| @staticmethod |
| def gla_forward(q, k, v, beta, scale, initial_state=None): |
| """Forward pass for GLA attention operator.""" |
| if fused_chunk_gla is None or fused_recurrent_gla is None: |
| raise RuntimeError("GLA kernels are not available: CUDA required.") |
| if q.shape[-2] > 1: |
| |
| return fused_chunk_gla( |
| q, |
| k, |
| v, |
| beta, |
| scale=scale, |
| initial_state=initial_state, |
| output_final_state=True, |
| ) |
| return fused_recurrent_gla( |
| q, |
| k, |
| v, |
| beta, |
| scale=scale, |
| initial_state=initial_state, |
| output_final_state=True, |
| ) |
|
|
|
|
| def get_attention_operator(mode): |
| """Factory for AttentionOperator.""" |
| return AttentionOperator(mode=mode) |
|
|
|
|
| def extract_layer_idx(module_name: str) -> int: |
| """ |
| Extract the layer index from a module name string. |
| """ |
| match = re.search(r"\.(\d+)\.", module_name) |
| if match: |
| return int(match.group(1)) |
| return -1 |
|
|
|
|
| def find_embedding_lm(module): |
| """Find the embedding weight in a model module.""" |
| for _, child in module.named_modules(): |
| if hasattr(child, "embed_tokens") and hasattr(child.embed_tokens, "weight"): |
| return child.embed_tokens |
| if hasattr(child, "token_embeddings") and hasattr( |
| child.token_embeddings, "weight" |
| ): |
| return child.token_embeddings |
| return None |
|
|
|
|
| def soft_clamp(x, min_val=-1e4, max_val=1e4): |
| """Differentiable clamping for stability""" |
| scale = (max_val - min_val) / 2 |
| center = (max_val + min_val) / 2 |
| return torch.tanh((x - center) / scale) * scale + center |
|
|
|
|
| def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """Repeat key/value heads for grouped query attention (GQA).""" |
| return x.repeat_interleave(n_rep, dim=1) |
|
|
|
|
| def split_qkv(base_attn, qkv): |
| """Split the QKV tensor into separate Q, K, and V tensors.""" |
| num_q_heads = getattr(base_attn, "num_q_heads", None) |
| num_k_heads = getattr(base_attn, "num_k_heads", None) |
| num_v_heads = getattr(base_attn, "num_v_heads", None) |
| head_dim = getattr(base_attn, "head_dim", None) |
|
|
| q_len = num_q_heads * head_dim |
| k_len = num_k_heads * head_dim |
| v_len = num_v_heads * head_dim |
|
|
| q, k, v = torch.split(qkv, [q_len, k_len, v_len], dim=-1) |
| return q, k, v |
|
|
|
|
| def apply_linear_attention_mask(attention_mask, v): |
| |
| if attention_mask.dim() == 4 and attention_mask.shape[1] == 1: |
| |
| mask = attention_mask.diagonal(dim1=-2, dim2=-1).squeeze(1) |
| else: |
| |
| mask = attention_mask.squeeze( |
| dim=tuple( |
| i |
| for i in range(1, attention_mask.dim()) |
| if attention_mask.shape[i] == 1 |
| ) |
| ) |
| |
| mask = mask[:, -v.shape[-2] :][(...,) + (None,) * (v.dim() - 2)] |
| return v * mask |
|
|
|
|
| def truncate_attention_mask(hidden_states, attention_mask, max_length): |
| """ |
| Truncate hidden_states and attention_mask to the last window of size max_length, |
| matching the sequence dimension of hidden_states. |
| """ |
| seq_dim = 1 |
| seq_len = hidden_states.shape[seq_dim] |
| if seq_len > max_length: |
| hidden_states = hidden_states.narrow(seq_dim, seq_len - max_length, max_length) |
| if attention_mask is not None: |
| |
| if attention_mask.dim() == 2: |
| attention_mask = attention_mask[:, -max_length:] |
| |
| elif attention_mask.dim() == 3: |
| attention_mask = attention_mask[:, -max_length:, -max_length:] |
| |
| elif attention_mask.dim() == 4 and attention_mask.shape[1] == 1: |
| attention_mask = attention_mask[:, :, -max_length:, -max_length:] |
| else: |
| raise ValueError( |
| "No dimension in attention_mask matches sequence length of hidden_states." |
| ) |
| return hidden_states, attention_mask |
|
|
|
|
| def get_valid_chunk_size(total_l: int, chunk_size: int) -> int: |
| """ |
| Return the largest chunk_size <= chunk_size that divides total_l. |
| If no chunk_size > 1 fits, return 1. |
| """ |
| for c in range(min(chunk_size, total_l), 0, -1): |
| if total_l % c == 0: |
| return c |
| return 1 |
|
|
|
|
| def match_dim(x: torch.Tensor, dim: int, target_size: int) -> torch.Tensor: |
| """ |
| Match the size of tensor x along dimension dim to target_size by interpolation |
| or projection. |
| """ |
| src_size = x.shape[dim] |
| if src_size == target_size: |
| return x |
| x = torch.moveaxis(x, dim, -1) |
| shape = x.shape |
| if src_size < target_size: |
| x = x.reshape(-1, 1, src_size) |
| x = F.interpolate(x, size=target_size, mode="linear", align_corners=False) |
| x = x.reshape(*shape[:-1], target_size) |
| else: |
| eye = torch.eye(target_size, src_size, device=x.device, dtype=x.dtype) |
| x = F.linear(x, eye) |
| x = torch.moveaxis(x, -1, dim) |
| return x |
|
|