shunk031 commited on
Commit
56d2b5b
·
verified ·
1 Parent(s): 4316696

Upload configuration_longclip.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. configuration_longclip.py +184 -0
configuration_longclip.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LongCLIP configuration classes.
3
+
4
+ These configuration classes extend the standard CLIP configuration to support
5
+ the extended context length and custom positional embeddings of LongCLIP.
6
+ """
7
+
8
+ from typing import Dict, Any
9
+ from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
10
+ from transformers.configuration_utils import PretrainedConfig
11
+
12
+
13
+ class LongCLIPTextConfig(CLIPTextConfig):
14
+ """
15
+ Configuration class for LongCLIP text model.
16
+
17
+ Extends CLIPTextConfig to support 248 token context length
18
+ and custom positional embedding interpolation.
19
+
20
+ Args:
21
+ max_position_embeddings (int, optional): Maximum sequence length. Defaults to 248.
22
+ use_position_interpolation (bool, optional): Whether to use position interpolation.
23
+ Defaults to True.
24
+ interpolation_keep_length (int, optional): Number of positions to keep from
25
+ original embeddings before interpolation. Defaults to 20.
26
+ **kwargs: Additional arguments passed to CLIPTextConfig.
27
+ """
28
+
29
+ model_type = "longclip_text_model"
30
+
31
+ def __init__(
32
+ self,
33
+ max_position_embeddings: int = 248,
34
+ use_position_interpolation: bool = True,
35
+ interpolation_keep_length: int = 20,
36
+ **kwargs,
37
+ ):
38
+ super().__init__(max_position_embeddings=max_position_embeddings, **kwargs)
39
+
40
+ self.use_position_interpolation = use_position_interpolation
41
+ self.interpolation_keep_length = interpolation_keep_length
42
+
43
+
44
+ class LongCLIPVisionConfig(CLIPVisionConfig):
45
+ """
46
+ Configuration class for LongCLIP vision model.
47
+
48
+ This is identical to the standard CLIPVisionConfig as LongCLIP
49
+ does not modify the vision encoder.
50
+
51
+ Args:
52
+ **kwargs: Arguments passed to CLIPVisionConfig.
53
+ """
54
+
55
+ model_type = "longclip_vision_model"
56
+
57
+ def __init__(self, **kwargs):
58
+ super().__init__(**kwargs)
59
+
60
+
61
+ class LongCLIPConfig(CLIPConfig):
62
+ """
63
+ Configuration class for LongCLIP model.
64
+
65
+ Combines LongCLIPTextConfig and LongCLIPVisionConfig to create
66
+ a complete LongCLIP model configuration.
67
+
68
+ Args:
69
+ text_config (Dict[str, Any] or LongCLIPTextConfig, optional):
70
+ Configuration for the text model. If None, uses default LongCLIPTextConfig.
71
+ vision_config (Dict[str, Any] or LongCLIPVisionConfig, optional):
72
+ Configuration for the vision model. If None, uses default LongCLIPVisionConfig.
73
+ projection_dim (int, optional): Dimensionality of text and vision projection layers.
74
+ Defaults to 512.
75
+ **kwargs: Additional arguments passed to CLIPConfig.
76
+
77
+ Example:
78
+ ```python
79
+ >>> from long_clip_hf import LongCLIPConfig
80
+ >>> # Initialize with default settings
81
+ >>> config = LongCLIPConfig()
82
+ >>>
83
+ >>> # Initialize with custom text config
84
+ >>> text_config = {"max_position_embeddings": 248, "hidden_size": 512}
85
+ >>> config = LongCLIPConfig(text_config=text_config)
86
+ >>>
87
+ >>> # Save config
88
+ >>> config.save_pretrained("./my-longclip-config")
89
+ >>>
90
+ >>> # Load config
91
+ >>> config = LongCLIPConfig.from_pretrained("./my-longclip-config")
92
+ ```
93
+ """
94
+
95
+ model_type = "longclip"
96
+ is_composition = True
97
+
98
+ def __init__(
99
+ self,
100
+ text_config: Dict[str, Any] | None = None,
101
+ vision_config: Dict[str, Any] | None = None,
102
+ projection_dim: int = 512,
103
+ **kwargs,
104
+ ):
105
+ # Initialize text config
106
+ if text_config is None:
107
+ text_config = {}
108
+ logger.info(
109
+ "text_config is None. Initializing the LongCLIPTextConfig with default values."
110
+ )
111
+
112
+ if vision_config is None:
113
+ vision_config = {}
114
+ logger.info(
115
+ "vision_config is None. Initializing the LongCLIPVisionConfig with default values."
116
+ )
117
+
118
+ # Create config objects if they're dictionaries
119
+ if isinstance(text_config, dict):
120
+ text_config = LongCLIPTextConfig(**text_config)
121
+
122
+ if isinstance(vision_config, dict):
123
+ vision_config = LongCLIPVisionConfig(**vision_config)
124
+
125
+ # Call parent init with config dicts
126
+ super().__init__(
127
+ text_config=text_config.to_dict(),
128
+ vision_config=vision_config.to_dict(),
129
+ projection_dim=projection_dim,
130
+ **kwargs,
131
+ )
132
+
133
+ # Store as config objects for easier access
134
+ self.text_config = text_config
135
+ self.vision_config = vision_config
136
+
137
+ @classmethod
138
+ def from_text_vision_configs(
139
+ cls,
140
+ text_config: LongCLIPTextConfig,
141
+ vision_config: LongCLIPVisionConfig,
142
+ **kwargs,
143
+ ):
144
+ """
145
+ Instantiate a LongCLIPConfig from text and vision configs.
146
+
147
+ Args:
148
+ text_config (LongCLIPTextConfig): Text model configuration.
149
+ vision_config (LongCLIPVisionConfig): Vision model configuration.
150
+ **kwargs: Additional keyword arguments.
151
+
152
+ Returns:
153
+ LongCLIPConfig: Configuration object.
154
+ """
155
+ return cls(
156
+ text_config=text_config.to_dict(),
157
+ vision_config=vision_config.to_dict(),
158
+ **kwargs,
159
+ )
160
+
161
+ def to_dict(self) -> Dict[str, Any]:
162
+ """
163
+ Serializes this instance to a Python dictionary.
164
+
165
+ Returns:
166
+ Dict[str, Any]: Dictionary of all attributes.
167
+ """
168
+ output = super().to_dict()
169
+ # Ensure text_config and vision_config are properly serialized
170
+ if hasattr(self, "text_config") and isinstance(
171
+ self.text_config, PretrainedConfig
172
+ ):
173
+ output["text_config"] = self.text_config.to_dict()
174
+ if hasattr(self, "vision_config") and isinstance(
175
+ self.vision_config, PretrainedConfig
176
+ ):
177
+ output["vision_config"] = self.vision_config.to_dict()
178
+ return output
179
+
180
+
181
+ # For logging
182
+ import logging
183
+
184
+ logger = logging.getLogger(__name__)