Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Pranav Patel commited on
Commit ·
d4ccbaf
1
Parent(s): 9aeca6d
fix: sync bug fixes (imports, eval safety, missing fields, task guards, rewards)
Browse files- harfeast_env/client.py +4 -10
- harfeast_env/models.py +13 -4
- harfeast_env/server/app.py +3 -8
- harfeast_env/server/harfeast_environment.py +2 -6
- harfeast_openenv/actions.py +28 -2
- harfeast_openenv/environment.py +10 -1
- harfeast_openenv/rewards.py +102 -0
harfeast_env/client.py
CHANGED
|
@@ -5,16 +5,10 @@ Connects to HarFeast OpenEnv server via WebSocket/HTTP.
|
|
| 5 |
|
| 6 |
from typing import Any, Dict
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
from harfeast_env.models import HarFeastAction, HarFeastObservation
|
| 13 |
-
except ImportError:
|
| 14 |
-
from openenv.core.client_types import StepResult
|
| 15 |
-
from openenv.core.env_server.types import State
|
| 16 |
-
from openenv.core.env_client import EnvClient
|
| 17 |
-
from models import HarFeastAction, HarFeastObservation
|
| 18 |
|
| 19 |
|
| 20 |
class HarFeastEnv(EnvClient[HarFeastAction, HarFeastObservation, State]):
|
|
|
|
| 5 |
|
| 6 |
from typing import Any, Dict
|
| 7 |
|
| 8 |
+
from openenv.core.client_types import StepResult
|
| 9 |
+
from openenv.core.env_server.types import State
|
| 10 |
+
from openenv.core.env_client import EnvClient
|
| 11 |
+
from harfeast_env.models import HarFeastAction, HarFeastObservation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
class HarFeastEnv(EnvClient[HarFeastAction, HarFeastObservation, State]):
|
harfeast_env/models.py
CHANGED
|
@@ -5,10 +5,7 @@ Actions are JSON-serialized calls: {"action": "files.list", "path": "."}
|
|
| 5 |
|
| 6 |
from pydantic import Field
|
| 7 |
|
| 8 |
-
|
| 9 |
-
from openenv.core.env_server.types import Action, Observation
|
| 10 |
-
except ImportError:
|
| 11 |
-
from openenv.core.env_server.types import Action, Observation
|
| 12 |
|
| 13 |
|
| 14 |
class HarFeastAction(Action):
|
|
@@ -46,3 +43,15 @@ class HarFeastObservation(Observation):
|
|
| 46 |
default="[]",
|
| 47 |
description="JSON list of filtered dataset names available for chaining",
|
| 48 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
from pydantic import Field
|
| 7 |
|
| 8 |
+
from openenv.core.env_server.types import Action, Observation
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class HarFeastAction(Action):
|
|
|
|
| 43 |
default="[]",
|
| 44 |
description="JSON list of filtered dataset names available for chaining",
|
| 45 |
)
|
| 46 |
+
done: bool = Field(
|
| 47 |
+
default=False,
|
| 48 |
+
description="Whether the episode has ended",
|
| 49 |
+
)
|
| 50 |
+
reward: float = Field(
|
| 51 |
+
default=0.0,
|
| 52 |
+
description="Rubric score (0-100) when done, else 0",
|
| 53 |
+
)
|
| 54 |
+
metadata: dict = Field(
|
| 55 |
+
default_factory=dict,
|
| 56 |
+
description="Extra info (action_taken, last_error, task_id)",
|
| 57 |
+
)
|
harfeast_env/server/app.py
CHANGED
|
@@ -11,14 +11,9 @@ _project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(
|
|
| 11 |
if _project_root not in sys.path:
|
| 12 |
sys.path.insert(0, _project_root)
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
from harfeast_env.server.harfeast_environment import HarFeastEnvironment
|
| 18 |
-
except ImportError:
|
| 19 |
-
from openenv.core.env_server.http_server import create_app
|
| 20 |
-
from models import HarFeastAction, HarFeastObservation
|
| 21 |
-
from server.harfeast_environment import HarFeastEnvironment
|
| 22 |
|
| 23 |
# World path - use env var or default to project harfeast_world
|
| 24 |
WORLD_PATH = os.environ.get("HARFEAST_WORLD_PATH") or os.path.join(_project_root, "harfeast_world")
|
|
|
|
| 11 |
if _project_root not in sys.path:
|
| 12 |
sys.path.insert(0, _project_root)
|
| 13 |
|
| 14 |
+
from openenv.core.env_server.http_server import create_app
|
| 15 |
+
from harfeast_env.models import HarFeastAction, HarFeastObservation
|
| 16 |
+
from harfeast_env.server.harfeast_environment import HarFeastEnvironment
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# World path - use env var or default to project harfeast_world
|
| 19 |
WORLD_PATH = os.environ.get("HARFEAST_WORLD_PATH") or os.path.join(_project_root, "harfeast_world")
|
harfeast_env/server/harfeast_environment.py
CHANGED
|
@@ -7,12 +7,8 @@ import json
|
|
| 7 |
import os
|
| 8 |
from uuid import uuid4
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
from openenv.core.env_server.types import State
|
| 13 |
-
except ImportError:
|
| 14 |
-
from openenv.core.env_server.interfaces import Environment
|
| 15 |
-
from openenv.core.env_server.types import State
|
| 16 |
|
| 17 |
# Import our core logic - use path relative to project root
|
| 18 |
import sys
|
|
|
|
| 7 |
import os
|
| 8 |
from uuid import uuid4
|
| 9 |
|
| 10 |
+
from openenv.core.env_server.interfaces import Environment
|
| 11 |
+
from openenv.core.env_server.types import State
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
# Import our core logic - use path relative to project root
|
| 14 |
import sys
|
harfeast_openenv/actions.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
| 3 |
import ast
|
| 4 |
import csv
|
| 5 |
import json
|
|
|
|
| 6 |
import os
|
| 7 |
import re
|
| 8 |
from collections import defaultdict
|
|
@@ -13,6 +14,30 @@ from .schemas import ActionResult
|
|
| 13 |
|
| 14 |
# ── Observation size limits ──────────────────────────────────────
|
| 15 |
MAX_TABLE_ROWS = 20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
MAX_DOCUMENT_CHARS = 2000
|
| 17 |
def handle_files_list(world_path: str, path: str = ".") -> ActionResult:
|
| 18 |
"""
|
|
@@ -412,7 +437,7 @@ def handle_data_add_columns(
|
|
| 412 |
v = _try_float(row.get(c, ""))
|
| 413 |
ns[c] = v if isinstance(v, (int, float)) else 0
|
| 414 |
try:
|
| 415 |
-
row[new_column] = round(
|
| 416 |
except Exception:
|
| 417 |
row[new_column] = 0
|
| 418 |
new_rows.append(row)
|
|
@@ -442,7 +467,8 @@ def handle_data_compute(expression: str) -> ActionResult:
|
|
| 442 |
error="Invalid expression",
|
| 443 |
)
|
| 444 |
try:
|
| 445 |
-
|
|
|
|
| 446 |
if isinstance(result, float) and not result.is_integer():
|
| 447 |
return ActionResult(observation=str(round(result, 2)))
|
| 448 |
return ActionResult(observation=str(result))
|
|
|
|
| 3 |
import ast
|
| 4 |
import csv
|
| 5 |
import json
|
| 6 |
+
import operator
|
| 7 |
import os
|
| 8 |
import re
|
| 9 |
from collections import defaultdict
|
|
|
|
| 14 |
|
| 15 |
# ── Observation size limits ──────────────────────────────────────
|
| 16 |
MAX_TABLE_ROWS = 20
|
| 17 |
+
|
| 18 |
+
# ── Safe arithmetic evaluator (replaces eval) ────────────────────
|
| 19 |
+
_SAFE_BINOPS = {
|
| 20 |
+
ast.Add: operator.add, ast.Sub: operator.sub,
|
| 21 |
+
ast.Mult: operator.mul, ast.Div: operator.truediv,
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
def _safe_eval_expr(node, namespace=None):
|
| 25 |
+
"""Evaluate an AST node containing only arithmetic on numbers (and optionally named vars)."""
|
| 26 |
+
if isinstance(node, ast.Expression):
|
| 27 |
+
return _safe_eval_expr(node.body, namespace)
|
| 28 |
+
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
|
| 29 |
+
return node.value
|
| 30 |
+
if isinstance(node, ast.BinOp) and type(node.op) in _SAFE_BINOPS:
|
| 31 |
+
left = _safe_eval_expr(node.left, namespace)
|
| 32 |
+
right = _safe_eval_expr(node.right, namespace)
|
| 33 |
+
return _SAFE_BINOPS[type(node.op)](left, right)
|
| 34 |
+
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
|
| 35 |
+
return -_safe_eval_expr(node.operand, namespace)
|
| 36 |
+
if isinstance(node, ast.Name) and namespace is not None:
|
| 37 |
+
if node.id in namespace:
|
| 38 |
+
return namespace[node.id]
|
| 39 |
+
raise ValueError(f"Unknown variable: {node.id}")
|
| 40 |
+
raise ValueError(f"Unsupported expression element: {ast.dump(node)}")
|
| 41 |
MAX_DOCUMENT_CHARS = 2000
|
| 42 |
def handle_files_list(world_path: str, path: str = ".") -> ActionResult:
|
| 43 |
"""
|
|
|
|
| 437 |
v = _try_float(row.get(c, ""))
|
| 438 |
ns[c] = v if isinstance(v, (int, float)) else 0
|
| 439 |
try:
|
| 440 |
+
row[new_column] = round(_safe_eval_expr(tree, namespace=ns), 2)
|
| 441 |
except Exception:
|
| 442 |
row[new_column] = 0
|
| 443 |
new_rows.append(row)
|
|
|
|
| 467 |
error="Invalid expression",
|
| 468 |
)
|
| 469 |
try:
|
| 470 |
+
tree = ast.parse(expr, mode="eval")
|
| 471 |
+
result = _safe_eval_expr(tree)
|
| 472 |
if isinstance(result, float) and not result.is_integer():
|
| 473 |
return ActionResult(observation=str(round(result, 2)))
|
| 474 |
return ActionResult(observation=str(result))
|
harfeast_openenv/environment.py
CHANGED
|
@@ -127,6 +127,15 @@ class HarFeastOpenEnv:
|
|
| 127 |
Execute one action and return the result.
|
| 128 |
Action format: {"action": "files.list", "path": "."} or JSON string.
|
| 129 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
if self._done:
|
| 131 |
return StepResult(
|
| 132 |
observation="Episode already ended. Call reset() to start a new episode.",
|
|
@@ -286,7 +295,7 @@ class HarFeastOpenEnv:
|
|
| 286 |
|
| 287 |
def _build_context_summary(self) -> str:
|
| 288 |
"""Compact summary of the episode so far, prepended to every observation."""
|
| 289 |
-
if not self._history:
|
| 290 |
return ""
|
| 291 |
|
| 292 |
lines = [f"=== Task: {self._task['task_name']} ==="]
|
|
|
|
| 127 |
Execute one action and return the result.
|
| 128 |
Action format: {"action": "files.list", "path": "."} or JSON string.
|
| 129 |
"""
|
| 130 |
+
if self._task is None:
|
| 131 |
+
return StepResult(
|
| 132 |
+
observation="No task loaded. Call reset() before step().",
|
| 133 |
+
prompt="",
|
| 134 |
+
step_count=0,
|
| 135 |
+
done=True,
|
| 136 |
+
reward=0.0,
|
| 137 |
+
info={"action_taken": "none", "last_error": "reset() not called"},
|
| 138 |
+
)
|
| 139 |
if self._done:
|
| 140 |
return StepResult(
|
| 141 |
observation="Episode already ended. Call reset() to start a new episode.",
|
|
|
|
| 295 |
|
| 296 |
def _build_context_summary(self) -> str:
|
| 297 |
"""Compact summary of the episode so far, prepended to every observation."""
|
| 298 |
+
if not self._history or not self._task:
|
| 299 |
return ""
|
| 300 |
|
| 301 |
lines = [f"=== Task: {self._task['task_name']} ==="]
|
harfeast_openenv/rewards.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GDPO-style decomposed reward functions for HarFeast GRPO training.
|
| 3 |
+
|
| 4 |
+
Three independent reward signals, each normalized independently by TRL's
|
| 5 |
+
GRPOTrainer when passed as a list to reward_funcs. This is equivalent to
|
| 6 |
+
NVIDIA's GDPO (Jan 2026) multi-signal normalization.
|
| 7 |
+
|
| 8 |
+
Signature: reward_func(completions: list[list[dict]], **kwargs) -> list[float]
|
| 9 |
+
- completions[i] = [{"role": "assistant", "content": "..."}]
|
| 10 |
+
- kwargs include dataset columns: "rubric" (JSON-serialized list of criteria)
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import json
|
| 14 |
+
import re
|
| 15 |
+
from .rubric import score_answer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _extract_text(completions):
|
| 19 |
+
"""Extract plain text from TRL chat-format completions."""
|
| 20 |
+
texts = []
|
| 21 |
+
for comp in completions:
|
| 22 |
+
if isinstance(comp, list) and comp:
|
| 23 |
+
texts.append(comp[-1].get("content", ""))
|
| 24 |
+
elif isinstance(comp, str):
|
| 25 |
+
texts.append(comp)
|
| 26 |
+
else:
|
| 27 |
+
texts.append("")
|
| 28 |
+
return texts
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _extract_answer(text):
|
| 32 |
+
"""Pull the answer portion after 'Answer:' if present."""
|
| 33 |
+
if "Answer:" in text:
|
| 34 |
+
return text.split("Answer:")[-1].strip()
|
| 35 |
+
return text.strip()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def reward_correctness(completions, **kwargs):
|
| 39 |
+
"""
|
| 40 |
+
Signal 1: Rubric correctness (0.0 - 1.0).
|
| 41 |
+
Scores each completion against task rubric criteria using deterministic
|
| 42 |
+
substring matching. This is the primary learning signal.
|
| 43 |
+
"""
|
| 44 |
+
texts = _extract_text(completions)
|
| 45 |
+
rubric_strs = kwargs.get("rubric", [])
|
| 46 |
+
rewards = []
|
| 47 |
+
for i, text in enumerate(texts):
|
| 48 |
+
answer = _extract_answer(text)
|
| 49 |
+
try:
|
| 50 |
+
rubric = json.loads(rubric_strs[i]) if i < len(rubric_strs) else []
|
| 51 |
+
except (json.JSONDecodeError, TypeError):
|
| 52 |
+
rubric = []
|
| 53 |
+
if not rubric:
|
| 54 |
+
rewards.append(0.0)
|
| 55 |
+
continue
|
| 56 |
+
score, _ = score_answer(answer, rubric)
|
| 57 |
+
rewards.append(score / 100.0)
|
| 58 |
+
return rewards
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def reward_format(completions, **kwargs):
|
| 62 |
+
"""
|
| 63 |
+
Signal 2: Format compliance (0.0 or 1.0).
|
| 64 |
+
Checks that the completion follows the expected output structure:
|
| 65 |
+
contains 'Answer:', includes at least one number, reasonable length.
|
| 66 |
+
"""
|
| 67 |
+
texts = _extract_text(completions)
|
| 68 |
+
rewards = []
|
| 69 |
+
for text in texts:
|
| 70 |
+
score = 0.0
|
| 71 |
+
has_answer_prefix = "Answer:" in text or "answer:" in text.lower()
|
| 72 |
+
has_number = bool(re.search(r"\d+\.?\d*", text))
|
| 73 |
+
reasonable_length = 50 <= len(text) <= 3000
|
| 74 |
+
if has_answer_prefix and has_number and reasonable_length:
|
| 75 |
+
score = 1.0
|
| 76 |
+
elif has_number and reasonable_length:
|
| 77 |
+
score = 0.5
|
| 78 |
+
rewards.append(score)
|
| 79 |
+
return rewards
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def reward_completeness(completions, **kwargs):
|
| 83 |
+
"""
|
| 84 |
+
Signal 3: Numeric completeness (0.0 - 1.0).
|
| 85 |
+
Measures how many distinct numeric values appear in the answer relative
|
| 86 |
+
to the number of rubric criteria. Rewards specificity: an answer with
|
| 87 |
+
concrete numbers for every criterion scores higher.
|
| 88 |
+
"""
|
| 89 |
+
texts = _extract_text(completions)
|
| 90 |
+
rubric_strs = kwargs.get("rubric", [])
|
| 91 |
+
rewards = []
|
| 92 |
+
for i, text in enumerate(texts):
|
| 93 |
+
answer = _extract_answer(text)
|
| 94 |
+
try:
|
| 95 |
+
rubric = json.loads(rubric_strs[i]) if i < len(rubric_strs) else []
|
| 96 |
+
except (json.JSONDecodeError, TypeError):
|
| 97 |
+
rubric = []
|
| 98 |
+
n_criteria = max(len(rubric), 1)
|
| 99 |
+
numbers = set(re.findall(r"\b\d[\d,.]*\d\b|\b\d+\b", answer))
|
| 100 |
+
ratio = min(len(numbers) / n_criteria, 1.0)
|
| 101 |
+
rewards.append(round(ratio, 3))
|
| 102 |
+
return rewards
|