File size: 2,812 Bytes
8446e1f f355fb3 8446e1f f355fb3 8446e1f c3efe87 8446e1f c3efe87 8446e1f c3efe87 8446e1f c3efe87 8446e1f c3efe87 8446e1f c3efe87 8446e1f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | import gymnasium as gym
import mujoco
import numpy as np
import os
# --- 1. SO-101 ์ปค์คํ
Gym ํ๊ฒฝ (์ฐ๋ฆฌ๊ฐ ๋ง๋ค์๋ ๊ฒ) ---
# ์ด ํ์ผ(env.py) ๊ธฐ์ค ์๋ ๊ฒฝ๋ก
MODEL_PATH = os.path.join(
os.path.dirname(__file__), "SO-ARM100", "Simulation", "SO101", "scene.xml"
)
class SO101SimEnv(gym.Env):
"""SO-101 MuJoCo ์๋ฎฌ๋ ์ด์
์ ์ํ ์ปค์คํ
Gym ํ๊ฒฝ"""
metadata = {"render_modes": ["human"], "render_fps": 30}
def __init__(self, render_mode=None):
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(
f"SO-101 ๋ชจ๋ธ ํ์ผ({MODEL_PATH})์ ์ฐพ์ ์ ์์ต๋๋ค. "
f"so100-arm ์ ์ฅ์๊ฐ ์ฌ๋ฐ๋ฅด๊ฒ ๋ณต์ ๋์๋์ง ํ์ธํ์ธ์."
)
self.model = mujoco.MjModel.from_xml_path(MODEL_PATH)
self.data = mujoco.MjData(self.model)
self.render_mode = render_mode
self.viewer = None
# 6๊ฐ ๋ชจํฐ
self.action_space = gym.spaces.Box(
low=-1.0, high=1.0, shape=(6,), dtype=np.float32
)
# 12๊ฐ ์ํ (6๊ฐ ๊ด์ ๊ฐ๋ + 6๊ฐ ๊ด์ ์๋)
self.observation_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(12,), dtype=np.float32
)
def _get_obs(self):
return np.concatenate([self.data.qpos[:6], self.data.qvel[:6]]).astype(
np.float32
)
def step(self, action):
self.data.ctrl[:6] = action
mujoco.mj_step(self.model, self.data)
if self.render_mode == "human":
self.render()
obs = self._get_obs()
return obs, 0.0, False, False, {} # (obs, reward, terminated, truncated, info)
def reset(self, seed=None, options=None):
super().reset(seed=seed)
mujoco.mj_resetData(self.model, self.data)
return self._get_obs(), {}
def render(self):
if self.viewer is None:
os.environ["MUJOCO_GL"] = "egl" # WSL ๋ ๋๋ง ํธํ์ฑ
from mujoco.viewer import launch_passive
self.viewer = launch_passive(self.model, self.data)
self.viewer.sync()
def close(self):
if self.viewer:
self.viewer.close()
# --- 2. EnvHub๋ฅผ ์ํ make_env ํจ์ (์๋ก์ด ๋ถ๋ถ) ---
def make_env(n_envs: int = 1, use_async_envs: bool = False):
"""
LeRobot EnvHub๊ฐ ํธ์ถํ SO-101 ํ๊ฒฝ ์์ฑ ํฉํ ๋ฆฌ ํจ์
"""
def _make_single_env():
# 1๋ฒ์์ ๋ง๋ ์ฐ๋ฆฌ ์ปค์คํ
ํ๊ฒฝ์ ๋ฆฌํด
return SO101SimEnv(render_mode="human")
# LeRobot ๋ฌธ์์ ํ์ค ๋ฐฉ์ (๋ฒกํฐ ํ๊ฒฝ์ผ๋ก ๋ํ)
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
vec_env = env_cls([_make_single_env for _ in range(n_envs)])
return vec_env
|