| |
| |
| |
| |
| |
|
|
| """Contract Validation Environment Client.""" |
|
|
| from typing import Dict, Any |
|
|
| from openenv.core import EnvClient |
| from openenv.core.client_types import StepResult |
| from openenv.core.env_server.types import State |
|
|
| from models import ContractValidationAction, ContractValidationObservation |
|
|
|
|
| class ContractValidationEnv( |
| EnvClient[ContractValidationAction, ContractValidationObservation, State] |
| ): |
| """Client for the Contract Validation Environment.""" |
|
|
| def _step_payload(self, action: ContractValidationAction) -> Dict[str, Any]: |
| |
| return { |
| "clause_id": action.clause_id, |
| "risk_type": action.risk_type, |
| "submit_final": action.submit_final, |
| "explanation": action.explanation, |
| } |
|
|
| def _parse_result(self, payload: Dict[str, Any]) -> StepResult[ContractValidationObservation]: |
| obs_data = payload.get("observation", {}) |
|
|
| |
| |
| observation = ContractValidationObservation(**obs_data) |
|
|
| return StepResult( |
| observation=observation, |
| reward=payload.get("reward"), |
| done=payload.get("done", False), |
| ) |
|
|
| def _parse_state(self, payload: Dict[str, Any]) -> State: |
| return State( |
| episode_id=payload.get("episode_id"), |
| step_count=payload.get("step_count", 0), |
| ) |
|
|