| |
| |
| |
| |
| @@ -1475,11 +1475,7 @@ def from_legacy_cache( |
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
| # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` |
| - if self.self_attention_cache.key_cache == []: |
| - return 0 |
| - if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []: |
| - return 0 |
| - return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
| + return self.self_attention_cache.get_seq_length(layer_idx) |
| |
| def reset(self): |
| if hasattr(self.self_attention_cache, "reset"): |
| |
| |
| |
| |
| @@ -1535,8 +1535,12 @@ def _prepare_generation_config( |
| def _get_initial_cache_position(self, input_ids, model_kwargs): |
| """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" |
| # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` |
| - if "inputs_embeds" in model_kwargs: |
| + if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: |
| cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 |
| + elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder: |
| + cache_position = ( |
| + torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 |
| + ) |
| else: |
| cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 |
| |
| @@ -1633,7 +1637,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None): |
| |
| cache_kwargs = { |
| "config": self.config.get_text_config(), |
| - "max_batch_size": batch_size, |
| + "batch_size": batch_size, |
| "max_cache_len": max_cache_len, |
| "device": device, |
| "dtype": cache_dtype, |
| |
| |
| |
| |
| @@ -79,7 +79,12 @@ class LongT5Config(PretrainedConfig): |
| |
| model_type = "longt5" |
| keys_to_ignore_at_inference = ["past_key_values"] |
| - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} |
| + attribute_map = { |
| + "hidden_size": "d_model", |
| + "num_attention_heads": "num_heads", |
| + "num_hidden_layers": "num_layers", |
| + "head_dim": "d_kv", |
| + } |
| |
| def __init__( |
| self, |
| |
| |
| |
| |
| @@ -24,7 +24,9 @@ |
| from torch.nn import CrossEntropyLoss |
| |
| from ...activations import ACT2FN |
| +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| from ...generation import GenerationMixin |
| +from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPastAndCrossAttentions, |
| @@ -39,6 +41,7 @@ |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_torch_fx_proxy, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -317,7 +320,12 @@ def forward(self, hidden_states): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5 |
| class LongT5Attention(nn.Module): |
| - def __init__(self, config: LongT5Config, has_relative_attention_bias=False): |
| + def __init__( |
| + self, |
| + config: LongT5Config, |
| + has_relative_attention_bias=False, |
| + layer_idx: Optional[int] = None, |
| + ): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.has_relative_attention_bias = has_relative_attention_bias |
| @@ -328,6 +336,13 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias=False): |
| self.n_heads = config.num_heads |
| self.dropout = config.dropout_rate |
| self.inner_dim = self.n_heads * self.key_value_proj_dim |
| + self.layer_idx = layer_idx |
| + if layer_idx is None and self.is_decoder: |
| + logger.warning_once( |
| + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " |
| + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| + "when creating this class." |
| + ) |
| |
| # Mesh TensorFlow initialization to avoid scaling before softmax |
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| @@ -404,11 +419,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets |
| relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) |
| return relative_buckets |
| |
| - def compute_bias(self, query_length, key_length, device=None): |
| + def compute_bias(self, query_length, key_length, device=None, cache_position=None): |
| """Compute binned relative position bias""" |
| if device is None: |
| device = self.relative_attention_bias.weight.device |
| - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + if cache_position is None: |
| + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + else: |
| + context_position = cache_position[:, None].to(device) |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] |
| relative_position = memory_position - context_position # shape (query_length, key_length) |
| relative_position_bucket = self._relative_position_bucket( |
| @@ -432,94 +450,72 @@ def forward( |
| query_length=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| """ |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
| """ |
| # Input is (batch_size, seq_length, dim) |
| - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) |
| - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) |
| + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) |
| batch_size, seq_length = hidden_states.shape[:2] |
| |
| - real_seq_length = seq_length |
| + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder |
| + is_cross_attention = key_value_states is not None |
| |
| - if past_key_value is not None: |
| - if len(past_key_value) != 2: |
| - raise ValueError( |
| - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" |
| - ) |
| - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length |
| + query_states = self.q(hidden_states) |
| + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] |
| - |
| - def shape(states): |
| - """projection""" |
| - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + if past_key_value is not None: |
| + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| + if is_cross_attention: |
| + # after the first generated id, we can subsequently re-use all key/value_states from cache |
| + curr_past_key_value = past_key_value.cross_attention_cache |
| + else: |
| + curr_past_key_value = past_key_value.self_attention_cache |
| |
| - def unshape(states): |
| - """reshape""" |
| - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) |
| - |
| - def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| - """projects hidden states correctly to key/query states""" |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(hidden_states)) |
| - elif past_key_value is None: |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| + current_states = key_value_states if is_cross_attention else hidden_states |
| + if is_cross_attention and past_key_value is not None and is_updated: |
| + # reuse k,v, cross_attentions |
| + key_states = curr_past_key_value.key_cache[self.layer_idx] |
| + value_states = curr_past_key_value.value_cache[self.layer_idx] |
| + else: |
| + key_states = self.k(current_states) |
| + value_states = self.v(current_states) |
| + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| if past_key_value is not None: |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, key_length, dim_per_head) |
| - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) |
| - elif past_key_value.shape[2] != key_value_states.shape[1]: |
| - # checking that the `sequence_length` of the `past_key_value` is the same as |
| - # the provided `key_value_states` to support prefix tuning |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| - else: |
| - # cross-attn |
| - hidden_states = past_key_value |
| - return hidden_states |
| - |
| - # get query states |
| - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) |
| - |
| - # get key/value states |
| - key_states = project( |
| - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None |
| - ) |
| - value_states = project( |
| - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None |
| - ) |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + cache_position = cache_position if not is_cross_attention else None |
| + key_states, value_states = curr_past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| + if is_cross_attention: |
| + past_key_value.is_updated[self.layer_idx] = True |
| |
| - # compute scores |
| - scores = torch.matmul( |
| - query_states, key_states.transpose(3, 2) |
| - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + scores = torch.matmul(query_states, key_states.transpose(3, 2)) |
| |
| if position_bias is None: |
| + key_length = key_states.shape[-2] |
| + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) |
| + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 |
| if not self.has_relative_attention_bias: |
| position_bias = torch.zeros( |
| - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype |
| + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype |
| ) |
| if self.gradient_checkpointing and self.training: |
| position_bias.requires_grad = True |
| else: |
| - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) |
| - |
| - # if key and values are already calculated |
| - # we want only the last query position bias |
| - if past_key_value is not None: |
| - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] |
| + position_bias = self.compute_bias( |
| + real_seq_length, key_length, device=scores.device, cache_position=cache_position |
| + ) |
| + position_bias = position_bias[:, :, -seq_length:, :] |
| |
| if mask is not None: |
| - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) |
| + causal_mask = mask[:, :, :, : key_states.shape[-2]] |
| + position_bias = position_bias + causal_mask |
| |
| if self.pruned_heads: |
| mask = torch.ones(position_bias.shape[1]) |
| @@ -529,22 +525,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| position_bias_masked = position_bias |
| |
| scores += position_bias_masked |
| - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( |
| - scores |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| - attn_weights = nn.functional.dropout( |
| - attn_weights, p=self.dropout, training=self.training |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| + |
| + # (batch_size, n_heads, seq_length, key_length) |
| + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) |
| + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| |
| # Mask heads if we want to |
| if layer_head_mask is not None: |
| attn_weights = attn_weights * layer_head_mask |
| |
| - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) |
| + attn_output = torch.matmul(attn_weights, value_states) |
| + |
| + attn_output = attn_output.transpose(1, 2).contiguous() |
| + attn_output = attn_output.view(batch_size, -1, self.inner_dim) |
| attn_output = self.o(attn_output) |
| |
| - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None |
| - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) |
| + outputs = (attn_output, past_key_value, position_bias) |
| |
| if output_attentions: |
| outputs = outputs + (attn_weights,) |
| @@ -1008,9 +1004,11 @@ def unshape(states): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5 |
| class LongT5LayerSelfAttention(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) |
| + self.SelfAttention = LongT5Attention( |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| + ) |
| self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -1023,6 +1021,7 @@ def forward( |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.SelfAttention( |
| @@ -1033,6 +1032,7 @@ def forward( |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them |
| @@ -1042,7 +1042,7 @@ def forward( |
| class LongT5LayerLocalSelfAttention(nn.Module): |
| """Local self attention used in encoder""" |
| |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias) |
| self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| @@ -1073,7 +1073,7 @@ def forward( |
| class LongT5LayerTransientGlobalSelfAttention(nn.Module): |
| """Transient-Global self attention used in encoder""" |
| |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention( |
| config, has_relative_attention_bias=has_relative_attention_bias |
| @@ -1105,9 +1105,9 @@ def forward( |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5 |
| class LongT5LayerCrossAttention(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False) |
| + self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) |
| self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -1122,6 +1122,7 @@ def forward( |
| use_cache=False, |
| query_length=None, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.EncDecAttention( |
| @@ -1134,6 +1135,7 @@ def forward( |
| use_cache=use_cache, |
| query_length=query_length, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| layer_output = hidden_states + self.dropout(attention_output[0]) |
| outputs = (layer_output,) + attention_output[1:] # add attentions if we output them |
| @@ -1141,7 +1143,7 @@ def forward( |
| |
| |
| class LongT5Block(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| if config.is_decoder: |
| @@ -1156,9 +1158,11 @@ def __init__(self, config, has_relative_attention_bias=False): |
| f"but got {config.encoder_attention_type}." |
| ) |
| self.layer = nn.ModuleList() |
| - self.layer.append(attention_layer(config, has_relative_attention_bias=has_relative_attention_bias)) |
| + self.layer.append( |
| + attention_layer(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) |
| + ) |
| if self.is_decoder: |
| - self.layer.append(LongT5LayerCrossAttention(config)) |
| + self.layer.append(LongT5LayerCrossAttention(config, layer_idx=layer_idx)) |
| |
| self.layer.append(LongT5LayerFF(config)) |
| |
| @@ -1176,34 +1180,19 @@ def forward( |
| use_cache=False, |
| output_attentions=False, |
| return_dict=True, |
| + cache_position=None, |
| ): |
| - if past_key_value is not None: |
| - if not self.is_decoder: |
| - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") |
| - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
| - |
| - if len(past_key_value) != expected_num_past_key_values: |
| - raise ValueError( |
| - f"There should be {expected_num_past_key_values} past states. " |
| - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" |
| - f"Got {len(past_key_value)} past key / value states" |
| - ) |
| - |
| - self_attn_past_key_value = past_key_value[:2] |
| - cross_attn_past_key_value = past_key_value[2:] |
| - else: |
| - self_attn_past_key_value, cross_attn_past_key_value = None, None |
| - |
| self_attention_outputs = self.layer[0]( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| layer_head_mask=layer_head_mask, |
| - past_key_value=self_attn_past_key_value, |
| + past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states, present_key_value_state = self_attention_outputs[:2] |
| + hidden_states, past_key_value = self_attention_outputs[:2] |
| attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights |
| |
| # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ |
| @@ -1213,35 +1202,25 @@ def forward( |
| |
| do_cross_attention = self.is_decoder and encoder_hidden_states is not None |
| if do_cross_attention: |
| - # the actual query length is unknown for cross attention |
| - # if using past key value states. Need to inject it here |
| - if present_key_value_state is not None: |
| - query_length = present_key_value_state[0].shape[2] |
| - else: |
| - query_length = None |
| - |
| cross_attention_outputs = self.layer[1]( |
| hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=cross_attn_past_key_value, |
| - query_length=query_length, |
| + past_key_value=past_key_value, |
| + query_length=cache_position[-1] + 1, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states = cross_attention_outputs[0] |
| + hidden_states, past_key_value = cross_attention_outputs[:2] |
| |
| # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ |
| if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| - # Combine self attn and cross attn key value states |
| - if present_key_value_state is not None: |
| - present_key_value_state = present_key_value_state + cross_attention_outputs[1] |
| - |
| # Keep cross-attention outputs and relative position weights |
| attention_outputs = attention_outputs + cross_attention_outputs[2:] |
| |
| @@ -1256,7 +1235,7 @@ def forward( |
| outputs = (hidden_states,) |
| |
| if use_cache: |
| - outputs = outputs + (present_key_value_state,) + attention_outputs |
| + outputs = outputs + (past_key_value,) + attention_outputs |
| else: |
| outputs = outputs + attention_outputs |
| |
| @@ -1273,6 +1252,8 @@ class LongT5PreTrainedModel(PreTrainedModel): |
| base_model_prefix = "transformer" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["LongT5Block"] |
| + _supports_cache_class = True |
| + _supports_static_cache = False # TODO: @raushan more involved due to local/global attn |
| |
| @property |
| # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs |
| @@ -1376,7 +1357,10 @@ def __init__(self, config, embed_tokens=None): |
| self.block_len = self.local_radius + 1 |
| |
| self.block = nn.ModuleList( |
| - [LongT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] |
| + [ |
| + LongT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) |
| + for i in range(config.num_layers) |
| + ] |
| ) |
| self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| @@ -1408,6 +1392,7 @@ def forward( |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| + cache_position=None, |
| ): |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -1430,36 +1415,65 @@ def forward( |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") |
| |
| + if self.gradient_checkpointing and self.training: |
| + if use_cache: |
| + logger.warning_once( |
| + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| + ) |
| + use_cache = False |
| + |
| if inputs_embeds is None: |
| assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" |
| inputs_embeds = self.embed_tokens(input_ids) |
| |
| batch_size, seq_length = input_shape |
| |
| - # required mask seq length can be calculated via length of past |
| - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length |
| - |
| - if use_cache is True: |
| - assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" |
| + # initialize past_key_values |
| + return_legacy_cache = False |
| + return_self_attention_cache = False |
| + if self.is_decoder and (use_cache or past_key_values is not None): |
| + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): |
| + return_self_attention_cache = True |
| + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| + elif not isinstance(past_key_values, EncoderDecoderCache): |
| + return_legacy_cache = True |
| + logger.warning_once( |
| + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " |
| + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " |
| + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." |
| + ) |
| + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
| + elif past_key_values is None: |
| + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| + elif not self.is_decoder: |
| + # do not pass cache object down the line for encoder stack |
| + # it messes indexing later in decoder-stack because cache object is modified in-place |
| + past_key_values = None |
| + |
| + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + if cache_position is None: |
| + cache_position = torch.arange( |
| + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device |
| + ) |
| |
| - if attention_mask is None: |
| + if attention_mask is None and not is_torchdynamo_compiling(): |
| + # required mask seq length can be calculated via length of past |
| + mask_seq_length = past_key_values_length + seq_length |
| attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| |
| - # initialize past_key_values with `None` if past does not exist |
| - if past_key_values is None: |
| - past_key_values = [None] * len(self.block) |
| - |
| - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] |
| - # ourselves in which case we just need to make it broadcastable to all heads. |
| - # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used |
| if self.is_decoder: |
| - extended_attention_mask = self.get_extended_attention_mask( |
| - attention_mask, input_shape, inputs_embeds.device |
| + causal_mask = self._update_causal_mask( |
| + attention_mask, |
| + inputs_embeds, |
| + cache_position, |
| + past_key_values.self_attention_cache if past_key_values is not None else None, |
| + output_attentions, |
| ) |
| + # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used |
| elif self.config.encoder_attention_type == "local": |
| - extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) |
| + causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) |
| else: # we need to use both local attention mask and standard extended mask for transient-global attention |
| - extended_attention_mask = attention_mask |
| + causal_mask = attention_mask |
| |
| # If a 2D or 3D attention mask is provided for the cross-attention |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |
| @@ -1472,17 +1486,9 @@ def forward( |
| else: |
| encoder_extended_attention_mask = None |
| |
| - if self.gradient_checkpointing and self.training: |
| - if use_cache: |
| - logger.warning_once( |
| - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| - ) |
| - use_cache = False |
| - |
| # Prepare head mask if needed |
| head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
| cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) |
| - present_key_value_states = () if use_cache else None |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and self.is_decoder) else None |
| @@ -1491,7 +1497,7 @@ def forward( |
| |
| hidden_states = self.dropout(inputs_embeds) |
| |
| - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): |
| + for i, layer_module in enumerate(self.block): |
| layer_head_mask = head_mask[i] |
| cross_attn_layer_head_mask = cross_attn_head_mask[i] |
| |
| @@ -1502,7 +1508,7 @@ def forward( |
| layer_outputs = self._gradient_checkpointing_func( |
| layer_module.forward, |
| hidden_states, |
| - extended_attention_mask, |
| + causal_mask, |
| position_bias, |
| encoder_hidden_states, |
| encoder_extended_attention_mask, |
| @@ -1512,20 +1518,24 @@ def forward( |
| None, # past_key_value is always None with gradient checkpointing |
| use_cache, |
| output_attentions, |
| + return_dict, |
| + cache_position, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| - attention_mask=extended_attention_mask, |
| + attention_mask=causal_mask, |
| position_bias=position_bias, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| encoder_decoder_position_bias=encoder_decoder_position_bias, |
| layer_head_mask=layer_head_mask, |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=past_key_value, |
| + past_key_value=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| # layer_outputs is a tuple with: |
| @@ -1533,7 +1543,7 @@ def forward( |
| if use_cache is False: |
| layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] |
| |
| - hidden_states, present_key_value_state = layer_outputs[:2] |
| + hidden_states, next_decoder_cache = layer_outputs[:2] |
| |
| # We share the position biases between the layers - the first layer store them |
| # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), |
| @@ -1541,9 +1551,6 @@ def forward( |
| position_bias = layer_outputs[2] |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] |
| - # append next layer key value states |
| - if use_cache: |
| - present_key_value_states = present_key_value_states + (present_key_value_state,) |
| |
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[3],) |
| @@ -1557,12 +1564,18 @@ def forward( |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| |
| + next_cache = next_decoder_cache if use_cache else None |
| + if return_self_attention_cache: |
| + next_cache = past_key_values.self_attention_cache |
| + if return_legacy_cache: |
| + next_cache = past_key_values.to_legacy_cache() |
| + |
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| - present_key_value_states, |
| + next_cache, |
| all_hidden_states, |
| all_attentions, |
| all_cross_attentions, |
| @@ -1571,12 +1584,135 @@ def forward( |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| - past_key_values=present_key_value_states, |
| + past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| + def _update_causal_mask( |
| + self, |
| + attention_mask: torch.Tensor, |
| + input_tensor: torch.Tensor, |
| + cache_position: torch.Tensor, |
| + past_key_values: Cache, |
| + output_attentions: bool, |
| + ): |
| + if self.config._attn_implementation == "flash_attention_2": |
| + if attention_mask is not None and 0.0 in attention_mask: |
| + return attention_mask |
| + return None |
| + |
| + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
| + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
| + # to infer the attention mask. |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + using_static_cache = isinstance(past_key_values, StaticCache) |
| + |
| + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
| + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| + if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| + attention_mask, |
| + inputs_embeds=input_tensor, |
| + past_key_values_length=past_seen_tokens, |
| + is_training=self.training, |
| + ): |
| + return None |
| + |
| + dtype, device = input_tensor.dtype, input_tensor.device |
| + sequence_length = input_tensor.shape[1] |
| + if using_static_cache: |
| + target_length = past_key_values.get_max_cache_shape() |
| + else: |
| + target_length = ( |
| + attention_mask.shape[-1] |
| + if isinstance(attention_mask, torch.Tensor) |
| + else past_seen_tokens + sequence_length + 1 |
| + ) |
| + |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| + |
| + if ( |
| + self.config._attn_implementation == "sdpa" |
| + and attention_mask is not None |
| + and attention_mask.device.type == "cuda" |
| + and not output_attentions |
| + ): |
| + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| + # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| + |
| + return causal_mask |
| + |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + device: torch.device, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + device (`torch.device`): |
| + The device to plcae the 4D attention mask on. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| LONGT5_START_DOCSTRING = r""" |
| |
| @@ -1693,6 +1829,9 @@ def forward( |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the |
| + cache in the correct position and to infer the complete sequence length. |
| """ |
| |
| LONGT5_ENCODER_INPUTS_DOCSTRING = r""" |
| @@ -1817,6 +1956,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: |
| r""" |
| Returns: |
| @@ -1883,6 +2023,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| if not return_dict: |
| @@ -1975,6 +2116,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| @@ -2050,6 +2192,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| sequence_output = decoder_outputs[0] |
| |
| |
| |
| |
| @@ -72,7 +72,12 @@ class MT5Config(PretrainedConfig): |
| |
| model_type = "mt5" |
| keys_to_ignore_at_inference = ["past_key_values"] |
| - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} |
| + attribute_map = { |
| + "hidden_size": "d_model", |
| + "num_attention_heads": "num_heads", |
| + "num_hidden_layers": "num_layers", |
| + "head_dim": "d_kv", |
| + } |
| |
| def __init__( |
| self, |
| |
| |
| |
| |
| @@ -25,7 +25,9 @@ |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| |
| from ...activations import ACT2FN |
| +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| from ...generation import GenerationMixin |
| +from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPastAndCrossAttentions, |
| @@ -43,6 +45,7 @@ |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_torch_fx_proxy, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -214,7 +217,12 @@ def forward(self, hidden_states): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->MT5 |
| class MT5Attention(nn.Module): |
| - def __init__(self, config: MT5Config, has_relative_attention_bias=False): |
| + def __init__( |
| + self, |
| + config: MT5Config, |
| + has_relative_attention_bias=False, |
| + layer_idx: Optional[int] = None, |
| + ): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.has_relative_attention_bias = has_relative_attention_bias |
| @@ -225,6 +233,13 @@ def __init__(self, config: MT5Config, has_relative_attention_bias=False): |
| self.n_heads = config.num_heads |
| self.dropout = config.dropout_rate |
| self.inner_dim = self.n_heads * self.key_value_proj_dim |
| + self.layer_idx = layer_idx |
| + if layer_idx is None and self.is_decoder: |
| + logger.warning_once( |
| + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " |
| + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| + "when creating this class." |
| + ) |
| |
| # Mesh TensorFlow initialization to avoid scaling before softmax |
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| @@ -301,11 +316,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets |
| relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) |
| return relative_buckets |
| |
| - def compute_bias(self, query_length, key_length, device=None): |
| + def compute_bias(self, query_length, key_length, device=None, cache_position=None): |
| """Compute binned relative position bias""" |
| if device is None: |
| device = self.relative_attention_bias.weight.device |
| - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + if cache_position is None: |
| + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + else: |
| + context_position = cache_position[:, None].to(device) |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] |
| relative_position = memory_position - context_position # shape (query_length, key_length) |
| relative_position_bucket = self._relative_position_bucket( |
| @@ -329,94 +347,72 @@ def forward( |
| query_length=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| """ |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
| """ |
| # Input is (batch_size, seq_length, dim) |
| - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) |
| - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) |
| + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) |
| batch_size, seq_length = hidden_states.shape[:2] |
| |
| - real_seq_length = seq_length |
| + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder |
| + is_cross_attention = key_value_states is not None |
| |
| - if past_key_value is not None: |
| - if len(past_key_value) != 2: |
| - raise ValueError( |
| - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" |
| - ) |
| - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length |
| - |
| - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] |
| - |
| - def shape(states): |
| - """projection""" |
| - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + query_states = self.q(hidden_states) |
| + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| - def unshape(states): |
| - """reshape""" |
| - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) |
| + if past_key_value is not None: |
| + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| + if is_cross_attention: |
| + # after the first generated id, we can subsequently re-use all key/value_states from cache |
| + curr_past_key_value = past_key_value.cross_attention_cache |
| + else: |
| + curr_past_key_value = past_key_value.self_attention_cache |
| |
| - def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| - """projects hidden states correctly to key/query states""" |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(hidden_states)) |
| - elif past_key_value is None: |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| + current_states = key_value_states if is_cross_attention else hidden_states |
| + if is_cross_attention and past_key_value is not None and is_updated: |
| + # reuse k,v, cross_attentions |
| + key_states = curr_past_key_value.key_cache[self.layer_idx] |
| + value_states = curr_past_key_value.value_cache[self.layer_idx] |
| + else: |
| + key_states = self.k(current_states) |
| + value_states = self.v(current_states) |
| + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| if past_key_value is not None: |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, key_length, dim_per_head) |
| - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) |
| - elif past_key_value.shape[2] != key_value_states.shape[1]: |
| - # checking that the `sequence_length` of the `past_key_value` is the same as |
| - # the provided `key_value_states` to support prefix tuning |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| - else: |
| - # cross-attn |
| - hidden_states = past_key_value |
| - return hidden_states |
| - |
| - # get query states |
| - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) |
| - |
| - # get key/value states |
| - key_states = project( |
| - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None |
| - ) |
| - value_states = project( |
| - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None |
| - ) |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + cache_position = cache_position if not is_cross_attention else None |
| + key_states, value_states = curr_past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| + if is_cross_attention: |
| + past_key_value.is_updated[self.layer_idx] = True |
| |
| - # compute scores |
| - scores = torch.matmul( |
| - query_states, key_states.transpose(3, 2) |
| - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + scores = torch.matmul(query_states, key_states.transpose(3, 2)) |
| |
| if position_bias is None: |
| + key_length = key_states.shape[-2] |
| + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) |
| + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 |
| if not self.has_relative_attention_bias: |
| position_bias = torch.zeros( |
| - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype |
| + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype |
| ) |
| if self.gradient_checkpointing and self.training: |
| position_bias.requires_grad = True |
| else: |
| - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) |
| - |
| - # if key and values are already calculated |
| - # we want only the last query position bias |
| - if past_key_value is not None: |
| - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] |
| + position_bias = self.compute_bias( |
| + real_seq_length, key_length, device=scores.device, cache_position=cache_position |
| + ) |
| + position_bias = position_bias[:, :, -seq_length:, :] |
| |
| if mask is not None: |
| - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) |
| + causal_mask = mask[:, :, :, : key_states.shape[-2]] |
| + position_bias = position_bias + causal_mask |
| |
| if self.pruned_heads: |
| mask = torch.ones(position_bias.shape[1]) |
| @@ -426,22 +422,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| position_bias_masked = position_bias |
| |
| scores += position_bias_masked |
| - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( |
| - scores |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| - attn_weights = nn.functional.dropout( |
| - attn_weights, p=self.dropout, training=self.training |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| + |
| + # (batch_size, n_heads, seq_length, key_length) |
| + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) |
| + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| |
| # Mask heads if we want to |
| if layer_head_mask is not None: |
| attn_weights = attn_weights * layer_head_mask |
| |
| - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) |
| + attn_output = torch.matmul(attn_weights, value_states) |
| + |
| + attn_output = attn_output.transpose(1, 2).contiguous() |
| + attn_output = attn_output.view(batch_size, -1, self.inner_dim) |
| attn_output = self.o(attn_output) |
| |
| - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None |
| - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) |
| + outputs = (attn_output, past_key_value, position_bias) |
| |
| if output_attentions: |
| outputs = outputs + (attn_weights,) |
| @@ -450,9 +446,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5 |
| class MT5LayerSelfAttention(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.SelfAttention = MT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) |
| + self.SelfAttention = MT5Attention( |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| + ) |
| self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -465,6 +463,7 @@ def forward( |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.SelfAttention( |
| @@ -475,6 +474,7 @@ def forward( |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them |
| @@ -483,9 +483,9 @@ def forward( |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->MT5 |
| class MT5LayerCrossAttention(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False) |
| + self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) |
| self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -500,6 +500,7 @@ def forward( |
| use_cache=False, |
| query_length=None, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.EncDecAttention( |
| @@ -512,6 +513,7 @@ def forward( |
| use_cache=use_cache, |
| query_length=query_length, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| layer_output = hidden_states + self.dropout(attention_output[0]) |
| outputs = (layer_output,) + attention_output[1:] # add attentions if we output them |
| @@ -520,13 +522,15 @@ def forward( |
| |
| # Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5 |
| class MT5Block(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.layer = nn.ModuleList() |
| - self.layer.append(MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) |
| + self.layer.append( |
| + MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) |
| + ) |
| if self.is_decoder: |
| - self.layer.append(MT5LayerCrossAttention(config)) |
| + self.layer.append(MT5LayerCrossAttention(config, layer_idx=layer_idx)) |
| |
| self.layer.append(MT5LayerFF(config)) |
| |
| @@ -544,34 +548,19 @@ def forward( |
| use_cache=False, |
| output_attentions=False, |
| return_dict=True, |
| + cache_position=None, |
| ): |
| - if past_key_value is not None: |
| - if not self.is_decoder: |
| - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") |
| - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
| - |
| - if len(past_key_value) != expected_num_past_key_values: |
| - raise ValueError( |
| - f"There should be {expected_num_past_key_values} past states. " |
| - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" |
| - f"Got {len(past_key_value)} past key / value states" |
| - ) |
| - |
| - self_attn_past_key_value = past_key_value[:2] |
| - cross_attn_past_key_value = past_key_value[2:] |
| - else: |
| - self_attn_past_key_value, cross_attn_past_key_value = None, None |
| - |
| self_attention_outputs = self.layer[0]( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| layer_head_mask=layer_head_mask, |
| - past_key_value=self_attn_past_key_value, |
| + past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states, present_key_value_state = self_attention_outputs[:2] |
| + hidden_states, past_key_value = self_attention_outputs[:2] |
| attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights |
| |
| # clamp inf values to enable fp16 training |
| @@ -585,25 +574,18 @@ def forward( |
| |
| do_cross_attention = self.is_decoder and encoder_hidden_states is not None |
| if do_cross_attention: |
| - # the actual query length is unknown for cross attention |
| - # if using past key value states. Need to inject it here |
| - if present_key_value_state is not None: |
| - query_length = present_key_value_state[0].shape[2] |
| - else: |
| - query_length = None |
| - |
| cross_attention_outputs = self.layer[1]( |
| hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=cross_attn_past_key_value, |
| - query_length=query_length, |
| + past_key_value=past_key_value, |
| + query_length=cache_position[-1] + 1, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| - hidden_states = cross_attention_outputs[0] |
| + hidden_states, past_key_value = cross_attention_outputs[:2] |
| |
| # clamp inf values to enable fp16 training |
| if hidden_states.dtype == torch.float16: |
| @@ -614,10 +596,6 @@ def forward( |
| ) |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| - # Combine self attn and cross attn key value states |
| - if present_key_value_state is not None: |
| - present_key_value_state = present_key_value_state + cross_attention_outputs[1] |
| - |
| # Keep cross-attention outputs and relative position weights |
| attention_outputs = attention_outputs + cross_attention_outputs[2:] |
| |
| @@ -636,11 +614,11 @@ def forward( |
| outputs = (hidden_states,) |
| |
| if use_cache: |
| - outputs = outputs + (present_key_value_state,) + attention_outputs |
| + outputs = outputs + (past_key_value,) + attention_outputs |
| else: |
| outputs = outputs + attention_outputs |
| |
| - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) |
| + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) |
| |
| |
| def load_tf_weights_in_mt5(model, config, tf_checkpoint_path): |
| @@ -780,6 +758,9 @@ class MT5PreTrainedModel(PreTrainedModel): |
| base_model_prefix = "transformer" |
| is_parallelizable = True |
| supports_gradient_checkpointing = True |
| + _supports_quantized_cache = False # enc-dec models don't support yet |
| + _supports_static_cache = True |
| + _supports_cache_class = True |
| _no_split_modules = ["MT5Block"] |
| _keep_in_fp32_modules = ["wo"] |
| |
| @@ -892,7 +873,7 @@ def __init__(self, config, embed_tokens=None): |
| self.is_decoder = config.is_decoder |
| |
| self.block = nn.ModuleList( |
| - [MT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] |
| + [MT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] |
| ) |
| self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| @@ -968,6 +949,7 @@ def forward( |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| + cache_position=None, |
| ): |
| # Model parallel |
| if self.model_parallel: |
| @@ -994,6 +976,13 @@ def forward( |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") |
| |
| + if self.gradient_checkpointing and self.training: |
| + if use_cache: |
| + logger.warning_once( |
| + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| + ) |
| + use_cache = False |
| + |
| if inputs_embeds is None: |
| if self.embed_tokens is None: |
| raise ValueError("You have to initialize the model with valid token embeddings") |
| @@ -1001,23 +990,57 @@ def forward( |
| |
| batch_size, seq_length = input_shape |
| |
| - # required mask seq length can be calculated via length of past |
| - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length |
| - |
| if use_cache is True: |
| if not self.is_decoder: |
| raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") |
| |
| - # initialize past_key_values with `None` if past does not exist |
| - if past_key_values is None: |
| - past_key_values = [None] * len(self.block) |
| + # initialize past_key_values |
| + return_legacy_cache = False |
| + return_self_attention_cache = False |
| + if self.is_decoder and (use_cache or past_key_values is not None): |
| + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): |
| + return_self_attention_cache = True |
| + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| + elif not isinstance(past_key_values, EncoderDecoderCache): |
| + return_legacy_cache = True |
| + logger.warning_once( |
| + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " |
| + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " |
| + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." |
| + ) |
| + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
| + elif past_key_values is None: |
| + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| + elif not self.is_decoder: |
| + # do not pass cache object down the line for encoder stack |
| + # it messes indexing later in decoder-stack because cache object is modified in-place |
| + past_key_values = None |
| + |
| + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + if cache_position is None: |
| + cache_position = torch.arange( |
| + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device |
| + ) |
| |
| - if attention_mask is None: |
| + if attention_mask is None and not is_torchdynamo_compiling(): |
| + # required mask seq length can be calculated via length of past cache |
| + mask_seq_length = past_key_values_length + seq_length |
| attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| |
| - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] |
| - # ourselves in which case we just need to make it broadcastable to all heads. |
| - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
| + if self.config.is_decoder: |
| + causal_mask = self._update_causal_mask( |
| + attention_mask, |
| + inputs_embeds, |
| + cache_position, |
| + past_key_values.self_attention_cache if past_key_values is not None else None, |
| + output_attentions, |
| + ) |
| + elif attention_mask is not None: |
| + causal_mask = attention_mask[:, None, None, :] |
| + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) |
| + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min |
| + else: |
| + causal_mask = None |
| |
| # If a 2D or 3D attention mask is provided for the cross-attention |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |
| @@ -1032,17 +1055,9 @@ def forward( |
| else: |
| encoder_extended_attention_mask = None |
| |
| - if self.gradient_checkpointing and self.training: |
| - if use_cache: |
| - logger.warning_once( |
| - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| - ) |
| - use_cache = False |
| - |
| # Prepare head mask if needed |
| head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
| cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) |
| - present_key_value_states = () if use_cache else None |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and self.is_decoder) else None |
| @@ -1051,15 +1066,15 @@ def forward( |
| |
| hidden_states = self.dropout(inputs_embeds) |
| |
| - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): |
| + for i, layer_module in enumerate(self.block): |
| layer_head_mask = head_mask[i] |
| cross_attn_layer_head_mask = cross_attn_head_mask[i] |
| # Model parallel |
| if self.model_parallel: |
| torch.cuda.set_device(hidden_states.device) |
| # Ensure that attention_mask is always on the same device as hidden_states |
| - if attention_mask is not None: |
| - attention_mask = attention_mask.to(hidden_states.device) |
| + if causal_mask is not None: |
| + causal_mask = causal_mask.to(hidden_states.device) |
| if position_bias is not None: |
| position_bias = position_bias.to(hidden_states.device) |
| if encoder_hidden_states is not None: |
| @@ -1079,7 +1094,7 @@ def forward( |
| layer_outputs = self._gradient_checkpointing_func( |
| layer_module.forward, |
| hidden_states, |
| - extended_attention_mask, |
| + causal_mask, |
| position_bias, |
| encoder_hidden_states, |
| encoder_extended_attention_mask, |
| @@ -1089,20 +1104,24 @@ def forward( |
| None, # past_key_value is always None with gradient checkpointing |
| use_cache, |
| output_attentions, |
| + return_dict, |
| + cache_position, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| - attention_mask=extended_attention_mask, |
| + attention_mask=causal_mask, |
| position_bias=position_bias, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| encoder_decoder_position_bias=encoder_decoder_position_bias, |
| layer_head_mask=layer_head_mask, |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=past_key_value, |
| + past_key_value=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| # layer_outputs is a tuple with: |
| @@ -1110,7 +1129,7 @@ def forward( |
| if use_cache is False: |
| layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] |
| |
| - hidden_states, present_key_value_state = layer_outputs[:2] |
| + hidden_states, next_decoder_cache = layer_outputs[:2] |
| |
| # We share the position biases between the layers - the first layer store them |
| # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), |
| @@ -1118,9 +1137,6 @@ def forward( |
| position_bias = layer_outputs[2] |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] |
| - # append next layer key value states |
| - if use_cache: |
| - present_key_value_states = present_key_value_states + (present_key_value_state,) |
| |
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[3],) |
| @@ -1140,12 +1156,18 @@ def forward( |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| |
| + next_cache = next_decoder_cache if use_cache else None |
| + if return_self_attention_cache: |
| + next_cache = past_key_values.self_attention_cache |
| + if return_legacy_cache: |
| + next_cache = past_key_values.to_legacy_cache() |
| + |
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| - present_key_value_states, |
| + next_cache, |
| all_hidden_states, |
| all_attentions, |
| all_cross_attentions, |
| @@ -1154,12 +1176,135 @@ def forward( |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| - past_key_values=present_key_value_states, |
| + past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| + def _update_causal_mask( |
| + self, |
| + attention_mask: torch.Tensor, |
| + input_tensor: torch.Tensor, |
| + cache_position: torch.Tensor, |
| + past_key_values: Cache, |
| + output_attentions: bool, |
| + ): |
| + if self.config._attn_implementation == "flash_attention_2": |
| + if attention_mask is not None and 0.0 in attention_mask: |
| + return attention_mask |
| + return None |
| + |
| + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
| + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
| + # to infer the attention mask. |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + using_static_cache = isinstance(past_key_values, StaticCache) |
| + |
| + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
| + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| + if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| + attention_mask, |
| + inputs_embeds=input_tensor, |
| + past_key_values_length=past_seen_tokens, |
| + is_training=self.training, |
| + ): |
| + return None |
| + |
| + dtype, device = input_tensor.dtype, input_tensor.device |
| + sequence_length = input_tensor.shape[1] |
| + if using_static_cache: |
| + target_length = past_key_values.get_max_cache_shape() |
| + else: |
| + target_length = ( |
| + attention_mask.shape[-1] |
| + if isinstance(attention_mask, torch.Tensor) |
| + else past_seen_tokens + sequence_length + 1 |
| + ) |
| + |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| + |
| + if ( |
| + self.config._attn_implementation == "sdpa" |
| + and attention_mask is not None |
| + and attention_mask.device.type == "cuda" |
| + and not output_attentions |
| + ): |
| + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| + # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| + |
| + return causal_mask |
| + |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + device: torch.device, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + device (`torch.device`): |
| + The device to plcae the 4D attention mask on. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| MT5_START_DOCSTRING = r""" |
| |
| @@ -1454,6 +1599,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: |
| r""" |
| Returns: |
| @@ -1533,6 +1679,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| if not return_dict: |
| @@ -1685,6 +1832,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| @@ -1779,6 +1927,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| sequence_output = decoder_outputs[0] |
| |
| |
| |
| |
| @@ -22,7 +22,9 @@ |
| from torch import nn |
| |
| from ...activations import ACT2FN |
| +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| from ...generation import GenerationMixin |
| +from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPooling, |
| @@ -38,6 +40,7 @@ |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_torch_fx_proxy, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -184,14 +187,17 @@ def to_projection_shape(states): |
| if self.gradient_checkpointing and self.training: |
| position_bias.requires_grad = True |
| |
| - if attention_mask is None: |
| - attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype) |
| - |
| if attention_mask.dim() == 2: |
| position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device) |
| - else: |
| + elif attention_mask is not None: |
| # (batch_size, n_heads, seq_length, key_length) |
| position_bias = position_bias + attention_mask.to(position_bias.device) |
| + elif not is_torchdynamo_compiling(): |
| + attention_mask = torch.ones( |
| + (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype |
| + ) |
| + position_bias = position_bias + attention_mask.to(position_bias.device) |
| + |
| position_bias = 1 - position_bias |
| |
| position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min) |
| @@ -355,6 +361,8 @@ class Pix2StructPreTrainedModel(PreTrainedModel): |
| """ |
| |
| config_class = Pix2StructConfig |
| + _supports_cache_class = True |
| + _supports_static_cache = False |
| |
| @property |
| def dummy_inputs(self): |
| @@ -673,7 +681,9 @@ def forward(self, hidden_states): |
| |
| |
| class Pix2StructTextAttention(nn.Module): |
| - def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=False): |
| + def __init__( |
| + self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None |
| + ): |
| super().__init__() |
| self.has_relative_attention_bias = has_relative_attention_bias |
| self.relative_attention_num_buckets = config.relative_attention_num_buckets |
| @@ -683,6 +693,13 @@ def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=Fal |
| self.n_heads = config.num_heads |
| self.dropout = config.dropout_rate |
| self.inner_dim = self.n_heads * self.key_value_proj_dim |
| + self.layer_idx = layer_idx |
| + if layer_idx is None: |
| + logger.warning_once( |
| + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " |
| + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| + "when creating this class." |
| + ) |
| |
| # Mesh TensorFlow initialization to avoid scaling before softmax |
| self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False) |
| @@ -773,75 +790,56 @@ def forward( |
| query_length=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| """ |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
| """ |
| # Input is (batch_size, seq_length, dim) |
| - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) |
| - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) |
| + # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length) |
| batch_size, seq_length = hidden_states.shape[:2] |
| |
| - real_seq_length = seq_length |
| - |
| - if past_key_value is not None: |
| - if len(past_key_value) != 2: |
| - raise ValueError( |
| - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" |
| - ) |
| - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length |
| - |
| - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] |
| - |
| - def to_projection_shape(states): |
| - """projection""" |
| - return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder |
| + is_cross_attention = key_value_states is not None |
| |
| - def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| - """projects hidden states correctly to key/query states""" |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = to_projection_shape(proj_layer(hidden_states)) |
| - elif past_key_value is None: |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = to_projection_shape(proj_layer(key_value_states)) |
| - |
| - if past_key_value is not None: |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, key_length, dim_per_head) |
| - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) |
| - elif past_key_value.shape[2] != key_value_states.shape[1]: |
| - # checking that the `sequence_length` of the `past_key_value` is the same as |
| - # the provided `key_value_states` to support prefix tuning |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = to_projection_shape(proj_layer(key_value_states)) |
| - else: |
| - # cross-attn |
| - hidden_states = past_key_value |
| - return hidden_states |
| + query_states = self.query(hidden_states).contiguous() |
| + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| - # get query states |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - query_states = to_projection_shape(self.query(hidden_states)) |
| + if past_key_value is not None: |
| + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| + if is_cross_attention: |
| + # after the first generated id, we can subsequently re-use all key/value_states from cache |
| + past_key_value = past_key_value.cross_attention_cache |
| + else: |
| + past_key_value = past_key_value.self_attention_cache |
| |
| # get key/value states |
| - key_states = project( |
| - hidden_states, self.key, key_value_states, past_key_value[0] if past_key_value is not None else None |
| - ) |
| - value_states = project( |
| - hidden_states, self.value, key_value_states, past_key_value[1] if past_key_value is not None else None |
| - ) |
| + current_states = key_value_states if is_cross_attention else hidden_states |
| + if is_cross_attention and past_key_value and is_updated: |
| + # reuse k,v, cross_attentions |
| + key_states = past_key_value.key_cache[self.layer_idx] |
| + value_states = past_key_value.value_cache[self.layer_idx] |
| + else: |
| + key_states = self.key(current_states).contiguous() |
| + value_states = self.value(current_states).contiguous() |
| + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + if past_key_value is not None: |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + cache_position = cache_position if not is_cross_attention else None |
| + key_states, value_states = past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| + if is_cross_attention: |
| + past_key_value.is_updated[self.layer_idx] = True |
| |
| # compute scores |
| - scores = torch.matmul( |
| - query_states, key_states.transpose(3, 2) |
| - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + scores = torch.matmul(query_states, key_states.transpose(3, 2)) |
| |
| if position_bias is None: |
| + real_seq_length = cache_position[-1] + 1 if query_length is None else query_length |
| + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] |
| if not self.has_relative_attention_bias: |
| position_bias = torch.zeros( |
| (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype |
| @@ -851,11 +849,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| else: |
| position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) |
| |
| - # if key and values are already calculated |
| - # we want only the last query position bias |
| - if past_key_value is not None: |
| - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] |
| - |
| if mask is not None: |
| position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) |
| |
| @@ -883,19 +876,20 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| |
| attn_output = self.output(attn_output) |
| |
| - present_key_value_state = (key_states, value_states) if use_cache else None |
| - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) |
| + outputs = (attn_output,) + (past_key_value,) + (position_bias,) |
| |
| if output_attentions: |
| outputs = outputs + (attn_weights,) |
| return outputs |
| |
| |
| -# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size |
| +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size |
| class Pix2StructTextLayerSelfAttention(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=has_relative_attention_bias) |
| + self.attention = Pix2StructTextAttention( |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| + ) |
| self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -908,6 +902,7 @@ def forward( |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.attention( |
| @@ -918,17 +913,18 @@ def forward( |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them |
| return outputs |
| |
| |
| -# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size |
| +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size |
| class Pix2StructTextLayerCrossAttention(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False) |
| + self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) |
| self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -943,6 +939,7 @@ def forward( |
| use_cache=False, |
| query_length=None, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.attention( |
| @@ -955,6 +952,7 @@ def forward( |
| use_cache=use_cache, |
| query_length=query_length, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| layer_output = hidden_states + self.dropout(attention_output[0]) |
| outputs = (layer_output,) + attention_output[1:] # add attentions if we output them |
| @@ -962,11 +960,13 @@ def forward( |
| |
| |
| class Pix2StructTextBlock(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| |
| self.self_attention = Pix2StructTextLayerSelfAttention( |
| - config, has_relative_attention_bias=has_relative_attention_bias |
| + config, |
| + has_relative_attention_bias=has_relative_attention_bias, |
| + layer_idx=layer_idx, |
| ) |
| |
| self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config) |
| @@ -987,32 +987,19 @@ def forward( |
| use_cache=False, |
| output_attentions=False, |
| return_dict=True, |
| + cache_position=None, |
| ): |
| - if past_key_value is not None: |
| - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
| - |
| - if len(past_key_value) != expected_num_past_key_values: |
| - raise ValueError( |
| - f"There should be {expected_num_past_key_values} past states. " |
| - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" |
| - f"Got {len(past_key_value)} past key / value states" |
| - ) |
| - |
| - self_attn_past_key_value = past_key_value[:2] |
| - cross_attn_past_key_value = past_key_value[2:] |
| - else: |
| - self_attn_past_key_value, cross_attn_past_key_value = None, None |
| - |
| self_attention_outputs = self.self_attention( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| layer_head_mask=layer_head_mask, |
| - past_key_value=self_attn_past_key_value, |
| + past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states, present_key_value_state = self_attention_outputs[:2] |
| + hidden_states, past_key_value = self_attention_outputs[:2] |
| attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights |
| |
| # clamp inf values to enable fp16 training |
| @@ -1022,35 +1009,25 @@ def forward( |
| |
| do_cross_attention = encoder_hidden_states is not None |
| if do_cross_attention: |
| - # the actual query length is unknown for cross attention |
| - # if using past key value states. Need to inject it here |
| - if present_key_value_state is not None: |
| - query_length = present_key_value_state[0].shape[2] |
| - else: |
| - query_length = None |
| - |
| cross_attention_outputs = self.encoder_decoder_attention( |
| hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=cross_attn_past_key_value, |
| - query_length=query_length, |
| + past_key_value=past_key_value, |
| + query_length=cache_position[-1] + 1, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states = cross_attention_outputs[0] |
| + hidden_states, past_key_value = cross_attention_outputs[:2] |
| |
| # clamp inf values to enable fp16 training |
| if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| - # Combine self attn and cross attn key value states |
| - if present_key_value_state is not None: |
| - present_key_value_state = present_key_value_state + cross_attention_outputs[1] |
| - |
| # Keep cross-attention outputs and relative position weights |
| attention_outputs = attention_outputs + cross_attention_outputs[2:] |
| |
| @@ -1065,7 +1042,7 @@ def forward( |
| outputs = (hidden_states,) |
| |
| if use_cache: |
| - outputs = outputs + (present_key_value_state,) + attention_outputs |
| + outputs = outputs + (past_key_value,) + attention_outputs |
| else: |
| outputs = outputs + attention_outputs |
| |
| @@ -1187,6 +1164,9 @@ def forward( |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the |
| + cache in the correct position and to infer the complete sequence length. |
| """ |
| |
| PIX2STRUCT_INPUTS_DOCSTRING = r""" |
| @@ -1293,7 +1273,10 @@ def __init__(self, config): |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| |
| self.layer = nn.ModuleList( |
| - [Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] |
| + [ |
| + Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) |
| + for i in range(config.num_layers) |
| + ] |
| ) |
| self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| @@ -1364,6 +1347,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| labels: Optional[torch.LongTensor] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]: |
| r""" |
| @@ -1405,24 +1389,54 @@ def forward( |
| |
| batch_size, seq_length = input_shape |
| |
| - # required mask seq length can be calculated via length of past |
| - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length |
| + # initialize past_key_values |
| + return_legacy_cache = False |
| + return_self_attention_cache = False |
| + if use_cache or past_key_values is not None: |
| + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): |
| + return_self_attention_cache = True |
| + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| + elif not isinstance(past_key_values, EncoderDecoderCache): |
| + return_legacy_cache = True |
| + logger.warning_once( |
| + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " |
| + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " |
| + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." |
| + ) |
| + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
| + elif past_key_values is None: |
| + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| + |
| + past_key_values_length = 0 |
| + if cache_position is not None: |
| + past_key_values_length = cache_position[0] |
| + elif past_key_values is not None: |
| + past_key_values_length = past_key_values.get_seq_length() |
| + |
| + if cache_position is None: |
| + cache_position = torch.arange( |
| + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device |
| + ) |
| |
| if attention_mask is None: |
| - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| - if encoder_attention_mask is None and encoder_hidden_states is not None: |
| - encoder_seq_length = encoder_hidden_states.shape[1] |
| - encoder_attention_mask = torch.ones( |
| - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long |
| + # required mask seq length can be calculated via length of past |
| + mask_seq_length = ( |
| + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length |
| ) |
| + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| |
| - # initialize past_key_values with `None` if past does not exist |
| - if past_key_values is None: |
| - past_key_values = [None] * len(self.layer) |
| - |
| - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] |
| - # ourselves in which case we just need to make it broadcastable to all heads. |
| - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
| + if self.config.is_decoder: |
| + causal_mask = self._update_causal_mask( |
| + attention_mask, |
| + inputs_embeds, |
| + cache_position, |
| + past_key_values.self_attention_cache if past_key_values is not None else None, |
| + output_attentions, |
| + ) |
| + else: |
| + causal_mask = attention_mask[:, None, None, :] |
| + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) |
| + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min |
| |
| # If a 2D or 3D attention mask is provided for the cross-attention |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |
| @@ -1438,7 +1452,6 @@ def forward( |
| # Prepare head mask if needed |
| head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
| cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) |
| - present_key_value_states = () if use_cache else None |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions) else None |
| @@ -1447,7 +1460,7 @@ def forward( |
| |
| hidden_states = self.dropout(inputs_embeds) |
| |
| - for i, (layer_module, past_key_value) in enumerate(zip(self.layer, past_key_values)): |
| + for i, layer_module in enumerate(self.layer): |
| layer_head_mask = head_mask[i] |
| cross_attn_layer_head_mask = cross_attn_head_mask[i] |
| if output_hidden_states: |
| @@ -1462,7 +1475,7 @@ def forward( |
| layer_outputs = self._gradient_checkpointing_func( |
| layer_module.forward, |
| hidden_states, |
| - extended_attention_mask, |
| + causal_mask, |
| position_bias, |
| encoder_hidden_states, |
| encoder_extended_attention_mask, |
| @@ -1472,20 +1485,22 @@ def forward( |
| None, # past_key_value is always None with gradient checkpointing |
| use_cache, |
| output_attentions, |
| + cache_position, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| - attention_mask=extended_attention_mask, |
| + attention_mask=causal_mask, |
| position_bias=position_bias, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| encoder_decoder_position_bias=encoder_decoder_position_bias, |
| layer_head_mask=layer_head_mask, |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=past_key_value, |
| + past_key_value=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| |
| # layer_outputs is a tuple with: |
| @@ -1493,7 +1508,7 @@ def forward( |
| if use_cache is False: |
| layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] |
| |
| - hidden_states, present_key_value_state = layer_outputs[:2] |
| + hidden_states, next_decoder_cache = layer_outputs[:2] |
| |
| # We share the position biases between the layers - the first layer store them |
| # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), |
| @@ -1501,9 +1516,6 @@ def forward( |
| position_bias = layer_outputs[2] |
| if encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] |
| - # append next layer key value states |
| - if use_cache: |
| - present_key_value_states = present_key_value_states + (present_key_value_state,) |
| |
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[3],) |
| @@ -1527,13 +1539,19 @@ def forward( |
| |
| loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) |
| |
| + next_cache = next_decoder_cache if use_cache else None |
| + if return_self_attention_cache: |
| + next_cache = past_key_values.self_attention_cache |
| + if return_legacy_cache: |
| + next_cache = past_key_values.to_legacy_cache() |
| + |
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| loss, |
| logits, |
| - present_key_value_states, |
| + next_cache, |
| all_hidden_states, |
| all_attentions, |
| all_cross_attentions, |
| @@ -1543,12 +1561,135 @@ def forward( |
| return CausalLMOutputWithCrossAttentions( |
| loss=loss, |
| logits=logits, |
| - past_key_values=present_key_value_states, |
| + past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| + def _update_causal_mask( |
| + self, |
| + attention_mask: torch.Tensor, |
| + input_tensor: torch.Tensor, |
| + cache_position: torch.Tensor, |
| + past_key_values: Cache, |
| + output_attentions: bool, |
| + ): |
| + if self.config._attn_implementation == "flash_attention_2": |
| + if attention_mask is not None and 0.0 in attention_mask: |
| + return attention_mask |
| + return None |
| + |
| + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
| + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
| + # to infer the attention mask. |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + using_static_cache = isinstance(past_key_values, StaticCache) |
| + |
| + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
| + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| + if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| + attention_mask, |
| + inputs_embeds=input_tensor, |
| + past_key_values_length=past_seen_tokens, |
| + is_training=self.training, |
| + ): |
| + return None |
| + |
| + dtype, device = input_tensor.dtype, input_tensor.device |
| + sequence_length = input_tensor.shape[1] |
| + if using_static_cache: |
| + target_length = past_key_values.get_max_cache_shape() |
| + else: |
| + target_length = ( |
| + attention_mask.shape[-1] |
| + if isinstance(attention_mask, torch.Tensor) |
| + else past_seen_tokens + sequence_length + 1 |
| + ) |
| + |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| + |
| + if ( |
| + self.config._attn_implementation == "sdpa" |
| + and attention_mask is not None |
| + and attention_mask.device.type == "cuda" |
| + and not output_attentions |
| + ): |
| + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| + # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| + |
| + return causal_mask |
| + |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + device: torch.device, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + device (`torch.device`): |
| + The device to plcae the 4D attention mask on. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| @add_start_docstrings( |
| "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", |
| @@ -1615,6 +1756,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: |
| r""" |
| Returns: |
| @@ -1723,6 +1865,7 @@ def forward( |
| output_hidden_states=output_hidden_states, |
| labels=labels, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| if not return_dict: |
| |
| |
| |
| |
| @@ -25,7 +25,9 @@ |
| from transformers.generation import GenerationConfig |
| |
| from ...activations import ACT2FN |
| +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| from ...generation import GenerationMixin |
| +from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPastAndCrossAttentions, |
| @@ -37,6 +39,7 @@ |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_torch_fx_proxy, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -136,6 +139,9 @@ |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the |
| + cache in the correct position and to infer the complete sequence length. |
| """ |
| |
| |
| @@ -245,7 +251,12 @@ def forward(self, hidden_states): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano |
| class Pop2PianoAttention(nn.Module): |
| - def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False): |
| + def __init__( |
| + self, |
| + config: Pop2PianoConfig, |
| + has_relative_attention_bias=False, |
| + layer_idx: Optional[int] = None, |
| + ): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.has_relative_attention_bias = has_relative_attention_bias |
| @@ -256,6 +267,13 @@ def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False): |
| self.n_heads = config.num_heads |
| self.dropout = config.dropout_rate |
| self.inner_dim = self.n_heads * self.key_value_proj_dim |
| + self.layer_idx = layer_idx |
| + if layer_idx is None and self.is_decoder: |
| + logger.warning_once( |
| + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " |
| + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| + "when creating this class." |
| + ) |
| |
| # Mesh TensorFlow initialization to avoid scaling before softmax |
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| @@ -332,11 +350,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets |
| relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) |
| return relative_buckets |
| |
| - def compute_bias(self, query_length, key_length, device=None): |
| + def compute_bias(self, query_length, key_length, device=None, cache_position=None): |
| """Compute binned relative position bias""" |
| if device is None: |
| device = self.relative_attention_bias.weight.device |
| - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + if cache_position is None: |
| + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + else: |
| + context_position = cache_position[:, None].to(device) |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] |
| relative_position = memory_position - context_position # shape (query_length, key_length) |
| relative_position_bucket = self._relative_position_bucket( |
| @@ -360,94 +381,72 @@ def forward( |
| query_length=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| """ |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
| """ |
| # Input is (batch_size, seq_length, dim) |
| - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) |
| - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) |
| + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) |
| batch_size, seq_length = hidden_states.shape[:2] |
| |
| - real_seq_length = seq_length |
| + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder |
| + is_cross_attention = key_value_states is not None |
| |
| - if past_key_value is not None: |
| - if len(past_key_value) != 2: |
| - raise ValueError( |
| - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" |
| - ) |
| - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length |
| - |
| - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] |
| - |
| - def shape(states): |
| - """projection""" |
| - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + query_states = self.q(hidden_states) |
| + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| - def unshape(states): |
| - """reshape""" |
| - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) |
| + if past_key_value is not None: |
| + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| + if is_cross_attention: |
| + # after the first generated id, we can subsequently re-use all key/value_states from cache |
| + curr_past_key_value = past_key_value.cross_attention_cache |
| + else: |
| + curr_past_key_value = past_key_value.self_attention_cache |
| |
| - def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| - """projects hidden states correctly to key/query states""" |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(hidden_states)) |
| - elif past_key_value is None: |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| + current_states = key_value_states if is_cross_attention else hidden_states |
| + if is_cross_attention and past_key_value is not None and is_updated: |
| + # reuse k,v, cross_attentions |
| + key_states = curr_past_key_value.key_cache[self.layer_idx] |
| + value_states = curr_past_key_value.value_cache[self.layer_idx] |
| + else: |
| + key_states = self.k(current_states) |
| + value_states = self.v(current_states) |
| + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| if past_key_value is not None: |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, key_length, dim_per_head) |
| - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) |
| - elif past_key_value.shape[2] != key_value_states.shape[1]: |
| - # checking that the `sequence_length` of the `past_key_value` is the same as |
| - # the provided `key_value_states` to support prefix tuning |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| - else: |
| - # cross-attn |
| - hidden_states = past_key_value |
| - return hidden_states |
| - |
| - # get query states |
| - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) |
| - |
| - # get key/value states |
| - key_states = project( |
| - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None |
| - ) |
| - value_states = project( |
| - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None |
| - ) |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + cache_position = cache_position if not is_cross_attention else None |
| + key_states, value_states = curr_past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| + if is_cross_attention: |
| + past_key_value.is_updated[self.layer_idx] = True |
| |
| - # compute scores |
| - scores = torch.matmul( |
| - query_states, key_states.transpose(3, 2) |
| - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + scores = torch.matmul(query_states, key_states.transpose(3, 2)) |
| |
| if position_bias is None: |
| + key_length = key_states.shape[-2] |
| + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) |
| + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 |
| if not self.has_relative_attention_bias: |
| position_bias = torch.zeros( |
| - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype |
| + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype |
| ) |
| if self.gradient_checkpointing and self.training: |
| position_bias.requires_grad = True |
| else: |
| - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) |
| - |
| - # if key and values are already calculated |
| - # we want only the last query position bias |
| - if past_key_value is not None: |
| - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] |
| + position_bias = self.compute_bias( |
| + real_seq_length, key_length, device=scores.device, cache_position=cache_position |
| + ) |
| + position_bias = position_bias[:, :, -seq_length:, :] |
| |
| if mask is not None: |
| - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) |
| + causal_mask = mask[:, :, :, : key_states.shape[-2]] |
| + position_bias = position_bias + causal_mask |
| |
| if self.pruned_heads: |
| mask = torch.ones(position_bias.shape[1]) |
| @@ -457,22 +456,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| position_bias_masked = position_bias |
| |
| scores += position_bias_masked |
| - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( |
| - scores |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| - attn_weights = nn.functional.dropout( |
| - attn_weights, p=self.dropout, training=self.training |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| + |
| + # (batch_size, n_heads, seq_length, key_length) |
| + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) |
| + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| |
| # Mask heads if we want to |
| if layer_head_mask is not None: |
| attn_weights = attn_weights * layer_head_mask |
| |
| - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) |
| + attn_output = torch.matmul(attn_weights, value_states) |
| + |
| + attn_output = attn_output.transpose(1, 2).contiguous() |
| + attn_output = attn_output.view(batch_size, -1, self.inner_dim) |
| attn_output = self.o(attn_output) |
| |
| - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None |
| - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) |
| + outputs = (attn_output, past_key_value, position_bias) |
| |
| if output_attentions: |
| outputs = outputs + (attn_weights,) |
| @@ -481,9 +480,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano |
| class Pop2PianoLayerSelfAttention(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.SelfAttention = Pop2PianoAttention(config, has_relative_attention_bias=has_relative_attention_bias) |
| + self.SelfAttention = Pop2PianoAttention( |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| + ) |
| self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -496,6 +497,7 @@ def forward( |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.SelfAttention( |
| @@ -506,6 +508,7 @@ def forward( |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them |
| @@ -514,9 +517,9 @@ def forward( |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano |
| class Pop2PianoLayerCrossAttention(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False) |
| + self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) |
| self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -531,6 +534,7 @@ def forward( |
| use_cache=False, |
| query_length=None, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.EncDecAttention( |
| @@ -543,6 +547,7 @@ def forward( |
| use_cache=use_cache, |
| query_length=query_length, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| layer_output = hidden_states + self.dropout(attention_output[0]) |
| outputs = (layer_output,) + attention_output[1:] # add attentions if we output them |
| @@ -551,13 +556,17 @@ def forward( |
| |
| # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano |
| class Pop2PianoBlock(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.layer = nn.ModuleList() |
| - self.layer.append(Pop2PianoLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) |
| + self.layer.append( |
| + Pop2PianoLayerSelfAttention( |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| + ) |
| + ) |
| if self.is_decoder: |
| - self.layer.append(Pop2PianoLayerCrossAttention(config)) |
| + self.layer.append(Pop2PianoLayerCrossAttention(config, layer_idx=layer_idx)) |
| |
| self.layer.append(Pop2PianoLayerFF(config)) |
| |
| @@ -575,34 +584,19 @@ def forward( |
| use_cache=False, |
| output_attentions=False, |
| return_dict=True, |
| + cache_position=None, |
| ): |
| - if past_key_value is not None: |
| - if not self.is_decoder: |
| - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") |
| - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
| - |
| - if len(past_key_value) != expected_num_past_key_values: |
| - raise ValueError( |
| - f"There should be {expected_num_past_key_values} past states. " |
| - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" |
| - f"Got {len(past_key_value)} past key / value states" |
| - ) |
| - |
| - self_attn_past_key_value = past_key_value[:2] |
| - cross_attn_past_key_value = past_key_value[2:] |
| - else: |
| - self_attn_past_key_value, cross_attn_past_key_value = None, None |
| - |
| self_attention_outputs = self.layer[0]( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| layer_head_mask=layer_head_mask, |
| - past_key_value=self_attn_past_key_value, |
| + past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states, present_key_value_state = self_attention_outputs[:2] |
| + hidden_states, past_key_value = self_attention_outputs[:2] |
| attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights |
| |
| # clamp inf values to enable fp16 training |
| @@ -616,25 +610,18 @@ def forward( |
| |
| do_cross_attention = self.is_decoder and encoder_hidden_states is not None |
| if do_cross_attention: |
| - # the actual query length is unknown for cross attention |
| - # if using past key value states. Need to inject it here |
| - if present_key_value_state is not None: |
| - query_length = present_key_value_state[0].shape[2] |
| - else: |
| - query_length = None |
| - |
| cross_attention_outputs = self.layer[1]( |
| hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=cross_attn_past_key_value, |
| - query_length=query_length, |
| + past_key_value=past_key_value, |
| + query_length=cache_position[-1] + 1, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| - hidden_states = cross_attention_outputs[0] |
| + hidden_states, past_key_value = cross_attention_outputs[:2] |
| |
| # clamp inf values to enable fp16 training |
| if hidden_states.dtype == torch.float16: |
| @@ -645,10 +632,6 @@ def forward( |
| ) |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| - # Combine self attn and cross attn key value states |
| - if present_key_value_state is not None: |
| - present_key_value_state = present_key_value_state + cross_attention_outputs[1] |
| - |
| # Keep cross-attention outputs and relative position weights |
| attention_outputs = attention_outputs + cross_attention_outputs[2:] |
| |
| @@ -667,11 +650,11 @@ def forward( |
| outputs = (hidden_states,) |
| |
| if use_cache: |
| - outputs = outputs + (present_key_value_state,) + attention_outputs |
| + outputs = outputs + (past_key_value,) + attention_outputs |
| else: |
| outputs = outputs + attention_outputs |
| |
| - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) |
| + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) |
| |
| |
| class Pop2PianoPreTrainedModel(PreTrainedModel): |
| @@ -684,6 +667,8 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): |
| base_model_prefix = "transformer" |
| is_parallelizable = False |
| supports_gradient_checkpointing = True |
| + _supports_cache_class = True |
| + _supports_static_cache = False |
| _no_split_modules = ["Pop2PianoBlock"] |
| _keep_in_fp32_modules = ["wo"] |
| |
| @@ -769,7 +754,10 @@ def __init__(self, config, embed_tokens=None): |
| self.is_decoder = config.is_decoder |
| |
| self.block = nn.ModuleList( |
| - [Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] |
| + [ |
| + Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) |
| + for i in range(config.num_layers) |
| + ] |
| ) |
| self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| @@ -803,6 +791,7 @@ def forward( |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| + cache_position=None, |
| ): |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -825,6 +814,13 @@ def forward( |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") |
| |
| + if self.gradient_checkpointing and self.training: |
| + if use_cache: |
| + logger.warning_once( |
| + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| + ) |
| + use_cache = False |
| + |
| if inputs_embeds is None: |
| if self.embed_tokens is None: |
| raise ValueError("You have to initialize the model with valid token embeddings") |
| @@ -832,28 +828,55 @@ def forward( |
| |
| batch_size, seq_length = input_shape |
| |
| - # required mask seq length can be calculated via length of past |
| - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length |
| - |
| if use_cache is True: |
| if not self.is_decoder: |
| raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") |
| |
| - if attention_mask is None: |
| - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: |
| - encoder_seq_length = encoder_hidden_states.shape[1] |
| - encoder_attention_mask = torch.ones( |
| - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long |
| + # initialize past_key_values |
| + return_legacy_cache = False |
| + return_self_attention_cache = False |
| + if self.is_decoder and (use_cache or past_key_values is not None): |
| + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): |
| + return_self_attention_cache = True |
| + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| + elif not isinstance(past_key_values, EncoderDecoderCache): |
| + return_legacy_cache = True |
| + logger.warning_once( |
| + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " |
| + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " |
| + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." |
| + ) |
| + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
| + elif past_key_values is None: |
| + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| + elif not self.is_decoder: |
| + # do not pass cache object down the line for encoder stack |
| + # it messes indexing later in decoder-stack because cache object is modified in-place |
| + past_key_values = None |
| + |
| + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + if cache_position is None: |
| + cache_position = torch.arange( |
| + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device |
| ) |
| |
| - # initialize past_key_values with `None` if past does not exist |
| - if past_key_values is None: |
| - past_key_values = [None] * len(self.block) |
| + if attention_mask is None and not is_torchdynamo_compiling(): |
| + # required mask seq length can be calculated via length of past cache |
| + mask_seq_length = past_key_values_length + seq_length |
| + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| |
| - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] |
| - # ourselves in which case we just need to make it broadcastable to all heads. |
| - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
| + if self.config.is_decoder: |
| + causal_mask = self._update_causal_mask( |
| + attention_mask, |
| + inputs_embeds, |
| + cache_position, |
| + past_key_values.self_attention_cache if past_key_values is not None else None, |
| + output_attentions, |
| + ) |
| + else: |
| + causal_mask = attention_mask[:, None, None, :] |
| + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) |
| + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min |
| |
| # If a 2D or 3D attention mask is provided for the cross-attention |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |
| @@ -866,17 +889,9 @@ def forward( |
| else: |
| encoder_extended_attention_mask = None |
| |
| - if self.gradient_checkpointing and self.training: |
| - if use_cache: |
| - logger.warning_once( |
| - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| - ) |
| - use_cache = False |
| - |
| # Prepare head mask if needed |
| head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
| cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) |
| - present_key_value_states = () if use_cache else None |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and self.is_decoder) else None |
| @@ -885,7 +900,7 @@ def forward( |
| |
| hidden_states = self.dropout(inputs_embeds) |
| |
| - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): |
| + for i, layer_module in enumerate(self.block): |
| layer_head_mask = head_mask[i] |
| cross_attn_layer_head_mask = cross_attn_head_mask[i] |
| if output_hidden_states: |
| @@ -895,7 +910,7 @@ def forward( |
| layer_outputs = self._gradient_checkpointing_func( |
| layer_module.forward, |
| hidden_states, |
| - extended_attention_mask, |
| + causal_mask, |
| position_bias, |
| encoder_hidden_states, |
| encoder_extended_attention_mask, |
| @@ -905,20 +920,22 @@ def forward( |
| None, # past_key_value is always None with gradient checkpointing |
| use_cache, |
| output_attentions, |
| + cache_position, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| - attention_mask=extended_attention_mask, |
| + attention_mask=causal_mask, |
| position_bias=position_bias, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| encoder_decoder_position_bias=encoder_decoder_position_bias, |
| layer_head_mask=layer_head_mask, |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=past_key_value, |
| + past_key_value=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| |
| # layer_outputs is a tuple with: |
| @@ -926,7 +943,7 @@ def forward( |
| if use_cache is False: |
| layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] |
| |
| - hidden_states, present_key_value_state = layer_outputs[:2] |
| + hidden_states, next_decoder_cache = layer_outputs[:2] |
| |
| # We share the position biases between the layers - the first layer store them |
| # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), |
| @@ -934,9 +951,6 @@ def forward( |
| position_bias = layer_outputs[2] |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] |
| - # append next layer key value states |
| - if use_cache: |
| - present_key_value_states = present_key_value_states + (present_key_value_state,) |
| |
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[3],) |
| @@ -950,12 +964,18 @@ def forward( |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| |
| + next_cache = next_decoder_cache if use_cache else None |
| + if return_self_attention_cache: |
| + next_cache = past_key_values.self_attention_cache |
| + if return_legacy_cache: |
| + next_cache = past_key_values.to_legacy_cache() |
| + |
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| - present_key_value_states, |
| + next_cache, |
| all_hidden_states, |
| all_attentions, |
| all_cross_attentions, |
| @@ -964,12 +984,135 @@ def forward( |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| - past_key_values=present_key_value_states, |
| + past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| + def _update_causal_mask( |
| + self, |
| + attention_mask: torch.Tensor, |
| + input_tensor: torch.Tensor, |
| + cache_position: torch.Tensor, |
| + past_key_values: Cache, |
| + output_attentions: bool, |
| + ): |
| + if self.config._attn_implementation == "flash_attention_2": |
| + if attention_mask is not None and 0.0 in attention_mask: |
| + return attention_mask |
| + return None |
| + |
| + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
| + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
| + # to infer the attention mask. |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + using_static_cache = isinstance(past_key_values, StaticCache) |
| + |
| + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
| + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| + if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| + attention_mask, |
| + inputs_embeds=input_tensor, |
| + past_key_values_length=past_seen_tokens, |
| + is_training=self.training, |
| + ): |
| + return None |
| + |
| + dtype, device = input_tensor.dtype, input_tensor.device |
| + sequence_length = input_tensor.shape[1] |
| + if using_static_cache: |
| + target_length = past_key_values.get_max_cache_shape() |
| + else: |
| + target_length = ( |
| + attention_mask.shape[-1] |
| + if isinstance(attention_mask, torch.Tensor) |
| + else past_seen_tokens + sequence_length + 1 |
| + ) |
| + |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| + |
| + if ( |
| + self.config._attn_implementation == "sdpa" |
| + and attention_mask is not None |
| + and attention_mask.device.type == "cuda" |
| + and not output_attentions |
| + ): |
| + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| + # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| + |
| + return causal_mask |
| + |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + device: torch.device, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + device (`torch.device`): |
| + The device to plcae the 4D attention mask on. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| class Pop2PianoConcatEmbeddingToMel(nn.Module): |
| """Embedding Matrix for `composer` tokens.""" |
| @@ -1122,6 +1265,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| @@ -1177,6 +1321,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| sequence_output = decoder_outputs[0] |
| |
| |
| |
| |
| @@ -24,7 +24,9 @@ |
| from torch.nn import CrossEntropyLoss |
| |
| from ...activations import ACT2FN |
| +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| from ...generation import GenerationMixin |
| +from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_outputs import ( |
| MoEModelOutput, |
| MoEModelOutputWithPastAndCrossAttentions, |
| @@ -39,6 +41,7 @@ |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_torch_fx_proxy, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -355,7 +358,12 @@ def forward(self, hidden_states, output_router_logits): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers |
| class SwitchTransformersAttention(nn.Module): |
| - def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False): |
| + def __init__( |
| + self, |
| + config: SwitchTransformersConfig, |
| + has_relative_attention_bias=False, |
| + layer_idx: Optional[int] = None, |
| + ): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.has_relative_attention_bias = has_relative_attention_bias |
| @@ -366,6 +374,13 @@ def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias |
| self.n_heads = config.num_heads |
| self.dropout = config.dropout_rate |
| self.inner_dim = self.n_heads * self.key_value_proj_dim |
| + self.layer_idx = layer_idx |
| + if layer_idx is None and self.is_decoder: |
| + logger.warning_once( |
| + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " |
| + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| + "when creating this class." |
| + ) |
| |
| # Mesh TensorFlow initialization to avoid scaling before softmax |
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| @@ -442,11 +457,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets |
| relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) |
| return relative_buckets |
| |
| - def compute_bias(self, query_length, key_length, device=None): |
| + def compute_bias(self, query_length, key_length, device=None, cache_position=None): |
| """Compute binned relative position bias""" |
| if device is None: |
| device = self.relative_attention_bias.weight.device |
| - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + if cache_position is None: |
| + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + else: |
| + context_position = cache_position[:, None].to(device) |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] |
| relative_position = memory_position - context_position # shape (query_length, key_length) |
| relative_position_bucket = self._relative_position_bucket( |
| @@ -470,94 +488,72 @@ def forward( |
| query_length=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| """ |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
| """ |
| # Input is (batch_size, seq_length, dim) |
| - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) |
| - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) |
| + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) |
| batch_size, seq_length = hidden_states.shape[:2] |
| |
| - real_seq_length = seq_length |
| + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder |
| + is_cross_attention = key_value_states is not None |
| |
| - if past_key_value is not None: |
| - if len(past_key_value) != 2: |
| - raise ValueError( |
| - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" |
| - ) |
| - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length |
| - |
| - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] |
| - |
| - def shape(states): |
| - """projection""" |
| - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + query_states = self.q(hidden_states) |
| + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| - def unshape(states): |
| - """reshape""" |
| - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) |
| + if past_key_value is not None: |
| + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| + if is_cross_attention: |
| + # after the first generated id, we can subsequently re-use all key/value_states from cache |
| + curr_past_key_value = past_key_value.cross_attention_cache |
| + else: |
| + curr_past_key_value = past_key_value.self_attention_cache |
| |
| - def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| - """projects hidden states correctly to key/query states""" |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(hidden_states)) |
| - elif past_key_value is None: |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| + current_states = key_value_states if is_cross_attention else hidden_states |
| + if is_cross_attention and past_key_value is not None and is_updated: |
| + # reuse k,v, cross_attentions |
| + key_states = curr_past_key_value.key_cache[self.layer_idx] |
| + value_states = curr_past_key_value.value_cache[self.layer_idx] |
| + else: |
| + key_states = self.k(current_states) |
| + value_states = self.v(current_states) |
| + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| if past_key_value is not None: |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, key_length, dim_per_head) |
| - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) |
| - elif past_key_value.shape[2] != key_value_states.shape[1]: |
| - # checking that the `sequence_length` of the `past_key_value` is the same as |
| - # the provided `key_value_states` to support prefix tuning |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| - else: |
| - # cross-attn |
| - hidden_states = past_key_value |
| - return hidden_states |
| - |
| - # get query states |
| - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) |
| - |
| - # get key/value states |
| - key_states = project( |
| - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None |
| - ) |
| - value_states = project( |
| - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None |
| - ) |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + cache_position = cache_position if not is_cross_attention else None |
| + key_states, value_states = curr_past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| + if is_cross_attention: |
| + past_key_value.is_updated[self.layer_idx] = True |
| |
| - # compute scores |
| - scores = torch.matmul( |
| - query_states, key_states.transpose(3, 2) |
| - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + scores = torch.matmul(query_states, key_states.transpose(3, 2)) |
| |
| if position_bias is None: |
| + key_length = key_states.shape[-2] |
| + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) |
| + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 |
| if not self.has_relative_attention_bias: |
| position_bias = torch.zeros( |
| - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype |
| + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype |
| ) |
| if self.gradient_checkpointing and self.training: |
| position_bias.requires_grad = True |
| else: |
| - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) |
| - |
| - # if key and values are already calculated |
| - # we want only the last query position bias |
| - if past_key_value is not None: |
| - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] |
| + position_bias = self.compute_bias( |
| + real_seq_length, key_length, device=scores.device, cache_position=cache_position |
| + ) |
| + position_bias = position_bias[:, :, -seq_length:, :] |
| |
| if mask is not None: |
| - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) |
| + causal_mask = mask[:, :, :, : key_states.shape[-2]] |
| + position_bias = position_bias + causal_mask |
| |
| if self.pruned_heads: |
| mask = torch.ones(position_bias.shape[1]) |
| @@ -567,22 +563,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| position_bias_masked = position_bias |
| |
| scores += position_bias_masked |
| - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( |
| - scores |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| - attn_weights = nn.functional.dropout( |
| - attn_weights, p=self.dropout, training=self.training |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| + |
| + # (batch_size, n_heads, seq_length, key_length) |
| + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) |
| + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| |
| # Mask heads if we want to |
| if layer_head_mask is not None: |
| attn_weights = attn_weights * layer_head_mask |
| |
| - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) |
| + attn_output = torch.matmul(attn_weights, value_states) |
| + |
| + attn_output = attn_output.transpose(1, 2).contiguous() |
| + attn_output = attn_output.view(batch_size, -1, self.inner_dim) |
| attn_output = self.o(attn_output) |
| |
| - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None |
| - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) |
| + outputs = (attn_output, past_key_value, position_bias) |
| |
| if output_attentions: |
| outputs = outputs + (attn_weights,) |
| @@ -591,10 +587,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers |
| class SwitchTransformersLayerSelfAttention(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.SelfAttention = SwitchTransformersAttention( |
| - config, has_relative_attention_bias=has_relative_attention_bias |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| ) |
| self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| @@ -608,6 +604,7 @@ def forward( |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.SelfAttention( |
| @@ -618,6 +615,7 @@ def forward( |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them |
| @@ -626,9 +624,11 @@ def forward( |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers |
| class SwitchTransformersLayerCrossAttention(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.EncDecAttention = SwitchTransformersAttention(config, has_relative_attention_bias=False) |
| + self.EncDecAttention = SwitchTransformersAttention( |
| + config, has_relative_attention_bias=False, layer_idx=layer_idx |
| + ) |
| self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -643,6 +643,7 @@ def forward( |
| use_cache=False, |
| query_length=None, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.EncDecAttention( |
| @@ -655,6 +656,7 @@ def forward( |
| use_cache=use_cache, |
| query_length=query_length, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| layer_output = hidden_states + self.dropout(attention_output[0]) |
| outputs = (layer_output,) + attention_output[1:] # add attentions if we output them |
| @@ -662,16 +664,18 @@ def forward( |
| |
| |
| class SwitchTransformersBlock(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False, is_sparse=False): |
| + def __init__(self, config, has_relative_attention_bias=False, is_sparse=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.is_sparse = is_sparse |
| self.layer = nn.ModuleList() |
| self.layer.append( |
| - SwitchTransformersLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias) |
| + SwitchTransformersLayerSelfAttention( |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| + ) |
| ) |
| if self.is_decoder: |
| - self.layer.append(SwitchTransformersLayerCrossAttention(config)) |
| + self.layer.append(SwitchTransformersLayerCrossAttention(config, layer_idx=layer_idx)) |
| |
| self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse)) |
| |
| @@ -690,34 +694,19 @@ def forward( |
| output_attentions=False, |
| output_router_logits=True, |
| return_dict=True, |
| + cache_position=None, |
| ): |
| - if past_key_value is not None: |
| - if not self.is_decoder: |
| - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") |
| - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
| - |
| - if len(past_key_value) != expected_num_past_key_values: |
| - raise ValueError( |
| - f"There should be {expected_num_past_key_values} past states. " |
| - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" |
| - f"Got {len(past_key_value)} past key / value states" |
| - ) |
| - |
| - self_attn_past_key_value = past_key_value[:2] |
| - cross_attn_past_key_value = past_key_value[2:] |
| - else: |
| - self_attn_past_key_value, cross_attn_past_key_value = None, None |
| - |
| self_attention_outputs = self.layer[0]( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| layer_head_mask=layer_head_mask, |
| - past_key_value=self_attn_past_key_value, |
| + past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states, present_key_value_state = self_attention_outputs[:2] |
| + hidden_states, past_key_value = self_attention_outputs[:2] |
| attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights |
| |
| # clamp inf values to enable fp16 training |
| @@ -727,35 +716,25 @@ def forward( |
| |
| do_cross_attention = self.is_decoder and encoder_hidden_states is not None |
| if do_cross_attention: |
| - # the actual query length is unknown for cross attention |
| - # if using past key value states. Need to inject it here |
| - if present_key_value_state is not None: |
| - query_length = present_key_value_state[0].shape[2] |
| - else: |
| - query_length = None |
| - |
| cross_attention_outputs = self.layer[1]( |
| hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=cross_attn_past_key_value, |
| - query_length=query_length, |
| + past_key_value=past_key_value, |
| + query_length=cache_position[-1] + 1, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states = cross_attention_outputs[0] |
| + hidden_states, past_key_value = cross_attention_outputs[:2] |
| |
| # clamp inf values to enable fp16 training |
| if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): |
| clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| - # Combine self attn and cross attn key value states |
| - if present_key_value_state is not None: |
| - present_key_value_state = present_key_value_state + cross_attention_outputs[1] |
| - |
| # Keep cross-attention outputs and relative position weights |
| attention_outputs = attention_outputs + cross_attention_outputs[2:] |
| |
| @@ -775,11 +754,11 @@ def forward( |
| outputs = (hidden_states,) |
| |
| if use_cache: |
| - outputs = outputs + (present_key_value_state,) + attention_outputs + (router_tuple,) |
| + outputs = outputs + (past_key_value,) + attention_outputs + (router_tuple,) |
| else: |
| outputs = outputs + attention_outputs + (router_tuple,) |
| |
| - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) |
| + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) |
| |
| |
| class SwitchTransformersPreTrainedModel(PreTrainedModel): |
| @@ -791,6 +770,8 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): |
| config_class = SwitchTransformersConfig |
| base_model_prefix = "switch_transformers" |
| supports_gradient_checkpointing = True |
| + _supports_cache_class = True |
| + _supports_static_cache = False |
| _no_split_modules = ["SwitchTransformersBlock"] |
| |
| @property |
| @@ -897,7 +878,9 @@ def __init__(self, config, embed_tokens=None): |
| is_sparse = (i % sparse_step == 1 or sparse_step == 1) if sparse_step > 0 else False |
| |
| self.block.append( |
| - SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse) |
| + SwitchTransformersBlock( |
| + config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse, layer_idx=i |
| + ) |
| ) |
| |
| self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| @@ -930,6 +913,7 @@ def forward( |
| output_hidden_states=None, |
| output_router_logits=True, |
| return_dict=None, |
| + cache_position=None, |
| ): |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -952,6 +936,13 @@ def forward( |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") |
| |
| + if self.gradient_checkpointing and self.training: |
| + if use_cache: |
| + logger.warning_once( |
| + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| + ) |
| + use_cache = False |
| + |
| if inputs_embeds is None: |
| if self.embed_tokens is None: |
| raise ValueError("You have to initialize the model with valid token embeddings") |
| @@ -959,28 +950,55 @@ def forward( |
| |
| batch_size, seq_length = input_shape |
| |
| - # required mask seq length can be calculated via length of past |
| - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length |
| - |
| if use_cache is True: |
| if not self.is_decoder: |
| raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") |
| |
| - if attention_mask is None: |
| - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: |
| - encoder_seq_length = encoder_hidden_states.shape[1] |
| - encoder_attention_mask = torch.ones( |
| - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long |
| + # initialize past_key_values |
| + return_legacy_cache = False |
| + return_self_attention_cache = False |
| + if self.is_decoder and (use_cache or past_key_values is not None): |
| + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): |
| + return_self_attention_cache = True |
| + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| + elif not isinstance(past_key_values, EncoderDecoderCache): |
| + return_legacy_cache = True |
| + logger.warning_once( |
| + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " |
| + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " |
| + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." |
| + ) |
| + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
| + elif past_key_values is None: |
| + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| + elif not self.is_decoder: |
| + # do not pass cache object down the line for encoder stack |
| + # it messes indexing later in decoder-stack because cache object is modified in-place |
| + past_key_values = None |
| + |
| + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + if cache_position is None: |
| + cache_position = torch.arange( |
| + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device |
| ) |
| |
| - # initialize past_key_values with `None` if past does not exist |
| - if past_key_values is None: |
| - past_key_values = [None] * len(self.block) |
| + if attention_mask is None and not is_torchdynamo_compiling(): |
| + # required mask seq length can be calculated via length of past cache |
| + mask_seq_length = past_key_values_length + seq_length |
| + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| |
| - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] |
| - # ourselves in which case we just need to make it broadcastable to all heads. |
| - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
| + if self.config.is_decoder: |
| + causal_mask = self._update_causal_mask( |
| + attention_mask, |
| + inputs_embeds, |
| + cache_position, |
| + past_key_values.self_attention_cache if past_key_values is not None else None, |
| + output_attentions, |
| + ) |
| + else: |
| + causal_mask = attention_mask[:, None, None, :] |
| + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) |
| + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min |
| |
| # If a 2D or 3D attention mask is provided for the cross-attention |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |
| @@ -993,17 +1011,9 @@ def forward( |
| else: |
| encoder_extended_attention_mask = None |
| |
| - if self.gradient_checkpointing and self.training: |
| - if use_cache: |
| - logger.warning_once( |
| - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| - ) |
| - use_cache = False |
| - |
| # Prepare head mask if needed |
| head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
| cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) |
| - present_key_value_states = () if use_cache else None |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_router_probs = () if output_router_logits else None |
| @@ -1013,7 +1023,7 @@ def forward( |
| |
| hidden_states = self.dropout(inputs_embeds) |
| |
| - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): |
| + for i, layer_module in enumerate(self.block): |
| layer_head_mask = head_mask[i] |
| cross_attn_layer_head_mask = cross_attn_head_mask[i] |
| |
| @@ -1024,7 +1034,7 @@ def forward( |
| layer_outputs = self._gradient_checkpointing_func( |
| layer_module.forward, |
| hidden_states, |
| - extended_attention_mask, |
| + causal_mask, |
| position_bias, |
| encoder_hidden_states, |
| encoder_extended_attention_mask, |
| @@ -1034,21 +1044,26 @@ def forward( |
| None, # past_key_value is always None with gradient checkpointing |
| use_cache, |
| output_attentions, |
| + output_router_logits, |
| + return_dict, |
| + cache_position, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| - attention_mask=extended_attention_mask, |
| + attention_mask=causal_mask, |
| position_bias=position_bias, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| encoder_decoder_position_bias=encoder_decoder_position_bias, |
| layer_head_mask=layer_head_mask, |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=past_key_value, |
| + past_key_value=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_router_logits=output_router_logits, |
| + return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| router_probs = layer_outputs[-1] |
| @@ -1059,7 +1074,7 @@ def forward( |
| if use_cache is False: |
| layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] |
| |
| - hidden_states, present_key_value_state = layer_outputs[:2] |
| + hidden_states, next_decoder_cache = layer_outputs[:2] |
| |
| # We share the position biases between the layers - the first layer store them |
| # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), |
| @@ -1067,9 +1082,6 @@ def forward( |
| position_bias = layer_outputs[2] |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] |
| - # append next layer key value states |
| - if use_cache: |
| - present_key_value_states = present_key_value_states + (present_key_value_state,) |
| |
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[3],) |
| @@ -1086,12 +1098,18 @@ def forward( |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| |
| + next_cache = next_decoder_cache if use_cache else None |
| + if return_self_attention_cache: |
| + next_cache = past_key_values.self_attention_cache |
| + if return_legacy_cache: |
| + next_cache = past_key_values.to_legacy_cache() |
| + |
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| - present_key_value_states, |
| + next_cache, |
| all_hidden_states, |
| all_attentions, |
| all_cross_attentions, |
| @@ -1101,13 +1119,136 @@ def forward( |
| ) |
| return MoEModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| - past_key_values=present_key_value_states, |
| + past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| cross_attentions=all_cross_attentions, |
| router_probs=all_router_probs, |
| ) |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| + def _update_causal_mask( |
| + self, |
| + attention_mask: torch.Tensor, |
| + input_tensor: torch.Tensor, |
| + cache_position: torch.Tensor, |
| + past_key_values: Cache, |
| + output_attentions: bool, |
| + ): |
| + if self.config._attn_implementation == "flash_attention_2": |
| + if attention_mask is not None and 0.0 in attention_mask: |
| + return attention_mask |
| + return None |
| + |
| + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
| + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
| + # to infer the attention mask. |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + using_static_cache = isinstance(past_key_values, StaticCache) |
| + |
| + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
| + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| + if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| + attention_mask, |
| + inputs_embeds=input_tensor, |
| + past_key_values_length=past_seen_tokens, |
| + is_training=self.training, |
| + ): |
| + return None |
| + |
| + dtype, device = input_tensor.dtype, input_tensor.device |
| + sequence_length = input_tensor.shape[1] |
| + if using_static_cache: |
| + target_length = past_key_values.get_max_cache_shape() |
| + else: |
| + target_length = ( |
| + attention_mask.shape[-1] |
| + if isinstance(attention_mask, torch.Tensor) |
| + else past_seen_tokens + sequence_length + 1 |
| + ) |
| + |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| + |
| + if ( |
| + self.config._attn_implementation == "sdpa" |
| + and attention_mask is not None |
| + and attention_mask.device.type == "cuda" |
| + and not output_attentions |
| + ): |
| + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| + # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| + |
| + return causal_mask |
| + |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + device: torch.device, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + device (`torch.device`): |
| + The device to plcae the 4D attention mask on. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| SWITCH_TRANSFORMERS_START_DOCSTRING = r""" |
| |
| @@ -1228,6 +1369,9 @@ def forward( |
| should not be returned during inference. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the |
| + cache in the correct position and to infer the complete sequence length. |
| """ |
| |
| SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" |
| @@ -1355,6 +1499,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| output_router_logits: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEModelOutput]: |
| r""" |
| Returns: |
| @@ -1435,6 +1580,7 @@ def forward( |
| output_hidden_states=output_hidden_states, |
| output_router_logits=output_router_logits, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| if not return_dict: |
| @@ -1535,6 +1681,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| output_router_logits: Optional[bool] = True, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| @@ -1618,6 +1765,7 @@ def forward( |
| output_hidden_states=output_hidden_states, |
| output_router_logits=output_router_logits, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| sequence_output = decoder_outputs[0] |
| |
| |
| |
| |
| @@ -73,7 +73,12 @@ class T5Config(PretrainedConfig): |
| |
| model_type = "t5" |
| keys_to_ignore_at_inference = ["past_key_values"] |
| - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} |
| + attribute_map = { |
| + "hidden_size": "d_model", |
| + "num_attention_heads": "num_heads", |
| + "num_hidden_layers": "num_layers", |
| + "head_dim": "d_kv", |
| + } |
| |
| def __init__( |
| self, |
| |
| |
| |
| |
| @@ -25,7 +25,9 @@ |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| |
| from ...activations import ACT2FN |
| +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| from ...generation import GenerationMixin |
| +from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPastAndCrossAttentions, |
| @@ -43,6 +45,7 @@ |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_torch_fx_proxy, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -339,7 +342,12 @@ def forward(self, hidden_states): |
| |
| |
| class T5Attention(nn.Module): |
| - def __init__(self, config: T5Config, has_relative_attention_bias=False): |
| + def __init__( |
| + self, |
| + config: T5Config, |
| + has_relative_attention_bias=False, |
| + layer_idx: Optional[int] = None, |
| + ): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.has_relative_attention_bias = has_relative_attention_bias |
| @@ -350,6 +358,13 @@ def __init__(self, config: T5Config, has_relative_attention_bias=False): |
| self.n_heads = config.num_heads |
| self.dropout = config.dropout_rate |
| self.inner_dim = self.n_heads * self.key_value_proj_dim |
| + self.layer_idx = layer_idx |
| + if layer_idx is None and self.is_decoder: |
| + logger.warning_once( |
| + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " |
| + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| + "when creating this class." |
| + ) |
| |
| # Mesh TensorFlow initialization to avoid scaling before softmax |
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| @@ -426,11 +441,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets |
| relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) |
| return relative_buckets |
| |
| - def compute_bias(self, query_length, key_length, device=None): |
| + def compute_bias(self, query_length, key_length, device=None, cache_position=None): |
| """Compute binned relative position bias""" |
| if device is None: |
| device = self.relative_attention_bias.weight.device |
| - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + if cache_position is None: |
| + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + else: |
| + context_position = cache_position[:, None].to(device) |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] |
| relative_position = memory_position - context_position # shape (query_length, key_length) |
| relative_position_bucket = self._relative_position_bucket( |
| @@ -454,94 +472,72 @@ def forward( |
| query_length=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| """ |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
| """ |
| # Input is (batch_size, seq_length, dim) |
| - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) |
| - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) |
| + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) |
| batch_size, seq_length = hidden_states.shape[:2] |
| |
| - real_seq_length = seq_length |
| + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder |
| + is_cross_attention = key_value_states is not None |
| |
| - if past_key_value is not None: |
| - if len(past_key_value) != 2: |
| - raise ValueError( |
| - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" |
| - ) |
| - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length |
| - |
| - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] |
| - |
| - def shape(states): |
| - """projection""" |
| - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + query_states = self.q(hidden_states) |
| + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| - def unshape(states): |
| - """reshape""" |
| - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) |
| + if past_key_value is not None: |
| + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| + if is_cross_attention: |
| + # after the first generated id, we can subsequently re-use all key/value_states from cache |
| + curr_past_key_value = past_key_value.cross_attention_cache |
| + else: |
| + curr_past_key_value = past_key_value.self_attention_cache |
| |
| - def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| - """projects hidden states correctly to key/query states""" |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(hidden_states)) |
| - elif past_key_value is None: |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| + current_states = key_value_states if is_cross_attention else hidden_states |
| + if is_cross_attention and past_key_value is not None and is_updated: |
| + # reuse k,v, cross_attentions |
| + key_states = curr_past_key_value.key_cache[self.layer_idx] |
| + value_states = curr_past_key_value.value_cache[self.layer_idx] |
| + else: |
| + key_states = self.k(current_states) |
| + value_states = self.v(current_states) |
| + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| if past_key_value is not None: |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, key_length, dim_per_head) |
| - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) |
| - elif past_key_value.shape[2] != key_value_states.shape[1]: |
| - # checking that the `sequence_length` of the `past_key_value` is the same as |
| - # the provided `key_value_states` to support prefix tuning |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| - else: |
| - # cross-attn |
| - hidden_states = past_key_value |
| - return hidden_states |
| - |
| - # get query states |
| - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) |
| - |
| - # get key/value states |
| - key_states = project( |
| - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None |
| - ) |
| - value_states = project( |
| - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None |
| - ) |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + cache_position = cache_position if not is_cross_attention else None |
| + key_states, value_states = curr_past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| + if is_cross_attention: |
| + past_key_value.is_updated[self.layer_idx] = True |
| |
| - # compute scores |
| - scores = torch.matmul( |
| - query_states, key_states.transpose(3, 2) |
| - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + scores = torch.matmul(query_states, key_states.transpose(3, 2)) |
| |
| if position_bias is None: |
| + key_length = key_states.shape[-2] |
| + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) |
| + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 |
| if not self.has_relative_attention_bias: |
| position_bias = torch.zeros( |
| - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype |
| + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype |
| ) |
| if self.gradient_checkpointing and self.training: |
| position_bias.requires_grad = True |
| else: |
| - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) |
| - |
| - # if key and values are already calculated |
| - # we want only the last query position bias |
| - if past_key_value is not None: |
| - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] |
| + position_bias = self.compute_bias( |
| + real_seq_length, key_length, device=scores.device, cache_position=cache_position |
| + ) |
| + position_bias = position_bias[:, :, -seq_length:, :] |
| |
| if mask is not None: |
| - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) |
| + causal_mask = mask[:, :, :, : key_states.shape[-2]] |
| + position_bias = position_bias + causal_mask |
| |
| if self.pruned_heads: |
| mask = torch.ones(position_bias.shape[1]) |
| @@ -551,22 +547,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| position_bias_masked = position_bias |
| |
| scores += position_bias_masked |
| - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( |
| - scores |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| - attn_weights = nn.functional.dropout( |
| - attn_weights, p=self.dropout, training=self.training |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| + |
| + # (batch_size, n_heads, seq_length, key_length) |
| + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) |
| + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| |
| # Mask heads if we want to |
| if layer_head_mask is not None: |
| attn_weights = attn_weights * layer_head_mask |
| |
| - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) |
| + attn_output = torch.matmul(attn_weights, value_states) |
| + |
| + attn_output = attn_output.transpose(1, 2).contiguous() |
| + attn_output = attn_output.view(batch_size, -1, self.inner_dim) |
| attn_output = self.o(attn_output) |
| |
| - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None |
| - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) |
| + outputs = (attn_output, past_key_value, position_bias) |
| |
| if output_attentions: |
| outputs = outputs + (attn_weights,) |
| @@ -574,9 +570,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| |
| |
| class T5LayerSelfAttention(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) |
| + self.SelfAttention = T5Attention( |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| + ) |
| self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -589,6 +587,7 @@ def forward( |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.SelfAttention( |
| @@ -599,6 +598,7 @@ def forward( |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them |
| @@ -606,9 +606,9 @@ def forward( |
| |
| |
| class T5LayerCrossAttention(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) |
| + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) |
| self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -623,6 +623,7 @@ def forward( |
| use_cache=False, |
| query_length=None, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.EncDecAttention( |
| @@ -635,6 +636,7 @@ def forward( |
| use_cache=use_cache, |
| query_length=query_length, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| layer_output = hidden_states + self.dropout(attention_output[0]) |
| outputs = (layer_output,) + attention_output[1:] # add attentions if we output them |
| @@ -642,13 +644,15 @@ def forward( |
| |
| |
| class T5Block(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.layer = nn.ModuleList() |
| - self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) |
| + self.layer.append( |
| + T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) |
| + ) |
| if self.is_decoder: |
| - self.layer.append(T5LayerCrossAttention(config)) |
| + self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx)) |
| |
| self.layer.append(T5LayerFF(config)) |
| |
| @@ -666,34 +670,19 @@ def forward( |
| use_cache=False, |
| output_attentions=False, |
| return_dict=True, |
| + cache_position=None, |
| ): |
| - if past_key_value is not None: |
| - if not self.is_decoder: |
| - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") |
| - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
| - |
| - if len(past_key_value) != expected_num_past_key_values: |
| - raise ValueError( |
| - f"There should be {expected_num_past_key_values} past states. " |
| - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" |
| - f"Got {len(past_key_value)} past key / value states" |
| - ) |
| - |
| - self_attn_past_key_value = past_key_value[:2] |
| - cross_attn_past_key_value = past_key_value[2:] |
| - else: |
| - self_attn_past_key_value, cross_attn_past_key_value = None, None |
| - |
| self_attention_outputs = self.layer[0]( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| layer_head_mask=layer_head_mask, |
| - past_key_value=self_attn_past_key_value, |
| + past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states, present_key_value_state = self_attention_outputs[:2] |
| + hidden_states, past_key_value = self_attention_outputs[:2] |
| attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights |
| |
| # clamp inf values to enable fp16 training |
| @@ -707,25 +696,18 @@ def forward( |
| |
| do_cross_attention = self.is_decoder and encoder_hidden_states is not None |
| if do_cross_attention: |
| - # the actual query length is unknown for cross attention |
| - # if using past key value states. Need to inject it here |
| - if present_key_value_state is not None: |
| - query_length = present_key_value_state[0].shape[2] |
| - else: |
| - query_length = None |
| - |
| cross_attention_outputs = self.layer[1]( |
| hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=cross_attn_past_key_value, |
| - query_length=query_length, |
| + past_key_value=past_key_value, |
| + query_length=cache_position[-1] + 1, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| - hidden_states = cross_attention_outputs[0] |
| + hidden_states, past_key_value = cross_attention_outputs[:2] |
| |
| # clamp inf values to enable fp16 training |
| if hidden_states.dtype == torch.float16: |
| @@ -736,10 +718,6 @@ def forward( |
| ) |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| - # Combine self attn and cross attn key value states |
| - if present_key_value_state is not None: |
| - present_key_value_state = present_key_value_state + cross_attention_outputs[1] |
| - |
| # Keep cross-attention outputs and relative position weights |
| attention_outputs = attention_outputs + cross_attention_outputs[2:] |
| |
| @@ -758,11 +736,11 @@ def forward( |
| outputs = (hidden_states,) |
| |
| if use_cache: |
| - outputs = outputs + (present_key_value_state,) + attention_outputs |
| + outputs = outputs + (past_key_value,) + attention_outputs |
| else: |
| outputs = outputs + attention_outputs |
| |
| - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) |
| + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) |
| |
| |
| class T5ClassificationHead(nn.Module): |
| @@ -794,6 +772,9 @@ class T5PreTrainedModel(PreTrainedModel): |
| base_model_prefix = "transformer" |
| is_parallelizable = True |
| supports_gradient_checkpointing = True |
| + _supports_quantized_cache = False # enc-dec models don't support yet |
| + _supports_static_cache = True |
| + _supports_cache_class = True |
| _no_split_modules = ["T5Block"] |
| _keep_in_fp32_modules = ["wo"] |
| |
| @@ -905,7 +886,7 @@ def __init__(self, config, embed_tokens=None): |
| self.is_decoder = config.is_decoder |
| |
| self.block = nn.ModuleList( |
| - [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] |
| + [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] |
| ) |
| self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| @@ -981,6 +962,7 @@ def forward( |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| + cache_position=None, |
| ): |
| # Model parallel |
| if self.model_parallel: |
| @@ -1007,6 +989,13 @@ def forward( |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") |
| |
| + if self.gradient_checkpointing and self.training: |
| + if use_cache: |
| + logger.warning_once( |
| + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| + ) |
| + use_cache = False |
| + |
| if inputs_embeds is None: |
| if self.embed_tokens is None: |
| raise ValueError("You have to initialize the model with valid token embeddings") |
| @@ -1014,23 +1003,57 @@ def forward( |
| |
| batch_size, seq_length = input_shape |
| |
| - # required mask seq length can be calculated via length of past |
| - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length |
| - |
| if use_cache is True: |
| if not self.is_decoder: |
| raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") |
| |
| - # initialize past_key_values with `None` if past does not exist |
| - if past_key_values is None: |
| - past_key_values = [None] * len(self.block) |
| + # initialize past_key_values |
| + return_legacy_cache = False |
| + return_self_attention_cache = False |
| + if self.is_decoder and (use_cache or past_key_values is not None): |
| + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): |
| + return_self_attention_cache = True |
| + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| + elif not isinstance(past_key_values, EncoderDecoderCache): |
| + return_legacy_cache = True |
| + logger.warning_once( |
| + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " |
| + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " |
| + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." |
| + ) |
| + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
| + elif past_key_values is None: |
| + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| + elif not self.is_decoder: |
| + # do not pass cache object down the line for encoder stack |
| + # it messes indexing later in decoder-stack because cache object is modified in-place |
| + past_key_values = None |
| + |
| + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + if cache_position is None: |
| + cache_position = torch.arange( |
| + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device |
| + ) |
| |
| - if attention_mask is None: |
| + if attention_mask is None and not is_torchdynamo_compiling(): |
| + # required mask seq length can be calculated via length of past cache |
| + mask_seq_length = past_key_values_length + seq_length |
| attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| |
| - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] |
| - # ourselves in which case we just need to make it broadcastable to all heads. |
| - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
| + if self.config.is_decoder: |
| + causal_mask = self._update_causal_mask( |
| + attention_mask, |
| + inputs_embeds, |
| + cache_position, |
| + past_key_values.self_attention_cache if past_key_values is not None else None, |
| + output_attentions, |
| + ) |
| + elif attention_mask is not None: |
| + causal_mask = attention_mask[:, None, None, :] |
| + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) |
| + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min |
| + else: |
| + causal_mask = None |
| |
| # If a 2D or 3D attention mask is provided for the cross-attention |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |
| @@ -1045,17 +1068,9 @@ def forward( |
| else: |
| encoder_extended_attention_mask = None |
| |
| - if self.gradient_checkpointing and self.training: |
| - if use_cache: |
| - logger.warning_once( |
| - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| - ) |
| - use_cache = False |
| - |
| # Prepare head mask if needed |
| head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
| cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) |
| - present_key_value_states = () if use_cache else None |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and self.is_decoder) else None |
| @@ -1064,15 +1079,15 @@ def forward( |
| |
| hidden_states = self.dropout(inputs_embeds) |
| |
| - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): |
| + for i, layer_module in enumerate(self.block): |
| layer_head_mask = head_mask[i] |
| cross_attn_layer_head_mask = cross_attn_head_mask[i] |
| # Model parallel |
| if self.model_parallel: |
| torch.cuda.set_device(hidden_states.device) |
| # Ensure that attention_mask is always on the same device as hidden_states |
| - if attention_mask is not None: |
| - attention_mask = attention_mask.to(hidden_states.device) |
| + if causal_mask is not None: |
| + causal_mask = causal_mask.to(hidden_states.device) |
| if position_bias is not None: |
| position_bias = position_bias.to(hidden_states.device) |
| if encoder_hidden_states is not None: |
| @@ -1092,7 +1107,7 @@ def forward( |
| layer_outputs = self._gradient_checkpointing_func( |
| layer_module.forward, |
| hidden_states, |
| - extended_attention_mask, |
| + causal_mask, |
| position_bias, |
| encoder_hidden_states, |
| encoder_extended_attention_mask, |
| @@ -1102,20 +1117,24 @@ def forward( |
| None, # past_key_value is always None with gradient checkpointing |
| use_cache, |
| output_attentions, |
| + return_dict, |
| + cache_position, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| - attention_mask=extended_attention_mask, |
| + attention_mask=causal_mask, |
| position_bias=position_bias, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| encoder_decoder_position_bias=encoder_decoder_position_bias, |
| layer_head_mask=layer_head_mask, |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=past_key_value, |
| + past_key_value=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| # layer_outputs is a tuple with: |
| @@ -1123,7 +1142,7 @@ def forward( |
| if use_cache is False: |
| layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] |
| |
| - hidden_states, present_key_value_state = layer_outputs[:2] |
| + hidden_states, next_decoder_cache = layer_outputs[:2] |
| |
| # We share the position biases between the layers - the first layer store them |
| # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), |
| @@ -1131,9 +1150,6 @@ def forward( |
| position_bias = layer_outputs[2] |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] |
| - # append next layer key value states |
| - if use_cache: |
| - present_key_value_states = present_key_value_states + (present_key_value_state,) |
| |
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[3],) |
| @@ -1153,12 +1169,18 @@ def forward( |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| |
| + next_cache = next_decoder_cache if use_cache else None |
| + if return_self_attention_cache: |
| + next_cache = past_key_values.self_attention_cache |
| + if return_legacy_cache: |
| + next_cache = past_key_values.to_legacy_cache() |
| + |
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| - present_key_value_states, |
| + next_cache, |
| all_hidden_states, |
| all_attentions, |
| all_cross_attentions, |
| @@ -1167,12 +1189,135 @@ def forward( |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| - past_key_values=present_key_value_states, |
| + past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| + def _update_causal_mask( |
| + self, |
| + attention_mask: torch.Tensor, |
| + input_tensor: torch.Tensor, |
| + cache_position: torch.Tensor, |
| + past_key_values: Cache, |
| + output_attentions: bool, |
| + ): |
| + if self.config._attn_implementation == "flash_attention_2": |
| + if attention_mask is not None and 0.0 in attention_mask: |
| + return attention_mask |
| + return None |
| + |
| + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
| + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
| + # to infer the attention mask. |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + using_static_cache = isinstance(past_key_values, StaticCache) |
| + |
| + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
| + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| + if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| + attention_mask, |
| + inputs_embeds=input_tensor, |
| + past_key_values_length=past_seen_tokens, |
| + is_training=self.training, |
| + ): |
| + return None |
| + |
| + dtype, device = input_tensor.dtype, input_tensor.device |
| + sequence_length = input_tensor.shape[1] |
| + if using_static_cache: |
| + target_length = past_key_values.get_max_cache_shape() |
| + else: |
| + target_length = ( |
| + attention_mask.shape[-1] |
| + if isinstance(attention_mask, torch.Tensor) |
| + else past_seen_tokens + sequence_length + 1 |
| + ) |
| + |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| + |
| + if ( |
| + self.config._attn_implementation == "sdpa" |
| + and attention_mask is not None |
| + and attention_mask.device.type == "cuda" |
| + and not output_attentions |
| + ): |
| + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| + # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| + |
| + return causal_mask |
| + |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + device: torch.device, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + device (`torch.device`): |
| + The device to plcae the 4D attention mask on. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| T5_START_DOCSTRING = r""" |
| |
| @@ -1286,6 +1431,9 @@ def forward( |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the |
| + cache in the correct position and to infer the complete sequence length. |
| """ |
| |
| T5_ENCODER_INPUTS_DOCSTRING = r""" |
| @@ -1446,6 +1594,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: |
| r""" |
| Returns: |
| @@ -1525,6 +1674,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| if not return_dict: |
| @@ -1656,6 +1806,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| @@ -1750,6 +1901,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| sequence_output = decoder_outputs[0] |
| |
| |
| |
| |
| @@ -34,13 +34,16 @@ |
| ) |
| |
| from ...activations import ACT2FN |
| +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| from ...generation import GenerationMixin |
| +from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_utils import PreTrainedModel |
| from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer |
| from ...utils import ( |
| ModelOutput, |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| + is_torchdynamo_compiling, |
| replace_return_docstrings, |
| ) |
| |
| @@ -154,6 +157,9 @@ |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the |
| + cache in the correct position and to infer the complete sequence length. |
| """ |
| |
| |
| @@ -411,6 +417,8 @@ class UdopPreTrainedModel(PreTrainedModel): |
| config_class = UdopConfig |
| base_model_prefix = "transformer" |
| supports_gradient_checkpointing = True |
| + _supports_cache_class = True |
| + _supports_static_cache = False |
| _keep_in_fp32_modules = ["wo"] |
| |
| def _init_weights(self, module): |
| @@ -598,7 +606,12 @@ def forward(self, hidden_states): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop |
| class UdopAttention(nn.Module): |
| - def __init__(self, config: UdopConfig, has_relative_attention_bias=False): |
| + def __init__( |
| + self, |
| + config: UdopConfig, |
| + has_relative_attention_bias=False, |
| + layer_idx: Optional[int] = None, |
| + ): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.has_relative_attention_bias = has_relative_attention_bias |
| @@ -609,6 +622,13 @@ def __init__(self, config: UdopConfig, has_relative_attention_bias=False): |
| self.n_heads = config.num_heads |
| self.dropout = config.dropout_rate |
| self.inner_dim = self.n_heads * self.key_value_proj_dim |
| + self.layer_idx = layer_idx |
| + if layer_idx is None and self.is_decoder: |
| + logger.warning_once( |
| + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " |
| + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| + "when creating this class." |
| + ) |
| |
| # Mesh TensorFlow initialization to avoid scaling before softmax |
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| @@ -685,11 +705,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets |
| relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) |
| return relative_buckets |
| |
| - def compute_bias(self, query_length, key_length, device=None): |
| + def compute_bias(self, query_length, key_length, device=None, cache_position=None): |
| """Compute binned relative position bias""" |
| if device is None: |
| device = self.relative_attention_bias.weight.device |
| - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + if cache_position is None: |
| + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + else: |
| + context_position = cache_position[:, None].to(device) |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] |
| relative_position = memory_position - context_position # shape (query_length, key_length) |
| relative_position_bucket = self._relative_position_bucket( |
| @@ -713,94 +736,72 @@ def forward( |
| query_length=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| """ |
| Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). |
| """ |
| # Input is (batch_size, seq_length, dim) |
| - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) |
| - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) |
| + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) |
| batch_size, seq_length = hidden_states.shape[:2] |
| |
| - real_seq_length = seq_length |
| - |
| - if past_key_value is not None: |
| - if len(past_key_value) != 2: |
| - raise ValueError( |
| - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" |
| - ) |
| - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length |
| - |
| - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] |
| + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder |
| + is_cross_attention = key_value_states is not None |
| |
| - def shape(states): |
| - """projection""" |
| - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + query_states = self.q(hidden_states) |
| + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| - def unshape(states): |
| - """reshape""" |
| - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) |
| + if past_key_value is not None: |
| + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| + if is_cross_attention: |
| + # after the first generated id, we can subsequently re-use all key/value_states from cache |
| + curr_past_key_value = past_key_value.cross_attention_cache |
| + else: |
| + curr_past_key_value = past_key_value.self_attention_cache |
| |
| - def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| - """projects hidden states correctly to key/query states""" |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(hidden_states)) |
| - elif past_key_value is None: |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| + current_states = key_value_states if is_cross_attention else hidden_states |
| + if is_cross_attention and past_key_value is not None and is_updated: |
| + # reuse k,v, cross_attentions |
| + key_states = curr_past_key_value.key_cache[self.layer_idx] |
| + value_states = curr_past_key_value.value_cache[self.layer_idx] |
| + else: |
| + key_states = self.k(current_states) |
| + value_states = self.v(current_states) |
| + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| if past_key_value is not None: |
| - if key_value_states is None: |
| - # self-attn |
| - # (batch_size, n_heads, key_length, dim_per_head) |
| - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) |
| - elif past_key_value.shape[2] != key_value_states.shape[1]: |
| - # checking that the `sequence_length` of the `past_key_value` is the same as |
| - # the provided `key_value_states` to support prefix tuning |
| - # cross-attn |
| - # (batch_size, n_heads, seq_length, dim_per_head) |
| - hidden_states = shape(proj_layer(key_value_states)) |
| - else: |
| - # cross-attn |
| - hidden_states = past_key_value |
| - return hidden_states |
| - |
| - # get query states |
| - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) |
| - |
| - # get key/value states |
| - key_states = project( |
| - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None |
| - ) |
| - value_states = project( |
| - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None |
| - ) |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + cache_position = cache_position if not is_cross_attention else None |
| + key_states, value_states = curr_past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| + if is_cross_attention: |
| + past_key_value.is_updated[self.layer_idx] = True |
| |
| - # compute scores |
| - scores = torch.matmul( |
| - query_states, key_states.transpose(3, 2) |
| - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + scores = torch.matmul(query_states, key_states.transpose(3, 2)) |
| |
| if position_bias is None: |
| + key_length = key_states.shape[-2] |
| + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) |
| + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 |
| if not self.has_relative_attention_bias: |
| position_bias = torch.zeros( |
| - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype |
| + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype |
| ) |
| if self.gradient_checkpointing and self.training: |
| position_bias.requires_grad = True |
| else: |
| - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) |
| - |
| - # if key and values are already calculated |
| - # we want only the last query position bias |
| - if past_key_value is not None: |
| - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] |
| + position_bias = self.compute_bias( |
| + real_seq_length, key_length, device=scores.device, cache_position=cache_position |
| + ) |
| + position_bias = position_bias[:, :, -seq_length:, :] |
| |
| if mask is not None: |
| - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) |
| + causal_mask = mask[:, :, :, : key_states.shape[-2]] |
| + position_bias = position_bias + causal_mask |
| |
| if self.pruned_heads: |
| mask = torch.ones(position_bias.shape[1]) |
| @@ -810,22 +811,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| position_bias_masked = position_bias |
| |
| scores += position_bias_masked |
| - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( |
| - scores |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| - attn_weights = nn.functional.dropout( |
| - attn_weights, p=self.dropout, training=self.training |
| - ) # (batch_size, n_heads, seq_length, key_length) |
| + |
| + # (batch_size, n_heads, seq_length, key_length) |
| + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) |
| + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| |
| # Mask heads if we want to |
| if layer_head_mask is not None: |
| attn_weights = attn_weights * layer_head_mask |
| |
| - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) |
| + attn_output = torch.matmul(attn_weights, value_states) |
| + |
| + attn_output = attn_output.transpose(1, 2).contiguous() |
| + attn_output = attn_output.view(batch_size, -1, self.inner_dim) |
| attn_output = self.o(attn_output) |
| |
| - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None |
| - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) |
| + outputs = (attn_output, past_key_value, position_bias) |
| |
| if output_attentions: |
| outputs = outputs + (attn_weights,) |
| @@ -834,9 +835,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop |
| class UdopLayerSelfAttention(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.SelfAttention = UdopAttention(config, has_relative_attention_bias=has_relative_attention_bias) |
| + self.SelfAttention = UdopAttention( |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| + ) |
| self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -849,6 +852,7 @@ def forward( |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.SelfAttention( |
| @@ -859,6 +863,7 @@ def forward( |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them |
| @@ -867,9 +872,9 @@ def forward( |
| |
| # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop |
| class UdopLayerCrossAttention(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False) |
| + self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) |
| self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -884,6 +889,7 @@ def forward( |
| use_cache=False, |
| query_length=None, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.EncDecAttention( |
| @@ -896,6 +902,7 @@ def forward( |
| use_cache=use_cache, |
| query_length=query_length, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| layer_output = hidden_states + self.dropout(attention_output[0]) |
| outputs = (layer_output,) + attention_output[1:] # add attentions if we output them |
| @@ -904,13 +911,17 @@ def forward( |
| |
| # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop |
| class UdopBlock(nn.Module): |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.layer = nn.ModuleList() |
| - self.layer.append(UdopLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) |
| + self.layer.append( |
| + UdopLayerSelfAttention( |
| + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx |
| + ) |
| + ) |
| if self.is_decoder: |
| - self.layer.append(UdopLayerCrossAttention(config)) |
| + self.layer.append(UdopLayerCrossAttention(config, layer_idx=layer_idx)) |
| |
| self.layer.append(UdopLayerFF(config)) |
| |
| @@ -928,34 +939,19 @@ def forward( |
| use_cache=False, |
| output_attentions=False, |
| return_dict=True, |
| + cache_position=None, |
| ): |
| - if past_key_value is not None: |
| - if not self.is_decoder: |
| - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") |
| - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 |
| - |
| - if len(past_key_value) != expected_num_past_key_values: |
| - raise ValueError( |
| - f"There should be {expected_num_past_key_values} past states. " |
| - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" |
| - f"Got {len(past_key_value)} past key / value states" |
| - ) |
| - |
| - self_attn_past_key_value = past_key_value[:2] |
| - cross_attn_past_key_value = past_key_value[2:] |
| - else: |
| - self_attn_past_key_value, cross_attn_past_key_value = None, None |
| - |
| self_attention_outputs = self.layer[0]( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_bias=position_bias, |
| layer_head_mask=layer_head_mask, |
| - past_key_value=self_attn_past_key_value, |
| + past_key_value=past_key_value, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| - hidden_states, present_key_value_state = self_attention_outputs[:2] |
| + hidden_states, past_key_value = self_attention_outputs[:2] |
| attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights |
| |
| # clamp inf values to enable fp16 training |
| @@ -969,25 +965,18 @@ def forward( |
| |
| do_cross_attention = self.is_decoder and encoder_hidden_states is not None |
| if do_cross_attention: |
| - # the actual query length is unknown for cross attention |
| - # if using past key value states. Need to inject it here |
| - if present_key_value_state is not None: |
| - query_length = present_key_value_state[0].shape[2] |
| - else: |
| - query_length = None |
| - |
| cross_attention_outputs = self.layer[1]( |
| hidden_states, |
| key_value_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| position_bias=encoder_decoder_position_bias, |
| layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=cross_attn_past_key_value, |
| - query_length=query_length, |
| + past_key_value=past_key_value, |
| + query_length=cache_position[-1] + 1, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| ) |
| - hidden_states = cross_attention_outputs[0] |
| + hidden_states, past_key_value = cross_attention_outputs[:2] |
| |
| # clamp inf values to enable fp16 training |
| if hidden_states.dtype == torch.float16: |
| @@ -998,10 +987,6 @@ def forward( |
| ) |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| - # Combine self attn and cross attn key value states |
| - if present_key_value_state is not None: |
| - present_key_value_state = present_key_value_state + cross_attention_outputs[1] |
| - |
| # Keep cross-attention outputs and relative position weights |
| attention_outputs = attention_outputs + cross_attention_outputs[2:] |
| |
| @@ -1020,11 +1005,11 @@ def forward( |
| outputs = (hidden_states,) |
| |
| if use_cache: |
| - outputs = outputs + (present_key_value_state,) + attention_outputs |
| + outputs = outputs + (past_key_value,) + attention_outputs |
| else: |
| outputs = outputs + attention_outputs |
| |
| - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) |
| + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) |
| |
| |
| class UdopCellEmbeddings(nn.Module): |
| @@ -1286,7 +1271,7 @@ def __init__(self, config, embed_tokens=None, embed_patches=None): |
| self.num_layers = config.num_layers |
| |
| self.block = nn.ModuleList( |
| - [UdopBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(self.num_layers)] |
| + [UdopBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(self.num_layers)] |
| ) |
| self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| |
| @@ -1338,6 +1323,7 @@ def forward( |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| + cache_position=None, |
| ): |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -1399,26 +1385,54 @@ def forward( |
| |
| batch_size, seq_length = input_shape |
| |
| - # required mask seq length can be calculated via length of past |
| - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length |
| - |
| if use_cache is True: |
| assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) |
| |
| - if attention_mask is None: |
| - attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) |
| - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: |
| - encoder_seq_length = encoder_hidden_states.shape[1] |
| - encoder_attention_mask = torch.ones( |
| - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long |
| + # initialize past_key_values |
| + return_legacy_cache = False |
| + return_self_attention_cache = False |
| + if self.is_decoder and (use_cache or past_key_values is not None): |
| + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): |
| + return_self_attention_cache = True |
| + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| + elif not isinstance(past_key_values, EncoderDecoderCache): |
| + return_legacy_cache = True |
| + logger.warning_once( |
| + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " |
| + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " |
| + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." |
| + ) |
| + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
| + elif past_key_values is None: |
| + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| + elif not self.is_decoder: |
| + # do not pass cache object down the line for encoder stack |
| + # it messes indexing later in decoder-stack because cache object is modified in-place |
| + past_key_values = None |
| + |
| + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + if cache_position is None: |
| + cache_position = torch.arange( |
| + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device |
| ) |
| |
| - # initialize past_key_values with `None` if past does not exist |
| - if past_key_values is None: |
| - past_key_values = [None] * len(self.block) |
| + if attention_mask is None and not is_torchdynamo_compiling(): |
| + # required mask seq length can be calculated via length of past cache |
| + mask_seq_length = past_key_values_length + seq_length |
| + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| |
| - # ourselves in which case we just need to make it broadcastable to all heads. |
| - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
| + if self.config.is_decoder: |
| + causal_mask = self._update_causal_mask( |
| + attention_mask, |
| + inputs_embeds, |
| + cache_position, |
| + past_key_values.self_attention_cache if past_key_values is not None else None, |
| + output_attentions, |
| + ) |
| + else: |
| + causal_mask = attention_mask[:, None, None, :] |
| + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) |
| + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min |
| |
| if self.is_decoder and encoder_attention_mask is not None: |
| encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
| @@ -1427,7 +1441,6 @@ def forward( |
| |
| # Prepare head mask if needed |
| head_mask = self.get_head_mask(head_mask, self.num_layers) |
| - present_key_value_states = () if use_cache else None |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_cross_attentions = () if (output_attentions and self.is_decoder) else None |
| @@ -1436,34 +1449,35 @@ def forward( |
| position_bias = None |
| else: |
| position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox) |
| - position_bias = position_bias + extended_attention_mask |
| + position_bias = position_bias + causal_mask |
| encoder_decoder_position_bias = None |
| |
| hidden_states = inputs_embeds |
| |
| hidden_states = self.dropout(hidden_states) |
| |
| - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): |
| + for i, layer_module in enumerate(self.block): |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| |
| layer_outputs = layer_module( |
| hidden_states, |
| - attention_mask=extended_attention_mask, |
| + attention_mask=causal_mask, |
| position_bias=position_bias, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| encoder_decoder_position_bias=encoder_decoder_position_bias, |
| layer_head_mask=head_mask[i], |
| - past_key_value=past_key_value, |
| + past_key_value=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| # layer_outputs is a tuple with: |
| # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) |
| if use_cache is False: # MP fixes |
| layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] |
| - hidden_states, present_key_value_state = layer_outputs[:2] |
| + hidden_states, next_decoder_cache = layer_outputs[:2] |
| |
| # We share the position biases between the layers - the first layer store them |
| # layer_outputs = hidden-states, key-value-states (self-attention weights), |
| @@ -1472,9 +1486,6 @@ def forward( |
| position_bias = layer_outputs[2] |
| if self.is_decoder and encoder_hidden_states is not None: |
| encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] |
| - # append next layer key value states |
| - if use_cache: |
| - present_key_value_states = present_key_value_states + (present_key_value_state,) |
| |
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now |
| @@ -1488,13 +1499,19 @@ def forward( |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| |
| + next_cache = next_decoder_cache if use_cache else None |
| + if return_self_attention_cache: |
| + next_cache = past_key_values.self_attention_cache |
| + if return_legacy_cache: |
| + next_cache = past_key_values.to_legacy_cache() |
| + |
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| attention_mask, |
| - present_key_value_states, |
| + next_cache, |
| all_hidden_states, |
| all_attentions, |
| all_cross_attentions, |
| @@ -1505,12 +1522,135 @@ def forward( |
| return BaseModelOutputWithAttentionMask( |
| last_hidden_state=hidden_states, |
| attention_mask=attention_mask, |
| - past_key_values=present_key_value_states, |
| + past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| + def _update_causal_mask( |
| + self, |
| + attention_mask: torch.Tensor, |
| + input_tensor: torch.Tensor, |
| + cache_position: torch.Tensor, |
| + past_key_values: Cache, |
| + output_attentions: bool, |
| + ): |
| + if self.config._attn_implementation == "flash_attention_2": |
| + if attention_mask is not None and 0.0 in attention_mask: |
| + return attention_mask |
| + return None |
| + |
| + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
| + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
| + # to infer the attention mask. |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + using_static_cache = isinstance(past_key_values, StaticCache) |
| + |
| + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
| + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| + if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| + attention_mask, |
| + inputs_embeds=input_tensor, |
| + past_key_values_length=past_seen_tokens, |
| + is_training=self.training, |
| + ): |
| + return None |
| + |
| + dtype, device = input_tensor.dtype, input_tensor.device |
| + sequence_length = input_tensor.shape[1] |
| + if using_static_cache: |
| + target_length = past_key_values.get_max_cache_shape() |
| + else: |
| + target_length = ( |
| + attention_mask.shape[-1] |
| + if isinstance(attention_mask, torch.Tensor) |
| + else past_seen_tokens + sequence_length + 1 |
| + ) |
| + |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| + |
| + if ( |
| + self.config._attn_implementation == "sdpa" |
| + and attention_mask is not None |
| + and attention_mask.device.type == "cuda" |
| + and not output_attentions |
| + ): |
| + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| + # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| + |
| + return causal_mask |
| + |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + device: torch.device, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + device (`torch.device`): |
| + The device to plcae the 4D attention mask on. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| @add_start_docstrings( |
| "The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top.", |
| @@ -1584,6 +1724,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Tuple[Tensor, ...]: |
| r""" |
| Returns: |
| @@ -1653,6 +1794,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| if not return_dict: |
| @@ -1759,6 +1901,7 @@ def forward( |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| labels: Optional[Tensor] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Tuple[Tensor, ...]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| @@ -1837,6 +1980,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| sequence_output = decoder_outputs[0] |
| |
| |
| |
| |
| @@ -72,7 +72,12 @@ class UMT5Config(PretrainedConfig): |
| |
| model_type = "umt5" |
| keys_to_ignore_at_inference = ["past_key_values"] |
| - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} |
| + attribute_map = { |
| + "hidden_size": "d_model", |
| + "num_attention_heads": "num_heads", |
| + "num_hidden_layers": "num_layers", |
| + "head_dim": "d_kv", |
| + } |
| |
| def __init__( |
| self, |
| |
| |
| |
| |
| @@ -23,7 +23,9 @@ |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
| |
| from ...activations import ACT2FN |
| +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache |
| from ...generation import GenerationMixin |
| +from ...modeling_attn_mask_utils import AttentionMaskConverter |
| from ...modeling_outputs import ( |
| BaseModelOutput, |
| BaseModelOutputWithPastAndCrossAttentions, |
| @@ -40,6 +42,7 @@ |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| is_torch_fx_proxy, |
| + is_torchdynamo_compiling, |
| logging, |
| replace_return_docstrings, |
| ) |
| @@ -155,7 +158,7 @@ class UMT5Attention(nn.Module): |
| T5's attention using relative_attention_bias. |
| """ |
| |
| - def __init__(self, config, has_relative_attention_bias=False): |
| + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.has_relative_attention_bias = has_relative_attention_bias |
| @@ -166,6 +169,13 @@ def __init__(self, config, has_relative_attention_bias=False): |
| self.n_heads = config.num_heads |
| self.dropout = config.dropout_rate |
| self.inner_dim = self.n_heads * self.key_value_proj_dim |
| + self.layer_idx = layer_idx |
| + if layer_idx is None and self.is_decoder: |
| + logger.warning_once( |
| + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " |
| + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " |
| + "when creating this class." |
| + ) |
| |
| # Mesh TensorFlow initialization to avoid scaling before softmax |
| self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) |
| @@ -230,11 +240,14 @@ def _relative_position_bucket(self, relative_position): |
| relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) |
| return relative_buckets |
| |
| - def compute_bias(self, query_length, key_length, device=None): |
| + def compute_bias(self, query_length, key_length, device=None, cache_position=None): |
| """Compute binned relative position bias""" |
| if device is None: |
| device = self.relative_attention_bias.weight.device |
| - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + if cache_position is None: |
| + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] |
| + else: |
| + context_position = cache_position[:, None] |
| memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] |
| relative_position = memory_position - context_position # shape (query_length, key_length) |
| relative_position_bucket = self._relative_position_bucket(relative_position) |
| @@ -249,78 +262,95 @@ def forward( |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| layer_head_mask: Optional[torch.Tensor] = None, |
| + cache_position: Optional[torch.Tensor] = None, |
| ): |
| - is_cross_attention = encoder_hidden_states is not None |
| batch_size, seq_length = hidden_states.shape[:2] |
| |
| - # use encoder_hidden_states if cross attention |
| - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
| - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided |
| - # `encoder_hidden_states` to support prefix tuning |
| - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: |
| + # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder |
| + is_cross_attention = encoder_hidden_states is not None |
| + |
| + query_states = self.q(hidden_states) |
| + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + |
| + if past_key_value is not None: |
| + is_updated = past_key_value.is_updated.get(self.layer_idx) |
| + if is_cross_attention: |
| + # after the first generated id, we can subsequently re-use all key/value_states from cache |
| + curr_past_key_value = past_key_value.cross_attention_cache |
| + else: |
| + curr_past_key_value = past_key_value.self_attention_cache |
| + |
| + current_states = encoder_hidden_states if is_cross_attention else hidden_states |
| + if is_cross_attention and past_key_value is not None and is_updated: |
| # reuse k,v, cross_attentions |
| - key_states = past_key_value[0] |
| - value_states = past_key_value[1] |
| + key_states = curr_past_key_value.key_cache[self.layer_idx] |
| + value_states = curr_past_key_value.value_cache[self.layer_idx] |
| else: |
| - key_states = self._shape(self.k(current_states)) |
| - value_states = self._shape(self.v(current_states)) |
| - if past_key_value is not None and not is_cross_attention: |
| - # reuse k, v, self_attention |
| - key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| - value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| - |
| - query_states = self._shape(self.q(hidden_states)) |
| - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) |
| + key_states = self.k(current_states) |
| + value_states = self.v(current_states) |
| + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) |
| |
| - # compute positional bias |
| - if self.has_relative_attention_bias: |
| - query_length = seq_length |
| if past_key_value is not None: |
| - query_length += past_key_value[0].shape[2] |
| - position_bias = self.compute_bias(query_length, key_states.size(2), device=attention_scores.device) |
| - else: |
| + # save all key/value_states to cache to be re-used for fast auto-regressive generation |
| + cache_position = cache_position if not is_cross_attention else None |
| + key_states, value_states = curr_past_key_value.update( |
| + key_states, value_states, self.layer_idx, {"cache_position": cache_position} |
| + ) |
| + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls |
| + if is_cross_attention: |
| + past_key_value.is_updated[self.layer_idx] = True |
| + |
| + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 |
| + scores = torch.matmul(query_states, key_states.transpose(3, 2)) |
| + |
| + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) |
| + real_seq_length = seq_length + past_key_value.get_seq_length() if past_key_value is not None else seq_length |
| + key_length = key_states.shape[-2] |
| + if not self.has_relative_attention_bias: |
| position_bias = torch.zeros( |
| - (1, self.n_heads, seq_length, key_states.size(2)), |
| - device=attention_scores.device, |
| - dtype=attention_scores.dtype, |
| - requires_grad=self.training, |
| + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype |
| ) |
| - if past_key_value is not None: |
| - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] |
| + else: |
| + position_bias = self.compute_bias( |
| + real_seq_length, key_length, device=scores.device, cache_position=cache_position |
| + ) |
| + position_bias = position_bias[:, :, -seq_length:, :] |
| + |
| if attention_mask is not None: |
| - position_bias = position_bias + attention_mask # (batch_size, n_heads, seq_length, key_length) |
| + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| + position_bias = position_bias + causal_mask |
| + |
| + if self.pruned_heads: |
| + mask = torch.ones(position_bias.shape[1]) |
| + mask[list(self.pruned_heads)] = 0 |
| + position_bias_masked = position_bias[:, mask.bool()] |
| + else: |
| + position_bias_masked = position_bias |
| + |
| + scores += position_bias_masked |
| |
| - if self.is_decoder: |
| - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. |
| - # Further calls to cross_attention layer can then reuse all cross-attention |
| - # key/value_states (first "if" case) |
| - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of |
| - # all previous decoder key/value_states. Further calls to uni-directional self-attention |
| - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) |
| - # if encoder bi-directional self-attention `past_key_value` is always `None` |
| - past_key_value = (key_states, value_states) |
| - |
| - attention_scores += position_bias |
| # (batch_size, n_heads, seq_length, key_length) |
| - attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).type_as(attention_scores) |
| + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) |
| attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
| |
| # Mask heads if we want to |
| if layer_head_mask is not None: |
| attn_weights = attn_weights * layer_head_mask |
| |
| - # attn_output = torch.bmm(attn_probs, value_states) ? |
| - context_states = torch.matmul(attn_weights, value_states) |
| - # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ? |
| - context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) |
| - attn_output = self.o(context_states) |
| + attn_output = torch.matmul(attn_weights, value_states) |
| + |
| + attn_output = attn_output.transpose(1, 2).contiguous() |
| + attn_output = attn_output.view(batch_size, seq_length, -1) |
| + |
| + attn_output = self.o(attn_output) |
| return attn_output, attn_weights, past_key_value |
| |
| |
| class UMT5LayerSelfAttention(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True) |
| + self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True, layer_idx=layer_idx) |
| self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -330,6 +360,7 @@ def forward( |
| attention_mask=None, |
| layer_head_mask=None, |
| past_key_value=None, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.SelfAttention( |
| @@ -337,6 +368,7 @@ def forward( |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| past_key_value=past_key_value, |
| + cache_position=cache_position, |
| ) |
| hidden_states = hidden_states + self.dropout(attention_output[0]) |
| outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them |
| @@ -344,9 +376,9 @@ def forward( |
| |
| |
| class UMT5LayerCrossAttention(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| - self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False) |
| + self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) |
| self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -357,6 +389,7 @@ def forward( |
| attention_mask=None, |
| layer_head_mask=None, |
| past_key_value=None, |
| + cache_position=None, |
| ): |
| normed_hidden_states = self.layer_norm(hidden_states) |
| attention_output = self.EncDecAttention( |
| @@ -365,6 +398,7 @@ def forward( |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| past_key_value=past_key_value, |
| + cache_position=cache_position, |
| ) |
| layer_output = hidden_states + self.dropout(attention_output[0]) |
| outputs = (layer_output,) + attention_output[1:] # add attentions if we output them |
| @@ -372,13 +406,13 @@ def forward( |
| |
| |
| class UMT5Block(nn.Module): |
| - def __init__(self, config): |
| + def __init__(self, config, layer_idx: Optional[int] = None): |
| super().__init__() |
| self.is_decoder = config.is_decoder |
| self.layer = nn.ModuleList() |
| - self.layer.append(UMT5LayerSelfAttention(config)) |
| + self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx)) |
| if self.is_decoder: |
| - self.layer.append(UMT5LayerCrossAttention(config)) |
| + self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx)) |
| |
| self.layer.append(UMT5LayerFF(config)) |
| |
| @@ -393,16 +427,14 @@ def forward( |
| past_key_value=None, |
| use_cache=False, |
| output_attentions=False, |
| + cache_position=None, |
| ): |
| - # Self Attention |
| - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 |
| - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None |
| - |
| - hidden_states, self_attn_weights, present_key_value = self.layer[0]( |
| + hidden_states, self_attn_weights, past_key_value = self.layer[0]( |
| hidden_states, |
| attention_mask=attention_mask, |
| layer_head_mask=layer_head_mask, |
| - past_key_value=self_attn_past_key_value, |
| + past_key_value=past_key_value, |
| + cache_position=cache_position, |
| ) |
| |
| # clamp inf values to enable fp16 training |
| @@ -412,18 +444,16 @@ def forward( |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| # Cross-Attention Block |
| - cross_attn_present_key_value = None |
| cross_attn_weights = None |
| do_cross_attention = self.is_decoder and encoder_hidden_states is not None |
| if do_cross_attention: |
| - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple |
| - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None |
| - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.layer[1]( |
| + hidden_states, cross_attn_weights, past_key_value = self.layer[1]( |
| hidden_states, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=encoder_attention_mask, |
| layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=cross_attn_past_key_value, |
| + past_key_value=past_key_value, |
| + cache_position=cache_position, |
| ) |
| # clamp inf values to enable fp16 training |
| if hidden_states.dtype == torch.float16: |
| @@ -431,8 +461,6 @@ def forward( |
| clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype) |
| hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| |
| - present_key_value += cross_attn_present_key_value |
| - |
| # Apply Feed Forward layer |
| hidden_states = self.layer[-1](hidden_states) |
| |
| @@ -444,7 +472,7 @@ def forward( |
| |
| outputs = ( |
| hidden_states, |
| - present_key_value, |
| + past_key_value, |
| ) |
| |
| if output_attentions: |
| @@ -481,6 +509,8 @@ class UMT5PreTrainedModel(PreTrainedModel): |
| config_class = UMT5Config |
| base_model_prefix = "transformer" |
| supports_gradient_checkpointing = True |
| + _supports_cache_class = True |
| + _supports_static_cache = True |
| _no_split_modules = ["UMT5Block"] |
| _keep_in_fp32_modules = ["wo"] |
| |
| @@ -594,7 +624,7 @@ def __init__(self, config, embed_tokens=None): |
| super().__init__(config) |
| self.embed_tokens = embed_tokens |
| self.is_decoder = config.is_decoder |
| - self.block = nn.ModuleList([UMT5Block(config) for i in range(config.num_layers)]) |
| + self.block = nn.ModuleList([UMT5Block(config, layer_idx=i) for i in range(config.num_layers)]) |
| self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) |
| self.dropout = nn.Dropout(config.dropout_rate) |
| |
| @@ -622,6 +652,7 @@ def forward( |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| + cache_position=None, |
| ): |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| @@ -644,6 +675,13 @@ def forward( |
| err_msg_prefix = "decoder_" if self.is_decoder else "" |
| raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") |
| |
| + if self.gradient_checkpointing and self.training: |
| + if use_cache: |
| + logger.warning_once( |
| + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| + ) |
| + use_cache = False |
| + |
| if inputs_embeds is None: |
| if self.embed_tokens is None: |
| raise ValueError("You have to initialize the model with valid token embeddings") |
| @@ -651,28 +689,57 @@ def forward( |
| |
| batch_size, seq_length = input_shape |
| |
| - # required mask seq length can be calculated via length of past |
| - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length |
| - |
| if use_cache is True: |
| if not self.is_decoder: |
| raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") |
| |
| - if attention_mask is None: |
| - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: |
| - encoder_seq_length = encoder_hidden_states.shape[1] |
| - encoder_attention_mask = torch.ones( |
| - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long |
| + # initialize past_key_values |
| + return_legacy_cache = False |
| + return_self_attention_cache = False |
| + if self.is_decoder and (use_cache or past_key_values is not None): |
| + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): |
| + return_self_attention_cache = True |
| + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) |
| + elif not isinstance(past_key_values, EncoderDecoderCache): |
| + return_legacy_cache = True |
| + logger.warning_once( |
| + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " |
| + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " |
| + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." |
| + ) |
| + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
| + elif past_key_values is None: |
| + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) |
| + elif not self.is_decoder: |
| + # do not pass cache object down the line for encoder stack |
| + # it messes indexing later in decoder-stack because cache object is modified in-place |
| + past_key_values = None |
| + |
| + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + if cache_position is None: |
| + cache_position = torch.arange( |
| + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device |
| ) |
| |
| - # initialize past_key_values with `None` if past does not exist |
| - if past_key_values is None: |
| - past_key_values = [None] * len(self.block) |
| + if attention_mask is None and not is_torchdynamo_compiling(): |
| + # required mask seq length can be calculated via length of past cache |
| + mask_seq_length = past_key_values_length + seq_length |
| + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
| |
| - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] |
| - # ourselves in which case we just need to make it broadcastable to all heads. |
| - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
| + if self.is_decoder: |
| + causal_mask = self._update_causal_mask( |
| + attention_mask, |
| + inputs_embeds, |
| + cache_position, |
| + past_key_values.self_attention_cache if past_key_values is not None else None, |
| + output_attentions, |
| + ) |
| + elif attention_mask is not None: |
| + causal_mask = attention_mask[:, None, None, :] |
| + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) |
| + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min |
| + else: |
| + causal_mask = None |
| |
| # If a 2D or 3D attention mask is provided for the cross-attention |
| # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] |
| @@ -685,24 +752,16 @@ def forward( |
| else: |
| encoder_extended_attention_mask = None |
| |
| - if self.gradient_checkpointing and self.training: |
| - if use_cache: |
| - logger.warning_once( |
| - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| - ) |
| - use_cache = False |
| - |
| # Prepare head mask if needed |
| head_mask = self.get_head_mask(head_mask, self.config.num_layers) |
| cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) |
| - present_key_value_states = () if use_cache else None |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| all_cross_attentions = () if output_attentions and self.is_decoder else None |
| |
| hidden_states = self.dropout(inputs_embeds) |
| |
| - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): |
| + for i, layer_module in enumerate(self.block): |
| layer_head_mask = head_mask[i] |
| cross_attn_layer_head_mask = cross_attn_head_mask[i] |
| |
| @@ -713,7 +772,7 @@ def forward( |
| layer_outputs = self._gradient_checkpointing_func( |
| layer_module.forward, |
| hidden_states, |
| - extended_attention_mask, |
| + causal_mask, |
| encoder_hidden_states, |
| encoder_extended_attention_mask, |
| layer_head_mask, |
| @@ -721,24 +780,26 @@ def forward( |
| None, # past_key_value is always None with gradient checkpointing |
| use_cache, |
| output_attentions, |
| + cache_position, |
| ) |
| else: |
| layer_outputs = layer_module( |
| hidden_states, |
| - attention_mask=extended_attention_mask, |
| + attention_mask=causal_mask, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=encoder_extended_attention_mask, |
| layer_head_mask=layer_head_mask, |
| cross_attn_layer_head_mask=cross_attn_layer_head_mask, |
| - past_key_value=past_key_value, |
| + past_key_value=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| + cache_position=cache_position, |
| ) |
| |
| hidden_states = layer_outputs[0] |
| |
| if use_cache: |
| - present_key_value_states += (layer_outputs[1],) |
| + next_decoder_cache = layer_outputs[1] |
| |
| if output_attentions: |
| all_attentions += (layer_outputs[2],) |
| @@ -752,12 +813,18 @@ def forward( |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
| |
| + next_cache = next_decoder_cache if use_cache else None |
| + if return_self_attention_cache: |
| + next_cache = past_key_values.self_attention_cache |
| + if return_legacy_cache: |
| + next_cache = past_key_values.to_legacy_cache() |
| + |
| if not return_dict: |
| return tuple( |
| v |
| for v in [ |
| hidden_states, |
| - present_key_value_states, |
| + next_cache, |
| all_hidden_states, |
| all_attentions, |
| all_cross_attentions, |
| @@ -766,12 +833,135 @@ def forward( |
| ) |
| return BaseModelOutputWithPastAndCrossAttentions( |
| last_hidden_state=hidden_states, |
| - past_key_values=present_key_value_states, |
| + past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| cross_attentions=all_cross_attentions, |
| ) |
| |
| + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask |
| + def _update_causal_mask( |
| + self, |
| + attention_mask: torch.Tensor, |
| + input_tensor: torch.Tensor, |
| + cache_position: torch.Tensor, |
| + past_key_values: Cache, |
| + output_attentions: bool, |
| + ): |
| + if self.config._attn_implementation == "flash_attention_2": |
| + if attention_mask is not None and 0.0 in attention_mask: |
| + return attention_mask |
| + return None |
| + |
| + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in |
| + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail |
| + # to infer the attention mask. |
| + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| + using_static_cache = isinstance(past_key_values, StaticCache) |
| + |
| + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward |
| + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
| + if AttentionMaskConverter._ignore_causal_mask_sdpa( |
| + attention_mask, |
| + inputs_embeds=input_tensor, |
| + past_key_values_length=past_seen_tokens, |
| + is_training=self.training, |
| + ): |
| + return None |
| + |
| + dtype, device = input_tensor.dtype, input_tensor.device |
| + sequence_length = input_tensor.shape[1] |
| + if using_static_cache: |
| + target_length = past_key_values.get_max_cache_shape() |
| + else: |
| + target_length = ( |
| + attention_mask.shape[-1] |
| + if isinstance(attention_mask, torch.Tensor) |
| + else past_seen_tokens + sequence_length + 1 |
| + ) |
| + |
| + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). |
| + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask, |
| + sequence_length=sequence_length, |
| + target_length=target_length, |
| + dtype=dtype, |
| + device=device, |
| + cache_position=cache_position, |
| + batch_size=input_tensor.shape[0], |
| + ) |
| + |
| + if ( |
| + self.config._attn_implementation == "sdpa" |
| + and attention_mask is not None |
| + and attention_mask.device.type == "cuda" |
| + and not output_attentions |
| + ): |
| + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when |
| + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. |
| + # Details: https://github.com/pytorch/pytorch/issues/110213 |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
| + |
| + return causal_mask |
| + |
| + @staticmethod |
| + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position |
| + def _prepare_4d_causal_attention_mask_with_cache_position( |
| + attention_mask: torch.Tensor, |
| + sequence_length: int, |
| + target_length: int, |
| + dtype: torch.dtype, |
| + device: torch.device, |
| + cache_position: torch.Tensor, |
| + batch_size: int, |
| + **kwargs, |
| + ): |
| + """ |
| + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| + |
| + Args: |
| + attention_mask (`torch.Tensor`): |
| + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
| + `(batch_size, 1, query_length, key_value_length)`. |
| + sequence_length (`int`): |
| + The sequence length being processed. |
| + target_length (`int`): |
| + The target length: when generating with static cache, the mask should be as long as the static cache, |
| + to account for the 0 padding, the part of the cache that is not filled yet. |
| + dtype (`torch.dtype`): |
| + The dtype to use for the 4D attention mask. |
| + device (`torch.device`): |
| + The device to plcae the 4D attention mask on. |
| + cache_position (`torch.Tensor`): |
| + Indices depicting the position of the input sequence tokens in the sequence. |
| + batch_size (`torch.Tensor`): |
| + Batch size. |
| + """ |
| + if attention_mask is not None and attention_mask.dim() == 4: |
| + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. |
| + causal_mask = attention_mask |
| + else: |
| + min_dtype = torch.finfo(dtype).min |
| + causal_mask = torch.full( |
| + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| + ) |
| + if sequence_length != 1: |
| + causal_mask = torch.triu(causal_mask, diagonal=1) |
| + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| + if attention_mask is not None: |
| + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit |
| + mask_length = attention_mask.shape[-1] |
| + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| + padding_mask = padding_mask == 0 |
| + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| + padding_mask, min_dtype |
| + ) |
| + |
| + return causal_mask |
| + |
| |
| UMT5_START_DOCSTRING = r""" |
| |
| @@ -885,6 +1075,9 @@ def forward( |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): |
| + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the |
| + cache in the correct position and to infer the complete sequence length. |
| """ |
| |
| UMT5_ENCODER_INPUTS_DOCSTRING = r""" |
| @@ -1022,6 +1215,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: |
| r""" |
| Returns: |
| @@ -1084,6 +1278,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| if not return_dict: |
| @@ -1197,6 +1392,7 @@ def forward( |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| + cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
| r""" |
| labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
| @@ -1268,6 +1464,7 @@ def forward( |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| + cache_position=cache_position, |
| ) |
| |
| sequence_output = decoder_outputs[0] |
| |
| |
| |
| |
| @@ -31,6 +31,7 @@ |
| |
| if is_torch_available(): |
| import torch |
| + import torch.nn.functional as F |
| |
| from transformers import ( |
| MODEL_FOR_QUESTION_ANSWERING_MAPPING, |
| @@ -574,6 +575,41 @@ def test_decoder_model_past_with_3d_attn_mask(self): |
| lm_labels, |
| ) |
| |
| + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` |
| + def test_custom_4d_attention_mask(self): |
| + for model_class in self.all_generative_model_classes: |
| + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| + model = model_class(config).to(device=torch_device, dtype=torch.float32) |
| + |
| + ( |
| + input_ids, |
| + _, |
| + input_ids_shared_prefix, |
| + mask_shared_prefix, |
| + _, |
| + ) = self._get_custom_4d_mask_test_data() |
| + |
| + logits = model.forward( |
| + decoder_input_ids=input_ids, |
| + input_ids=input_dict["input_ids"][:3], |
| + ).logits |
| + # logits.shape == torch.Size([3, 4, ...]) |
| + |
| + logits_shared_prefix = model( |
| + input_ids=input_dict["input_ids"][:1], |
| + decoder_input_ids=input_ids_shared_prefix, |
| + decoder_attention_mask=mask_shared_prefix, |
| + )[0] |
| + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) |
| + |
| + out_last_tokens = logits[:, -1, :] # last tokens in each batch line |
| + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens |
| + |
| + # comparing softmax-normalized logits: |
| + normalized_0 = F.softmax(out_last_tokens) |
| + normalized_1 = F.softmax(out_shared_prefix_last_tokens) |
| + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) |
| + |
| def test_decoder_model_past_with_large_inputs(self): |
| config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) |
| @@ -602,7 +638,7 @@ def test_export_to_onnx(self): |
| (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), |
| f"{tmpdirname}/longt5_test.onnx", |
| export_params=True, |
| - opset_version=13, |
| + opset_version=14, |
| input_names=["input_ids", "decoder_input_ids"], |
| ) |
| |
| |
| |
| |
| |
| @@ -40,6 +40,7 @@ |
| |
| if is_torch_available(): |
| import torch |
| + import torch.nn.functional as F |
| |
| from transformers import ( |
| AutoModelForSeq2SeqLM, |
| @@ -575,6 +576,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, |
| # The small MT5 model needs higher percentages for CPU/MP tests |
| model_split_percents = [0.5, 0.8, 0.9] |
| |
| + # used in `test_torch_compile` |
| + _torch_compile_test_ckpt = "google/mt5-small" |
| + |
| def setUp(self): |
| self.model_tester = MT5ModelTester(self) |
| self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37) |
| @@ -627,12 +631,9 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa |
| ] |
| if labels is not None: |
| input_names.append("labels") |
| - |
| filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} |
| input_names = list(filtered_inputs.keys()) |
| - |
| model_output = model(**filtered_inputs) |
| - |
| traced_model = symbolic_trace(model, input_names) |
| traced_output = traced_model(**filtered_inputs) |
| else: |
| @@ -647,7 +648,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa |
| "visual_feats", |
| "visual_pos", |
| ] |
| - |
| labels = inputs.get("labels", None) |
| start_positions = inputs.get("start_positions", None) |
| end_positions = inputs.get("end_positions", None) |
| @@ -657,15 +657,12 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa |
| input_names.append("start_positions") |
| if end_positions is not None: |
| input_names.append("end_positions") |
| - |
| filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} |
| input_names = list(filtered_inputs.keys()) |
| - |
| if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( |
| not hasattr(model.config, "problem_type") or model.config.problem_type is None |
| ): |
| model.config.problem_type = "single_label_classification" |
| - |
| traced_model = symbolic_trace(model, input_names) |
| traced_output = traced_model(**filtered_inputs) |
| model_output = model(**filtered_inputs) |
| @@ -718,6 +715,41 @@ def flatten_output(output): |
| # (Even with this call, there are still memory leak by ~0.04MB) |
| self.clear_torch_jit_class_registry() |
| |
| + # overwrite because MT5 doesn't accept position ids as input and expects `decoder_input_ids` |
| + def test_custom_4d_attention_mask(self): |
| + for model_class in self.all_generative_model_classes: |
| + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| + model = model_class(config).to(device=torch_device, dtype=torch.float32) |
| + |
| + ( |
| + input_ids, |
| + _, |
| + input_ids_shared_prefix, |
| + mask_shared_prefix, |
| + _, |
| + ) = self._get_custom_4d_mask_test_data() |
| + |
| + logits = model.forward( |
| + decoder_input_ids=input_ids, |
| + input_ids=input_dict["input_ids"][:3], |
| + ).logits |
| + # logits.shape == torch.Size([3, 4, ...]) |
| + |
| + logits_shared_prefix = model( |
| + input_ids=input_dict["input_ids"][:1], |
| + decoder_input_ids=input_ids_shared_prefix, |
| + decoder_attention_mask=mask_shared_prefix, |
| + )[0] |
| + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) |
| + |
| + out_last_tokens = logits[:, -1, :] # last tokens in each batch line |
| + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens |
| + |
| + # comparing softmax-normalized logits: |
| + normalized_0 = F.softmax(out_last_tokens) |
| + normalized_1 = F.softmax(out_shared_prefix_last_tokens) |
| + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) |
| + |
| def test_config(self): |
| self.config_tester.run_common_tests() |
| |
| |
| |
| |
| |
| @@ -620,7 +620,7 @@ def test_export_to_onnx(self): |
| (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), |
| f"{tmpdirname}/Pop2Piano_test.onnx", |
| export_params=True, |
| - opset_version=9, |
| + opset_version=14, |
| input_names=["input_ids", "decoder_input_ids"], |
| ) |
| |
| |
| |
| |
| |
| @@ -36,6 +36,7 @@ |
| |
| if is_torch_available(): |
| import torch |
| + import torch.nn.functional as F |
| |
| from transformers import ( |
| AutoTokenizer, |
| @@ -645,6 +646,41 @@ def test_decoder_model_past_with_3d_attn_mask(self): |
| lm_labels, |
| ) |
| |
| + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` |
| + def test_custom_4d_attention_mask(self): |
| + for model_class in self.all_generative_model_classes: |
| + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| + model = model_class(config).to(device=torch_device, dtype=torch.float32) |
| + |
| + ( |
| + input_ids, |
| + _, |
| + input_ids_shared_prefix, |
| + mask_shared_prefix, |
| + _, |
| + ) = self._get_custom_4d_mask_test_data() |
| + |
| + logits = model.forward( |
| + decoder_input_ids=input_ids, |
| + input_ids=input_dict["input_ids"][:3], |
| + ).logits |
| + # logits.shape == torch.Size([3, 4, ...]) |
| + |
| + logits_shared_prefix = model( |
| + input_ids=input_dict["input_ids"][:1], |
| + decoder_input_ids=input_ids_shared_prefix, |
| + decoder_attention_mask=mask_shared_prefix, |
| + )[0] |
| + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) |
| + |
| + out_last_tokens = logits[:, -1, :] # last tokens in each batch line |
| + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens |
| + |
| + # comparing softmax-normalized logits: |
| + normalized_0 = F.softmax(out_last_tokens) |
| + normalized_1 = F.softmax(out_shared_prefix_last_tokens) |
| + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) |
| + |
| def test_decoder_model_past_with_large_inputs(self): |
| config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) |
| |
| |
| |
| |
| @@ -27,6 +27,7 @@ |
| require_sentencepiece, |
| require_tokenizers, |
| require_torch, |
| + require_torch_gpu, |
| slow, |
| torch_device, |
| ) |
| @@ -44,6 +45,7 @@ |
| |
| if is_torch_available(): |
| import torch |
| + import torch.nn.functional as F |
| |
| from transformers import ( |
| AutoTokenizer, |
| @@ -578,6 +580,9 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, |
| # The small T5 model needs higher percentages for CPU/MP tests |
| model_split_percents = [0.5, 0.8, 0.9] |
| |
| + # used in `test_torch_compile` |
| + _torch_compile_test_ckpt = "google-t5/t5-small" |
| + |
| def setUp(self): |
| self.model_tester = T5ModelTester(self) |
| self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37) |
| @@ -630,12 +635,9 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa |
| ] |
| if labels is not None: |
| input_names.append("labels") |
| - |
| filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} |
| input_names = list(filtered_inputs.keys()) |
| - |
| model_output = model(**filtered_inputs) |
| - |
| traced_model = symbolic_trace(model, input_names) |
| traced_output = traced_model(**filtered_inputs) |
| else: |
| @@ -650,7 +652,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa |
| "visual_feats", |
| "visual_pos", |
| ] |
| - |
| labels = inputs.get("labels", None) |
| start_positions = inputs.get("start_positions", None) |
| end_positions = inputs.get("end_positions", None) |
| @@ -660,15 +661,12 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa |
| input_names.append("start_positions") |
| if end_positions is not None: |
| input_names.append("end_positions") |
| - |
| filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} |
| input_names = list(filtered_inputs.keys()) |
| - |
| if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( |
| not hasattr(model.config, "problem_type") or model.config.problem_type is None |
| ): |
| model.config.problem_type = "single_label_classification" |
| - |
| traced_model = symbolic_trace(model, input_names) |
| traced_output = traced_model(**filtered_inputs) |
| model_output = model(**filtered_inputs) |
| @@ -721,6 +719,41 @@ def flatten_output(output): |
| # (Even with this call, there are still memory leak by ~0.04MB) |
| self.clear_torch_jit_class_registry() |
| |
| + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` |
| + def test_custom_4d_attention_mask(self): |
| + for model_class in self.all_generative_model_classes: |
| + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| + model = model_class(config).to(device=torch_device, dtype=torch.float32) |
| + |
| + ( |
| + input_ids, |
| + _, |
| + input_ids_shared_prefix, |
| + mask_shared_prefix, |
| + _, |
| + ) = self._get_custom_4d_mask_test_data() |
| + |
| + logits = model.forward( |
| + decoder_input_ids=input_ids, |
| + input_ids=input_dict["input_ids"][:3], |
| + ).logits |
| + # logits.shape == torch.Size([3, 4, ...]) |
| + |
| + logits_shared_prefix = model( |
| + input_ids=input_dict["input_ids"][:1], |
| + decoder_input_ids=input_ids_shared_prefix, |
| + decoder_attention_mask=mask_shared_prefix, |
| + )[0] |
| + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) |
| + |
| + out_last_tokens = logits[:, -1, :] # last tokens in each batch line |
| + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens |
| + |
| + # comparing softmax-normalized logits: |
| + normalized_0 = F.softmax(out_last_tokens) |
| + normalized_1 = F.softmax(out_shared_prefix_last_tokens) |
| + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) |
| + |
| def test_config(self): |
| self.config_tester.run_common_tests() |
| |
| @@ -1482,6 +1515,7 @@ def test_summarization(self): |
| [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], |
| padding="max_length", |
| truncation=True, |
| + max_length=512, |
| return_tensors="pt", |
| ).to(torch_device) |
| self.assertEqual(512, dct["input_ids"].shape[1]) |
| @@ -1604,14 +1638,76 @@ def test_contrastive_search_t5(self): |
| outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) |
| generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| |
| + # TODO: @arthur? |
| + # PR #31938 caused regression on this test which was fixed by PR #34089 |
| self.assertListEqual( |
| generated_text, |
| [ |
| - "Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for " |
| - "permanent residence after the marriages, prosecutors say." |
| + "Liana Barrientos has been married 10 times, nine of them in the Bronx . Her husbands filed for " |
| + "permanent residence after the marriages, prosecutors say ." |
| ], |
| ) |
| |
| + @slow |
| + @require_torch_gpu |
| + def test_compile_static_cache(self): |
| + NUM_TOKENS_TO_GENERATE = 40 |
| + EXPECTED_TEXT_COMPLETION = [ |
| + "theory of relativity states that 1) the speed of light is constant in all inertial reference frames. the laws of physics are the same for all inertial reference frames.", |
| + "ketchup is my favorite condiment.", |
| + ] |
| + |
| + prompts = [ |
| + "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " |
| + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " |
| + "theory of relativity is not hard to grasp.", |
| + "summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " |
| + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.", |
| + ] |
| + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small").to(torch_device) |
| + tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") |
| + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) |
| + |
| + # Dynamic Cache |
| + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) |
| + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
| + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) |
| + |
| + # Static Cache |
| + generated_ids = model.generate( |
| + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" |
| + ) |
| + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
| + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) |
| + |
| + # Static Cache + compile |
| + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) |
| + generated_ids = model.generate( |
| + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" |
| + ) |
| + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
| + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) |
| + |
| + @slow |
| + @require_torch_gpu |
| + def test_compile_static_cache_encoder(self): |
| + prompts = [ |
| + "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " |
| + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " |
| + "theory of relativity is not hard to grasp.", |
| + "summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " |
| + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.", |
| + ] |
| + model = T5EncoderModel.from_pretrained("google-t5/t5-small").to(torch_device) |
| + tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") |
| + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) |
| + |
| + logits = model(**inputs) |
| + |
| + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) |
| + logits_compiled = model(**inputs) |
| + self.assertTrue(torch.allclose(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], atol=1e-5)) |
| + |
| |
| @require_torch |
| class TestAsymmetricT5(unittest.TestCase): |
| |
| |
| |
| |
| @@ -37,6 +37,7 @@ |
| |
| if is_torch_available(): |
| import torch |
| + import torch.nn.functional as F |
| |
| from transformers import UdopEncoderModel, UdopForConditionalGeneration, UdopModel, UdopProcessor |
| |
| @@ -348,6 +349,7 @@ def test_forward_signature(self): |
| expected_arg_names = [ |
| "attention_mask", |
| "bbox", |
| + "cache_position", |
| "cross_attn_head_mask", |
| "decoder_attention_mask", |
| "decoder_head_mask", |
| @@ -365,6 +367,43 @@ def test_forward_signature(self): |
| expected_arg_names = sorted(expected_arg_names) |
| self.assertListEqual(sorted(arg_names[: len(expected_arg_names)]), expected_arg_names) |
| |
| + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` |
| + def test_custom_4d_attention_mask(self): |
| + for model_class in self.all_generative_model_classes: |
| + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| + model = model_class(config).to(device=torch_device, dtype=torch.float32) |
| + |
| + ( |
| + input_ids, |
| + _, |
| + input_ids_shared_prefix, |
| + mask_shared_prefix, |
| + _, |
| + ) = self._get_custom_4d_mask_test_data() |
| + |
| + logits = model.forward( |
| + decoder_input_ids=input_ids, |
| + input_ids=input_dict["input_ids"][:3], |
| + bbox=input_dict["bbox"][:3], |
| + ).logits |
| + # logits.shape == torch.Size([3, 4, ...]) |
| + |
| + logits_shared_prefix = model( |
| + input_ids=input_dict["input_ids"][:1], |
| + bbox=input_dict["bbox"][:1], |
| + decoder_input_ids=input_ids_shared_prefix, |
| + decoder_attention_mask=mask_shared_prefix, |
| + )[0] |
| + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) |
| + |
| + out_last_tokens = logits[:, -1, :] # last tokens in each batch line |
| + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens |
| + |
| + # comparing softmax-normalized logits: |
| + normalized_0 = F.softmax(out_last_tokens) |
| + normalized_1 = F.softmax(out_shared_prefix_last_tokens) |
| + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) |
| + |
| @unittest.skip( |
| "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" |
| ) |
| @@ -534,6 +573,41 @@ def test_model(self): |
| config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| self.model_tester.create_and_check_model(*config_and_inputs) |
| |
| + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` |
| + def test_custom_4d_attention_mask(self): |
| + for model_class in self.all_generative_model_classes: |
| + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| + model = model_class(config).to(device=torch_device, dtype=torch.float32) |
| + |
| + ( |
| + input_ids, |
| + _, |
| + input_ids_shared_prefix, |
| + mask_shared_prefix, |
| + _, |
| + ) = self._get_custom_4d_mask_test_data() |
| + |
| + logits = model.forward( |
| + decoder_input_ids=input_ids, |
| + input_ids=input_dict["input_ids"][:3], |
| + ).logits |
| + # logits.shape == torch.Size([3, 4, ...]) |
| + |
| + logits_shared_prefix = model( |
| + input_ids=input_dict["input_ids"][:1], |
| + decoder_input_ids=input_ids_shared_prefix, |
| + decoder_attention_mask=mask_shared_prefix, |
| + )[0] |
| + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) |
| + |
| + out_last_tokens = logits[:, -1, :] # last tokens in each batch line |
| + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens |
| + |
| + # comparing softmax-normalized logits: |
| + normalized_0 = F.softmax(out_last_tokens) |
| + normalized_1 = F.softmax(out_shared_prefix_last_tokens) |
| + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) |
| + |
| @unittest.skip( |
| "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" |
| ) |
| |
| |
| |
| |
| @@ -41,6 +41,7 @@ |
| |
| if is_torch_available(): |
| import torch |
| + import torch.nn.functional as F |
| |
| from transformers import ( |
| AutoTokenizer, |
| @@ -316,6 +317,9 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin |
| # The small UMT5 model needs higher percentages for CPU/MP tests |
| model_split_percents = [0.5, 0.8, 0.9] |
| |
| + # used in `test_torch_compile` |
| + _torch_compile_test_ckpt = "google/umt5-small" |
| + |
| def setUp(self): |
| self.model_tester = UMT5ModelTester(self) |
| |
| @@ -486,6 +490,41 @@ def test_inputs_embeds(self): |
| with torch.no_grad(): |
| model(**inputs)[0] |
| |
| + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` |
| + def test_custom_4d_attention_mask(self): |
| + for model_class in self.all_generative_model_classes: |
| + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() |
| + model = model_class(config).to(device=torch_device, dtype=torch.float32) |
| + |
| + ( |
| + input_ids, |
| + _, |
| + input_ids_shared_prefix, |
| + mask_shared_prefix, |
| + _, |
| + ) = self._get_custom_4d_mask_test_data() |
| + |
| + logits = model.forward( |
| + decoder_input_ids=input_ids, |
| + input_ids=input_dict["input_ids"][:3], |
| + ).logits |
| + # logits.shape == torch.Size([3, 4, ...]) |
| + |
| + logits_shared_prefix = model( |
| + input_ids=input_dict["input_ids"][:1], |
| + decoder_input_ids=input_ids_shared_prefix, |
| + decoder_attention_mask=mask_shared_prefix, |
| + )[0] |
| + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) |
| + |
| + out_last_tokens = logits[:, -1, :] # last tokens in each batch line |
| + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens |
| + |
| + # comparing softmax-normalized logits: |
| + normalized_0 = F.softmax(out_last_tokens) |
| + normalized_1 = F.softmax(out_shared_prefix_last_tokens) |
| + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) |
| + |
| def test_with_sequence_classification_head(self): |
| config_and_inputs = self.model_tester.prepare_config_and_inputs() |
| self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) |
| |
| |
| |
| |
| @@ -37,6 +37,7 @@ |
| from transformers import ( |
| AutoModel, |
| AutoModelForCausalLM, |
| + AutoModelForSeq2SeqLM, |
| AutoModelForSequenceClassification, |
| AutoTokenizer, |
| GenerationConfig, |
| @@ -5109,10 +5110,15 @@ def test_torch_compile(self): |
| batch_size = 1 |
| n_iter = 3 |
| |
| - tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision) |
| - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( |
| - torch_device |
| - ) |
| + tokenizer = AutoTokenizer.from_pretrained(ckpt) |
| + if self.is_encoder_decoder: |
| + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( |
| + torch_device |
| + ) |
| + else: |
| + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( |
| + torch_device |
| + ) |
| |
| model.generation_config.max_new_tokens = 4 |
| |
| @@ -5184,10 +5190,15 @@ def test_compile_cuda_graph_time(self): |
| |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
| - tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision) |
| - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( |
| - torch_device |
| - ) |
| + tokenizer = AutoTokenizer.from_pretrained(ckpt) |
| + if self.is_encoder_decoder: |
| + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( |
| + torch_device |
| + ) |
| + else: |
| + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( |
| + torch_device |
| + ) |
| |
| cache_implementation = "static" |
| if model.config.model_type == "gemma2": |
|
|