| import torch |
| import numpy as np |
| import math |
| import logging |
|
|
| from typing import Any, Optional, Union, Sequence |
| import torch.nn as nn |
| from transformers import PreTrainedModel, T5EncoderModel, T5ForConditionalGeneration, T5ForQuestionAnswering, T5ForTokenClassification, T5Model |
| from torch import nn |
| from transformers.models.t5.modeling_t5 import T5Attention, T5DenseActDense, T5DenseGatedActDense, T5ClassificationHead, T5LayerNorm, T5Stack, T5Block, T5LayerSelfAttention, T5LayerFF |
| from transformers.cache_utils import DynamicCache, EncoderDecoderCache |
| from transformers.models.t5.configuration_t5 import T5Config |
| from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutput |
| from transformers.utils import DUMMY_INPUTS, DUMMY_MASK, is_torch_fx_proxy, is_torchdynamo_compiling |
| from transformers.utils.deprecation import deprecate_kwarg |
| from .common import M5Pooler |
| from .prepare_data import get_positional_encodings_and_align |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class M5EncoderConfig(T5Config): |
| model_type = "m5_model" |
|
|
| def __init__( |
| self, |
| d_ff= 2048, |
| d_kv = 64, |
| d_model = 512, |
| num_layers = 24, |
| num_heads = 12, |
| pad_token_id = 2, |
| dropout_rate = 0, |
| feed_forward_proj = "gated-gelu", |
| classifier_dropout=0, |
| relative_attention_max_distance=96, |
| relative_attention_num_buckets=32, |
| vocab_size=1032, |
| num_decoder_layers=0, |
| **kwargs, |
| ): |
| super().__init__(d_ff=d_ff, |
| d_kv=d_kv, |
| d_model=d_model, |
| num_layers=num_layers, |
| num_heads=num_heads, |
| pad_token_id=pad_token_id, |
| dropout_rate=dropout_rate, |
| feed_forward_proj=feed_forward_proj, |
| classifier_dropout=classifier_dropout, |
| relative_attention_max_distance=relative_attention_max_distance, |
| relative_attention_num_buckets=relative_attention_num_buckets, |
| vocab_size=vocab_size, |
| num_decoder_layers=num_decoder_layers, |
| **kwargs) |
|
|
| class M5Encoder(PreTrainedModel): |
| config_class = M5EncoderConfig |
| base_model_prefix = "encoder" |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = M5EncoderModel(config) |
| |
| def get_input_embeddings(self): |
| return self.model.shared |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.model.shared = new_embeddings |
| self.model.encoder.embed_tokens = new_embeddings |
|
|
| def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs): |
| return self.model(input_ids=input_ids, |
| attention_mask=attention_mask, |
| relative_position=relative_position) |
| |
| @staticmethod |
| def get_positional_encodings_and_align( |
| smiles: str, |
| seed: int, |
| token_regr: Optional[np.ndarray] = None, |
| ) -> tuple[str, np.ndarray, Optional[np.ndarray]]: |
| """ |
| Convert a SMILES string into a SELFIES tokenization, compute pairwise |
| molecular-graph distance encodings, and optionally align token-level |
| regression labels to the new token order. |
| |
| Args: |
| smiles: Input molecule as a SMILES string. Does not need to be |
| canonical — canonicalization and optional randomization are |
| applied internally. |
| seed: Epoch/seed value controlling SMILES augmentation. When 0, |
| the canonical SELFIES is used; any other value produces a |
| reproducible randomized SELFIES variant. |
| token_regr: Optional array for reproducibility. |
| Array of per-atom regression labels (e.g. |
| Löwdin charges) aligned to the original SMILES atom order. |
| If provided, labels are re-aligned to match the SELFIES token |
| order of the (possibly randomized) output SMILES. |
| Shape: ``(n_atoms,)``. |
| |
| Returns: |
| A tuple of: |
| - **selfies** (``str``): SELFIES encoding of the (possibly |
| randomized) SMILES. |
| - **pos_encod** (``np.ndarray``): Pairwise distance matrix of |
| shape ``(seq_len, seq_len)`` with ``dtype=np.int16``. Entries |
| are shortest-path graph distances between atoms, capped at |
| ``np.iinfo(np.int16).max - 1``. Special values: ``0`` for |
| CLS-to-token, token-to-CLS, and ring/dot-separated fragment |
| pairs; ``-1`` for intra-branch/ring structural tokens; |
| ``np.iinfo(np.int16).max`` for padding positions. |
| - **token_regr_selfies** (``np.ndarray`` or ``None``): Labels |
| re-aligned to SELFIES token positions, shape |
| ``(seq_len - 1,)``, with ``np.nan`` for non-atom tokens |
| (branches, rings, dots). ``None`` if ``token_regr`` was not |
| provided. |
| """ |
|
|
| return get_positional_encodings_and_align(smiles, token_regr, seed) |
| |
| @staticmethod |
| def collate_for_dataset(batch: list[dict[str, Any]], n_global_regr: int = 0, PAD_TOKEN_ID: int = 2): |
| """ |
| Collate processed data for pytorch dataloaders. |
| |
| Each item in ``batch`` is a 3-tuple ``(token_dict, pos_encod, reg)`` |
| where: |
| |
| - ``token_dict`` is a dict with keys ``"input_ids"`` (``np.ndarray``, |
| shape ``(L,)``) and ``"attention_mask"`` (``np.ndarray``, shape |
| ``(L,)``), as produced by a tokenizer. |
| - ``pos_encod`` is an ``np.ndarray`` of shape ``(L, L)`` and dtype |
| ``np.int16`` holding pairwise molecular-graph distances, as returned |
| by :meth:`get_positional_encodings_and_align`. |
| - ``reg`` is an ``np.ndarray`` of shape |
| ``(n_global_regr + L - 1,)`` containing first the |
| ``n_global_regr`` sequence-level regression targets followed by |
| ``L - 1`` token-level targets (one per non-CLS token). Ignored when |
| ``n_global_regr == 0``. |
| |
| All sequences are right-padded to the length of the longest sequence |
| in the batch (``L_max``): |
| |
| - ``input_ids`` is padded with ``PAD_TOKEN_ID``. |
| - ``attention_mask`` is padded with ``0``. |
| - ``pos_encod`` is padded with ``np.iinfo(np.int16).max``; the |
| diagonal of the padded region is set to ``0`` to be consistent with |
| real token self-distances. |
| - ``labels`` (when present) is padded with ``float("nan")`` so that |
| padding positions can be masked out in the loss. |
| |
| Args: |
| batch: List of ``(token_dict, pos_encod, reg)`` tuples, one per |
| sample. |
| n_global_regr: Number of sequence-level regression targets at the |
| start of each ``reg`` array. When ``0``, no ``"labels"`` key |
| is included in the returned dict. |
| PAD_TOKEN_ID: Token id used to fill padded positions in |
| ``input_ids``. Defaults to ``2``. |
| |
| Returns: |
| A dict with the following keys: |
| |
| - ``"input_ids"`` — ``torch.LongTensor`` of shape |
| ``(B, L_max)``. |
| - ``"attention_mask"`` — ``torch.LongTensor`` of shape |
| ``(B, L_max)``; ``1`` for real tokens, ``0`` for padding. |
| - ``"positional_encodings"`` — ``torch.ShortTensor`` of shape |
| ``(B, L_max, L_max)``. |
| - ``"labels"`` *(only when* ``n_global_regr > 0`` *)* — |
| ``torch.FloatTensor`` of shape |
| ``(B, n_global_regr + L_max - 1)``; ``nan`` for padding |
| positions. |
| """ |
| token_dicts, pos_encod, regs = zip(*batch) |
| lengths = [td["input_ids"].shape[0] for td in token_dicts] |
| L_max = max(lengths) |
| B = len(batch) |
| |
| input_ids_out = np.full((B, L_max), PAD_TOKEN_ID, dtype=np.int64) |
| attn_mask_out = np.zeros((B, L_max), dtype=np.int64) |
| pos_encod_out = np.full((B, L_max, L_max), np.iinfo(np.int16).max, dtype=np.int16) |
| |
| if n_global_regr > 0: |
| reg_out = np.full((B, n_global_regr + L_max - 1), float("nan"), dtype=np.float32) |
| |
| |
| |
| diag_idx = np.arange(L_max) |
| pos_encod_out[:, diag_idx, diag_idx] = 0 |
| |
| for i, (td, pe, reg) in enumerate(zip(token_dicts, pos_encod, regs)): |
| L = lengths[i] |
| |
| |
| input_ids_out[i, :L] = td["input_ids"] |
| attn_mask_out[i, :L] = td["attention_mask"] |
| |
| |
| pos_encod_out[i, :L, :L] = pe |
| |
| |
| if n_global_regr > 0: |
| reg_out[i, :n_global_regr] = reg[:n_global_regr] |
| reg_out[i, n_global_regr:n_global_regr + L - 1] = reg[n_global_regr:] |
| |
| out = { |
| "input_ids": torch.from_numpy(input_ids_out), |
| "attention_mask": torch.from_numpy(attn_mask_out), |
| "positional_encodings": torch.from_numpy(pos_encod_out), |
| } |
| |
| if n_global_regr > 0: |
| out["labels"] = torch.from_numpy(reg_out) |
| |
| return out |
| |
| |
|
|
| class M5EncoderModel(T5EncoderModel): |
| def __init__(self, config: T5Config): |
| super().__init__(config) |
|
|
| encoder_config = config |
| encoder_config.use_cache = False |
| encoder_config.is_encoder_decoder = False |
| self.encoder = M5Stack(encoder_config, self.shared) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| relative_position: Optional[torch.LongTensor] = None |
| ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]: |
| r""" |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you |
| should be able to pad the inputs on both the right and the left. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for detail. |
| |
| To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, T5EncoderModel |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") |
| >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small") |
| >>> input_ids = tokenizer( |
| ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" |
| ... ).input_ids # Batch size 1 |
| >>> outputs = model(input_ids=input_ids) |
| >>> last_hidden_states = outputs.last_hidden_state |
| ```""" |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| head_mask=head_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| relative_position=relative_position.to(dtype=torch.int32) if relative_position is not None else None |
| ) |
|
|
| return encoder_outputs |
|
|
| class M5Stack(T5Stack): |
| def __init__(self, config, embed_tokens=None): |
| super().__init__(config, embed_tokens) |
|
|
| self.block = nn.ModuleList( |
| [M5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] |
| ) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| inputs_embeds=None, |
| head_mask=None, |
| cross_attn_head_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| cache_position=None, |
| relative_position=None |
| ): |
| |
| if self.model_parallel: |
| torch.cuda.set_device(self.first_device) |
| self.embed_tokens = self.embed_tokens.to(self.first_device) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if input_ids is not None and inputs_embeds is not None: |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError( |
| f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" |
| ) |
| elif input_ids is not None: |
| input_shape = input_ids.size() |
| input_ids = input_ids.view(-1, input_shape[-1]) |
| elif inputs_embeds is not None: |
| input_shape = inputs_embeds.size()[:-1] |
| else: |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| if inputs_embeds is None: |
| if self.embed_tokens is None: |
| raise ValueError("You have to initialize the model with valid token embeddings") |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| batch_size, seq_length = input_shape |
|
|
| if use_cache is True: |
| if not self.is_decoder: |
| raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") |
|
|
| if self.is_decoder: |
| if use_cache and past_key_values is None: |
| if self.config.is_encoder_decoder: |
| past_key_values = EncoderDecoderCache( |
| DynamicCache(config=self.config), DynamicCache(config=self.config) |
| ) |
| else: |
| past_key_values = DynamicCache(config=self.config) |
| elif not self.is_decoder: |
| |
| |
| past_key_values = None |
|
|
| past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| if cache_position is None: |
| cache_position = torch.arange( |
| past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device |
| ) |
|
|
| if attention_mask is None and not is_torchdynamo_compiling(): |
| |
| mask_seq_length = past_key_values_length + seq_length |
| attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
|
|
| if self.config.is_decoder: |
| causal_mask = self._update_causal_mask( |
| attention_mask, |
| inputs_embeds, |
| cache_position, |
| past_key_values.self_attention_cache |
| if isinstance(past_key_values, EncoderDecoderCache) |
| else past_key_values, |
| output_attentions, |
| ) |
| elif attention_mask is not None: |
| causal_mask = attention_mask[:, None, None, :] |
| causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) |
| causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min |
| else: |
| causal_mask = None |
|
|
| |
| |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
| encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
| if encoder_attention_mask is None: |
| encoder_attention_mask = torch.ones( |
| encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long |
| ) |
| encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| else: |
| encoder_extended_attention_mask = None |
|
|
| |
| head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
| cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and self.is_decoder) else None |
| position_bias = None |
| encoder_decoder_position_bias = None |
|
|
| hidden_states = self.dropout(inputs_embeds) |
|
|
| for i, layer_module in enumerate(self.block): |
| layer_head_mask = head_mask[i] |
| cross_attn_layer_head_mask = cross_attn_head_mask[i] |
| |
| if self.model_parallel: |
| torch.cuda.set_device(hidden_states.device) |
| |
| if causal_mask is not None: |
| causal_mask = causal_mask.to(hidden_states.device) |
| if position_bias is not None: |
| position_bias = position_bias.to(hidden_states.device) |
| if encoder_hidden_states is not None: |
| encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) |
| if encoder_extended_attention_mask is not None: |
| encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) |
| if encoder_decoder_position_bias is not None: |
| encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) |
| if layer_head_mask is not None: |
| layer_head_mask = layer_head_mask.to(hidden_states.device) |
| if cross_attn_layer_head_mask is not None: |
| cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| layer_outputs = layer_module( |
| hidden_states, |
| causal_mask, |
| position_bias, |
| encoder_hidden_states, |
| encoder_extended_attention_mask, |
| encoder_decoder_position_bias, |
| layer_head_mask=layer_head_mask, |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| relative_position=relative_position |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| |
| |
| |
| position_bias = layer_outputs[1] |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2] |
|
|
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[2],) |
| if self.is_decoder: |
| all_cross_attentions = all_cross_attentions + (layer_outputs[4],) |
|
|
| |
| if self.model_parallel: |
| for k, v in self.device_map.items(): |
| if i == v[-1] and "cuda:" + str(k) != self.last_device: |
| hidden_states = hidden_states.to("cuda:" + str(k + 1)) |
|
|
| hidden_states = self.final_layer_norm(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| |
| |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| past_key_values, |
| all_hidden_states, |
| all_attentions, |
| all_cross_attentions, |
| ] |
| if v is not None |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
| |
| class M5Block(T5Block): |
| def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__(config, has_relative_attention_bias, layer_idx) |
| self.layer = nn.ModuleList() |
| self.layer.append( |
| M5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) |
| ) |
| if self.is_decoder: |
| self.layer.append(M5LayerSelfAttention(config, layer_idx=layer_idx)) |
| self.layer.append(T5LayerFF(config)) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_bias=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| encoder_decoder_position_bias=None, |
| layer_head_mask=None, |
| cross_attn_layer_head_mask=None, |
| past_key_values=None, |
| use_cache=False, |
| output_attentions=False, |
| return_dict=True, |
| cache_position=None, |
| relative_position=None, |
| ): |
| self_attention_outputs = self.layer[0]( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| layer_head_mask=layer_head_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| cache_position=cache_position, |
| relative_position=relative_position |
| ) |
| hidden_states = self_attention_outputs[0] |
| attention_outputs = self_attention_outputs[1:] |
|
|
| |
| if hidden_states.dtype == torch.float16: |
| clamp_value = torch.where( |
| torch.isinf(hidden_states).any(), |
| torch.finfo(hidden_states.dtype).max - 1000, |
| torch.finfo(hidden_states.dtype).max, |
| ) |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
| do_cross_attention = self.is_decoder and encoder_hidden_states is not None |
| if do_cross_attention: |
| cross_attention_outputs = self.layer[1]( |
| hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| layer_head_mask=cross_attn_layer_head_mask, |
| past_key_values=past_key_values, |
| query_length=cache_position[-1] + 1, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = cross_attention_outputs[0] |
|
|
| |
| if hidden_states.dtype == torch.float16: |
| clamp_value = torch.where( |
| torch.isinf(hidden_states).any(), |
| torch.finfo(hidden_states.dtype).max - 1000, |
| torch.finfo(hidden_states.dtype).max, |
| ) |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
| |
| attention_outputs = attention_outputs + cross_attention_outputs[1:] |
|
|
| |
| hidden_states = self.layer[-1](hidden_states) |
|
|
| |
| if hidden_states.dtype == torch.float16: |
| clamp_value = torch.where( |
| torch.isinf(hidden_states).any(), |
| torch.finfo(hidden_states.dtype).max - 1000, |
| torch.finfo(hidden_states.dtype).max, |
| ) |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
|
|
| outputs = (hidden_states,) |
|
|
| return ( |
| outputs + attention_outputs |
| ) |
|
|
| class M5LayerSelfAttention(T5LayerSelfAttention): |
| def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__(config, has_relative_attention_bias, layer_idx) |
| self.SelfAttention = M5Attention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_bias=None, |
| layer_head_mask=None, |
| past_key_values=None, |
| use_cache=False, |
| output_attentions=False, |
| cache_position=None, |
| relative_position=None, |
| ): |
| |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.SelfAttention( |
| normed_hidden_states, |
| mask=attention_mask, |
| position_bias=position_bias, |
| layer_head_mask=layer_head_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| cache_position=cache_position, |
| relative_position=relative_position |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] |
| return outputs |
|
|
| class M5Attention(T5Attention): |
| """ |
| def __init__( |
| self, |
| config: T5Config, |
| has_relative_attention_bias=False, |
| layer_idx: Optional[int] = None, |
| ): |
| super().__init__(config, has_relative_attention_bias, layer_idx) |
| |
| if self.has_relative_attention_bias: |
| self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) |
| else: |
| self.elaborate = nn.Linear() |
| |
| """ |
|
|
| @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
| def forward( |
| self, |
| hidden_states, |
| mask=None, |
| key_value_states=None, |
| position_bias=None, |
| past_key_values=None, |
| layer_head_mask=None, |
| query_length=None, |
| use_cache=False, |
| output_attentions=False, |
| cache_position=None, |
| relative_position=None |
| |
| ): |
| """ |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
| """ |
| |
| |
| batch_size, seq_length = hidden_states.shape[:2] |
|
|
| |
| is_cross_attention = key_value_states is not None |
|
|
| query_states = self.q(hidden_states) |
| query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
|
|
| |
| is_updated = False |
| if isinstance(past_key_values, EncoderDecoderCache): |
| is_updated = past_key_values.is_updated.get(self.layer_idx) |
| if is_cross_attention: |
| |
| curr_past_key_value = past_key_values.cross_attention_cache |
| else: |
| curr_past_key_value = past_key_values.self_attention_cache |
| else: |
| curr_past_key_value = past_key_values |
|
|
| current_states = key_value_states if is_cross_attention else hidden_states |
| if is_cross_attention and past_key_values is not None and is_updated: |
| |
| key_states = curr_past_key_value.layers[self.layer_idx].keys |
| value_states = curr_past_key_value.layers[self.layer_idx].values |
| else: |
| key_states = self.k(current_states) |
| value_states = self.v(current_states) |
| key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
|
|
| if past_key_values is not None: |
| |
| cache_position = cache_position if not is_cross_attention else None |
| key_states, value_states = curr_past_key_value.update( |
| key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| ) |
| |
| if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): |
| past_key_values.is_updated[self.layer_idx] = True |
|
|
| |
| scores = torch.matmul(query_states, key_states.transpose(3, 2)) |
|
|
| if position_bias is None: |
| key_length = key_states.shape[-2] |
| |
| real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 |
| if not self.has_relative_attention_bias: |
| position_bias = torch.zeros( |
| (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype |
| ) |
| if self.gradient_checkpointing and self.training: |
| position_bias.requires_grad = True |
| else: |
| position_bias = self.compute_bias( |
| real_seq_length, key_length, device=scores.device, cache_position=cache_position, relative_position=relative_position |
| ) |
| position_bias = position_bias[:, :, -seq_length:, :] |
|
|
| if mask is not None: |
| causal_mask = mask[:, :, :, : key_states.shape[-2]] |
| position_bias = position_bias + causal_mask |
|
|
| if self.pruned_heads: |
| mask = torch.ones(position_bias.shape[1]) |
| mask[list(self.pruned_heads)] = 0 |
| position_bias_masked = position_bias[:, mask.bool()] |
| else: |
| position_bias_masked = position_bias |
|
|
| scores += position_bias_masked |
|
|
| |
| attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) |
| attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
|
| |
| if layer_head_mask is not None: |
| attn_weights = attn_weights * layer_head_mask |
|
|
| attn_output = torch.matmul(attn_weights, value_states) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
| attn_output = attn_output.view(batch_size, -1, self.inner_dim) |
| attn_output = self.o(attn_output) |
|
|
| outputs = (attn_output, position_bias) |
|
|
| if output_attentions: |
| outputs = outputs + (attn_weights,) |
| return outputs |
|
|
| @staticmethod |
| def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): |
| """ |
| Adapted from Mesh Tensorflow: |
| https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 |
| |
| Translate relative position to a bucket number for relative attention. The relative position is defined as |
| memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to |
| position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for |
| small absolute relative_position and larger buckets for larger absolute relative_positions. All relative |
| positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. |
| This should allow for more graceful generalization to longer sequences than the model has been trained on |
| |
| Args: |
| relative_position: an int32 Tensor |
| bidirectional: a boolean - whether the attention is bidirectional |
| num_buckets: an integer |
| max_distance: an integer |
| |
| Returns: |
| a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) |
| """ |
| |
| |
| relative_position = relative_position + 1 |
| relative_position = torch.max(relative_position, torch.zeros_like(relative_position)) |
|
|
| |
| max_exact = num_buckets // 2 |
| is_small = relative_position < max_exact |
|
|
| num_log_buckets = num_buckets - max_exact - 1 |
| |
| |
| relative_position_if_large = max_exact + ( |
| torch.log(relative_position.float() / max_exact) |
| / math.log(max_distance / max_exact) |
| * (num_buckets - num_log_buckets) |
| ).to(torch.long) |
|
|
| relative_position_if_large = torch.min( |
| relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 2) |
| ) |
|
|
| relative_buckets = torch.where(is_small, relative_position, relative_position_if_large) |
| |
| |
| |
| |
| special_mask = (relative_position == np.iinfo(np.int16).max+1) |
| relative_buckets[special_mask] = num_buckets-1 |
|
|
| return relative_buckets |
| |
| def compute_bias(self, query_length, key_length, device=None, cache_position=None, relative_position=None): |
| """Compute binned relative position bias""" |
| if device is None: |
| device = self.relative_attention_bias.weight.device |
| |
| if relative_position is None: |
| if cache_position is None: |
| context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| else: |
| context_position = cache_position[:, None].to(device) |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] |
| relative_position = memory_position - context_position |
| |
| |
| relative_position_bucket = self._relative_position_bucket( |
| relative_position, |
| bidirectional=(not self.is_decoder), |
| num_buckets=self.relative_attention_num_buckets, |
| max_distance=self.relative_attention_max_distance, |
| ) |
|
|
| values = self.relative_attention_bias(relative_position_bucket) |
| values = values.permute([0, 3, 1, 2]) |
| return values |
|
|
| |
| |
| class M5RegressionHead(nn.Module): |
| def __init__(self, config: T5Config): |
| super().__init__() |
|
|
| self.pooler = M5Pooler(config) |
| self.transform = nn.Linear(config.d_model, config.d_model) |
| if config.is_gated_act: |
| self.DenseReluDense = T5DenseGatedActDense(config) |
| else: |
| self.DenseReluDense = T5DenseActDense(config) |
| self.out_proj = nn.Linear(config.d_model, config.num_labels) |
|
|
| def forward(self, input_ids: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: |
| pooled = self.pooler(input_ids, hidden_states) |
|
|
| pooled = self.transform(pooled) |
| pooled = self.DenseReluDense(pooled) |
| output = self.out_proj(pooled) |
|
|
| return output |
|
|
| |
| class M5TokenRegressionHead(nn.Module): |
| def __init__(self, config: T5Config): |
| super().__init__() |
|
|
| |
| self.transform1 = nn.Linear(config.d_model*2, config.d_model) |
| if config.is_gated_act: |
| self.DenseReluDense1 = T5DenseGatedActDense(config) |
| else: |
| self.DenseReluDense1 = T5DenseActDense(config) |
|
|
| self.transform2 = nn.Linear(config.d_model, config.d_model) |
|
|
| if config.is_gated_act: |
| self.DenseReluDense2 = T5DenseGatedActDense(config) |
| else: |
| self.DenseReluDense2 = T5DenseActDense(config) |
|
|
| |
|
|
| self.output = nn.Linear(config.d_model, 1) |
| self.config = config |
|
|
| def forward(self, token_hidden_states: torch.Tensor) -> torch.Tensor: |
| |
|
|
| |
| cls_hidden = token_hidden_states[:, 0, :] |
| token_hidden = token_hidden_states[:, 1:, :] |
|
|
| cls_repeated = cls_hidden.unsqueeze(1).expand(-1, token_hidden.size(1), -1) |
| augmented_hidden = torch.cat([token_hidden, cls_repeated], dim=-1).contiguous() |
|
|
| transformed = self.transform1(augmented_hidden) |
| transformed = self.DenseReluDense1(transformed) |
| transformed = self.transform2(transformed) |
| transformed = self.DenseReluDense2(transformed) |
|
|
| output = self.output(transformed) |
| output = output.squeeze(-1) |
| |
| |
| return output |
|
|
|
|
| class M5PreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = T5Config |
| base_model_prefix = "transformer" |
| is_parallelizable = True |
| supports_gradient_checkpointing = True |
| _supports_quantized_cache = False |
| _supports_static_cache = True |
| _supports_cache_class = True |
| _no_split_modules = ["T5Block"] |
| _keep_in_fp32_modules = ["wo"] |
|
|
| @property |
| def dummy_inputs(self): |
| input_ids = torch.tensor(DUMMY_INPUTS) |
| input_mask = torch.tensor(DUMMY_MASK) |
| dummy_inputs = { |
| "decoder_input_ids": input_ids, |
| "input_ids": input_ids, |
| "decoder_attention_mask": input_mask, |
| } |
| return dummy_inputs |
|
|
| def _init_weights(self, module): |
| """Initialize the weights""" |
| factor = self.config.initializer_factor |
| if isinstance(module, T5LayerNorm): |
| module.weight.data.fill_(factor * 1.0) |
| elif isinstance( |
| module, |
| (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering), |
| ): |
| |
| |
| module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) |
| if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: |
| module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) |
| if hasattr(module, "qa_outputs"): |
| module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
| module.qa_outputs.bias.data.zero_() |
| elif isinstance(module, T5ForTokenClassification): |
| if hasattr(module, "classifier"): |
| module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0) |
| module.classifier.bias.data.zero_() |
| elif isinstance(module, T5ClassificationHead): |
| module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
| if hasattr(module.dense, "bias") and module.dense.bias is not None: |
| module.dense.bias.data.zero_() |
| module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
| if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: |
| module.out_proj.bias.data.zero_() |
| elif isinstance(module, T5DenseActDense): |
| |
| |
| |
| module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
| if hasattr(module.wi, "bias") and module.wi.bias is not None: |
| module.wi.bias.data.zero_() |
| module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) |
| if hasattr(module.wo, "bias") and module.wo.bias is not None: |
| module.wo.bias.data.zero_() |
| elif isinstance(module, T5DenseGatedActDense): |
| module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
| if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: |
| module.wi_0.bias.data.zero_() |
| module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
| if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: |
| module.wi_1.bias.data.zero_() |
| module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) |
| if hasattr(module.wo, "bias") and module.wo.bias is not None: |
| module.wo.bias.data.zero_() |
| elif isinstance(module, M5RegressionHead): |
| module.transform.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
| if hasattr(module.transform, "bias") and module.transform.bias is not None: |
| module.transform.bias.data.zero_() |
| module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
| if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None: |
| module.out_proj.bias.data.zero_() |
| elif isinstance(module, M5TokenRegressionHead): |
| module.transform1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model*2) ** -0.5)) |
| module.transform1.bias.data.zero_() |
| module.transform2.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) |
| module.transform2.bias.data.zero_() |
| module.output.weight.data.normal_(mean=0.0, std=factor * ((37.84) ** -0.5)) |
| module.output.bias.data.zero_() |
|
|
| elif isinstance(module, T5Attention): |
| |
| |
| d_model = self.config.d_model |
| key_value_proj_dim = self.config.d_kv |
| n_heads = self.config.num_heads |
| module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) |
| module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) |
| module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) |
| module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) |
| if module.has_relative_attention_bias: |
| module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) |
|
|
| def _shift_right(self, input_ids): |
| decoder_start_token_id = self.config.decoder_start_token_id |
| pad_token_id = self.config.pad_token_id |
|
|
| if decoder_start_token_id is None: |
| raise ValueError( |
| "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. " |
| "See T5 docs for more information." |
| ) |
|
|
| |
| if is_torch_fx_proxy(input_ids): |
| |
| shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) |
| shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) |
| else: |
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
| shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() |
| shifted_input_ids[..., 0] = decoder_start_token_id |
|
|
| if pad_token_id is None: |
| raise ValueError("self.model.config.pad_token_id has to be defined.") |
| |
| shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
| return shifted_input_ids |
|
|
|
|
| class M5ModelForRegression(M5PreTrainedModel): |
| config_class = M5EncoderConfig |
| model_type = "m5_model" |
|
|
| def __init__( |
| self, |
| config: T5Config): |
|
|
| super().__init__(config) |
| self.encoder: M5Encoder = M5Encoder(config) |
| self.token_reg_head: M5TokenRegressionHead = M5TokenRegressionHead(config) |
| self.reg_head: M5RegressionHead = M5RegressionHead(config) |
|
|
| self.init_weights() |
|
|
| def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs): |
| output = self.encoder(input_ids, attention_mask, relative_position=relative_position, **kwargs) |
| hidden_states = output.last_hidden_state |
|
|
| tokreg_head = self.token_reg_head(hidden_states) |
| reg_head = self.reg_head(input_ids, hidden_states) |
|
|
| concatenated_preds = torch.cat([reg_head, tokreg_head], dim=-1) |
| return concatenated_preds |