shunk031 commited on
Commit
7f4f8bb
·
verified ·
1 Parent(s): fb78dac

Upload long_clip_hf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. long_clip_hf.py +604 -0
long_clip_hf.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LongCLIP: Unlocking the Long-Text Capability of CLIP
3
+
4
+ This module provides HuggingFace Transformers-compatible implementations of LongCLIP,
5
+ which extends CLIP's text encoder to support 248 tokens (vs 77 in original CLIP).
6
+
7
+ Repository: https://github.com/beichenzbc/Long-CLIP
8
+ Paper: https://arxiv.org/abs/2403.15378
9
+ """
10
+
11
+ import logging
12
+ from typing import Any, Dict, List, Optional, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
17
+ from transformers import CLIPTextModel, CLIPVisionModel, CLIPModel
18
+ from transformers import CLIPImageProcessor, CLIPTokenizer
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.models.clip.modeling_clip import CLIPTextTransformer
21
+ from transformers.processing_utils import ProcessorMixin
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ # ================== Configuration Classes ==================
27
+
28
+
29
+ class LongCLIPTextConfig(CLIPTextConfig):
30
+ """
31
+ Configuration class for LongCLIP text model.
32
+
33
+ Extends CLIPTextConfig to support 248 token context length
34
+ and custom positional embedding interpolation.
35
+
36
+ Args:
37
+ max_position_embeddings (int, optional): Maximum sequence length. Defaults to 248.
38
+ use_position_interpolation (bool, optional): Whether to use position interpolation.
39
+ Defaults to True.
40
+ interpolation_keep_length (int, optional): Number of positions to keep from
41
+ original embeddings before interpolation. Defaults to 20.
42
+ **kwargs: Additional arguments passed to CLIPTextConfig.
43
+ """
44
+
45
+ model_type = "longclip_text_model"
46
+
47
+ def __init__(
48
+ self,
49
+ max_position_embeddings: int = 248,
50
+ use_position_interpolation: bool = True,
51
+ interpolation_keep_length: int = 20,
52
+ **kwargs,
53
+ ):
54
+ super().__init__(max_position_embeddings=max_position_embeddings, **kwargs)
55
+
56
+ self.use_position_interpolation = use_position_interpolation
57
+ self.interpolation_keep_length = interpolation_keep_length
58
+
59
+
60
+ class LongCLIPVisionConfig(CLIPVisionConfig):
61
+ """
62
+ Configuration class for LongCLIP vision model.
63
+
64
+ This is identical to the standard CLIPVisionConfig as LongCLIP
65
+ does not modify the vision encoder.
66
+
67
+ Args:
68
+ **kwargs: Arguments passed to CLIPVisionConfig.
69
+ """
70
+
71
+ model_type = "longclip_vision_model"
72
+
73
+ def __init__(self, **kwargs):
74
+ super().__init__(**kwargs)
75
+
76
+
77
+ class LongCLIPConfig(CLIPConfig):
78
+ """
79
+ Configuration class for LongCLIP model.
80
+
81
+ Combines LongCLIPTextConfig and LongCLIPVisionConfig to create
82
+ a complete LongCLIP model configuration.
83
+
84
+ Args:
85
+ text_config (Dict[str, Any] or LongCLIPTextConfig, optional):
86
+ Configuration for the text model. If None, uses default LongCLIPTextConfig.
87
+ vision_config (Dict[str, Any] or LongCLIPVisionConfig, optional):
88
+ Configuration for the vision model. If None, uses default LongCLIPVisionConfig.
89
+ projection_dim (int, optional): Dimensionality of text and vision projection layers.
90
+ Defaults to 512.
91
+ **kwargs: Additional arguments passed to CLIPConfig.
92
+ """
93
+
94
+ model_type = "longclip"
95
+ is_composition = True
96
+
97
+ def __init__(
98
+ self,
99
+ text_config: Dict[str, Any] | None = None,
100
+ vision_config: Dict[str, Any] | None = None,
101
+ projection_dim: int = 512,
102
+ **kwargs,
103
+ ):
104
+ # Initialize text config
105
+ if text_config is None:
106
+ text_config = {}
107
+ logger.info(
108
+ "text_config is None. Initializing the LongCLIPTextConfig with default values."
109
+ )
110
+
111
+ if vision_config is None:
112
+ vision_config = {}
113
+ logger.info(
114
+ "vision_config is None. Initializing the LongCLIPVisionConfig with default values."
115
+ )
116
+
117
+ # Create config objects if they're dictionaries
118
+ if isinstance(text_config, dict):
119
+ text_config = LongCLIPTextConfig(**text_config)
120
+
121
+ if isinstance(vision_config, dict):
122
+ vision_config = LongCLIPVisionConfig(**vision_config)
123
+
124
+ # Call parent init with config dicts
125
+ super().__init__(
126
+ text_config=text_config.to_dict(),
127
+ vision_config=vision_config.to_dict(),
128
+ projection_dim=projection_dim,
129
+ **kwargs,
130
+ )
131
+
132
+ # Store as config objects for easier access
133
+ self.text_config = text_config
134
+ self.vision_config = vision_config
135
+
136
+ @classmethod
137
+ def from_text_vision_configs(
138
+ cls,
139
+ text_config: LongCLIPTextConfig,
140
+ vision_config: LongCLIPVisionConfig,
141
+ **kwargs,
142
+ ):
143
+ """
144
+ Instantiate a LongCLIPConfig from text and vision configs.
145
+
146
+ Args:
147
+ text_config (LongCLIPTextConfig): Text model configuration.
148
+ vision_config (LongCLIPVisionConfig): Vision model configuration.
149
+ **kwargs: Additional keyword arguments.
150
+
151
+ Returns:
152
+ LongCLIPConfig: Configuration object.
153
+ """
154
+ return cls(
155
+ text_config=text_config.to_dict(),
156
+ vision_config=vision_config.to_dict(),
157
+ **kwargs,
158
+ )
159
+
160
+ def to_dict(self) -> Dict[str, Any]:
161
+ """
162
+ Serializes this instance to a Python dictionary.
163
+
164
+ Returns:
165
+ Dict[str, Any]: Dictionary of all attributes.
166
+ """
167
+ output = super().to_dict()
168
+ # Ensure text_config and vision_config are properly serialized
169
+ if hasattr(self, "text_config") and isinstance(
170
+ self.text_config, PretrainedConfig
171
+ ):
172
+ output["text_config"] = self.text_config.to_dict()
173
+ if hasattr(self, "vision_config") and isinstance(
174
+ self.vision_config, PretrainedConfig
175
+ ):
176
+ output["vision_config"] = self.vision_config.to_dict()
177
+ return output
178
+
179
+
180
+ # ================== Model Classes ==================
181
+
182
+
183
+ class LongCLIPTextEmbeddings(nn.Module):
184
+ """
185
+ Text embeddings for LongCLIP with custom positional embedding mechanism.
186
+
187
+ This module implements the dual positional embedding approach used in LongCLIP:
188
+ - The first 20 positions use the original CLIP positional embeddings (mask1)
189
+ - The remaining positions (21-248) use interpolated embeddings (mask2)
190
+ - position_embedding: Fixed base embeddings
191
+ - position_embedding_res: Trainable residual embeddings
192
+
193
+ Args:
194
+ config (LongCLIPTextConfig): Configuration for text embeddings.
195
+ """
196
+
197
+ def __init__(self, config: LongCLIPTextConfig):
198
+ super().__init__()
199
+ self.config = config
200
+ embed_dim = config.hidden_size
201
+
202
+ # Token embeddings
203
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
204
+
205
+ # Dual positional embeddings (LongCLIP approach)
206
+ # position_embedding: Base embeddings (typically loaded from checkpoint)
207
+ self.position_embedding = nn.Embedding(
208
+ config.max_position_embeddings, embed_dim
209
+ )
210
+
211
+ # position_embedding_res: Trainable residual embeddings
212
+ self.position_embedding_res = nn.Parameter(
213
+ torch.zeros(config.max_position_embeddings, embed_dim)
214
+ )
215
+
216
+ # Create masks for applying embeddings
217
+ # mask1: Use original embeddings for first interpolation_keep_length positions
218
+ # mask2: Use interpolated embeddings for remaining positions
219
+ self.register_buffer(
220
+ "mask1", self._create_mask(config, use_first=True), persistent=False
221
+ )
222
+ self.register_buffer(
223
+ "mask2", self._create_mask(config, use_first=False), persistent=False
224
+ )
225
+
226
+ # Store position IDs for efficiency
227
+ self.register_buffer(
228
+ "position_ids",
229
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
230
+ persistent=False,
231
+ )
232
+
233
+ def _create_mask(self, config: LongCLIPTextConfig, use_first: bool) -> torch.Tensor:
234
+ """
235
+ Create mask for positional embeddings.
236
+
237
+ Args:
238
+ config: Configuration object.
239
+ use_first: If True, mask first `interpolation_keep_length` positions.
240
+ If False, mask remaining positions.
241
+
242
+ Returns:
243
+ Mask tensor of shape [max_position_embeddings, 1].
244
+ """
245
+ mask = torch.zeros(config.max_position_embeddings, 1)
246
+ if use_first:
247
+ # mask1: First interpolation_keep_length positions
248
+ mask[: config.interpolation_keep_length] = 1.0
249
+ else:
250
+ # mask2: Remaining positions
251
+ mask[config.interpolation_keep_length :] = 1.0
252
+ return mask
253
+
254
+ def forward(
255
+ self,
256
+ input_ids: Optional[torch.LongTensor] = None,
257
+ position_ids: Optional[torch.LongTensor] = None,
258
+ inputs_embeds: Optional[torch.FloatTensor] = None,
259
+ ) -> torch.Tensor:
260
+ """
261
+ Forward pass for text embeddings.
262
+
263
+ Args:
264
+ input_ids: Token IDs of shape [batch_size, seq_length].
265
+ position_ids: Position IDs of shape [batch_size, seq_length].
266
+ inputs_embeds: Pre-computed token embeddings.
267
+
268
+ Returns:
269
+ Embeddings of shape [batch_size, seq_length, hidden_size].
270
+ """
271
+ seq_length = (
272
+ input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
273
+ )
274
+
275
+ if position_ids is None:
276
+ position_ids = self.position_ids[:, :seq_length]
277
+
278
+ # Get token embeddings
279
+ if inputs_embeds is None:
280
+ inputs_embeds = self.token_embedding(input_ids)
281
+
282
+ # Get positional embeddings
283
+ position_embeddings = self.position_embedding(position_ids)
284
+
285
+ # Add residual positional embeddings (for positions > interpolation_keep_length)
286
+ # Expand position_embedding_res for batch dimension
287
+ position_embeddings_res = self.position_embedding_res.unsqueeze(0).expand(
288
+ position_ids.shape[0], -1, -1
289
+ )[:, :seq_length, :]
290
+
291
+ # Apply masks: mask1 for first 20, mask2 for rest
292
+ # Broadcasting: [seq_length, 1] * [batch, seq_length, hidden_size]
293
+ mask1 = self.mask1[:seq_length].transpose(0, 1) # [1, seq_length]
294
+ mask2 = self.mask2[:seq_length].transpose(0, 1) # [1, seq_length]
295
+
296
+ # Combine embeddings with masking
297
+ embeddings = (
298
+ inputs_embeds
299
+ + position_embeddings * mask1.unsqueeze(-1)
300
+ + position_embeddings_res * mask2.unsqueeze(-1)
301
+ )
302
+
303
+ return embeddings
304
+
305
+
306
+ class LongCLIPTextTransformer(CLIPTextTransformer):
307
+ """
308
+ Text transformer for LongCLIP.
309
+
310
+ This extends CLIPTextTransformer to use LongCLIPTextEmbeddings
311
+ with custom positional embedding mechanism.
312
+
313
+ Args:
314
+ config (LongCLIPTextConfig): Configuration for text transformer.
315
+ """
316
+
317
+ def __init__(self, config: LongCLIPTextConfig):
318
+ super().__init__(config)
319
+ # Replace embeddings with LongCLIP version
320
+ self.embeddings = LongCLIPTextEmbeddings(config)
321
+
322
+
323
+ class LongCLIPTextModel(CLIPTextModel):
324
+ """
325
+ LongCLIP text model compatible with HuggingFace Transformers.
326
+
327
+ This model extends CLIPTextModel to support 248 token context length
328
+ with custom positional embedding interpolation.
329
+
330
+ Args:
331
+ config (LongCLIPTextConfig): Configuration for the text model.
332
+ """
333
+
334
+ config_class = LongCLIPTextConfig
335
+
336
+ def __init__(self, config: LongCLIPTextConfig):
337
+ super().__init__(config)
338
+ # Replace text_model with LongCLIP version
339
+ self.text_model = LongCLIPTextTransformer(config)
340
+ # Initialize weights
341
+ self.post_init()
342
+
343
+ def get_input_embeddings(self) -> nn.Module:
344
+ """Get token embedding layer."""
345
+ return self.text_model.embeddings.token_embedding
346
+
347
+ def set_input_embeddings(self, value: nn.Module):
348
+ """Set token embedding layer."""
349
+ self.text_model.embeddings.token_embedding = value
350
+
351
+
352
+ class LongCLIPVisionModel(CLIPVisionModel):
353
+ """
354
+ LongCLIP vision model.
355
+
356
+ This is identical to CLIPVisionModel as LongCLIP does not modify
357
+ the vision encoder. Provided for API consistency.
358
+
359
+ Args:
360
+ config (LongCLIPVisionConfig): Configuration for the vision model.
361
+ """
362
+
363
+ config_class = LongCLIPVisionConfig
364
+
365
+
366
+ class LongCLIPModel(CLIPModel):
367
+ """
368
+ LongCLIP model combining text and vision encoders.
369
+
370
+ This model extends CLIPModel to use LongCLIPTextModel with 248 token
371
+ context length while keeping the standard vision encoder.
372
+
373
+ Args:
374
+ config (LongCLIPConfig): Configuration for the complete model.
375
+ """
376
+
377
+ config_class = LongCLIPConfig
378
+
379
+ def __init__(self, config: LongCLIPConfig):
380
+ super().__init__(config)
381
+
382
+ # Replace text model with LongCLIP version
383
+ if not isinstance(config.text_config, LongCLIPTextConfig):
384
+ text_config = LongCLIPTextConfig(**config.text_config)
385
+ else:
386
+ text_config = config.text_config
387
+
388
+ self.text_model = LongCLIPTextModel(text_config)
389
+
390
+ # Vision model stays the same (standard CLIP)
391
+ if not isinstance(config.vision_config, LongCLIPVisionConfig):
392
+ vision_config = LongCLIPVisionConfig(**config.vision_config)
393
+ else:
394
+ vision_config = config.vision_config
395
+
396
+ self.vision_model = LongCLIPVisionModel(vision_config)
397
+
398
+ # Initialize weights
399
+ self.post_init()
400
+
401
+ def get_text_features(
402
+ self,
403
+ input_ids: Optional[torch.Tensor] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ position_ids: Optional[torch.Tensor] = None,
406
+ output_attentions: Optional[bool] = None,
407
+ output_hidden_states: Optional[bool] = None,
408
+ return_dict: Optional[bool] = None,
409
+ ) -> torch.FloatTensor:
410
+ """
411
+ Get text features from the text encoder.
412
+
413
+ Args:
414
+ input_ids: Token IDs.
415
+ attention_mask: Attention mask.
416
+ position_ids: Position IDs.
417
+ output_attentions: Whether to output attention weights.
418
+ output_hidden_states: Whether to output hidden states.
419
+ return_dict: Whether to return a ModelOutput object.
420
+
421
+ Returns:
422
+ Text features of shape [batch_size, projection_dim].
423
+ """
424
+ return_dict = (
425
+ return_dict if return_dict is not None else self.config.use_return_dict
426
+ )
427
+
428
+ text_outputs = self.text_model(
429
+ input_ids=input_ids,
430
+ attention_mask=attention_mask,
431
+ position_ids=position_ids,
432
+ output_attentions=output_attentions,
433
+ output_hidden_states=output_hidden_states,
434
+ return_dict=return_dict,
435
+ )
436
+
437
+ pooled_output = (
438
+ text_outputs[1] if not return_dict else text_outputs.pooler_output
439
+ )
440
+ text_features = self.text_projection(pooled_output)
441
+
442
+ return text_features
443
+
444
+ def get_image_features(
445
+ self,
446
+ pixel_values: Optional[torch.FloatTensor] = None,
447
+ output_attentions: Optional[bool] = None,
448
+ output_hidden_states: Optional[bool] = None,
449
+ return_dict: Optional[bool] = None,
450
+ ) -> torch.FloatTensor:
451
+ """
452
+ Get image features from the vision encoder.
453
+
454
+ Args:
455
+ pixel_values: Pixel values.
456
+ output_attentions: Whether to output attention weights.
457
+ output_hidden_states: Whether to output hidden states.
458
+ return_dict: Whether to return a ModelOutput object.
459
+
460
+ Returns:
461
+ Image features of shape [batch_size, projection_dim].
462
+ """
463
+ return_dict = (
464
+ return_dict if return_dict is not None else self.config.use_return_dict
465
+ )
466
+
467
+ vision_outputs = self.vision_model(
468
+ pixel_values=pixel_values,
469
+ output_attentions=output_attentions,
470
+ output_hidden_states=output_hidden_states,
471
+ return_dict=return_dict,
472
+ )
473
+
474
+ pooled_output = (
475
+ vision_outputs[1] if not return_dict else vision_outputs.pooler_output
476
+ )
477
+ image_features = self.visual_projection(pooled_output)
478
+
479
+ return image_features
480
+
481
+
482
+ # ================== Processor Class ==================
483
+
484
+
485
+ class LongCLIPProcessor(ProcessorMixin):
486
+ """
487
+ Processor for LongCLIP that combines image and text preprocessing.
488
+
489
+ This processor wraps CLIPImageProcessor and CLIPTokenizer to provide
490
+ a unified interface for preprocessing inputs for LongCLIP models.
491
+
492
+ Args:
493
+ image_processor (CLIPImageProcessor): Image processor for preprocessing images.
494
+ tokenizer (CLIPTokenizer): Tokenizer for preprocessing text.
495
+
496
+ Attributes:
497
+ image_processor_class (str): Name of the image processor class.
498
+ tokenizer_class (str): Name of the tokenizer class.
499
+ """
500
+
501
+ attributes = ["image_processor", "tokenizer"]
502
+ image_processor_class = "CLIPImageProcessor"
503
+ tokenizer_class = "CLIPTokenizer"
504
+
505
+ def __init__(
506
+ self,
507
+ image_processor: Optional[CLIPImageProcessor] = None,
508
+ tokenizer: Optional[CLIPTokenizer] = None,
509
+ **kwargs,
510
+ ):
511
+ if image_processor is None:
512
+ raise ValueError("You need to specify an `image_processor`.")
513
+ if tokenizer is None:
514
+ raise ValueError("You need to specify a `tokenizer`.")
515
+
516
+ super().__init__(image_processor, tokenizer)
517
+
518
+ def __call__(
519
+ self,
520
+ text: Union[str, List[str], None] = None,
521
+ images=None,
522
+ return_tensors: Optional[str] = "pt",
523
+ padding: Union[bool, str] = True,
524
+ max_length: Optional[int] = 248,
525
+ truncation: Optional[bool] = True,
526
+ **kwargs,
527
+ ):
528
+ """
529
+ Preprocess text and images for LongCLIP model.
530
+
531
+ Args:
532
+ text (str, List[str], optional): Text or list of texts to process.
533
+ images: Image or list of images to process. Can be PIL Image, numpy array, or tensor.
534
+ return_tensors (str, optional): Type of tensors to return ('pt' for PyTorch).
535
+ padding (bool or str, optional): Padding strategy. Defaults to True.
536
+ max_length (int, optional): Maximum sequence length. Defaults to 248 for LongCLIP.
537
+ truncation (bool, optional): Whether to truncate sequences. Defaults to True.
538
+ **kwargs: Additional keyword arguments.
539
+
540
+ Returns:
541
+ BatchEncoding: Dictionary containing processed inputs with keys:
542
+ - input_ids: Tokenized text (if text provided)
543
+ - attention_mask: Attention mask for text (if text provided)
544
+ - pixel_values: Processed images (if images provided)
545
+ """
546
+ # Process text
547
+ if text is not None:
548
+ text_inputs = self.tokenizer(
549
+ text,
550
+ return_tensors=return_tensors,
551
+ padding=padding,
552
+ max_length=max_length,
553
+ truncation=truncation,
554
+ **kwargs,
555
+ )
556
+ else:
557
+ text_inputs = {}
558
+
559
+ # Process images
560
+ if images is not None:
561
+ image_inputs = self.image_processor(
562
+ images,
563
+ return_tensors=return_tensors,
564
+ )
565
+ else:
566
+ image_inputs = {}
567
+
568
+ # Combine inputs
569
+ return {**text_inputs, **image_inputs}
570
+
571
+ def batch_decode(self, *args, **kwargs):
572
+ """
573
+ Decode token IDs back to text.
574
+
575
+ This method is forwarded to the tokenizer's batch_decode method.
576
+ """
577
+ return self.tokenizer.batch_decode(*args, **kwargs)
578
+
579
+ def decode(self, *args, **kwargs):
580
+ """
581
+ Decode token IDs back to text.
582
+
583
+ This method is forwarded to the tokenizer's decode method.
584
+ """
585
+ return self.tokenizer.decode(*args, **kwargs)
586
+
587
+ @property
588
+ def model_input_names(self):
589
+ """
590
+ Get the names of model inputs.
591
+
592
+ Returns:
593
+ List[str]: List of input names.
594
+ """
595
+ tokenizer_input_names = self.tokenizer.model_input_names
596
+ image_processor_input_names = self.image_processor.model_input_names
597
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
598
+
599
+
600
+ # Register configuration for auto classes
601
+ from transformers import AutoConfig, AutoModel
602
+
603
+ AutoConfig.register("longclip", LongCLIPConfig)
604
+ AutoModel.register(LongCLIPConfig, LongCLIPModel)