# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import random from typing import Dict, List, Optional from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State try: from ..models import BRIDGE_TYPES, MATERIALS, BridgeForgeAction, BridgeForgeObservation except (ImportError, ValueError): from models import BRIDGE_TYPES, MATERIALS, BridgeForgeAction, BridgeForgeObservation try: from .scenarios import SCENARIOS, get_scenario, get_visible_constraints from .simulation import run_simulation from .reward import compute_reward except (ImportError, ValueError): from server.scenarios import SCENARIOS, get_scenario, get_visible_constraints from server.simulation import run_simulation from server.reward import compute_reward class BridgeForgeEnvironment(Environment): SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): self._state = State(episode_id=str(uuid4()), step_count=0) self._bridge_type: Optional[str] = None self._nodes: List[Dict] = [] self._members: List[Dict] = [] self._supports: List[Dict] = [] self._loads: List[Dict] = [] self._simulation_result: Optional[Dict] = None self._scenario: Optional[Dict] = None self._done = False def reset(self, seed: Optional[int] = None, **kwargs) -> BridgeForgeObservation: if seed is not None: random.seed(seed) scenario_id = kwargs.get("scenario_id") if scenario_id: self._scenario = get_scenario(scenario_id) else: self._scenario = random.choice(SCENARIOS) self._state = State(episode_id=str(uuid4()), step_count=0) self._bridge_type = None self._nodes = [] self._members = [] self._supports = [] self._loads = [] self._simulation_result = None self._done = False return self._make_observation(message="Environment reset. Choose a bridge type to begin.") def step(self, action: BridgeForgeAction, **kwargs) -> BridgeForgeObservation: if self._done: return self._make_observation(message="Episode is done. Call reset() to start a new episode.") if self._scenario is None: return self._make_observation(message="No scenario loaded. Call reset() first.") self._state.step_count += 1 action_type = action.action_type params = action.params handler = { "select_type": self._handle_select_type, "add_node": self._handle_add_node, "add_member": self._handle_add_member, "add_support": self._handle_add_support, "add_load": self._handle_add_load, "simulate": self._handle_simulate, "submit": self._handle_submit, }.get(action_type) if handler is None: return self._make_observation( message=f"Unknown action_type: {action_type}. " f"Valid types: select_type, add_node, add_member, add_support, add_load, simulate, submit" ) return handler(params) def _handle_select_type(self, params: Dict) -> BridgeForgeObservation: bridge_type = params.get("bridge_type", "") if bridge_type not in BRIDGE_TYPES: return self._make_observation( message=f"Invalid bridge type '{bridge_type}'. Choose from: {BRIDGE_TYPES}" ) self._bridge_type = bridge_type return self._make_observation(message=f"Bridge type set to '{bridge_type}'.") def _handle_add_node(self, params: Dict) -> BridgeForgeObservation: node_id = params.get("node_id") x = params.get("x") y = params.get("y") if node_id is None or x is None or y is None: return self._make_observation(message="add_node requires: node_id, x, y") for n in self._nodes: if n["node_id"] == node_id: return self._make_observation(message=f"Node '{node_id}' already exists.") self._nodes.append({"node_id": str(node_id), "x": float(x), "y": float(y)}) return self._make_observation(message=f"Node '{node_id}' added at ({x}, {y}).") def _handle_add_member(self, params: Dict) -> BridgeForgeObservation: member_id = params.get("member_id") node_start = params.get("node_start") node_end = params.get("node_end") material = params.get("material", "steel") section_area = params.get("section_area", 0.01) if member_id is None or node_start is None or node_end is None: return self._make_observation(message="add_member requires: member_id, node_start, node_end") if material not in MATERIALS: return self._make_observation( message=f"Invalid material '{material}'. Choose from: {list(MATERIALS.keys())}" ) node_ids = {n["node_id"] for n in self._nodes} if str(node_start) not in node_ids: return self._make_observation(message=f"Node '{node_start}' not found.") if str(node_end) not in node_ids: return self._make_observation(message=f"Node '{node_end}' not found.") for m in self._members: if m["member_id"] == member_id: return self._make_observation(message=f"Member '{member_id}' already exists.") self._members.append({ "member_id": str(member_id), "node_start": str(node_start), "node_end": str(node_end), "material": material, "section_area": float(section_area), }) return self._make_observation( message=f"Member '{member_id}' added: {node_start} -> {node_end} ({material}, A={section_area})." ) def _handle_add_support(self, params: Dict) -> BridgeForgeObservation: node_id = params.get("node_id") support_type = params.get("support_type") if node_id is None or support_type is None: return self._make_observation(message="add_support requires: node_id, support_type (pin|roller)") if support_type not in ("pin", "roller"): return self._make_observation(message="support_type must be 'pin' or 'roller'.") node_ids = {n["node_id"] for n in self._nodes} if str(node_id) not in node_ids: return self._make_observation(message=f"Node '{node_id}' not found.") for s in self._supports: if s["node_id"] == str(node_id): return self._make_observation(message=f"Support already exists at node '{node_id}'.") self._supports.append({"node_id": str(node_id), "support_type": support_type}) return self._make_observation(message=f"Support ({support_type}) added at node '{node_id}'.") def _handle_add_load(self, params: Dict) -> BridgeForgeObservation: node_id = params.get("node_id") Fx = params.get("Fx", 0.0) Fy = params.get("Fy", 0.0) if node_id is None: return self._make_observation(message="add_load requires: node_id") node_ids = {n["node_id"] for n in self._nodes} if str(node_id) not in node_ids: return self._make_observation(message=f"Node '{node_id}' not found.") self._loads.append({"node_id": str(node_id), "Fx": float(Fx), "Fy": float(Fy)}) return self._make_observation( message=f"Load applied at node '{node_id}': Fx={Fx} kN, Fy={Fy} kN." ) def _handle_simulate(self, params: Dict) -> BridgeForgeObservation: if not self._nodes or not self._members: return self._make_observation(message="Cannot simulate: add nodes and members first.") if not self._supports: return self._make_observation(message="Cannot simulate: add at least one support.") if not self._loads: return self._make_observation(message="Cannot simulate: add at least one load.") result = run_simulation( nodes=self._nodes, members=self._members, supports=self._supports, loads=self._loads, constraints=self._scenario["constraints"], ) self._simulation_result = result reward = compute_reward( bridge_type=self._bridge_type or "", nodes=self._nodes, simulation_result=result, constraints=self._scenario["constraints"], is_submit=False, ) return self._make_observation( message="Simulation complete.", reward=reward, ) def _handle_submit(self, params: Dict) -> BridgeForgeObservation: if self._simulation_result is None: return self._make_observation(message="Run simulate() before submitting.") reward = compute_reward( bridge_type=self._bridge_type or "", nodes=self._nodes, simulation_result=self._simulation_result, constraints=self._scenario["constraints"], is_submit=True, ) self._done = True return self._make_observation( message=f"Design submitted. Final reward: {reward}", reward=reward, done=True, ) def _make_observation( self, message: str = "", reward: float = 0.0, done: bool = False, ) -> BridgeForgeObservation: visible_constraints = {} if self._scenario: visible_constraints = get_visible_constraints(self._scenario["constraints"]) return BridgeForgeObservation( scenario=self._scenario["scenario"] if self._scenario else "", bridge_type=self._bridge_type, nodes=list(self._nodes), members=list(self._members), supports=list(self._supports), loads=list(self._loads), simulation_result=self._simulation_result, constraints=visible_constraints, step_count=self._state.step_count, done=done, reward=reward, message=message, ) @property def state(self) -> State: return self._state