flan-t5la-large / t5la_adapter.py
hrezaei's picture
Fix the dimensions in la_head output in the case of k=1
8ec9152 verified
import warnings
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import T5ForConditionalGeneration, T5Config, Cache
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
class T5LaAdapterConfig(T5Config):
model_type = "t5la_adapter"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}
auto_map = {
"AutoConfig": "t5la_adapter.T5LaAdapterConfig",
"AutoModel": "t5la_adapter.T5LaAdapterForConditionalGeneration",
"AutoModelForSeq2SeqLM": "t5la_adapter.T5LaAdapterForConditionalGeneration",
"AutoTokenizer": [
"transformers.T5TokenizerFast",
"transformers.T5Tokenizer"
]
}
def __init__(
self,
is_encoder_decoder=True,
pad_token_id=0,
eos_token_id=1,
lookahead_type="la",
lookahead_size=0,
freeze_base=True,
**kwargs,
):
self.lookahead_type = lookahead_type
self.lookahead_size = lookahead_size
self.freeze_base = freeze_base
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
self.auto_map = {
"AutoConfig": "t5la_adapter.T5LaAdapterConfig",
"AutoModel": "t5la_adapter.T5LaAdapterForConditionalGeneration",
"AutoModelForSeq2SeqLM": "t5la_adapter.T5LaAdapterForConditionalGeneration",
"AutoTokenizer": [
"transformers.T5TokenizerFast",
"transformers.T5Tokenizer"
]
}
@dataclass
class Seq2SeqLMOutputLA(Seq2SeqLMOutput):
lookahead_logits: torch.FloatTensor = None
lookahead_loss: Optional[torch.FloatTensor] = None
base_loss: Optional[torch.FloatTensor] = None
decoder_last_hidden_state: Optional[tuple[torch.FloatTensor, ...]] = None
class LookAheadHeads(nn.Module):
def __init__(self, config: T5LaAdapterConfig, k: int) -> None:
super().__init__()
self.k = k
self.heads = nn.ModuleList(
[
# K heads for LA positions:
nn.Linear(config.d_model, config.vocab_size, bias=False)
for _ in range(self.k)
]
)
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
# Apply each head to the shared features
logits = [head(x) for head in self.heads]
# Stack logits along a new dimension to create a tensor of shape [batch_size, num_heads, output_size]
if self.k > 0:
logits = torch.stack(logits, dim=1)
else:
logits = logits[0]
return logits
class T5LaAdapterForConditionalGeneration(T5ForConditionalGeneration):
config_class = T5LaAdapterConfig
def __init__(self, config: T5LaAdapterConfig):
super().__init__(config)
if config.lookahead_type == "la":
self.la_heads = LookAheadHeads(config, config.lookahead_size)
elif config.lookahead_type in ["laa", "laa2"]:
self.la_heads = LookAheadHeads(config, 1)
# Freeze all parameters except the new head
if config.freeze_base:
for param in self.parameters():
param.requires_grad = False
for param in self.la_heads.parameters():
param.requires_grad = True # unfreeze the extra head
def freeze_base(self):
# Freeze all parameters except the new head
for param in self.parameters():
param.requires_grad = False
for param in self.la_heads.parameters():
param.requires_grad = True # unfreeze the extra head
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
lookahead_targets: Optional[torch.LongTensor] = None,
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutputLA]:
r"""
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. T5LA is a model with relative position embeddings so you
should be able to pad the inputs on both the right and the left.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for detail.
[What are input IDs?](../glossary#input-ids)
To know more on how to prepare `input_ids` for pretraining take a look a [T5LA Training](./t5la#training).
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Indices of decoder input sequence tokens in the vocabulary.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are decoder input IDs?](../glossary#decoder-input-ids)
T5LA uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5LA
Training](./t5la#training).
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
be used by default.
decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
`[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
labels in `[0, ..., config.vocab_size]`
lookahead_targets (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the loss of the LA heads or positions (models of type la, laa, and laa2 have
LA heads and lae has LA positions)
Examples:
```python
>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
>>> config = T5LaAdapterConfig.from_pretrained("google-t5/t5-small", lookahead_size=2)
>>> model = T5LaAdapterForConditionalGeneration.from_pretrained("google-t5/t5-small", config=config)
>>> # training
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
>>> # inference
>>> input_ids = tokenizer(
... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids # Batch size 1
>>> outputs = model.generate(input_ids)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> # studies have shown that owning a dog is good for you.
```"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
if self.config.num_layers == self.config.num_decoder_layers:
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
decoder_head_mask = head_mask
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
if self.config.lookahead_type == "lae":
# Extend decoder input with lookahead_size extra positions filled by zero as especial tokens:
zeros_to_add = torch.zeros(
decoder_input_ids.shape[0],
self.config.lookahead_size,
device=decoder_input_ids.device,
dtype=decoder_input_ids.dtype,
)
decoder_input_ids = torch.cat((decoder_input_ids, zeros_to_add), dim=1)
if decoder_attention_mask is not None:
ones_to_add = torch.ones(
decoder_attention_mask.shape[0],
self.config.lookahead_size,
device=decoder_attention_mask.device,
dtype=decoder_attention_mask.dtype,
)
decoder_attention_mask = torch.cat((decoder_attention_mask, ones_to_add), dim=1)
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
hidden_states = hidden_states.to(self.decoder.first_device)
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
if attention_mask is not None:
attention_mask = attention_mask.to(self.decoder.first_device)
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
sequence_output = decoder_outputs[0]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.encoder.first_device)
self.lm_head = self.lm_head.to(self.encoder.first_device)
sequence_output = sequence_output.to(self.lm_head.weight.device)
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim**-0.5)
lm_logits = self.lm_head(sequence_output)
lookahead_logits = None
if self.config.lookahead_type == "la":
lookahead_logits = self.la_heads(sequence_output)
elif self.config.lookahead_type == "laa":
la_input = torch.repeat_interleave(hidden_states[:, [-1]], self.config.lookahead_size, dim=1)
lookahead_logits = self.la_heads(la_input)
elif self.config.lookahead_type == "laa2":
lookahead_logits = self.la_heads(hidden_states[:, -self.config.lookahead_size :])
elif self.config.lookahead_type == "lae":
lookahead_logits = lm_logits[:, -self.config.lookahead_size :].contiguous()
lm_logits = lm_logits[:, : -self.config.lookahead_size].contiguous()
lookahead_loss = None
loss = None
base_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-100)
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
base_loss = loss.clone()
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
if self.config.lookahead_size > 0 and lookahead_targets is not None:
lookahead_loss = loss_fct(
lookahead_logits.reshape(-1, lookahead_logits.size(-1)),
lookahead_targets.view(-1),
# vocab_size=self.config.vocab_size,
)
if self.config.lookahead_type == "la":
# If we simply add, the loss will be larger than a non-LA T5 model because
# in a normal T5, the number of tokens is much lower:
loss = (loss + lookahead_loss) / (1 + self.config.lookahead_size)
else:
loss = (loss * lm_logits.shape[1] + lookahead_loss * self.config.lookahead_size) / (
lm_logits.shape[1] + self.config.lookahead_size
)
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutputLA(
loss=loss,
base_loss=base_loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_last_hidden_state=decoder_outputs.last_hidden_state,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
lookahead_logits=lookahead_logits,
lookahead_loss=lookahead_loss,
)