Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Constraint Env Environment Implementation. | |
| Evaluates an LLM's ability to convert natural-language scheduling | |
| constraints into a JSON AST DSL. | |
| Reward breakdown per step: | |
| +0.125 valid JSON (1 / 8 of max) | |
| +0.250 correct top-level structure (2 / 8 of max) | |
| +0.625 exact match with target AST (5 / 8 of max) | |
| ────── | |
| 1.000 total maximum reward | |
| Penalties: | |
| -0.250 bad_structure (structure wrong but JSON parsed) | |
| -0.250 invalid_json (cannot parse at all, replaces reward=0) | |
| """ | |
| import re | |
| import json | |
| from uuid import uuid4 | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import random | |
| from openenv.core.env_server.interfaces import Environment | |
| try: | |
| from ..models import ConstraintAction, ConstraintObservation, ConstraintState | |
| except ImportError: | |
| import sys, os | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from models import ConstraintAction, ConstraintObservation, ConstraintState | |
| # --------------------------------------------------------------------------- | |
| # Domain knowledge | |
| # --------------------------------------------------------------------------- | |
| VALID_DOMAINS = {"teachers", "subjects", "branches", "days", "slots"} | |
| VALID_FUNCTIONS: Dict[str, int] = { | |
| "subject_type": 2, | |
| "schedule": 4, | |
| "occupies": 4, | |
| "occupies_teacher": 3, | |
| "teaches": 2, | |
| "SUM": 1, | |
| "COUNT": 1, | |
| } | |
| try: | |
| from .graders import calculate_step_reward | |
| except ImportError: | |
| from graders import calculate_step_reward | |
| # --------------------------------------------------------------------------- | |
| # Environment | |
| # --------------------------------------------------------------------------- | |
| class ConstraintEnvironment(Environment): | |
| """ | |
| OpenEnv environment for natural-language → constraint-AST translation. | |
| Args: | |
| dataset: Dict with keys "easy", "medium", "hard", each a list of | |
| {"prompt": str, "target_ast": dict} entries. | |
| If omitted, the built-in dataset_example is used. | |
| """ | |
| def __init__(self, dataset: Optional[Dict[str, List[Dict]]] = None): | |
| if dataset is None: | |
| try: | |
| from dataset_example import dataset as _ds | |
| except ImportError: | |
| from constraint_env.dataset_example import dataset as _ds # type: ignore | |
| dataset = _ds | |
| self._dataset = dataset | |
| self._difficulty: str = "easy" | |
| self._indexes: Dict[str, int] = {k: 0 for k in dataset} | |
| self._current_sample: Optional[Dict] = None | |
| self._state = ConstraintState(episode_id=None) | |
| # -----------------------------------------------f------------------- | |
| # OpenEnv interface | |
| # ------------------------------------------------------------------ | |
| def reset(self, task_id: Optional[str] = None): | |
| """ | |
| Reset the environment for a new episode. | |
| Args: | |
| task_id: One of "easy", "medium", or "hard". | |
| Defaults to cycling through random. | |
| """ | |
| if task_id and task_id in self._dataset: | |
| self._difficulty = task_id | |
| elif task_id is None: | |
| # Default: cycle through easy samples | |
| self._difficulty = random.choice(["easy", "medium", "hard"]) | |
| pool = self._dataset[self._difficulty] | |
| # idx = self._indexes[self._difficulty] | |
| idx = 0 | |
| self._current_sample = pool[idx] | |
| # self._indexes[self._difficulty] = (idx + 1) % len(pool) | |
| self._state = ConstraintState( | |
| episode_id=str(uuid4()), | |
| step_count=0, | |
| max_steps=5 | |
| ) | |
| return ConstraintObservation( | |
| prompt=self._current_sample["prompt"], | |
| done=False, | |
| reward=0.0, | |
| info={"difficulty": self._difficulty}, | |
| messages=[] | |
| ) | |
| def step(self, action: ConstraintAction): | |
| """ | |
| Evaluate the agent's AST output and return a scored observation. | |
| """ | |
| if self._current_sample is None: | |
| self.reset() | |
| self._state.step_count += 1 | |
| info: Dict[str, Any] = {"difficulty": self._difficulty} | |
| messages: List[str] = [] | |
| is_valid_json = False | |
| is_valid_structure = False | |
| is_exact_match = False | |
| ast = None | |
| # ── 1. Parse JSON ──────────────────────────────────────────── | |
| try: | |
| raw = action.ast_output | |
| # The Gradio/WebUI may pass the action already parsed as a dict | |
| if isinstance(raw, dict): | |
| ast = raw | |
| else: | |
| ast = json.loads(raw) | |
| # Handle double-encoded JSON (string that contains JSON) | |
| if isinstance(ast, str): | |
| ast = json.loads(ast) | |
| if not isinstance(ast, dict): | |
| raise TypeError(f"Expected a JSON object, got {type(ast).__name__}") | |
| is_valid_json = True | |
| except (json.JSONDecodeError, TypeError) as exc: | |
| info["error"] = "invalid_json" | |
| messages.extend([ | |
| {"role": "assistant", "content": "Your last submitted AST:"}, | |
| {"role": "assistant", "content": str(action.ast_output)}, | |
| {"role": "assistant", "content": f"Compiler Error: Syntax Error. Invalid JSON — {exc}"} | |
| ]) | |
| # ── 2. Logic match (ignores "name") ────────────────────────── | |
| if is_valid_json and isinstance(ast, dict) and "target_ast" in self._current_sample: | |
| if self._logic_match(ast, self._current_sample["target_ast"]): | |
| is_exact_match = True | |
| info["exact_match"] = True | |
| else: | |
| info["exact_match"] = False | |
| # ── 3. Validate structure ───────────────────────────────────── | |
| if is_exact_match: | |
| is_valid_structure = True | |
| elif is_valid_json: | |
| # THE FIX: Safely convert the AST to a string so Pydantic doesn't crash! | |
| safe_ast_str = json.dumps(ast) if isinstance(ast, dict) else str(action.ast_output) | |
| valid, msg = self._validate_structure(ast) | |
| if valid: | |
| is_valid_structure = True | |
| # NEW FIX: Provide feedback when structure is valid but logic fails! | |
| info["error"] = "logic_mismatch" | |
| messages.extend([ | |
| {"role": "assistant", "content": "Your last submitted AST:"}, | |
| {"role": "assistant", "content": safe_ast_str}, | |
| {"role": "assistant", "content": "Compiler Error: Syntax is valid, but the logic does not match the prompt's target constraint. Please adjust your logical conditions and resubmit."} | |
| ]) | |
| else: | |
| info["error"] = "bad_structure" | |
| messages.extend([ | |
| {"role": "assistant", "content": "Your last submitted AST:"}, | |
| {"role": "assistant", "content": safe_ast_str}, | |
| {"role": "assistant", "content": f"Compiler Error: {msg}"} | |
| ]) | |
| done = is_exact_match or self._state.step_count >= self._state.max_steps | |
| reward = calculate_step_reward(is_valid_json, is_valid_structure, is_exact_match, self._state.step_count) | |
| return ConstraintObservation( | |
| prompt=self._current_sample["prompt"], | |
| done=done, | |
| reward=reward, | |
| info=info, | |
| messages=messages | |
| ) | |
| def state(self): | |
| return self._state | |
| # ------------------------------------------------------------------ | |
| # Validation helpers | |
| # ------------------------------------------------------------------ | |
| def _logic_match(ast: Any, target: Dict[str, Any]) -> bool: | |
| """ | |
| Compare two ASTs on every logically meaningful field, ignoring "name". | |
| Fields compared: | |
| • type – hard / soft | |
| • forall – same variable declarations (order-independent) | |
| • where – optional guard expression (string equality) | |
| • assert – constraint body } exactly one | |
| • minimize – objective body } | |
| "name" is intentionally excluded — it is a free-form user-chosen | |
| snake_case identifier that has no effect on the constraint's semantics. | |
| """ | |
| _LOGIC_KEYS = {"type", "forall", "where", "assert", "minimize"} | |
| # Guard: both sides must be dicts | |
| if not isinstance(ast, dict) or not isinstance(target, dict): | |
| return False | |
| # Collect only logic keys present in either dict | |
| all_keys = (set(ast.keys()) | set(target.keys())) & _LOGIC_KEYS | |
| for key in all_keys: | |
| a_val = ast.get(key) | |
| t_val = target.get(key) | |
| if key == "forall": | |
| # Order of variable declarations doesn't matter; | |
| # compare as a frozenset of (var, domain) pairs. | |
| try: | |
| def extract_pairs(lst): | |
| pairs = [] | |
| for decl in (lst or []): | |
| if isinstance(decl, dict) and len(decl) == 1: | |
| var, domain = list(decl.items())[0] | |
| domain_val = list(domain.keys())[0] if isinstance(domain, dict) else domain | |
| pairs.append((var, domain_val)) | |
| return frozenset(pairs) | |
| a_set = extract_pairs(a_val) | |
| t_set = extract_pairs(t_val) | |
| if a_set != t_set: | |
| return False | |
| except (TypeError, KeyError): | |
| return False | |
| else: | |
| if a_val != t_val: | |
| return False | |
| return True | |
| def _validate_structure(self, ast: Dict[str, Any]) -> Tuple[bool, str]: | |
| """Return (True, "") if AST follows the expected schema, else (False, error_msg).""" | |
| # 1. Top-level type field must be exactly "hard" or "soft" | |
| if ast.get("type") not in {"hard", "soft"}: | |
| return False, "Structure Error: 'type' must be 'hard' or 'soft'." | |
| # 2. Must have forall | |
| if "forall" not in ast: | |
| return False, "Structure Error: Missing 'forall' array." | |
| if not isinstance(ast["forall"], list) or len(ast["forall"]) == 0: | |
| return False, "Structure Error: 'forall' must be a non-empty list of variable declarations." | |
| # 3. Build variable scope from forall declarations | |
| scope: Dict[str, str] = {} | |
| for var_decl in ast["forall"]: | |
| if not isinstance(var_decl, dict) or len(var_decl) != 1: | |
| return False, "Structure Error: Each item in 'forall' must be a dict with a single key-value pair." | |
| var, domain = list(var_decl.items())[0] | |
| if isinstance(domain, dict): | |
| # e.g., {"sub": {"subjects": "b"}} | |
| domain_val = list(domain.keys())[0] | |
| else: | |
| domain_val = domain | |
| if domain_val not in VALID_DOMAINS: | |
| return False, f"Structure Error: Unknown domain '{domain_val}'. Valid domains: {VALID_DOMAINS}" | |
| scope[var] = domain_val | |
| # 4. Validate optional WHERE clause | |
| if "where" in ast: | |
| valid, msg = self._validate_nested_expr(ast["where"], scope) | |
| if not valid: | |
| return False, f"Structure Error in 'where' clause: {msg}" | |
| # 5. Validate payload: assert OR minimize (one required) | |
| if "assert" in ast: | |
| valid, msg = self._validate_nested_expr(ast["assert"], scope) | |
| if not valid: | |
| return False, f"Structure Error in 'assert' clause: {msg}" | |
| return True, "" | |
| elif "minimize" in ast: | |
| valid, msg = self._validate_nested_expr(ast["minimize"], scope) | |
| if not valid: | |
| return False, f"Structure Error in 'minimize' clause: {msg}" | |
| return True, "" | |
| else: | |
| return False, "Structure Error: Must provide either 'assert' or 'minimize'." | |
| def _validate_nested_expr(self, expr: Any, scope: Dict[str, str]) -> Tuple[bool, str]: | |
| """Recursively validate a nested expression JSON AST against the scope.""" | |
| if isinstance(expr, (int, float, bool)): | |
| return True, "" | |
| if isinstance(expr, str): | |
| # Known domain literals or numbers passed as string | |
| if expr.isdigit(): return True, "" | |
| if expr.startswith("'") or expr.startswith('"'): return True, "" | |
| if expr in ["CS", "online", "practical"]: return True, "" | |
| if expr not in scope: | |
| return False, f"Unknown identifier '{expr}' not declared in scope." | |
| return True, "" | |
| if not isinstance(expr, dict): | |
| return False, f"Expression node must be dict or literal, got {type(expr).__name__}." | |
| if "name" in expr: | |
| var_name = expr["name"] | |
| if var_name not in scope: | |
| return False, f"Unknown identifier '{var_name}' not declared in scope." | |
| return True, "" | |
| if "operator" in expr: | |
| op = expr["operator"] | |
| if op == "sum": | |
| if "over" not in expr or not isinstance(expr["over"], list): | |
| return False, "Sum operator must have an 'over' array." | |
| if "expression" not in expr: | |
| return False, "Sum operator must have an 'expression' to evaluate." | |
| local_scope = dict(scope) | |
| for decl in expr["over"]: | |
| if not isinstance(decl, dict) or len(decl) != 1: | |
| return False, "Each item in 'over' must be a dict with 1 key-value pair." | |
| var, domain = list(decl.items())[0] | |
| domain_val = list(domain.keys())[0] if isinstance(domain, dict) else domain | |
| if domain_val not in VALID_DOMAINS: | |
| return False, f"Unknown domain '{domain_val}' in sum 'over'." | |
| local_scope[var] = domain_val | |
| return self._validate_nested_expr(expr["expression"], local_scope) | |
| else: | |
| for side in ["left", "right"]: | |
| if side in expr: | |
| valid, msg = self._validate_nested_expr(expr[side], scope) | |
| if not valid: | |
| return False, msg | |
| return True, "" | |
| if "target" in expr: | |
| target = expr["target"] | |
| if target not in VALID_FUNCTIONS: | |
| return False, f"Unknown function '{target}'." | |
| args = expr.get("args", []) | |
| if not isinstance(args, list): | |
| return False, f"Function '{target}' args must be a list." | |
| if len(args) != VALID_FUNCTIONS[target]: | |
| return False, f"Function '{target}' expects {VALID_FUNCTIONS[target]} args, got {len(args)}." | |
| for arg in args: | |
| valid, msg = self._validate_nested_expr(arg, scope) | |
| if not valid: | |
| return False, msg | |
| return True, "" | |
| return False, f"Unrecognized AST node structure: {list(expr.keys())}" | |
| # --------------------------------------------------------------------------- | |
| # Quick smoke-test | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import json as _json | |
| try: | |
| # If run dynamically relative to root | |
| import sys, os | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| from dataset_example import dataset as _ds | |
| except ImportError: | |
| pass | |
| env = ConstraintEnvironment(_ds) | |
| for difficulty in ("easy", "medium", "hard"): | |
| obs = env.reset(task_id=difficulty) | |
| print(f"\n[{difficulty.upper()}] prompt: {obs.prompt}") | |
| # send perfect answer | |
| target = env._current_sample["target_ast"] | |
| action = ConstraintAction(ast_output=_json.dumps(target)) | |
| result = env.step(action) | |
| print(f" reward={result.reward} done={result.done} info={result.info}") | |
| # send bad JSON | |
| obs2 = env.reset(task_id=difficulty) | |
| bad = ConstraintAction(ast_output="this is not json") | |
| res2 = env.step(bad) | |
| print(f" [bad JSON] reward={res2.reward} info={res2.info}") |