|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch BLOOM model.""" |
|
|
|
|
|
import math |
|
|
import warnings |
|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from ...cache_utils import Cache, DynamicCache, StaticCache |
|
|
from ...generation import GenerationMixin |
|
|
from ...modeling_attn_mask_utils import AttentionMaskConverter |
|
|
from ...modeling_layers import GradientCheckpointingLayer |
|
|
from ...modeling_outputs import ( |
|
|
BaseModelOutputWithPastAndCrossAttentions, |
|
|
CausalLMOutputWithCrossAttentions, |
|
|
QuestionAnsweringModelOutput, |
|
|
SequenceClassifierOutputWithPast, |
|
|
TokenClassifierOutput, |
|
|
) |
|
|
from ...modeling_utils import PreTrainedModel |
|
|
from ...utils import ( |
|
|
auto_docstring, |
|
|
is_torch_flex_attn_available, |
|
|
logging, |
|
|
) |
|
|
from .configuration_bloom import BloomConfig |
|
|
|
|
|
|
|
|
if is_torch_flex_attn_available(): |
|
|
from torch.nn.attention.flex_attention import BlockMask |
|
|
|
|
|
from ...integrations.flex_attention import make_flex_block_causal_mask |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: |
|
|
""" |
|
|
Link to paper: https://huggingface.co/papers/2108.12409 Alibi tensor is not causal as the original paper mentions, it |
|
|
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value |
|
|
`softmax(l+a) = softmax(l)`. Based on |
|
|
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 |
|
|
TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. |
|
|
|
|
|
Args: |
|
|
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) |
|
|
attention_mask (`torch.Tensor`): |
|
|
Token-wise attention mask, this should be of shape (batch_size, max_seq_len). |
|
|
num_heads (`int`): |
|
|
number of heads |
|
|
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): |
|
|
dtype of the output tensor |
|
|
""" |
|
|
batch_size, seq_length = attention_mask.shape |
|
|
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) |
|
|
base = torch.tensor( |
|
|
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 |
|
|
) |
|
|
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) |
|
|
slopes = torch.pow(base, powers) |
|
|
|
|
|
if closest_power_of_2 != num_heads: |
|
|
extra_base = torch.tensor( |
|
|
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 |
|
|
) |
|
|
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) |
|
|
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) |
|
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] |
|
|
alibi = slopes[..., None] * arange_tensor |
|
|
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) |
|
|
|
|
|
|
|
|
def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor: |
|
|
""" |
|
|
Dropout add function |
|
|
|
|
|
Args: |
|
|
x (`torch.tensor`): |
|
|
input tensor |
|
|
residual (`torch.tensor`): |
|
|
residual tensor |
|
|
prob (`float`): |
|
|
dropout probability |
|
|
training (`bool`): |
|
|
training mode |
|
|
""" |
|
|
out = F.dropout(x, p=prob, training=training) |
|
|
out = residual + out |
|
|
return out |
|
|
|
|
|
|
|
|
def bloom_gelu_forward(x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to |
|
|
make the model jitable. |
|
|
|
|
|
Args: |
|
|
x (`torch.tensor`): |
|
|
input hidden states |
|
|
""" |
|
|
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) |
|
|
|
|
|
|
|
|
def bloom_gelu_back(g: torch.Tensor, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) + |
|
|
0.3989423 * x * torch.exp(-0.5 * x * x) |
|
|
|
|
|
Args: |
|
|
g (`torch.tensor`): |
|
|
gradient output tensor |
|
|
x (`torch.tensor`): |
|
|
input tensor |
|
|
""" |
|
|
x = x[0] |
|
|
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) |
|
|
|
|
|
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) |
|
|
return ff * g |
|
|
|
|
|
|
|
|
class GeLUFunction(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, input: torch.Tensor) -> torch.Tensor: |
|
|
ctx.save_for_backward(input) |
|
|
return bloom_gelu_forward(input) |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: |
|
|
input = ctx.saved_tensors |
|
|
tmp = bloom_gelu_back(grad_output, input) |
|
|
return tmp |
|
|
|
|
|
|
|
|
class BloomGelu(nn.Module): |
|
|
""" |
|
|
BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model |
|
|
torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly |
|
|
copied from Megatron-DeepSpeed code and adapted for our needs |
|
|
|
|
|
See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329 |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
if self.training: |
|
|
return GeLUFunction.apply(x) |
|
|
else: |
|
|
return bloom_gelu_forward(x) |
|
|
|
|
|
|
|
|
class BloomAttention(nn.Module): |
|
|
def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None): |
|
|
super().__init__() |
|
|
|
|
|
self.pretraining_tp = config.pretraining_tp |
|
|
self.slow_but_exact = config.slow_but_exact |
|
|
|
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.n_head |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
self.split_size = self.hidden_size |
|
|
self.hidden_dropout = config.hidden_dropout |
|
|
|
|
|
if self.head_dim * self.num_heads != self.hidden_size: |
|
|
raise ValueError( |
|
|
f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:" |
|
|
f" {self.num_heads})." |
|
|
) |
|
|
|
|
|
|
|
|
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) |
|
|
self.beta = 1.0 |
|
|
self.layer_idx = layer_idx |
|
|
if layer_idx is None: |
|
|
logger.warning_once( |
|
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
|
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
|
|
"when creating this class." |
|
|
) |
|
|
|
|
|
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True) |
|
|
self.dense = nn.Linear(self.hidden_size, self.hidden_size) |
|
|
self.attention_dropout = nn.Dropout(config.attention_dropout) |
|
|
|
|
|
def _reshape(self, fused_qkv: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Split the last dimension into (num_heads, head_dim) and reshapes to (bs, heads, len, dim) shape |
|
|
without making any copies, results share same memory storage as `fused_qkv` |
|
|
|
|
|
Args: |
|
|
fused_qkv (`torch.tensor`): [batch_size, seq_length, num_heads * 3 * head_dim] |
|
|
|
|
|
Returns: |
|
|
query: [batch_size, num_heads, seq_length, head_dim] |
|
|
key: [batch_size, num_heads, seq_length, head_dim] |
|
|
value: [batch_size, num_heads, seq_length, head_dim] |
|
|
""" |
|
|
batch_size, seq_length, three_times_hidden_size = fused_qkv.shape |
|
|
fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) |
|
|
query_layer = fused_qkv[..., 0, :].transpose(1, 2) |
|
|
key_layer = fused_qkv[..., 1, :].transpose(1, 2) |
|
|
value_layer = fused_qkv[..., 2, :].transpose(1, 2) |
|
|
return query_layer, key_layer, value_layer |
|
|
|
|
|
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Merge heads together over the last dimension |
|
|
|
|
|
Args: |
|
|
x (`torch.tensor`): [batch_size * num_heads, seq_length, head_dim] |
|
|
|
|
|
Returns: |
|
|
torch.tensor: [batch_size, seq_length, num_heads * head_dim] |
|
|
""" |
|
|
|
|
|
|
|
|
batch_size_and_num_heads, seq_length, _ = x.shape |
|
|
batch_size = batch_size_and_num_heads // self.num_heads |
|
|
|
|
|
|
|
|
|
|
|
x = x.view(batch_size, self.num_heads, seq_length, self.head_dim) |
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 1, 3) |
|
|
|
|
|
|
|
|
return x.reshape(batch_size, seq_length, self.num_heads * self.head_dim) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
residual: torch.Tensor, |
|
|
alibi: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
layer_past: Optional[Cache] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
use_cache: bool = False, |
|
|
output_attentions: bool = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
): |
|
|
batch_size, q_length, _ = hidden_states.shape |
|
|
fused_qkv = self.query_key_value(hidden_states) |
|
|
|
|
|
query_layer, key_layer, value_layer = self._reshape(fused_qkv) |
|
|
|
|
|
if layer_past is not None: |
|
|
cache_kwargs = {"cache_position": cache_position} |
|
|
key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) |
|
|
|
|
|
|
|
|
query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) |
|
|
key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2) |
|
|
value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) |
|
|
|
|
|
|
|
|
attention_scores = alibi.baddbmm( |
|
|
batch1=query_layer, |
|
|
batch2=key_layer, |
|
|
beta=self.beta, |
|
|
alpha=self.inv_norm_factor, |
|
|
) |
|
|
|
|
|
|
|
|
attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1) |
|
|
if attention_mask is not None: |
|
|
causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]] |
|
|
attn_weights = attn_weights + causal_mask |
|
|
|
|
|
|
|
|
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype) |
|
|
|
|
|
|
|
|
attention_probs = self.attention_dropout(attention_probs) |
|
|
|
|
|
if head_mask is not None: |
|
|
attention_probs = attention_probs * head_mask |
|
|
|
|
|
|
|
|
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) |
|
|
|
|
|
|
|
|
context_layer = torch.bmm(attention_probs_reshaped, value_layer) |
|
|
|
|
|
|
|
|
context_layer = self._merge_heads(context_layer) |
|
|
|
|
|
|
|
|
if self.pretraining_tp > 1 and self.slow_but_exact: |
|
|
slices = self.hidden_size / self.pretraining_tp |
|
|
output_tensor = torch.zeros_like(context_layer) |
|
|
for i in range(self.pretraining_tp): |
|
|
output_tensor = output_tensor + F.linear( |
|
|
context_layer[:, :, int(i * slices) : int((i + 1) * slices)], |
|
|
self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], |
|
|
) |
|
|
else: |
|
|
output_tensor = self.dense(context_layer) |
|
|
|
|
|
output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) |
|
|
return output_tensor, attention_probs |
|
|
|
|
|
|
|
|
class BloomMLP(nn.Module): |
|
|
def __init__(self, config: BloomConfig): |
|
|
super().__init__() |
|
|
hidden_size = config.hidden_size |
|
|
|
|
|
self.pretraining_tp = config.pretraining_tp |
|
|
self.slow_but_exact = config.slow_but_exact |
|
|
self.dense_h_to_4h = nn.Linear(hidden_size, 4 * hidden_size) |
|
|
self.gelu_impl = BloomGelu() |
|
|
self.dense_4h_to_h = nn.Linear(4 * hidden_size, hidden_size) |
|
|
self.hidden_dropout = config.hidden_dropout |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states)) |
|
|
|
|
|
if self.pretraining_tp > 1 and self.slow_but_exact: |
|
|
intermediate_output = torch.zeros_like(residual) |
|
|
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp |
|
|
for i in range(self.pretraining_tp): |
|
|
intermediate_output = intermediate_output + F.linear( |
|
|
hidden_states[:, :, int(i * slices) : int((i + 1) * slices)], |
|
|
self.dense_4h_to_h.weight[:, int(i * slices) : int((i + 1) * slices)], |
|
|
) |
|
|
else: |
|
|
intermediate_output = self.dense_4h_to_h(hidden_states) |
|
|
|
|
|
output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class BloomBlock(GradientCheckpointingLayer): |
|
|
def __init__(self, config: BloomConfig, layer_idx: Optional[int] = None): |
|
|
super().__init__() |
|
|
hidden_size = config.hidden_size |
|
|
|
|
|
self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
|
|
self.num_heads = config.n_head |
|
|
self.self_attention = BloomAttention(config, layer_idx) |
|
|
self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) |
|
|
|
|
|
self.mlp = BloomMLP(config) |
|
|
|
|
|
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm |
|
|
self.hidden_dropout = config.hidden_dropout |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
alibi: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
layer_past: Optional[Cache] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
use_cache: bool = False, |
|
|
output_attentions: bool = False, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
layernorm_output = self.input_layernorm(hidden_states) |
|
|
|
|
|
|
|
|
if self.apply_residual_connection_post_layernorm: |
|
|
residual = layernorm_output |
|
|
else: |
|
|
residual = hidden_states |
|
|
|
|
|
|
|
|
attention_output, attn_weights = self.self_attention( |
|
|
layernorm_output, |
|
|
residual, |
|
|
layer_past=layer_past, |
|
|
attention_mask=attention_mask, |
|
|
alibi=alibi, |
|
|
head_mask=head_mask, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
layernorm_output = self.post_attention_layernorm(attention_output) |
|
|
|
|
|
|
|
|
if self.apply_residual_connection_post_layernorm: |
|
|
residual = layernorm_output |
|
|
else: |
|
|
residual = attention_output |
|
|
|
|
|
|
|
|
output = self.mlp(layernorm_output, residual) |
|
|
|
|
|
return output, attn_weights |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BloomPreTrainedModel(PreTrainedModel): |
|
|
config: BloomConfig |
|
|
base_model_prefix = "transformer" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["BloomBlock"] |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
|
|
_can_compile_fullgraph = True |
|
|
|
|
|
def __init__(self, *inputs, **kwargs): |
|
|
super().__init__(*inputs, **kwargs) |
|
|
|
|
|
def _init_weights(self, module: nn.Module): |
|
|
"""Initialize the weights.""" |
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BloomModel(BloomPreTrainedModel): |
|
|
def __init__(self, config: BloomConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.embed_dim = config.hidden_size |
|
|
self.num_heads = config.n_head |
|
|
|
|
|
|
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) |
|
|
self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) |
|
|
|
|
|
|
|
|
self.h = nn.ModuleList([BloomBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) |
|
|
|
|
|
|
|
|
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: |
|
|
return build_alibi_tensor(attention_mask, num_heads, dtype) |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.word_embeddings |
|
|
|
|
|
def set_input_embeddings(self, new_embeddings: torch.Tensor): |
|
|
self.word_embeddings = new_embeddings |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**deprecated_arguments, |
|
|
) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: |
|
|
r""" |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): |
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` |
|
|
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. |
|
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as |
|
|
`input_ids`. |
|
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
""" |
|
|
if deprecated_arguments.pop("position_ids", False) is not False: |
|
|
|
|
|
warnings.warn( |
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" |
|
|
" passing `position_ids`.", |
|
|
FutureWarning, |
|
|
) |
|
|
if len(deprecated_arguments) > 0: |
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
|
logger.warning_once( |
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
|
) |
|
|
use_cache = False |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
|
|
|
if use_cache and past_key_values is None: |
|
|
past_key_values = DynamicCache(config=self.config) |
|
|
|
|
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
|
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
seq_length_with_past = seq_length + past_length |
|
|
if cache_position is None: |
|
|
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.n_layer) |
|
|
hidden_states = self.word_embeddings_layernorm(inputs_embeds) |
|
|
|
|
|
all_self_attentions = () if output_attentions else None |
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) |
|
|
else: |
|
|
attention_mask = attention_mask.to(hidden_states.device) |
|
|
|
|
|
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) |
|
|
causal_mask = self._update_causal_mask( |
|
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
|
|
) |
|
|
|
|
|
for i, block in enumerate(self.h): |
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
outputs = block( |
|
|
hidden_states, |
|
|
layer_past=past_key_values, |
|
|
attention_mask=causal_mask, |
|
|
head_mask=head_mask[i], |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
alibi=alibi, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
hidden_states = outputs[0] |
|
|
if output_attentions: |
|
|
all_self_attentions = all_self_attentions + (outputs[1],) |
|
|
|
|
|
|
|
|
hidden_states = self.ln_f(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
if not return_dict: |
|
|
return tuple( |
|
|
v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None |
|
|
) |
|
|
|
|
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
def _update_causal_mask( |
|
|
self, |
|
|
attention_mask: Union[torch.Tensor, "BlockMask"], |
|
|
input_tensor: torch.Tensor, |
|
|
cache_position: torch.Tensor, |
|
|
past_key_values: Cache, |
|
|
output_attentions: bool = False, |
|
|
): |
|
|
if self.config._attn_implementation == "flash_attention_2": |
|
|
if attention_mask is not None and (attention_mask == 0.0).any(): |
|
|
return attention_mask |
|
|
return None |
|
|
if self.config._attn_implementation == "flex_attention": |
|
|
if isinstance(attention_mask, torch.Tensor): |
|
|
attention_mask = make_flex_block_causal_mask(attention_mask) |
|
|
return attention_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
|
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False |
|
|
|
|
|
|
|
|
if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: |
|
|
if AttentionMaskConverter._ignore_causal_mask_sdpa( |
|
|
attention_mask, |
|
|
inputs_embeds=input_tensor, |
|
|
past_key_values_length=past_seen_tokens, |
|
|
is_training=self.training, |
|
|
): |
|
|
return None |
|
|
|
|
|
dtype = input_tensor.dtype |
|
|
sequence_length = input_tensor.shape[1] |
|
|
if using_compilable_cache: |
|
|
target_length = past_key_values.get_max_cache_shape() |
|
|
else: |
|
|
target_length = ( |
|
|
attention_mask.shape[-1] |
|
|
if isinstance(attention_mask, torch.Tensor) |
|
|
else past_seen_tokens + sequence_length + 1 |
|
|
) |
|
|
|
|
|
|
|
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( |
|
|
attention_mask, |
|
|
sequence_length=sequence_length, |
|
|
target_length=target_length, |
|
|
dtype=dtype, |
|
|
cache_position=cache_position, |
|
|
batch_size=input_tensor.shape[0], |
|
|
) |
|
|
|
|
|
if ( |
|
|
self.config._attn_implementation == "sdpa" |
|
|
and attention_mask is not None |
|
|
and attention_mask.device.type in ["cuda", "xpu", "npu"] |
|
|
and not output_attentions |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
min_dtype = torch.finfo(dtype).min |
|
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
|
|
|
return causal_mask |
|
|
|
|
|
@staticmethod |
|
|
|
|
|
def _prepare_4d_causal_attention_mask_with_cache_position( |
|
|
attention_mask: torch.Tensor, |
|
|
sequence_length: int, |
|
|
target_length: int, |
|
|
dtype: torch.dtype, |
|
|
cache_position: torch.Tensor, |
|
|
batch_size: int, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
|
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
|
|
|
|
|
Args: |
|
|
attention_mask (`torch.Tensor`): |
|
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape |
|
|
`(batch_size, 1, query_length, key_value_length)`. |
|
|
sequence_length (`int`): |
|
|
The sequence length being processed. |
|
|
target_length (`int`): |
|
|
The target length: when generating with static cache, the mask should be as long as the static cache, |
|
|
to account for the 0 padding, the part of the cache that is not filled yet. |
|
|
dtype (`torch.dtype`): |
|
|
The dtype to use for the 4D attention mask. |
|
|
cache_position (`torch.Tensor`): |
|
|
Indices depicting the position of the input sequence tokens in the sequence. |
|
|
batch_size (`torch.Tensor`): |
|
|
Batch size. |
|
|
""" |
|
|
if attention_mask is not None and attention_mask.dim() == 4: |
|
|
|
|
|
causal_mask = attention_mask |
|
|
else: |
|
|
min_dtype = torch.finfo(dtype).min |
|
|
causal_mask = torch.full( |
|
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device |
|
|
) |
|
|
if sequence_length != 1: |
|
|
causal_mask = torch.triu(causal_mask, diagonal=1) |
|
|
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) |
|
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
|
|
if attention_mask is not None: |
|
|
causal_mask = causal_mask.clone() |
|
|
mask_length = attention_mask.shape[-1] |
|
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( |
|
|
causal_mask.device |
|
|
) |
|
|
padding_mask = padding_mask == 0 |
|
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
|
|
padding_mask, min_dtype |
|
|
) |
|
|
|
|
|
return causal_mask |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input |
|
|
embeddings). |
|
|
""" |
|
|
) |
|
|
class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin): |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config: BloomConfig): |
|
|
super().__init__(config) |
|
|
self.transformer = BloomModel(config) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings: torch.Tensor): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
inputs_embeds=None, |
|
|
cache_position=None, |
|
|
use_cache=True, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
|
if inputs_embeds is not None and input_ids.shape[1] == 0: |
|
|
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] |
|
|
elif ( |
|
|
inputs_embeds is not None |
|
|
or cache_position[-1] >= input_ids.shape[1] |
|
|
): |
|
|
input_ids = input_ids[:, -cache_position.shape[0] :] |
|
|
elif input_ids.shape[1] != cache_position.shape[0]: |
|
|
input_ids = input_ids[:, cache_position] |
|
|
|
|
|
|
|
|
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: |
|
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(past_key_values, StaticCache) and attention_mask is not None: |
|
|
target_length = past_key_values.get_max_cache_shape() |
|
|
batch_size, seq_length = attention_mask.shape |
|
|
diff = target_length - seq_length |
|
|
|
|
|
new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype) |
|
|
attention_mask = torch.cat( |
|
|
[attention_mask, new_attn_mask], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
model_inputs.update( |
|
|
{ |
|
|
"cache_position": cache_position, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": use_cache, |
|
|
"attention_mask": attention_mask, |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
if key not in model_inputs: |
|
|
model_inputs[key] = value |
|
|
|
|
|
return model_inputs |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**deprecated_arguments, |
|
|
) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: |
|
|
r""" |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): |
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` |
|
|
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. |
|
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as |
|
|
`input_ids`. |
|
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set |
|
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` |
|
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` |
|
|
""" |
|
|
|
|
|
num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None) |
|
|
if deprecated_arguments.pop("position_ids", False) is not False: |
|
|
|
|
|
warnings.warn( |
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" |
|
|
" passing `position_ids`.", |
|
|
FutureWarning, |
|
|
) |
|
|
if len(deprecated_arguments) > 0: |
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
transformer_outputs = self.transformer( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
attention_mask=attention_mask, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
hidden_states = transformer_outputs[0] |
|
|
|
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
labels = labels.to(lm_logits.device) |
|
|
|
|
|
loss = self.loss_function( |
|
|
lm_logits, |
|
|
labels, |
|
|
vocab_size=self.config.vocab_size, |
|
|
num_items_in_batch=num_items_in_batch, |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
output = (lm_logits,) + transformer_outputs[1:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
loss=loss, |
|
|
logits=lm_logits, |
|
|
past_key_values=transformer_outputs.past_key_values, |
|
|
hidden_states=transformer_outputs.hidden_states, |
|
|
attentions=transformer_outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
The Bloom Model transformer with a sequence classification head on top (linear layer). |
|
|
|
|
|
[`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models |
|
|
(e.g. GPT-1) do. |
|
|
|
|
|
Since it does classification on the last token, it requires to know the position of the last token. If a |
|
|
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If |
|
|
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the |
|
|
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in |
|
|
each row of the batch). |
|
|
""" |
|
|
) |
|
|
class BloomForSequenceClassification(BloomPreTrainedModel): |
|
|
def __init__(self, config: BloomConfig): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.transformer = BloomModel(config) |
|
|
self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**deprecated_arguments, |
|
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]: |
|
|
r""" |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): |
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` |
|
|
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. |
|
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as |
|
|
`input_ids`. |
|
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
|
""" |
|
|
if deprecated_arguments.pop("position_ids", False) is not False: |
|
|
|
|
|
warnings.warn( |
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" |
|
|
" passing `position_ids`.", |
|
|
FutureWarning, |
|
|
) |
|
|
if len(deprecated_arguments) > 0: |
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
transformer_outputs = self.transformer( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
attention_mask=attention_mask, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
hidden_states = transformer_outputs[0] |
|
|
logits = self.score(hidden_states) |
|
|
|
|
|
if input_ids is not None: |
|
|
batch_size = input_ids.shape[0] |
|
|
else: |
|
|
batch_size = inputs_embeds.shape[0] |
|
|
|
|
|
if self.config.pad_token_id is None and batch_size != 1: |
|
|
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") |
|
|
if self.config.pad_token_id is None: |
|
|
last_non_pad_token = -1 |
|
|
elif input_ids is not None: |
|
|
|
|
|
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) |
|
|
token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) |
|
|
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) |
|
|
else: |
|
|
last_non_pad_token = -1 |
|
|
logger.warning_once( |
|
|
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " |
|
|
"unexpected if using padding tokens in conjunction with `inputs_embeds.`" |
|
|
) |
|
|
|
|
|
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.config.problem_type is None: |
|
|
if self.num_labels == 1: |
|
|
self.config.problem_type = "regression" |
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
|
self.config.problem_type = "single_label_classification" |
|
|
else: |
|
|
self.config.problem_type = "multi_label_classification" |
|
|
|
|
|
if self.config.problem_type == "regression": |
|
|
loss_fct = MSELoss() |
|
|
if self.num_labels == 1: |
|
|
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) |
|
|
else: |
|
|
loss = loss_fct(pooled_logits, labels) |
|
|
elif self.config.problem_type == "single_label_classification": |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss = loss_fct(pooled_logits, labels) |
|
|
elif self.config.problem_type == "multi_label_classification": |
|
|
loss_fct = BCEWithLogitsLoss() |
|
|
loss = loss_fct(pooled_logits, labels) |
|
|
if not return_dict: |
|
|
output = (pooled_logits,) + transformer_outputs[1:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutputWithPast( |
|
|
loss=loss, |
|
|
logits=pooled_logits, |
|
|
past_key_values=transformer_outputs.past_key_values, |
|
|
hidden_states=transformer_outputs.hidden_states, |
|
|
attentions=transformer_outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BloomForTokenClassification(BloomPreTrainedModel): |
|
|
def __init__(self, config: BloomConfig): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.transformer = BloomModel(config) |
|
|
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: |
|
|
classifier_dropout = config.classifier_dropout |
|
|
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: |
|
|
classifier_dropout = config.hidden_dropout |
|
|
else: |
|
|
classifier_dropout = 0.1 |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor, torch.Tensor], ...]]] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**deprecated_arguments, |
|
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: |
|
|
r""" |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): |
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` |
|
|
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. |
|
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as |
|
|
`input_ids`. |
|
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
|
""" |
|
|
if deprecated_arguments.pop("position_ids", False) is not False: |
|
|
|
|
|
warnings.warn( |
|
|
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" |
|
|
" passing `position_ids`.", |
|
|
FutureWarning, |
|
|
) |
|
|
if len(deprecated_arguments) > 0: |
|
|
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
transformer_outputs = self.transformer( |
|
|
input_ids, |
|
|
past_key_values=past_key_values, |
|
|
attention_mask=attention_mask, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
hidden_states = transformer_outputs[0] |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
logits = self.classifier(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
|
|
|
labels = labels.to(logits.device) |
|
|
batch_size, seq_length = labels.shape |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss = loss_fct( |
|
|
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) |
|
|
) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + transformer_outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return TokenClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=transformer_outputs.hidden_states, |
|
|
attentions=transformer_outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BloomForQuestionAnswering(BloomPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.transformer = BloomModel(config) |
|
|
self.qa_outputs = nn.Linear(config.hidden_size, 2) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
start_positions: Optional[torch.LongTensor] = None, |
|
|
end_positions: Optional[torch.LongTensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple, QuestionAnsweringModelOutput]: |
|
|
r""" |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): |
|
|
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values.get_seq_length()` |
|
|
(`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary. |
|
|
|
|
|
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as |
|
|
`input_ids`. |
|
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.transformer( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
|
|
logits = self.qa_outputs(sequence_output) |
|
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
|
start_logits = start_logits.squeeze(-1).contiguous() |
|
|
end_logits = end_logits.squeeze(-1).contiguous() |
|
|
|
|
|
total_loss = None |
|
|
if start_positions is not None and end_positions is not None: |
|
|
|
|
|
if len(start_positions.size()) > 1: |
|
|
start_positions = start_positions.squeeze(-1) |
|
|
if len(end_positions.size()) > 1: |
|
|
end_positions = end_positions.squeeze(-1) |
|
|
|
|
|
ignored_index = start_logits.size(1) |
|
|
start_positions = start_positions.clamp(0, ignored_index) |
|
|
end_positions = end_positions.clamp(0, ignored_index) |
|
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
|
|
start_loss = loss_fct(start_logits, start_positions) |
|
|
end_loss = loss_fct(end_logits, end_positions) |
|
|
total_loss = (start_loss + end_loss) / 2 |
|
|
|
|
|
if not return_dict: |
|
|
output = (start_logits, end_logits) + outputs[2:] |
|
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|
|
|
|
return QuestionAnsweringModelOutput( |
|
|
loss=total_loss, |
|
|
start_logits=start_logits, |
|
|
end_logits=end_logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"BloomForCausalLM", |
|
|
"BloomModel", |
|
|
"BloomPreTrainedModel", |
|
|
"BloomForSequenceClassification", |
|
|
"BloomForTokenClassification", |
|
|
"BloomForQuestionAnswering", |
|
|
] |
|
|
|