File size: 5,092 Bytes
a291cc7
58bc650
a291cc7
 
 
58bc650
32fd889
4891d33
4481348
32fd889
 
e154616
32fd889
08c99f6
 
32fd889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb3e844
32fd889
a91de1d
32fd889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb3e844
a91de1d
32fd889
f1e56e0
32fd889
 
 
 
cb3e844
32fd889
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb3e844
a91de1d
32fd889
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from pathlib import Path
import sys
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from huggingface_hub import snapshot_download
import yaml
snapshot_download("lerobot/unitree-g1-mujoco")

# Ensure sim module is importable
sys.path.insert(0, str(Path(__file__).parent))

from sim.simulator_factory import SimulatorFactory, init_channel


def make_env(n_envs=1, use_async_envs=False, **kwargs):
    """Create a UnitreeG1 simulation environment.
    
    Args:
        n_envs: Number of environments (currently only 1 supported)
        use_async_envs: Whether to use async envs (not implemented)
        **kwargs: Additional arguments passed to simulator
            - publish_images: bool, whether to publish camera images via ZMQ
            - camera_port: int, ZMQ port for camera images
            - cameras: list of camera names
    """
    repo_dir = Path(__file__).parent
    
    # Load config
    config_path = repo_dir / "config.yaml"
    with open(config_path) as f:
        config = yaml.safe_load(f)
    
    # Configure cameras if requested
    publish_images = kwargs.get("publish_images", True)
    camera_port = kwargs.get("camera_port", 5555)
    cameras = kwargs.get("cameras", ["head_camera"])
    
    enable_offscreen = publish_images or config.get("ENABLE_OFFSCREEN", False)
    camera_configs = {}
    if enable_offscreen:
        for cam_name in cameras:
            camera_configs[cam_name] = {"height": 480, "width": 640}
    
    # Initialize DDS channel
    init_channel(config=config)
    
    # Create simulator (runs in same process)
    simulator = SimulatorFactory.create_simulator(
        config=config,
        env_name="default",
        onscreen=config.get("ENABLE_ONSCREEN", True),
        offscreen=enable_offscreen,
        camera_configs=camera_configs,
    )
    
    class UnitreeG1Env(gym.Env):
        """Gymnasium environment wrapper for Unitree G1 MuJoCo simulation."""
        
        metadata = {"render_modes": ["human"]}

        def __init__(self, sim, cam_configs, pub_images, cam_port):
            super().__init__()
            self.simulator = sim
            self.sim_env = sim.sim_env
            self.step_count = 0
            self.camera_configs = cam_configs
            
            # Get timing from config
            self.sim_dt = config["SIMULATE_DT"]
            self.viewer_dt = config.get("VIEWER_DT", 0.02)
            self.image_dt = config.get("IMAGE_DT", 0.033333)
            
            # Start image publishing subprocess if requested
            if pub_images and len(cam_configs) > 0:
                sim.start_image_publish_subprocess(
                    start_method=config.get("MP_START_METHOD", "spawn"),
                    camera_port=cam_port
                )
                print(f"Camera images publishing on tcp://localhost:{cam_port}")
            
            # Define spaces
            num_joints = config.get("NUM_MOTORS", 29)
            self.action_space = spaces.Box(-np.pi, np.pi, shape=(num_joints,), dtype=np.float32)
            self.observation_space = spaces.Box(-np.inf, np.inf, shape=(num_joints * 3 + 10,), dtype=np.float32)

        def reset(self, seed=None, options=None):
            """Reset the simulation."""
            super().reset(seed=seed, options=options)
            self.simulator.reset()
            self.step_count = 0
            obs = self._get_obs()
            return obs, {}

        def step(self, action=None):
            """Execute one simulation step. Caller handles timing."""
            # Run physics step
            self.sim_env.sim_step()
            
            # Update viewer at viewer rate
            if self.step_count % int(self.viewer_dt / self.sim_dt) == 0:
                self.sim_env.update_viewer()
            
            # Update render caches at image rate (for ZMQ publishing)
            if self.step_count % int(self.image_dt / self.sim_dt) == 0:
                self.sim_env.update_render_caches()
            
            self.step_count += 1
            
            obs = self._get_obs()
            reward = 0.0
            terminated = False
            truncated = False
            info = {}
            return obs, reward, terminated, truncated, info
        
        def _get_obs(self):
            """Get current observation from simulation."""
            obs_dict = self.sim_env.prepare_obs()
            obs = np.concatenate([
                obs_dict.get("body_q", np.zeros(29)),
                obs_dict.get("body_dq", np.zeros(29)),
                obs_dict.get("body_tau_est", np.zeros(29)),
                obs_dict.get("floating_base_pose", np.zeros(7))[:4],
                obs_dict.get("floating_base_vel", np.zeros(6))[:3],
                obs_dict.get("floating_base_acc", np.zeros(3)),
            ]).astype(np.float32)
            return obs

        def close(self):
            """Close the simulation."""
            print("++++closing simulator ++++")
            self.simulator.close()

    return UnitreeG1Env(simulator, camera_configs, publish_images, camera_port)