nepyope commited on
Commit
541ddf1
·
verified ·
1 Parent(s): 1700fd3

Update sim/base_sim.py

Browse files
Files changed (1) hide show
  1. sim/base_sim.py +4 -139
sim/base_sim.py CHANGED
@@ -2,7 +2,7 @@ import argparse
2
  import pathlib
3
  from pathlib import Path
4
  import threading
5
- from threading import Lock, Thread
6
  from typing import Dict
7
 
8
  import mujoco
@@ -18,7 +18,7 @@ from unitree_sdk2py.core.channel import ChannelFactoryInitialize
18
  import yaml
19
  import os
20
  from .image_publish_utils import ImagePublishProcess
21
- from .metric_utils import check_contact, check_height
22
  from .sim_utils import get_subtree_body_names
23
  from .unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge
24
 
@@ -60,8 +60,6 @@ class DefaultEnv:
60
  if len(camera_configs) > 0:
61
  print(f"✓ DefaultEnv initialized with {len(camera_configs)} camera(s): {list(camera_configs.keys())}")
62
 
63
- # Thread safety lock
64
- self.reward_lock = Lock()
65
 
66
  # Unitree bridge will be initialized by the simulator
67
  self.unitree_bridge = None
@@ -71,7 +69,6 @@ class DefaultEnv:
71
 
72
  # Initialize scene (defined in subclasses)
73
  self.init_scene()
74
- self.last_reward = 0
75
 
76
  # Setup offscreen rendering if needed
77
  self.offscreen = offscreen
@@ -439,16 +436,6 @@ class DefaultEnv:
439
  else:
440
  self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
441
 
442
- def update_reward(self):
443
- """Calculate reward. Should be implemented by subclasses."""
444
- with self.reward_lock:
445
- self.last_reward = 0
446
-
447
- def get_reward(self):
448
- """Thread-safe way to get the last calculated reward."""
449
- with self.reward_lock:
450
- return self.last_reward
451
-
452
  def set_unitree_bridge(self, unitree_bridge):
453
  """Set the unitree bridge from the simulator"""
454
  self.unitree_bridge = unitree_bridge
@@ -509,114 +496,6 @@ class DefaultEnv:
509
  mujoco.mj_resetData(self.mj_model, self.mj_data)
510
 
511
 
512
- class CubeEnv(DefaultEnv):
513
- """Environment with a cube object for pick and place tasks"""
514
-
515
- def __init__(
516
- self,
517
- config: Dict[str, any],
518
- onscreen: bool = False,
519
- offscreen: bool = False,
520
- ):
521
- # Override the robot scene
522
- config = config.copy() # Create a copy to avoid modifying the original
523
- config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_cube_43dof.xml"
524
- super().__init__(config, "cube", {}, onscreen, offscreen)
525
-
526
- def update_reward(self):
527
- """Calculate reward based on gripper contact with cube and cube height"""
528
- right_hand_body = [
529
- "right_hand_thumb_2_link",
530
- "right_hand_middle_1_link",
531
- "right_hand_index_1_link",
532
- ]
533
- gripper_cube_contact = check_contact(
534
- self.mj_model, self.mj_data, right_hand_body, "cube_body"
535
- )
536
- cube_lifted = check_height(self.mj_model, self.mj_data, "cube", 0.85, 2.0)
537
-
538
- with self.reward_lock:
539
- self.last_reward = gripper_cube_contact & cube_lifted
540
-
541
-
542
- class BoxEnv(DefaultEnv):
543
- """Environment with a box object for manipulation tasks"""
544
-
545
- def __init__(
546
- self,
547
- config: Dict[str, any],
548
- onscreen: bool = False,
549
- offscreen: bool = False,
550
- ):
551
- # Override the robot scene
552
- config = config.copy() # Create a copy to avoid modifying the original
553
- config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/lift_box_43dof.xml"
554
- super().__init__(config, "box", {}, onscreen, offscreen)
555
-
556
- def reward(self):
557
- """Calculate reward based on gripper contact with cube and cube height"""
558
- left_hand_body = [
559
- "left_hand_thumb_2_link",
560
- "left_hand_middle_1_link",
561
- "left_hand_index_1_link",
562
- ]
563
- right_hand_body = [
564
- "right_hand_thumb_2_link",
565
- "right_hand_middle_1_link",
566
- "right_hand_index_1_link",
567
- ]
568
- gripper_box_contact = check_contact(self.mj_model, self.mj_data, left_hand_body, "box_body")
569
- gripper_box_contact &= check_contact(
570
- self.mj_model, self.mj_data, right_hand_body, "box_body"
571
- )
572
- box_lifted = check_height(self.mj_model, self.mj_data, "box", 0.92, 2.0)
573
-
574
- print("gripper_box_contact: ", gripper_box_contact, "box_lifted: ", box_lifted)
575
-
576
- with self.reward_lock:
577
- self.last_reward = gripper_box_contact & box_lifted
578
- return self.last_reward
579
-
580
-
581
- class BottleEnv(DefaultEnv):
582
- """Environment with a cylinder object for manipulation tasks"""
583
-
584
- def __init__(
585
- self,
586
- config: Dict[str, any],
587
- onscreen: bool = False,
588
- offscreen: bool = False,
589
- ):
590
- # Override the robot scene
591
- config = config.copy() # Create a copy to avoid modifying the original
592
- config["ROBOT_SCENE"] = "gr00t_wbc/control/robot_model/model_data/g1/pnp_bottle_43dof.xml"
593
- camera_configs = {
594
- "egoview": {
595
- "height": 400,
596
- "width": 400,
597
- },
598
- }
599
- super().__init__(
600
- config, "cylinder", camera_configs, onscreen, offscreen
601
- )
602
-
603
- self.bottle_body = self.mj_model.body("bottle_body")
604
- self.bottle_geom = self.mj_model.geom("bottle")
605
-
606
- if self.viewer is not None:
607
- self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
608
- self.viewer.cam.fixedcamid = self.mj_model.camera("egoview").id
609
-
610
- def update_reward(self):
611
- """Calculate reward based on gripper contact with cylinder and cylinder height"""
612
- pass
613
-
614
- def get_privileged_obs(self):
615
- obs_pos = self.mj_data.xpos[self.bottle_body.id]
616
- obs_quat = self.mj_data.xquat[self.bottle_body.id]
617
- return {"bottle_pos": obs_pos, "bottle_quat": obs_quat}
618
-
619
-
620
  class BaseSimulator:
