LongCLIP-L / modeling_longclip.py
shunk031's picture
Upload modeling_longclip.py with huggingface_hub
ccd172a verified
"""
LongCLIP model implementation compatible with HuggingFace Transformers.
This module provides transformers-compatible implementations of LongCLIP models.
"""
from typing import Optional
import torch
import torch.nn as nn
from transformers import CLIPTextModel, CLIPVisionModel, CLIPModel
from transformers.models.clip.modeling_clip import (
CLIPTextTransformer,
)
from .configuration_longclip import (
LongCLIPConfig,
LongCLIPTextConfig,
LongCLIPVisionConfig,
)
class LongCLIPTextEmbeddings(nn.Module):
"""
Text embeddings for LongCLIP with custom positional embedding mechanism.
This module implements the dual positional embedding approach used in LongCLIP:
- The first 20 positions use the original CLIP positional embeddings (mask1)
- The remaining positions (21-248) use interpolated embeddings (mask2)
- position_embedding: Fixed base embeddings
- position_embedding_res: Trainable residual embeddings
Args:
config (LongCLIPTextConfig): Configuration for text embeddings.
"""
def __init__(self, config: LongCLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
# Token embeddings
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
# Dual positional embeddings (LongCLIP approach)
# position_embedding: Base embeddings (typically loaded from checkpoint)
self.position_embedding = nn.Embedding(
config.max_position_embeddings, embed_dim
)
# position_embedding_res: Trainable residual embeddings
self.position_embedding_res = nn.Parameter(
torch.zeros(config.max_position_embeddings, embed_dim)
)
# Create masks for applying embeddings
# mask1: Use original embeddings for first interpolation_keep_length positions
# mask2: Use interpolated embeddings for remaining positions
self.register_buffer(
"mask1", self._create_mask(config, use_first=True), persistent=False
)
self.register_buffer(
"mask2", self._create_mask(config, use_first=False), persistent=False
)
# Store position IDs for efficiency
self.register_buffer(
"position_ids",
torch.arange(config.max_position_embeddings).expand((1, -1)),
persistent=False,
)
def _create_mask(self, config: LongCLIPTextConfig, use_first: bool) -> torch.Tensor:
"""
Create mask for positional embeddings.
Args:
config: Configuration object.
use_first: If True, mask first `interpolation_keep_length` positions.
If False, mask remaining positions.
Returns:
Mask tensor of shape [max_position_embeddings, 1].
"""
mask = torch.zeros(config.max_position_embeddings, 1)
if use_first:
# mask1: First interpolation_keep_length positions
mask[: config.interpolation_keep_length] = 1.0
else:
# mask2: Remaining positions
mask[config.interpolation_keep_length :] = 1.0
return mask
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
"""
Forward pass for text embeddings.
Args:
input_ids: Token IDs of shape [batch_size, seq_length].
position_ids: Position IDs of shape [batch_size, seq_length].
inputs_embeds: Pre-computed token embeddings.
Returns:
Embeddings of shape [batch_size, seq_length, hidden_size].
"""
seq_length = (
input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
# Get token embeddings
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
# Get positional embeddings
position_embeddings = self.position_embedding(position_ids)
# Add residual positional embeddings (for positions > interpolation_keep_length)
# Expand position_embedding_res for batch dimension
position_embeddings_res = self.position_embedding_res.unsqueeze(0).expand(
position_ids.shape[0], -1, -1
)[:, :seq_length, :]
# Apply masks: mask1 for first 20, mask2 for rest
# Broadcasting: [seq_length, 1] * [batch, seq_length, hidden_size]
mask1 = self.mask1[:seq_length].transpose(0, 1) # [1, seq_length]
mask2 = self.mask2[:seq_length].transpose(0, 1) # [1, seq_length]
# Combine embeddings with masking
embeddings = (
inputs_embeds
+ position_embeddings * mask1.unsqueeze(-1)
+ position_embeddings_res * mask2.unsqueeze(-1)
)
return embeddings
class LongCLIPTextTransformer(CLIPTextTransformer):
"""
Text transformer for LongCLIP.
This extends CLIPTextTransformer to use LongCLIPTextEmbeddings
with custom positional embedding mechanism.
Args:
config (LongCLIPTextConfig): Configuration for text transformer.
"""
def __init__(self, config: LongCLIPTextConfig):
super().__init__(config)
# Replace embeddings with LongCLIP version
self.embeddings = LongCLIPTextEmbeddings(config)
class LongCLIPTextModel(CLIPTextModel):
"""
LongCLIP text model compatible with HuggingFace Transformers.
This model extends CLIPTextModel to support 248 token context length
with custom positional embedding interpolation.
Args:
config (LongCLIPTextConfig): Configuration for the text model.
Example:
```python
>>> from long_clip_hf import LongCLIPTextConfig, LongCLIPTextModel
>>> from transformers import CLIPTokenizer
>>>
>>> # Initialize model
>>> config = LongCLIPTextConfig()
>>> model = LongCLIPTextModel(config)
>>>
>>> # Tokenize text
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> inputs = tokenizer(
... ["a photo of a cat"],
... return_tensors="pt",
... padding="max_length",
... max_length=248,
... truncation=True,
... )
>>>
>>> # Get text features
>>> outputs = model(**inputs)
>>> text_features = outputs.pooler_output
```
"""
config_class = LongCLIPTextConfig
def __init__(self, config: LongCLIPTextConfig):
super().__init__(config)
# Replace text_model with LongCLIP version
self.text_model = LongCLIPTextTransformer(config)
# Initialize weights
self.post_init()
def get_input_embeddings(self) -> nn.Module:
"""Get token embedding layer."""
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value: nn.Module):
"""Set token embedding layer."""
self.text_model.embeddings.token_embedding = value
class LongCLIPVisionModel(CLIPVisionModel):
"""
LongCLIP vision model.
This is identical to CLIPVisionModel as LongCLIP does not modify
the vision encoder. Provided for API consistency.
Args:
config (LongCLIPVisionConfig): Configuration for the vision model.
Example:
```python
>>> from long_clip_hf import LongCLIPVisionConfig, LongCLIPVisionModel
>>> from transformers import CLIPImageProcessor
>>> from PIL import Image
>>>
>>> # Initialize model
>>> config = LongCLIPVisionConfig()
>>> model = LongCLIPVisionModel(config)
>>>
>>> # Process image
>>> processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>> image = Image.open("path/to/image.jpg")
>>> inputs = processor(images=image, return_tensors="pt")
>>>
>>> # Get image features
>>> outputs = model(**inputs)
>>> image_features = outputs.pooler_output
```
"""
config_class = LongCLIPVisionConfig
class LongCLIPModel(CLIPModel):
"""
LongCLIP model combining text and vision encoders.
This model extends CLIPModel to use LongCLIPTextModel with 248 token
context length while keeping the standard vision encoder.
Args:
config (LongCLIPConfig): Configuration for the complete model.
Example:
```python
>>> from long_clip_hf import LongCLIPConfig, LongCLIPModel
>>> from transformers import CLIPTokenizer, CLIPImageProcessor
>>> from PIL import Image
>>>
>>> # Initialize model
>>> config = LongCLIPConfig()
>>> model = LongCLIPModel(config)
>>>
>>> # Prepare inputs
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
>>> processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
>>>
>>> text = "a photo of a cat"
>>> image = Image.open("path/to/image.jpg")
>>>
>>> text_inputs = tokenizer(
... [text],
... return_tensors="pt",
... padding="max_length",
... max_length=248,
... truncation=True,
... )
>>> image_inputs = processor(images=image, return_tensors="pt")
>>>
>>> # Get features
>>> outputs = model(
... input_ids=text_inputs["input_ids"],
... pixel_values=image_inputs["pixel_values"],
... )
>>>
>>> # Compute similarity
>>> logits_per_image = outputs.logits_per_image
>>> probs = logits_per_image.softmax(dim=1)
```
"""
config_class = LongCLIPConfig
def __init__(self, config: LongCLIPConfig):
super().__init__(config)
# Replace text model with LongCLIP version
if not isinstance(config.text_config, LongCLIPTextConfig):
text_config = LongCLIPTextConfig(**config.text_config)
else:
text_config = config.text_config
self.text_model = LongCLIPTextModel(text_config)
# Vision model stays the same (standard CLIP)
if not isinstance(config.vision_config, LongCLIPVisionConfig):
vision_config = LongCLIPVisionConfig(**config.vision_config)
else:
vision_config = config.vision_config
self.vision_model = LongCLIPVisionModel(vision_config)
# Initialize weights
self.post_init()
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
"""
Get text features from the text encoder.
Args:
input_ids: Token IDs.
attention_mask: Attention mask.
position_ids: Position IDs.
output_attentions: Whether to output attention weights.
output_hidden_states: Whether to output hidden states.
return_dict: Whether to return a ModelOutput object.
Returns:
Text features of shape [batch_size, projection_dim].
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = (
text_outputs[1] if not return_dict else text_outputs.pooler_output
)
text_features = self.text_projection(pooled_output)
return text_features
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
"""
Get image features from the vision encoder.
Args:
pixel_values: Pixel values.
output_attentions: Whether to output attention weights.
output_hidden_states: Whether to output hidden states.
return_dict: Whether to return a ModelOutput object.
Returns:
Image features of shape [batch_size, projection_dim].
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = (
vision_outputs[1] if not return_dict else vision_outputs.pooler_output
)
image_features = self.visual_projection(pooled_output)
return image_features