adityss commited on
Commit
c70e17d
·
1 Parent(s): 8204dc0

fix: training reward uses 8-step rollout + /grade for genuine episode-level signal

Browse files
Files changed (1) hide show
  1. scripts/train_unsloth.py +33 -29
scripts/train_unsloth.py CHANGED
@@ -85,15 +85,16 @@ def reward_has_required_keys(completions, **kwargs):
85
 
86
  def get_reward_env_interaction(env_url):
87
  """Closure to capture the target environment URL for the reward function.
88
-
89
- Uses direct requests calls instead of GenericEnvClient to avoid dependency issues.
 
 
90
  """
91
  def reward_env_interaction(completions, **kwargs):
92
  rewards = []
93
  for completion in completions:
94
  text = completion[0]["content"] if isinstance(completion, list) else completion
95
  try:
96
- # Parse action from LLM output
97
  match = re.search(r'\{.*?\}', text, re.DOTALL)
98
  action = json.loads(match.group()) if match else {}
99
  step_action = {
@@ -103,41 +104,44 @@ def get_reward_env_interaction(env_url):
103
  "load_shed_fraction": float(max(0, min(0.5, action.get("load_shed_fraction", 0.0)))),
104
  "building_id": 0
105
  }
106
-
107
- # Direct HTTP calls to environment instead of GenericEnvClient
108
- # Reset the environment first
109
  reset_resp = requests.post(
110
  f"{env_url}/reset",
111
- json={"task_id": 1, "seed": 42},
112
  timeout=30
113
  )
114
  if reset_resp.status_code != 200:
115
  rewards.append(0.0)
116
  continue
117
-
118
- # Take a step with the proposed action
119
- step_resp = requests.post(
120
- f"{env_url}/step",
121
- json=[step_action],
122
- timeout=30
123
- )
124
- if step_resp.status_code != 200:
125
- rewards.append(0.0)
126
- continue
127
-
128
- result = step_resp.json()
129
- if isinstance(result, list) and len(result) > 0:
130
- step_reward = float(result[0].get("reward", 0.0))
131
- elif isinstance(result, dict) and "results" in result:
132
- step_reward = float(result["results"][0].get("reward", 0.0))
 
 
 
 
 
 
 
 
133
  else:
134
- step_reward = 0.0
135
-
136
- # Normalize reward to 0.0-0.4 range. The Go step reward is usually around [-2.0, 3.0].
137
- # Shift by +2.0 and scale by 0.05 to map to ~0.0-0.4.
138
- val = (step_reward + 2.0) * 0.08
139
  rewards.append(min(0.4, max(0.0, val)))
140
-
141
  except Exception as e:
142
  print(f"Env error: {e}", file=sys.stderr)
143
  rewards.append(0.0)
 
85
 
86
  def get_reward_env_interaction(env_url):
87
  """Closure to capture the target environment URL for the reward function.
88
+
89
+ Uses a SHORT (8-step) rollout to get a more genuine episode-level reward signal.
90
+ The grade endpoint returns the true episode score (0.0-1.0 clamped open interval),
91
+ which is what we use as the reward — not the step-level reward.
92
  """
93
  def reward_env_interaction(completions, **kwargs):
94
  rewards = []
95
  for completion in completions:
96
  text = completion[0]["content"] if isinstance(completion, list) else completion
97
  try:
 
98
  match = re.search(r'\{.*?\}', text, re.DOTALL)
99
  action = json.loads(match.group()) if match else {}
100
  step_action = {
 
104
  "load_shed_fraction": float(max(0, min(0.5, action.get("load_shed_fraction", 0.0)))),
105
  "building_id": 0
106
  }
107
+
 
 
108
  reset_resp = requests.post(
109
  f"{env_url}/reset",
110
+ json={"task_id": 2, "seed": 42},
111
  timeout=30
112
  )
113
  if reset_resp.status_code != 200:
114
  rewards.append(0.0)
115
  continue
116
+
117
+ step_rewards = []
118
+ for _ in range(8):
119
+ step_resp = requests.post(
120
+ f"{env_url}/step",
121
+ json=[step_action],
122
+ timeout=30
123
+ )
124
+ if step_resp.status_code != 200:
125
+ step_rewards.append(0.0)
126
+ continue
127
+ result = step_resp.json()
128
+ if isinstance(result, list) and len(result) > 0:
129
+ r = float(result[0].get("reward", 0.0))
130
+ elif isinstance(result, dict) and "results" in result:
131
+ r = float(result["results"][0].get("reward", 0.0))
132
+ else:
133
+ r = 0.0
134
+ step_rewards.append(r)
135
+
136
+ grade_resp = requests.get(f"{env_url}/grade", timeout=30)
137
+ if grade_resp.status_code == 200:
138
+ episode_score = float(grade_resp.json().get("score", 0.5))
139
+ val = episode_score * 0.4
140
  else:
141
+ mean_step_reward = sum(step_rewards) / len(step_rewards) if step_rewards else 0.0
142
+ val = (mean_step_reward + 2.0) * 0.08
 
 
 
143
  rewards.append(min(0.4, max(0.0, val)))
144
+
145
  except Exception as e:
146
  print(f"Env error: {e}", file=sys.stderr)
147
  rewards.append(0.0)