Spaces:
Running on Zero
Running on Zero
fix: strip unused minicpm generation inputs
Browse files
hackathon_advisor/model_runtime.py
CHANGED
|
@@ -140,6 +140,7 @@ class MiniCPMTransformersPlanner:
|
|
| 140 |
return_dict=True,
|
| 141 |
return_tensors="pt",
|
| 142 |
).to(next(self._model.parameters()).device)
|
|
|
|
| 143 |
context = self._inference_mode() if self._inference_mode is not None else nullcontext()
|
| 144 |
with context:
|
| 145 |
generated = self._model.generate(
|
|
@@ -205,6 +206,10 @@ def system_prompt() -> str:
|
|
| 205 |
)
|
| 206 |
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
def _json_string(value: str) -> str:
|
| 209 |
import json
|
| 210 |
|
|
|
|
| 140 |
return_dict=True,
|
| 141 |
return_tensors="pt",
|
| 142 |
).to(next(self._model.parameters()).device)
|
| 143 |
+
_strip_unused_generation_inputs(inputs)
|
| 144 |
context = self._inference_mode() if self._inference_mode is not None else nullcontext()
|
| 145 |
with context:
|
| 146 |
generated = self._model.generate(
|
|
|
|
| 206 |
)
|
| 207 |
|
| 208 |
|
| 209 |
+
def _strip_unused_generation_inputs(inputs: dict[str, Any]) -> None:
|
| 210 |
+
inputs.pop("token_type_ids", None)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
def _json_string(value: str) -> str:
|
| 214 |
import json
|
| 215 |
|