JacobLinCool commited on
Commit
4c696f3
·
verified ·
1 Parent(s): 4e530e4

fix: strip unused minicpm generation inputs

Browse files
Files changed (1) hide show
  1. hackathon_advisor/model_runtime.py +5 -0
hackathon_advisor/model_runtime.py CHANGED
@@ -140,6 +140,7 @@ class MiniCPMTransformersPlanner:
140
  return_dict=True,
141
  return_tensors="pt",
142
  ).to(next(self._model.parameters()).device)
 
143
  context = self._inference_mode() if self._inference_mode is not None else nullcontext()
144
  with context:
145
  generated = self._model.generate(
@@ -205,6 +206,10 @@ def system_prompt() -> str:
205
  )
206
 
207
 
 
 
 
 
208
  def _json_string(value: str) -> str:
209
  import json
210
 
 
140
  return_dict=True,
141
  return_tensors="pt",
142
  ).to(next(self._model.parameters()).device)
143
+ _strip_unused_generation_inputs(inputs)
144
  context = self._inference_mode() if self._inference_mode is not None else nullcontext()
145
  with context:
146
  generated = self._model.generate(
 
206
  )
207
 
208
 
209
+ def _strip_unused_generation_inputs(inputs: dict[str, Any]) -> None:
210
+ inputs.pop("token_type_ids", None)
211
+
212
+
213
  def _json_string(value: str) -> str:
214
  import json
215