shunk031 commited on
Commit
4316696
·
verified ·
1 Parent(s): 8b61a4a

Upload modeling_longclip.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_longclip.py +400 -0
modeling_longclip.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LongCLIP model implementation compatible with HuggingFace Transformers.
3
+
4
+ This module provides transformers-compatible implementations of LongCLIP models.
5
+ """
6
+
7
+ from typing import Optional
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from transformers import CLIPTextModel, CLIPVisionModel, CLIPModel
12
+ from transformers.models.clip.modeling_clip import (
13
+ CLIPTextTransformer,
14
+ )
15
+
16
+ from .configuration_longclip import (
17
+ LongCLIPConfig,
18
+ LongCLIPTextConfig,
19
+ LongCLIPVisionConfig,
20
+ )
21
+
22
+
23
+ class LongCLIPTextEmbeddings(nn.Module):
24
+ """
25
+ Text embeddings for LongCLIP with custom positional embedding mechanism.
26
+
27
+ This module implements the dual positional embedding approach used in LongCLIP:
28
+ - The first 20 positions use the original CLIP positional embeddings (mask1)
29
+ - The remaining positions (21-248) use interpolated embeddings (mask2)
30
+ - position_embedding: Fixed base embeddings
31
+ - position_embedding_res: Trainable residual embeddings
32
+
33
+ Args:
34
+ config (LongCLIPTextConfig): Configuration for text embeddings.
35
+ """
36
+
37
+ def __init__(self, config: LongCLIPTextConfig):
38
+ super().__init__()
39
+ self.config = config
40
+ embed_dim = config.hidden_size
41
+
42
+ # Token embeddings
43
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
44
+
45
+ # Dual positional embeddings (LongCLIP approach)
46
+ # position_embedding: Base embeddings (typically loaded from checkpoint)
47
+ self.position_embedding = nn.Embedding(
48
+ config.max_position_embeddings, embed_dim
49
+ )
50
+
51
+ # position_embedding_res: Trainable residual embeddings
52
+ self.position_embedding_res = nn.Parameter(
53
+ torch.zeros(config.max_position_embeddings, embed_dim)
54
+ )
55
+
56
+ # Create masks for applying embeddings
57
+ # mask1: Use original embeddings for first interpolation_keep_length positions
58
+ # mask2: Use interpolated embeddings for remaining positions
59
+ self.register_buffer(
60
+ "mask1", self._create_mask(config, use_first=True), persistent=False
61
+ )
62
+ self.register_buffer(
63
+ "mask2", self._create_mask(config, use_first=False), persistent=False
64
+ )
65
+
66
+ # Store position IDs for efficiency
67
+ self.register_buffer(
68
+ "position_ids",
69
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
70
+ persistent=False,
71
+ )
72
+
73
+ def _create_mask(self, config: LongCLIPTextConfig, use_first: bool) -> torch.Tensor:
74
+ """
75
+ Create mask for positional embeddings.
76
+
77
+ Args:
78
+ config: Configuration object.
79
+ use_first: If True, mask first `interpolation_keep_length` positions.
80
+ If False, mask remaining positions.
81
+
82
+ Returns:
83
+ Mask tensor of shape [max_position_embeddings, 1].
84
+ """
85
+ mask = torch.zeros(config.max_position_embeddings, 1)
86
+ if use_first:
87
+ # mask1: First interpolation_keep_length positions
88
+ mask[: config.interpolation_keep_length] = 1.0
89
+ else:
90
+ # mask2: Remaining positions
91
+ mask[config.interpolation_keep_length :] = 1.0
92
+ return mask
93
+
94
+ def forward(
95
+ self,
96
+ input_ids: Optional[torch.LongTensor] = None,
97
+ position_ids: Optional[torch.LongTensor] = None,
98
+ inputs_embeds: Optional[torch.FloatTensor] = None,
99
+ ) -> torch.Tensor:
100
+ """
101
+ Forward pass for text embeddings.
102
+
103
+ Args:
104
+ input_ids: Token IDs of shape [batch_size, seq_length].
105
+ position_ids: Position IDs of shape [batch_size, seq_length].
106
+ inputs_embeds: Pre-computed token embeddings.
107
+
108
+ Returns:
109
+ Embeddings of shape [batch_size, seq_length, hidden_size].
110
+ """
111
+ seq_length = (
112
+ input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
113
+ )
114
+
115
+ if position_ids is None:
116
+ position_ids = self.position_ids[:, :seq_length]
117
+
118
+ # Get token embeddings
119
+ if inputs_embeds is None:
120
+ inputs_embeds = self.token_embedding(input_ids)
121
+
122
+ # Get positional embeddings
123
+ position_embeddings = self.position_embedding(position_ids)
124
+
125
+ # Add residual positional embeddings (for positions > interpolation_keep_length)
126
+ # Expand position_embedding_res for batch dimension
127
+ position_embeddings_res = self.position_embedding_res.unsqueeze(0).expand(
128
+ position_ids.shape[0], -1, -1
129
+ )[:, :seq_length, :]
130
+
131
+ # Apply masks: mask1 for first 20, mask2 for rest
132
+ # Broadcasting: [seq_length, 1] * [batch, seq_length, hidden_size]
133
+ mask1 = self.mask1[:seq_length].transpose(0, 1) # [1, seq_length]
134
+ mask2 = self.mask2[:seq_length].transpose(0, 1) # [1, seq_length]
135
+
136
+ # Combine embeddings with masking
137
+ embeddings = (
138
+ inputs_embeds
139
+ + position_embeddings * mask1.unsqueeze(-1)
140
+ + position_embeddings_res * mask2.unsqueeze(-1)
141
+ )
142
+
143
+ return embeddings
144
+
145
+
146
+ class LongCLIPTextTransformer(CLIPTextTransformer):
147
+ """
148
+ Text transformer for LongCLIP.
149
+
150
+ This extends CLIPTextTransformer to use LongCLIPTextEmbeddings
151
+ with custom positional embedding mechanism.
152
+
153
+ Args:
154
+ config (LongCLIPTextConfig): Configuration for text transformer.
155
+ """
156
+
157
+ def __init__(self, config: LongCLIPTextConfig):
158
+ super().__init__(config)
159
+ # Replace embeddings with LongCLIP version
160
+ self.embeddings = LongCLIPTextEmbeddings(config)
161
+
162
+
163
+ class LongCLIPTextModel(CLIPTextModel):
164
+ """
165
+ LongCLIP text model compatible with HuggingFace Transformers.
166
+
167
+ This model extends CLIPTextModel to support 248 token context length
168
+ with custom positional embedding interpolation.
169
+
170
+ Args:
171
+ config (LongCLIPTextConfig): Configuration for the text model.
172
+
173
+ Example:
174
+ ```python
175
+ >>> from long_clip_hf import LongCLIPTextConfig, LongCLIPTextModel
176
+ >>> from transformers import CLIPTokenizer
177
+ >>>
178
+ >>> # Initialize model
179
+ >>> config = LongCLIPTextConfig()
180
+ >>> model = LongCLIPTextModel(config)
181
+ >>>
182
+ >>> # Tokenize text
183
+ >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
184
+ >>> inputs = tokenizer(
185
+ ... ["a photo of a cat"],
186
+ ... return_tensors="pt",
187
+ ... padding="max_length",
188
+ ... max_length=248,
189
+ ... truncation=True,
190
+ ... )
191
+ >>>
192
+ >>> # Get text features
193
+ >>> outputs = model(**inputs)
194
+ >>> text_features = outputs.pooler_output
195
+ ```
196
+ """
197
+
198
+ config_class = LongCLIPTextConfig
199
+
200
+ def __init__(self, config: LongCLIPTextConfig):
201
+ super().__init__(config)
202
+ # Replace text_model with LongCLIP version
203
+ self.text_model = LongCLIPTextTransformer(config)
204
+ # Initialize weights
205
+ self.post_init()
206
+
207
+ def get_input_embeddings(self) -> nn.Module:
208
+ """Get token embedding layer."""
209
+ return self.text_model.embeddings.token_embedding
210
+
211
+ def set_input_embeddings(self, value: nn.Module):
212
+ """Set token embedding layer."""
213
+ self.text_model.embeddings.token_embedding = value
214
+
215
+
216
+ class LongCLIPVisionModel(CLIPVisionModel):
217
+ """
218
+ LongCLIP vision model.
219
+
220
+ This is identical to CLIPVisionModel as LongCLIP does not modify
221
+ the vision encoder. Provided for API consistency.
222
+
223
+ Args:
224
+ config (LongCLIPVisionConfig): Configuration for the vision model.
225
+
226
+ Example:
227
+ ```python
228
+ >>> from long_clip_hf import LongCLIPVisionConfig, LongCLIPVisionModel
229
+ >>> from transformers import CLIPImageProcessor
230
+ >>> from PIL import Image
231
+ >>>
232
+ >>> # Initialize model
233
+ >>> config = LongCLIPVisionConfig()
234
+ >>> model = LongCLIPVisionModel(config)
235
+ >>>
236
+ >>> # Process image
237
+ >>> processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
238
+ >>> image = Image.open("path/to/image.jpg")
239
+ >>> inputs = processor(images=image, return_tensors="pt")
240
+ >>>
241
+ >>> # Get image features
242
+ >>> outputs = model(**inputs)
243
+ >>> image_features = outputs.pooler_output
244
+ ```
245
+ """
246
+
247
+ config_class = LongCLIPVisionConfig
248
+
249
+
250
+ class LongCLIPModel(CLIPModel):
251
+ """
252
+ LongCLIP model combining text and vision encoders.
253
+
254
+ This model extends CLIPModel to use LongCLIPTextModel with 248 token
255
+ context length while keeping the standard vision encoder.
256
+
257
+ Args:
258
+ config (LongCLIPConfig): Configuration for the complete model.
259
+
260
+ Example:
261
+ ```python
262
+ >>> from long_clip_hf import LongCLIPConfig, LongCLIPModel
263
+ >>> from transformers import CLIPTokenizer, CLIPImageProcessor
264
+ >>> from PIL import Image
265
+ >>>
266
+ >>> # Initialize model
267
+ >>> config = LongCLIPConfig()
268
+ >>> model = LongCLIPModel(config)
269
+ >>>
270
+ >>> # Prepare inputs
271
+ >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
272
+ >>> processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
273
+ >>>
274
+ >>> text = "a photo of a cat"
275
+ >>> image = Image.open("path/to/image.jpg")
276
+ >>>
277
+ >>> text_inputs = tokenizer(
278
+ ... [text],
279
+ ... return_tensors="pt",
280
+ ... padding="max_length",
281
+ ... max_length=248,
282
+ ... truncation=True,
283
+ ... )
284
+ >>> image_inputs = processor(images=image, return_tensors="pt")
285
+ >>>
286
+ >>> # Get features
287
+ >>> outputs = model(
288
+ ... input_ids=text_inputs["input_ids"],
289
+ ... pixel_values=image_inputs["pixel_values"],
290
+ ... )
291
+ >>>
292
+ >>> # Compute similarity
293
+ >>> logits_per_image = outputs.logits_per_image
294
+ >>> probs = logits_per_image.softmax(dim=1)
295
+ ```
296
+ """
297
+
298
+ config_class = LongCLIPConfig
299
+
300
+ def __init__(self, config: LongCLIPConfig):
301
+ super().__init__(config)
302
+
303
+ # Replace text model with LongCLIP version
304
+ if not isinstance(config.text_config, LongCLIPTextConfig):
305
+ text_config = LongCLIPTextConfig(**config.text_config)
306
+ else:
307
+ text_config = config.text_config
308
+
309
+ self.text_model = LongCLIPTextModel(text_config)
310
+
311
+ # Vision model stays the same (standard CLIP)
312
+ if not isinstance(config.vision_config, LongCLIPVisionConfig):
313
+ vision_config = LongCLIPVisionConfig(**config.vision_config)
314
+ else:
315
+ vision_config = config.vision_config
316
+
317
+ self.vision_model = LongCLIPVisionModel(vision_config)
318
+
319
+ # Initialize weights
320
+ self.post_init()
321
+
322
+ def get_text_features(
323
+ self,
324
+ input_ids: Optional[torch.Tensor] = None,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ position_ids: Optional[torch.Tensor] = None,
327
+ output_attentions: Optional[bool] = None,
328
+ output_hidden_states: Optional[bool] = None,
329
+ return_dict: Optional[bool] = None,
330
+ ) -> torch.FloatTensor:
331
+ """
332
+ Get text features from the text encoder.
333
+
334
+ Args:
335
+ input_ids: Token IDs.
336
+ attention_mask: Attention mask.
337
+ position_ids: Position IDs.
338
+ output_attentions: Whether to output attention weights.
339
+ output_hidden_states: Whether to output hidden states.
340
+ return_dict: Whether to return a ModelOutput object.
341
+
342
+ Returns:
343
+ Text features of shape [batch_size, projection_dim].
344
+ """
345
+ return_dict = (
346
+ return_dict if return_dict is not None else self.config.use_return_dict
347
+ )
348
+
349
+ text_outputs = self.text_model(
350
+ input_ids=input_ids,
351
+ attention_mask=attention_mask,
352
+ position_ids=position_ids,
353
+ output_attentions=output_attentions,
354
+ output_hidden_states=output_hidden_states,
355
+ return_dict=return_dict,
356
+ )
357
+
358
+ pooled_output = (
359
+ text_outputs[1] if not return_dict else text_outputs.pooler_output
360
+ )
361
+ text_features = self.text_projection(pooled_output)
362
+
363
+ return text_features
364
+
365
+ def get_image_features(
366
+ self,
367
+ pixel_values: Optional[torch.FloatTensor] = None,
368
+ output_attentions: Optional[bool] = None,
369
+ output_hidden_states: Optional[bool] = None,
370
+ return_dict: Optional[bool] = None,
371
+ ) -> torch.FloatTensor:
372
+ """
373
+ Get image features from the vision encoder.
374
+
375
+ Args:
376
+ pixel_values: Pixel values.
377
+ output_attentions: Whether to output attention weights.
378
+ output_hidden_states: Whether to output hidden states.
379
+ return_dict: Whether to return a ModelOutput object.
380
+
381
+ Returns:
382
+ Image features of shape [batch_size, projection_dim].
383
+ """
384
+ return_dict = (
385
+ return_dict if return_dict is not None else self.config.use_return_dict
386
+ )
387
+
388
+ vision_outputs = self.vision_model(
389
+ pixel_values=pixel_values,
390
+ output_attentions=output_attentions,
391
+ output_hidden_states=output_hidden_states,
392
+ return_dict=return_dict,
393
+ )
394
+
395
+ pooled_output = (
396
+ vision_outputs[1] if not return_dict else vision_outputs.pooler_output
397
+ )
398
+ image_features = self.visual_projection(pooled_output)
399
+
400
+ return image_features