File size: 5,997 Bytes
baeaa93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""
LongCLIP configuration classes.

These configuration classes extend the standard CLIP configuration to support
the extended context length and custom positional embeddings of LongCLIP.
"""

from typing import Dict, Any
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.configuration_utils import PretrainedConfig


class LongCLIPTextConfig(CLIPTextConfig):
    """
    Configuration class for LongCLIP text model.

    Extends CLIPTextConfig to support 248 token context length
    and custom positional embedding interpolation.

    Args:
        max_position_embeddings (int, optional): Maximum sequence length. Defaults to 248.
        use_position_interpolation (bool, optional): Whether to use position interpolation.
            Defaults to True.
        interpolation_keep_length (int, optional): Number of positions to keep from
            original embeddings before interpolation. Defaults to 20.
        **kwargs: Additional arguments passed to CLIPTextConfig.
    """

    model_type = "longclip_text_model"

    def __init__(
        self,
        max_position_embeddings: int = 248,
        use_position_interpolation: bool = True,
        interpolation_keep_length: int = 20,
        **kwargs,
    ):
        super().__init__(max_position_embeddings=max_position_embeddings, **kwargs)

        self.use_position_interpolation = use_position_interpolation
        self.interpolation_keep_length = interpolation_keep_length


class LongCLIPVisionConfig(CLIPVisionConfig):
    """
    Configuration class for LongCLIP vision model.

    This is identical to the standard CLIPVisionConfig as LongCLIP
    does not modify the vision encoder.

    Args:
        **kwargs: Arguments passed to CLIPVisionConfig.
    """

    model_type = "longclip_vision_model"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)


class LongCLIPConfig(CLIPConfig):
    """
    Configuration class for LongCLIP model.

    Combines LongCLIPTextConfig and LongCLIPVisionConfig to create
    a complete LongCLIP model configuration.

    Args:
        text_config (Dict[str, Any] or LongCLIPTextConfig, optional):
            Configuration for the text model. If None, uses default LongCLIPTextConfig.
        vision_config (Dict[str, Any] or LongCLIPVisionConfig, optional):
            Configuration for the vision model. If None, uses default LongCLIPVisionConfig.
        projection_dim (int, optional): Dimensionality of text and vision projection layers.
            Defaults to 512.
        **kwargs: Additional arguments passed to CLIPConfig.

    Example:
        ```python
        >>> from long_clip_hf import LongCLIPConfig
        >>> # Initialize with default settings
        >>> config = LongCLIPConfig()
        >>>
        >>> # Initialize with custom text config
        >>> text_config = {"max_position_embeddings": 248, "hidden_size": 512}
        >>> config = LongCLIPConfig(text_config=text_config)
        >>>
        >>> # Save config
        >>> config.save_pretrained("./my-longclip-config")
        >>>
        >>> # Load config
        >>> config = LongCLIPConfig.from_pretrained("./my-longclip-config")
        ```
    """

    model_type = "longclip"
    is_composition = True

    def __init__(
        self,
        text_config: Dict[str, Any] | None = None,
        vision_config: Dict[str, Any] | None = None,
        projection_dim: int = 512,
        **kwargs,
    ):
        # Initialize text config
        if text_config is None:
            text_config = {}
            logger.info(
                "text_config is None. Initializing the LongCLIPTextConfig with default values."
            )

        if vision_config is None:
            vision_config = {}
            logger.info(
                "vision_config is None. Initializing the LongCLIPVisionConfig with default values."
            )

        # Create config objects if they're dictionaries
        if isinstance(text_config, dict):
            text_config = LongCLIPTextConfig(**text_config)

        if isinstance(vision_config, dict):
            vision_config = LongCLIPVisionConfig(**vision_config)

        # Call parent init with config dicts
        super().__init__(
            text_config=text_config.to_dict(),
            vision_config=vision_config.to_dict(),
            projection_dim=projection_dim,
            **kwargs,
        )

        # Store as config objects for easier access
        self.text_config = text_config
        self.vision_config = vision_config

    @classmethod
    def from_text_vision_configs(
        cls,
        text_config: LongCLIPTextConfig,
        vision_config: LongCLIPVisionConfig,
        **kwargs,
    ):
        """
        Instantiate a LongCLIPConfig from text and vision configs.

        Args:
            text_config (LongCLIPTextConfig): Text model configuration.
            vision_config (LongCLIPVisionConfig): Vision model configuration.
            **kwargs: Additional keyword arguments.

        Returns:
            LongCLIPConfig: Configuration object.
        """
        return cls(
            text_config=text_config.to_dict(),
            vision_config=vision_config.to_dict(),
            **kwargs,
        )

    def to_dict(self) -> Dict[str, Any]:
        """
        Serializes this instance to a Python dictionary.

        Returns:
            Dict[str, Any]: Dictionary of all attributes.
        """
        output = super().to_dict()
        # Ensure text_config and vision_config are properly serialized
        if hasattr(self, "text_config") and isinstance(
            self.text_config, PretrainedConfig
        ):
            output["text_config"] = self.text_config.to_dict()
        if hasattr(self, "vision_config") and isinstance(
            self.vision_config, PretrainedConfig
        ):
            output["vision_config"] = self.vision_config.to_dict()
        return output


# For logging
import logging

logger = logging.getLogger(__name__)