| import warnings |
| from dataclasses import dataclass |
| from typing import Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import CrossEntropyLoss |
| from transformers import T5ForConditionalGeneration, T5Config, Cache |
| from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput |
|
|
|
|
| class T5LaAdapterConfig(T5Config): |
| model_type = "t5la_adapter" |
| keys_to_ignore_at_inference = ["past_key_values"] |
| attribute_map = { |
| "hidden_size": "d_model", |
| "num_attention_heads": "num_heads", |
| "num_hidden_layers": "num_layers", |
| "head_dim": "d_kv", |
| } |
| auto_map = { |
| "AutoConfig": "t5la_adapter.T5LaAdapterConfig", |
| "AutoModel": "t5la_adapter.T5LaAdapterForConditionalGeneration", |
| "AutoModelForSeq2SeqLM": "t5la_adapter.T5LaAdapterForConditionalGeneration", |
| "AutoTokenizer": [ |
| "transformers.T5TokenizerFast", |
| "transformers.T5Tokenizer" |
| ] |
| } |
|
|
| def __init__( |
| self, |
| is_encoder_decoder=True, |
| pad_token_id=0, |
| eos_token_id=1, |
| lookahead_type="la", |
| lookahead_size=0, |
| freeze_base=True, |
| **kwargs, |
| ): |
| self.lookahead_type = lookahead_type |
| self.lookahead_size = lookahead_size |
| self.freeze_base = freeze_base |
| super().__init__( |
| pad_token_id=pad_token_id, |
| eos_token_id=eos_token_id, |
| is_encoder_decoder=is_encoder_decoder, |
| **kwargs, |
| ) |
| self.auto_map = { |
| "AutoConfig": "t5la_adapter.T5LaAdapterConfig", |
| "AutoModel": "t5la_adapter.T5LaAdapterForConditionalGeneration", |
| "AutoModelForSeq2SeqLM": "t5la_adapter.T5LaAdapterForConditionalGeneration", |
| "AutoTokenizer": [ |
| "transformers.T5TokenizerFast", |
| "transformers.T5Tokenizer" |
| ] |
| } |
|
|
| @dataclass |
| class Seq2SeqLMOutputLA(Seq2SeqLMOutput): |
| lookahead_logits: torch.FloatTensor = None |
| lookahead_loss: Optional[torch.FloatTensor] = None |
| base_loss: Optional[torch.FloatTensor] = None |
| decoder_last_hidden_state: Optional[tuple[torch.FloatTensor, ...]] = None |
|
|
|
|
| class LookAheadHeads(nn.Module): |
| def __init__(self, config: T5LaAdapterConfig, k: int) -> None: |
| super().__init__() |
| self.k = k |
| self.heads = nn.ModuleList( |
| [ |
| |
| nn.Linear(config.d_model, config.vocab_size, bias=False) |
| for _ in range(self.k) |
| ] |
| ) |
|
|
| def forward(self, x): |
| |
| |
| logits = [head(x) for head in self.heads] |
|
|
| |
| if self.k > 0: |
| logits = torch.stack(logits, dim=1) |
| else: |
| logits = logits[0] |
| return logits |
|
|
|
|
| class T5LaAdapterForConditionalGeneration(T5ForConditionalGeneration): |
| config_class = T5LaAdapterConfig |
| def __init__(self, config: T5LaAdapterConfig): |
| super().__init__(config) |
| if config.lookahead_type == "la": |
| self.la_heads = LookAheadHeads(config, config.lookahead_size) |
| elif config.lookahead_type in ["laa", "laa2"]: |
| self.la_heads = LookAheadHeads(config, 1) |
|
|
| |
| if config.freeze_base: |
| for param in self.parameters(): |
| param.requires_grad = False |
| for param in self.la_heads.parameters(): |
| param.requires_grad = True |
|
|
| def freeze_base(self): |
| |
| for param in self.parameters(): |
| param.requires_grad = False |
| for param in self.la_heads.parameters(): |
| param.requires_grad = True |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| decoder_input_ids: Optional[torch.LongTensor] = None, |
| decoder_attention_mask: Optional[torch.BoolTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| decoder_head_mask: Optional[torch.FloatTensor] = None, |
| cross_attn_head_mask: Optional[torch.Tensor] = None, |
| encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| lookahead_targets: Optional[torch.LongTensor] = None, |
| ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutputLA]: |
| r""" |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. T5LA is a model with relative position embeddings so you |
| should be able to pad the inputs on both the right and the left. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for detail. |
| |
| [What are input IDs?](../glossary#input-ids) |
| |
| To know more on how to prepare `input_ids` for pretraining take a look a [T5LA Training](./t5la#training). |
| decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): |
| Indices of decoder input sequence tokens in the vocabulary. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are decoder input IDs?](../glossary#decoder-input-ids) |
| |
| T5LA uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` |
| is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). |
| |
| To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5LA |
| Training](./t5la#training). |
| decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): |
| Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
| be used by default. |
| decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): |
| Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, |
| 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): |
| Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in |
| `[0, 1]`: |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., |
| config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for |
| labels in `[0, ..., config.vocab_size]` |
| lookahead_targets (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| Labels for computing the loss of the LA heads or positions (models of type la, laa, and laa2 have |
| LA heads and lae has LA positions) |
| |
| Examples: |
| |
| ```python |
| >>> from transformers import AutoTokenizer |
| |
| >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small") |
| >>> config = T5LaAdapterConfig.from_pretrained("google-t5/t5-small", lookahead_size=2) |
| >>> model = T5LaAdapterForConditionalGeneration.from_pretrained("google-t5/t5-small", config=config) |
| |
| >>> # training |
| >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids |
| >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids |
| >>> outputs = model(input_ids=input_ids, labels=labels) |
| >>> loss = outputs.loss |
| >>> logits = outputs.logits |
| |
| >>> # inference |
| >>> input_ids = tokenizer( |
| ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" |
| ... ).input_ids # Batch size 1 |
| >>> outputs = model.generate(input_ids) |
| >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
| >>> # studies have shown that owning a dog is good for you. |
| ```""" |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if head_mask is not None and decoder_head_mask is None: |
| if self.config.num_layers == self.config.num_decoder_layers: |
| warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) |
| decoder_head_mask = head_mask |
|
|
| |
| if encoder_outputs is None: |
| |
| encoder_outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| head_mask=head_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): |
| encoder_outputs = BaseModelOutput( |
| last_hidden_state=encoder_outputs[0], |
| hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, |
| attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, |
| ) |
|
|
| hidden_states = encoder_outputs[0] |
|
|
| if self.model_parallel: |
| torch.cuda.set_device(self.decoder.first_device) |
|
|
| if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: |
| |
| decoder_input_ids = self._shift_right(labels) |
|
|
| if self.config.lookahead_type == "lae": |
| |
| zeros_to_add = torch.zeros( |
| decoder_input_ids.shape[0], |
| self.config.lookahead_size, |
| device=decoder_input_ids.device, |
| dtype=decoder_input_ids.dtype, |
| ) |
| decoder_input_ids = torch.cat((decoder_input_ids, zeros_to_add), dim=1) |
| if decoder_attention_mask is not None: |
| ones_to_add = torch.ones( |
| decoder_attention_mask.shape[0], |
| self.config.lookahead_size, |
| device=decoder_attention_mask.device, |
| dtype=decoder_attention_mask.dtype, |
| ) |
| decoder_attention_mask = torch.cat((decoder_attention_mask, ones_to_add), dim=1) |
| |
| if self.model_parallel: |
| torch.cuda.set_device(self.decoder.first_device) |
| hidden_states = hidden_states.to(self.decoder.first_device) |
| if decoder_input_ids is not None: |
| decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) |
| if attention_mask is not None: |
| attention_mask = attention_mask.to(self.decoder.first_device) |
| if decoder_attention_mask is not None: |
| decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) |
|
|
| |
| decoder_outputs = self.decoder( |
| input_ids=decoder_input_ids, |
| attention_mask=decoder_attention_mask, |
| inputs_embeds=decoder_inputs_embeds, |
| past_key_values=past_key_values, |
| encoder_hidden_states=hidden_states, |
| encoder_attention_mask=attention_mask, |
| head_mask=decoder_head_mask, |
| cross_attn_head_mask=cross_attn_head_mask, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| sequence_output = decoder_outputs[0] |
|
|
| |
| if self.model_parallel: |
| torch.cuda.set_device(self.encoder.first_device) |
| self.lm_head = self.lm_head.to(self.encoder.first_device) |
| sequence_output = sequence_output.to(self.lm_head.weight.device) |
|
|
| if self.config.tie_word_embeddings: |
| |
| |
| sequence_output = sequence_output * (self.model_dim**-0.5) |
|
|
| lm_logits = self.lm_head(sequence_output) |
|
|
| lookahead_logits = None |
| if self.config.lookahead_type == "la": |
| lookahead_logits = self.la_heads(sequence_output) |
| elif self.config.lookahead_type == "laa": |
| la_input = torch.repeat_interleave(hidden_states[:, [-1]], self.config.lookahead_size, dim=1) |
| lookahead_logits = self.la_heads(la_input) |
| elif self.config.lookahead_type == "laa2": |
| lookahead_logits = self.la_heads(hidden_states[:, -self.config.lookahead_size :]) |
| elif self.config.lookahead_type == "lae": |
| lookahead_logits = lm_logits[:, -self.config.lookahead_size :].contiguous() |
| lm_logits = lm_logits[:, : -self.config.lookahead_size].contiguous() |
|
|
| lookahead_loss = None |
| loss = None |
| base_loss = None |
| if labels is not None: |
| loss_fct = CrossEntropyLoss(ignore_index=-100) |
| |
| labels = labels.to(lm_logits.device) |
| loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) |
| base_loss = loss.clone() |
| |
| if self.config.lookahead_size > 0 and lookahead_targets is not None: |
| lookahead_loss = loss_fct( |
| lookahead_logits.reshape(-1, lookahead_logits.size(-1)), |
| lookahead_targets.view(-1), |
| |
| ) |
| if self.config.lookahead_type == "la": |
| |
| |
| loss = (loss + lookahead_loss) / (1 + self.config.lookahead_size) |
| else: |
| loss = (loss * lm_logits.shape[1] + lookahead_loss * self.config.lookahead_size) / ( |
| lm_logits.shape[1] + self.config.lookahead_size |
| ) |
|
|
| if not return_dict: |
| output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs |
| return ((loss,) + output) if loss is not None else output |
|
|
| return Seq2SeqLMOutputLA( |
| loss=loss, |
| base_loss=base_loss, |
| logits=lm_logits, |
| past_key_values=decoder_outputs.past_key_values, |
| decoder_hidden_states=decoder_outputs.hidden_states, |
| decoder_last_hidden_state=decoder_outputs.last_hidden_state, |
| decoder_attentions=decoder_outputs.attentions, |
| cross_attentions=decoder_outputs.cross_attentions, |
| encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
| encoder_hidden_states=encoder_outputs.hidden_states, |
| encoder_attentions=encoder_outputs.attentions, |
| lookahead_logits=lookahead_logits, |
| lookahead_loss=lookahead_loss, |
| ) |
|
|