| | """ |
| | Orginally Taken verbatim from xformers library |
| | https://github.com/facebookresearch/xformers/blob/bcb707576c6a80eaf850aa80e8643d3497ec2bc4/xformers/components/positional_embedding/rotary.py |
| | |
| | The difference is that xformers seems to assume the inputs to be |
| | (bs, head, seq_len, dim) while we assume (bs, seq_len, head, dim) |
| | |
| | """ |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | |
| | |
| |
|
| | import math |
| | from typing import List, Optional, Tuple, Dict, Union |
| |
|
| | import torch |
| | import dataclasses |
| | from transformers.utils import logging |
| |
|
| | from transformers import PretrainedConfig |
| |
|
| | is_dacite_available = False |
| | try: |
| | import dacite |
| | is_dacite_available = True |
| | except ImportError: |
| | pass |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | @dataclasses.dataclass |
| | class LongRopeConfig(object): |
| | short_factor: List[float] |
| | long_factor: List[float] |
| | original_max_position_embeddings: int |
| | type: str = "longrope" |
| | short_mscale: float = -1 |
| | long_mscale: float = -1 |
| |
|
| |
|
| | def __post_init__(self): |
| | assert self.type in ("longrope", "su"), f"Invalid type {self.type} for LongRopeConfig. Expected longrope / su" |
| |
|
| |
|
| | @classmethod |
| | def from_dict(cls, config_dict: Dict[str, Union[float, List[float], int]]) -> "LongRopeConfig": |
| | if is_dacite_available: |
| | |
| | return dacite.from_dict(data_class=cls, data=config_dict) |
| | kwargs = {} |
| | for field in dataclasses.fields(cls): |
| | if field.name in config_dict: |
| | if field.init: |
| | kwargs[field.name] = config_dict[field.name] |
| | else: |
| | raise ValueError(f"Field {field.name} is not initiable") |
| | else: |
| | if field.default is dataclasses.MISSING: |
| | raise ValueError(f"Field {field.name} is required") |
| | extra_keys = set(config_dict.keys()) - set(kwargs.keys()) |
| | if len(extra_keys) > 0: |
| | for key in extra_keys: |
| | logger.error(f"Unrecognized key {key} in config_dict") |
| | raise ValueError(f"Unrecognized keys in config_dict") |
| | return cls(**kwargs) |
| |
|
| | def rotate_half(x): |
| | x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=x1.ndim - 1) |
| |
|
| |
|
| |
|
| | @torch.jit.script |
| | def apply_rotary_pos_emb(x, cos, sin, seq_dimension: int): |
| | |
| |
|
| | if seq_dimension == 0: |
| | cos = cos[: x.shape[0], None, None, :] |
| | sin = sin[: x.shape[0], None, None, :] |
| | elif seq_dimension == 1: |
| | |
| | cos = cos[None, : x.shape[1], None, :] |
| | sin = sin[None, : x.shape[1], None, :] |
| | elif seq_dimension == 2: |
| | cos = cos[None, None, : x.shape[2], :] |
| | sin = sin[None, None, : x.shape[2], :] |
| |
|
| | return (x * cos) + (rotate_half(x) * sin) |
| |
|
| |
|
| |
|
| | class RotaryEmbedding(torch.nn.Module): |
| | """ |
| | Adapted from the xformers library |
| | |
| | The rotary position embeddings from RoFormer_ (Su et. al). |
| | A crucial insight from the method is that the query and keys are |
| | transformed by rotation matrices which depend on the relative positions. |
| | Other implementations are available in the Rotary Transformer repo_ and in |
| | GPT-NeoX_, GPT-NeoX was an inspiration |
| | .. _RoFormer: https://arxiv.org/abs/2104.09864 |
| | .. _repo: https://github.com/ZhuiyiTechnology/roformer |
| | .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox |
| | .. warning: Please note that this embedding is not registered on purpose, as it is transformative |
| | (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis |
| | |
| | # Arguments |
| | :param dim_mode: head dimention |
| | :param max_seq_len: |
| | :param default_seq_dimension: which dim is the sequence length |
| | :param dtype: cos/sin dtype |
| | :param use_fused_kernel: if to use customized fused kernel. |
| | Note: if used, q, k will be modified inplace. Ok for both forward & backward. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim_model: int, |
| | *, |
| | max_seq_len: Optional[int] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | base=10000, |
| | position_scale=1, |
| | device: Optional[torch.device] = None, |
| | longrope_config: Optional[LongRopeConfig] = None, |
| | ): |
| | super().__init__() |
| | self.base = base |
| | self.dim_model = dim_model |
| | self.max_seq_len = max_seq_len |
| | self.longrope_config = longrope_config |
| |
|
| | if self.is_longrope: |
| | |
| | self.register_buffer( |
| | "range_vector", |
| | torch.arange(max_seq_len, device=device, dtype=torch.float32), |
| | persistent=False |
| | ) |
| | self.register_buffer( |
| | "short_factors", |
| | torch.tensor(self.longrope_config.short_factor, dtype=torch.float32), |
| | persistent=False |
| | ) |
| | self.register_buffer( |
| | "long_factors", |
| | torch.tensor(self.longrope_config.long_factor, dtype=torch.float32), |
| | persistent=False |
| | ) |
| | else: |
| | |
| | inv_freq = 1.0 / (base ** (torch.arange(0, dim_model, 2).float().to(device) / self.dim_model)) |
| | self.register_buffer("inv_freq", inv_freq) |
| |
|
| | self.position_scale = position_scale |
| | |
| | if not self.is_longrope: |
| | dtype = dtype or torch.get_default_dtype() |
| | self._set_cos_sin_cache( |
| | seq_len=max_seq_len, |
| | device=self.inv_freq.device, |
| | dtype=dtype, |
| | ) |
| | @property |
| | def is_longrope(self): |
| | return self.longrope_config is not None |
| |
|
| | @property |
| | def original_max_seq_len(self): |
| | if self.longrope_config is not None: |
| | return self.longrope_config.original_max_position_embeddings |
| | logger.warning_once( |
| | ( |
| | "``original_max_seq_len'' is being accessed, but longrope_config has not been set. " |
| | "Please only do this if you are sure about the context." |
| | ) |
| | ) |
| | return self.max_seq_len |
| |
|
| | def get_range_vector(self, seq_len: int, device: torch.device): |
| | if self.is_longrope: |
| | assert seq_len < self.range_vector.shape[0], f"Found seq_len {seq_len} greater than max_seq_len {self.range_vector.shape[0]}" |
| | if self.range_vector.device != device: |
| | self.range_vector = self.range_vector.to(device) |
| | return self.range_vector[:seq_len] |
| | return torch.arange(seq_len, device=device, dtype=torch.float32) |
| |
|
| |
|
| | def _calc_mscale(self, scale: torch.Tensor) -> torch.Tensor: |
| | if scale <= 1.0: |
| | return 1.0 |
| | return math.sqrt(1 + math.log(scale) / math.log(self.original_max_seq_len)) |
| |
|
| | def _set_cos_sin_cache( |
| | self, |
| | seq_len: int, |
| | device: Optional[torch.device] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | ) -> None: |
| | dtype = dtype or torch.get_default_dtype() |
| | self.max_seq_len_cached = seq_len |
| | t = (torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) * self.position_scale).type_as(self.inv_freq) |
| | device_type = device.type if device is not None else "cpu" |
| | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
| | with torch.autocast(device_type=device_type, enabled=False): |
| | |
| | freqs = torch.outer(t, self.inv_freq) |
| | |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | cos = emb.cos() |
| | sin = emb.sin() |
| | self.register_buffer("cos_cached", cos.to(dtype), persistent=False) |
| | self.register_buffer("sin_cached", sin.to(dtype), persistent=False) |
| |
|
| | def forward( |
| | self, q: torch.Tensor, |
| | k: torch.Tensor, |
| | seq_dimension: int = 1, |
| | seqlen_offset: int = 0, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """q, k does not include `seqlen_offset` |
| | q: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim) |
| | k: Either (bs, seq_len, num_heads, head_dim) or (seq_len, bs, num_heads, head_dim) |
| | """ |
| | if seq_dimension < 0: |
| | seq_dimension = k.ndim + seq_dimension |
| | assert seq_dimension in (0, 1, 2) |
| | seq_len = k.shape[seq_dimension] + seqlen_offset |
| |
|
| | if self.is_longrope: |
| | if seq_len > self.original_max_seq_len: |
| | t = self.get_range_vector(seq_len, device=q.device) |
| | rescale_factors = self.long_factors.to(q.device) |
| | long_mscale = self.longrope_config.long_mscale |
| | mscale = long_mscale if long_mscale > 0 else self._calc_mscale(self.max_seq_len / self.original_max_seq_len) |
| | else: |
| | t = self.get_range_vector(self.original_max_seq_len, device=q.device) |
| | rescale_factors = self.short_factors.to(q.device) |
| | short_mscale = self.longrope_config.short_mscale |
| | mscale = short_mscale if short_mscale > 0 else 1.0 |
| | assert rescale_factors.shape == (self.dim_model // 2, ), ( |
| | f"misaligned shape for LongRoPE rescale factors:\n" |
| | f"\tExpected {(self.dim_model // 2, )}, got {rescale_factors.shape}." |
| | ) |
| | inv_freq = 1.0 / (rescale_factors * (self.base ** (torch.arange(0, self.dim_model, 2).float().to(q.device) / self.dim_model))) |
| | device_type = q.device.type if q.device is not None else "cpu" |
| | device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" |
| | with torch.autocast(device_type=device_type, enabled=False): |
| | freqs = torch.outer(t, inv_freq) |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| | cos = emb.cos() * mscale |
| | sin = emb.sin() * mscale |
| | cos_cached = cos.to(q.dtype) |
| | sin_cached = sin.to(q.dtype) |
| | else: |
| | if seq_len > self.max_seq_len_cached: |
| | self._set_cos_sin_cache( |
| | seq_len=seq_len, |
| | device=k.device, |
| | dtype=k.dtype, |
| | ) |
| | cos_cached = self.cos_cached |
| | sin_cached = self.sin_cached |
| | return ( |
| | apply_rotary_pos_emb( |
| | q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension |
| | ).to(q.dtype), |
| | apply_rotary_pos_emb( |
| | k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension |
| | ).to(k.dtype), |
| | ) |
| |
|
| | @classmethod |
| | def from_config(cls, config: PretrainedConfig) -> "RotaryEmbedding": |
| | kwargs = dict( |
| | dim_model=config.hidden_size // config.num_attention_heads, |
| | max_seq_len=config.max_position_embeddings, |
| | base=config.rope_embedding_base, |
| | position_scale=config.rope_position_scale, |
| | ) |
| | if config.rope_scaling is not None: |
| | kwargs["longrope_config"] = LongRopeConfig.from_dict(config.rope_scaling) |
| | return cls(**kwargs) |
| |
|