|
|
""" |
|
|
LongCLIP configuration classes. |
|
|
|
|
|
These configuration classes extend the standard CLIP configuration to support |
|
|
the extended context length and custom positional embeddings of LongCLIP. |
|
|
""" |
|
|
|
|
|
from typing import Dict, Any |
|
|
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig |
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
|
|
|
class LongCLIPTextConfig(CLIPTextConfig): |
|
|
""" |
|
|
Configuration class for LongCLIP text model. |
|
|
|
|
|
Extends CLIPTextConfig to support 248 token context length |
|
|
and custom positional embedding interpolation. |
|
|
|
|
|
Args: |
|
|
max_position_embeddings (int, optional): Maximum sequence length. Defaults to 248. |
|
|
use_position_interpolation (bool, optional): Whether to use position interpolation. |
|
|
Defaults to True. |
|
|
interpolation_keep_length (int, optional): Number of positions to keep from |
|
|
original embeddings before interpolation. Defaults to 20. |
|
|
**kwargs: Additional arguments passed to CLIPTextConfig. |
|
|
""" |
|
|
|
|
|
model_type = "longclip_text_model" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
max_position_embeddings: int = 248, |
|
|
use_position_interpolation: bool = True, |
|
|
interpolation_keep_length: int = 20, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(max_position_embeddings=max_position_embeddings, **kwargs) |
|
|
|
|
|
self.use_position_interpolation = use_position_interpolation |
|
|
self.interpolation_keep_length = interpolation_keep_length |
|
|
|
|
|
|
|
|
class LongCLIPVisionConfig(CLIPVisionConfig): |
|
|
""" |
|
|
Configuration class for LongCLIP vision model. |
|
|
|
|
|
This is identical to the standard CLIPVisionConfig as LongCLIP |
|
|
does not modify the vision encoder. |
|
|
|
|
|
Args: |
|
|
**kwargs: Arguments passed to CLIPVisionConfig. |
|
|
""" |
|
|
|
|
|
model_type = "longclip_vision_model" |
|
|
|
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
class LongCLIPConfig(CLIPConfig): |
|
|
""" |
|
|
Configuration class for LongCLIP model. |
|
|
|
|
|
Combines LongCLIPTextConfig and LongCLIPVisionConfig to create |
|
|
a complete LongCLIP model configuration. |
|
|
|
|
|
Args: |
|
|
text_config (Dict[str, Any] or LongCLIPTextConfig, optional): |
|
|
Configuration for the text model. If None, uses default LongCLIPTextConfig. |
|
|
vision_config (Dict[str, Any] or LongCLIPVisionConfig, optional): |
|
|
Configuration for the vision model. If None, uses default LongCLIPVisionConfig. |
|
|
projection_dim (int, optional): Dimensionality of text and vision projection layers. |
|
|
Defaults to 512. |
|
|
**kwargs: Additional arguments passed to CLIPConfig. |
|
|
|
|
|
Example: |
|
|
```python |
|
|
>>> from long_clip_hf import LongCLIPConfig |
|
|
>>> # Initialize with default settings |
|
|
>>> config = LongCLIPConfig() |
|
|
>>> |
|
|
>>> # Initialize with custom text config |
|
|
>>> text_config = {"max_position_embeddings": 248, "hidden_size": 512} |
|
|
>>> config = LongCLIPConfig(text_config=text_config) |
|
|
>>> |
|
|
>>> # Save config |
|
|
>>> config.save_pretrained("./my-longclip-config") |
|
|
>>> |
|
|
>>> # Load config |
|
|
>>> config = LongCLIPConfig.from_pretrained("./my-longclip-config") |
|
|
``` |
|
|
""" |
|
|
|
|
|
model_type = "longclip" |
|
|
is_composition = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
text_config: Dict[str, Any] | None = None, |
|
|
vision_config: Dict[str, Any] | None = None, |
|
|
projection_dim: int = 512, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
if text_config is None: |
|
|
text_config = {} |
|
|
logger.info( |
|
|
"text_config is None. Initializing the LongCLIPTextConfig with default values." |
|
|
) |
|
|
|
|
|
if vision_config is None: |
|
|
vision_config = {} |
|
|
logger.info( |
|
|
"vision_config is None. Initializing the LongCLIPVisionConfig with default values." |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(text_config, dict): |
|
|
text_config = LongCLIPTextConfig(**text_config) |
|
|
|
|
|
if isinstance(vision_config, dict): |
|
|
vision_config = LongCLIPVisionConfig(**vision_config) |
|
|
|
|
|
|
|
|
super().__init__( |
|
|
text_config=text_config.to_dict(), |
|
|
vision_config=vision_config.to_dict(), |
|
|
projection_dim=projection_dim, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
self.text_config = text_config |
|
|
self.vision_config = vision_config |
|
|
|
|
|
@classmethod |
|
|
def from_text_vision_configs( |
|
|
cls, |
|
|
text_config: LongCLIPTextConfig, |
|
|
vision_config: LongCLIPVisionConfig, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
Instantiate a LongCLIPConfig from text and vision configs. |
|
|
|
|
|
Args: |
|
|
text_config (LongCLIPTextConfig): Text model configuration. |
|
|
vision_config (LongCLIPVisionConfig): Vision model configuration. |
|
|
**kwargs: Additional keyword arguments. |
|
|
|
|
|
Returns: |
|
|
LongCLIPConfig: Configuration object. |
|
|
""" |
|
|
return cls( |
|
|
text_config=text_config.to_dict(), |
|
|
vision_config=vision_config.to_dict(), |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Serializes this instance to a Python dictionary. |
|
|
|
|
|
Returns: |
|
|
Dict[str, Any]: Dictionary of all attributes. |
|
|
""" |
|
|
output = super().to_dict() |
|
|
|
|
|
if hasattr(self, "text_config") and isinstance( |
|
|
self.text_config, PretrainedConfig |
|
|
): |
|
|
output["text_config"] = self.text_config.to_dict() |
|
|
if hasattr(self, "vision_config") and isinstance( |
|
|
self.vision_config, PretrainedConfig |
|
|
): |
|
|
output["vision_config"] = self.vision_config.to_dict() |
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|