Update engine.py
#1
by
qgallouedec
HF Staff
- opened
engine.py
CHANGED
|
@@ -9,6 +9,7 @@ from functools import partial
|
|
| 9 |
from typing import Generator, Optional, List, Dict, Any, Tuple
|
| 10 |
from datasets import Dataset, load_dataset
|
| 11 |
from trl import SFTConfig, SFTTrainer
|
|
|
|
| 12 |
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
| 13 |
from huggingface_hub import HfApi, model_info, metadata_update
|
| 14 |
|
|
@@ -374,6 +375,10 @@ class FunctionGemmaEngine:
|
|
| 374 |
return fig
|
| 375 |
|
| 376 |
def _evaluate_model(self, test_dataset) -> Generator[str, None, None]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
results = []
|
| 378 |
success_count = 0
|
| 379 |
for idx, item in enumerate(test_dataset):
|
|
|
|
| 9 |
from typing import Generator, Optional, List, Dict, Any, Tuple
|
| 10 |
from datasets import Dataset, load_dataset
|
| 11 |
from trl import SFTConfig, SFTTrainer
|
| 12 |
+
from trl.trainer.utils import remove_none_values
|
| 13 |
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
|
| 14 |
from huggingface_hub import HfApi, model_info, metadata_update
|
| 15 |
|
|
|
|
| 375 |
return fig
|
| 376 |
|
| 377 |
def _evaluate_model(self, test_dataset) -> Generator[str, None, None]:
|
| 378 |
+
# Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
|
| 379 |
+
# sampled data.
|
| 380 |
+
test_dataset = test_dataset.with_transform(remove_none_values)
|
| 381 |
+
|
| 382 |
results = []
|
| 383 |
success_count = 0
|
| 384 |
for idx, item in enumerate(test_dataset):
|