Pranav Patel commited on
Commit
d4ccbaf
·
1 Parent(s): 9aeca6d

fix: sync bug fixes (imports, eval safety, missing fields, task guards, rewards)

Browse files
harfeast_env/client.py CHANGED
@@ -5,16 +5,10 @@ Connects to HarFeast OpenEnv server via WebSocket/HTTP.
5
 
6
  from typing import Any, Dict
7
 
8
- try:
9
- from openenv.core.client_types import StepResult
10
- from openenv.core.env_server.types import State
11
- from openenv.core.env_client import EnvClient
12
- from harfeast_env.models import HarFeastAction, HarFeastObservation
13
- except ImportError:
14
- from openenv.core.client_types import StepResult
15
- from openenv.core.env_server.types import State
16
- from openenv.core.env_client import EnvClient
17
- from models import HarFeastAction, HarFeastObservation
18
 
19
 
20
  class HarFeastEnv(EnvClient[HarFeastAction, HarFeastObservation, State]):
 
5
 
6
  from typing import Any, Dict
7
 
8
+ from openenv.core.client_types import StepResult
9
+ from openenv.core.env_server.types import State
10
+ from openenv.core.env_client import EnvClient
11
+ from harfeast_env.models import HarFeastAction, HarFeastObservation
 
 
 
 
 
 
12
 
13
 
14
  class HarFeastEnv(EnvClient[HarFeastAction, HarFeastObservation, State]):
harfeast_env/models.py CHANGED
@@ -5,10 +5,7 @@ Actions are JSON-serialized calls: {"action": "files.list", "path": "."}
5
 
6
  from pydantic import Field
7
 
8
- try:
9
- from openenv.core.env_server.types import Action, Observation
10
- except ImportError:
11
- from openenv.core.env_server.types import Action, Observation
12
 
13
 
14
  class HarFeastAction(Action):
@@ -46,3 +43,15 @@ class HarFeastObservation(Observation):
46
  default="[]",
47
  description="JSON list of filtered dataset names available for chaining",
48
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  from pydantic import Field
7
 
8
+ from openenv.core.env_server.types import Action, Observation
 
 
 
9
 
10
 
11
  class HarFeastAction(Action):
 
43
  default="[]",
44
  description="JSON list of filtered dataset names available for chaining",
45
  )
46
+ done: bool = Field(
47
+ default=False,
48
+ description="Whether the episode has ended",
49
+ )
50
+ reward: float = Field(
51
+ default=0.0,
52
+ description="Rubric score (0-100) when done, else 0",
53
+ )
54
+ metadata: dict = Field(
55
+ default_factory=dict,
56
+ description="Extra info (action_taken, last_error, task_id)",
57
+ )
harfeast_env/server/app.py CHANGED
@@ -11,14 +11,9 @@ _project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
11
  if _project_root not in sys.path:
12
  sys.path.insert(0, _project_root)
13
 
14
- try:
15
- from openenv.core.env_server.http_server import create_app
16
- from harfeast_env.models import HarFeastAction, HarFeastObservation
17
- from harfeast_env.server.harfeast_environment import HarFeastEnvironment
18
- except ImportError:
19
- from openenv.core.env_server.http_server import create_app
20
- from models import HarFeastAction, HarFeastObservation
21
- from server.harfeast_environment import HarFeastEnvironment
22
 
23
  # World path - use env var or default to project harfeast_world
24
  WORLD_PATH = os.environ.get("HARFEAST_WORLD_PATH") or os.path.join(_project_root, "harfeast_world")
 
11
  if _project_root not in sys.path:
12
  sys.path.insert(0, _project_root)
13
 
14
+ from openenv.core.env_server.http_server import create_app
15
+ from harfeast_env.models import HarFeastAction, HarFeastObservation
16
+ from harfeast_env.server.harfeast_environment import HarFeastEnvironment
 
 
 
 
 
17
 
18
  # World path - use env var or default to project harfeast_world
19
  WORLD_PATH = os.environ.get("HARFEAST_WORLD_PATH") or os.path.join(_project_root, "harfeast_world")
harfeast_env/server/harfeast_environment.py CHANGED
@@ -7,12 +7,8 @@ import json
7
  import os
8
  from uuid import uuid4
9
 
10
- try:
11
- from openenv.core.env_server.interfaces import Environment
12
- from openenv.core.env_server.types import State
13
- except ImportError:
14
- from openenv.core.env_server.interfaces import Environment
15
- from openenv.core.env_server.types import State
16
 
17
  # Import our core logic - use path relative to project root
18
  import sys
 
7
  import os
8
  from uuid import uuid4
9
 
10
+ from openenv.core.env_server.interfaces import Environment
11
+ from openenv.core.env_server.types import State
 
 
 
 
12
 
