| |
|
|
| from dataclasses import dataclass |
|
|
| import copy |
| import math |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| from torch import nn |
| from torch.nn import CrossEntropyLoss |
| import torch.nn.functional as F |
|
|
| from transformers.modeling_utils import ModuleUtilsMixin |
| from transformers.modeling_outputs import ModelOutput, Seq2SeqModelOutput, BaseModelOutput |
| from transformers import PreTrainedModel |
|
|
| try: |
| from .rms_norm import fast_rms_layernorm |
| except ImportError: |
| fast_rms_layernorm = None |
|
|
| try: |
| from .cross_entropy_loss import fast_cross_entropy_loss |
| except ImportError: |
| fast_cross_entropy_loss = None |
|
|
| try: |
| from .flash_attention_v2_bias import attention as flash_attention_triton |
| except ImportError: |
| fast_cross_entropy_loss = None |
|
|
| try: |
| from .gated_mlp import gated_mlp |
| except ImportError: |
| gated_mlp = None |
|
|
| try: |
| |
| from .fa2_compilable import flash_attn_kvpacked_func, flash_attn_func |
| except ImportError: |
| flash_attn_kvpacked_func, flash_attn_func = None, None |
|
|
| from .attn_ref import attn_ref |
|
|
| from .configuration_flash_t5 import FlashT5Config |
| from .positional_encoding import ALiBiPositionalEncoding, RelativePositionalEncoding, RotaryPositionalEncoding |
|
|
| @dataclass |
| class EncoderOutput(ModelOutput): |
| hidden_states: torch.FloatTensor = None |
| attention_mask: torch.FloatTensor = None |
| |
| @dataclass |
| class Seq2SeqLMOutput(ModelOutput): |
| loss: torch.FloatTensor = None |
| logits: torch.FloatTensor = None |
| encoder_outputs: EncoderOutput = None |
|
|
| |
| class FlashT5CrossEntropyLoss(nn.Module): |
| def __init__(self, z_loss_factor=0.0, label_smoothing=0.0, use_triton_crossentropy=False): |
|
|
| super().__init__() |
|
|
| if use_triton_crossentropy and fast_cross_entropy_loss is None: |
| raise ImportError("fast_cross_entropy_loss is not available") |
|
|
| self.use_triton_crossentropy = use_triton_crossentropy |
| self.z_loss_factor = z_loss_factor |
|
|
| self.cross_entropy_loss = nn.CrossEntropyLoss(label_smoothing=label_smoothing) |
|
|
| def compute_zloss(self, logits: torch.Tensor, z_loss: float): |
| logits_sum = torch.logsumexp(logits, dim=-1, keepdim=True) |
| log_z = torch.squeeze(logits_sum, axis=-1) |
| total_z_loss = z_loss * torch.square(log_z) |
| return total_z_loss.mean() |
|
|
| def forward(self, logits, labels): |
|
|
| if self.use_triton_crossentropy: |
| return fast_cross_entropy_loss(logits, labels, z_loss_factor=self.z_loss_factor) |
|
|
| |
| batch, seq_len, d = logits.shape |
| logits_flatten = logits.float().view(batch*seq_len, d) |
| labels_flatten = labels.view(-1) |
| loss = self.cross_entropy_loss(logits_flatten, labels_flatten) |
| z_loss = 0.0 |
| if self.z_loss_factor != 0.0: |
| z_loss = self.compute_zloss(logits_flatten[labels_flatten != -100], |
| z_loss=self.z_loss_factor) |
| return loss, z_loss |
|
|
| class FlashT5LayerNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6, use_triton_layernorm=False): |
| """ |
| Construct a layernorm module in the T5 style. No bias and no subtraction of mean. |
| """ |
| super().__init__() |
|
|
| if use_triton_layernorm and fast_rms_layernorm is None: |
| raise ImportError("fast_rms_layernorm is not available") |
|
|
| self.use_triton_layernorm = use_triton_layernorm |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
|
|
| if self.use_triton_layernorm: |
| return fast_rms_layernorm(hidden_states, self.weight, self.variance_epsilon) |
|
|
| |
| |
| |
| |
|
|
| variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
|
|
| |
| if self.weight.dtype in [torch.float16, torch.bfloat16]: |
| hidden_states = hidden_states.to(self.weight.dtype) |
|
|
| return self.weight * hidden_states |
|
|
| class FlashT5DenseAct(nn.Module): |
| def __init__(self, config: FlashT5Config): |
| super().__init__() |
| self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU() |
|
|
| def forward(self, hidden_states): |
| hidden_states = self.wi(hidden_states) |
| hidden_states = self.act(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
| if ( |
| isinstance(self.wo.weight, torch.Tensor) |
| and hidden_states.dtype != self.wo.weight.dtype |
| and self.wo.weight.dtype != torch.int8 |
| ): |
| hidden_states = hidden_states.to(self.wo.weight.dtype) |
|
|
| return hidden_states |
|
|
| class FlashT5DenseGatedAct(nn.Module): |
| def __init__(self, config: FlashT5Config): |
| super().__init__() |
| self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) |
| self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| self.act = torch.nn.GELU(approximate='tanh') if config.use_gelu_act else torch.nn.ReLU() |
|
|
| self.use_triton_gated_mlp = config.use_triton_gated_mlp |
| if self.use_triton_gated_mlp and gated_mlp is None: |
| raise ImportError("gated_mlp is not available") |
| self.use_gelu_act = config.use_gelu_act |
|
|
| def forward(self, hidden_states): |
|
|
| if self.use_triton_gated_mlp: |
| return gated_mlp(hidden_states, self.wi_0.weight, self.wi_1.weight, self.use_gelu_act) |
|
|
| hidden_act = self.act(self.wi_0(hidden_states)) |
| hidden_linear = self.wi_1(hidden_states) |
| hidden_states = hidden_act * hidden_linear |
| hidden_states = self.dropout(hidden_states) |
|
|
| return hidden_states |
|
|
| class FlashT5LayerFF(nn.Module): |
| def __init__(self, config: FlashT5Config): |
| super().__init__() |
| if config.use_glu_mlp: |
| self.act = FlashT5DenseGatedAct(config) |
| else: |
| self.act = FlashT5DenseAct(config) |
|
|
| self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm) |
| self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward(self, hidden_states): |
| forwarded_states = self.layer_norm(hidden_states).type_as(hidden_states) |
| forwarded_states = self.act(forwarded_states) |
| forwarded_states = self.wo(forwarded_states) |
| hidden_states = hidden_states + self.dropout(forwarded_states) |
| return hidden_states |
|
|
|
|
| class FlashT5Attention(nn.Module, ModuleUtilsMixin): |
| def __init__(self, config: FlashT5Config, has_positional_encoding=False, is_causal=False): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.has_positional_encoding = has_positional_encoding |
| self.is_causal = is_causal |
| self.relative_attention_num_buckets = config.relative_attention_num_buckets |
| self.relative_attention_max_distance = config.relative_attention_max_distance |
| self.d_model = config.d_model |
| self.key_value_proj_dim = config.d_kv |
| self.n_heads = config.num_heads |
| self.p_dropout = config.attention_dropout_rate |
| self.inner_dim = self.n_heads * self.key_value_proj_dim |
| self.use_flash_attention = config.use_flash_attention |
| self.position_encoding_type = config.position_encoding_type |
| self.max_sequence_length = config.max_sequence_length |
| self.softmax_scale = 1.0/math.sqrt(self.n_heads) |
| self.use_full_bias_size = config.use_full_bias_size |
|
|
| if self.use_flash_attention == "triton" and flash_attention_triton is None: |
| raise ImportError("flash_attention_triton is not available") |
| elif self.use_flash_attention == "fa2" and flash_attn_func is None: |
| raise ImportError("Flash Attention 2 is not available") |
|
|
| assert (self.p_dropout == 0.0) or (self.use_flash_attention != "triton"), "Triton attention does not support dropout" |
|
|
| self.pe_encoding = None |
| if self.position_encoding_type == "ALiBi" and has_positional_encoding: |
| |
| self.pe_encoding = ALiBiPositionalEncoding(self.max_sequence_length, self.n_heads, config.alibi_mode, config.use_randomized_position_encoding) |
| elif self.position_encoding_type == "t5" and has_positional_encoding: |
| self.pe_encoding = RelativePositionalEncoding(self.relative_attention_num_buckets, self.relative_attention_max_distance, self.n_heads, self.max_sequence_length, config.use_randomized_position_encoding) |
| elif self.position_encoding_type == "RoPE": |
| self.pe_encoding = RotaryPositionalEncoding(int(self.key_value_proj_dim * config.rotary_emb_fraction), self.max_sequence_length, config.rotary_base, config.rotary_interleaved, config.rotary_scale_base, config.use_randomized_position_encoding) |
|
|
| self.Wq = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| self.Wk = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| self.Wv = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) |
|
|
| def forward( |
| self, |
| hidden_states, |
| mask=None, |
| key_value_states=None, |
| position_bias=None, |
| ): |
| """ |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
| """ |
| |
| |
| batch_size, seq_length = hidden_states.shape[:2] |
| key_length = seq_length if key_value_states is None else key_value_states.shape[1] |
| q = self.Wq(hidden_states) |
| if key_value_states is None: |
| k = self.Wk(hidden_states) |
| v = self.Wv(hidden_states) |
| else: |
| k = self.Wk(key_value_states) |
| v = self.Wv(key_value_states) |
|
|
| q = q.view(batch_size, seq_length, self.n_heads, self.key_value_proj_dim) |
| k = k.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim) |
| v = v.view(batch_size, key_length, self.n_heads, self.key_value_proj_dim) |
|
|
| if position_bias is None and self.pe_encoding is not None: |
| q, k, v, position_bias = self.pe_encoding(q, k, v) |
|
|
| if position_bias is not None and self.use_full_bias_size and (self.use_flash_attention == "fa2" or self.use_flash_attention == "triton"): |
| position_bias = position_bias.expand(q.shape[0], q.shape[2], q.shape[1], k.shape[1]).contiguous() |
|
|
| if self.use_flash_attention == "fa2": |
| output = flash_attn_func(q, k, v, dropout_p=self.p_dropout, softmax_scale=self.softmax_scale, attn_bias=position_bias, causal=self.is_causal) |
| elif self.use_flash_attention == "triton": |
| q = q.permute(0, 2, 1, 3) |
| k = k.permute(0, 2, 1, 3) |
| v = v.permute(0, 2, 1, 3) |
| output = flash_attention_triton(q, k, v, position_bias, self.is_causal, self.softmax_scale) |
| output = output.permute(0, 2, 1, 3) |
| else: |
| q = q.permute(0, 2, 1, 3) |
| k = k.permute(0, 2, 1, 3) |
| v = v.permute(0, 2, 1, 3) |
| output = attn_ref(q, k, v, position_bias, dropout_p=self.p_dropout, sm_scale=self.softmax_scale, causal=self.is_causal) |
| output = output.permute(0, 2, 1, 3) |
|
|
| output = self.o(output.reshape(output.shape[0], output.shape[1], self.inner_dim)) |
| return (output, position_bias) |
|
|
|
|
| class FlashT5LayerSelfAttention(nn.Module): |
| def __init__(self, config, has_positional_encoding=False): |
| super().__init__() |
| self.self_attention = FlashT5Attention(config, has_positional_encoding=has_positional_encoding, is_causal=config.is_decoder) |
| self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_bias=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states).type_as(hidden_states) |
| attention_output = self.self_attention( |
| normed_hidden_states, |
| mask=attention_mask, |
| position_bias=position_bias, |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] |
| return outputs |
|
|
|
|
| class FlashT5LayerCrossAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.cross_attention = FlashT5Attention(config, has_positional_encoding=False) |
| self.layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward( |
| self, |
| hidden_states, |
| key_value_states, |
| attention_mask=None, |
| position_bias=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.cross_attention( |
| normed_hidden_states, |
| mask=attention_mask, |
| key_value_states=key_value_states, |
| position_bias=position_bias, |
| ) |
| layer_output = hidden_states + self.dropout(attention_output[0]) |
| outputs = (layer_output,) + attention_output[1:] |
| return outputs |
|
|
|
|
| class FlashT5Block(nn.Module): |
| def __init__(self, config, has_positional_encoding=False): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
|
|
| self.self_attention_layer = FlashT5LayerSelfAttention(config, has_positional_encoding=has_positional_encoding) |
|
|
| if self.is_decoder: |
| self.cross_attention_layer = FlashT5LayerCrossAttention(config) |
|
|
| self.ff_layer = FlashT5LayerFF(config) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_bias=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| encoder_decoder_position_bias=None, |
| ): |
| self_attention_outputs = self.self_attention_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| ) |
| hidden_states = self_attention_outputs[0] |
| attention_outputs = self_attention_outputs[1:] |
|
|
| if self.is_decoder and encoder_hidden_states is not None: |
| cross_attention_outputs = self.cross_attention_layer( |
| hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| ) |
| hidden_states = cross_attention_outputs[0] |
|
|
| |
| attention_outputs = attention_outputs + cross_attention_outputs[1:] |
|
|
| |
| hidden_states = self.ff_layer(hidden_states) |
|
|
| outputs = (hidden_states,) + attention_outputs |
| return outputs |
|
|
| class FlashT5Stack(nn.Module, ModuleUtilsMixin): |
| def __init__(self, config, embed_tokens): |
| super().__init__() |
| assert embed_tokens is not None |
|
|
| self.config = config |
| self.embed_tokens = embed_tokens |
| self.is_decoder = config.is_decoder |
| self.use_flash_attention = config.use_flash_attention |
|
|
| self.block = nn.ModuleList( |
| [FlashT5Block(config, has_positional_encoding=bool(i == 0)) for i in range(config.num_layers)] |
| ) |
|
|
| self.final_layer_norm = FlashT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon, use_triton_layernorm=config.use_triton_layernorm) |
| self.dropout = nn.Dropout(config.dropout_rate) |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| encoder_hidden_states=None, |
| encoder_attention_mask=None, |
| inputs_embeds=None, |
| head_mask=None, |
| cross_attn_head_mask=None, |
| past_key_values=None, |
| use_cache=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None) -> BaseModelOutput: |
| input_shape = input_ids.size() |
| batch_size, seq_length = input_shape |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if torch.is_autocast_enabled() and input_ids.device.type == 'cuda': |
| inputs_embeds = inputs_embeds.to(torch.get_autocast_gpu_dtype()) |
|
|
| |
| if attention_mask is None: |
| attention_mask = torch.ones(batch_size, seq_length, device=inputs_embeds.device, dtype=torch.bool) |
|
|
| if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: |
| encoder_seq_length = encoder_hidden_states.shape[1] |
| encoder_attention_mask = torch.ones( |
| batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.bool |
| ) |
|
|
| position_bias = None |
| encoder_decoder_position_bias = None |
|
|
| hidden_states = self.dropout(inputs_embeds) |
|
|
| for _, layer_module in enumerate(self.block): |
| layer_outputs = layer_module( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_attention_mask, |
| encoder_decoder_position_bias=encoder_decoder_position_bias, |
| ) |
|
|
| |
| position_bias = layer_outputs[1] |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[2] |
|
|
| hidden_states = layer_outputs[0] |
|
|
| hidden_states = self.final_layer_norm(hidden_states).type_as(hidden_states) |
| hidden_states = self.dropout(hidden_states) |
|
|
| return BaseModelOutput( |
| last_hidden_state=hidden_states |
| ) |
|
|
|
|
| class FlashT5PreTrainedModel(PreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = FlashT5Config |
| base_model_prefix = "transformer" |
| is_parallelizable = False |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["FlashT5Block"] |
| _keep_in_fp32_modules = [] |
|
|
| def _init_weights(self, module): |
| factor = self.config.initializer_factor |
| if isinstance(module, FlashT5LayerNorm): |
| module.weight.data.fill_(factor * 1.0) |
| elif isinstance(module, (FlashT5ForConditionalGeneration)): |
| module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) |
| if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: |
| module.lm_head.weight.data.normal_(mean=0.0, std=factor * self.config.d_model ** -0.5) |
| elif isinstance(module, FlashT5DenseGatedAct): |
| d_ff, d_model = module.wi_0.weight.data.size() |
| module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) |
| module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) |
| elif isinstance(module, FlashT5LayerFF): |
| d_ff, d_model = module.wo.weight.data.size() |
| module.wo.weight.data.normal_(mean=0.0, std=factor * ((d_ff) ** -0.5)) |
| elif isinstance(module, FlashT5Attention): |
| d_model = self.config.d_model |
| key_value_proj_dim = self.config.d_kv |
| n_heads = self.config.num_heads |
| module.Wq.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5)) |
| module.Wk.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) |
| module.Wv.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) |
| module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5)) |
| if module.has_positional_encoding: |
| if hasattr(module.pe_encoding, "relative_attention_bias"): |
| module.pe_encoding.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) |
|
|
| def _shift_right(self, input_ids): |
| decoder_start_token_id = self.config.decoder_start_token_id |
| pad_token_id = self.config.pad_token_id |
|
|
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
| shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() |
| shifted_input_ids[..., 0] = decoder_start_token_id |
|
|
| |
| shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
| return shifted_input_ids |
|
|
|
|
| class FlashT5Model(FlashT5PreTrainedModel): |
| def __init__(self, config: FlashT5Config): |
| super().__init__(config) |
| self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| encoder_config = copy.deepcopy(config) |
| encoder_config.is_decoder = False |
| encoder_config.use_cache = False |
| encoder_config.is_encoder_decoder = False |
| self.encoder = FlashT5Stack(encoder_config, self.shared) |
|
|
| decoder_config = copy.deepcopy(config) |
| decoder_config.is_decoder = True |
| decoder_config.is_encoder_decoder = False |
| decoder_config.num_layers = config.num_decoder_layers |
| self.decoder = FlashT5Stack(decoder_config, self.shared) |
|
|
| |
| self.post_init() |
|
|
| |
| self.model_parallel = False |
| self.device_map = None |
|
|
| def get_input_embeddings(self): |
| return self.shared |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.shared = new_embeddings |
| self.encoder.set_input_embeddings(new_embeddings) |
| self.decoder.set_input_embeddings(new_embeddings) |
|
|
| def get_encoder(self): |
| return self.encoder |
|
|
| def get_decoder(self): |
| return self.decoder |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| decoder_attention_mask: Optional[torch.BoolTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| decoder_head_mask: Optional[torch.FloatTensor] = None, |
| cross_attn_head_mask: Optional[torch.Tensor] = None, |
| encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| decoder_inputs_embeds: Optional[torch.Tensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: |
|
|
| |
| if encoder_outputs is None: |
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds |
| ) |
|
|
| hidden_states = encoder_outputs[0] |
|
|
| |
| decoder_outputs = self.decoder( |
| input_ids=decoder_input_ids, |
| attention_mask=decoder_attention_mask, |
| inputs_embeds=decoder_inputs_embeds, |
| encoder_hidden_states=hidden_states, |
| encoder_attention_mask=attention_mask |
| ) |
|
|
| return Seq2SeqModelOutput( |
| last_hidden_state=decoder_outputs.last_hidden_state, |
| decoder_hidden_states=decoder_outputs.hidden_states, |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
| encoder_hidden_states=encoder_outputs.hidden_states, |
| ) |
|
|
| class FlashT5ForConditionalGeneration(FlashT5PreTrainedModel): |
|
|
| def __init__(self, config: FlashT5Config): |
| super().__init__(config) |
| config.is_encoder_decoder = False |
| assert not config.tie_word_embeddings |
|
|
| self.config = config |
| self.model_dim = config.d_model |
| self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| encoder_config = copy.deepcopy(config) |
| encoder_config.is_decoder = False |
| self.encoder = FlashT5Stack(encoder_config, self.shared) |
|
|
| decoder_config = copy.deepcopy(config) |
| decoder_config.is_decoder = True |
| decoder_config.num_layers = config.num_decoder_layers |
| self.decoder = FlashT5Stack(decoder_config, self.shared) |
|
|
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| self.loss_fct = FlashT5CrossEntropyLoss(z_loss_factor=config.z_loss, |
| label_smoothing=config.label_smoothing, |
| use_triton_crossentropy=config.use_triton_crossentropy) |
|
|
| |
| self.post_init() |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
| ): |
| |
| model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
| return model_inputs |
|
|
| def get_input_embeddings(self): |
| return self.shared |
|
|
| def set_input_embeddings(self, value): |
| self.shared = value |
|
|
| def generate( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| max_length = 32, |
| **kwargs, |
| ) -> torch.LongTensor: |
| """ |
| input_ids: B x L_encoder, int64 |
| attention_mask: B x L_encoder, int64 |
| 1 for tokens to attend to, 0 for tokens to ignore |
| |
| Generation: |
| Starts with 0, ends with 1, padding is 0 |
| |
| # For 20 input/outputs, the diff between my implementation and HF is 9.8s vs 11.4s |
| """ |
| B, _ = input_ids.size() |
| labels = torch.zeros(B, 1, dtype=torch.long, device=input_ids.device) |
| encoder_outputs = None |
|
|
| for _ in range(max_length): |
| out = self.forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| decoder_input_ids=labels, |
| encoder_outputs=encoder_outputs, |
| ) |
| encoder_outputs = out.encoder_outputs |
| top_labels = out.logits[:, -1].argmax(-1).unsqueeze(-1) |
| labels = torch.cat([labels, top_labels], dim=-1) |
|
|
| if (labels == 1).sum(-1).clamp(min=0, max=1).sum().item() == B: |
| break |
|
|
| labels[:, -1] = 1 |
|
|
| |
| B, L = labels.size() |
| mask = torch.arange(L, device=labels.device).unsqueeze(0) <= (labels == 1).long().argmax(-1).unsqueeze(-1) |
| labels = labels.masked_fill(~mask, 0) |
|
|
| return labels |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| decoder_attention_mask: Optional[torch.BoolTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| encoder_outputs = None, |
| ) -> Seq2SeqLMOutput: |
| """ |
| input_ids: B x L_encoder, int64 |
| attention_mask: B x L_encoder, int64 |
| 1 for tokens to attend to, 0 for tokens to ignore |
| labels: B x L_decoder, int64 |
| """ |
| if encoder_outputs is None: |
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
|
|
| hidden_states = encoder_outputs.hidden_states |
|
|
| if labels is not None and decoder_input_ids is None: |
| decoder_input_ids = self._shift_right(labels) |
|
|
| decoder_outputs = self.decoder( |
| input_ids=decoder_input_ids, |
| attention_mask=decoder_attention_mask, |
| encoder_hidden_states=hidden_states, |
| encoder_attention_mask=attention_mask, |
| ) |
|
|
| sequence_output = decoder_outputs[0] |
| lm_logits = self.lm_head(sequence_output) |
|
|
| loss = None |
| if labels is not None: |
| loss, z_loss = self.loss_fct(lm_logits, labels) |
| loss += z_loss |
|
|
| return Seq2SeqLMOutput( |
| loss=loss, |
| logits=lm_logits, |
| encoder_outputs=encoder_outputs, |
| ) |
|
|
|
|
|
|
| class FlashT5EncoderModel(FlashT5PreTrainedModel): |
| _tied_weights_keys = ["encoder.embed_tokens.weight"] |
|
|
| def __init__(self, config: FlashT5Config): |
| super().__init__(config) |
| self.shared = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| encoder_config = copy.deepcopy(config) |
| encoder_config.use_cache = False |
| encoder_config.is_encoder_decoder = False |
| self.encoder = FlashT5Stack(encoder_config, self.shared) |
|
|
| |
| self.post_init() |
|
|
| |
| self.model_parallel = False |
| self.device_map = None |
|
|
| |
| def parallelize(self, device_map=None): |
| warnings.warn( |
| "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" |
| " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" |
| " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0," |
| " 'block.1': 1, ...}", |
| FutureWarning, |
| ) |
| self.device_map = ( |
| get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) |
| if device_map is None |
| else device_map |
| ) |
| assert_device_map(self.device_map, len(self.encoder.block)) |
| self.encoder.parallelize(self.device_map) |
| self.model_parallel = True |
|
|
| def deparallelize(self): |
| warnings.warn( |
| "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", |
| FutureWarning, |
| ) |
| self.encoder.deparallelize() |
| self.encoder = self.encoder.to("cpu") |
| self.model_parallel = False |
| self.device_map = None |
| torch.cuda.empty_cache() |
|
|
| def get_input_embeddings(self): |
| return self.shared |
|
|
| def set_input_embeddings(self, new_embeddings): |
| self.shared = new_embeddings |
| self.encoder.set_input_embeddings(new_embeddings) |
|
|
| def get_encoder(self): |
| return self.encoder |
|
|
| def _prune_heads(self, heads_to_prune): |
| """ |
| Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
| class PreTrainedModel |
| """ |
| for layer, heads in heads_to_prune.items(): |
| self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: |
| r""" |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, T5EncoderModel |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") |
| >>> model = T5EncoderModel.from_pretrained("t5-small") |
| >>> input_ids = tokenizer( |
| ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" |
| ... ).input_ids # Batch size 1 |
| >>> outputs = model(input_ids=input_ids) |
| >>> last_hidden_states = outputs.last_hidden_state |
| ```""" |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| head_mask=head_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| return encoder_outputs |