varb15 commited on
Commit
c5b540e
·
verified ·
1 Parent(s): 9f1cf04

Upload folder using huggingface_hub

Browse files
dataqa_env/server/environment.py CHANGED
@@ -26,6 +26,14 @@ from .tasks import PlantedIssue, Task, get_task, list_tasks
26
  IDENTIFY_WEIGHT = 0.6
27
  FIX_WEIGHT = 0.4
28
 
 
 
 
 
 
 
 
 
29
 
30
  def parse_issue_key(raw: str) -> Optional[str]:
31
  """
@@ -416,7 +424,7 @@ class DataQAEnvironment(Environment):
416
  num_issues_hint=len(self._current_task.planted_issues),
417
  max_steps=self._current_task.max_steps,
418
  done=False,
419
- reward=0.0,
420
  )
421
 
422
  def step(
@@ -596,7 +604,7 @@ class DataQAEnvironment(Environment):
596
  num_issues_hint=len(self._current_task.planted_issues),
597
  max_steps=self._state.max_steps,
598
  done=is_done,
599
- reward=self._best_score,
600
  metadata={
601
  "identify_f1": identify_f1,
602
  "identify_score": identify_score,
 
26
  IDENTIFY_WEIGHT = 0.6
27
  FIX_WEIGHT = 0.4
28
 
29
+ # Clamp reward to strict (0, 1) — validators reject exactly 0.0 and 1.0
30
+ REWARD_MIN = 0.001
31
+ REWARD_MAX = 0.999
32
+
33
+
34
+ def _clamp_reward(r: float) -> float:
35
+ return max(REWARD_MIN, min(REWARD_MAX, r))
36
+
37
 
38
  def parse_issue_key(raw: str) -> Optional[str]:
39
  """
 
424
  num_issues_hint=len(self._current_task.planted_issues),
425
  max_steps=self._current_task.max_steps,
426
  done=False,
427
+ reward=_clamp_reward(0.0),
428
  )
429
 
430
  def step(
 
604
  num_issues_hint=len(self._current_task.planted_issues),
605
  max_steps=self._state.max_steps,
606
  done=is_done,
607
+ reward=_clamp_reward(self._best_score),
608
  metadata={
609
  "identify_f1": identify_f1,
610
  "identify_score": identify_score,
tests/test_environment.py CHANGED
@@ -262,7 +262,7 @@ class TestDataQAEnvironment:
262
  assert obs.num_issues_hint == 6
263
  assert obs.max_steps == 3
264
  assert obs.done is False
265
- assert obs.reward == 0.0
266
  assert "fix" in obs.feedback.lower() # mentions fix phase
267
 
268
  def test_reset_medium(self, env):
@@ -318,7 +318,7 @@ class TestDataQAEnvironment:
318
  env.reset(task_id="easy")
319
  action = DataQAAction(issues=[], task_id="easy")
320
  obs = env.step(action)
321
- assert obs.reward == 0.0
322
 
323
  def test_step_exhausts_max_steps(self, env):
324
  env.reset(task_id="easy")
 
262
  assert obs.num_issues_hint == 6
263
  assert obs.max_steps == 3
264
  assert obs.done is False
265
+ assert obs.reward < 0.01 # clamped to 0.001, not exactly 0.0
266
  assert "fix" in obs.feedback.lower() # mentions fix phase
267
 
268
  def test_reset_medium(self, env):
 
318
  env.reset(task_id="easy")
319
  action = DataQAAction(issues=[], task_id="easy")
320
  obs = env.step(action)
321
+ assert obs.reward < 0.01 # clamped to 0.001, not exactly 0.0
322
 
323
  def test_step_exhausts_max_steps(self, env):
324
  env.reset(task_id="easy")