File size: 5,997 Bytes
baeaa93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
"""
LongCLIP configuration classes.
These configuration classes extend the standard CLIP configuration to support
the extended context length and custom positional embeddings of LongCLIP.
"""
from typing import Dict, Any
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.configuration_utils import PretrainedConfig
class LongCLIPTextConfig(CLIPTextConfig):
"""
Configuration class for LongCLIP text model.
Extends CLIPTextConfig to support 248 token context length
and custom positional embedding interpolation.
Args:
max_position_embeddings (int, optional): Maximum sequence length. Defaults to 248.
use_position_interpolation (bool, optional): Whether to use position interpolation.
Defaults to True.
interpolation_keep_length (int, optional): Number of positions to keep from
original embeddings before interpolation. Defaults to 20.
**kwargs: Additional arguments passed to CLIPTextConfig.
"""
model_type = "longclip_text_model"
def __init__(
self,
max_position_embeddings: int = 248,
use_position_interpolation: bool = True,
interpolation_keep_length: int = 20,
**kwargs,
):
super().__init__(max_position_embeddings=max_position_embeddings, **kwargs)
self.use_position_interpolation = use_position_interpolation
self.interpolation_keep_length = interpolation_keep_length
class LongCLIPVisionConfig(CLIPVisionConfig):
"""
Configuration class for LongCLIP vision model.
This is identical to the standard CLIPVisionConfig as LongCLIP
does not modify the vision encoder.
Args:
**kwargs: Arguments passed to CLIPVisionConfig.
"""
model_type = "longclip_vision_model"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class LongCLIPConfig(CLIPConfig):
"""
Configuration class for LongCLIP model.
Combines LongCLIPTextConfig and LongCLIPVisionConfig to create
a complete LongCLIP model configuration.
Args:
text_config (Dict[str, Any] or LongCLIPTextConfig, optional):
Configuration for the text model. If None, uses default LongCLIPTextConfig.
vision_config (Dict[str, Any] or LongCLIPVisionConfig, optional):
Configuration for the vision model. If None, uses default LongCLIPVisionConfig.
projection_dim (int, optional): Dimensionality of text and vision projection layers.
Defaults to 512.
**kwargs: Additional arguments passed to CLIPConfig.
Example:
```python
>>> from long_clip_hf import LongCLIPConfig
>>> # Initialize with default settings
>>> config = LongCLIPConfig()
>>>
>>> # Initialize with custom text config
>>> text_config = {"max_position_embeddings": 248, "hidden_size": 512}
>>> config = LongCLIPConfig(text_config=text_config)
>>>
>>> # Save config
>>> config.save_pretrained("./my-longclip-config")
>>>
>>> # Load config
>>> config = LongCLIPConfig.from_pretrained("./my-longclip-config")
```
"""
model_type = "longclip"
is_composition = True
def __init__(
self,
text_config: Dict[str, Any] | None = None,
vision_config: Dict[str, Any] | None = None,
projection_dim: int = 512,
**kwargs,
):
# Initialize text config
if text_config is None:
text_config = {}
logger.info(
"text_config is None. Initializing the LongCLIPTextConfig with default values."
)
if vision_config is None:
vision_config = {}
logger.info(
"vision_config is None. Initializing the LongCLIPVisionConfig with default values."
)
# Create config objects if they're dictionaries
if isinstance(text_config, dict):
text_config = LongCLIPTextConfig(**text_config)
if isinstance(vision_config, dict):
vision_config = LongCLIPVisionConfig(**vision_config)
# Call parent init with config dicts
super().__init__(
text_config=text_config.to_dict(),
vision_config=vision_config.to_dict(),
projection_dim=projection_dim,
**kwargs,
)
# Store as config objects for easier access
self.text_config = text_config
self.vision_config = vision_config
@classmethod
def from_text_vision_configs(
cls,
text_config: LongCLIPTextConfig,
vision_config: LongCLIPVisionConfig,
**kwargs,
):
"""
Instantiate a LongCLIPConfig from text and vision configs.
Args:
text_config (LongCLIPTextConfig): Text model configuration.
vision_config (LongCLIPVisionConfig): Vision model configuration.
**kwargs: Additional keyword arguments.
Returns:
LongCLIPConfig: Configuration object.
"""
return cls(
text_config=text_config.to_dict(),
vision_config=vision_config.to_dict(),
**kwargs,
)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
Returns:
Dict[str, Any]: Dictionary of all attributes.
"""
output = super().to_dict()
# Ensure text_config and vision_config are properly serialized
if hasattr(self, "text_config") and isinstance(
self.text_config, PretrainedConfig
):
output["text_config"] = self.text_config.to_dict()
if hasattr(self, "vision_config") and isinstance(
self.vision_config, PretrainedConfig
):
output["vision_config"] = self.vision_config.to_dict()
return output
# For logging
import logging
logger = logging.getLogger(__name__)
|