| """ |
| LongCLIP: Unlocking the Long-Text Capability of CLIP |
| |
| This module provides HuggingFace Transformers-compatible implementations of LongCLIP, |
| which extends CLIP's text encoder to support 248 tokens (vs 77 in original CLIP). |
| |
| Repository: https://github.com/beichenzbc/Long-CLIP |
| Paper: https://arxiv.org/abs/2403.15378 |
| """ |
|
|
| import logging |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig |
| from transformers import CLIPTextModel, CLIPVisionModel, CLIPModel |
| from transformers import CLIPImageProcessor, CLIPTokenizer |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.models.clip.modeling_clip import CLIPTextTransformer |
| from transformers.processing_utils import ProcessorMixin |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
|
|
|
|
| class LongCLIPTextConfig(CLIPTextConfig): |
| """ |
| Configuration class for LongCLIP text model. |
| |
| Extends CLIPTextConfig to support 248 token context length |
| and custom positional embedding interpolation. |
| |
| Args: |
| max_position_embeddings (int, optional): Maximum sequence length. Defaults to 248. |
| use_position_interpolation (bool, optional): Whether to use position interpolation. |
| Defaults to True. |
| interpolation_keep_length (int, optional): Number of positions to keep from |
| original embeddings before interpolation. Defaults to 20. |
| **kwargs: Additional arguments passed to CLIPTextConfig. |
| """ |
|
|
| model_type = "longclip_text_model" |
|
|
| def __init__( |
| self, |
| max_position_embeddings: int = 248, |
| use_position_interpolation: bool = True, |
| interpolation_keep_length: int = 20, |
| **kwargs, |
| ): |
| super().__init__(max_position_embeddings=max_position_embeddings, **kwargs) |
|
|
| self.use_position_interpolation = use_position_interpolation |
| self.interpolation_keep_length = interpolation_keep_length |
|
|
|
|
| class LongCLIPVisionConfig(CLIPVisionConfig): |
| """ |
| Configuration class for LongCLIP vision model. |
| |
| This is identical to the standard CLIPVisionConfig as LongCLIP |
| does not modify the vision encoder. |
| |
| Args: |
| **kwargs: Arguments passed to CLIPVisionConfig. |
| """ |
|
|
| model_type = "longclip_vision_model" |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
|
|
| class LongCLIPConfig(CLIPConfig): |
| """ |
| Configuration class for LongCLIP model. |
| |
| Combines LongCLIPTextConfig and LongCLIPVisionConfig to create |
| a complete LongCLIP model configuration. |
| |
| Args: |
| text_config (Dict[str, Any] or LongCLIPTextConfig, optional): |
| Configuration for the text model. If None, uses default LongCLIPTextConfig. |
| vision_config (Dict[str, Any] or LongCLIPVisionConfig, optional): |
| Configuration for the vision model. If None, uses default LongCLIPVisionConfig. |
| projection_dim (int, optional): Dimensionality of text and vision projection layers. |
| Defaults to 512. |
| **kwargs: Additional arguments passed to CLIPConfig. |
| """ |
|
|
| model_type = "longclip" |
| is_composition = True |
|
|
| def __init__( |
| self, |
| text_config: Dict[str, Any] | None = None, |
| vision_config: Dict[str, Any] | None = None, |
| projection_dim: int = 512, |
| **kwargs, |
| ): |
| |
| if text_config is None: |
| text_config = {} |
| logger.info( |
| "text_config is None. Initializing the LongCLIPTextConfig with default values." |
| ) |
|
|
| if vision_config is None: |
| vision_config = {} |
| logger.info( |
| "vision_config is None. Initializing the LongCLIPVisionConfig with default values." |
| ) |
|
|
| |
| if isinstance(text_config, dict): |
| text_config = LongCLIPTextConfig(**text_config) |
|
|
| if isinstance(vision_config, dict): |
| vision_config = LongCLIPVisionConfig(**vision_config) |
|
|
| |
| super().__init__( |
| text_config=text_config.to_dict(), |
| vision_config=vision_config.to_dict(), |
| projection_dim=projection_dim, |
| **kwargs, |
| ) |
|
|
| |
| self.text_config = text_config |
| self.vision_config = vision_config |
|
|
| @classmethod |
| def from_text_vision_configs( |
| cls, |
| text_config: LongCLIPTextConfig, |
| vision_config: LongCLIPVisionConfig, |
| **kwargs, |
| ): |
| """ |
| Instantiate a LongCLIPConfig from text and vision configs. |
| |
| Args: |
| text_config (LongCLIPTextConfig): Text model configuration. |
| vision_config (LongCLIPVisionConfig): Vision model configuration. |
| **kwargs: Additional keyword arguments. |
| |
| Returns: |
| LongCLIPConfig: Configuration object. |
| """ |
| return cls( |
| text_config=text_config.to_dict(), |
| vision_config=vision_config.to_dict(), |
| **kwargs, |
| ) |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """ |
| Serializes this instance to a Python dictionary. |
| |
| Returns: |
| Dict[str, Any]: Dictionary of all attributes. |
| """ |
| output = super().to_dict() |
| |
| if hasattr(self, "text_config") and isinstance( |
| self.text_config, PretrainedConfig |
| ): |
| output["text_config"] = self.text_config.to_dict() |
| if hasattr(self, "vision_config") and isinstance( |
| self.vision_config, PretrainedConfig |
| ): |
| output["vision_config"] = self.vision_config.to_dict() |
| return output |
|
|
|
|
| |
|
|
|
|
| class LongCLIPTextEmbeddings(nn.Module): |
| """ |
| Text embeddings for LongCLIP with custom positional embedding mechanism. |
| |
| This module implements the dual positional embedding approach used in LongCLIP: |
| - The first 20 positions use the original CLIP positional embeddings (mask1) |
| - The remaining positions (21-248) use interpolated embeddings (mask2) |
| - position_embedding: Fixed base embeddings |
| - position_embedding_res: Trainable residual embeddings |
| |
| Args: |
| config (LongCLIPTextConfig): Configuration for text embeddings. |
| """ |
|
|
| def __init__(self, config: LongCLIPTextConfig): |
| super().__init__() |
| self.config = config |
| embed_dim = config.hidden_size |
|
|
| |
| self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) |
|
|
| |
| |
| self.position_embedding = nn.Embedding( |
| config.max_position_embeddings, embed_dim |
| ) |
|
|
| |
| self.position_embedding_res = nn.Parameter( |
| torch.zeros(config.max_position_embeddings, embed_dim) |
| ) |
|
|
| |
| |
| |
| self.register_buffer( |
| "mask1", self._create_mask(config, use_first=True), persistent=False |
| ) |
| self.register_buffer( |
| "mask2", self._create_mask(config, use_first=False), persistent=False |
| ) |
|
|
| |
| self.register_buffer( |
| "position_ids", |
| torch.arange(config.max_position_embeddings).expand((1, -1)), |
| persistent=False, |
| ) |
|
|
| def _create_mask(self, config: LongCLIPTextConfig, use_first: bool) -> torch.Tensor: |
| """ |
| Create mask for positional embeddings. |
| |
| Args: |
| config: Configuration object. |
| use_first: If True, mask first `interpolation_keep_length` positions. |
| If False, mask remaining positions. |
| |
| Returns: |
| Mask tensor of shape [max_position_embeddings, 1]. |
| """ |
| mask = torch.zeros(config.max_position_embeddings, 1) |
| if use_first: |
| |
| mask[: config.interpolation_keep_length] = 1.0 |
| else: |
| |
| mask[config.interpolation_keep_length :] = 1.0 |
| return mask |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Forward pass for text embeddings. |
| |
| Args: |
| input_ids: Token IDs of shape [batch_size, seq_length]. |
| position_ids: Position IDs of shape [batch_size, seq_length]. |
| inputs_embeds: Pre-computed token embeddings. |
| |
| Returns: |
| Embeddings of shape [batch_size, seq_length, hidden_size]. |
| """ |
| seq_length = ( |
| input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
| ) |
|
|
| if position_ids is None: |
| position_ids = self.position_ids[:, :seq_length] |
|
|
| |
| if inputs_embeds is None: |
| inputs_embeds = self.token_embedding(input_ids) |
|
|
| |
| position_embeddings = self.position_embedding(position_ids) |
|
|
| |
| |
| position_embeddings_res = self.position_embedding_res.unsqueeze(0).expand( |
| position_ids.shape[0], -1, -1 |
| )[:, :seq_length, :] |
|
|
| |
| |
| mask1 = self.mask1[:seq_length].transpose(0, 1) |
| mask2 = self.mask2[:seq_length].transpose(0, 1) |
|
|
| |
| embeddings = ( |
| inputs_embeds |
| + position_embeddings * mask1.unsqueeze(-1) |
| + position_embeddings_res * mask2.unsqueeze(-1) |
| ) |
|
|
| return embeddings |
|
|
|
|
| class LongCLIPTextTransformer(CLIPTextTransformer): |
| """ |
| Text transformer for LongCLIP. |
| |
| This extends CLIPTextTransformer to use LongCLIPTextEmbeddings |
| with custom positional embedding mechanism. |
| |
| Args: |
| config (LongCLIPTextConfig): Configuration for text transformer. |
| """ |
|
|
| def __init__(self, config: LongCLIPTextConfig): |
| super().__init__(config) |
| |
| self.embeddings = LongCLIPTextEmbeddings(config) |
|
|
|
|
| class LongCLIPTextModel(CLIPTextModel): |
| """ |
| LongCLIP text model compatible with HuggingFace Transformers. |
| |
| This model extends CLIPTextModel to support 248 token context length |
| with custom positional embedding interpolation. |
| |
| Args: |
| config (LongCLIPTextConfig): Configuration for the text model. |
| """ |
|
|
| config_class = LongCLIPTextConfig |
|
|
| def __init__(self, config: LongCLIPTextConfig): |
| super().__init__(config) |
| |
| self.text_model = LongCLIPTextTransformer(config) |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self) -> nn.Module: |
| """Get token embedding layer.""" |
| return self.text_model.embeddings.token_embedding |
|
|
| def set_input_embeddings(self, value: nn.Module): |
| """Set token embedding layer.""" |
| self.text_model.embeddings.token_embedding = value |
|
|
|
|
| class LongCLIPVisionModel(CLIPVisionModel): |
| """ |
| LongCLIP vision model. |
| |
| This is identical to CLIPVisionModel as LongCLIP does not modify |
| the vision encoder. Provided for API consistency. |
| |
| Args: |
| config (LongCLIPVisionConfig): Configuration for the vision model. |
| """ |
|
|
| config_class = LongCLIPVisionConfig |
|
|
|
|
| class LongCLIPModel(CLIPModel): |
| """ |
| LongCLIP model combining text and vision encoders. |
| |
| This model extends CLIPModel to use LongCLIPTextModel with 248 token |
| context length while keeping the standard vision encoder. |
| |
| Args: |
| config (LongCLIPConfig): Configuration for the complete model. |
| """ |
|
|
| config_class = LongCLIPConfig |
|
|
| def __init__(self, config: LongCLIPConfig): |
| super().__init__(config) |
|
|
| |
| if not isinstance(config.text_config, LongCLIPTextConfig): |
| text_config = LongCLIPTextConfig(**config.text_config) |
| else: |
| text_config = config.text_config |
|
|
| self.text_model = LongCLIPTextModel(text_config) |
|
|
| |
| if not isinstance(config.vision_config, LongCLIPVisionConfig): |
| vision_config = LongCLIPVisionConfig(**config.vision_config) |
| else: |
| vision_config = config.vision_config |
|
|
| self.vision_model = LongCLIPVisionModel(vision_config) |
|
|
| |
| self.post_init() |
|
|
| def get_text_features( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> torch.FloatTensor: |
| """ |
| Get text features from the text encoder. |
| |
| Args: |
| input_ids: Token IDs. |
| attention_mask: Attention mask. |
| position_ids: Position IDs. |
| output_attentions: Whether to output attention weights. |
| output_hidden_states: Whether to output hidden states. |
| return_dict: Whether to return a ModelOutput object. |
| |
| Returns: |
| Text features of shape [batch_size, projection_dim]. |
| """ |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| text_outputs = self.text_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = ( |
| text_outputs[1] if not return_dict else text_outputs.pooler_output |
| ) |
| text_features = self.text_projection(pooled_output) |
|
|
| return text_features |
|
|
| def get_image_features( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> torch.FloatTensor: |
| """ |
| Get image features from the vision encoder. |
| |
| Args: |
| pixel_values: Pixel values. |
| output_attentions: Whether to output attention weights. |
| output_hidden_states: Whether to output hidden states. |
| return_dict: Whether to return a ModelOutput object. |
| |
| Returns: |
| Image features of shape [batch_size, projection_dim]. |
| """ |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| vision_outputs = self.vision_model( |
| pixel_values=pixel_values, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = ( |
| vision_outputs[1] if not return_dict else vision_outputs.pooler_output |
| ) |
| image_features = self.visual_projection(pooled_output) |
|
|
| return image_features |
|
|
|
|
| |
|
|
|
|
| class LongCLIPProcessor(ProcessorMixin): |
| """ |
| Processor for LongCLIP that combines image and text preprocessing. |
| |
| This processor wraps CLIPImageProcessor and CLIPTokenizer to provide |
| a unified interface for preprocessing inputs for LongCLIP models. |
| |
| Args: |
| image_processor (CLIPImageProcessor): Image processor for preprocessing images. |
| tokenizer (CLIPTokenizer): Tokenizer for preprocessing text. |
| |
| Attributes: |
| image_processor_class (str): Name of the image processor class. |
| tokenizer_class (str): Name of the tokenizer class. |
| """ |
|
|
| attributes = ["image_processor", "tokenizer"] |
| image_processor_class = "CLIPImageProcessor" |
| tokenizer_class = "CLIPTokenizer" |
|
|
| def __init__( |
| self, |
| image_processor: Optional[CLIPImageProcessor] = None, |
| tokenizer: Optional[CLIPTokenizer] = None, |
| **kwargs, |
| ): |
| if image_processor is None: |
| raise ValueError("You need to specify an `image_processor`.") |
| if tokenizer is None: |
| raise ValueError("You need to specify a `tokenizer`.") |
|
|
| super().__init__(image_processor, tokenizer) |
|
|
| def __call__( |
| self, |
| text: Union[str, List[str], None] = None, |
| images=None, |
| return_tensors: Optional[str] = "pt", |
| padding: Union[bool, str] = True, |
| max_length: Optional[int] = 248, |
| truncation: Optional[bool] = True, |
| **kwargs, |
| ): |
| """ |
| Preprocess text and images for LongCLIP model. |
| |
| Args: |
| text (str, List[str], optional): Text or list of texts to process. |
| images: Image or list of images to process. Can be PIL Image, numpy array, or tensor. |
| return_tensors (str, optional): Type of tensors to return ('pt' for PyTorch). |
| padding (bool or str, optional): Padding strategy. Defaults to True. |
| max_length (int, optional): Maximum sequence length. Defaults to 248 for LongCLIP. |
| truncation (bool, optional): Whether to truncate sequences. Defaults to True. |
| **kwargs: Additional keyword arguments. |
| |
| Returns: |
| BatchEncoding: Dictionary containing processed inputs with keys: |
| - input_ids: Tokenized text (if text provided) |
| - attention_mask: Attention mask for text (if text provided) |
| - pixel_values: Processed images (if images provided) |
| """ |
| |
| if text is not None: |
| text_inputs = self.tokenizer( |
| text, |
| return_tensors=return_tensors, |
| padding=padding, |
| max_length=max_length, |
| truncation=truncation, |
| **kwargs, |
| ) |
| else: |
| text_inputs = {} |
|
|
| |
| if images is not None: |
| image_inputs = self.image_processor( |
| images, |
| return_tensors=return_tensors, |
| ) |
| else: |
| image_inputs = {} |
|
|
| |
| return {**text_inputs, **image_inputs} |
|
|
| def batch_decode(self, *args, **kwargs): |
| """ |
| Decode token IDs back to text. |
| |
| This method is forwarded to the tokenizer's batch_decode method. |
| """ |
| return self.tokenizer.batch_decode(*args, **kwargs) |
|
|
| def decode(self, *args, **kwargs): |
| """ |
| Decode token IDs back to text. |
| |
| This method is forwarded to the tokenizer's decode method. |
| """ |
| return self.tokenizer.decode(*args, **kwargs) |
|
|
| @property |
| def model_input_names(self): |
| """ |
| Get the names of model inputs. |
| |
| Returns: |
| List[str]: List of input names. |
| """ |
| tokenizer_input_names = self.tokenizer.model_input_names |
| image_processor_input_names = self.image_processor.model_input_names |
| return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) |
|
|
|
|
| |
| from transformers import AutoConfig, AutoModel |
|
|
| AutoConfig.register("longclip", LongCLIPConfig) |
| AutoModel.register(LongCLIPConfig, LongCLIPModel) |
|
|