yhirokawa commited on
Commit
92c75fd
·
verified ·
1 Parent(s): 33c33b6

Properly propagate `model_inputs` (#14)

Browse files

- Properly propagate `model_inputs` (3947136e80ad57d10a80752afe1f2442238750e3)

Files changed (1) hide show
  1. modeling_plamo.py +3 -0
modeling_plamo.py CHANGED
@@ -1663,6 +1663,9 @@ class Plamo2ForCausalLM(Plamo2PreTrainedModel):
1663
  "position_ids": position_ids,
1664
  "past_key_values": past_key_values,
1665
  "use_cache": kwargs.get("use_cache"),
 
 
 
1666
  "attention_mask": attention_mask,
1667
  "image_features": image_features,
1668
  }
 
1663
  "position_ids": position_ids,
1664
  "past_key_values": past_key_values,
1665
  "use_cache": kwargs.get("use_cache"),
1666
+ "output_attentions": kwargs.get("output_attentions"),
1667
+ "output_hidden_states": kwargs.get("output_hidden_states"),
1668
+ "logits_to_keep": kwargs.get("logits_to_keep"),
1669
  "attention_mask": attention_mask,
1670
  "image_features": image_features,
1671
  }