Fix task_id kwarg in reward function
Browse files- grpo_train.py +4 -3
grpo_train.py
CHANGED
|
@@ -47,13 +47,14 @@ def build_dataset():
|
|
| 47 |
|
| 48 |
# ββ REWARD FUNCTION (actually calls the environment) ββββββββββββββββββββββββββ
|
| 49 |
|
| 50 |
-
def reward_environment(prompts, completions,
|
| 51 |
"""
|
| 52 |
This is the real reward β model outputs an action,
|
| 53 |
we send it to the environment, environment returns the reward.
|
| 54 |
"""
|
| 55 |
rewards = []
|
| 56 |
-
|
|
|
|
| 57 |
try:
|
| 58 |
# Parse model output
|
| 59 |
content = completion.strip()
|
|
@@ -70,7 +71,7 @@ def reward_environment(prompts, completions, task_ids, **kwargs):
|
|
| 70 |
|
| 71 |
try:
|
| 72 |
# Fresh episode for each reward calculation
|
| 73 |
-
requests.post(f"{ENV_URL}/reset", json={"task_id":
|
| 74 |
|
| 75 |
# Run a minimal sequence: if model says query_regulations,
|
| 76 |
# run that then check what reward it generates
|
|
|
|
| 47 |
|
| 48 |
# ββ REWARD FUNCTION (actually calls the environment) ββββββββββββββββββββββββββ
|
| 49 |
|
| 50 |
+
def reward_environment(prompts, completions, task_id, **kwargs):
|
| 51 |
"""
|
| 52 |
This is the real reward β model outputs an action,
|
| 53 |
we send it to the environment, environment returns the reward.
|
| 54 |
"""
|
| 55 |
rewards = []
|
| 56 |
+
# Notice we zip with task_id (from the dataset) and use t_id inside the loop
|
| 57 |
+
for completion, t_id in zip(completions, task_id):
|
| 58 |
try:
|
| 59 |
# Parse model output
|
| 60 |
content = completion.strip()
|
|
|
|
| 71 |
|
| 72 |
try:
|
| 73 |
# Fresh episode for each reward calculation
|
| 74 |
+
requests.post(f"{ENV_URL}/reset", json={"task_id": t_id})
|
| 75 |
|
| 76 |
# Run a minimal sequence: if model says query_regulations,
|
| 77 |
# run that then check what reward it generates
|