File size: 12,038 Bytes
96da58e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
"""
Convenience script to tune a robot's joint positions in a mujoco environment.
Allows keyboard presses to move specific robot joints around in the viewer, and
then prints the current joint parameters upon an inputted command

RELEVANT KEY PRESSES:
    '1 - n' : Sets the active robot joint being tuned to this number. Maximum
        is n which is the number of robot joints
    't' : Toggle between robot arms being tuned (only applicable for multi-arm environments)
    'r' : Resets the active joint values to 0
    'UP_ARROW' : Increment the active robot joint position
    'DOWN_ARROW' : Decrement the active robot joint position
    'RIGHT_ARROW' : Increment the delta joint position change per keypress
    'LEFT_ARROW' : Decrement the delta joint position change per keypress

"""

import argparse

import numpy as np
from pynput.keyboard import Controller, Key, Listener

import robosuite
from robosuite.robots import SingleArm


class KeyboardHandler:
    def __init__(self, env, delta=0.05):
        """
        Store internal state here.

        Args:
            env (MujocoEnv): Environment to use
            delta (float): initial joint tuning increment
        """
        self.env = env
        self.delta = delta
        self.num_robots = len(env.robots)
        self.active_robot_num = 0
        self.active_arm_joint = 1
        self.active_arm = "right"  # only relevant for bimanual robots
        self.current_joints_pos = env.sim.data.qpos[self.active_robot._ref_joint_pos_indexes[: self.num_joints]]

        # make a thread to listen to keyboard and register our callback functions
        self.listener = Listener(on_press=self.on_press, on_release=self.on_release)

        # start listening
        self.listener.start()

    def on_press(self, key):
        """
        Key handler for key presses.

        Args:
            key (int): keycode corresponding to the key that was pressed
        """

        try:
            if key == Key.up:
                # Increment the active joint
                self._update_joint_position(self.active_arm_joint, self.delta)
            elif key == Key.down:
                # Decrement the active joint
                self._update_joint_position(self.active_arm_joint, -self.delta)
            elif key == Key.right:
                # Increment the delta value
                self.delta = min(1.0, self.delta + 0.005)
                # Print out new value to user
                print("Delta now = {:.3f}".format(self.delta))
            elif key == Key.left:
                # Decrement the delta value
                self.delta = max(0, self.delta - 0.005)
                print("Delta now = {:.3f}".format(self.delta))
            # controls for setting active arm
            elif key.char == "0":
                # Notify use that joint indexes are 1-indexed
                print("Joint Indexes are 1-Indexed. Available joints are 1 - {}".format(self.num_joints))
            elif key.char == "1":
                # Make sure range is valid; if so, update this specific joint
                if self._check_valid_joint(1):
                    self.active_arm_joint = 1
                    # Print out to user
                    print("New joint being tuned: {}".format(self.active_arm_joint))
            elif key.char == "2":
                # Make sure range is valid; if so, update this specific joint
                if self._check_valid_joint(2):
                    self.active_arm_joint = 2
                    # Print out to user
                    print("New joint being tuned: {}".format(self.active_arm_joint))
            elif key.char == "3":
                # Make sure range is valid; if so, update this specific joint
                if self._check_valid_joint(3):
                    self.active_arm_joint = 3
                    # Print out to user
                    print("New joint being tuned: {}".format(self.active_arm_joint))
            elif key.char == "4":
                # Make sure range is valid; if so, update this specific joint
                if self._check_valid_joint(4):
                    self.active_arm_joint = 4
                    # Print out to user
                    print("New joint being tuned: {}".format(self.active_arm_joint))
            elif key.char == "5":
                # Make sure range is valid; if so, update this specific joint
                if self._check_valid_joint(5):
                    self.active_arm_joint = 5
                    # Print out to user
                    print("New joint being tuned: {}".format(self.active_arm_joint))
            elif key.char == "6":
                # Make sure range is valid; if so, update this specific joint
                if self._check_valid_joint(6):
                    self.active_arm_joint = 6
                    # Print out to user
                    print("New joint being tuned: {}".format(self.active_arm_joint))
            elif key.char == "7":
                # Make sure range is valid; if so, update this specific joint
                if self._check_valid_joint(7):
                    self.active_arm_joint = 7
                    # Print out to user
                    print("New joint being tuned: {}".format(self.active_arm_joint))
            elif key.char == "8":
                # Make sure range is valid; if so, update this specific joint
                if self._check_valid_joint(8):
                    self.active_arm_joint = 8
                    # Print out to user
                    print("New joint being tuned: {}".format(self.active_arm_joint))
            elif key.char == "9":
                # Make sure range is valid; if so, update this specific joint
                if self._check_valid_joint(9):
                    self.active_arm_joint = 9
                    # Print out to user
                    print("New joint being tuned: {}".format(self.active_arm_joint))
            elif key.char == "t":
                # Toggle active arm
                self._toggle_arm()
            elif key.char == "r":
                # Reset active arm joint qpos to 0
                self.set_joint_positions(np.zeros(self.num_joints))

        except AttributeError as e:
            pass

    def on_release(self, key):
        """
        Key handler for key releases.

        Args:
            key: [NOT USED]
        """
        pass

    def set_joint_positions(self, qpos):
        """
        Automatically sets the joint positions to be the given value

        Args:
            qpos (np.array): Joint positions to set
        """
        self.current_joints_pos = qpos
        self._update_joint_position(1, 0)

    def _check_valid_joint(self, i):
        """
        Checks to make sure joint number request @i is within valid range

        Args:
            i (int): Index to validate

        Returns:
            bool: True if index @i is valid, else prints out an error and returns False
        """
        if i > self.num_joints:
            # Print error
            print("Error: Requested joint {} is out of range; available joints are 1 - {}".format(i, self.num_joints))
            return False
        else:
            return True

    def _toggle_arm(self):
        """
        Toggle between arms in the environment to set as current active arm
        """
        if isinstance(self.active_robot, SingleArm):
            self.active_robot_num = (self.active_robot_num + 1) % self.num_robots
            robot = self.active_robot_num
        else:  # Bimanual case
            self.active_arm = "left" if self.active_arm == "right" else "right"
            robot = self.active_arm
        # Reset joint being controlled to 1
        self.active_arm_joint = 1
        # Print out new robot to user
        print("New robot arm being tuned: {}".format(robot))

    def _update_joint_position(self, i, delta):
        """
        Updates specified joint position @i by value @delta from its current position
        Note: assumes @i is already within the valid joint range

        Args:
            i (int): Joint index to update
            delta (float): Increment to alter specific joint by
        """
        self.current_joints_pos[i - 1] += delta
        if isinstance(self.active_robot, SingleArm):
            robot = self.active_robot_num
            self.env.sim.data.qpos[self.active_robot._ref_joint_pos_indexes] = self.current_joints_pos
        else:  # Bimanual case
            robot = self.active_arm
            if self.active_arm == "right":
                self.env.sim.data.qpos[
                    self.active_robot._ref_joint_pos_indexes[: self.num_joints]
                ] = self.current_joints_pos
            else:  # left arm case
                self.env.sim.data.qpos[
                    self.active_robot._ref_joint_pos_indexes[self.num_joints :]
                ] = self.current_joints_pos
        # Print out current joint positions to user
        print("Robot {} joint qpos: {}".format(robot, self.current_joints_pos))

    @property
    def active_robot(self):
        """
        Returns:
            Robot: active robot arm currently being tuned
        """
        return self.env.robots[self.active_robot_num]

    @property
    def num_joints(self):
        """
        Returns:
            int: number of joints for the current arm
        """
        if isinstance(self.active_robot, SingleArm):
            return len(self.active_robot.torque_limits[0])
        else:  # Bimanual arm case
            return int(len(self.active_robot.torque_limits[0]) / 2)