13
  # Import our core logic - use path relative to project root
14
  import sys
harfeast_openenv/actions.py CHANGED
@@ -3,6 +3,7 @@
3
  import ast
4
  import csv
5
  import json
 
6
  import os
7
  import re
8
  from collections import defaultdict
@@ -13,6 +14,30 @@ from .schemas import ActionResult
13
 
14
  # ── Observation size limits ──────────────────────────────────────
15
  MAX_TABLE_ROWS = 20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  MAX_DOCUMENT_CHARS = 2000
17
  def handle_files_list(world_path: str, path: str = ".") -> ActionResult:
18
  """
@@ -412,7 +437,7 @@ def handle_data_add_columns(
412
  v = _try_float(row.get(c, ""))
413
  ns[c] = v if isinstance(v, (int, float)) else 0
414
  try:
415
- row[new_column] = round(eval(expression, {"__builtins__": {}}, ns), 2)
416
  except Exception:
417
  row[new_column] = 0
418
  new_rows.append(row)
@@ -442,7 +467,8 @@ def handle_data_compute(expression: str) -> ActionResult:
442
  error="Invalid expression",
443
  )
444
  try:
445
- result = eval(expr)
 
446
  if isinstance(result, float) and not result.is_integer():
447
  return ActionResult(observation=str(round(result, 2)))
448
  return ActionResult(observation=str(result))
 
3
  import ast
4
  import csv
5
  import json
6
+ import operator
7
  import os
8
  import re
9
  from collections import defaultdict
 
14
 
15
  # ── Observation size limits ──────────────────────────────────────
16
  MAX_TABLE_ROWS = 20
17
+
18
+ # ── Safe arithmetic evaluator (replaces eval) ────────────────────
19
+ _SAFE_BINOPS = {
20
+ ast.Add: operator.add, ast.Sub: operator.sub,
21
+ ast.Mult: operator.mul, ast.Div: operator.truediv,
22
+ }
23
+
24
+ def _safe_eval_expr(node, namespace=None):
25
+ """Evaluate an AST node containing only arithmetic on numbers (and optionally named vars)."""
26
+ if isinstance(node, ast.Expression):
27
+ return _safe_eval_expr(node.body, namespace)
28
+ if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
29
+ return node.value
30
+ if isinstance(node, ast.BinOp) and type(node.op) in _SAFE_BINOPS:
31
+ left = _safe_eval_expr(node.left, namespace)
32
+ right = _safe_eval_expr(node.right, namespace)
33
+ return _SAFE_BINOPS[type(node.op)](left, right)
34
+ if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
35
+ return -_safe_eval_expr(node.operand, namespace)
36
+ if isinstance(node, ast.Name) and namespace is not None:
37
+ if node.id in namespace:
38
+ return namespace[node.id]
39
+ raise ValueError(f"Unknown variable: {node.id}")
40
+ raise ValueError(f"Unsupported expression element: {ast.dump(node)}")
41
  MAX_DOCUMENT_CHARS = 2000
42
  def handle_files_list(world_path: str, path: str = ".") -> ActionResult:
43
  """
 
437
  v = _try_float(row.get(c, ""))
438
  ns[c] = v if isinstance(v, (int, float)) else 0
439
  try:
440
+ row[new_column] = round(_safe_eval_expr(tree, namespace=ns), 2)
441
  except Exception:
442
  row[new_column] = 0
443
  new_rows.append(row)
 
467
  error="Invalid expression",
468
  )
469
  try:
470
+ tree = ast.parse(expr, mode="eval")
471
+ result = _safe_eval_expr(tree)
472
  if isinstance(result, float) and not result.is_integer():
473
  return ActionResult(observation=str(round(result, 2)))
474
  return ActionResult(observation=str(result))
harfeast_openenv/environment.py CHANGED
@@ -127,6 +127,15 @@ class HarFeastOpenEnv:
127
  Execute one action and return the result.
128
  Action format: {"action": "files.list", "path": "."} or JSON string.
129
  """
 
 
 
 
 
 
 
 
 
130
  if self._done:
131
  return StepResult(
132
  observation="Episode already ended. Call reset() to start a new episode.",
@@ -286,7 +295,7 @@ class HarFeastOpenEnv:
286
 
287
  def _build_context_summary(self) -> str:
288
  """Compact summary of the episode so far, prepended to every observation."""
289
- if not self._history:
290
  return ""
291
 
292
  lines = [f"=== Task: {self._task['task_name']} ==="]
 
127
  Execute one action and return the result.
128
  Action format: {"action": "files.list", "path": "."} or JSON string.
129
  """
130
+ if self._task is None:
131
+ return StepResult(
132
+ observation="No task loaded. Call reset() before step().",
133
+ prompt="",
134
+ step_count=0,
135
+ done=True,
136
+ reward=0.0,
137
+ info={"action_taken": "none", "last_error": "reset() not called"},
138
+ )
139
  if self._done:
