Update sim/base_sim.py
Browse files- sim/base_sim.py +4 -139
sim/base_sim.py
CHANGED
|
@@ -2,7 +2,7 @@ import argparse
|
|
| 2 |
import pathlib
|
| 3 |
from pathlib import Path
|
| 4 |
import threading
|
| 5 |
-
from threading import
|
| 6 |
from typing import Dict
|
| 7 |
|
| 8 |
import mujoco
|
|
@@ -18,7 +18,7 @@ from unitree_sdk2py.core.channel import ChannelFactoryInitialize
|
|
| 18 |
import yaml
|
| 19 |
import os
|
| 20 |
from .image_publish_utils import ImagePublishProcess
|
| 21 |
-
from .metric_utils import check_contact
|
| 22 |
from .sim_utils import get_subtree_body_names
|
| 23 |
from .unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge
|
| 24 |
|
|
@@ -60,8 +60,6 @@ class DefaultEnv:
|
|
| 60 |
if len(camera_configs) > 0:
|
| 61 |
print(f"✓ DefaultEnv initialized with {len(camera_configs)} camera(s): {list(camera_configs.keys())}")
|
| 62 |
|
| 63 |
-
# Thread safety lock
|
| 64 |
-
self.reward_lock = Lock()
|
| 65 |
|
| 66 |
# Unitree bridge will be initialized by the simulator
|
| 67 |
self.unitree_bridge = None
|
|
@@ -71,7 +69,6 @@ class DefaultEnv:
|
|
| 71 |
|
| 72 |
# Initialize scene (defined in subclasses)
|
| 73 |
self.init_scene()
|
| 74 |
-
self.last_reward = 0
|
| 75 |
|
| 76 |
# Setup offscreen rendering if needed
|
| 77 |
self.offscreen = offscreen
|
|
@@ -439,16 +436,6 @@ class DefaultEnv:
|
|
| 439 |
else:
|
| 440 |
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
|
| 441 |
|
| 442 |
-
def update_reward(self):
|
| 443 |
-
"""Calculate reward. Should be implemented by subclasses."""
|
| 444 |
-
with self.reward_lock:
|
| 445 |
-
self.last_reward = 0
|
| 446 |
-
|
| 447 |
-
def get_reward(self):
|
| 448 |
-
"""Thread-safe way to get the last calculated reward."""
|
| 449 |
-
with self.reward_lock:
|
| 450 |
-
return self.last_reward
|
| 451 |
-
|
| 452 |
def set_unitree_bridge(self, unitree_bridge):
|
| 453 |
"""Set the unitree bridge from the simulator"""
|
| 454 |
self.unitree_bridge = unitree_bridge
|
|
@@ -509,114 +496,6 @@ class DefaultEnv:
|
|
| 509 |
mujoco.mj_resetData(self.mj_model, self.mj_data)
|
| 510 |
|
| 511 |
|
| 512 |
-
class CubeEnv(DefaultEnv):
|
| 513 |
-
"""Environment with a cube object for pick and place tasks"""
|
| 514 |
-
|
| 515 |
-
def __init__(
|
| 516 |
-
self,
|
| 517 |
-
config: Dict[str, any],
|
| 518 |
-
onscreen: bool = False,
|
| 519 |
-
offscreen: bool = False,
|
| 520 |
-
):
|
| 521 |
-
# Override the robot scene
|
| 522 |
-
config = config.copy() # Create a copy to avoid modifying the original
|
| 523 |
-
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_cube_43dof.xml"
|
| 524 |
-
super().__init__(config, "cube", {}, onscreen, offscreen)
|
| 525 |
-
|
| 526 |
-
def update_reward(self):
|
| 527 |
-
"""Calculate reward based on gripper contact with cube and cube height"""
|
| 528 |
-
right_hand_body = [
|
| 529 |
-
"right_hand_thumb_2_link",
|
| 530 |
-
"right_hand_middle_1_link",
|
| 531 |
-
"right_hand_index_1_link",
|
| 532 |
-
]
|
| 533 |
-
gripper_cube_contact = check_contact(
|
| 534 |
-
self.mj_model, self.mj_data, right_hand_body, "cube_body"
|
| 535 |
-
)
|
| 536 |
-
cube_lifted = check_height(self.mj_model, self.mj_data, "cube", 0.85, 2.0)
|
| 537 |
-
|
| 538 |
-
with self.reward_lock:
|
| 539 |
-
self.last_reward = gripper_cube_contact & cube_lifted
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
class BoxEnv(DefaultEnv):
|
| 543 |
-
"""Environment with a box object for manipulation tasks"""
|
| 544 |
-
|
| 545 |
-
def __init__(
|
| 546 |
-
self,
|
| 547 |
-
config: Dict[str, any],
|
| 548 |
-
onscreen: bool = False,
|
| 549 |
-
offscreen: bool = False,
|
| 550 |
-
):
|
| 551 |
-
# Override the robot scene
|
| 552 |
-
config = config.copy() # Create a copy to avoid modifying the original
|
| 553 |
-
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/lift_box_43dof.xml"
|
| 554 |
-
super().__init__(config, "box", {}, onscreen, offscreen)
|
| 555 |
-
|
| 556 |
-
def reward(self):
|
| 557 |
-
"""Calculate reward based on gripper contact with cube and cube height"""
|
| 558 |
-
left_hand_body = [
|
| 559 |
-
"left_hand_thumb_2_link",
|
| 560 |
-
"left_hand_middle_1_link",
|
| 561 |
-
"left_hand_index_1_link",
|
| 562 |
-
]
|
| 563 |
-
right_hand_body = [
|
| 564 |
-
"right_hand_thumb_2_link",
|
| 565 |
-
"right_hand_middle_1_link",
|
| 566 |
-
"right_hand_index_1_link",
|
| 567 |
-
]
|
| 568 |
-
gripper_box_contact = check_contact(self.mj_model, self.mj_data, left_hand_body, "box_body")
|
| 569 |
-
gripper_box_contact &= check_contact(
|
| 570 |
-
self.mj_model, self.mj_data, right_hand_body, "box_body"
|
| 571 |
-
)
|
| 572 |
-
box_lifted = check_height(self.mj_model, self.mj_data, "box", 0.92, 2.0)
|
| 573 |
-
|
| 574 |
-
print("gripper_box_contact: ", gripper_box_contact, "box_lifted: ", box_lifted)
|
| 575 |
-
|
| 576 |
-
with self.reward_lock:
|
| 577 |
-
self.last_reward = gripper_box_contact & box_lifted
|
| 578 |
-
return self.last_reward
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
class BottleEnv(DefaultEnv):
|
| 582 |
-
"""Environment with a cylinder object for manipulation tasks"""
|
| 583 |
-
|
| 584 |
-
def __init__(
|
| 585 |
-
self,
|
| 586 |
-
config: Dict[str, any],
|
| 587 |
-
onscreen: bool = False,
|
| 588 |
-
offscreen: bool = False,
|
| 589 |
-
):
|
| 590 |
-
# Override the robot scene
|
| 591 |
-
config = config.copy() # Create a copy to avoid modifying the original
|
| 592 |
-
config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_bottle_43dof.xml"
|
| 593 |
-
camera_configs = {
|
| 594 |
-
"egoview": {
|
| 595 |
-
"height": 400,
|
| 596 |
-
"width": 400,
|
| 597 |
-
},
|
| 598 |
-
}
|
| 599 |
-
super().__init__(
|
| 600 |
-
config, "cylinder", camera_configs, onscreen, offscreen
|
| 601 |
-
)
|
| 602 |
-
|
| 603 |
-
self.bottle_body = self.mj_model.body("bottle_body")
|
| 604 |
-
self.bottle_geom = self.mj_model.geom("bottle")
|
| 605 |
-
|
| 606 |
-
if self.viewer is not None:
|
| 607 |
-
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
|
| 608 |
-
self.viewer.cam.fixedcamid = self.mj_model.camera("egoview").id
|
| 609 |
-
|
| 610 |
-
def update_reward(self):
|
| 611 |
-
"""Calculate reward based on gripper contact with cylinder and cylinder height"""
|
| 612 |
-
pass
|
| 613 |
-
|
| 614 |
-
def get_privileged_obs(self):
|
| 615 |
-
obs_pos = self.mj_data.xpos[self.bottle_body.id]
|
| 616 |
-
obs_quat = self.mj_data.xquat[self.bottle_body.id]
|
| 617 |
-
return {"bottle_pos": obs_pos, "bottle_quat": obs_quat}
|
| 618 |
-
|
| 619 |
-
|
| 620 |
class BaseSimulator:
|
| 621 |
"""Base simulator class that handles initialization and running of simulations"""
|
| 622 |
|
|
@@ -641,21 +520,11 @@ class BaseSimulator:
|
|
| 641 |
|
| 642 |
# Set update frequencies
|
| 643 |
self.sim_dt = self.config["SIMULATE_DT"]
|
| 644 |
-
self.reward_dt = self.config.get("REWARD_DT", 0.02)
|
| 645 |
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
|
| 646 |
self.viewer_dt = self.config.get("VIEWER_DT", 0.02)
|
| 647 |
|
| 648 |
-
# Create the
|
| 649 |
-
|
| 650 |
-
self.sim_env = DefaultEnv(config, env_name, **kwargs)
|
| 651 |
-
elif env_name == "pnp_cube":
|
| 652 |
-
self.sim_env = CubeEnv(config, **kwargs)
|
| 653 |
-
elif env_name == "lift_box":
|
| 654 |
-
self.sim_env = BoxEnv(config, **kwargs)
|
| 655 |
-
elif env_name == "pnp_bottle":
|
| 656 |
-
self.sim_env = BottleEnv(config, **kwargs)
|
| 657 |
-
else:
|
| 658 |
-
raise ValueError(f"Invalid environment name: {env_name}")
|
| 659 |
|
| 660 |
# Initialize the DDS communication layer - should be safe to call multiple times
|
| 661 |
|
|
@@ -722,10 +591,6 @@ class BaseSimulator:
|
|
| 722 |
if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0:
|
| 723 |
self.sim_env.update_viewer()
|
| 724 |
|
| 725 |
-
# Calculate reward at reward rate
|
| 726 |
-
if sim_cnt % int(self.reward_dt / self.sim_dt) == 0:
|
| 727 |
-
self.sim_env.update_reward()
|
| 728 |
-
|
| 729 |
# Update render caches at image rate
|
| 730 |
if sim_cnt % int(self.image_dt / self.sim_dt) == 0:
|
| 731 |
self.sim_env.update_render_caches()
|
|
|
|
| 2 |
import pathlib
|
| 3 |
from pathlib import Path
|
| 4 |
import threading
|
| 5 |
+
from threading import Thread
|
| 6 |
from typing import Dict
|
| 7 |
|
| 8 |
import mujoco
|
|
|
|
| 18 |
import yaml
|
| 19 |
import os
|
| 20 |
from .image_publish_utils import ImagePublishProcess
|
| 21 |
+
from .metric_utils import check_contact
|
| 22 |
from .sim_utils import get_subtree_body_names
|
| 23 |
from .unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge
|
| 24 |
|
|
|
|
| 60 |
if len(camera_configs) > 0:
|
| 61 |
print(f"✓ DefaultEnv initialized with {len(camera_configs)} camera(s): {list(camera_configs.keys())}")
|
| 62 |
|
|
|
|
|
|
|
| 63 |
|
| 64 |
# Unitree bridge will be initialized by the simulator
|
| 65 |
self.unitree_bridge = None
|
|
|
|
| 69 |
|
| 70 |
# Initialize scene (defined in subclasses)
|
| 71 |
self.init_scene()
|
|
|
|
| 72 |
|
| 73 |
# Setup offscreen rendering if needed
|
| 74 |
self.offscreen = offscreen
|
|
|
|
| 436 |
else:
|
| 437 |
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
|
| 438 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 439 |
def set_unitree_bridge(self, unitree_bridge):
|
| 440 |
"""Set the unitree bridge from the simulator"""
|
| 441 |
self.unitree_bridge = unitree_bridge
|
|
|
|
| 496 |
mujoco.mj_resetData(self.mj_model, self.mj_data)
|
| 497 |
|
| 498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
class BaseSimulator:
|
| 500 |
"""Base simulator class that handles initialization and running of simulations"""
|
| 501 |
|
|
|
|
| 520 |
|
| 521 |
# Set update frequencies
|
| 522 |
self.sim_dt = self.config["SIMULATE_DT"]
|
|
|
|
| 523 |
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
|
| 524 |
self.viewer_dt = self.config.get("VIEWER_DT", 0.02)
|
| 525 |
|
| 526 |
+
# Create the environment
|
| 527 |
+
self.sim_env = DefaultEnv(config, env_name, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
|
| 529 |
# Initialize the DDS communication layer - should be safe to call multiple times
|
| 530 |
|
|
|
|
| 591 |
if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0:
|
| 592 |
self.sim_env.update_viewer()
|
| 593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
# Update render caches at image rate
|
| 595 |
if sim_cnt % int(self.image_dt / self.sim_dt) == 0:
|
| 596 |
self.sim_env.update_render_caches()
|