|
|
""" |
|
|
LongCLIP model implementation compatible with HuggingFace Transformers. |
|
|
|
|
|
This module provides transformers-compatible implementations of LongCLIP models. |
|
|
""" |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import CLIPTextModel, CLIPVisionModel, CLIPModel |
|
|
from transformers.models.clip.modeling_clip import ( |
|
|
CLIPTextTransformer, |
|
|
) |
|
|
|
|
|
from .configuration_longclip import ( |
|
|
LongCLIPConfig, |
|
|
LongCLIPTextConfig, |
|
|
LongCLIPVisionConfig, |
|
|
) |
|
|
|
|
|
|
|
|
class LongCLIPTextEmbeddings(nn.Module): |
|
|
""" |
|
|
Text embeddings for LongCLIP with custom positional embedding mechanism. |
|
|
|
|
|
This module implements the dual positional embedding approach used in LongCLIP: |
|
|
- The first 20 positions use the original CLIP positional embeddings (mask1) |
|
|
- The remaining positions (21-248) use interpolated embeddings (mask2) |
|
|
- position_embedding: Fixed base embeddings |
|
|
- position_embedding_res: Trainable residual embeddings |
|
|
|
|
|
Args: |
|
|
config (LongCLIPTextConfig): Configuration for text embeddings. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LongCLIPTextConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
embed_dim = config.hidden_size |
|
|
|
|
|
|
|
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) |
|
|
|
|
|
|
|
|
|
|
|
self.position_embedding = nn.Embedding( |
|
|
config.max_position_embeddings, embed_dim |
|
|
) |
|
|
|
|
|
|
|
|
self.position_embedding_res = nn.Parameter( |
|
|
torch.zeros(config.max_position_embeddings, embed_dim) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer( |
|
|
"mask1", self._create_mask(config, use_first=True), persistent=False |
|
|
) |
|
|
self.register_buffer( |
|
|
"mask2", self._create_mask(config, use_first=False), persistent=False |
|
|
) |
|
|
|
|
|
|
|
|
self.register_buffer( |
|
|
"position_ids", |
|
|
torch.arange(config.max_position_embeddings).expand((1, -1)), |
|
|
persistent=False, |
|
|
) |
|
|
|
|
|
def _create_mask(self, config: LongCLIPTextConfig, use_first: bool) -> torch.Tensor: |
|
|
""" |
|
|
Create mask for positional embeddings. |
|
|
|
|
|
Args: |
|
|
config: Configuration object. |
|
|
use_first: If True, mask first `interpolation_keep_length` positions. |
|
|
If False, mask remaining positions. |
|
|
|
|
|
Returns: |
|
|
Mask tensor of shape [max_position_embeddings, 1]. |
|
|
""" |
|
|
mask = torch.zeros(config.max_position_embeddings, 1) |
|
|
if use_first: |
|
|
|
|
|
mask[: config.interpolation_keep_length] = 1.0 |
|
|
else: |
|
|
|
|
|
mask[config.interpolation_keep_length :] = 1.0 |
|
|
return mask |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass for text embeddings. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs of shape [batch_size, seq_length]. |
|
|
position_ids: Position IDs of shape [batch_size, seq_length]. |
|
|
inputs_embeds: Pre-computed token embeddings. |
|
|
|
|
|
Returns: |
|
|
Embeddings of shape [batch_size, seq_length, hidden_size]. |
|
|
""" |
|
|
seq_length = ( |
|
|
input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
|
|
) |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = self.position_ids[:, :seq_length] |
|
|
|
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.token_embedding(input_ids) |
|
|
|
|
|
|
|
|
position_embeddings = self.position_embedding(position_ids) |
|
|
|
|
|
|
|
|
|
|
|
position_embeddings_res = self.position_embedding_res.unsqueeze(0).expand( |
|
|
position_ids.shape[0], -1, -1 |
|
|
)[:, :seq_length, :] |
|
|
|
|
|
|
|
|
|
|
|
mask1 = self.mask1[:seq_length].transpose(0, 1) |
|
|
mask2 = self.mask2[:seq_length].transpose(0, 1) |
|
|
|
|
|
|
|
|
embeddings = ( |
|
|
inputs_embeds |
|
|
+ position_embeddings * mask1.unsqueeze(-1) |
|
|
+ position_embeddings_res * mask2.unsqueeze(-1) |
|
|
) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
class LongCLIPTextTransformer(CLIPTextTransformer): |
|
|
""" |
|
|
Text transformer for LongCLIP. |
|
|
|
|
|
This extends CLIPTextTransformer to use LongCLIPTextEmbeddings |
|
|
with custom positional embedding mechanism. |
|
|
|
|
|
Args: |
|
|
config (LongCLIPTextConfig): Configuration for text transformer. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: LongCLIPTextConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.embeddings = LongCLIPTextEmbeddings(config) |
|
|
|
|
|
|
|
|
class LongCLIPTextModel(CLIPTextModel): |
|
|
""" |
|
|
LongCLIP text model compatible with HuggingFace Transformers. |
|
|
|
|
|
This model extends CLIPTextModel to support 248 token context length |
|
|
with custom positional embedding interpolation. |
|
|
|
|
|
Args: |
|
|
config (LongCLIPTextConfig): Configuration for the text model. |
|
|
|
|
|
Example: |
|
|
```python |
|
|
>>> from long_clip_hf import LongCLIPTextConfig, LongCLIPTextModel |
|
|
>>> from transformers import CLIPTokenizer |
|
|
>>> |
|
|
>>> # Initialize model |
|
|
>>> config = LongCLIPTextConfig() |
|
|
>>> model = LongCLIPTextModel(config) |
|
|
>>> |
|
|
>>> # Tokenize text |
|
|
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
>>> inputs = tokenizer( |
|
|
... ["a photo of a cat"], |
|
|
... return_tensors="pt", |
|
|
... padding="max_length", |
|
|
... max_length=248, |
|
|
... truncation=True, |
|
|
... ) |
|
|
>>> |
|
|
>>> # Get text features |
|
|
>>> outputs = model(**inputs) |
|
|
>>> text_features = outputs.pooler_output |
|
|
``` |
|
|
""" |
|
|
|
|
|
config_class = LongCLIPTextConfig |
|
|
|
|
|
def __init__(self, config: LongCLIPTextConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
self.text_model = LongCLIPTextTransformer(config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
|
"""Get token embedding layer.""" |
|
|
return self.text_model.embeddings.token_embedding |
|
|
|
|
|
def set_input_embeddings(self, value: nn.Module): |
|
|
"""Set token embedding layer.""" |
|
|
self.text_model.embeddings.token_embedding = value |
|
|
|
|
|
|
|
|
class LongCLIPVisionModel(CLIPVisionModel): |
|
|
""" |
|
|
LongCLIP vision model. |
|
|
|
|
|
This is identical to CLIPVisionModel as LongCLIP does not modify |
|
|
the vision encoder. Provided for API consistency. |
|
|
|
|
|
Args: |
|
|
config (LongCLIPVisionConfig): Configuration for the vision model. |
|
|
|
|
|
Example: |
|
|
```python |
|
|
>>> from long_clip_hf import LongCLIPVisionConfig, LongCLIPVisionModel |
|
|
>>> from transformers import CLIPImageProcessor |
|
|
>>> from PIL import Image |
|
|
>>> |
|
|
>>> # Initialize model |
|
|
>>> config = LongCLIPVisionConfig() |
|
|
>>> model = LongCLIPVisionModel(config) |
|
|
>>> |
|
|
>>> # Process image |
|
|
>>> processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
>>> image = Image.open("path/to/image.jpg") |
|
|
>>> inputs = processor(images=image, return_tensors="pt") |
|
|
>>> |
|
|
>>> # Get image features |
|
|
>>> outputs = model(**inputs) |
|
|
>>> image_features = outputs.pooler_output |
|
|
``` |
|
|
""" |
|
|
|
|
|
config_class = LongCLIPVisionConfig |
|
|
|
|
|
|
|
|
class LongCLIPModel(CLIPModel): |
|
|
""" |
|
|
LongCLIP model combining text and vision encoders. |
|
|
|
|
|
This model extends CLIPModel to use LongCLIPTextModel with 248 token |
|
|
context length while keeping the standard vision encoder. |
|
|
|
|
|
Args: |
|
|
config (LongCLIPConfig): Configuration for the complete model. |
|
|
|
|
|
Example: |
|
|
```python |
|
|
>>> from long_clip_hf import LongCLIPConfig, LongCLIPModel |
|
|
>>> from transformers import CLIPTokenizer, CLIPImageProcessor |
|
|
>>> from PIL import Image |
|
|
>>> |
|
|
>>> # Initialize model |
|
|
>>> config = LongCLIPConfig() |
|
|
>>> model = LongCLIPModel(config) |
|
|
>>> |
|
|
>>> # Prepare inputs |
|
|
>>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
>>> processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
>>> |
|
|
>>> text = "a photo of a cat" |
|
|
>>> image = Image.open("path/to/image.jpg") |
|
|
>>> |
|
|
>>> text_inputs = tokenizer( |
|
|
... [text], |
|
|
... return_tensors="pt", |
|
|
... padding="max_length", |
|
|
... max_length=248, |
|
|
... truncation=True, |
|
|
... ) |
|
|
>>> image_inputs = processor(images=image, return_tensors="pt") |
|
|
>>> |
|
|
>>> # Get features |
|
|
>>> outputs = model( |
|
|
... input_ids=text_inputs["input_ids"], |
|
|
... pixel_values=image_inputs["pixel_values"], |
|
|
... ) |
|
|
>>> |
|
|
>>> # Compute similarity |
|
|
>>> logits_per_image = outputs.logits_per_image |
|
|
>>> probs = logits_per_image.softmax(dim=1) |
|
|
``` |
|
|
""" |
|
|
|
|
|
config_class = LongCLIPConfig |
|
|
|
|
|
def __init__(self, config: LongCLIPConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
if not isinstance(config.text_config, LongCLIPTextConfig): |
|
|
text_config = LongCLIPTextConfig(**config.text_config) |
|
|
else: |
|
|
text_config = config.text_config |
|
|
|
|
|
self.text_model = LongCLIPTextModel(text_config) |
|
|
|
|
|
|
|
|
if not isinstance(config.vision_config, LongCLIPVisionConfig): |
|
|
vision_config = LongCLIPVisionConfig(**config.vision_config) |
|
|
else: |
|
|
vision_config = config.vision_config |
|
|
|
|
|
self.vision_model = LongCLIPVisionModel(vision_config) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_text_features( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> torch.FloatTensor: |
|
|
""" |
|
|
Get text features from the text encoder. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs. |
|
|
attention_mask: Attention mask. |
|
|
position_ids: Position IDs. |
|
|
output_attentions: Whether to output attention weights. |
|
|
output_hidden_states: Whether to output hidden states. |
|
|
return_dict: Whether to return a ModelOutput object. |
|
|
|
|
|
Returns: |
|
|
Text features of shape [batch_size, projection_dim]. |
|
|
""" |
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
|
|
|
text_outputs = self.text_model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
pooled_output = ( |
|
|
text_outputs[1] if not return_dict else text_outputs.pooler_output |
|
|
) |
|
|
text_features = self.text_projection(pooled_output) |
|
|
|
|
|
return text_features |
|
|
|
|
|
def get_image_features( |
|
|
self, |
|
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> torch.FloatTensor: |
|
|
""" |
|
|
Get image features from the vision encoder. |
|
|
|
|
|
Args: |
|
|
pixel_values: Pixel values. |
|
|
output_attentions: Whether to output attention weights. |
|
|
output_hidden_states: Whether to output hidden states. |
|
|
return_dict: Whether to return a ModelOutput object. |
|
|
|
|
|
Returns: |
|
|
Image features of shape [batch_size, projection_dim]. |
|
|
""" |
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
|
|
|
vision_outputs = self.vision_model( |
|
|
pixel_values=pixel_values, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
pooled_output = ( |
|
|
vision_outputs[1] if not return_dict else vision_outputs.pooler_output |
|
|
) |
|
|
image_features = self.visual_projection(pooled_output) |
|
|
|
|
|
return image_features |
|
|
|