|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch AltCLIP model.""" |
|
|
|
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Callable, Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from ...activations import ACT2FN |
|
|
from ...modeling_layers import GradientCheckpointingLayer |
|
|
from ...modeling_outputs import ( |
|
|
BaseModelOutput, |
|
|
BaseModelOutputWithPooling, |
|
|
BaseModelOutputWithPoolingAndCrossAttentions, |
|
|
BaseModelOutputWithPoolingAndProjection, |
|
|
) |
|
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer |
|
|
from ...utils import ModelOutput, auto_docstring, can_return_tuple, filter_out_non_signature_kwargs, logging, torch_int |
|
|
from .configuration_altclip import AltCLIPConfig, AltCLIPTextConfig, AltCLIPVisionConfig |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: |
|
|
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) |
|
|
|
|
|
|
|
|
def clip_loss(similarity: torch.Tensor) -> torch.Tensor: |
|
|
caption_loss = contrastive_loss(similarity) |
|
|
image_loss = contrastive_loss(similarity.t()) |
|
|
return (caption_loss + image_loss) / 2.0 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
@auto_docstring |
|
|
|
|
|
class AltCLIPOutput(ModelOutput): |
|
|
r""" |
|
|
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): |
|
|
Contrastive loss for image-text similarity. |
|
|
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): |
|
|
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text |
|
|
similarity scores. |
|
|
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): |
|
|
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image |
|
|
similarity scores. |
|
|
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): |
|
|
The text embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPTextModel`]. |
|
|
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): |
|
|
The image embeddings obtained by applying the projection layer to the pooled output of [`AltCLIPVisionModel`]. |
|
|
text_model_output (`BaseModelOutputWithPooling`): |
|
|
The output of the [`AltCLIPTextModel`]. |
|
|
vision_model_output (`BaseModelOutputWithPooling`): |
|
|
The output of the [`AltCLIPVisionModel`]. |
|
|
""" |
|
|
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
logits_per_image: Optional[torch.FloatTensor] = None |
|
|
logits_per_text: Optional[torch.FloatTensor] = None |
|
|
text_embeds: Optional[torch.FloatTensor] = None |
|
|
image_embeds: Optional[torch.FloatTensor] = None |
|
|
text_model_output: BaseModelOutputWithPooling = None |
|
|
vision_model_output: BaseModelOutputWithPooling = None |
|
|
|
|
|
def to_tuple(self) -> tuple[Any]: |
|
|
return tuple( |
|
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() |
|
|
for k in self.keys() |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class AltRobertaEmbeddings(nn.Module): |
|
|
""" |
|
|
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
|
|
self.register_buffer( |
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False |
|
|
) |
|
|
self.register_buffer( |
|
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False |
|
|
) |
|
|
|
|
|
|
|
|
self.padding_idx = config.pad_token_id |
|
|
self.position_embeddings = nn.Embedding( |
|
|
config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 |
|
|
): |
|
|
if position_ids is None: |
|
|
if input_ids is not None: |
|
|
|
|
|
position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) |
|
|
else: |
|
|
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) |
|
|
|
|
|
if input_ids is not None: |
|
|
input_shape = input_ids.size() |
|
|
else: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
|
|
seq_length = input_shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if token_type_ids is None: |
|
|
if hasattr(self, "token_type_ids"): |
|
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) |
|
|
token_type_ids = buffered_token_type_ids_expanded |
|
|
else: |
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings |
|
|
if self.position_embedding_type == "absolute": |
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
|
embeddings += position_embeddings |
|
|
embeddings = self.LayerNorm(embeddings) |
|
|
embeddings = self.dropout(embeddings) |
|
|
return embeddings |
|
|
|
|
|
def create_position_ids_from_inputs_embeds(self, inputs_embeds): |
|
|
""" |
|
|
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. |
|
|
|
|
|
Args: |
|
|
inputs_embeds: torch.Tensor |
|
|
|
|
|
Returns: torch.Tensor |
|
|
""" |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
sequence_length = input_shape[1] |
|
|
|
|
|
position_ids = torch.arange( |
|
|
self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device |
|
|
) |
|
|
return position_ids.unsqueeze(0).expand(input_shape) |
|
|
|
|
|
|
|
|
class AltRobertaSelfAttention(nn.Module): |
|
|
def __init__(self, config, position_embedding_type=None): |
|
|
super().__init__() |
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
|
|
raise ValueError( |
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
|
|
f"heads ({config.num_attention_heads})" |
|
|
) |
|
|
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
|
self.position_embedding_type = position_embedding_type or getattr( |
|
|
config, "position_embedding_type", "absolute" |
|
|
) |
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
|
self.max_position_embeddings = config.max_position_embeddings |
|
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
) -> tuple[torch.Tensor]: |
|
|
input_shape = hidden_states.shape[:-1] |
|
|
hidden_shape = (*input_shape, -1, self.attention_head_size) |
|
|
|
|
|
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2) |
|
|
|
|
|
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
|
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
|
query_length, key_length = query_layer.shape[2], key_layer.shape[2] |
|
|
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
|
|
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
|
|
distance = position_ids_l - position_ids_r |
|
|
|
|
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
|
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
|
|
|
|
|
if self.position_embedding_type == "relative_key": |
|
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
|
attention_scores = attention_scores + relative_position_scores |
|
|
elif self.position_embedding_type == "relative_key_query": |
|
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
|
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
|
|
|
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
|
if attention_mask is not None: |
|
|
|
|
|
attention_scores = attention_scores + attention_mask |
|
|
|
|
|
|
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
attention_probs = self.dropout(attention_probs) |
|
|
|
|
|
|
|
|
if head_mask is not None: |
|
|
attention_probs = attention_probs * head_mask |
|
|
|
|
|
context_layer = torch.matmul(attention_probs, value_layer) |
|
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
|
context_layer = context_layer.view(new_context_layer_shape) |
|
|
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
class AltRobertaSelfOutput(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
ALT_ROBERTA_SELF_ATTENTION_CLASSES = { |
|
|
"eager": AltRobertaSelfAttention, |
|
|
} |
|
|
|
|
|
|
|
|
class AltRobertaAttention(nn.Module): |
|
|
def __init__(self, config, position_embedding_type=None): |
|
|
super().__init__() |
|
|
self.self = ALT_ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( |
|
|
config, position_embedding_type=position_embedding_type |
|
|
) |
|
|
self.output = AltRobertaSelfOutput(config) |
|
|
self.pruned_heads = set() |
|
|
|
|
|
def prune_heads(self, heads): |
|
|
if len(heads) == 0: |
|
|
return |
|
|
heads, index = find_pruneable_heads_and_indices( |
|
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads |
|
|
) |
|
|
|
|
|
|
|
|
self.self.query = prune_linear_layer(self.self.query, index) |
|
|
self.self.key = prune_linear_layer(self.self.key, index) |
|
|
self.self.value = prune_linear_layer(self.self.value, index) |
|
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
|
|
|
|
|
|
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
|
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
|
|
self.pruned_heads = self.pruned_heads.union(heads) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
) -> tuple[torch.Tensor]: |
|
|
self_outputs = self.self( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
head_mask=head_mask, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
attention_output = self.output(self_outputs[0], hidden_states) |
|
|
outputs = (attention_output,) + self_outputs[1:] |
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
class AltRobertaIntermediate(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
|
|
if isinstance(config.hidden_act, str): |
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] |
|
|
else: |
|
|
self.intermediate_act_fn = config.hidden_act |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
class AltRobertaOutput(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
class AltRobertaLayer(GradientCheckpointingLayer): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward |
|
|
self.seq_len_dim = 1 |
|
|
self.attention = AltRobertaAttention(config) |
|
|
self.intermediate = AltRobertaIntermediate(config) |
|
|
self.output = AltRobertaOutput(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
**kwargs, |
|
|
) -> tuple[torch.Tensor]: |
|
|
self_attention_outputs = self.attention( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
head_mask=head_mask, |
|
|
output_attentions=output_attentions, |
|
|
**kwargs, |
|
|
) |
|
|
attention_output = self_attention_outputs[0] |
|
|
|
|
|
outputs = self_attention_outputs[1:] |
|
|
layer_output = apply_chunking_to_forward( |
|
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output |
|
|
) |
|
|
outputs = (layer_output,) + outputs |
|
|
|
|
|
return outputs |
|
|
|
|
|
def feed_forward_chunk(self, attention_output): |
|
|
intermediate_output = self.intermediate(attention_output) |
|
|
layer_output = self.output(intermediate_output, attention_output) |
|
|
return layer_output |
|
|
|
|
|
|
|
|
|
|
|
class AltRobertaEncoder(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer = nn.ModuleList([AltRobertaLayer(config) for i in range(config.num_hidden_layers)]) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
@can_return_tuple |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
output_hidden_states: Optional[bool] = False, |
|
|
return_dict: Optional[bool] = True, |
|
|
**kwargs, |
|
|
) -> Union[tuple[torch.Tensor], BaseModelOutput]: |
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attentions = () if output_attentions else None |
|
|
|
|
|
for i, layer_module in enumerate(self.layer): |
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None |
|
|
|
|
|
layer_outputs = layer_module( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
head_mask=layer_head_mask, |
|
|
output_attentions=output_attentions, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
if output_attentions: |
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
return BaseModelOutput( |
|
|
last_hidden_state=hidden_states, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class AltRobertaPooler(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = nn.Tanh() |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] |
|
|
pooled_output = self.dense(first_token_tensor) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
return pooled_output |
|
|
|
|
|
|
|
|
|
|
|
def eager_attention_forward( |
|
|
module: nn.Module, |
|
|
query: torch.Tensor, |
|
|
key: torch.Tensor, |
|
|
value: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
scaling: float, |
|
|
dropout: float = 0.0, |
|
|
**kwargs, |
|
|
): |
|
|
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling |
|
|
if attention_mask is not None: |
|
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) |
|
|
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) |
|
|
|
|
|
attn_output = torch.matmul(attn_weights, value) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
class AltCLIPAttention(nn.Module): |
|
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.embed_dim = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.head_dim = self.embed_dim // self.num_heads |
|
|
if self.head_dim * self.num_heads != self.embed_dim: |
|
|
raise ValueError( |
|
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
|
|
f" {self.num_heads})." |
|
|
) |
|
|
self.scale = self.head_dim**-0.5 |
|
|
self.dropout = config.attention_dropout |
|
|
self.is_causal = False |
|
|
|
|
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
causal_attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
|
|
batch_size, seq_length, embed_dim = hidden_states.shape |
|
|
|
|
|
queries = self.q_proj(hidden_states) |
|
|
keys = self.k_proj(hidden_states) |
|
|
values = self.v_proj(hidden_states) |
|
|
|
|
|
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
if self.config._attn_implementation != "flash_attention_2": |
|
|
if attention_mask is not None and causal_attention_mask is not None: |
|
|
attention_mask = attention_mask + causal_attention_mask |
|
|
elif causal_attention_mask is not None: |
|
|
attention_mask = causal_attention_mask |
|
|
else: |
|
|
self.is_causal = causal_attention_mask is not None |
|
|
|
|
|
attention_interface: Callable = eager_attention_forward |
|
|
if self.config._attn_implementation != "eager": |
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
|
|
attn_output, attn_weights = attention_interface( |
|
|
self, |
|
|
queries, |
|
|
keys, |
|
|
values, |
|
|
attention_mask, |
|
|
is_causal=self.is_causal, |
|
|
scaling=self.scale, |
|
|
dropout=0.0 if not self.training else self.dropout, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() |
|
|
attn_output = self.out_proj(attn_output) |
|
|
if not output_attentions: |
|
|
attn_weights = None |
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
|
|
|
|
class AltCLIPMLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.activation_fn = ACT2FN[config.hidden_act] |
|
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
|
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.fc1(hidden_states) |
|
|
hidden_states = self.activation_fn(hidden_states) |
|
|
hidden_states = self.fc2(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class AltCLIPEncoderLayer(GradientCheckpointingLayer): |
|
|
def __init__(self, config: AltCLIPConfig): |
|
|
super().__init__() |
|
|
self.embed_dim = config.hidden_size |
|
|
self.self_attn = AltCLIPAttention(config) |
|
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
self.mlp = AltCLIPMLP(config) |
|
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
causal_attention_mask: torch.Tensor, |
|
|
output_attentions: Optional[bool] = False, |
|
|
) -> tuple[torch.FloatTensor]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
|
|
attention_mask (`torch.FloatTensor`): attention mask of size |
|
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
|
|
`(config.encoder_attention_heads,)`. |
|
|
output_attentions (`bool`, *optional*): |
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
|
returned tensors for more detail. |
|
|
""" |
|
|
residual = hidden_states |
|
|
|
|
|
hidden_states = self.layer_norm1(hidden_states) |
|
|
hidden_states, attn_weights = self.self_attn( |
|
|
hidden_states=hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
causal_attention_mask=causal_attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.layer_norm2(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
outputs = (hidden_states,) |
|
|
|
|
|
if output_attentions: |
|
|
outputs += (attn_weights,) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
class AltCLIPEncoder(nn.Module): |
|
|
""" |
|
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a |
|
|
[`AltCLIPEncoderLayer`]. |
|
|
|
|
|
Args: |
|
|
config: AltCLIPConfig |
|
|
""" |
|
|
|
|
|
def __init__(self, config: AltCLIPConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layers = nn.ModuleList([AltCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
@can_return_tuple |
|
|
def forward( |
|
|
self, |
|
|
inputs_embeds, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
causal_attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple, BaseModelOutput]: |
|
|
r""" |
|
|
Args: |
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
|
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
|
|
than the model's internal embedding lookup matrix. |
|
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
|
|
|
|
|
- 1 for tokens that are **not masked**, |
|
|
- 0 for tokens that are **masked**. |
|
|
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Causal mask for the text model. Mask values selected in `[0, 1]`: |
|
|
|
|
|
- 1 for tokens that are **not masked**, |
|
|
- 0 for tokens that are **masked**. |
|
|
|
|
|
[What are attention masks?](../glossary#attention-mask) |
|
|
output_attentions (`bool`, *optional*): |
|
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
|
|
returned tensors for more detail. |
|
|
output_hidden_states (`bool`, *optional*): |
|
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
|
|
for more detail. |
|
|
return_dict (`bool`, *optional*): |
|
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
|
|
""" |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
encoder_states = () if output_hidden_states else None |
|
|
all_attentions = () if output_attentions else None |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
for idx, encoder_layer in enumerate(self.layers): |
|
|
if output_hidden_states: |
|
|
encoder_states = encoder_states + (hidden_states,) |
|
|
layer_outputs = encoder_layer( |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
causal_attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
|
|
if output_attentions: |
|
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
|
|
if output_hidden_states: |
|
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
|
|
return BaseModelOutput( |
|
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
class AltCLIPVisionEmbeddings(nn.Module): |
|
|
def __init__(self, config: AltCLIPVisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.embed_dim = config.hidden_size |
|
|
self.image_size = config.image_size |
|
|
self.patch_size = config.patch_size |
|
|
|
|
|
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) |
|
|
|
|
|
self.patch_embedding = nn.Conv2d( |
|
|
in_channels=config.num_channels, |
|
|
out_channels=self.embed_dim, |
|
|
kernel_size=self.patch_size, |
|
|
stride=self.patch_size, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2 |
|
|
self.num_positions = self.num_patches + 1 |
|
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) |
|
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) |
|
|
|
|
|
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: |
|
|
""" |
|
|
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution |
|
|
images. This method is also adapted to support torch.jit tracing. |
|
|
|
|
|
Adapted from: |
|
|
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and |
|
|
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211 |
|
|
""" |
|
|
|
|
|
num_patches = embeddings.shape[1] - 1 |
|
|
position_embedding = self.position_embedding.weight.unsqueeze(0) |
|
|
num_positions = position_embedding.shape[1] - 1 |
|
|
|
|
|
|
|
|
if not torch.jit.is_tracing() and num_patches == num_positions and height == width: |
|
|
return self.position_embedding(self.position_ids) |
|
|
|
|
|
class_pos_embed = position_embedding[:, :1] |
|
|
patch_pos_embed = position_embedding[:, 1:] |
|
|
|
|
|
dim = embeddings.shape[-1] |
|
|
|
|
|
new_height = height // self.patch_size |
|
|
new_width = width // self.patch_size |
|
|
|
|
|
sqrt_num_positions = torch_int(num_positions**0.5) |
|
|
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim) |
|
|
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) |
|
|
|
|
|
patch_pos_embed = nn.functional.interpolate( |
|
|
patch_pos_embed, |
|
|
size=(new_height, new_width), |
|
|
mode="bicubic", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) |
|
|
|
|
|
return torch.cat((class_pos_embed, patch_pos_embed), dim=1) |
|
|
|
|
|
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: |
|
|
batch_size, _, height, width = pixel_values.shape |
|
|
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size): |
|
|
raise ValueError( |
|
|
f"Input image size ({height}*{width}) doesn't match model ({self.image_size}*{self.image_size})." |
|
|
) |
|
|
target_dtype = self.patch_embedding.weight.dtype |
|
|
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) |
|
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) |
|
|
|
|
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1) |
|
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) |
|
|
if interpolate_pos_encoding: |
|
|
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) |
|
|
else: |
|
|
embeddings = embeddings + self.position_embedding(self.position_ids) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class AltCLIPPreTrainedModel(PreTrainedModel): |
|
|
config: AltCLIPConfig |
|
|
base_model_prefix = "altclip" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_module = [] |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights""" |
|
|
factor = self.config.initializer_factor |
|
|
if isinstance(module, AltCLIPVisionEmbeddings): |
|
|
factor = self.config.initializer_factor |
|
|
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) |
|
|
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) |
|
|
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) |
|
|
elif isinstance(module, AltCLIPAttention): |
|
|
factor = self.config.initializer_factor |
|
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor |
|
|
out_proj_std = (module.embed_dim**-0.5) * factor |
|
|
nn.init.normal_(module.q_proj.weight, std=in_proj_std) |
|
|
nn.init.normal_(module.k_proj.weight, std=in_proj_std) |
|
|
nn.init.normal_(module.v_proj.weight, std=in_proj_std) |
|
|
nn.init.normal_(module.out_proj.weight, std=out_proj_std) |
|
|
elif isinstance(module, AltCLIPMLP): |
|
|
factor = self.config.initializer_factor |
|
|
in_proj_std = (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor |
|
|
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor |
|
|
nn.init.normal_(module.fc1.weight, std=fc_std) |
|
|
nn.init.normal_(module.fc2.weight, std=in_proj_std) |
|
|
elif isinstance(module, AltCLIPModel): |
|
|
nn.init.normal_( |
|
|
module.text_projection.weight, |
|
|
std=module.text_embed_dim**-0.5 * self.config.initializer_factor, |
|
|
) |
|
|
module.text_projection._is_hf_initialized = True |
|
|
nn.init.normal_( |
|
|
module.visual_projection.weight, |
|
|
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, |
|
|
) |
|
|
module.visual_projection._is_hf_initialized = True |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
elif isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_factor) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
|
|
|
class AltCLIPVisionTransformer(nn.Module): |
|
|
def __init__(self, config: AltCLIPVisionConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
embed_dim = config.hidden_size |
|
|
|
|
|
self.embeddings = AltCLIPVisionEmbeddings(config) |
|
|
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
|
|
self.encoder = AltCLIPEncoder(config) |
|
|
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
|
|
|
|
|
@can_return_tuple |
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
interpolate_pos_encoding: Optional[bool] = False, |
|
|
) -> Union[tuple, BaseModelOutputWithPooling]: |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if pixel_values is None: |
|
|
raise ValueError("You have to specify pixel_values") |
|
|
|
|
|
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) |
|
|
hidden_states = self.pre_layrnorm(hidden_states) |
|
|
|
|
|
encoder_outputs = self.encoder( |
|
|
inputs_embeds=hidden_states, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
last_hidden_state = encoder_outputs[0] |
|
|
pooled_output = last_hidden_state[:, 0, :] |
|
|
pooled_output = self.post_layernorm(pooled_output) |
|
|
|
|
|
return BaseModelOutputWithPooling( |
|
|
last_hidden_state=last_hidden_state, |
|
|
pooler_output=pooled_output, |
|
|
hidden_states=encoder_outputs.hidden_states, |
|
|
attentions=encoder_outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class AltCLIPVisionModel(AltCLIPPreTrainedModel): |
|
|
config: AltCLIPVisionConfig |
|
|
main_input_name = "pixel_values" |
|
|
|
|
|
def __init__(self, config: AltCLIPVisionConfig): |
|
|
super().__init__(config) |
|
|
self.vision_model = AltCLIPVisionTransformer(config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
|
return self.vision_model.embeddings.patch_embedding |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
interpolate_pos_encoding: bool = False, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple, BaseModelOutputWithPooling]: |
|
|
r""" |
|
|
Examples: |
|
|
|
|
|
```python |
|
|
>>> from PIL import Image |
|
|
>>> import requests |
|
|
>>> from transformers import AutoProcessor, AltCLIPVisionModel |
|
|
|
|
|
>>> model = AltCLIPVisionModel.from_pretrained("BAAI/AltCLIP") |
|
|
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") |
|
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
>>> outputs = model(**inputs) |
|
|
>>> last_hidden_state = outputs.last_hidden_state |
|
|
>>> pooled_output = outputs.pooler_output # pooled CLS states |
|
|
```""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
return self.vision_model( |
|
|
pixel_values=pixel_values, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
The model behaves as an encoder following the architecture described in *Attention is |
|
|
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz |
|
|
Kaiser and Illia Polosukhin. |
|
|
|
|
|
.. _*Attention is all you need*: https://huggingface.co/papers/1706.03762 |
|
|
""" |
|
|
) |
|
|
class AltRobertaModel(AltCLIPPreTrainedModel): |
|
|
config: AltCLIPTextConfig |
|
|
|
|
|
|
|
|
def __init__(self, config, add_pooling_layer=True): |
|
|
r""" |
|
|
add_pooling_layer (bool, *optional*, defaults to `True`): |
|
|
Whether to add a pooling layer |
|
|
""" |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.embeddings = AltRobertaEmbeddings(config) |
|
|
self.encoder = AltRobertaEncoder(config) |
|
|
|
|
|
self.pooler = AltRobertaPooler(config) if add_pooling_layer else None |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embeddings.word_embeddings |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embeddings.word_embeddings = value |
|
|
|
|
|
def _prune_heads(self, heads_to_prune): |
|
|
""" |
|
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
|
|
class PreTrainedModel |
|
|
""" |
|
|
for layer, heads in heads_to_prune.items(): |
|
|
self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
|
|
|
@auto_docstring |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
|
elif input_ids is not None: |
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
|
input_shape = input_ids.size() |
|
|
elif inputs_embeds is not None: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
else: |
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
|
|
batch_size, seq_length = input_shape |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones(((batch_size, seq_length)), device=device) |
|
|
|
|
|
if token_type_ids is None: |
|
|
if hasattr(self.embeddings, "token_type_ids"): |
|
|
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] |
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) |
|
|
token_type_ids = buffered_token_type_ids_expanded |
|
|
else: |
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
|
|
|
|
|
|
|
|
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) |
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
|
|
embedding_output = self.embeddings( |
|
|
input_ids=input_ids, |
|
|
position_ids=position_ids, |
|
|
token_type_ids=token_type_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
) |
|
|
encoder_outputs = self.encoder( |
|
|
embedding_output, |
|
|
attention_mask=extended_attention_mask, |
|
|
head_mask=head_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
) |
|
|
sequence_output = encoder_outputs[0] |
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
|
|
|
return BaseModelOutputWithPooling( |
|
|
last_hidden_state=sequence_output, |
|
|
pooler_output=pooled_output, |
|
|
hidden_states=encoder_outputs.hidden_states, |
|
|
attentions=encoder_outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class AltCLIPTextModel(AltCLIPPreTrainedModel): |
|
|
config: AltCLIPTextConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.roberta = AltRobertaModel(config, add_pooling_layer=False) |
|
|
self.transformation = nn.Linear(config.hidden_size, config.project_dim) |
|
|
self.pre_LN = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
|
return self.roberta.embeddings.word_embeddings |
|
|
|
|
|
def set_input_embeddings(self, value: nn.Embedding) -> None: |
|
|
self.roberta.embeddings.word_embeddings = value |
|
|
|
|
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: |
|
|
return super().resize_token_embeddings(new_num_tokens) |
|
|
|
|
|
@can_return_tuple |
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
) -> Union[tuple, BaseModelOutputWithPoolingAndProjection]: |
|
|
r""" |
|
|
Examples: |
|
|
|
|
|
```python |
|
|
>>> from transformers import AutoProcessor, AltCLIPTextModel |
|
|
|
|
|
>>> model = AltCLIPTextModel.from_pretrained("BAAI/AltCLIP") |
|
|
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") |
|
|
|
|
|
>>> texts = ["it's a cat", "it's a dog"] |
|
|
|
|
|
>>> inputs = processor(text=texts, padding=True, return_tensors="pt") |
|
|
|
|
|
>>> outputs = model(**inputs) |
|
|
>>> last_hidden_state = outputs.last_hidden_state |
|
|
>>> pooled_output = outputs.pooler_output # pooled CLS states |
|
|
```""" |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.roberta( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
|
|
|
|
|
sequence_output = self.pre_LN(sequence_output) |
|
|
|
|
|
|
|
|
projection_state = self.transformation(sequence_output) |
|
|
pooler_output = projection_state[:, 0] |
|
|
|
|
|
return BaseModelOutputWithPoolingAndProjection( |
|
|
last_hidden_state=projection_state, |
|
|
pooler_output=pooler_output, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class AltCLIPModel(AltCLIPPreTrainedModel): |
|
|
config: AltCLIPConfig |
|
|
|
|
|
def __init__(self, config: AltCLIPConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
if not isinstance(config.vision_config, AltCLIPVisionConfig): |
|
|
raise TypeError( |
|
|
"config.vision_config is expected to be of type AltCLIPVisionConfig but is of type" |
|
|
f" {type(config.vision_config)}." |
|
|
) |
|
|
if not isinstance(config.text_config, AltCLIPTextConfig): |
|
|
raise TypeError( |
|
|
"config.text_config is expected to be of type AltCLIPTextConfig but is of type" |
|
|
f" {type(config.text_config)}." |
|
|
) |
|
|
|
|
|
text_config = config.text_config |
|
|
vision_config = config.vision_config |
|
|
|
|
|
vision_config._attn_implementation = config._attn_implementation |
|
|
|
|
|
self.projection_dim = config.projection_dim |
|
|
self.text_embed_dim = text_config.project_dim |
|
|
self.vision_embed_dim = vision_config.hidden_size |
|
|
|
|
|
self.text_model = AltCLIPTextModel(text_config) |
|
|
self.vision_model = AltCLIPVisionTransformer(vision_config) |
|
|
|
|
|
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) |
|
|
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) |
|
|
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@filter_out_non_signature_kwargs() |
|
|
@auto_docstring |
|
|
def get_text_features( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
) -> torch.FloatTensor: |
|
|
r""" |
|
|
Returns: |
|
|
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by |
|
|
applying the projection layer to the pooled output of [`AltCLIPTextModel`]. |
|
|
|
|
|
Examples: |
|
|
|
|
|
```python |
|
|
>>> import torch |
|
|
>>> from transformers import AutoProcessor, AltCLIPModel |
|
|
|
|
|
>>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") |
|
|
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") |
|
|
|
|
|
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") |
|
|
>>> with torch.inference_mode(): |
|
|
... text_features = model.get_text_features(**inputs) |
|
|
```""" |
|
|
text_outputs = self.text_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
token_type_ids=token_type_ids, |
|
|
) |
|
|
pooled_output = text_outputs.pooler_output |
|
|
text_features = self.text_projection(pooled_output) |
|
|
|
|
|
return text_features |
|
|
|
|
|
@filter_out_non_signature_kwargs() |
|
|
@auto_docstring |
|
|
def get_image_features( |
|
|
self, |
|
|
pixel_values: torch.FloatTensor, |
|
|
interpolate_pos_encoding: bool = False, |
|
|
) -> torch.FloatTensor: |
|
|
r""" |
|
|
Returns: |
|
|
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by |
|
|
applying the projection layer to the pooled output of [`AltCLIPVisionModel`]. |
|
|
|
|
|
Examples: |
|
|
|
|
|
```python |
|
|
>>> import torch |
|
|
>>> from transformers import AutoProcessor, AltCLIPModel |
|
|
>>> from transformers.image_utils import load_image |
|
|
|
|
|
>>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") |
|
|
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") |
|
|
|
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
|
>>> image = load_image(url) |
|
|
|
|
|
>>> inputs = processor(images=image, return_tensors="pt") |
|
|
>>> with torch.inference_mode(): |
|
|
... image_features = model.get_image_features(**inputs) |
|
|
```""" |
|
|
vision_outputs = self.vision_model( |
|
|
pixel_values=pixel_values, |
|
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
|
) |
|
|
pooled_output = vision_outputs.pooler_output |
|
|
image_features = self.visual_projection(pooled_output) |
|
|
|
|
|
return image_features |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
return_loss: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
interpolate_pos_encoding: bool = False, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple, AltCLIPOutput]: |
|
|
r""" |
|
|
return_loss (`bool`, *optional*): |
|
|
Whether or not to return the contrastive loss. |
|
|
|
|
|
Examples: |
|
|
|
|
|
```python |
|
|
>>> from PIL import Image |
|
|
>>> import requests |
|
|
>>> from transformers import AutoProcessor, AltCLIPModel |
|
|
|
|
|
>>> model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") |
|
|
>>> processor = AutoProcessor.from_pretrained("BAAI/AltCLIP") |
|
|
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
|
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
>>> inputs = processor( |
|
|
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True |
|
|
... ) |
|
|
>>> outputs = model(**inputs) |
|
|
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score |
|
|
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities |
|
|
```""" |
|
|
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
text_outputs = self.text_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
vision_outputs = self.vision_model( |
|
|
pixel_values=pixel_values, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
interpolate_pos_encoding=interpolate_pos_encoding, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
image_embeds = vision_outputs[1] |
|
|
image_embeds = self.visual_projection(image_embeds) |
|
|
|
|
|
text_embeds = text_outputs[1] |
|
|
text_embeds = self.text_projection(text_embeds) |
|
|
|
|
|
|
|
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
|
|
|
|
|
|
logit_scale = self.logit_scale.exp() |
|
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale |
|
|
logits_per_image = logits_per_text.T |
|
|
|
|
|
loss = None |
|
|
if return_loss: |
|
|
loss = clip_loss(logits_per_text) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return AltCLIPOutput( |
|
|
loss=loss, |
|
|
logits_per_image=logits_per_image, |
|
|
logits_per_text=logits_per_text, |
|
|
text_embeds=text_embeds, |
|
|
image_embeds=image_embeds, |
|
|
text_model_output=text_outputs, |
|
|
vision_model_output=vision_outputs, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): |
|
|
""" |
|
|
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols |
|
|
are ignored. This is modified from fairseq's `utils.make_positions`. |
|
|
|
|
|
Args: |
|
|
x: torch.Tensor x: |
|
|
|
|
|
Returns: torch.Tensor |
|
|
""" |
|
|
|
|
|
mask = input_ids.ne(padding_idx).int() |
|
|
incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask |
|
|
return incremental_indices.long() + padding_idx |
|
|
|
|
|
|
|
|
__all__ = ["AltCLIPPreTrainedModel", "AltCLIPVisionModel", "AltCLIPTextModel", "AltCLIPModel"] |
|
|
|