Text Generation
Transformers
Safetensors
DIVEdoc
docvqa
distillation
VLM
document-understanding
OCR-free
custom_code
JayRay5 commited on
Commit
f687af2
·
verified ·
1 Parent(s): ccec861

Update modeling_divedoc.py

Browse files
Files changed (1) hide show
  1. modeling_divedoc.py +539 -540
modeling_divedoc.py CHANGED
@@ -1,541 +1,540 @@
1
- import sys
2
- from pathlib import Path
3
- parent_root = Path().resolve().parent.parent
4
- sys.path.append(str(parent_root))
5
-
6
-
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.utils.checkpoint
11
- import torch.nn.functional as F
12
-
13
-
14
- from transformers import Cache, HybridCache, StaticCache
15
- from transformers.modeling_outputs import BaseModelOutput
16
- from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, is_torchdynamo_compiling, replace_return_docstrings
17
- from transformers.utils.deprecation import deprecate_kwarg
18
- from transformers import PreTrainedModel, AutoConfig, PaliGemmaPreTrainedModel,AutoModelForCausalLM,GenerationMixin
19
- from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector, PaliGemmaCausalLMOutputWithPast, PALIGEMMA_INPUTS_DOCSTRING
20
- from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
21
- from transformers.models.donut.modeling_donut_swin import DonutSwinModel
22
-
23
-
24
- from .config_divedoc import SwinPamVisionEncoderConfig, SiglipPAMVisionEncoderConfig, DIVEdocConfig
25
- from typing import List, Optional, Tuple, Union
26
- from dataclasses import dataclass
27
-
28
-
29
- class PAM(nn.Module):
30
- def __init__(
31
- self,
32
- sequence_mapping_layer_type: Literal["linear_projection","bilinear","bicubic","nearest-exact"] = "bilinear",
33
- student_fmap_dim: Tuple[int,int]=(80,60),
34
- student_embedding_dim: int = 1024,
35
- teacher_fmap_dim: Tuple[int,int] = (64,64),
36
- teacher_embedding_dim: int = 1152
37
- ):
38
- super().__init__()
39
- self.sequence_mapping_layer_type = sequence_mapping_layer_type
40
- self.sequence_mapping_layer = nn.Linear(student_fmap_dim[0]*student_fmap_dim[1],teacher_fmap_dim[0]*teacher_fmap_dim[1]) if sequence_mapping_layer_type == "linear_projection" else None
41
- self.embedding_projection_layer = nn.Sequential(
42
- nn.Linear(student_embedding_dim,teacher_embedding_dim),
43
- nn.LayerNorm((teacher_embedding_dim,),eps=1e-06))
44
-
45
- self.student_fmap_dim = student_fmap_dim
46
- self.student_embedding_dim = student_embedding_dim
47
- self.teacher_fmap_dim = teacher_fmap_dim
48
- self.teacher_embedding_dim = teacher_embedding_dim
49
-
50
- print(self.student_fmap_dim)
51
- #take input x of shape (Batch, Nb_token, Dim_embedding)
52
- def forward(self,x:Tensor) -> Tensor:
53
- #
54
- '''
55
- if x.shape[1] != self.student_fmap_dim[0] * self.student_fmap_dim[1] or x.shape[2] != self.student_embedding_dim:
56
- raise ValueError(f"Expected input shape (*, {self.student_fmap_dim[0] * self.student_fmap_dim[1],self.student_embedding_dim}), "
57
- f"but got {x.shape}")
58
- '''
59
-
60
- if x.shape[1]!=(self.teacher_fmap_dim[0]*self.teacher_fmap_dim[1]):
61
- print(x.shape[1])
62
- print(self.teacher_fmap_dim[0]*self.teacher_fmap_dim[1])
63
- print("Resizing")
64
- if self.sequence_mapping_layer_type == "linear_projection":
65
- x = torch.permute(x,(0,2,1))
66
- x = self.sequence_mapping_layer(x)
67
- x = torch.permute(x,(0,2,1))
68
-
69
- elif self.sequence_mapping_layer_type in ["bilinear","bicubic","nearest-exact"]:
70
- batch_size,_,embedding_size = x.size()
71
- x = x.view(batch_size,self.student_fmap_dim[0],self.student_fmap_dim[1],embedding_size).permute(0,3, 1, 2)
72
- x = F.interpolate(x,size=self.teacher_fmap_dim,mode=self.sequence_mapping_layer_type) # Shape: (1, D, target_height, target_width)
73
- x = x.permute(0,2, 3, 1).reshape(batch_size,-1, embedding_size)
74
-
75
- x = self.embedding_projection_layer(x)
76
- return x
77
-
78
- class SwinPam(nn.Module):
79
- def __init__(
80
- self,
81
- encoder_config: AutoConfig,
82
- pam_sequence_mapping_layer_type: Literal["linear_projection","bilinear","bicubic","nearest-exact"] = "bilinear",
83
- pam_student_fmap_dim: Tuple[int,int] = (80,60),
84
- pam_student_embedding_dim: int = 1024,
85
- pam_teacher_fmap_dim: Tuple[int,int] = (64,64),
86
- pam_teacher_embedding_dim: int = 1152
87
- ):
88
- super().__init__()
89
- self.encoder_model = DonutSwinModel(encoder_config)
90
- print(pam_student_fmap_dim)
91
- self.pam = PAM(
92
- sequence_mapping_layer_type = pam_sequence_mapping_layer_type,
93
- student_fmap_dim = pam_student_fmap_dim,
94
- student_embedding_dim = pam_student_embedding_dim,
95
- teacher_fmap_dim = pam_teacher_fmap_dim,
96
- teacher_embedding_dim = pam_teacher_embedding_dim)
97
-
98
- def forward(self,x):
99
- x = self.encoder_model(x).last_hidden_state
100
- x = self.pam(x)
101
- return x
102
-
103
-
104
-
105
- @dataclass
106
- class SwinPamVisionEncoderOutput(ModelOutput):
107
- """
108
- Base class for PaliGemmacausal language model (or autoregressive) outputs.
109
-
110
- Args:
111
- last_hidden_states (`torch.FloatTensor`, *optional*):
112
- A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
113
- image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
114
- """
115
-
116
- last_hidden_states: Optional[torch.FloatTensor] = None
117
-
118
- class SwinPamVisionEncoder(PreTrainedModel):
119
- config_class = SwinPamVisionEncoderConfig
120
- keys_to_ignore_at_inference = ["past_key_values"]
121
-
122
- def __init__(self, config):
123
- super().__init__(config)
124
- self.model = SwinPam(
125
- config.encoder_config,
126
- config.pam_config.sequence_mapping_layer_type,
127
- config.pam_config.student_fmap_dim,
128
- config.pam_config.student_embedding_dim,
129
- config.pam_config.teacher_fmap_dim,
130
- config.pam_config.teacher_embedding_dim,
131
- )
132
- def forward(self,x):
133
- x = self.model(x)
134
- return BaseModelOutput(last_hidden_state=x)
135
-
136
- class SiglipPAMVisionEncoder(PreTrainedModel):
137
- config_class = SiglipPAMVisionEncoderConfig
138
- keys_to_ignore_at_inference = ["past_key_values"]
139
-
140
- def __init__(self, config):
141
- super().__init__(config)
142
- self.model = SiglipPAM(
143
- config.encoder_config,
144
- config.pam_config.sequence_mapping_layer_type,
145
- config.pam_config.student_fmap_dim,
146
- config.pam_config.student_embedding_dim,
147
- config.pam_config.teacher_fmap_dim,
148
- config.pam_config.teacher_embedding_dim,
149
- )
150
- def forward(self,x):
151
- x = self.model(x)
152
- return BaseModelOutput(last_hidden_state=x)
153
-
154
-
155
- class PaliGemmaMultiModalProjector(nn.Module):
156
- def __init__(self, config: PaliGemmaConfig):
157
- super().__init__()
158
- self.linear = nn.Linear(config.vision_config.pam_config.teacher_embedding_dim, config.vision_config.projection_dim, bias=True)
159
-
160
- def forward(self, image_features):
161
- hidden_states = self.linear(image_features)
162
-
163
- return hidden_states
164
-
165
-
166
-
167
- _CONFIG_FOR_DOC = "DIVEdocConfig"
168
- class DIVEdoc(PaliGemmaPreTrainedModel, GenerationMixin):
169
- config_class = DIVEdocConfig
170
- def __init__(self, config: DIVEdocConfig):
171
- super().__init__(config)
172
-
173
- print(f"Vision config in end-to-end model: {config.vision_config.model_type}")
174
- if config.vision_config.model_type == "swinpam":
175
- self.vision_tower = SwinPamVisionEncoder(config=config.vision_config)
176
-
177
- elif config.vision_config.model_type == "siglippam":
178
- self.vision_tower = SiglipPAMVisionEncoder(config=config.vision_config)
179
-
180
- else:
181
- raise ValueError("Unknown model_type in vision_config")
182
-
183
- self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
184
- self.vocab_size = config.text_config.vocab_size
185
-
186
- language_model = AutoModelForCausalLM.from_config(config=config.text_config)
187
-
188
- if language_model._tied_weights_keys is not None:
189
- self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
190
- self.language_model = language_model
191
-
192
- self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
193
- self.post_init()
194
-
195
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
196
- def get_input_embeddings(self):
197
- return self.language_model.get_input_embeddings()
198
-
199
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
200
- def set_input_embeddings(self, value):
201
- self.language_model.set_input_embeddings(value)
202
-
203
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
204
- def get_output_embeddings(self):
205
- return self.language_model.get_output_embeddings()
206
-
207
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
208
- def set_output_embeddings(self, new_embeddings):
209
- self.language_model.set_output_embeddings(new_embeddings)
210
-
211
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
212
- def set_decoder(self, decoder):
213
- self.language_model.set_decoder(decoder)
214
-
215
- # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
216
- def get_decoder(self):
217
- return self.language_model.get_decoder()
218
- def get_dtype(self):
219
- return self.dtype
220
-
221
- def _update_causal_mask(
222
- self,
223
- attention_mask,
224
- token_type_ids=None,
225
- past_key_values=None,
226
- cache_position=None,
227
- input_tensor=None,
228
- is_training: bool = None,
229
- dtype=None, #to handle quantized finetuning issue when model switch between 4 or 8bit and float
230
- ):
231
- if self.config.text_config._attn_implementation == "flash_attention_2":
232
- if attention_mask is not None and 0.0 in attention_mask:
233
- return attention_mask
234
- return None
235
- is_training = is_training if is_training is not None else self.training
236
- using_static_cache = isinstance(past_key_values, StaticCache)
237
-
238
- # Handle the case when the model is quantized in 4 or 8 bit
239
-
240
- if dtype is not None:
241
- min_dtype = torch.finfo(dtype).min
242
- else:
243
- min_dtype = torch.finfo(self.get_dtype()).min
244
-
245
-
246
- if input_tensor is None:
247
- input_tensor = attention_mask
248
-
249
- inputs_lead_dim, sequence_length = input_tensor.shape[:2]
250
- if using_static_cache:
251
- target_length = past_key_values.get_max_cache_shape()
252
- elif isinstance(past_key_values, HybridCache):
253
- target_length = past_key_values.get_max_cache_shape()
254
- else:
255
- target_length = (
256
- attention_mask.shape[-1]
257
- if isinstance(attention_mask, torch.Tensor)
258
- else cache_position[0] + sequence_length + 1
259
- )
260
-
261
- if attention_mask is not None and attention_mask.dim() == 4:
262
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
263
- return attention_mask
264
- ''' initial line but changed for quantization processing
265
- causal_mask = torch.full(
266
- (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
267
- )
268
- '''
269
- causal_mask = torch.full(
270
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
271
- )
272
- # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
273
- if sequence_length != 1:
274
- if is_training:
275
- causal_mask = torch.triu(causal_mask, diagonal=1)
276
- else:
277
- causal_mask[:, :sequence_length] = 0.0
278
-
279
- causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
280
- causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
281
- if attention_mask is not None:
282
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
283
- mask_length = attention_mask.shape[-1]
284
-
285
- # First unmask prefix tokens during training
286
- if is_training:
287
- if token_type_ids is None:
288
- raise ValueError("Token type ids must be provided during training")
289
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
290
- token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
291
- )
292
-
293
- # Then apply padding mask (will mask pad tokens)
294
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
295
- padding_mask = padding_mask == 0
296
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
297
- padding_mask, min_dtype
298
- )
299
-
300
- return causal_mask
301
-
302
- def get_image_features(self, pixel_values: torch.FloatTensor):
303
- """
304
- Obtains image last hidden states from the vision tower and apply multimodal projection.
305
-
306
- Args:
307
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
308
- The tensors corresponding to the input images.
309
- Returns:
310
- image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
311
- """
312
- image_outputs = self.vision_tower(pixel_values)
313
- selected_image_feature = image_outputs.last_hidden_state
314
- image_features = self.multi_modal_projector(selected_image_feature)
315
- image_features = image_features / (self.config.text_config.hidden_size**0.5)
316
- return image_features
317
-
318
- @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
319
- @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING)
320
- @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
321
- def forward(
322
- self,
323
- input_ids: torch.LongTensor = None,
324
- pixel_values: torch.FloatTensor = None,
325
- attention_mask: Optional[torch.Tensor] = None,
326
- position_ids: Optional[torch.LongTensor] = None,
327
- past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
328
- token_type_ids: Optional[torch.LongTensor] = None,
329
- cache_position: Optional[torch.LongTensor] = None,
330
- inputs_embeds: Optional[torch.FloatTensor] = None,
331
- labels: Optional[torch.LongTensor] = None,
332
- use_cache: Optional[bool] = None,
333
- output_attentions: Optional[bool] = None,
334
- output_hidden_states: Optional[bool] = None,
335
- return_dict: Optional[bool] = None,
336
- logits_to_keep: Union[int, torch.Tensor] = 0,
337
- **lm_kwargs,
338
- ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
339
- r"""
340
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
341
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
342
- config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
343
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
344
-
345
- logits_to_keep (`int` or `torch.Tensor`, *optional*):
346
- If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
347
- `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
348
- token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
349
- If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
350
- This is useful when using packed tensor format (single dimension for batch and sequence length).
351
-
352
- Returns:
353
-
354
- Example:
355
-
356
- ```python
357
- >>> from PIL import Image
358
- >>> import requests
359
- >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
360
-
361
- >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
362
- >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
363
-
364
- >>> prompt = "Where is the cat standing?"
365
- >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
366
- >>> image = Image.open(requests.get(url, stream=True).raw)
367
-
368
- >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
369
-
370
- >>> # Generate
371
- >>> generate_ids = model.generate(**inputs,)
372
- >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
373
- "Where is the cat standing?\nsnow"
374
- ```"""
375
- #save the original dtype before switching to 4bit when quantization
376
- dtype = self.get_dtype()
377
-
378
- if (input_ids is None) ^ (inputs_embeds is not None):
379
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
380
-
381
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
382
- output_hidden_states = (
383
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
384
- )
385
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
386
-
387
- is_training = token_type_ids is not None and labels is not None
388
-
389
- # Replace image id woth PAD if the image token if OOV, to avoid index-errors
390
- if input_ids is not None and self.config.image_token_index >= self.vocab_size:
391
- special_image_mask = input_ids == self.config.image_token_index
392
- llm_input_ids = input_ids.clone()
393
- llm_input_ids[special_image_mask] = 0
394
- else:
395
- llm_input_ids = input_ids
396
-
397
- if inputs_embeds is None:
398
- inputs_embeds = self.get_input_embeddings()(llm_input_ids)
399
-
400
- if cache_position is None:
401
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
402
- cache_position = torch.arange(
403
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
404
- )
405
-
406
- if position_ids is None:
407
- position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
408
-
409
- # Merge text and images
410
- if pixel_values is not None:
411
- image_features = self.get_image_features(pixel_values)
412
-
413
- if input_ids is None:
414
- special_image_mask = inputs_embeds == self.get_input_embeddings()(
415
- torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
416
- )
417
- else:
418
- special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
419
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
420
-
421
- if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
422
- image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
423
- raise ValueError(
424
- f"Number of images does not match number of special image tokens in the input text. "
425
- f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
426
- "tokens from image embeddings."
427
- )
428
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
429
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
430
-
431
- # mask out pad-token-ids in labels for BC
432
- if labels is not None and self.pad_token_id in labels:
433
- logger.warning_once(
434
- "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
435
- "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
436
- )
437
- labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
438
-
439
- causal_mask = self._update_causal_mask(
440
- attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training,dtype=dtype
441
- )
442
- outputs = self.language_model(
443
- attention_mask=causal_mask,
444
- position_ids=position_ids,
445
- past_key_values=past_key_values,
446
- inputs_embeds=inputs_embeds,
447
- use_cache=use_cache,
448
- output_attentions=output_attentions,
449
- output_hidden_states=output_hidden_states,
450
- return_dict=return_dict,
451
- cache_position=cache_position,
452
- logits_to_keep=logits_to_keep,
453
- **lm_kwargs,
454
- )
455
-
456
- logits = outputs[0]
457
- loss = None
458
- if labels is not None:
459
- # Upcast to float if we need to compute the loss to avoid potential precision issues
460
- shift_logits = logits[..., :-1, :]
461
- shift_labels = labels[..., 1:]
462
-
463
- if attention_mask is not None:
464
- # we use the input attention mask to shift the logits and labels, because it is 2D.
465
- # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
466
- shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
467
- shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
468
- shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
469
- else:
470
- shift_logits = shift_logits.contiguous()
471
- shift_labels = shift_labels.contiguous()
472
- # Flatten the tokens
473
- loss_fct = nn.CrossEntropyLoss()
474
-
475
- flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
476
- flat_labels = shift_labels.view(-1).to(shift_logits.device)
477
-
478
- valid_mask = flat_labels != -100
479
-
480
- flat_labels = flat_labels[valid_mask]
481
- flat_logits = flat_logits[valid_mask]
482
-
483
- loss = loss_fct(flat_logits, flat_labels)
484
- if not return_dict:
485
- output = (logits,) + outputs[1:]
486
- return (loss,) + output if loss is not None else output
487
-
488
- return PaliGemmaCausalLMOutputWithPast(
489
- loss=loss,
490
- logits=logits,
491
- past_key_values=outputs.past_key_values,
492
- hidden_states=outputs.hidden_states,
493
- attentions=outputs.attentions,
494
- image_hidden_states=image_features if pixel_values is not None else None,
495
- )
496
-
497
- def prepare_inputs_for_generation(
498
- self,
499
- input_ids,
500
- past_key_values=None,
501
- inputs_embeds=None,
502
- cache_position=None,
503
- position_ids=None,
504
- pixel_values=None,
505
- attention_mask=None,
506
- token_type_ids=None,
507
- use_cache=True,
508
- logits_to_keep=None,
509
- labels=None,
510
- **kwargs,
511
- ):
512
- # Overwritten -- custom `position_ids` and `pixel_values` handling
513
- model_inputs = self.language_model.prepare_inputs_for_generation(
514
- input_ids,
515
- past_key_values=past_key_values,
516
- inputs_embeds=inputs_embeds,
517
- attention_mask=attention_mask,
518
- position_ids=position_ids,
519
- cache_position=cache_position,
520
- use_cache=use_cache,
521
- logits_to_keep=logits_to_keep,
522
- token_type_ids=token_type_ids,
523
- **kwargs,
524
- )
525
-
526
- # position_ids in Paligemma are 1-indexed
527
- if model_inputs.get("position_ids") is not None:
528
- model_inputs["position_ids"] += 1
529
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
530
- # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
531
- if cache_position[0] == 0:
532
- model_inputs["pixel_values"] = pixel_values
533
- is_training = token_type_ids is not None and labels is not None
534
- if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
535
- input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
536
- causal_mask = self._update_causal_mask(
537
- attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
538
- )
539
- model_inputs["attention_mask"] = causal_mask
540
-
541
  return model_inputs
 
1
+ import sys
2
+ from pathlib import Path
3
+ parent_root = Path().resolve().parent.parent
4
+ sys.path.append(str(parent_root))
5
+
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.utils.checkpoint
11
+ import torch.nn.functional as F
12
+
13
+
14
+ from transformers import Cache, HybridCache, StaticCache
15
+ from transformers.modeling_outputs import BaseModelOutput
16
+ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, is_torchdynamo_compiling, replace_return_docstrings
17
+ from transformers.utils.deprecation import deprecate_kwarg
18
+ from transformers import PreTrainedModel, AutoConfig, PaliGemmaPreTrainedModel,AutoModelForCausalLM,GenerationMixin
19
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector, PaliGemmaCausalLMOutputWithPast
20
+ from transformers.models.paligemma.configuration_paligemma import PaliGemmaConfig
21
+ from transformers.models.donut.modeling_donut_swin import DonutSwinModel
22
+
23
+
24
+ from .config_divedoc import SwinPamVisionEncoderConfig, SiglipPAMVisionEncoderConfig, DIVEdocConfig
25
+ from typing import List, Optional, Tuple, Union
26
+ from dataclasses import dataclass
27
+
28
+
29
+ class PAM(nn.Module):
30
+ def __init__(
31
+ self,
32
+ sequence_mapping_layer_type: Literal["linear_projection","bilinear","bicubic","nearest-exact"] = "bilinear",
33
+ student_fmap_dim: Tuple[int,int]=(80,60),
34
+ student_embedding_dim: int = 1024,
35
+ teacher_fmap_dim: Tuple[int,int] = (64,64),
36
+ teacher_embedding_dim: int = 1152
37
+ ):
38
+ super().__init__()
39
+ self.sequence_mapping_layer_type = sequence_mapping_layer_type
40
+ self.sequence_mapping_layer = nn.Linear(student_fmap_dim[0]*student_fmap_dim[1],teacher_fmap_dim[0]*teacher_fmap_dim[1]) if sequence_mapping_layer_type == "linear_projection" else None
41
+ self.embedding_projection_layer = nn.Sequential(
42
+ nn.Linear(student_embedding_dim,teacher_embedding_dim),
43
+ nn.LayerNorm((teacher_embedding_dim,),eps=1e-06))
44
+
45
+ self.student_fmap_dim = student_fmap_dim
46
+ self.student_embedding_dim = student_embedding_dim
47
+ self.teacher_fmap_dim = teacher_fmap_dim
48
+ self.teacher_embedding_dim = teacher_embedding_dim
49
+
50
+ print(self.student_fmap_dim)
51
+ #take input x of shape (Batch, Nb_token, Dim_embedding)
52
+ def forward(self,x:Tensor) -> Tensor:
53
+ #
54
+ '''
55
+ if x.shape[1] != self.student_fmap_dim[0] * self.student_fmap_dim[1] or x.shape[2] != self.student_embedding_dim:
56
+ raise ValueError(f"Expected input shape (*, {self.student_fmap_dim[0] * self.student_fmap_dim[1],self.student_embedding_dim}), "
57
+ f"but got {x.shape}")
58
+ '''
59
+
60
+ if x.shape[1]!=(self.teacher_fmap_dim[0]*self.teacher_fmap_dim[1]):
61
+ print(x.shape[1])
62
+ print(self.teacher_fmap_dim[0]*self.teacher_fmap_dim[1])
63
+ print("Resizing")
64
+ if self.sequence_mapping_layer_type == "linear_projection":
65
+ x = torch.permute(x,(0,2,1))
66
+ x = self.sequence_mapping_layer(x)
67
+ x = torch.permute(x,(0,2,1))
68
+
69
+ elif self.sequence_mapping_layer_type in ["bilinear","bicubic","nearest-exact"]:
70
+ batch_size,_,embedding_size = x.size()
71
+ x = x.view(batch_size,self.student_fmap_dim[0],self.student_fmap_dim[1],embedding_size).permute(0,3, 1, 2)
72
+ x = F.interpolate(x,size=self.teacher_fmap_dim,mode=self.sequence_mapping_layer_type) # Shape: (1, D, target_height, target_width)
73
+ x = x.permute(0,2, 3, 1).reshape(batch_size,-1, embedding_size)
74
+
75
+ x = self.embedding_projection_layer(x)
76
+ return x
77
+
78
+ class SwinPam(nn.Module):
79
+ def __init__(
80
+ self,
81
+ encoder_config: AutoConfig,
82
+ pam_sequence_mapping_layer_type: Literal["linear_projection","bilinear","bicubic","nearest-exact"] = "bilinear",
83
+ pam_student_fmap_dim: Tuple[int,int] = (80,60),
84
+ pam_student_embedding_dim: int = 1024,
85
+ pam_teacher_fmap_dim: Tuple[int,int] = (64,64),
86
+ pam_teacher_embedding_dim: int = 1152
87
+ ):
88
+ super().__init__()
89
+ self.encoder_model = DonutSwinModel(encoder_config)
90
+ print(pam_student_fmap_dim)
91
+ self.pam = PAM(
92
+ sequence_mapping_layer_type = pam_sequence_mapping_layer_type,
93
+ student_fmap_dim = pam_student_fmap_dim,
94
+ student_embedding_dim = pam_student_embedding_dim,
95
+ teacher_fmap_dim = pam_teacher_fmap_dim,
96
+ teacher_embedding_dim = pam_teacher_embedding_dim)
97
+
98
+ def forward(self,x):
99
+ x = self.encoder_model(x).last_hidden_state
100
+ x = self.pam(x)
101
+ return x
102
+
103
+
104
+
105
+ @dataclass
106
+ class SwinPamVisionEncoderOutput(ModelOutput):
107
+ """
108
+ Base class for PaliGemmacausal language model (or autoregressive) outputs.
109
+
110
+ Args:
111
+ last_hidden_states (`torch.FloatTensor`, *optional*):
112
+ A `torch.FloatTensor` of size `(batch_size, sequence_length, hidden_size)`.
113
+ image_hidden_states of the model produced by the vision encoder after projecting last hidden state.
114
+ """
115
+
116
+ last_hidden_states: Optional[torch.FloatTensor] = None
117
+
118
+ class SwinPamVisionEncoder(PreTrainedModel):
119
+ config_class = SwinPamVisionEncoderConfig
120
+ keys_to_ignore_at_inference = ["past_key_values"]
121
+
122
+ def __init__(self, config):
123
+ super().__init__(config)
124
+ self.model = SwinPam(
125
+ config.encoder_config,
126
+ config.pam_config.sequence_mapping_layer_type,
127
+ config.pam_config.student_fmap_dim,
128
+ config.pam_config.student_embedding_dim,
129
+ config.pam_config.teacher_fmap_dim,
130
+ config.pam_config.teacher_embedding_dim,
131
+ )
132
+ def forward(self,x):
133
+ x = self.model(x)
134
+ return BaseModelOutput(last_hidden_state=x)
135
+
136
+ class SiglipPAMVisionEncoder(PreTrainedModel):
137
+ config_class = SiglipPAMVisionEncoderConfig
138
+ keys_to_ignore_at_inference = ["past_key_values"]
139
+
140
+ def __init__(self, config):
141
+ super().__init__(config)
142
+ self.model = SiglipPAM(
143
+ config.encoder_config,
144
+ config.pam_config.sequence_mapping_layer_type,
145
+ config.pam_config.student_fmap_dim,
146
+ config.pam_config.student_embedding_dim,
147
+ config.pam_config.teacher_fmap_dim,
148
+ config.pam_config.teacher_embedding_dim,
149
+ )
150
+ def forward(self,x):
151
+ x = self.model(x)
152
+ return BaseModelOutput(last_hidden_state=x)
153
+
154
+
155
+ class PaliGemmaMultiModalProjector(nn.Module):
156
+ def __init__(self, config: PaliGemmaConfig):
157
+ super().__init__()
158
+ self.linear = nn.Linear(config.vision_config.pam_config.teacher_embedding_dim, config.vision_config.projection_dim, bias=True)
159
+
160
+ def forward(self, image_features):
161
+ hidden_states = self.linear(image_features)
162
+
163
+ return hidden_states
164
+
165
+
166
+
167
+ _CONFIG_FOR_DOC = "DIVEdocConfig"
168
+ class DIVEdoc(PaliGemmaPreTrainedModel, GenerationMixin):
169
+ config_class = DIVEdocConfig
170
+ def __init__(self, config: DIVEdocConfig):
171
+ super().__init__(config)
172
+
173
+ print(f"Vision config in end-to-end model: {config.vision_config.model_type}")
174
+ if config.vision_config.model_type == "swinpam":
175
+ self.vision_tower = SwinPamVisionEncoder(config=config.vision_config)
176
+
177
+ elif config.vision_config.model_type == "siglippam":
178
+ self.vision_tower = SiglipPAMVisionEncoder(config=config.vision_config)
179
+
180
+ else:
181
+ raise ValueError("Unknown model_type in vision_config")
182
+
183
+ self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
184
+ self.vocab_size = config.text_config.vocab_size
185
+
186
+ language_model = AutoModelForCausalLM.from_config(config=config.text_config)
187
+
188
+ if language_model._tied_weights_keys is not None:
189
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
190
+ self.language_model = language_model
191
+
192
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
193
+ self.post_init()
194
+
195
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_input_embeddings with Llava->PaliGemma
196
+ def get_input_embeddings(self):
197
+ return self.language_model.get_input_embeddings()
198
+
199
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_input_embeddings with Llava->PaliGemma
200
+ def set_input_embeddings(self, value):
201
+ self.language_model.set_input_embeddings(value)
202
+
203
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_output_embeddings with Llava->PaliGemma
204
+ def get_output_embeddings(self):
205
+ return self.language_model.get_output_embeddings()
206
+
207
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_output_embeddings with Llava->PaliGemma
208
+ def set_output_embeddings(self, new_embeddings):
209
+ self.language_model.set_output_embeddings(new_embeddings)
210
+
211
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.set_decoder with Llava->PaliGemma
212
+ def set_decoder(self, decoder):
213
+ self.language_model.set_decoder(decoder)
214
+
215
+ # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.get_decoder with Llava->PaliGemma
216
+ def get_decoder(self):
217
+ return self.language_model.get_decoder()
218
+ def get_dtype(self):
219
+ return self.dtype
220
+
221
+ def _update_causal_mask(
222
+ self,
223
+ attention_mask,
224
+ token_type_ids=None,
225
+ past_key_values=None,
226
+ cache_position=None,
227
+ input_tensor=None,
228
+ is_training: bool = None,
229
+ dtype=None, #to handle quantized finetuning issue when model switch between 4 or 8bit and float
230
+ ):
231
+ if self.config.text_config._attn_implementation == "flash_attention_2":
232
+ if attention_mask is not None and 0.0 in attention_mask:
233
+ return attention_mask
234
+ return None
235
+ is_training = is_training if is_training is not None else self.training
236
+ using_static_cache = isinstance(past_key_values, StaticCache)
237
+
238
+ # Handle the case when the model is quantized in 4 or 8 bit
239
+
240
+ if dtype is not None:
241
+ min_dtype = torch.finfo(dtype).min
242
+ else:
243
+ min_dtype = torch.finfo(self.get_dtype()).min
244
+
245
+
246
+ if input_tensor is None:
247
+ input_tensor = attention_mask
248
+
249
+ inputs_lead_dim, sequence_length = input_tensor.shape[:2]
250
+ if using_static_cache:
251
+ target_length = past_key_values.get_max_cache_shape()
252
+ elif isinstance(past_key_values, HybridCache):
253
+ target_length = past_key_values.get_max_cache_shape()
254
+ else:
255
+ target_length = (
256
+ attention_mask.shape[-1]
257
+ if isinstance(attention_mask, torch.Tensor)
258
+ else cache_position[0] + sequence_length + 1
259
+ )
260
+
261
+ if attention_mask is not None and attention_mask.dim() == 4:
262
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
263
+ return attention_mask
264
+ ''' initial line but changed for quantization processing
265
+ causal_mask = torch.full(
266
+ (sequence_length, target_length), fill_value=min_dtype, dtype=self.dtype, device=cache_position.device
267
+ )
268
+ '''
269
+ causal_mask = torch.full(
270
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
271
+ )
272
+ # Causal diagonal mask only if training, otherwise attend to the whole prefix. Training-specific attn for prefix is handled below
273
+ if sequence_length != 1:
274
+ if is_training:
275
+ causal_mask = torch.triu(causal_mask, diagonal=1)
276
+ else:
277
+ causal_mask[:, :sequence_length] = 0.0
278
+
279
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
280
+ causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1)
281
+ if attention_mask is not None:
282
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
283
+ mask_length = attention_mask.shape[-1]
284
+
285
+ # First unmask prefix tokens during training
286
+ if is_training:
287
+ if token_type_ids is None:
288
+ raise ValueError("Token type ids must be provided during training")
289
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
290
+ token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
291
+ )
292
+
293
+ # Then apply padding mask (will mask pad tokens)
294
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
295
+ padding_mask = padding_mask == 0
296
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
297
+ padding_mask, min_dtype
298
+ )
299
+
300
+ return causal_mask
301
+
302
+ def get_image_features(self, pixel_values: torch.FloatTensor):
303
+ """
304
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
305
+
306
+ Args:
307
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
308
+ The tensors corresponding to the input images.
309
+ Returns:
310
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
311
+ """
312
+ image_outputs = self.vision_tower(pixel_values)
313
+ selected_image_feature = image_outputs.last_hidden_state
314
+ image_features = self.multi_modal_projector(selected_image_feature)
315
+ image_features = image_features / (self.config.text_config.hidden_size**0.5)
316
+ return image_features
317
+
318
+ @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
319
+ @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
320
+ def forward(
321
+ self,
322
+ input_ids: torch.LongTensor = None,
323
+ pixel_values: torch.FloatTensor = None,
324
+ attention_mask: Optional[torch.Tensor] = None,
325
+ position_ids: Optional[torch.LongTensor] = None,
326
+ past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
327
+ token_type_ids: Optional[torch.LongTensor] = None,
328
+ cache_position: Optional[torch.LongTensor] = None,
329
+ inputs_embeds: Optional[torch.FloatTensor] = None,
330
+ labels: Optional[torch.LongTensor] = None,
331
+ use_cache: Optional[bool] = None,
332
+ output_attentions: Optional[bool] = None,
333
+ output_hidden_states: Optional[bool] = None,
334
+ return_dict: Optional[bool] = None,
335
+ logits_to_keep: Union[int, torch.Tensor] = 0,
336
+ **lm_kwargs,
337
+ ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:
338
+ r"""
339
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
340
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
341
+ config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
342
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`.
343
+
344
+ logits_to_keep (`int` or `torch.Tensor`, *optional*):
345
+ If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
346
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
347
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
348
+ If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
349
+ This is useful when using packed tensor format (single dimension for batch and sequence length).
350
+
351
+ Returns:
352
+
353
+ Example:
354
+
355
+ ```python
356
+ >>> from PIL import Image
357
+ >>> import requests
358
+ >>> from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
359
+
360
+ >>> model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma2-3b-mix-224")
361
+ >>> processor = AutoProcessor.from_pretrained("google/paligemma2-3b-mix-224")
362
+
363
+ >>> prompt = "Where is the cat standing?"
364
+ >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"
365
+ >>> image = Image.open(requests.get(url, stream=True).raw)
366
+
367
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt")
368
+
369
+ >>> # Generate
370
+ >>> generate_ids = model.generate(**inputs,)
371
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
372
+ "Where is the cat standing?\nsnow"
373
+ ```"""
374
+ #save the original dtype before switching to 4bit when quantization
375
+ dtype = self.get_dtype()
376
+
377
+ if (input_ids is None) ^ (inputs_embeds is not None):
378
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
379
+
380
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
381
+ output_hidden_states = (
382
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
383
+ )
384
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
385
+
386
+ is_training = token_type_ids is not None and labels is not None
387
+
388
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
389
+ if input_ids is not None and self.config.image_token_index >= self.vocab_size:
390
+ special_image_mask = input_ids == self.config.image_token_index
391
+ llm_input_ids = input_ids.clone()
392
+ llm_input_ids[special_image_mask] = 0
393
+ else:
394
+ llm_input_ids = input_ids
395
+
396
+ if inputs_embeds is None:
397
+ inputs_embeds = self.get_input_embeddings()(llm_input_ids)
398
+
399
+ if cache_position is None:
400
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
401
+ cache_position = torch.arange(
402
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
403
+ )
404
+
405
+ if position_ids is None:
406
+ position_ids = cache_position.unsqueeze(0) + 1 # Paligemma positions are 1-indexed
407
+
408
+ # Merge text and images
409
+ if pixel_values is not None:
410
+ image_features = self.get_image_features(pixel_values)
411
+
412
+ if input_ids is None:
413
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
414
+ torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device)
415
+ )
416
+ else:
417
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
418
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
419
+
420
+ if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel():
421
+ image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0]
422
+ raise ValueError(
423
+ f"Number of images does not match number of special image tokens in the input text. "
424
+ f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
425
+ "tokens from image embeddings."
426
+ )
427
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
428
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
429
+
430
+ # mask out pad-token-ids in labels for BC
431
+ if labels is not None and self.pad_token_id in labels:
432
+ logger.warning_once(
433
+ "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. "
434
+ "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.",
435
+ )
436
+ labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
437
+
438
+ causal_mask = self._update_causal_mask(
439
+ attention_mask, token_type_ids, past_key_values, cache_position, inputs_embeds, is_training,dtype=dtype
440
+ )
441
+ outputs = self.language_model(
442
+ attention_mask=causal_mask,
443
+ position_ids=position_ids,
444
+ past_key_values=past_key_values,
445
+ inputs_embeds=inputs_embeds,
446
+ use_cache=use_cache,
447
+ output_attentions=output_attentions,
448
+ output_hidden_states=output_hidden_states,
449
+ return_dict=return_dict,
450
+ cache_position=cache_position,
451
+ logits_to_keep=logits_to_keep,
452
+ **lm_kwargs,
453
+ )
454
+
455
+ logits = outputs[0]
456
+ loss = None
457
+ if labels is not None:
458
+ # Upcast to float if we need to compute the loss to avoid potential precision issues
459
+ shift_logits = logits[..., :-1, :]
460
+ shift_labels = labels[..., 1:]
461
+
462
+ if attention_mask is not None:
463
+ # we use the input attention mask to shift the logits and labels, because it is 2D.
464
+ # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft
465
+ shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to(logits.device)
466
+ shift_logits = shift_logits[shift_attention_mask.to(logits.device) != 0].contiguous()
467
+ shift_labels = shift_labels[shift_attention_mask.to(shift_labels.device) != 0].contiguous()
468
+ else:
469
+ shift_logits = shift_logits.contiguous()
470
+ shift_labels = shift_labels.contiguous()
471
+ # Flatten the tokens
472
+ loss_fct = nn.CrossEntropyLoss()
473
+
474
+ flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
475
+ flat_labels = shift_labels.view(-1).to(shift_logits.device)
476
+
477
+ valid_mask = flat_labels != -100
478
+
479
+ flat_labels = flat_labels[valid_mask]
480
+ flat_logits = flat_logits[valid_mask]
481
+
482
+ loss = loss_fct(flat_logits, flat_labels)
483
+ if not return_dict:
484
+ output = (logits,) + outputs[1:]
485
+ return (loss,) + output if loss is not None else output
486
+
487
+ return PaliGemmaCausalLMOutputWithPast(
488
+ loss=loss,
489
+ logits=logits,
490
+ past_key_values=outputs.past_key_values,
491
+ hidden_states=outputs.hidden_states,
492
+ attentions=outputs.attentions,
493
+ image_hidden_states=image_features if pixel_values is not None else None,
494
+ )
495
+
496
+ def prepare_inputs_for_generation(
497
+ self,
498
+ input_ids,
499
+ past_key_values=None,
500
+ inputs_embeds=None,
501
+ cache_position=None,
502
+ position_ids=None,
503
+ pixel_values=None,
504
+ attention_mask=None,
505
+ token_type_ids=None,
506
+ use_cache=True,
507
+ logits_to_keep=None,
508
+ labels=None,
509
+ **kwargs,
510
+ ):
511
+ # Overwritten -- custom `position_ids` and `pixel_values` handling
512
+ model_inputs = self.language_model.prepare_inputs_for_generation(
513
+ input_ids,
514
+ past_key_values=past_key_values,
515
+ inputs_embeds=inputs_embeds,
516
+ attention_mask=attention_mask,
517
+ position_ids=position_ids,
518
+ cache_position=cache_position,
519
+ use_cache=use_cache,
520
+ logits_to_keep=logits_to_keep,
521
+ token_type_ids=token_type_ids,
522
+ **kwargs,
523
+ )
524
+
525
+ # position_ids in Paligemma are 1-indexed
526
+ if model_inputs.get("position_ids") is not None:
527
+ model_inputs["position_ids"] += 1
528
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
529
+ # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
530
+ if cache_position[0] == 0:
531
+ model_inputs["pixel_values"] = pixel_values
532
+ is_training = token_type_ids is not None and labels is not None
533
+ if cache_position[0] == 0 and isinstance(past_key_values, HybridCache):
534
+ input_tensor = inputs_embeds if inputs_embeds is not None else input_ids
535
+ causal_mask = self._update_causal_mask(
536
+ attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training
537
+ )
538
+ model_inputs["attention_mask"] = causal_mask
539
+
 
540
  return model_inputs