kush5699 commited on
Commit
593f876
·
verified ·
1 Parent(s): 6e90226

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. env/environment.py +1 -1
  2. env/tasks.py +4 -4
  3. inference.py +6 -5
env/environment.py CHANGED
@@ -71,7 +71,7 @@ class DataValidationEnvironment:
71
  self._state.last_actions.append(action_key)
72
 
73
  if is_repeat:
74
- reward = -0.1
75
  message = "Penalty: repeated identical action"
76
  else:
77
  reward, message, fixed = grade_action(
 
71
  self._state.last_actions.append(action_key)
72
 
73
  if is_repeat:
74
+ reward = 0.01
75
  message = "Penalty: repeated identical action"
76
  else:
77
  reward, message, fixed = grade_action(
env/tasks.py CHANGED
@@ -224,10 +224,10 @@ def grade_action(action_type: str, target_field: str, target_row: int,
224
 
225
  if action_type == "validate":
226
  fixed = sum(1 for e in errors if e.get("fixed", False))
227
- return 0.0, f"Validation: {fixed}/{total_errors} errors fixed ({fixed/total_errors*100:.0f}%)", False
228
 
229
  if action_type == "skip":
230
- return 0.0, "Skipped current action", False
231
 
232
  matching_error = None
233
  for e in errors:
@@ -238,7 +238,7 @@ def grade_action(action_type: str, target_field: str, target_row: int,
238
  break
239
 
240
  if matching_error is None:
241
- return -0.05, f"No unfixed error at row {target_row}, field '{target_field}'", False
242
 
243
  action_to_error_map = {
244
  "fix_missing": "missing",
@@ -277,4 +277,4 @@ def grade_action(action_type: str, target_field: str, target_row: int,
277
  reward = 0.9 / total_errors
278
  return reward, f"Fixed: row {target_row}, field '{target_field}' -> '{new_value}'", True
279
  else:
280
- return -0.05, f"Wrong value for row {target_row}, field '{target_field}'. Got '{new_value}', expected something else.", False
 
224
 
225
  if action_type == "validate":
226
  fixed = sum(1 for e in errors if e.get("fixed", False))
227
+ return 0.01, f"Validation: {fixed}/{total_errors} errors fixed ({fixed/total_errors*100:.0f}%)", False
228
 
229
  if action_type == "skip":
230
+ return 0.01, "Skipped current action", False
231
 
232
  matching_error = None
233
  for e in errors:
 
238
  break
239
 
240
  if matching_error is None:
241
+ return 0.01, f"No unfixed error at row {target_row}, field '{target_field}'", False
242
 
243
  action_to_error_map = {
244
  "fix_missing": "missing",
 
277
  reward = 0.9 / total_errors
278
  return reward, f"Fixed: row {target_row}, field '{target_field}' -> '{new_value}'", True
279
  else:
280
+ return 0.01, f"Wrong value for row {target_row}, field '{target_field}'. Got '{new_value}', expected something else.", False
inference.py CHANGED
@@ -175,14 +175,15 @@ def run_episode(task_config: dict) -> None:
175
  error_msg = None
176
  try:
177
  obs = env_step(action)
178
- reward = obs.get("reward", 0.0)
179
  done = obs.get("done", False)
180
  except Exception as e:
181
  error_msg = str(e)
182
- reward = 0.0
183
  done = False
184
 
185
  steps += 1
 
186
  rewards.append(reward)
187
 
188
  print(f"[STEP] step={steps} action={action_str} reward={reward:.2f} done={str(done).lower()} error={error_msg if error_msg else 'null'}")
@@ -196,11 +197,11 @@ def run_episode(task_config: dict) -> None:
196
  except Exception as e:
197
  error_str = str(e)
198
  if steps == 0:
199
- print(f"[STEP] step=1 action=null reward=0.00 done=true error={error_str}")
200
  steps = 1
201
- rewards = [0.0]
202
  finally:
203
- rewards_str = ",".join(f"{r:.2f}" for r in rewards) if rewards else "0.00"
204
  print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}")
205
 
206
 
 
175
  error_msg = None
176
  try:
177
  obs = env_step(action)
178
+ reward = obs.get("reward", 0.01)
179
  done = obs.get("done", False)
180
  except Exception as e:
181
  error_msg = str(e)
182
+ reward = 0.01
183
  done = False
184
 
185
  steps += 1
186
+ reward = max(0.01, min(0.99, reward))
187
  rewards.append(reward)
188
 
189
  print(f"[STEP] step={steps} action={action_str} reward={reward:.2f} done={str(done).lower()} error={error_msg if error_msg else 'null'}")
 
197
  except Exception as e:
198
  error_str = str(e)
199
  if steps == 0:
200
+ print(f"[STEP] step=1 action=null reward=0.01 done=true error={error_str}")
201
  steps = 1
202
+ rewards = [0.01]
203
  finally:
204
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards) if rewards else "0.01"
205
  print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}")
206
 
207