amanwithaplan commited on
Commit
a6f0d21
·
verified ·
1 Parent(s): 55c1fce

Fix: select only required columns for training

Browse files
Files changed (1) hide show
  1. train_reranker.py +10 -5
train_reranker.py CHANGED
@@ -144,11 +144,16 @@ def main():
144
  "score": "label"
145
  })
146
 
147
- # Split for evaluation
148
  eval_size = min(400, int(len(dataset) * 0.15))
149
  splits = dataset.train_test_split(test_size=eval_size, seed=42)
150
- train_dataset = splits["train"]
151
- eval_dataset = splits["test"]
 
 
 
 
 
152
 
153
  trackio.log({
154
  "data/train_size": len(train_dataset),
@@ -212,9 +217,9 @@ def main():
212
  trackio.log({f"final/{key}": value})
213
  logger.info(f"Final metrics: {final_metrics}")
214
 
215
- # Per-type evaluation
216
  logger.info("Evaluating by content type...")
217
- type_metrics = evaluate_by_type(model, eval_dataset)
218
  for key, value in type_metrics.items():
219
  trackio.log({f"final/by_type/{key}": value})
220
  logger.info(f"Per-type metrics: {type_metrics}")
 
144
  "score": "label"
145
  })
146
 
147
+ # Split for evaluation (before removing extra columns so we keep type for eval)
148
  eval_size = min(400, int(len(dataset) * 0.15))
149
  splits = dataset.train_test_split(test_size=eval_size, seed=42)
150
+
151
+ # Keep full eval dataset with type column for per-type evaluation
152
+ eval_dataset_full = splits["test"]
153
+
154
+ # Remove extra columns for training (CrossEncoderTrainer only wants sentence1, sentence2, label)
155
+ train_dataset = splits["train"].select_columns(["sentence1", "sentence2", "label"])
156
+ eval_dataset = splits["test"].select_columns(["sentence1", "sentence2", "label"])
157
 
158
  trackio.log({
159
  "data/train_size": len(train_dataset),
 
217
  trackio.log({f"final/{key}": value})
218
  logger.info(f"Final metrics: {final_metrics}")
219
 
220
+ # Per-type evaluation (use full eval dataset with type column)
221
  logger.info("Evaluating by content type...")
222
+ type_metrics = evaluate_by_type(model, eval_dataset_full)
223
  for key, value in type_metrics.items():
224
  trackio.log({f"final/by_type/{key}": value})
225
  logger.info(f"Per-type metrics: {type_metrics}")