| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
|
|
| from __future__ import annotations |
|
|
|
|
| import torch |
| import torch.nn as nn |
| import os |
| from torch import Tensor |
| from functools import lru_cache |
| from itertools import product |
| from typing import Any, Sequence, Tuple, List |
| from pathlib import Path |
| from collections import OrderedDict |
| from transformers.tokenization_utils import PreTrainedTokenizer |
|
|
|
|
| VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} |
| SPECIAL_TOKENS_MAP = { |
| "pad_token": { |
| "content": "<pad>", |
| "lstrip": False, |
| "normalized": False, |
| "rstrip": False, |
| "single_word": False, |
| }, |
| "cls_token": { |
| "content": "<cls>", |
| "lstrip": False, |
| "normalized": False, |
| "rstrip": False, |
| "single_word": False, |
| }, |
| "eos_token": { |
| "content": "<eos>", |
| "lstrip": False, |
| "normalized": False, |
| "rstrip": False, |
| "single_word": False, |
| }, |
| "unk_token": { |
| "content": "<unk>", |
| "lstrip": False, |
| "normalized": False, |
| "rstrip": False, |
| "single_word": False, |
| }, |
| "mask_token": { |
| "content": "<mask>", |
| "lstrip": False, |
| "normalized": False, |
| "rstrip": False, |
| "single_word": False, |
| }, |
| "null_token": { |
| "content": "<null>", |
| "lstrip": False, |
| "normalized": False, |
| "rstrip": False, |
| "single_word": False, |
| }, |
| } |
|
|
| STANDARD_ALPHABET = list("ACGUNRYSWKMBDHV.X*-I") |
|
|
| IUPAC_ALPHABET = list("ACGUNRYSWKMBDHV.") |
|
|
| STREAMLINE_ALPHABET = list("ACGUN") |
|
|
| NUCLEOBASE_ALPHABET = list("ACGU") |
|
|
| ALPHABETS = { |
| "standard": STANDARD_ALPHABET, |
| "iupac": IUPAC_ALPHABET, |
| "streamline": STREAMLINE_ALPHABET, |
| "nucleobase": NUCLEOBASE_ALPHABET, |
| } |
|
|
| VOCAB_MAPPING = { |
| "R": "AG", |
| "Y": "CU", |
| "S": "CG", |
| "W": "AU", |
| "K": "GU", |
| "M": "AC", |
| "B": "CGU", |
| "D": "AGU", |
| "H": "ACU", |
| "V": "ACG", |
| "X": "ACGU", |
| } |
|
|
| TOKENIZER_CONFIG = { |
| "tokenizer_class": "RnaTokenizer", |
| "clean_up_tokenization_spaces": True, |
| } |
|
|
|
|
| def get_alphabet(alphabet: List[str] | str | None = None, nmers: int = 1, **kwargs) -> Alphabet: |
| if alphabet is None: |
| alphabet = STANDARD_ALPHABET if nmers <= 1 else STREAMLINE_ALPHABET |
| elif isinstance(alphabet, str): |
| alphabet = ALPHABETS[alphabet] |
| return Alphabet(alphabet, nmers=nmers, **kwargs) |
|
|
|
|
| def get_vocab_mapping(): |
| return VOCAB_MAPPING |
|
|
|
|
| def get_special_tokens_map(): |
| return SPECIAL_TOKENS_MAP |
|
|
|
|
| def get_tokenizer_config(add_special_tokens: bool = False): |
| config = TOKENIZER_CONFIG |
| if add_special_tokens: |
| config.setdefault("added_tokens_decoder", {}) |
| for i, v in enumerate(SPECIAL_TOKENS_MAP.values()): |
| config["added_tokens_decoder"][str(i)] = v |
| return config |
|
|
|
|
| class Alphabet: |
| prepend_tokens: Tuple[str, ...] = ("<pad>", "<cls>", "<eos>", "<unk>", "<mask>", "<null>") |
| append_tokens: Tuple[str, ...] = () |
| tokens: Tuple[str, ...] |
| nmers: int |
|
|
| def __init__( |
| self, |
| tokens: Sequence[str], |
| prepend_tokens: Tuple[str, ...] | None = None, |
| append_tokens: Tuple[str, ...] | None = None, |
| nmers: int = 1, |
| ): |
| if isinstance(tokens, Alphabet): |
| tokens = tokens.tokens |
| self.tokens = tuple(tokens) |
| if prepend_tokens is not None: |
| self.prepend_tokens = tuple(prepend_tokens) |
| if append_tokens is not None: |
| self.append_tokens = tuple(append_tokens) |
| self.nmers = nmers |
|
|
| @property |
| def vocabulary(self) -> Tuple[str, ...]: |
| return self._vocabulary(self.prepend_tokens, self.tokens, self.nmers, self.append_tokens) |
|
|
| @staticmethod |
| @lru_cache(maxsize=None) |
| def _vocabulary( |
| prepend_tokens: Tuple[str, ...], tokens: Tuple[str, ...], nmers: int, append_tokens: Tuple[str, ...] |
| ) -> Tuple[str, ...]: |
| return prepend_tokens + generate_kmer_vocabulary(tokens, nmers) + append_tokens |
|
|
| def __iter__(self): |
| return iter(self.vocabulary) |
|
|
| def __len__(self): |
| return len(self.vocabulary) |
|
|
| def __contains__(self, item: str): |
| return item in self.vocabulary |
|
|
| def __repr__(self) -> str: |
| repr_parts = [f"Alphabet(tokens={self.tokens}"] |
| if self.nmers > 1: |
| repr_parts.append(f"nmers={self.nmers}") |
| repr_parts.append(f"prepend_tokens={self.prepend_tokens}") |
| repr_parts.append(f"append_tokens={self.append_tokens})") |
| return ", ".join(repr_parts) |
|
|
|
|
| def _merge_extra_special_tokens( |
| additional_special_tokens: List | Tuple | None, |
| kwargs: dict[str, Any], |
| ) -> List | Tuple | None: |
| if "extra_special_tokens" not in kwargs: |
| return additional_special_tokens |
|
|
| extra_special_tokens = kwargs.pop("extra_special_tokens") |
| if additional_special_tokens is None: |
| merged_special_tokens = [] |
| else: |
| merged_special_tokens = list(additional_special_tokens) |
|
|
| if isinstance(extra_special_tokens, dict): |
| extra_tokens = list(extra_special_tokens.values()) |
| elif isinstance(extra_special_tokens, (list, tuple)): |
| extra_tokens = list(extra_special_tokens) |
| else: |
| raise TypeError( |
| f"extra_special_tokens must be dict, list, or tuple, got {type(extra_special_tokens).__name__}" |
| ) |
|
|
| for token in extra_tokens: |
| token_value = token |
| if isinstance(token, dict) and "content" in token: |
| token_value = token["content"] |
| if token_value not in merged_special_tokens: |
| merged_special_tokens.append(token_value) |
| return merged_special_tokens |
|
|
|
|
| def generate_kmer_vocabulary(vocabulary: Tuple[str, ...], nmers: int = 1) -> Tuple[str, ...]: |
| """ |
| Generates a kmer vocabulary given an original vocabulary and the size of kmer. |
| |
| Args: |
| vocabulary (List[str]): The original vocabulary. |
| nmers (int, defaults to 1): The size of kmer to generate. |
| |
| Returns: |
| vocabulary (List[str]): The kmer vocabulary. |
| """ |
|
|
| if nmers <= 1: |
| return vocabulary |
|
|
| special_tokens, tokens = [], [] |
| for token in vocabulary: |
| if token.startswith("<") or token.startswith("["): |
| special_tokens.append(token) |
| else: |
| tokens.append(token) |
|
|
| return tuple(special_tokens) + tuple("".join(kmer) for kmer in product(tokens, repeat=nmers)) |
|
|
|
|
| class Tokenizer(PreTrainedTokenizer): |
| """ |
| Constructs a Base tokenizer. |
| |
| Args: |
| alphabet: List of tokens or an Alphabet object to use in tokenization. |
| Either alphabet or vocab_file must be specified. |
| bos_token: A special token representing the beginning of a sequence. |
| cls_token: A special token representing the classification token. |
| pad_token: A special token representing padding. |
| eos_token: A special token representing the end of a sequence. |
| sep_token: A special token representing the separator token. |
| unk_token: A special token representing unknown tokens. |
| mask_token: A special token representing the mask token. |
| null_token: A special token representing the null token. |
| additional_special_tokens: Additional special tokens to add to the vocabulary. |
| do_upper_case: Whether to convert input to uppercase. |
| vocab_file: Path to a vocabulary file. |
| Either alphabet or vocab_file must be specified. |
| |
| Examples: |
| >>> from multimolecule.tokenisers import Tokenizer |
| >>> tokenizer = Tokenizer(["A", "C", "G", "T", "N"], unk_token="N") |
| >>> tokenizer('ACGTN')["input_ids"] |
| [0, 1, 2, 3, 4] |
| >>> tokenizer('acgtn')["input_ids"] |
| [0, 1, 2, 3, 4] |
| >>> len(tokenizer) |
| 5 |
| >>> tokenizer = Tokenizer(["A", "C", "G", "T", "N"], unk_token="N", do_upper_case=False) |
| >>> tokenizer('ACGTN')["input_ids"] |
| [0, 1, 2, 3, 4] |
| >>> tokenizer('acgtn')["input_ids"] |
| [4, 4, 4, 4, 4] |
| >>> tokenizer('ACgtN')["input_ids"] |
| [0, 1, 4, 4, 4] |
| >>> tokenizer = Tokenizer(["<pad>", "<cls>", "A", "C", "G", "T", "N", "<mask>", "<eos>"]) |
| >>> tokenizer('ACGTN')["input_ids"] |
| [1, 2, 3, 4, 5, 6, 8] |
| >>> tokenizer('AC<mask>GTN')["input_ids"] |
| [1, 2, 3, 7, 4, 5, 6, 8] |
| >>> tokenizer(['TATATAT', 'ATCGN'], padding=True)["input_ids"] |
| [[1, 5, 2, 5, 2, 5, 2, 5, 8], [1, 2, 5, 3, 4, 6, 8, 0, 0]] |
| """ |
|
|
| model_input_names = ["input_ids", "attention_mask"] |
| vocab_files_names = VOCAB_FILES_NAMES |
| do_upper_case: bool = True |
|
|
| def __init__( |
| self, |
| alphabet: Alphabet | List[str] | None = None, |
| bos_token: str | None = ..., |
| cls_token: str | None = ..., |
| pad_token: str | None = ..., |
| eos_token: str | None = ..., |
| sep_token: str | None = ..., |
| unk_token: str | None = ..., |
| mask_token: str | None = ..., |
| null_token: str | None = ..., |
| additional_special_tokens: List | Tuple | None = None, |
| do_upper_case: bool = True, |
| vocab_file: str | None = None, |
| **kwargs, |
| ): |
| if alphabet is None and vocab_file is None: |
| raise ValueError("You must specify either alphabet or vocab_file") |
|
|
| if vocab_file is not None: |
| alphabet = self.load_vocabulary(vocab_file) |
|
|
| self._id_to_token: OrderedDict[int, str] = OrderedDict(enumerate(alphabet)) |
| self._token_to_id: OrderedDict[str, int] = OrderedDict({tok: ind for ind, tok in enumerate(alphabet)}) |
|
|
| if cls_token is ...: |
| cls_token = self.identify_special_token(alphabet, "cls") |
| if bos_token is ...: |
| bos_token = cls_token |
| if pad_token is ...: |
| pad_token = self.identify_special_token(alphabet, "pad") |
| if eos_token is ...: |
| eos_token = self.identify_special_token(alphabet, "eos") |
| if sep_token is ...: |
| sep_token = self.identify_special_token(alphabet, "sep") or self.identify_special_token(alphabet, "eos") |
| if unk_token is ...: |
| unk_token = self.identify_special_token(alphabet, "unk") |
| if mask_token is ...: |
| mask_token = self.identify_special_token(alphabet, "mask") |
| if null_token is ...: |
| null_token = self.identify_special_token(alphabet, "null") |
| additional_special_tokens = _merge_extra_special_tokens(additional_special_tokens, kwargs) |
| if additional_special_tokens is None: |
| additional_special_tokens = [] |
| if null_token in alphabet and null_token not in additional_special_tokens: |
| additional_special_tokens = list(additional_special_tokens) |
| additional_special_tokens.append(null_token) |
|
|
| super().__init__( |
| bos_token=bos_token, |
| cls_token=cls_token, |
| pad_token=pad_token, |
| eos_token=eos_token, |
| sep_token=sep_token, |
| unk_token=unk_token, |
| mask_token=mask_token, |
| additional_special_tokens=additional_special_tokens, |
| **kwargs, |
| ) |
| self.do_upper_case = do_upper_case |
| self._id_to_token.update(self.added_tokens_decoder) |
| self._token_to_id.update(self.added_tokens_encoder) |
|
|
| |
| |
|
|
| |
| |
|
|
| def _tokenize(self, text: str, **kwargs): |
| if self.do_upper_case: |
| text = text.upper() |
| return list(text) |
|
|
| def _convert_token_to_id(self, token: str) -> int: |
| id = self._token_to_id.get(token, self.unk_token_id) |
| if id is None: |
| raise ValueError(f"Token {token} is not in the vocabulary, and no UNK token is set!") |
| return id |
|
|
| def _convert_id_to_token(self, index: int) -> str: |
| token = self._id_to_token.get(index, self.unk_token) |
| if token is None: |
| raise ValueError(f"ID {index} is not in the vocabulary, and no UNK token is set!") |
| return token |
|
|
| def token_to_id(self, token: str) -> int: |
| return self._convert_token_to_id(token) |
|
|
| def id_to_token(self, index: int) -> str: |
| return self._convert_id_to_token(index) |
|
|
| def build_inputs_with_special_tokens( |
| self, token_ids_0: List[int], token_ids_1: List[int] | None = None |
| ) -> List[int]: |
| bos = [self.bos_token_id] |
| sep = [self.sep_token_id] |
| eos = [self.eos_token_id] |
| if token_ids_1 is None: |
| if self.bos_token_id is None: |
| if self.eos_token_id is None: |
| return token_ids_0 |
| return token_ids_0 + eos |
| if self.eos_token_id is None: |
| return bos + token_ids_0 |
| return bos + token_ids_0 + eos |
| if self.eos_token_id is None: |
| raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!") |
| return bos + token_ids_0 + sep + token_ids_1 + eos |
|
|
| def get_special_tokens_mask( |
| self, token_ids_0: List[int], token_ids_1: List[int] | None = None, already_has_special_tokens: bool = False |
| ) -> List[int]: |
| """ |
| Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding |
| special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods. |
| |
| Args: |
| token_ids_0 (`List[int]`): |
| List of ids of the first sequence. |
| token_ids_1 (`List[int]`, *optional*): |
| List of ids of the second sequence. |
| already_has_special_tokens (`bool`, *optional*, defaults to `False`): |
| Whether or not the token list is already formatted with special tokens for the model. |
| |
| Returns: |
| A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. |
| """ |
| if already_has_special_tokens: |
| if token_ids_1 is not None: |
| raise ValueError( |
| "You should not supply a second sequence if the provided sequence of " |
| "ids is already formatted with special tokens for the model." |
| ) |
|
|
| return [1 if token in self.all_special_ids else 0 for token in token_ids_0] |
| mask = [0] * len(token_ids_0) |
| if self.bos_token_id is not None: |
| mask = [1] + mask |
| if self.sep_token_id is not None: |
| mask += [1] |
| if token_ids_1 is not None: |
| mask += [0] * len(token_ids_1) |
| if self.eos_token_id is not None: |
| mask += [1] |
| return mask |
|
|
| @staticmethod |
| def load_vocabulary(vocab_file: str | Path) -> List[str]: |
| with open(vocab_file, encoding="utf-8") as reader: |
| vocabulary = reader.read().splitlines() |
| return vocabulary |
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: str | None = None): |
| vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt") |
| with open(vocab_file, "w") as f: |
| f.write("\n".join(self.all_tokens)) |
| return (vocab_file,) |
|
|
| @staticmethod |
| def identify_special_token(alphabet: Alphabet | List[str], token) -> str | None: |
| tokens = [i for i in alphabet if token in i.lower()] |
| if len(tokens) == 1: |
| return tokens[0] |
| if len(tokens) == 0: |
| return None |
| raise ValueError(f"Token {token} is ambiguous, could be {tokens}") |
|
|
| def get_vocab(self): |
| return dict(self.vocab, **self.added_tokens_encoder) |
|
|
| @property |
| def vocab(self) -> OrderedDict[str, int]: |
| return self._token_to_id.copy() |
|
|
| @property |
| def all_tokens(self) -> List[str]: |
| return list(self.get_vocab().keys()) |
|
|
| @property |
| def vocab_size(self) -> int: |
| return len(self.all_tokens) |
|
|
|
|
| class RnaTokenizer(Tokenizer): |
| """ |
| Tokenizer for RNA sequences. |
| |
| Args: |
| alphabet: alphabet to use for tokenization. |
| |
| - If is `None`, the standard RNA alphabet will be used. |
| - If is a `string`, it should correspond to the name of a predefined alphabet. The options include |
| + `standard` |
| + `extended` |
| + `streamline` |
| + `nucleobase` |
| - If is an alphabet or a list of characters, that specific alphabet will be used. |
| nmers: Size of kmer to tokenize. |
| codon: Whether to tokenize into codons. |
| replace_T_with_U: Whether to replace T with U. |
| do_upper_case: Whether to convert input to uppercase. |
| |
| Examples: |
| >>> from multimolecule import RnaTokenizer |
| >>> tokenizer = RnaTokenizer() |
| >>> tokenizer('<pad><cls><eos><unk><mask><null>ACGUNRYSWKMBDHV.X*-I')["input_ids"] |
| [1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 2] |
| >>> tokenizer('acgu')["input_ids"] |
| [1, 6, 7, 8, 9, 2] |
| >>> tokenizer('acgt')["input_ids"] |
| [1, 6, 7, 8, 9, 2] |
| >>> tokenizer = RnaTokenizer(replace_T_with_U=False) |
| >>> tokenizer('acgt')["input_ids"] |
| [1, 6, 7, 8, 3, 2] |
| >>> tokenizer = RnaTokenizer(nmers=3) |
| >>> tokenizer('uagcuuauc')["input_ids"] |
| [1, 83, 17, 64, 49, 96, 84, 22, 2] |
| >>> tokenizer = RnaTokenizer(codon=True) |
| >>> tokenizer('uagcuuauc')["input_ids"] |
| [1, 83, 49, 22, 2] |
| >>> tokenizer('uagcuuauca')["input_ids"] |
| Traceback (most recent call last): |
| ValueError: length of input sequence must be a multiple of 3 for codon tokenization, but got 10 |
| """ |
|
|
| model_input_names = ["input_ids", "attention_mask"] |
|
|
| def __init__( |
| self, |
| alphabet: Alphabet | str | List[str] | None = None, |
| nmers: int = 1, |
| codon: bool = True, |
| replace_T_with_U: bool = True, |
| do_upper_case: bool = True, |
| additional_special_tokens: List | Tuple | None = None, |
| **kwargs, |
| ): |
| if codon and (nmers > 1 and nmers != 3): |
| raise ValueError("Codon and nmers cannot be used together.") |
| if codon: |
| nmers = 3 |
| if not isinstance(alphabet, Alphabet): |
| alphabet = get_alphabet(alphabet, nmers=nmers) |
| additional_special_tokens = _merge_extra_special_tokens(additional_special_tokens, kwargs) |
| super().__init__( |
| alphabet=alphabet, |
| nmers=nmers, |
| codon=codon, |
| replace_T_with_U=replace_T_with_U, |
| do_upper_case=do_upper_case, |
| additional_special_tokens=additional_special_tokens, |
| **kwargs, |
| ) |
| self.replace_T_with_U = replace_T_with_U |
| self.nmers = nmers |
| self.codon = codon |
|
|
| def _tokenize(self, text: str, **kwargs): |
| if self.do_upper_case: |
| text = text.upper() |
| if self.replace_T_with_U: |
| text = text.replace("T", "U") |
| if self.codon: |
| if len(text) % 3 != 0: |
| raise ValueError( |
| f"length of input sequence must be a multiple of 3 for codon tokenization, but got {len(text)}" |
| ) |
| return [text[i : i + 3] for i in range(0, len(text), 3)] |
| if self.nmers > 1: |
| return [text[i : i + self.nmers] for i in range(len(text) - self.nmers + 1)] |
| return list(text) |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """ |
| Rotary position embeddings based on those in |
| [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). |
| |
| Query and keys are transformed by rotation |
| matrices which depend on their relative positions. |
| |
| Tip: **Cache** |
| The inverse frequency buffer is cached and updated only when the sequence length changes or the device changes. |
| |
| Success: **Sequence Length** |
| Rotary Embedding is irrespective of the sequence length and can be used for any sequence length. |
| Use the `scale` parameter to extend context length beyond training (e.g., scale=2.0 doubles effective context). |
| |
| Example: |
| >>> embedding = RotaryEmbedding(embedding_dim=64) |
| >>> query, key = torch.randn(2, 4, 28, 64), torch.randn(2, 4, 28, 64) |
| >>> query, key = embedding(query, key) |
| >>> query.shape |
| torch.Size([2, 4, 28, 64]) |
| >>> # For extended context length |
| >>> embedding_extended = RotaryEmbedding(embedding_dim=64, scale=2.0) |
| >>> embedding.state_dict() # no weight in state_dict |
| OrderedDict() |
| """ |
|
|
| _seq_len_cached: int | None = None |
| _cos_cached: Tensor | None = None |
| _sin_cached: Tensor | None = None |
|
|
| def __init__( |
| self, |
| embedding_dim: int, |
| base: float = 10000.0, |
| scale: float = 1.0, |
| dtype: torch.dtype = torch.float32, |
| ): |
| """ |
| Initialize rotary position embeddings. |
| |
| Args: |
| embedding_dim: Dimension of the embeddings (must be even) |
| base: Base for computing inverse frequencies. Defaults to 10000.0. |
| scale: Scaling factor for frequencies. Values > 1.0 extend context length |
| (e.g., scale=2.0 doubles the effective context). Defaults to 1.0. |
| dtype: Data type for computations. Defaults to torch.float32. |
| """ |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, embedding_dim, 2, dtype=dtype) / embedding_dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.scale = scale |
|
|
| def forward(self, q: Tensor, k: Tensor, offset: int = 0, seq_length: int | None = None) -> Tuple[Tensor, Tensor]: |
| """ |
| Apply rotary position embeddings to query and key tensors. |
| |
| Args: |
| q: Query tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)` |
| k: Key tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)` |
| offset: Position offset for the start of the sequence (used with past_key_values). |
| Defaults to 0. |
| seq_length: Full sequence length including offset. If None, uses the sequence length |
| from the input tensors. Required when offset > 0. |
| |
| Returns: |
| Tuple of (rotated_query, rotated_key) tensors with the same shapes as inputs. |
| """ |
| if offset > 0 and seq_length is None: |
| raise ValueError("seq_length must be provided when offset > 0") |
|
|
| if seq_length is None: |
| seq_length = k.shape[-2] |
|
|
| self._update_cos_sin_tables(k, seq_len_dim=-2, seq_length=seq_length) |
| return self.apply_rotary_pos_emb(q, offset=offset), self.apply_rotary_pos_emb(k, offset=offset) |
|
|
| def _update_cos_sin_tables( |
| self, x: Tensor, seq_len_dim: int = 2, seq_length: int | None = None |
| ) -> Tuple[Tensor, Tensor]: |
| """ |
| Update cached cos/sin tables for rotary embeddings. |
| |
| Args: |
| x: Input tensor to determine device and dtype |
| seq_len_dim: Dimension containing sequence length (default: -2) |
| seq_length: Full sequence length to cache. If None, uses x.shape[seq_len_dim] |
| """ |
| if seq_length is None: |
| seq_length = x.shape[seq_len_dim] |
|
|
| if seq_length != self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != x.device: |
| self._seq_len_cached = seq_length |
| inv_freq = self.inv_freq |
| if not isinstance(inv_freq, Tensor): |
| raise RuntimeError("inv_freq buffer is not a Tensor") |
| t = torch.arange(seq_length, device=x.device, dtype=inv_freq.dtype) |
| |
| freqs = torch.outer(t, inv_freq) / self.scale |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
| self._cos_cached = emb.cos()[None, None, :, :] |
| self._sin_cached = emb.sin()[None, None, :, :] |
| |
| assert self._cos_cached is not None and self._sin_cached is not None |
| return self._cos_cached, self._sin_cached |
|
|
| def apply_rotary_pos_emb(self, x: Tensor, offset: int = 0) -> Tensor: |
| """ |
| Apply rotary position embeddings to a tensor. |
| |
| Args: |
| x: Input tensor of shape `(batch_size, num_heads, seq_length, embedding_dim)` |
| offset: Position offset for the start of the sequence (used with past_key_values). |
| Defaults to 0. |
| |
| Returns: |
| Rotated tensor with the same shape as input. |
| """ |
| if self._cos_cached is None or self._sin_cached is None: |
| raise RuntimeError("Cos/sin tables not initialized. Call forward() or _update_cos_sin_tables() first.") |
|
|
| cos = self._cos_cached[:, :, offset : offset + x.shape[-2], :] |
| sin = self._sin_cached[:, :, offset : offset + x.shape[-2], :] |
| return (x * cos) + (self.rotate_half(x) * sin) |
|
|
| @staticmethod |
| def rotate_half(x: Tensor) -> Tensor: |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|