Mateusz Mróz commited on
Commit
d1d4e58
·
1 Parent(s): df23e97
Files changed (1) hide show
  1. modeling_florence2.py +15 -26
modeling_florence2.py CHANGED
@@ -3067,32 +3067,21 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
3067
  encoder_outputs=None,
3068
  **kwargs,
3069
  ):
3070
- # cut decoder_input_ids if past_key_values is used
3071
- if past_key_values is not None:
3072
- past_length = past_key_values[0][0].shape[2]
3073
-
3074
- # Some generation methods already pass only the last input ID
3075
- if decoder_input_ids.shape[1] > past_length:
3076
- remove_prefix_length = past_length
3077
- else:
3078
- # Default to old behavior: keep only final ID
3079
- remove_prefix_length = decoder_input_ids.shape[1] - 1
3080
-
3081
- decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
3082
-
3083
- return {
3084
- "input_ids": None, # encoder_outputs is defined. input_ids not needed
3085
- "encoder_outputs": encoder_outputs,
3086
- "past_key_values": past_key_values,
3087
- "decoder_input_ids": decoder_input_ids,
3088
- "attention_mask": attention_mask,
3089
- "pixel_values": pixel_values,
3090
- "decoder_attention_mask": decoder_attention_mask,
3091
- "head_mask": head_mask,
3092
- "decoder_head_mask": decoder_head_mask,
3093
- "cross_attn_head_mask": cross_attn_head_mask,
3094
- "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
3095
- }
3096
 
3097
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
3098
  return self.language_model.shift_tokens_right(labels)
 
3067
  encoder_outputs=None,
3068
  **kwargs,
3069
  ):
3070
+ # Ta funkcja powinna delegować wywołanie do `language_model`,
3071
+ # a nie zwracać własny słownik.
3072
+ # Poprawione wywołanie przekazuje wszystkie argumenty.
3073
+ return self.language_model.prepare_inputs_for_generation(
3074
+ decoder_input_ids,
3075
+ past_key_values=past_key_values,
3076
+ attention_mask=attention_mask,
3077
+ decoder_attention_mask=decoder_attention_mask,
3078
+ head_mask=head_mask,
3079
+ decoder_head_mask=decoder_head_mask,
3080
+ cross_attn_head_mask=cross_attn_head_mask,
3081
+ use_cache=use_cache,
3082
+ encoder_outputs=encoder_outputs,
3083
+ **kwargs,
3084
+ )
 
 
 
 
 
 
 
 
 
 
 
3085
 
3086
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
3087
  return self.language_model.shift_tokens_right(labels)