def print_command(char, info):
    """
    Prints out the command + relevant info entered by user

    Args:
        char (str): Command entered
        info (str): Any additional info to print
    """
    char += " " * (10 - len(char))
    print("{}\t{}".format(char, info))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--env", type=str, default="Lift")
    parser.add_argument("--robots", nargs="+", type=str, default="Panda", help="Which robot(s) to use in the env")
    parser.add_argument(
        "--init_qpos", nargs="+", type=float, default=0, help="Initial qpos to use. 0 defaults to all zeros"
    )

    args = parser.parse_args()

    print(
        "\nWelcome to the joint tuning script! You will be able to tune the robot\n"
        "arm joints in the specified environment by using your keyboard. The \n"
        "controls are printed below:"
    )

    print("")
    print_command("Keys", "Command")
    print_command("1-N", "Active Joint being tuned (N=number of joints for the active arm)")
    print_command("t", "Toggle between robot arms in the environment")
    print_command("r", "Reset active arm joints to all 0s")
    print_command("up/down", "incr/decrement the active joint angle")
    print_command("right/left", "incr/decrement the delta joint angle per up/down keypress")
    print("")

    # Setup printing options for numbers
    np.set_printoptions(formatter={"float": lambda x: "{0:0.3f}".format(x)})

    # Define the controller
    controller_config = robosuite.load_controller_config(default_controller="JOINT_POSITION")

    # make the environment
    env = robosuite.make(
        args.env,
        robots=args.robots,
        has_renderer=True,
        has_offscreen_renderer=False,
        ignore_done=True,
        use_camera_obs=False,
        control_freq=20,
        render_camera=None,
        controller_configs=controller_config,
        initialization_noise=None,
    )
    env.reset()

    # register callbacks to handle key presses in the viewer
    key_handler = KeyboardHandler(env=env)

    # Set initial state
    if type(args.init_qpos) == int and args.init_qpos == 0:
        # Default to all zeros
        pass
    else:
        key_handler.set_joint_positions(args.init_qpos)

    # just spin to let user interact with window
    while True:
        action = np.zeros(env.action_dim)
        obs, reward, done, _ = env.step(action)
        env.render()