3v324v23 commited on
Commit
574b833
Β·
1 Parent(s): 7c3bc96

Fix task_id kwarg in reward function

Browse files
Files changed (1) hide show
  1. grpo_train.py +4 -3
grpo_train.py CHANGED
@@ -47,13 +47,14 @@ def build_dataset():
47
 
48
  # ── REWARD FUNCTION (actually calls the environment) ──────────────────────────
49
 
50
- def reward_environment(prompts, completions, task_ids, **kwargs):
51
  """
52
  This is the real reward β€” model outputs an action,
53
  we send it to the environment, environment returns the reward.
54
  """
55
  rewards = []
56
- for completion, task_id in zip(completions, task_ids):
 
57
  try:
58
  # Parse model output
59
  content = completion.strip()
@@ -70,7 +71,7 @@ def reward_environment(prompts, completions, task_ids, **kwargs):
70
 
71
  try:
72
  # Fresh episode for each reward calculation
73
- requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
74
 
75
  # Run a minimal sequence: if model says query_regulations,
76
  # run that then check what reward it generates
 
47
 
48
  # ── REWARD FUNCTION (actually calls the environment) ──────────────────────────
49
 
50
+ def reward_environment(prompts, completions, task_id, **kwargs):
51
  """
52
  This is the real reward β€” model outputs an action,
53
  we send it to the environment, environment returns the reward.
54
  """
55
  rewards = []
56
+ # Notice we zip with task_id (from the dataset) and use t_id inside the loop
57
+ for completion, t_id in zip(completions, task_id):
58
  try:
59
  # Parse model output
60
  content = completion.strip()
 
71
 
72
  try:
73
  # Fresh episode for each reward calculation
74
+ requests.post(f"{ENV_URL}/reset", json={"task_id": t_id})
75
 
76
  # Run a minimal sequence: if model says query_regulations,
77
  # run that then check what reward it generates