Text Generation
Transformers
Safetensors
DIVEdoc
docvqa
distillation
VLM
document-understanding
OCR-free
custom_code
JayRay5 commited on
Commit
79bb81a
·
verified ·
1 Parent(s): 943bc58

Upload 2 files

Browse files

Models files to use AutoModel

Files changed (2) hide show
  1. configuration_divedoc.py +248 -0
  2. modeling_divedoc.py +541 -0
configuration_divedoc.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ parent_root = Path().resolve().parent.parent
4
+ sys.path.append(str(parent_root))
5
+
6
+
7
+
8
+
9
+ from transformers import PretrainedConfig, DonutSwinConfig, GemmaConfig, CONFIG_MAPPING, SiglipVisionConfig
10
+ from typing import Tuple, Literal
11
+
12
+
13
+
14
+ class PamConfig(PretrainedConfig):
15
+ model_type = "pam"
16
+ def __init__(
17
+ self,
18
+ sequence_mapping_layer_type: Literal["linear_projection","bilinear_interpolation"] = "bilinear_interpolation",
19
+ student_fmap_dim: Tuple[int,int]=(80,60),
20
+ student_embedding_dim: int = 1024,
21
+ teacher_fmap_dim: Tuple[int,int] = (64,64),
22
+ teacher_embedding_dim: int = 1152,
23
+ **kwargs,
24
+ ):
25
+ self.sequence_mapping_layer_type = sequence_mapping_layer_type
26
+ self.student_fmap_dim = student_fmap_dim
27
+ self.student_embedding_dim = student_embedding_dim
28
+ self.teacher_fmap_dim = teacher_fmap_dim
29
+ self.teacher_embedding_dim = teacher_embedding_dim
30
+ super().__init__(**kwargs)
31
+
32
+
33
+ class SwinPamVisionEncoderConfig(PretrainedConfig):
34
+ model_type = "swinpam"
35
+ sub_configs = {"encoder_config": DonutSwinConfig, "pam_config": PamConfig}
36
+ def __init__(
37
+ self,
38
+ encoder_config: DonutSwinConfig = None,
39
+ pam_config: PamConfig = None,
40
+ **kwargs
41
+ ):
42
+ self.encoder_config = encoder_config
43
+ self.pam_config = pam_config
44
+
45
+ if isinstance(self.encoder_config, dict):
46
+ encoder_config["model_type"] = (
47
+ encoder_config["model_type"] if "model_type" in encoder_config else "donut-swin"
48
+ )
49
+ if encoder_config["model_type"] == "donut-swin":
50
+ self.encoder_config = DonutSwinConfig(**encoder_config)
51
+ else:
52
+ print(f"Encoder type: {encoder_config['model_type']}")
53
+ self.encoder_config = CONFIG_MAPPING[encoder_config["model_type"]](**encoder_config)
54
+
55
+ '''
56
+ elif encoder_config is None:
57
+ print("coucou2")
58
+ self.encoder_config = DonutSwinConfig()
59
+ '''
60
+
61
+ if isinstance(self.pam_config, dict):
62
+ '''
63
+ pam_config["model_type"] = (
64
+ pam_config["model_type"] if "model_type" in pam_config else "pam"
65
+ )
66
+ '''
67
+ if pam_config["model_type"] == "pam":
68
+ self.pam_config = PamConfig(**pam_config)
69
+ else:
70
+ raise ValueError(f"pam_config['model_type'] should be 'pam', got {pam_config['model_type']}")
71
+ '''
72
+ elif pam_config is None:
73
+ self.pam_config = PamConfig()
74
+ '''
75
+ super().__init__(**kwargs)
76
+
77
+
78
+ class SiglipPAMVisionEncoderConfig(PretrainedConfig):
79
+ model_type = "siglippam"
80
+ sub_configs = {"encoder_config": SiglipVisionConfig, "pam_config": PamConfig}
81
+ def __init__(
82
+ self,
83
+ encoder_config: SiglipVisionConfig = None,
84
+ pam_config: PamConfig = None,
85
+ **kwargs
86
+ ):
87
+ self.encoder_config = encoder_config
88
+ self.pam_config = pam_config
89
+
90
+ if isinstance(self.encoder_config, dict):
91
+ encoder_config["model_type"] = (
92
+ encoder_config["model_type"] if "model_type" in encoder_config else "siglip_vision_model"
93
+ )
94
+ if encoder_config["model_type"] == "siglip_vision_model":
95
+ self.encoder_config = SiglipVisionConfig(**encoder_config)
96
+ else:
97
+ raise ValueError(f"Need siglip_model_type, got {encoder_config['model_type']}")
98
+
99
+ if isinstance(self.pam_config, dict):
100
+ if pam_config["model_type"] == "pam":
101
+ self.pam_config = PamConfig(**pam_config)
102
+ else:
103
+ raise ValueError(f"pam_config['model_type'] should be 'pam', got {pam_config['model_type']}")
104
+
105
+ super().__init__(**kwargs)
106
+
107
+
108
+ class DIVEdocConfig(PretrainedConfig):
109
+ keys_to_ignore_at_inference = ["past_key_values"]
110
+ sub_configs = {"vision_config": SwinPamVisionEncoderConfig, "text_config": GemmaConfig}
111
+ model_type = "DIVEdoc"
112
+ def __init__(
113
+ self,
114
+ vision_config=None,
115
+ text_config=None,
116
+ ignore_index=-100,
117
+ image_token_index=256000,
118
+ vocab_size=257152,
119
+ projection_dim=2048,
120
+ hidden_size=2048,
121
+ #_attn_implementation_autoset = True,
122
+ **kwargs,
123
+ ):
124
+ self._ignore_index = ignore_index
125
+ self.image_token_index = image_token_index
126
+ self._vocab_size = vocab_size
127
+ self.projection_dim = projection_dim
128
+ self.hidden_size = hidden_size
129
+ self.vision_config = vision_config
130
+ self.is_encoder_decoder = False
131
+ #self._attn_implementation_autoset = _attn_implementation_autoset
132
+
133
+
134
+ if isinstance(self.vision_config, dict):
135
+ vision_config["model_type"] = (
136
+ vision_config["model_type"] if "model_type" in vision_config else "swinpam"
137
+ )
138
+ if vision_config["model_type"] == "swinpam":
139
+ self.vision_config = SwinPamVisionEncoderConfig(encoder_config=vision_config["encoder_config"],pam_config=vision_config["pam_config"])
140
+ elif vision_config["model_type"] == "siglippam":
141
+ self.vision_config = SiglipPAMVisionEncoderConfig(encoder_config=vision_config["encoder_config"],pam_config=vision_config["pam_config"])
142
+ else:
143
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
144
+ elif vision_config is None:
145
+ self.vision_config = get_vision_config("swinpam")
146
+
147
+ self.text_config = text_config
148
+ if isinstance(self.text_config, dict):
149
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "gemma"
150
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
151
+ elif text_config is None:
152
+ self.text_config = CONFIG_MAPPING["gemma"](
153
+ hidden_size=2048,
154
+ num_hidden_layers=18,
155
+ intermediate_size=16384,
156
+ num_attention_heads=8,
157
+ num_key_value_heads=1,
158
+ is_encoder_decoder=False,
159
+ vocab_size=vocab_size,
160
+ )
161
+
162
+ self.text_config.num_image_tokens = self.vision_config.pam_config.teacher_fmap_dim[0] *\
163
+ self.vision_config.pam_config.teacher_fmap_dim[1]
164
+ self.vision_config.projection_dim = projection_dim
165
+ super().__init__(**kwargs)
166
+
167
+ def to_dict(self):
168
+ output = super().to_dict()
169
+ output.pop("_ignore_index", None)
170
+ return output
171
+
172
+ def get_siglip_vision_config(image_size=[896,896],num_image_token = 4096,hidden_size = 768):
173
+ encoder_config = SiglipVisionConfig(
174
+ hidden_size = hidden_size,
175
+ image_size = image_size,
176
+ intermediate_size = 2860,
177
+ model_type = "siglip_vision_model",
178
+ num_attention_heads = 8,
179
+ num_hidden_layers = 12,
180
+ num_image_tokens = num_image_token,
181
+ patch_size = 14,
182
+ projection_dim = 2048,
183
+ projector_hidden_act = "gelu_fast",
184
+ torch_dtype = "float32",
185
+ vision_use_head = False
186
+ )
187
+ return encoder_config
188
+
189
+ def get_swin_vision_config(image_size=[2560,1920],hidden_size = 1024):
190
+ encoder_config = DonutSwinConfig(
191
+ attention_probs_dropout_prob= 0.0,
192
+ depths =[
193
+ 2,
194
+ 2,
195
+ 14,
196
+ 2
197
+ ],
198
+ drop_path_rate= 0.1,
199
+ embed_dim =128,
200
+ hidden_act ="gelu",
201
+ hidden_dropout_prob = 0.0,
202
+ hidden_size = hidden_size,
203
+ image_size = image_size,
204
+ initializer_range = 0.02,
205
+ layer_norm_eps = 1e-05,
206
+ mlp_ratio = 4.0,
207
+ model_type = "donut-swin",
208
+ num_channels = 3,
209
+ num_heads =[
210
+ 4,
211
+ 8,
212
+ 16,
213
+ 32
214
+ ],
215
+ num_layers =4,
216
+ patch_size = 4,
217
+ path_norm = True,
218
+ qkv_bias = True,
219
+ use_absolute_embeddings = False,
220
+ window_size = 10
221
+ )
222
+ return encoder_config
223
+
224
+ def get_vision_config( visual_encoder_type:Literal["swinpam","siglip80m"],
225
+ image_size=[2560,1920],
226
+ sequence_mapping_layer_type= "bilinear",
227
+ student_fmap_dim=(80,60),
228
+ student_embedding_dim= 1024,
229
+ teacher_fmap_dim= (64,64),
230
+ teacher_embedding_dim= 1152):
231
+ pam_config = PamConfig(
232
+ sequence_mapping_layer_type = sequence_mapping_layer_type,
233
+ student_fmap_dim = student_fmap_dim,
234
+ student_embedding_dim = student_embedding_dim,
235
+ teacher_fmap_dim = teacher_fmap_dim,
236
+ teacher_embedding_dim = teacher_embedding_dim)
237
+
238
+ if visual_encoder_type == "swinpam":
239
+ encoder_config = get_swin_vision_config(image_size=image_size,hidden_size = student_embedding_dim)
240
+ ve_config = SwinPamVisionEncoderConfig(encoder_config=encoder_config,pam_config=pam_config)
241
+ return ve_config
242
+
243
+ elif visual_encoder_type =="siglip80m":
244
+ encoder_config = get_siglip_vision_config(image_size=image_size,num_image_token = (image_size//14)**2, hidden_size = student_embedding_dim)
245
+ ve_config = SiglipPAMVisionEncoderConfig(encoder_config=encoder_config,pam_config=pam_config)
246
+ return ve_config
247
+ else:
248
+ raise ValueError(f"Unknown visual encoder type, need 'swinpam' or 'siglip80m, got {visual_encoder_type}.")
modeling_divedoc.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ parent_root = Path().resolve().parent.parent
4
+ sys.path.append(str(parent_root))
5
+
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint
11
+ import torch.nn.functional as F
12
+
13
+
14
+ from transformers import Cache, HybridCache, StaticCache
15
+ from transformers.modeling_outputs import BaseModelOutput
16
+ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, is_torchdynamo_compiling, replace_return_docstrings
17
+ from transformers.utils.deprecation import deprecate_kwarg
18
+ from transformers import PreTrainedModel, AutoConfig, PaliGemmaPreTrainedModel,AutoModelForCausalLM,GenerationMixin
19
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector, PaliGemmaCausalLMOutputWithPast, PALIGEMMA_INPUTS_DOCSTRING
20
+ from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
21
+ from transformers.models.donut.modeling_donut_swin import DonutSwinModel
22
+
23
+
24
+ from .config_divedoc import SwinPamVisionEncoderConfig, SiglipPAMVisionEncoderConfig, DIVEdocConfig
25
+ from typing import List, Optional, Tuple, Union
26
+ from dataclasses import dataclass
27
+
28
+
29
+ class PAM(nn.Module):
30
+ def __init__(
31
+ self,
32
+ sequence_mapping_layer_type: Literal["linear_projection","bilinear","bicubic","nearest-exact"] = "bilinear",
33
+ student_fmap_dim: Tuple[int,int]=(80,60),
34
+ student_embedding_dim: int = 1024,
35
+ teacher_fmap_dim: Tuple[int,int] = (64,64),
36
+ teacher_embedding_dim: int = 1152
37
+ ):
38
+ super().__init__()
39
+ self.sequence_mapping_layer_type = sequence_mapping_layer_type
40
+ self.sequence_mapping_layer = nn.Linear(student_fmap_dim[0]*student_fmap_dim[1],teacher_fmap_dim[0]*teacher_fmap_dim[1]) if sequence_mapping_layer_type == "linear_projection" else None
41
+ self.embedding_projection_layer = nn.Sequential(
42
+ nn.Linear(student_embedding_dim,teacher_embedding_dim),
43
+ nn.LayerNorm((teacher_embedding_dim,),eps=1e-06))
44
+
45
+ self.student_fmap_dim = student_fmap_dim
46
+ self.student_embedding_dim = student_embedding_dim
47
+ self.teacher_fmap_dim = teacher_fmap_dim
48
+ self.teacher_embedding_dim = teacher_embedding_dim
49
+
50
+ print(self.student_fmap_dim)
51
+ #take input x of shape (Batch, Nb_token, Dim_embedding)
52
+ def forward(self,x:Tensor) -> Tensor:
53
+ #
54
+ '''
55
+ if x.shape[1] != self.student_fmap_dim[0] * self.student_fmap_dim[1] or x.shape[2] != self.student_embedding_dim:
56
+ raise ValueError(f"Expected input shape (*, {self.student_fmap_dim[0] * self.student_fmap_dim[1],self.student_embedding_dim}), "
57
+ f"but got {x.shape}")
58
+ '''
59
+
60
+ if x.shape[1]!=(self.teacher_fmap_dim[0]*self.teacher_fmap_dim[1]):
61
+ print(x.shape[1])
62
+ print(self.teacher_fmap_dim[0]*self.teacher_fmap_dim[1])
63
+ print("Resizing")
64
+ if self.sequence_mapping_layer_type == "linear_projection":
65
+ x = torch.permute(x,(0,2,1))
66
+ x = self.sequence_mapping_layer(x)
67
+ x = torch.permute(x,(0,2,1))
68
+
69
+ elif self.sequence_mapping_layer_type in ["bilinear","bicubic","nearest-exact"]:
70
+ batch_size,_,embedding_size = x.size()
71
+ x = x.view(batch_size,self.student_fmap_dim[0],self.student_fmap_dim[1],embedding_size).permute(0,3, 1, 2)
72
+ x = F.interpolate(x,size=self.teacher_fmap_dim,mode=self.sequence_mapping_layer_type) # Shape: (1, D, target_height, target_width)
73
+ x = x.permute(0,2, 3, 1).reshape(batch_size,-1, embedding_size)
74
+
75
+ x = self.embedding_projection_layer(x)
76
+ return x
77
+
78
+ class SwinPam(nn.Module):
79
+ def __init__(
80
+ self,
81
+ encoder_config: AutoConfig,
82
+ pam_sequence_mapping_layer_type: Literal["linear_projection","bilinear","bicubic","nearest-exact"] = "bilinear",
83
+ pam_student_fmap_dim: Tuple[int,int] = (80,60),
84
+ pam_student_embedding_dim: int = 1024,
85
+ pam_teacher_fmap_dim: Tuple[int,int] = (64,64),
86
+ pam_teacher_embedding_dim: int = 1152
87
+ ):
88
+ super().__init__()
89
+ self.encoder_model = DonutSwinModel(encoder_config)
90
+ print(pam_student_fmap_dim)
91
+ self.pam = PAM(
92
+ sequence_mapping_layer_type = pam_sequence_mapping_layer_type,
93
+ student_fmap_dim = pam_student_fmap_dim,
94
+ student_embedding_dim = pam_student_embedding_dim,
95
+ teacher_fmap_dim = pam_teacher_fmap_dim,
96
+ teacher_embedding_dim = pam_teacher_embedding_dim)
97
+
98
+ def forward(self,x):
99
+ x = self.encoder_model(x).last_hidden_state
100
+ x = self.pam(x)
101
+ return x
102
+
103
+
104
+
105
+ @dataclass
106
+ class SwinPamVisionEncoderOutput(ModelOutput):
107
+ """
108
+ Base class for PaliGemmacausal language model (or autoregressive) outputs.
109
+
110
+ Args:
111
+ last_hidden_states (`torch.FloatTensor`, *optional*):
112
+ A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
113
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
114
+ """
115
+
116
+ last_hidden_states: Optional[torch.FloatTensor] = None
117
+
118
+ class SwinPamVisionEncoder(PreTrainedModel):
119
+ config_class = SwinPamVisionEncoderConfig
120
+ keys_to_ignore_at_inference = ["past_key_values"]
121
+
122
+ def __init__(self, config):
123
+ super().__init__(config)
124
+ self.model = SwinPam(
125
+ config.encoder_config,
126
+ config.pam_config.sequence_mapping_layer_type,
127
+ config.pam_config.student_fmap_dim,
128
+ config.pam_config.student_embedding_dim,
129
+ config.pam_config.teacher_fmap_dim,
130
+ config.pam_config.teacher_embedding_dim,
131
+ )
132
+ def forward(self,x):
133
+ x = self.model(x)
134
+ return BaseModelOutput(last_hidden_state=x)
135
+
136
+ class SiglipPAMVisionEncoder(PreTrainedModel):
137
+ config_class = SiglipPAMVisionEncoderConfig
138
+ keys_to_ignore_at_inference = ["past_key_values"]
139
+
140
+ def __init__(self, config):
141
+ super().__init__(config)
142
+ self.model = SiglipPAM(
143
+ config.encoder_config,
144
+ config.pam_config.sequence_mapping_layer_type,
145
+ config.pam_config.student_fmap_dim,
146
+ config.pam_config.student_embedding_dim,
147
+ config.pam_config.teacher_fmap_dim,
148
+ config.pam_config.teacher_embedding_dim,
149
+ )
150
+ def forward(self,x):
151
+ x = self.model(x)
152
+ return BaseModelOutput(last_hidden_state=x)
153
+
154
+
155
+ class PaliGemmaMultiModalProjector(nn.Module):
156
+ def __init__(self, config: PaliGemmaConfig):
157
+ super().__init__()
158
+ self.linear = nn.Linear(config.vision_config.pam_config.teacher_embedding_dim, config.vision_config.projection_dim, bias=True)
159
+
160
+ def forward(self, image_features):
161
+ hidden_states = self.linear(image_features)
162
+
163
+ return hidden_states
164
+
165
+
166
+
167
+ _CONFIG_FOR_DOC = "DIVEdocConfig"
168
+ class DIVEdoc(PaliGemmaPreTrainedModel, GenerationMixin):
169
+ config_class = DIVEdocConfig
170
+ def __init__(self, config: DIVEdocConfig):
171
+ super().__init__(config)
172
+
173
+ print(f"Vision config in end-to-end model: {config.vision_config.model_type}")
174
+ if config.vision_config.model_type == "swinpam":
175
+ self.vision_tower = SwinPamVisionEncoder(config=config.vision_config)
176
+
177
+ elif config.vision_config.model_type == "siglippam":
178
+ self.vision_tower = SiglipPAMVisionEncoder(config=config.vision_config)
179
+
180
+ else:
181
+ raise ValueError("Unknown model_type in vision_config")
182
+
183
+ self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
184
+ self.vocab_size = config.text_config.vocab_size
185
+
186
+ language_model = AutoModelForCausalLM.from_config(config=config.text_config)
187
+
188
+ if language_model._tied_weights_keys is not None:
189
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
190
+ self.language_model = language_model
191
+
192
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
193
+ self.post_init()
194
+
195
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
196
+ def get_input_embeddings(self):
197
+ return self.language_model.get_input_embeddings()
198
+
199
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
200
+ def set_input_embeddings(self, value):
201
+ self.language_model.set_input_embeddings(value)
202
+
203
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
204
+ def get_output_embeddings(self):
205
+ return self.language_model.get_output_embeddings()
206
+
207
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
208
+ def set_output_embeddings(self, new_embeddings):
209
+ self.language_model.set_output_embeddings(new_embeddings)
210
+
211
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
212
+ def set_decoder(self, decoder):
213
+ self.language_model.set_decoder(decoder)
214
+
215
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
216
+ def get_decoder(self):
217
+ return self.language_model.get_decoder()
218
+ def get_dtype(self):
219
+ return self.dtype
220
+
221
+ def _update_causal_mask(
222
+ self,
223
+ attention_mask,
224
+ token_type_ids=None,
225
+ past_key_values=None,
226
+ cache_position=None,
227
+ input_tensor=None,
228
+ is_training: bool = None,
229
+ dtype=None, #to handle quantized finetuning issue when model switch between 4 or 8bit and float
230
+ ):
231
+ if self.config.text_config._attn_implementation == "flash_attention_2":
232
+ if attention_mask is not None and 0.0 in attention_mask:
233
+ return attention_mask
234
+ return None
235
+ is_training = is_training if is_training is not None else self.training
236
+ using_static_cache = isinstance(past_key_values, StaticCache)
237
+
238
+ # Handle the case when the model is quantized in 4 or 8 bit
239
+
240
+ if dtype is not None:
241
+ min_dtype = torch.finfo(dtype).min
242
+ else:
243
+ min_dtype = torch.finfo(self.get_dtype()).min
244
+
245
+
246
+ if input_tensor is None:
247
+ input_tensor = attention_mask
248
+
249
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
250
+ if using_static_cache:
251
+ target_length = past_key_values.get_max_cache_shape()
252
+ elif isinstance(past_key_values, HybridCache):
253
+ target_length = past_key_values.get_max_cache_shape()
254
+ else:
255
+ target_length = (
256
+ attention_mask.shape[-1]
257
+ if isinstance(attention_mask, torch.Tensor)
258
+ else cache_position[0] + sequence_length + 1
259
+ )
260
+
261
+ if attention_mask is not None and attention_mask.dim() == 4:
262
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
263
+ return attention_mask
264
+ ''' initial line but changed for quantization processing
265
+ causal_mask = torch.full(
266
+ (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
267
+ )
268
+ '''
269
+ causal_mask = torch.full(
270
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
271
+ )
272
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
273
+ if sequence_length != 1:
274
+ if is_training:
275
+ causal_mask = torch.triu(causal_mask, diagonal=1)
276
+ else:
277
+ causal_mask[:, :sequence_length] = 0.0
278
+
279
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
280
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
281
+ if attention_mask is not None:
282
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
283
+ mask_length = attention_mask.shape[-1]
284
+
285
+ # First unmask prefix tokens during training
286
+ if is_training:
287
+ if token_type_ids is None:
288
+ raise ValueError("Token type ids must be provided during training")
289
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
290
+ token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
291
+ )
292
+
293
+ # Then apply padding mask (will mask pad tokens)
294
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
295
+ padding_mask = padding_mask == 0
296
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
297
+ padding_mask, min_dtype
298
+ )
299
+
300
+ return causal_mask
301
+
302
+ def get_image_features(self, pixel_values: torch.FloatTensor):
303
+ """
304
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
305
+
306
+ Args:
307
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
308
+ The tensors corresponding to the input images.
309
+ Returns:
310
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
311
+ """
312
+ image_outputs = self.vision_tower(pixel_values)
313
+ selected_image_feature = image_outputs.last_hidden_state
314
+ image_features = self.multi_modal_projector(selected_image_feature)
315
+ image_features = image_features / (self.config.text_config.hidden_size**0.5)
316
+ return image_features
317
+
318
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
319
+ @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
320
+ @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
321
+ def forward(
322
+ self,
323
+ input_ids: torch.LongTensor = None,
324
+ pixel_values: torch.FloatTensor = None,
325
+ attention_mask: Optional[torch.Tensor] = None,
326
+ position_ids: Optional[torch.LongTensor] = None,
327
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
328
+ token_type_ids: Optional[torch.LongTensor] = None,
329
+ cache_position: Optional[torch.LongTensor] = None,
330
+ inputs_embeds: Optional[torch.FloatTensor] = None,
331
+ labels: Optional[torch.LongTensor] = None,
332
+ use_cache: Optional[bool] = None,
333
+ output_attentions: Optional[bool] = None,
334
+ output_hidden_states: Optional[bool] = None,
335
+ return_dict: Optional[bool] = None,
336
+ logits_to_keep: Union[int, torch.Tensor] = 0,
337
+ **lm_kwargs,
338
+ ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
339
+ r"""
340
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
341
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
342
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
343
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
344
+
345
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
346
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
347
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
348
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
349
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
350
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
351
+
352
+ Returns:
353
+
354
+ Example:
355
+
356
+ ```python
357
+ >>> from PIL import Image
358
+ >>> import requests
359
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
360
+
361
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
362
+ >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
363
+
364
+ >>> prompt = "Where is the cat standing?"
365
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
366
+ >>> image = Image.open(requests.get(url, stream=True).raw)
367
+
368
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
369
+
370
+ >>> # Generate
371
+ >>> generate_ids = model.generate(**inputs,)
372
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
373
+ "Where is the cat standing?\nsnow"
374
+ ```"""
375
+ #save the original dtype before switching to 4bit when quantization
376
+ dtype = self.get_dtype()
377
+
378
+ if (input_ids is None) ^ (inputs_embeds is not None):
379
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
380
+
381
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
382
+ output_hidden_states = (
383
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
384
+ )
385
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
386
+
387
+ is_training = token_type_ids is not None and labels is not None
388
+
389
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
390
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size:
391
+ special_image_mask = input_ids == self.config.image_token_index
392
+ llm_input_ids = input_ids.clone()
393
+ llm_input_ids[special_image_mask] = 0
394
+ else:
395
+ llm_input_ids = input_ids
396
+
397
+ if inputs_embeds is None:
398
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
399
+
400
+ if cache_position is None:
401
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
402
+ cache_position = torch.arange(
403
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
404
+ )
405
+
406
+ if position_ids is None:
407
+ position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
408
+
409
+ # Merge text and images
410
+ if pixel_values is not None:
411
+ image_features = self.get_image_features(pixel_values)
412
+
413
+ if input_ids is None:
414
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
415
+ torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
416
+ )
417
+ else:
418
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
419
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
420
+
421
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
422
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
423
+ raise ValueError(
424
+ f"Number of images does not match number of special image tokens in the input text. "
425
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
426
+ "tokens from image embeddings."
427
+ )
428
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
429
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
430
+
431
+ # mask out pad-token-ids in labels for BC
432
+ if labels is not None and self.pad_token_id in labels:
433
+ logger.warning_once(
434
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
435
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
436
+ )
437
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
438
+
439
+ causal_mask = self._update_causal_mask(
440
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training,dtype=dtype
441
+ )
442
+ outputs = self.language_model(
443
+ attention_mask=causal_mask,
444
+ position_ids=position_ids,
445
+ past_key_values=past_key_values,
446
+ inputs_embeds=inputs_embeds,
447
+ use_cache=use_cache,
448
+ output_attentions=output_attentions,
449
+ output_hidden_states=output_hidden_states,
450
+ return_dict=return_dict,
451
+ cache_position=cache_position,
452
+ logits_to_keep=logits_to_keep,
453
+ **lm_kwargs,
454
+ )
455
+
456
+ logits = outputs[0]
457
+ loss = None
458
+ if labels is not None:
459
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
460
+ shift_logits = logits[..., :-1, :]
461
+ shift_labels = labels[..., 1:]
462
+
463
+ if attention_mask is not None:
464
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
465
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
466
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
467
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
468
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
469
+ else:
470
+ shift_logits = shift_logits.contiguous()
471
+ shift_labels = shift_labels.contiguous()
472
+ # Flatten the tokens
473
+ loss_fct = nn.CrossEntropyLoss()
474
+
475
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
476
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
477
+
478
+ valid_mask = flat_labels != -100
479
+
480
+ flat_labels = flat_labels[valid_mask]
481
+ flat_logits = flat_logits[valid_mask]
482
+
483
+ loss = loss_fct(flat_logits, flat_labels)
484
+ if not return_dict:
485
+ output = (logits,) + outputs[1:]
486
+ return (loss,) + output if loss is not None else output
487
+
488
+ return PaliGemmaCausalLMOutputWithPast(
489
+ loss=loss,
490
+ logits=logits,
491
+ past_key_values=outputs.past_key_values,
492
+ hidden_states=outputs.hidden_states,
493
+ attentions=outputs.attentions,
494
+ image_hidden_states=image_features if pixel_values is not None else None,
495
+ )
496
+
497
+ def prepare_inputs_for_generation(
498
+ self,
499
+ input_ids,
500
+ past_key_values=None,
501
+ inputs_embeds=None,
502
+ cache_position=None,
503
+ position_ids=None,
504
+ pixel_values=None,
505
+ attention_mask=None,
506
+ token_type_ids=None,
507
+ use_cache=True,
508
+ logits_to_keep=None,
509
+ labels=None,
510
+ **kwargs,
511
+ ):
512
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
513
+ model_inputs = self.language_model.prepare_inputs_for_generation(
514
+ input_ids,
515
+ past_key_values=past_key_values,
516
+ inputs_embeds=inputs_embeds,
517
+ attention_mask=attention_mask,
518
+ position_ids=position_ids,
519
+ cache_position=cache_position,
520
+ use_cache=use_cache,
521
+ logits_to_keep=logits_to_keep,
522
+ token_type_ids=token_type_ids,
523
+ **kwargs,
524
+ )
525
+
526
+ # position_ids in Paligemma are 1-indexed
527
+ if model_inputs.get("position_ids") is not None:
528
+ model_inputs["position_ids"] += 1
529
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
530
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
531
+ if cache_position[0] == 0:
532
+ model_inputs["pixel_values"] = pixel_values
533
+ is_training = token_type_ids is not None and labels is not None
534
+ if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
535
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
536
+ causal_mask = self._update_causal_mask(
537
+ attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
538
+ )
539
+ model_inputs["attention_mask"] = causal_mask
540
+
541
+ return model_inputs