Update engine.py

#1
by qgallouedec HF Staff - opened
Files changed (1) hide show
  1. engine.py +5 -0
engine.py CHANGED
@@ -9,6 +9,7 @@ from functools import partial
9
  from typing import Generator, Optional, List, Dict, Any, Tuple
10
  from datasets import Dataset, load_dataset
11
  from trl import SFTConfig, SFTTrainer
 
12
  from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
13
  from huggingface_hub import HfApi, model_info, metadata_update
14
 
@@ -374,6 +375,10 @@ class FunctionGemmaEngine:
374
  return fig
375
 
376
  def _evaluate_model(self, test_dataset) -> Generator[str, None, None]:
 
 
 
 
377
  results = []
378
  success_count = 0
379
  for idx, item in enumerate(test_dataset):
 
9
  from typing import Generator, Optional, List, Dict, Any, Tuple
10
  from datasets import Dataset, load_dataset
11
  from trl import SFTConfig, SFTTrainer
12
+ from trl.trainer.utils import remove_none_values
13
  from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
14
  from huggingface_hub import HfApi, model_info, metadata_update
15
 
 
375
  return fig
376
 
377
  def _evaluate_model(self, test_dataset) -> Generator[str, None, None]:
378
+ # Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
379
+ # sampled data.
380
+ test_dataset = test_dataset.with_transform(remove_none_values)
381
+
382
  results = []
383
  success_count = 0
384
  for idx, item in enumerate(test_dataset):