621
  """Base simulator class that handles initialization and running of simulations"""
622
 
@@ -641,21 +520,11 @@ class BaseSimulator:
641
 
642
  # Set update frequencies
643
  self.sim_dt = self.config["SIMULATE_DT"]
644
- self.reward_dt = self.config.get("REWARD_DT", 0.02)
645
  self.image_dt = self.config.get("IMAGE_DT", 0.033333)
646
  self.viewer_dt = self.config.get("VIEWER_DT", 0.02)
647
 
648
- # Create the appropriate environment based on name
649
- if env_name == "default":
650
- self.sim_env = DefaultEnv(config, env_name, **kwargs)
651
- elif env_name == "pnp_cube":
652
- self.sim_env = CubeEnv(config, **kwargs)
653
- elif env_name == "lift_box":
654
- self.sim_env = BoxEnv(config, **kwargs)
655
- elif env_name == "pnp_bottle":
656
- self.sim_env = BottleEnv(config, **kwargs)
657
- else:
658
- raise ValueError(f"Invalid environment name: {env_name}")
659
 
660
  # Initialize the DDS communication layer - should be safe to call multiple times
661
 
@@ -722,10 +591,6 @@ class BaseSimulator:
722
  if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0:
723
  self.sim_env.update_viewer()
724
 
725
- # Calculate reward at reward rate
726
- if sim_cnt % int(self.reward_dt / self.sim_dt) == 0:
727
- self.sim_env.update_reward()
728
-
729
  # Update render caches at image rate
730
  if sim_cnt % int(self.image_dt / self.sim_dt) == 0:
731
  self.sim_env.update_render_caches()
 
2
  import pathlib
3
  from pathlib import Path
4
  import threading
5
+ from threading import Thread
6
  from typing import Dict
7
 
8
  import mujoco
 
18
  import yaml
19
  import os
20
  from .image_publish_utils import ImagePublishProcess
21
+ from .metric_utils import check_contact
22
  from .sim_utils import get_subtree_body_names
23
  from .unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge
24
 
 
60
  if len(camera_configs) > 0:
61
  print(f"✓ DefaultEnv initialized with {len(camera_configs)} camera(s): {list(camera_configs.keys())}")
62
 
 
 
63
 
64
  # Unitree bridge will be initialized by the simulator
65
  self.unitree_bridge = None
 
69
 
70
  # Initialize scene (defined in subclasses)
71
  self.init_scene()
 
72
 
73
  # Setup offscreen rendering if needed
74
  self.offscreen = offscreen
 
436
  else:
437
  self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
438
 
 
 
 
 
 
 
 
 
 
 
439
  def set_unitree_bridge(self, unitree_bridge):
440
  """Set the unitree bridge from the simulator"""
441
  self.unitree_bridge = unitree_bridge
 
496
  mujoco.mj_resetData(self.mj_model, self.mj_data)
497
 
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  class BaseSimulator:
500
  """Base simulator class that handles initialization and running of simulations"""
501
 
 
520
 
521
  # Set update frequencies
522
  self.sim_dt = self.config["SIMULATE_DT"]
 
523
  self.image_dt = self.config.get("IMAGE_DT", 0.033333)
524
  self.viewer_dt = self.config.get("VIEWER_DT", 0.02)
525
 
526
+ # Create the environment
527
+ self.sim_env = DefaultEnv(config, env_name, **kwargs)
 
 
 
 
 
 
 
 
 
528
 
529
  # Initialize the DDS communication layer - should be safe to call multiple times
530
 
 
591
  if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0:
592
  self.sim_env.update_viewer()
593
 
 
 
 
 
594
  # Update render caches at image rate
595
  if sim_cnt % int(self.image_dt / self.sim_dt) == 0:
596
  self.sim_env.update_render_caches()