import copy import json from pathlib import Path from typing import Optional from .paper_state import PaperState from .rewards import compute_reward, compute_terminal_reward from .prompts import ( code_as_policy_prompt, get_semantic_description, step_level_prompt, parse_fold_list, parse_single_fold, ) from .verifier import check_all_vertices TARGETS_DIR = Path(__file__).parent / 'targets' class OrigamiEnvironment: """ OpenEnv-compatible origami crease pattern environment. Supports two modes: - code_as_policy: model outputs complete fold sequence, gets terminal reward - step: model outputs one fold at a time, gets per-step reward """ def __init__( self, mode: str = 'code_as_policy', # 'code_as_policy' or 'step' max_steps: int = 8, targets_dir: Optional[str] = None, use_semantic: bool = True, ): assert mode in ('code_as_policy', 'step'), f"Unknown mode: {mode}" self.mode = mode self.max_steps = max_steps self.targets_dir = Path(targets_dir) if targets_dir else TARGETS_DIR self.use_semantic = use_semantic self.paper: Optional[PaperState] = None self.target: Optional[dict] = None self.target_name: Optional[str] = None self.step_count: int = 0 self.last_reward: Optional[dict] = None self._targets = self._load_all_targets() def _load_all_targets(self) -> dict[str, dict]: targets = {} for fold_file in self.targets_dir.glob('*.fold'): with open(fold_file) as f: targets[fold_file.stem] = json.load(f) return targets def available_targets(self) -> list[str]: return sorted(self._targets.keys()) def reset(self, target_name: Optional[str] = None) -> dict: """ Reset environment to start of a new episode. Args: target_name: name of target (stem of .fold file). If None, picks level-1 randomly. Returns: observation dict with 'prompt' key containing the LLM prompt string. """ import random if target_name: assert target_name in self._targets, f"Unknown target: {target_name}" self.target_name = target_name else: # Default to level-1 targets level1 = [k for k, v in self._targets.items() if v.get('level', 1) == 1] self.target_name = random.choice(level1 if level1 else list(self._targets.keys())) self.target = self._targets[self.target_name] self.paper = PaperState() self.step_count = 0 self.last_reward = None return self._get_observation() def step(self, action) -> tuple[dict, dict, bool, dict]: """ Execute an action. In code_as_policy mode: action is a string (model completion with tags) OR a list of fold dicts already parsed. In step mode: action is a string (single fold JSON) or dict. Returns: (observation, reward, done, info) """ if self.mode == 'code_as_policy': return self._step_sequence(action) else: return self._step_single(action) def _step_sequence(self, action) -> tuple[dict, dict, bool, dict]: """Execute a complete fold sequence (code-as-policy mode).""" # Parse action if it's a string if isinstance(action, str): try: folds = parse_fold_list(action) except ValueError as e: bad_reward = {'format': 0.0, 'total': -0.1, 'error': str(e)} return self._get_observation(), bad_reward, True, self._info() else: folds = action # already a list of dicts # Execute each fold sequentially last_result = {'valid': True, 'anchored': True, 'new_vertices': [], 'errors': []} for fold in folds: try: p1 = fold['from'] p2 = fold['to'] assignment = fold['assignment'] except (KeyError, TypeError) as e: last_result = {'valid': False, 'anchored': False, 'new_vertices': [], 'errors': [str(e)]} break last_result = self.paper.add_crease(p1, p2, assignment) self.step_count += 1 if not last_result['valid']: break # stop at first invalid fold, partial credit reward = compute_terminal_reward(self.paper, self.target, self.max_steps) self.last_reward = reward return self._get_observation(), reward, True, self._info() def _step_single(self, action) -> tuple[dict, dict, bool, dict]: """Execute a single fold (step mode).""" if isinstance(action, str): try: fold = parse_single_fold(action) except ValueError as e: bad_reward = {'format': 0.0, 'total': -0.1, 'error': str(e)} self.last_reward = bad_reward done = self.step_count >= self.max_steps return self._get_observation(), bad_reward, done, self._info() else: fold = action try: p1 = fold['from'] p2 = fold['to'] assignment = fold['assignment'] except (KeyError, TypeError) as e: bad_reward = {'format': 0.0, 'total': -0.1, 'error': str(e)} self.last_reward = bad_reward done = self.step_count >= self.max_steps return self._get_observation(), bad_reward, done, self._info() prev_state = copy.deepcopy(self.paper) result = self.paper.add_crease(p1, p2, assignment) self.step_count += 1 reward = compute_reward( prev_state=prev_state, action_result=result, new_state=self.paper, target=self.target, step=self.step_count, max_steps=self.max_steps, ) self.last_reward = reward done = ( self.step_count >= self.max_steps or reward.get('completion', 0) > 0 ) return self._get_observation(), reward, done, self._info() def _get_observation(self) -> dict: """Returns observation dict with the LLM prompt and raw state.""" desc = None if self.use_semantic and self.target_name and self.target: desc = get_semantic_description(self.target_name, self.target) if self.mode == 'code_as_policy': prompt = code_as_policy_prompt( self.target, max_folds=self.max_steps, semantic_description=desc, ) else: prompt = step_level_prompt( target=self.target, paper_state=self.paper, step=self.step_count, max_steps=self.max_steps, last_reward=self.last_reward, semantic_description=desc, ) return { 'prompt': prompt, 'target_name': self.target_name, 'step': self.step_count, 'paper_fold_json': self.paper.graph.edges if self.paper else {}, } def _info(self) -> dict: """Returns diagnostic info dict for logging.""" if self.paper is None: return {} interior = self.paper.graph.interior_vertices() vertex_scores = check_all_vertices(self.paper.graph) return { 'local_foldability': ( vertex_scores['kawasaki'] == 1.0 and vertex_scores['maekawa'] == 1.0 ), 'blb_satisfied': vertex_scores['blb'] == 1.0, 'global_foldability': 'not_checked', # NP-complete (Bern-Hayes 1996) 'n_interior_vertices': len(interior), 'n_creases': len(self.paper.graph.crease_edges()), 'target_name': self.target_name, } def state(self) -> dict: """Returns current environment state for logging/inspection.""" return { 'paper': { 'vertices': dict(self.paper.graph.vertices), 'edges': { k: v for k, v in self.paper.graph.edges.items() if v[2] in ('M', 'V') }, 'fold_history': self.paper.fold_history, }, 'target': self.target_name, 'step': self.step_count, 'mode': self.mode, } def close(self): """Cleanup.""" pass def clone(self) -> 'OrigamiEnvironment': """Return a deep copy for parallel evaluation (used in GRPO).""" new_env = OrigamiEnvironment( mode=self.mode, max_steps=self.max_steps, targets_dir=str(self.targets_dir), use_semantic=self.use_semantic, ) if self.paper is not None: new_env.paper = copy.deepcopy(self.paper) new_env.target = self.target new_env.target_name = self.target_name new_env.step_count = self.step_count new_env.last_reward = self.last_reward return new_env