Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files
dataqa_env/server/environment.py
CHANGED
|
@@ -26,6 +26,14 @@ from .tasks import PlantedIssue, Task, get_task, list_tasks
|
|
| 26 |
IDENTIFY_WEIGHT = 0.6
|
| 27 |
FIX_WEIGHT = 0.4
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def parse_issue_key(raw: str) -> Optional[str]:
|
| 31 |
"""
|
|
@@ -416,7 +424,7 @@ class DataQAEnvironment(Environment):
|
|
| 416 |
num_issues_hint=len(self._current_task.planted_issues),
|
| 417 |
max_steps=self._current_task.max_steps,
|
| 418 |
done=False,
|
| 419 |
-
reward=0.0,
|
| 420 |
)
|
| 421 |
|
| 422 |
def step(
|
|
@@ -596,7 +604,7 @@ class DataQAEnvironment(Environment):
|
|
| 596 |
num_issues_hint=len(self._current_task.planted_issues),
|
| 597 |
max_steps=self._state.max_steps,
|
| 598 |
done=is_done,
|
| 599 |
-
reward=self._best_score,
|
| 600 |
metadata={
|
| 601 |
"identify_f1": identify_f1,
|
| 602 |
"identify_score": identify_score,
|
|
|
|
| 26 |
IDENTIFY_WEIGHT = 0.6
|
| 27 |
FIX_WEIGHT = 0.4
|
| 28 |
|
| 29 |
+
# Clamp reward to strict (0, 1) — validators reject exactly 0.0 and 1.0
|
| 30 |
+
REWARD_MIN = 0.001
|
| 31 |
+
REWARD_MAX = 0.999
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _clamp_reward(r: float) -> float:
|
| 35 |
+
return max(REWARD_MIN, min(REWARD_MAX, r))
|
| 36 |
+
|
| 37 |
|
| 38 |
def parse_issue_key(raw: str) -> Optional[str]:
|
| 39 |
"""
|
|
|
|
| 424 |
num_issues_hint=len(self._current_task.planted_issues),
|
| 425 |
max_steps=self._current_task.max_steps,
|
| 426 |
done=False,
|
| 427 |
+
reward=_clamp_reward(0.0),
|
| 428 |
)
|
| 429 |
|
| 430 |
def step(
|
|
|
|
| 604 |
num_issues_hint=len(self._current_task.planted_issues),
|
| 605 |
max_steps=self._state.max_steps,
|
| 606 |
done=is_done,
|
| 607 |
+
reward=_clamp_reward(self._best_score),
|
| 608 |
metadata={
|
| 609 |
"identify_f1": identify_f1,
|
| 610 |
"identify_score": identify_score,
|
tests/test_environment.py
CHANGED
|
@@ -262,7 +262,7 @@ class TestDataQAEnvironment:
|
|
| 262 |
assert obs.num_issues_hint == 6
|
| 263 |
assert obs.max_steps == 3
|
| 264 |
assert obs.done is False
|
| 265 |
-
assert obs.reward
|
| 266 |
assert "fix" in obs.feedback.lower() # mentions fix phase
|
| 267 |
|
| 268 |
def test_reset_medium(self, env):
|
|
@@ -318,7 +318,7 @@ class TestDataQAEnvironment:
|
|
| 318 |
env.reset(task_id="easy")
|
| 319 |
action = DataQAAction(issues=[], task_id="easy")
|
| 320 |
obs = env.step(action)
|
| 321 |
-
assert obs.reward
|
| 322 |
|
| 323 |
def test_step_exhausts_max_steps(self, env):
|
| 324 |
env.reset(task_id="easy")
|
|
|
|
| 262 |
assert obs.num_issues_hint == 6
|
| 263 |
assert obs.max_steps == 3
|
| 264 |
assert obs.done is False
|
| 265 |
+
assert obs.reward < 0.01 # clamped to 0.001, not exactly 0.0
|
| 266 |
assert "fix" in obs.feedback.lower() # mentions fix phase
|
| 267 |
|
| 268 |
def test_reset_medium(self, env):
|
|
|
|
| 318 |
env.reset(task_id="easy")
|
| 319 |
action = DataQAAction(issues=[], task_id="easy")
|
| 320 |
obs = env.step(action)
|
| 321 |
+
assert obs.reward < 0.01 # clamped to 0.001, not exactly 0.0
|
| 322 |
|
| 323 |
def test_step_exhausts_max_steps(self, env):
|
| 324 |
env.reset(task_id="easy")
|