dheeena's picture
Add files using upload-large-folder tool
b9ae124 verified
# coding=utf-8
# Copyright 2022 The Salesforce Team Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BLIP model."""
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
from torch import nn
from torch.nn.functional import normalize
from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int
from ...utils.generic import check_model_inputs
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
logger = logging.get_logger(__name__)
# Copied from transformers.models.clip.modeling_clip.contrastive_loss
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
# Copied from transformers.models.clip.modeling_clip.clip_loss with clip->blip
def blip_loss(similarity: torch.Tensor) -> torch.Tensor:
caption_loss = contrastive_loss(similarity)
image_loss = contrastive_loss(similarity.t())
return (caption_loss + image_loss) / 2.0
@dataclass
@auto_docstring(
custom_intro="""
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
last hidden states. This class also adds the loss term from the text decoder.
"""
)
class BlipForConditionalGenerationModelOutput(ModelOutput):
r"""
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
Language modeling loss from the text decoder.
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`, *optional*):
Prediction scores of the language modeling head of the text decoder model.
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, *optional*):
The image embeddings obtained after applying the Vision Transformer model to the input image.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss: Optional[tuple[torch.FloatTensor]] = None
logits: Optional[tuple[torch.FloatTensor]] = None
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
@property
def decoder_logits(self):
warnings.warn(
"`decoder_logits` attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the `logits` attribute to retrieve the final output instead.",
FutureWarning,
)
return self.logits
@dataclass
@auto_docstring(
custom_intro="""
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
last hidden states. This class also adds the loss term from the text decoder.
"""
)
class BlipTextVisionModelOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss from the text decoder.
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
"""
loss: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
@dataclass
@auto_docstring(
custom_intro="""
Adapted from the base class for vision model's outputs that also contains image embeddings of the pooling of the
last hidden states. This class also adds the loss term from the text decoder as well as the image-text similarity
scores.
"""
)
class BlipImageTextMatchingModelOutput(ModelOutput):
r"""
itm_score (`torch.FloatTensor`):
The image-text similarity scores.
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss from the text decoder.
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
vision_pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`, *optional*):
Last layer hidden-state of the vision of the vision-only branch of the model.
question_embeds (`torch.FloatTensor`):
The question embeddings obtained by the text projection layer.
"""
itm_score: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
vision_pooler_output: Optional[torch.FloatTensor] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
question_embeds: Optional[tuple[torch.FloatTensor]] = None
@dataclass
@auto_docstring
class BlipOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`BlipTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`BlipVisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`BlipTextModel`].
vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`BlipVisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: Optional[torch.FloatTensor] = None
logits_per_text: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class BlipVisionEmbeddings(nn.Module):
def __init__(self, config: BlipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embedding
class_pos_embed = self.position_embedding[:, :1]
patch_pos_embed = self.position_embedding[:, 1:]
dim = embeddings.shape[-1]
new_height = height // self.patch_size
new_width = width // self.patch_size
sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if interpolate_pos_encoding:
position_embedding = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embedding = self.position_embedding
embeddings = embeddings + position_embedding[:, : embeddings.size(1), :].to(target_dtype)
return embeddings
# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Blip
class BlipTextEmbeddings(nn.Module):
def __init__(self, config: BlipTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
max_position_embedding = self.position_embedding.weight.shape[0]
if seq_length > max_position_embedding:
raise ValueError(
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
class BlipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = nn.Dropout(config.attention_dropout)
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim)
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, embed_dim = hidden_states.size()
mixed_qkv = (
self.qkv(hidden_states)
.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
query_states, key_states, value_states = mixed_qkv[0], mixed_qkv[1], mixed_qkv[2]
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores * self.scale
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.projection(context_layer)
return output, attention_probs
# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Blip
class BlipMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class BlipEncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: BlipConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = BlipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = BlipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
@auto_docstring
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
**kwargs: Unpack[TransformersKwargs],
) -> torch.FloatTensor:
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
head_mask=attention_mask,
**kwargs,
)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
@auto_docstring
class BlipPreTrainedModel(PreTrainedModel):
config: BlipConfig
base_model_prefix = "blip"
supports_gradient_checkpointing = True
_no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"]
_skip_keys_device_placement = ["past_key_values"]
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, (nn.Conv2d, nn.Embedding, nn.Linear)):
module.weight.data.normal_(mean=0.0, std=factor)
if hasattr(module, "bias") and module.bias is not None:
module.bias.data.zero_()
if isinstance(module, BlipVisionEmbeddings):
if hasattr(self.config, "vision_config"):
factor = self.config.vision_config.initializer_range
nn.init.trunc_normal_(
module.position_embedding,
mean=0.0,
std=factor,
)
nn.init.trunc_normal_(
module.class_embedding,
mean=0.0,
std=factor,
)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
class BlipEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`BlipEncoderLayer`].
Args:
config (`BlipConfig`):
The corresponding vision configuration for the `BlipEncoder`.
"""
def __init__(self, config: BlipConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([BlipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
@auto_docstring
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutput]:
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(
hidden_states,
attention_mask=attention_mask,
**kwargs,
)
return BaseModelOutput(last_hidden_state=hidden_states)
class BlipVisionModel(BlipPreTrainedModel):
main_input_name = "pixel_values"
config: BlipVisionConfig
_can_record_outputs = {
"hidden_states": BlipEncoderLayer,
"attentions": BlipAttention,
}
def __init__(self, config: BlipVisionConfig):
super().__init__(config)
self.config = config
embed_dim = config.hidden_size
self.embeddings = BlipVisionEmbeddings(config)
self.encoder = BlipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.post_init()
@check_model_inputs(tie_last_hidden_states=False)
@auto_docstring
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BaseModelOutputWithPooling]:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
**kwargs,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
)
def get_input_embeddings(self):
return self.embeddings
@auto_docstring(
custom_intro="""
This model is going to be deprecated in future versions. Please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase.
"""
)
class BlipModel(BlipPreTrainedModel):
config: BlipConfig
def __init__(self, config: BlipConfig):
super().__init__(config)
if not isinstance(config.text_config, BlipTextConfig):
raise TypeError(
"config.text_config is expected to be of type BlipTextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, BlipVisionConfig):
raise TypeError(
"config.vision_config is expected to be of type BlipVisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
self.projection_dim = config.projection_dim
self.text_embed_dim = text_config.hidden_size
self.vision_embed_dim = vision_config.hidden_size
self.text_model = BlipTextModel(text_config)
self.vision_model = BlipVisionModel(vision_config)
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
logger.warning(
"`BlipModel` is going to be deprecated in future release, please use `BlipForConditionalGeneration`, `BlipForQuestionAnswering` or `BlipForImageTextRetrieval` depending on your usecase."
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.text_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.text_model.set_input_embeddings(value)
@auto_docstring
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`BlipTextModel`].
Examples:
```python
>>> from transformers import AutoProcessor, BlipModel
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> inputs = processor(text=["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
>>> text_features = model.get_text_features(**inputs)
```"""
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
pooled_output = text_outputs[1]
text_features = self.text_projection(pooled_output)
return text_features
@auto_docstring
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`BlipVisionModel`].
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, BlipModel
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> image_features = model.get_image_features(**inputs)
```"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
pooled_output = vision_outputs[1] # pooled_output
image_features = self.visual_projection(pooled_output)
return image_features
@auto_docstring
def get_multimodal_features(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
) -> torch.FloatTensor:
r"""
Returns:
multimodal_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The multimodal embeddings
obtained by applying the image embeddings to the text encoder using the cross-attention mechanism.
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, BlipModel
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> texts = ["a photo of a cat", "a photo of a dog"]
>>> inputs = processor(images=image, text=texts, padding=True, return_tensors="pt")
>>> multimodal_features = model.get_multimodal_features(**inputs)
```"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
)
pooled_output = text_outputs[1] # pooled_output
multimodal_features = self.text_projection(pooled_output)
return multimodal_features
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BlipOutput]:
r"""
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, BlipModel
>>> model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
... )
>>> outputs = model(**inputs)
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
**kwargs,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
image_embeds = vision_outputs.pooler_output
image_embeds = self.visual_projection(image_embeds)
text_embeds = text_outputs.pooler_output
text_embeds = self.text_projection(text_embeds)
# normalized features
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp().to(device=text_embeds.device)
image_embeds = image_embeds.to(device=text_embeds.device, dtype=text_embeds.dtype)
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
loss = blip_loss(logits_per_text)
return BlipOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
@auto_docstring(
custom_intro="""
BLIP Model for image captioning. The model consists of a vision encoder and a text decoder. One can optionally pass
`input_ids` to the model, which serve as a text prompt, to make the text decoder continue the prompt. Otherwise,
the decoder starts generating text from the [BOS] (beginning-of-sequence) token. will start generating the caption
from the text input. If no text input is provided, the decoder will start with the [BOS] token only.
"""
)
class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin):
config: BlipConfig
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
main_input_name = "pixel_values"
def __init__(self, config: BlipConfig):
super().__init__(config)
self.vision_model = BlipVisionModel(config.vision_config)
self.text_decoder = BlipTextLMHeadModel(config.text_config)
self.decoder_input_ids = config.text_config.bos_token_id
self.decoder_pad_token_id = config.text_config.pad_token_id
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.text_decoder.get_input_embeddings()
def set_input_embeddings(self, value):
self.text_decoder.set_input_embeddings(value)
@can_return_tuple
@auto_docstring
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BlipForConditionalGenerationModelOutput]:
r"""
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, BlipForConditionalGeneration
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "A picture of"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> outputs = model(**inputs)
```"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
**kwargs,
)
image_embeds = vision_outputs.last_hidden_state
outputs = self.text_decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
labels=labels,
reduction="mean",
**kwargs,
)
return BlipForConditionalGenerationModelOutput(
loss=outputs.loss,
logits=outputs.logits,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
attentions=vision_outputs.attentions,
)
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
r"""
Overrides *generate* function to be able to use the model as a conditional generator
Parameters:
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
The sequence used as a prompt for the generation.
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, BlipForConditionalGeneration
>>> model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, return_tensors="pt")
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
two cats sleeping on a couch
```
"""
batch_size = pixel_values.shape[0]
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
elif input_ids is None:
input_ids = (
torch.LongTensor([[self.decoder_input_ids, self.config.text_config.eos_token_id]])
.repeat(batch_size, 1)
.to(image_embeds.device)
)
input_ids[:, 0] = self.config.text_config.bos_token_id
attention_mask = attention_mask[:, :-1] if attention_mask is not None else None
outputs = self.text_decoder.generate(
input_ids=input_ids[:, :-1],
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
**generate_kwargs,
)
return outputs
@auto_docstring(
custom_intro="""
BLIP Model for visual question answering. The model consists of a vision encoder, a text encoder as well as a text
decoder. The vision encoder will encode the input image, the text encoder will encode the input question together
with the encoding of the image, and the text decoder will output the answer to the question.
"""
)
class BlipForQuestionAnswering(BlipPreTrainedModel, GenerationMixin):
config: BlipConfig
_tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
def __init__(self, config: BlipConfig):
super().__init__(config)
self.vision_model = BlipVisionModel(config.vision_config)
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
self.text_decoder = BlipTextLMHeadModel(config.text_config)
self.decoder_pad_token_id = config.text_config.pad_token_id
self.decoder_start_token_id = config.text_config.bos_token_id
# Initialize weights and apply final processing
self.post_init()
def set_input_embeddings(self, value):
self.text_encoder.set_input_embeddings(value)
def get_input_embeddings(self):
# This will return shared embeddings if they are shared else specific to encoder.
return self.text_encoder.get_input_embeddings()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BlipTextVisionModelOutput]:
r"""
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, BlipForQuestionAnswering
>>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> # training
>>> text = "How many cats are in the picture?"
>>> label = "2"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> labels = processor(text=label, return_tensors="pt").input_ids
>>> inputs["labels"] = labels
>>> outputs = model(**inputs)
>>> loss = outputs.loss
>>> loss.backward()
>>> # inference
>>> text = "How many cats are in the picture?"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
2
```"""
if labels is None and decoder_input_ids is None:
raise ValueError(
"Either `decoder_input_ids` or `labels` should be passed when calling `forward` with"
" `BlipForQuestionAnswering`. if you are training the model make sure that `labels` is passed, if you"
" are using the model for inference make sure that `decoder_input_ids` is passed or call `generate`"
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
**kwargs,
)
image_embeds = vision_outputs.last_hidden_state
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
question_embeds = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
**kwargs,
)
if labels is not None and decoder_input_ids is None:
# labels are already shifted right, see: https://github.com/huggingface/transformers/pull/23153
decoder_input_ids = labels
question_embeds = question_embeds[0]
answer_output = self.text_decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=question_embeds,
encoder_attention_mask=attention_mask,
labels=labels,
reduction="mean",
**kwargs,
)
if labels is not None:
decoder_loss = answer_output.loss.mean()
else:
decoder_loss = None
return BlipTextVisionModelOutput(
loss=decoder_loss,
image_embeds=image_embeds,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
attentions=vision_outputs.attentions,
)
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**generate_kwargs,
) -> torch.LongTensor:
r"""
Overrides *generate* function to be able to use the model as a conditional generator
Parameters:
input_ids (*torch.LongTensor* of shape *(batch_size, sequence_length)*):
The sequence used as a prompt for the generation.
pixel_values (*torch.FloatTensor* of shape *(batch_size, num_channels, image_height, image_width)*:
Input image to be processed
attention_mask (*torch.LongTensor* of shape *(batch_size, sequence_length)*, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`. `1` for
tokens that are NOT MASKED, `0` for MASKED tokens.
**generate_kwargs:
Additional arguments passed to the *generate* function of the decoder
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, BlipForQuestionAnswering
>>> model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "How many cats are in the picture?"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> outputs = model.generate(**inputs)
>>> print(processor.decode(outputs[0], skip_special_tokens=True))
2
```
"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
image_embeds = vision_outputs[0]
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
if isinstance(input_ids, list):
input_ids = torch.LongTensor(input_ids)
question_outputs = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=False,
)
question_embeds = question_outputs[0]
question_attention_mask = torch.ones(
question_embeds.size()[:-1], dtype=torch.long, device=question_embeds.device
)
bos_ids = torch.full(
(question_embeds.size(0), 1), fill_value=self.decoder_start_token_id, device=question_embeds.device
)
outputs = self.text_decoder.generate(
input_ids=bos_ids,
eos_token_id=self.config.text_config.sep_token_id,
pad_token_id=self.config.text_config.pad_token_id,
encoder_hidden_states=question_embeds,
encoder_attention_mask=question_attention_mask,
**generate_kwargs,
)
return outputs
@auto_docstring(
custom_intro="""
BLIP Model with a vision and text projector, and a classification head on top. The model is used in the context of
image-text retrieval. Given an image and a text, the model returns the probability of the text being relevant to
the image.
"""
)
class BlipForImageTextRetrieval(BlipPreTrainedModel):
config: BlipConfig
def __init__(self, config: BlipConfig):
super().__init__(config)
self.vision_model = BlipVisionModel(config.vision_config)
self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
# vision projection layer
self.vision_proj = nn.Linear(config.vision_config.hidden_size, config.image_text_hidden_size)
# text projection layer
self.text_proj = nn.Linear(config.text_config.hidden_size, config.image_text_hidden_size)
# image text matching head
self.itm_head = nn.Linear(config.text_config.hidden_size, 2)
self.decoder_pad_token_id = (
config.text_config.pad_token_id
if not hasattr(config, "decoder_pad_token_id")
else config.decoder_pad_token_id
)
self.decoder_start_token_id = (
config.text_config.bos_token_id
if not hasattr(config, "decoder_start_token_id")
else config.decoder_start_token_id
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.text_encoder.get_input_embeddings()
def set_input_embeddings(self, value):
self.text_encoder.set_input_embeddings(value)
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
use_itm_head: Optional[bool] = True,
attention_mask: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, BlipTextVisionModelOutput]:
r"""
use_itm_head (`bool`, *optional*, defaults to `True`):
Whether or not to use the image-text matching head.
Examples:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, BlipForImageTextRetrieval
>>> model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-base-coco")
>>> processor = AutoProcessor.from_pretrained("Salesforce/blip-itm-base-coco")
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "an image of a cat"
>>> inputs = processor(images=image, text=text, return_tensors="pt")
>>> outputs = model(**inputs)
```
"""
vision_outputs = self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
**kwargs,
)
image_embeds = vision_outputs.last_hidden_state
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long)
if use_itm_head:
question_embeds = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
**kwargs,
)
question_embeds = question_embeds.last_hidden_state
output = self.itm_head(question_embeds[:, 0, :])
else:
question_embeds = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
**kwargs,
)
question_embeds = question_embeds.last_hidden_state
image_feat = normalize(self.vision_proj(image_embeds[:, 0, :]), dim=-1)
text_feat = normalize(self.text_proj(question_embeds[:, 0, :]), dim=-1)
output = image_feat @ text_feat.t()
return BlipImageTextMatchingModelOutput(
itm_score=output,
last_hidden_state=vision_outputs.last_hidden_state,
hidden_states=vision_outputs.hidden_states,
attentions=vision_outputs.attentions,
question_embeds=question_embeds,
)
__all__ = [
"BlipModel",
"BlipPreTrainedModel",
"BlipForConditionalGeneration",
"BlipForQuestionAnswering",
"BlipVisionModel",
"BlipTextModel",
"BlipForImageTextRetrieval",
]