constraint-env / server /constraint_env_environment.py
DecentSanage's picture
Upload folder using huggingface_hub
9f0026f verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Constraint Env Environment Implementation.
Evaluates an LLM's ability to convert natural-language scheduling
constraints into a JSON AST DSL.
Reward breakdown per step:
+0.125 valid JSON (1 / 8 of max)
+0.250 correct top-level structure (2 / 8 of max)
+0.625 exact match with target AST (5 / 8 of max)
──────
1.000 total maximum reward
Penalties:
-0.250 bad_structure (structure wrong but JSON parsed)
-0.250 invalid_json (cannot parse at all, replaces reward=0)
"""
import re
import json
from uuid import uuid4
from typing import Any, Dict, List, Optional, Tuple
import random
from openenv.core.env_server.interfaces import Environment
try:
from ..models import ConstraintAction, ConstraintObservation, ConstraintState
except ImportError:
import sys, os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from models import ConstraintAction, ConstraintObservation, ConstraintState
# ---------------------------------------------------------------------------
# Domain knowledge
# ---------------------------------------------------------------------------
VALID_DOMAINS = {"teachers", "subjects", "branches", "days", "slots"}
VALID_FUNCTIONS: Dict[str, int] = {
"subject_type": 2,
"schedule": 4,
"occupies": 4,
"occupies_teacher": 3,
"teaches": 2,
"SUM": 1,
"COUNT": 1,
}
try:
from .graders import calculate_step_reward
except ImportError:
from graders import calculate_step_reward
# ---------------------------------------------------------------------------
# Environment
# ---------------------------------------------------------------------------
class ConstraintEnvironment(Environment):
"""
OpenEnv environment for natural-language → constraint-AST translation.
Args:
dataset: Dict with keys "easy", "medium", "hard", each a list of
{"prompt": str, "target_ast": dict} entries.
If omitted, the built-in dataset_example is used.
"""
def __init__(self, dataset: Optional[Dict[str, List[Dict]]] = None):
if dataset is None:
try:
from dataset_example import dataset as _ds
except ImportError:
from constraint_env.dataset_example import dataset as _ds # type: ignore
dataset = _ds
self._dataset = dataset
self._difficulty: str = "easy"
self._indexes: Dict[str, int] = {k: 0 for k in dataset}
self._current_sample: Optional[Dict] = None
self._state = ConstraintState(episode_id=None)
# -----------------------------------------------f-------------------
# OpenEnv interface
# ------------------------------------------------------------------
def reset(self, task_id: Optional[str] = None):
"""
Reset the environment for a new episode.
Args:
task_id: One of "easy", "medium", or "hard".
Defaults to cycling through random.
"""
if task_id and task_id in self._dataset:
self._difficulty = task_id
elif task_id is None:
# Default: cycle through easy samples
self._difficulty = random.choice(["easy", "medium", "hard"])
pool = self._dataset[self._difficulty]
# idx = self._indexes[self._difficulty]
idx = 0
self._current_sample = pool[idx]
# self._indexes[self._difficulty] = (idx + 1) % len(pool)
self._state = ConstraintState(
episode_id=str(uuid4()),
step_count=0,
max_steps=5
)
return ConstraintObservation(
prompt=self._current_sample["prompt"],
done=False,
reward=0.0,
info={"difficulty": self._difficulty},
messages=[]
)
def step(self, action: ConstraintAction):
"""
Evaluate the agent's AST output and return a scored observation.
"""
if self._current_sample is None:
self.reset()
self._state.step_count += 1
info: Dict[str, Any] = {"difficulty": self._difficulty}
messages: List[str] = []
is_valid_json = False
is_valid_structure = False
is_exact_match = False
ast = None
# ── 1. Parse JSON ────────────────────────────────────────────
try:
raw = action.ast_output
# The Gradio/WebUI may pass the action already parsed as a dict
if isinstance(raw, dict):
ast = raw
else:
ast = json.loads(raw)
# Handle double-encoded JSON (string that contains JSON)
if isinstance(ast, str):
ast = json.loads(ast)
if not isinstance(ast, dict):
raise TypeError(f"Expected a JSON object, got {type(ast).__name__}")
is_valid_json = True
except (json.JSONDecodeError, TypeError) as exc:
info["error"] = "invalid_json"
messages.extend([
{"role": "assistant", "content": "Your last submitted AST:"},
{"role": "assistant", "content": str(action.ast_output)},
{"role": "assistant", "content": f"Compiler Error: Syntax Error. Invalid JSON — {exc}"}
])
# ── 2. Logic match (ignores "name") ──────────────────────────
if is_valid_json and isinstance(ast, dict) and "target_ast" in self._current_sample:
if self._logic_match(ast, self._current_sample["target_ast"]):
is_exact_match = True
info["exact_match"] = True
else:
info["exact_match"] = False
# ── 3. Validate structure ─────────────────────────────────────
if is_exact_match:
is_valid_structure = True
elif is_valid_json:
# THE FIX: Safely convert the AST to a string so Pydantic doesn't crash!
safe_ast_str = json.dumps(ast) if isinstance(ast, dict) else str(action.ast_output)
valid, msg = self._validate_structure(ast)
if valid:
is_valid_structure = True
# NEW FIX: Provide feedback when structure is valid but logic fails!
info["error"] = "logic_mismatch"
messages.extend([
{"role": "assistant", "content": "Your last submitted AST:"},
{"role": "assistant", "content": safe_ast_str},
{"role": "assistant", "content": "Compiler Error: Syntax is valid, but the logic does not match the prompt's target constraint. Please adjust your logical conditions and resubmit."}
])
else:
info["error"] = "bad_structure"
messages.extend([
{"role": "assistant", "content": "Your last submitted AST:"},
{"role": "assistant", "content": safe_ast_str},
{"role": "assistant", "content": f"Compiler Error: {msg}"}
])
done = is_exact_match or self._state.step_count >= self._state.max_steps
reward = calculate_step_reward(is_valid_json, is_valid_structure, is_exact_match, self._state.step_count)
return ConstraintObservation(
prompt=self._current_sample["prompt"],
done=done,
reward=reward,
info=info,
messages=messages
)
@property
def state(self):
return self._state
# ------------------------------------------------------------------
# Validation helpers
# ------------------------------------------------------------------
@staticmethod
def _logic_match(ast: Any, target: Dict[str, Any]) -> bool:
"""
Compare two ASTs on every logically meaningful field, ignoring "name".
Fields compared:
• type – hard / soft
• forall – same variable declarations (order-independent)
• where – optional guard expression (string equality)
• assert – constraint body } exactly one
• minimize – objective body }
"name" is intentionally excluded — it is a free-form user-chosen
snake_case identifier that has no effect on the constraint's semantics.
"""
_LOGIC_KEYS = {"type", "forall", "where", "assert", "minimize"}
# Guard: both sides must be dicts
if not isinstance(ast, dict) or not isinstance(target, dict):
return False
# Collect only logic keys present in either dict
all_keys = (set(ast.keys()) | set(target.keys())) & _LOGIC_KEYS
for key in all_keys:
a_val = ast.get(key)
t_val = target.get(key)
if key == "forall":
# Order of variable declarations doesn't matter;
# compare as a frozenset of (var, domain) pairs.
try:
def extract_pairs(lst):
pairs = []
for decl in (lst or []):
if isinstance(decl, dict) and len(decl) == 1:
var, domain = list(decl.items())[0]
domain_val = list(domain.keys())[0] if isinstance(domain, dict) else domain
pairs.append((var, domain_val))
return frozenset(pairs)
a_set = extract_pairs(a_val)
t_set = extract_pairs(t_val)
if a_set != t_set:
return False
except (TypeError, KeyError):
return False
else:
if a_val != t_val:
return False
return True
def _validate_structure(self, ast: Dict[str, Any]) -> Tuple[bool, str]:
"""Return (True, "") if AST follows the expected schema, else (False, error_msg)."""
# 1. Top-level type field must be exactly "hard" or "soft"
if ast.get("type") not in {"hard", "soft"}:
return False, "Structure Error: 'type' must be 'hard' or 'soft'."
# 2. Must have forall
if "forall" not in ast:
return False, "Structure Error: Missing 'forall' array."
if not isinstance(ast["forall"], list) or len(ast["forall"]) == 0:
return False, "Structure Error: 'forall' must be a non-empty list of variable declarations."
# 3. Build variable scope from forall declarations
scope: Dict[str, str] = {}
for var_decl in ast["forall"]:
if not isinstance(var_decl, dict) or len(var_decl) != 1:
return False, "Structure Error: Each item in 'forall' must be a dict with a single key-value pair."
var, domain = list(var_decl.items())[0]
if isinstance(domain, dict):
# e.g., {"sub": {"subjects": "b"}}
domain_val = list(domain.keys())[0]
else:
domain_val = domain
if domain_val not in VALID_DOMAINS:
return False, f"Structure Error: Unknown domain '{domain_val}'. Valid domains: {VALID_DOMAINS}"
scope[var] = domain_val
# 4. Validate optional WHERE clause
if "where" in ast:
valid, msg = self._validate_nested_expr(ast["where"], scope)
if not valid:
return False, f"Structure Error in 'where' clause: {msg}"
# 5. Validate payload: assert OR minimize (one required)
if "assert" in ast:
valid, msg = self._validate_nested_expr(ast["assert"], scope)
if not valid:
return False, f"Structure Error in 'assert' clause: {msg}"
return True, ""
elif "minimize" in ast:
valid, msg = self._validate_nested_expr(ast["minimize"], scope)
if not valid:
return False, f"Structure Error in 'minimize' clause: {msg}"
return True, ""
else:
return False, "Structure Error: Must provide either 'assert' or 'minimize'."
def _validate_nested_expr(self, expr: Any, scope: Dict[str, str]) -> Tuple[bool, str]:
"""Recursively validate a nested expression JSON AST against the scope."""
if isinstance(expr, (int, float, bool)):
return True, ""
if isinstance(expr, str):
# Known domain literals or numbers passed as string
if expr.isdigit(): return True, ""
if expr.startswith("'") or expr.startswith('"'): return True, ""
if expr in ["CS", "online", "practical"]: return True, ""
if expr not in scope:
return False, f"Unknown identifier '{expr}' not declared in scope."
return True, ""
if not isinstance(expr, dict):
return False, f"Expression node must be dict or literal, got {type(expr).__name__}."
if "name" in expr:
var_name = expr["name"]
if var_name not in scope:
return False, f"Unknown identifier '{var_name}' not declared in scope."
return True, ""
if "operator" in expr:
op = expr["operator"]
if op == "sum":
if "over" not in expr or not isinstance(expr["over"], list):
return False, "Sum operator must have an 'over' array."
if "expression" not in expr:
return False, "Sum operator must have an 'expression' to evaluate."
local_scope = dict(scope)
for decl in expr["over"]:
if not isinstance(decl, dict) or len(decl) != 1:
return False, "Each item in 'over' must be a dict with 1 key-value pair."
var, domain = list(decl.items())[0]
domain_val = list(domain.keys())[0] if isinstance(domain, dict) else domain
if domain_val not in VALID_DOMAINS:
return False, f"Unknown domain '{domain_val}' in sum 'over'."
local_scope[var] = domain_val
return self._validate_nested_expr(expr["expression"], local_scope)
else:
for side in ["left", "right"]:
if side in expr:
valid, msg = self._validate_nested_expr(expr[side], scope)
if not valid:
return False, msg
return True, ""
if "target" in expr:
target = expr["target"]
if target not in VALID_FUNCTIONS:
return False, f"Unknown function '{target}'."
args = expr.get("args", [])
if not isinstance(args, list):
return False, f"Function '{target}' args must be a list."
if len(args) != VALID_FUNCTIONS[target]:
return False, f"Function '{target}' expects {VALID_FUNCTIONS[target]} args, got {len(args)}."
for arg in args:
valid, msg = self._validate_nested_expr(arg, scope)
if not valid:
return False, msg
return True, ""
return False, f"Unrecognized AST node structure: {list(expr.keys())}"
# ---------------------------------------------------------------------------
# Quick smoke-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import json as _json
try:
# If run dynamically relative to root
import sys, os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from dataset_example import dataset as _ds
except ImportError:
pass
env = ConstraintEnvironment(_ds)
for difficulty in ("easy", "medium", "hard"):
obs = env.reset(task_id=difficulty)
print(f"\n[{difficulty.upper()}] prompt: {obs.prompt}")
# send perfect answer
target = env._current_sample["target_ast"]
action = ConstraintAction(ast_output=_json.dumps(target))
result = env.step(action)
print(f" reward={result.reward} done={result.done} info={result.info}")
# send bad JSON
obs2 = env.reset(task_id=difficulty)
bad = ConstraintAction(ast_output="this is not json")
res2 = env.step(bad)
print(f" [bad JSON] reward={res2.reward} info={res2.info}")