import gymnasium as gym import mujoco import numpy as np import os # --- 1. SO-101 커스텀 Gym 환경 (우리가 만들었던 것) --- # 이 파일(env.py) 기준 상대 경로 MODEL_PATH = os.path.join( os.path.dirname(__file__), "SO-ARM100", "Simulation", "SO101", "scene.xml" ) class SO101SimEnv(gym.Env): """SO-101 MuJoCo 시뮬레이션을 위한 커스텀 Gym 환경""" metadata = {"render_modes": ["human"], "render_fps": 30} def __init__(self, render_mode=None): if not os.path.exists(MODEL_PATH): raise FileNotFoundError( f"SO-101 모델 파일({MODEL_PATH})을 찾을 수 없습니다. " f"so100-arm 저장소가 올바르게 복제되었는지 확인하세요." ) self.model = mujoco.MjModel.from_xml_path(MODEL_PATH) self.data = mujoco.MjData(self.model) self.render_mode = render_mode self.viewer = None # 6개 모터 self.action_space = gym.spaces.Box( low=-1.0, high=1.0, shape=(6,), dtype=np.float32 ) # 12개 상태 (6개 관절 각도 + 6개 관절 속도) self.observation_space = gym.spaces.Box( low=-np.inf, high=np.inf, shape=(12,), dtype=np.float32 ) def _get_obs(self): return np.concatenate([self.data.qpos[:6], self.data.qvel[:6]]).astype( np.float32 ) def step(self, action): self.data.ctrl[:6] = action mujoco.mj_step(self.model, self.data) if self.render_mode == "human": self.render() obs = self._get_obs() return obs, 0.0, False, False, {} # (obs, reward, terminated, truncated, info) def reset(self, seed=None, options=None): super().reset(seed=seed) mujoco.mj_resetData(self.model, self.data) return self._get_obs(), {} def render(self): if self.viewer is None: os.environ["MUJOCO_GL"] = "egl" # WSL 렌더링 호환성 from mujoco.viewer import launch_passive self.viewer = launch_passive(self.model, self.data) self.viewer.sync() def close(self): if self.viewer: self.viewer.close() # --- 2. EnvHub를 위한 make_env 함수 (새로운 부분) --- def make_env(n_envs: int = 1, use_async_envs: bool = False): """ LeRobot EnvHub가 호출할 SO-101 환경 생성 팩토리 함수 """ def _make_single_env(): # 1번에서 만든 우리 커스텀 환경을 리턴 return SO101SimEnv(render_mode="human") # LeRobot 문서의 표준 방식 (벡터 환경으로 래핑) env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv vec_env = env_cls([_make_single_env for _ in range(n_envs)]) return vec_env