| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """PyTorch CLIP model.""" |
|
|
| from typing import Dict, List, Optional, Set, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from .modeling_clip import ( |
| CLIPConfig, |
| CLIPTextConfig, |
| CLIPVisionConfig, |
| CLIPEncoderLayer, |
| CLIPTextTransformer, |
| CLIPVisionTransformer, |
| CLIPModel, |
| CLIPVisionEmbeddings, |
| CLIPVisionModel, |
| CLIPOutput, |
| BaseModelOutput, |
| BaseModelOutputWithPooling, |
| ) |
|
|
|
|
| class ModLN(nn.Module): |
| def __init__(self, inner_dim: int, mod_dim: int = 32): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(mod_dim, inner_dim * 2), |
| ) |
|
|
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.zeros_(m.weight) |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, x: torch.Tensor, condition: torch.Tensor): |
| """ |
| x: [N, M, C_in], M: num of tokens |
| condition: [N, C_mod] |
| """ |
| shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) |
| return x * (1 + scale) + shift |
|
|
|
|
| class ConditionalCLIPVisionConfig(CLIPVisionConfig): |
| def __init__(self, modulation_dim: int = 32, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.modulation_dim = modulation_dim |
|
|
|
|
| class ConditionalCLIPEncoderLayer(CLIPEncoderLayer): |
| """This corresponds to the Block class in the original implementation.""" |
|
|
| def __init__(self, config: ConditionalCLIPVisionConfig) -> None: |
| super().__init__(config) |
| self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) |
| self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: torch.Tensor, |
| causal_attention_mask: torch.Tensor, |
| condition: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: |
| residual = hidden_states |
|
|
| hidden_states = self.mod_norm1(self.layer_norm1(hidden_states), condition) |
| hidden_states, attn_weights = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| causal_attention_mask=causal_attention_mask, |
| output_attentions=output_attentions, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.mod_norm2(self.layer_norm2(hidden_states), condition) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (attn_weights,) |
|
|
| return outputs |
|
|
|
|
| class ConditionalCLIPEncoder(nn.Module): |
| def __init__(self, config: CLIPConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.layers = nn.ModuleList( |
| [ |
| ConditionalCLIPEncoderLayer(config) |
| for _ in range(config.num_hidden_layers) |
| ] |
| ) |
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| inputs_embeds, |
| attention_mask: Optional[torch.Tensor] = None, |
| causal_attention_mask: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| condition: Optional[torch.Tensor] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[tuple, BaseModelOutput]: |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| encoder_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
|
|
| hidden_states = inputs_embeds |
| for idx, encoder_layer in enumerate(self.layers): |
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| encoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| causal_attention_mask, |
| condition=condition, |
| output_attentions=output_attentions, |
| ) |
| else: |
| layer_outputs = encoder_layer( |
| hidden_states, |
| attention_mask, |
| causal_attention_mask, |
| condition=condition, |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_attentions = all_attentions + (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| encoder_states = encoder_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v |
| for v in [hidden_states, encoder_states, all_attentions] |
| if v is not None |
| ) |
| return BaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=encoder_states, |
| attentions=all_attentions, |
| ) |
|
|
|
|
| class ConditionalCLIPVisionTransformer(CLIPVisionTransformer): |
| def __init__(self, config: ConditionalCLIPVisionConfig): |
| super().__init__(config) |
| self.config = config |
| embed_dim = config.hidden_size |
|
|
| self.embeddings = CLIPVisionEmbeddings(config) |
| self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
| self.encoder = ConditionalCLIPEncoder(config) |
| self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) |
|
|
| def forward( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| condition: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| if pixel_values is None: |
| raise ValueError("You have to specify pixel_values") |
|
|
| hidden_states = self.embeddings(pixel_values) |
| hidden_states = self.pre_layrnorm(hidden_states) |
|
|
| encoder_outputs = self.encoder( |
| inputs_embeds=hidden_states, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| condition=condition, |
| return_dict=return_dict, |
| ) |
|
|
| last_hidden_state = encoder_outputs[0] |
| pooled_output = last_hidden_state[:, 0, :] |
| pooled_output = self.post_layernorm(pooled_output) |
|
|
| if not return_dict: |
| return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
|
|
| return BaseModelOutputWithPooling( |
| last_hidden_state=last_hidden_state, |
| pooler_output=pooled_output, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
|
|
|
|
| class ConditionalCLIPVisionModel(CLIPVisionModel): |
| config_class = ConditionalCLIPVisionConfig |
|
|
| def __init__(self, config: ConditionalCLIPVisionConfig): |
| super().__init__(config) |
| self.vision_model = ConditionalCLIPVisionTransformer(config) |
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| condition: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| return self.vision_model( |
| pixel_values=pixel_values, |
| condition=condition, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
|
|
| class ConditionalCLIPModel(CLIPModel): |
| config_class = CLIPConfig |
|
|
| def __init__(self, config: CLIPConfig): |
| super().__init__(config) |
|
|
| if not isinstance(config.text_config, CLIPTextConfig): |
| raise ValueError( |
| "config.text_config is expected to be of type CLIPTextConfig but is of type" |
| f" {type(config.text_config)}." |
| ) |
|
|
| if not isinstance(config.vision_config, CLIPVisionConfig): |
| raise ValueError( |
| "config.vision_config is expected to be of type CLIPVisionConfig but is of type" |
| f" {type(config.vision_config)}." |
| ) |
|
|
| text_config = config.text_config |
| vision_config = config.vision_config |
|
|
| self.projection_dim = config.projection_dim |
| self.text_embed_dim = text_config.hidden_size |
| self.vision_embed_dim = vision_config.hidden_size |
|
|
| self.text_model = CLIPTextTransformer(text_config) |
| self.vision_model = ConditionalCLIPVisionTransformer(vision_config) |
|
|
| self.visual_projection = nn.Linear( |
| self.vision_embed_dim, self.projection_dim, bias=False |
| ) |
| self.text_projection = nn.Linear( |
| self.text_embed_dim, self.projection_dim, bias=False |
| ) |
| self.logit_scale = nn.Parameter( |
| torch.tensor(self.config.logit_scale_init_value) |
| ) |
|
|
| |
| self.post_init() |
|
|
| def get_image_features( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| condition: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> torch.FloatTensor: |
| |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| vision_outputs = self.vision_model( |
| pixel_values=pixel_values, |
| condition=condition, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = vision_outputs[1] |
| image_features = self.visual_projection(pooled_output) |
|
|
| return image_features |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| condition: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| return_loss: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CLIPOutput]: |
| |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| vision_outputs = self.vision_model( |
| pixel_values=pixel_values, |
| condition=condition, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| text_outputs = self.text_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| image_embeds = vision_outputs[1] |
| image_embeds = self.visual_projection(image_embeds) |
|
|
| text_embeds = text_outputs[1] |
| text_embeds = self.text_projection(text_embeds) |
|
|
| |
| image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) |
| text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) |
|
|
| |
| logit_scale = self.logit_scale.exp() |
| logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale |
| logits_per_image = logits_per_text.t() |
|
|
| loss = None |
| if return_loss: |
| loss = clip_loss(logits_per_text) |
|
|
| if not return_dict: |
| output = ( |
| logits_per_image, |
| logits_per_text, |
| text_embeds, |
| image_embeds, |
| text_outputs, |
| vision_outputs, |
| ) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return CLIPOutput( |
| loss=loss, |
| logits_per_image=logits_per_image, |
| logits_per_text=logits_per_text, |
| text_embeds=text_embeds, |
| image_embeds=image_embeds, |
| text_model_output=text_outputs, |
| vision_model_output=vision_outputs, |
| ) |
|
|