140
  return StepResult(
141
  observation="Episode already ended. Call reset() to start a new episode.",
 
295
 
296
  def _build_context_summary(self) -> str:
297
  """Compact summary of the episode so far, prepended to every observation."""
298
+ if not self._history or not self._task:
299
  return ""
300
 
301
  lines = [f"=== Task: {self._task['task_name']} ==="]
harfeast_openenv/rewards.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GDPO-style decomposed reward functions for HarFeast GRPO training.
3
+
4
+ Three independent reward signals, each normalized independently by TRL's
5
+ GRPOTrainer when passed as a list to reward_funcs. This is equivalent to
6
+ NVIDIA's GDPO (Jan 2026) multi-signal normalization.
7
+
8
+ Signature: reward_func(completions: list[list[dict]], **kwargs) -> list[float]
9
+ - completions[i] = [{"role": "assistant", "content": "..."}]
10
+ - kwargs include dataset columns: "rubric" (JSON-serialized list of criteria)
11
+ """
12
+
13
+ import json
14
+ import re
15
+ from .rubric import score_answer
16
+
17
+
18
+ def _extract_text(completions):
19
+ """Extract plain text from TRL chat-format completions."""
20
+ texts = []
21
+ for comp in completions:
22
+ if isinstance(comp, list) and comp:
23
+ texts.append(comp[-1].get("content", ""))
24
+ elif isinstance(comp, str):
25
+ texts.append(comp)
26
+ else:
27
+ texts.append("")
28
+ return texts
29
+
30
+
31
+ def _extract_answer(text):
32
+ """Pull the answer portion after 'Answer:' if present."""
33
+ if "Answer:" in text:
34
+ return text.split("Answer:")[-1].strip()
35
+ return text.strip()
36
+
37
+
38
+ def reward_correctness(completions, **kwargs):
39
+ """
40
+ Signal 1: Rubric correctness (0.0 - 1.0).
41
+ Scores each completion against task rubric criteria using deterministic
42
+ substring matching. This is the primary learning signal.
43
+ """
44
+ texts = _extract_text(completions)
45
+ rubric_strs = kwargs.get("rubric", [])
46
+ rewards = []
47
+ for i, text in enumerate(texts):
48
+ answer = _extract_answer(text)
49
+ try:
50
+ rubric = json.loads(rubric_strs[i]) if i < len(rubric_strs) else []
51
+ except (json.JSONDecodeError, TypeError):
52
+ rubric = []
53
+ if not rubric:
54
+ rewards.append(0.0)
55
+ continue
56
+ score, _ = score_answer(answer, rubric)
57
+ rewards.append(score / 100.0)
58
+ return rewards
59
+
60
+
61
+ def reward_format(completions, **kwargs):
62
+ """
63
+ Signal 2: Format compliance (0.0 or 1.0).
64
+ Checks that the completion follows the expected output structure:
65
+ contains 'Answer:', includes at least one number, reasonable length.
66
+ """
67
+ texts = _extract_text(completions)
68
+ rewards = []
69
+ for text in texts:
70
+ score = 0.0
71
+ has_answer_prefix = "Answer:" in text or "answer:" in text.lower()
72
+ has_number = bool(re.search(r"\d+\.?\d*", text))
73
+ reasonable_length = 50 <= len(text) <= 3000
74
+ if has_answer_prefix and has_number and reasonable_length:
75
+ score = 1.0
76
+ elif has_number and reasonable_length:
77
+ score = 0.5
78
+ rewards.append(score)
79
+ return rewards
80
+
81
+
82
+ def reward_completeness(completions, **kwargs):
83
+ """
84
+ Signal 3: Numeric completeness (0.0 - 1.0).
85
+ Measures how many distinct numeric values appear in the answer relative
86
+ to the number of rubric criteria. Rewards specificity: an answer with
87
+ concrete numbers for every criterion scores higher.
88
+ """
89
+ texts = _extract_text(completions)
90
+ rubric_strs = kwargs.get("rubric", [])
91
+ rewards = []
92
+ for i, text in enumerate(texts):
93
+ answer = _extract_answer(text)
94
+ try:
95
+ rubric = json.loads(rubric_strs[i]) if i < len(rubric_strs) else []
96
+ except (json.JSONDecodeError, TypeError):
97
+ rubric = []
98
+ n_criteria = max(len(rubric), 1)
99
+ numbers = set(re.findall(r"\b\d[\d,.]*\d\b|\b\d+\b", answer))
100
+ ratio = min(len(numbers) / n_criteria, 1.0)
101
+ rewards.append(round(ratio, 3))
102
+ return rewards