tobil commited on
Commit
6ca0e08
·
verified ·
1 Parent(s): 37174c2

Upload train_grpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_grpo.py +2 -1
train_grpo.py CHANGED
@@ -150,6 +150,7 @@ def compute_length_reward(text: str) -> float:
150
 
151
  class QMDRewardFunction:
152
  """Combined reward function for QMD query expansion."""
 
153
 
154
  def __init__(self):
155
  # Load a small embedding model for diversity computation
@@ -272,7 +273,7 @@ def main():
272
  processing_class=tokenizer,
273
  args=config,
274
  train_dataset=dataset,
275
- reward_funcs=reward_fn,
276
  )
277
 
278
  # Train
 
150
 
151
  class QMDRewardFunction:
152
  """Combined reward function for QMD query expansion."""
153
+ __name__ = "qmd_format_diversity_reward"
154
 
155
  def __init__(self):
156
  # Load a small embedding model for diversity computation
 
273
  processing_class=tokenizer,
274
  args=config,
275
  train_dataset=dataset,
276
+ reward_funcs=[reward_fn],
277
  )
278
 
279
  # Train