Mateusz Mróz commited on
Commit
d4318c2
·
1 Parent(s): d1d4e58

test nie udany

Browse files
Files changed (1) hide show
  1. modeling_florence2.py +34 -29
modeling_florence2.py CHANGED
@@ -2201,22 +2201,16 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
2201
  ):
2202
  # cut decoder_input_ids if past_key_values is used
2203
  if past_key_values is not None:
2204
- # Dodatkowe zabezpieczenie na wypadek, gdyby `past_key_values` nie było krotką
2205
- if not isinstance(past_key_values, tuple):
2206
- past_key_values = tuple(past_key_values)
2207
 
2208
- # Sprawdzamy, czy wewnętrzne elementy nie None, zanim uzyskamy do nich dostęp
2209
- if past_key_values[0] is not None and past_key_values[0][0] is not None:
2210
- past_length = past_key_values[0][0].shape[2]
2211
-
2212
- # Some generation methods already pass only the last input ID
2213
- if decoder_input_ids.shape[1] > past_length:
2214
- remove_prefix_length = past_length
2215
- else:
2216
- # Default to old behavior: keep only final ID
2217
- remove_prefix_length = decoder_input_ids.shape[1] - 1
2218
 
2219
- decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
2220
 
2221
  return {
2222
  "input_ids": None, # encoder_outputs is defined. input_ids not needed
@@ -3067,21 +3061,32 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
3067
  encoder_outputs=None,
3068
  **kwargs,
3069
  ):
3070
- # Ta funkcja powinna delegować wywołanie do `language_model`,
3071
- # a nie zwracać własny słownik.
3072
- # Poprawione wywołanie przekazuje wszystkie argumenty.
3073
- return self.language_model.prepare_inputs_for_generation(
3074
- decoder_input_ids,
3075
- past_key_values=past_key_values,
3076
- attention_mask=attention_mask,
3077
- decoder_attention_mask=decoder_attention_mask,
3078
- head_mask=head_mask,
3079
- decoder_head_mask=decoder_head_mask,
3080
- cross_attn_head_mask=cross_attn_head_mask,
3081
- use_cache=use_cache,
3082
- encoder_outputs=encoder_outputs,
3083
- **kwargs,
3084
- )
 
 
 
 
 
 
 
 
 
 
 
3085
 
3086
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
3087
  return self.language_model.shift_tokens_right(labels)
 
2201
  ):
2202
  # cut decoder_input_ids if past_key_values is used
2203
  if past_key_values is not None:
2204
+ past_length = past_key_values[0][0].shape[2]
 
 
2205
 
2206
+ # Some generation methods already pass only the last input ID
2207
+ if decoder_input_ids.shape[1] > past_length:
2208
+ remove_prefix_length = past_length
2209
+ else:
2210
+ # Default to old behavior: keep only final ID
2211
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
 
 
 
 
2212
 
2213
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
2214
 
2215
  return {
2216
  "input_ids": None, # encoder_outputs is defined. input_ids not needed
 
3061
  encoder_outputs=None,
3062
  **kwargs,
3063
  ):
3064
+ # cut decoder_input_ids if past_key_values is used
3065
+ if past_key_values is not None:
3066
+ past_length = past_key_values[0][0].shape[2]
3067
+
3068
+ # Some generation methods already pass only the last input ID
3069
+ if decoder_input_ids.shape[1] > past_length:
3070
+ remove_prefix_length = past_length
3071
+ else:
3072
+ # Default to old behavior: keep only final ID
3073
+ remove_prefix_length = decoder_input_ids.shape[1] - 1
3074
+
3075
+ decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
3076
+
3077
+ return {
3078
+ "input_ids": None, # encoder_outputs is defined. input_ids not needed
3079
+ "encoder_outputs": encoder_outputs,
3080
+ "past_key_values": past_key_values,
3081
+ "decoder_input_ids": decoder_input_ids,
3082
+ "attention_mask": attention_mask,
3083
+ "pixel_values": pixel_values,
3084
+ "decoder_attention_mask": decoder_attention_mask,
3085
+ "head_mask": head_mask,
3086
+ "decoder_head_mask": decoder_head_mask,
3087
+ "cross_attn_head_mask": cross_attn_head_mask,
3088
+ "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
3089
+ }
3090
 
3091
  def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
3092
  return self.language_model.shift_tokens_right(labels)