optigami / env /environment.py
sissississi's picture
iana (#1)
19abe39
raw
history blame
8.58 kB
import json
import os
import copy
from pathlib import Path
from typing import Optional
from .paper_state import PaperState
from .rewards import compute_reward, compute_terminal_reward, load_target, target_crease_edges
from .prompts import (
code_as_policy_prompt,
step_level_prompt,
parse_fold_list,
parse_single_fold,
)
from .verifier import check_all_vertices
TARGETS_DIR = Path(__file__).parent / 'targets'
class OrigamiEnvironment:
"""
OpenEnv-compatible origami crease pattern environment.
Supports two modes:
- code_as_policy: model outputs complete fold sequence, gets terminal reward
- step: model outputs one fold at a time, gets per-step reward
"""
def __init__(
self,
mode: str = 'code_as_policy', # 'code_as_policy' or 'step'
max_steps: int = 8,
targets_dir: Optional[str] = None,
):
assert mode in ('code_as_policy', 'step'), f"Unknown mode: {mode}"
self.mode = mode
self.max_steps = max_steps
self.targets_dir = Path(targets_dir) if targets_dir else TARGETS_DIR
self.paper: Optional[PaperState] = None
self.target: Optional[dict] = None
self.target_name: Optional[str] = None
self.step_count: int = 0
self.last_reward: Optional[dict] = None
# Cache all available targets
self._targets = self._load_all_targets()
def _load_all_targets(self) -> dict[str, dict]:
targets = {}
for fold_file in self.targets_dir.glob('*.fold'):
with open(fold_file) as f:
targets[fold_file.stem] = json.load(f)
return targets
def available_targets(self) -> list[str]:
return sorted(self._targets.keys())
def reset(self, target_name: Optional[str] = None) -> dict:
"""
Reset environment to start of a new episode.
Args:
target_name: name of target (stem of .fold file). If None, picks level-1 randomly.
Returns:
observation dict with 'prompt' key containing the LLM prompt string.
"""
import random
if target_name:
assert target_name in self._targets, f"Unknown target: {target_name}"
self.target_name = target_name
else:
# Default to level-1 targets
level1 = [k for k, v in self._targets.items() if v.get('level', 1) == 1]
self.target_name = random.choice(level1 if level1 else list(self._targets.keys()))
self.target = self._targets[self.target_name]
self.paper = PaperState()
self.step_count = 0
self.last_reward = None
return self._get_observation()
def step(self, action) -> tuple[dict, dict, bool, dict]:
"""
Execute an action.
In code_as_policy mode: action is a string (model completion with <folds> tags)
OR a list of fold dicts already parsed.
In step mode: action is a string (single fold JSON) or dict.
Returns:
(observation, reward, done, info)
"""
if self.mode == 'code_as_policy':
return self._step_sequence(action)
else:
return self._step_single(action)
def _step_sequence(self, action) -> tuple[dict, dict, bool, dict]:
"""Execute a complete fold sequence (code-as-policy mode)."""
# Parse action if it's a string
if isinstance(action, str):
try:
folds = parse_fold_list(action)
except ValueError as e:
bad_reward = {'format': 0.0, 'total': -0.1, 'error': str(e)}
return self._get_observation(), bad_reward, True, self._info()
else:
folds = action # already a list of dicts
# Execute each fold sequentially
last_result = {'valid': True, 'anchored': True, 'new_vertices': [], 'errors': []}
for fold in folds:
try:
p1 = fold['from']
p2 = fold['to']
assignment = fold['assignment']
except (KeyError, TypeError) as e:
last_result = {'valid': False, 'anchored': False, 'new_vertices': [], 'errors': [str(e)]}
break
last_result = self.paper.add_crease(p1, p2, assignment)
self.step_count += 1
if not last_result['valid']:
break # stop at first invalid fold, partial credit
reward = compute_terminal_reward(self.paper, self.target)
self.last_reward = reward
return self._get_observation(), reward, True, self._info()
def _step_single(self, action) -> tuple[dict, dict, bool, dict]:
"""Execute a single fold (step mode)."""
if isinstance(action, str):
try:
fold = parse_single_fold(action)
except ValueError as e:
bad_reward = {'format': 0.0, 'total': -0.1, 'error': str(e)}
self.last_reward = bad_reward
done = self.step_count >= self.max_steps
return self._get_observation(), bad_reward, done, self._info()
else:
fold = action
try:
p1 = fold['from']
p2 = fold['to']
assignment = fold['assignment']
except (KeyError, TypeError) as e:
bad_reward = {'format': 0.0, 'total': -0.1, 'error': str(e)}
self.last_reward = bad_reward
done = self.step_count >= self.max_steps
return self._get_observation(), bad_reward, done, self._info()
result = self.paper.add_crease(p1, p2, assignment)
self.step_count += 1
reward = compute_reward(self.paper, result, self.target)
self.last_reward = reward
done = (
self.step_count >= self.max_steps or
reward.get('completion', 0) > 0
)
return self._get_observation(), reward, done, self._info()
def _get_observation(self) -> dict:
"""Returns observation dict with the LLM prompt and raw state."""
if self.mode == 'code_as_policy':
prompt = code_as_policy_prompt(self.target, max_folds=self.max_steps)
else:
prompt = step_level_prompt(
target=self.target,
paper_state=self.paper,
step=self.step_count,
max_steps=self.max_steps,
last_reward=self.last_reward,
)
return {
'prompt': prompt,
'target_name': self.target_name,
'step': self.step_count,
'paper_fold_json': self.paper.graph.edges if self.paper else {},
}
def _info(self) -> dict:
"""Returns diagnostic info dict for logging."""
if self.paper is None:
return {}
interior = self.paper.graph.interior_vertices()
vertex_scores = check_all_vertices(self.paper.graph)
return {
'local_foldability': (
vertex_scores['kawasaki'] == 1.0 and
vertex_scores['maekawa'] == 1.0
),
'blb_satisfied': vertex_scores['blb'] == 1.0,
'global_foldability': 'not_checked', # NP-complete (Bern-Hayes 1996)
'n_interior_vertices': len(interior),
'n_creases': len(self.paper.graph.crease_edges()),
'target_name': self.target_name,
}
def state(self) -> dict:
"""Returns current environment state for logging/inspection."""
return {
'paper': {
'vertices': dict(self.paper.graph.vertices),
'edges': {
k: v for k, v in self.paper.graph.edges.items()
if v[2] in ('M', 'V')
},
'fold_history': self.paper.fold_history,
},
'target': self.target_name,
'step': self.step_count,
'mode': self.mode,
}
def close(self):
"""Cleanup."""
pass
def clone(self) -> 'OrigamiEnvironment':
"""Return a deep copy for parallel evaluation (used in GRPO)."""
new_env = OrigamiEnvironment(
mode=self.mode,
max_steps=self.max_steps,
targets_dir=str(self.targets_dir),
)
if self.paper is not None:
new_env.paper = copy.deepcopy(self.paper)
new_env.target = self.target
new_env.target_name = self.target_name
new_env.step_count = self.step_count
new_env.last_reward = self.last_reward
return new_env