chenhaojun commited on
Commit
885b6c5
·
verified ·
1 Parent(s): 3ad2bdb

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. VRL3/LICENSE +21 -0
  2. VRL3/src/cfgs_adroit/task/door.yaml +4 -0
  3. VRL3/src/cfgs_adroit/task/relocate.yaml +4 -0
  4. VRL3/src/logger.py +182 -0
  5. VRL3/src/replay_buffer.py +222 -0
  6. VRL3/src/rrl_local/__pycache__/rrl_multicam.cpython-38.pyc +0 -0
  7. VRL3/src/stage1_models.py +318 -0
  8. VRL3/src/train_stage1.py +493 -0
  9. VRL3/src/utils.py +149 -0
  10. VRL3/src/vrl3_agent.py +632 -0
  11. gym-0.21.0/.github/stale.yml +62 -0
  12. gym-0.21.0/CONTRIBUTING.md +18 -0
  13. gym-0.21.0/README.md +57 -0
  14. gym-0.21.0/docs/toy_text/blackjack.md +60 -0
  15. gym-0.21.0/docs/toy_text/taxi.md +92 -0
  16. gym-0.21.0/scripts/generate_json.py +119 -0
  17. gym-0.21.0/setup.py +76 -0
  18. mujoco-py-2.1.2.14/.gitignore +55 -0
  19. mujoco-py-2.1.2.14/docs/_static/.gitkeep +0 -0
  20. mujoco-py-2.1.2.14/docs/build/doctrees/reference.doctree +0 -0
  21. mujoco-py-2.1.2.14/mujoco_py.egg-info/SOURCES.txt +67 -0
  22. mujoco-py-2.1.2.14/mujoco_py/__pycache__/builder.cpython-38.pyc +0 -0
  23. mujoco-py-2.1.2.14/mujoco_py/__pycache__/mjviewer.cpython-38.pyc +0 -0
  24. mujoco-py-2.1.2.14/mujoco_py/builder.py +518 -0
  25. mujoco-py-2.1.2.14/mujoco_py/gl/eglplatform.h +125 -0
  26. mujoco-py-2.1.2.14/mujoco_py/gl/glshim.h +30 -0
  27. mujoco-py-2.1.2.14/mujoco_py/gl/khrplatform.h +285 -0
  28. mujoco-py-2.1.2.14/mujoco_py/gl/osmesashim.c +75 -0
  29. mujoco-py-2.1.2.14/mujoco_py/mjbatchrenderer.pyx +301 -0
  30. mujoco-py-2.1.2.14/mujoco_py/mjrendercontext.pyx +329 -0
  31. mujoco-py-2.1.2.14/mujoco_py/mjrenderpool.py +241 -0
  32. mujoco-py-2.1.2.14/mujoco_py/mjsim.pyx +439 -0
  33. mujoco-py-2.1.2.14/mujoco_py/pxd/__init__.py +0 -0
  34. mujoco-py-2.1.2.14/mujoco_py/pxd/mjdata.pxd +312 -0
  35. mujoco-py-2.1.2.14/mujoco_py/pxd/mjmodel.pxd +834 -0
  36. mujoco-py-2.1.2.14/mujoco_py/pxd/mjrender.pxd +115 -0
  37. mujoco-py-2.1.2.14/mujoco_py/pxd/mujoco.pxd +1083 -0
  38. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_materials.premod.png +0 -0
  39. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop0_1.png +0 -0
  40. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop1_0.png +0 -0
  41. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop1_1.png +0 -0
  42. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop2_1.png +0 -0
  43. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_render_pool.mp_test_states.2.png +0 -0
  44. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_render_pool.mp_test_states.3.png +0 -0
  45. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.camera1.png +0 -0
  46. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.freecam.depth-darwin.png +0 -0
  47. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.freecam.depth.png +0 -0
  48. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_resetting.loop1_1.png +0 -0
  49. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_textures.rgb.png +0 -0
  50. mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_textures.variety.png +0 -0
VRL3/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
VRL3/src/cfgs_adroit/task/door.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ num_train_frames: 4100000
2
+ task_name: door-v0
3
+ agent:
4
+ encoder_lr_scale: 1
VRL3/src/cfgs_adroit/task/relocate.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ num_train_frames: 4100000
2
+ task_name: relocate-v0
3
+ agent:
4
+ encoder_lr_scale: 0.01
VRL3/src/logger.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import csv
6
+ import datetime
7
+ from collections import defaultdict
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torchvision
12
+ from termcolor import colored
13
+ from torch.utils.tensorboard import SummaryWriter
14
+
15
+ COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
16
+ ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
17
+ ('episode_reward', 'R', 'float'),
18
+ ('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'),
19
+ ('total_time', 'T', 'time')]
20
+
21
+ COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
22
+ ('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
23
+ ('episode_reward', 'R', 'float'),
24
+ ('total_time', 'T', 'time')]
25
+
26
+
27
+ class AverageMeter(object):
28
+ def __init__(self):
29
+ self._sum = 0
30
+ self._count = 0
31
+
32
+ def update(self, value, n=1):
33
+ self._sum += value
34
+ self._count += n
35
+
36
+ def value(self):
37
+ return self._sum / max(1, self._count)
38
+
39
+
40
+ class MetersGroup(object):
41
+ def __init__(self, csv_file_name, formating):
42
+ self._csv_file_name = csv_file_name
43
+ self._formating = formating
44
+ self._meters = defaultdict(AverageMeter)
45
+ self._csv_file = None
46
+ self._csv_writer = None
47
+
48
+ def log(self, key, value, n=1):
49
+ self._meters[key].update(value, n)
50
+
51
+ def _prime_meters(self):
52
+ data = dict()
53
+ for key, meter in self._meters.items():
54
+ if key.startswith('train'):
55
+ key = key[len('train') + 1:]
56
+ else:
57
+ key = key[len('eval') + 1:]
58
+ key = key.replace('/', '_')
59
+ data[key] = meter.value()
60
+ return data
61
+
62
+ def _remove_old_entries(self, data):
63
+ rows = []
64
+ with self._csv_file_name.open('r') as f:
65
+ reader = csv.DictReader(f)
66
+ for row in reader:
67
+ if float(row['episode']) >= data['episode']:
68
+ break
69
+ rows.append(row)
70
+ with self._csv_file_name.open('w') as f:
71
+ writer = csv.DictWriter(f,
72
+ fieldnames=sorted(data.keys()),
73
+ restval=0.0)
74
+ writer.writeheader()
75
+ for row in rows:
76
+ writer.writerow(row)
77
+
78
+ def _dump_to_csv(self, data):
79
+ if self._csv_writer is None:
80
+ should_write_header = True
81
+ if self._csv_file_name.exists():
82
+ self._remove_old_entries(data)
83
+ should_write_header = False
84
+
85
+ self._csv_file = self._csv_file_name.open('a')
86
+ self._csv_writer = csv.DictWriter(self._csv_file,
87
+ fieldnames=sorted(data.keys()),
88
+ restval=0.0)
89
+ if should_write_header:
90
+ self._csv_writer.writeheader()
91
+
92
+ self._csv_writer.writerow(data)
93
+ self._csv_file.flush()
94
+
95
+ def _format(self, key, value, ty):
96
+ if ty == 'int':
97
+ value = int(value)
98
+ return f'{key}: {value}'
99
+ elif ty == 'float':
100
+ return f'{key}: {value:.04f}'
101
+ elif ty == 'time':
102
+ value = str(datetime.timedelta(seconds=int(value)))
103
+ return f'{key}: {value}'
104
+ else:
105
+ raise f'invalid format type: {ty}'
106
+
107
+ def _dump_to_console(self, data, prefix):
108
+ prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
109
+ pieces = [f'| {prefix: <14}']
110
+ for key, disp_key, ty in self._formating:
111
+ value = data.get(key, 0)
112
+ pieces.append(self._format(disp_key, value, ty))
113
+ print(' | '.join(pieces))
114
+
115
+ def dump(self, step, prefix):
116
+ if len(self._meters) == 0:
117
+ return
118
+ data = self._prime_meters()
119
+ data['frame'] = step
120
+ self._dump_to_csv(data)
121
+ self._dump_to_console(data, prefix)
122
+ self._meters.clear()
123
+
124
+
125
+ class Logger(object):
126
+ def __init__(self, log_dir, use_tb, stage2_logger=False):
127
+ self._log_dir = log_dir
128
+ if not stage2_logger:
129
+ self._train_mg = MetersGroup(log_dir / 'train.csv',
130
+ formating=COMMON_TRAIN_FORMAT)
131
+ self._eval_mg = MetersGroup(log_dir / 'eval.csv',
132
+ formating=COMMON_EVAL_FORMAT)
133
+ else:
134
+ self._train_mg = MetersGroup(log_dir / 'train_stage2.csv',
135
+ formating=COMMON_TRAIN_FORMAT)
136
+ self._eval_mg = MetersGroup(log_dir / 'eval_stage2.csv',
137
+ formating=COMMON_EVAL_FORMAT)
138
+ if use_tb:
139
+ self._sw = SummaryWriter(str(log_dir / 'tb'))
140
+ else:
141
+ self._sw = None
142
+
143
+ def _try_sw_log(self, key, value, step):
144
+ if self._sw is not None:
145
+ self._sw.add_scalar(key, value, step)
146
+
147
+ def log(self, key, value, step):
148
+ assert key.startswith('train') or key.startswith('eval')
149
+ if type(value) == torch.Tensor:
150
+ value = value.item()
151
+ self._try_sw_log(key, value, step)
152
+ mg = self._train_mg if key.startswith('train') else self._eval_mg
153
+ mg.log(key, value)
154
+
155
+ def log_metrics(self, metrics, step, ty):
156
+ for key, value in metrics.items():
157
+ self.log(f'{ty}/{key}', value, step)
158
+
159
+ def dump(self, step, ty=None):
160
+ if ty is None or ty == 'eval':
161
+ self._eval_mg.dump(step, 'eval')
162
+ if ty is None or ty == 'train':
163
+ self._train_mg.dump(step, 'train')
164
+
165
+ def log_and_dump_ctx(self, step, ty):
166
+ return LogAndDumpCtx(self, step, ty)
167
+
168
+
169
+ class LogAndDumpCtx:
170
+ def __init__(self, logger, step, ty):
171
+ self._logger = logger
172
+ self._step = step
173
+ self._ty = ty
174
+
175
+ def __enter__(self):
176
+ return self
177
+
178
+ def __call__(self, key, value):
179
+ self._logger.log(f'{self._ty}/{key}', value, self._step)
180
+
181
+ def __exit__(self, *args):
182
+ self._logger.dump(self._step, self._ty)
VRL3/src/replay_buffer.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import datetime
6
+ import io
7
+ import random
8
+ import traceback
9
+ from collections import defaultdict
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.utils.data import IterableDataset
15
+
16
+
17
+ def episode_len(episode):
18
+ # subtract -1 because the dummy first transition
19
+ return next(iter(episode.values())).shape[0] - 1
20
+
21
+
22
+ def save_episode(episode, fn):
23
+ with io.BytesIO() as bs:
24
+ np.savez_compressed(bs, **episode)
25
+ bs.seek(0)
26
+ with fn.open('wb') as f:
27
+ f.write(bs.read())
28
+
29
+
30
+ def load_episode(fn):
31
+ with fn.open('rb') as f:
32
+ episode = np.load(f)
33
+ episode = {k: episode[k] for k in episode.keys()}
34
+ return episode
35
+
36
+
37
+ class ReplayBufferStorage:
38
+ def __init__(self, data_specs, replay_dir):
39
+ self._data_specs = data_specs
40
+ self._replay_dir = replay_dir
41
+ replay_dir.mkdir(exist_ok=True)
42
+ self._current_episode = defaultdict(list)
43
+ self._preload()
44
+
45
+ def __len__(self):
46
+ return self._num_transitions
47
+
48
+ def add(self, time_step):
49
+ for spec in self._data_specs:
50
+ value = time_step[spec.name]
51
+ if np.isscalar(value):
52
+ value = np.full(spec.shape, value, spec.dtype)
53
+ # print(spec.name, spec.shape, spec.dtype, value.shape, value.dtype)
54
+ assert spec.shape == value.shape and spec.dtype == value.dtype
55
+ self._current_episode[spec.name].append(value)
56
+ if time_step.last():
57
+ episode = dict()
58
+ for spec in self._data_specs:
59
+ value = self._current_episode[spec.name]
60
+ episode[spec.name] = np.array(value, spec.dtype)
61
+ self._current_episode = defaultdict(list)
62
+ self._store_episode(episode)
63
+
64
+ def _preload(self):
65
+ self._num_episodes = 0
66
+ self._num_transitions = 0
67
+ for fn in self._replay_dir.glob('*.npz'):
68
+ _, _, eps_len = fn.stem.split('_')
69
+ self._num_episodes += 1
70
+ self._num_transitions += int(eps_len)
71
+
72
+ def _store_episode(self, episode):
73
+ eps_idx = self._num_episodes
74
+ eps_len = episode_len(episode)
75
+ self._num_episodes += 1
76
+ self._num_transitions += eps_len
77
+ ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
78
+ eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz'
79
+ save_episode(episode, self._replay_dir / eps_fn)
80
+
81
+
82
+ class ReplayBuffer(IterableDataset):
83
+ def __init__(self, replay_dir, max_size, num_workers, nstep, discount,
84
+ fetch_every, save_snapshot, is_adroit=False, return_next_action=False):
85
+ self._replay_dir = replay_dir
86
+ self._size = 0
87
+ self._max_size = max_size
88
+ self._num_workers = max(1, num_workers)
89
+ self._episode_fns = []
90
+ self._episodes = dict()
91
+ self._nstep = nstep
92
+ self._discount = discount
93
+ self._fetch_every = fetch_every
94
+ self._samples_since_last_fetch = fetch_every
95
+ self._save_snapshot = save_snapshot
96
+ self._is_adroit = is_adroit
97
+ self._return_next_action = return_next_action
98
+
99
+ def set_nstep(self, nstep):
100
+ self._nstep = nstep
101
+
102
+ def _sample_episode(self):
103
+ eps_fn = random.choice(self._episode_fns)
104
+ return self._episodes[eps_fn]
105
+
106
+ def _store_episode(self, eps_fn):
107
+ try:
108
+ episode = load_episode(eps_fn)
109
+ except:
110
+ return False
111
+ eps_len = episode_len(episode)
112
+ while eps_len + self._size > self._max_size:
113
+ early_eps_fn = self._episode_fns.pop(0)
114
+ early_eps = self._episodes.pop(early_eps_fn)
115
+ self._size -= episode_len(early_eps)
116
+ early_eps_fn.unlink(missing_ok=True)
117
+ self._episode_fns.append(eps_fn)
118
+ self._episode_fns.sort()
119
+ self._episodes[eps_fn] = episode
120
+ self._size += eps_len
121
+
122
+ if not self._save_snapshot:
123
+ eps_fn.unlink(missing_ok=True)
124
+ return True
125
+
126
+ def _try_fetch(self):
127
+ if self._samples_since_last_fetch < self._fetch_every:
128
+ return
129
+ self._samples_since_last_fetch = 0
130
+ try:
131
+ worker_id = torch.utils.data.get_worker_info().id
132
+ except:
133
+ worker_id = 0
134
+ eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True)
135
+ fetched_size = 0
136
+ for eps_fn in eps_fns:
137
+ eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
138
+ if eps_idx % self._num_workers != worker_id:
139
+ continue
140
+ if eps_fn in self._episodes.keys():
141
+ break
142
+ if fetched_size + eps_len > self._max_size:
143
+ break
144
+ fetched_size += eps_len
145
+ if not self._store_episode(eps_fn):
146
+ break
147
+
148
+ def _sample(self):
149
+ try:
150
+ self._try_fetch()
151
+ except:
152
+ traceback.print_exc()
153
+ self._samples_since_last_fetch += 1
154
+ episode = self._sample_episode()
155
+ # add +1 for the first dummy transition
156
+ idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
157
+ obs = episode['observation'][idx - 1]
158
+ action = episode['action'][idx]
159
+ next_obs = episode['observation'][idx + self._nstep - 1]
160
+ reward = np.zeros_like(episode['reward'][idx])
161
+ discount = np.ones_like(episode['discount'][idx])
162
+ for i in range(self._nstep):
163
+ step_reward = episode['reward'][idx + i]
164
+ reward += discount * step_reward
165
+ discount *= episode['discount'][idx + i] * self._discount
166
+
167
+ if self._return_next_action:
168
+ next_action = episode['action'][idx + self._nstep - 1]
169
+
170
+ if not self._is_adroit:
171
+ if self._return_next_action:
172
+ return (obs, action, reward, discount, next_obs, next_action)
173
+ else:
174
+ return (obs, action, reward, discount, next_obs)
175
+ else:
176
+ obs_sensor = episode['observation_sensor'][idx - 1]
177
+ obs_sensor_next = episode['observation_sensor'][idx + self._nstep - 1]
178
+ if self._return_next_action:
179
+ return (obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next, next_action)
180
+ else:
181
+ return (obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next)
182
+
183
+ def __iter__(self):
184
+ while True:
185
+ yield self._sample()
186
+
187
+
188
+ def _worker_init_fn(worker_id):
189
+ seed = np.random.get_state()[1][0] + worker_id
190
+ np.random.seed(seed)
191
+ random.seed(seed)
192
+
193
+
194
+ def make_replay_loader(replay_dir, max_size, batch_size, num_workers,
195
+ save_snapshot, nstep, discount, fetch_every=1000, is_adroit=False, return_next_action=False):
196
+ max_size_per_worker = max_size // max(1, num_workers)
197
+
198
+ iterable = ReplayBuffer(replay_dir,
199
+ max_size_per_worker,
200
+ num_workers,
201
+ nstep,
202
+ discount,
203
+ fetch_every=fetch_every,
204
+ save_snapshot=save_snapshot,
205
+ is_adroit=is_adroit,
206
+ return_next_action=return_next_action)
207
+
208
+ loader = torch.utils.data.DataLoader(iterable,
209
+ batch_size=batch_size,
210
+ num_workers=num_workers,
211
+ pin_memory=True,
212
+ worker_init_fn=_worker_init_fn)
213
+ return loader
214
+
215
+ def reinit_data_loader(data_loader, batch_size, num_workers):
216
+ # reinit a data loader with a new batch size
217
+ loader = torch.utils.data.DataLoader(data_loader.dataset,
218
+ batch_size=batch_size,
219
+ num_workers=num_workers,
220
+ pin_memory=True,
221
+ worker_init_fn=_worker_init_fn)
222
+ return loader
VRL3/src/rrl_local/__pycache__/rrl_multicam.cpython-38.pyc ADDED
Binary file (10.4 kB). View file
 
VRL3/src/stage1_models.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ from numpy import identity
5
+ import torch.nn as nn
6
+ import numpy as np
7
+
8
+ """
9
+ most code here are modified from the TORCHVISION.MODELS.RESNET
10
+ """
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
16
+ """3x3 convolution with padding"""
17
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
18
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
19
+
20
+ def conv1x1(in_planes, out_planes, stride=1):
21
+ """1x1 convolution"""
22
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
23
+
24
+ class BasicBlock(nn.Module):
25
+ expansion = 1
26
+
27
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
28
+ base_width=64, dilation=1, norm_layer=None):
29
+ super(BasicBlock, self).__init__()
30
+ if norm_layer is None:
31
+ norm_layer = nn.BatchNorm2d
32
+ if groups != 1 or base_width != 64:
33
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
34
+ if dilation > 1:
35
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
36
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
37
+ self.conv1 = conv3x3(inplanes, planes, stride)
38
+ self.bn1 = norm_layer(planes)
39
+ self.relu = nn.ReLU(inplace=True)
40
+ self.conv2 = conv3x3(planes, planes)
41
+ self.bn2 = norm_layer(planes)
42
+ self.downsample = downsample
43
+ self.stride = stride
44
+
45
+ def forward(self, x):
46
+ identity = x
47
+
48
+ out = self.conv1(x)
49
+ out = self.bn1(out)
50
+ out = self.relu(out)
51
+
52
+ out = self.conv2(out)
53
+ out = self.bn2(out)
54
+
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+
58
+ out += identity
59
+ out = self.relu(out)
60
+
61
+ return out
62
+
63
+ class OneLayerBlock(nn.Module):
64
+ # similar to BasicBlock, but shallower
65
+ expansion = 1
66
+
67
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
68
+ base_width=64, dilation=1, norm_layer=None):
69
+ super(OneLayerBlock, self).__init__()
70
+ if norm_layer is None:
71
+ norm_layer = nn.BatchNorm2d
72
+ if groups != 1 or base_width != 64:
73
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
74
+ if dilation > 1:
75
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
76
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
77
+ self.conv1 = conv3x3(inplanes, planes, stride)
78
+ self.bn1 = norm_layer(planes)
79
+ self.relu = nn.ReLU(inplace=True)
80
+ self.stride = stride
81
+
82
+ def forward(self, x):
83
+ out = self.conv1(x)
84
+ out = self.bn1(out)
85
+ out = self.relu(out)
86
+ return out
87
+
88
+ class Bottleneck(nn.Module):
89
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
90
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
91
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
92
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
93
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
94
+
95
+ expansion = 4
96
+
97
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
98
+ base_width=64, dilation=1, norm_layer=None):
99
+ super(Bottleneck, self).__init__()
100
+ if norm_layer is None:
101
+ norm_layer = nn.BatchNorm2d
102
+ width = int(planes * (base_width / 64.)) * groups
103
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
104
+ self.conv1 = conv1x1(inplanes, width)
105
+ self.bn1 = norm_layer(width)
106
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
107
+ self.bn2 = norm_layer(width)
108
+ self.conv3 = conv1x1(width, planes * self.expansion)
109
+ self.bn3 = norm_layer(planes * self.expansion)
110
+ self.relu = nn.ReLU(inplace=True)
111
+ self.downsample = downsample
112
+ self.stride = stride
113
+
114
+ def forward(self, x):
115
+ identity = x
116
+
117
+ out = self.conv1(x)
118
+ out = self.bn1(out)
119
+ out = self.relu(out)
120
+
121
+ out = self.conv2(out)
122
+ out = self.bn2(out)
123
+ out = self.relu(out)
124
+
125
+ out = self.conv3(out)
126
+ out = self.bn3(out)
127
+
128
+ if self.downsample is not None:
129
+ identity = self.downsample(x)
130
+
131
+ out += identity
132
+ out = self.relu(out)
133
+
134
+ return out
135
+
136
+ def drq_weight_init(m):
137
+ # weight init scheme used in DrQv2
138
+ if isinstance(m, nn.Linear):
139
+ nn.init.orthogonal_(m.weight.data)
140
+ if hasattr(m.bias, 'data'):
141
+ m.bias.data.fill_(0.0)
142
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
143
+ gain = nn.init.calculate_gain('relu')
144
+ nn.init.orthogonal_(m.weight.data, gain)
145
+ if hasattr(m.bias, 'data'):
146
+ m.bias.data.fill_(0.0)
147
+
148
+ class Stage3ShallowEncoder(nn.Module):
149
+ """
150
+ this is the encoder architecture used in DrQv2
151
+ """
152
+ def __init__(self, obs_shape, n_channel):
153
+ super().__init__()
154
+ assert len(obs_shape) == 3
155
+ self.repr_dim = n_channel * 35 * 35
156
+ self.conv1 = nn.Conv2d(obs_shape[0], n_channel, 3, stride=2)
157
+ self.conv2 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
158
+ self.conv3 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
159
+ self.conv4 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
160
+ self.relu = nn.ReLU(inplace=True)
161
+ self.apply(drq_weight_init)
162
+
163
+ def _forward_impl(self, x):
164
+ x = self.relu(self.conv1(x))
165
+ x = self.relu(self.conv2(x))
166
+ x = self.relu(self.conv3(x))
167
+ x = self.relu(self.conv4(x))
168
+ return x
169
+
170
+ def forward(self, obs):
171
+ o = obs
172
+ h = self._forward_impl(o)
173
+ h = h.view(h.shape[0], -1)
174
+ return h
175
+
176
+ class ResNet84(nn.Module):
177
+ """
178
+ default stage 1 encoder used by VRL3, this is modified from the PyTorch standard ResNet class
179
+ but is more lightweight and this is much faster with 84x84 input size
180
+ use "layers" to specify how deep the network is
181
+ use "start_num_channel" to control how wide it is
182
+ """
183
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
184
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
185
+ norm_layer=None, start_num_channel=32):
186
+ super(ResNet84, self).__init__()
187
+ if norm_layer is None:
188
+ norm_layer = nn.BatchNorm2d
189
+ self._norm_layer = norm_layer
190
+
191
+ self.start_num_channel = start_num_channel
192
+ self.inplanes = start_num_channel
193
+ self.dilation = 1
194
+ if replace_stride_with_dilation is None:
195
+ # each element in the tuple indicates if we should replace
196
+ # the 2x2 stride with a dilated convolution instead
197
+ replace_stride_with_dilation = [False, False, False]
198
+ if len(replace_stride_with_dilation) != 3:
199
+ raise ValueError("replace_stride_with_dilation should be None "
200
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
201
+ self.groups = groups
202
+ self.base_width = width_per_group
203
+ self.conv1 = nn.Conv2d(3, self.inplanes, 3, stride=2)
204
+ self.bn1 = norm_layer(self.inplanes)
205
+ self.relu = nn.ReLU(inplace=True)
206
+ # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
207
+
208
+ self.layer1 = self._make_layer(block, start_num_channel, layers[0])
209
+ self.layer2 = self._make_layer(block, start_num_channel * 2, layers[1], stride=2,
210
+ dilate=replace_stride_with_dilation[0])
211
+ self.layer3 = self._make_layer(block, start_num_channel * 4, layers[2], stride=2,
212
+ dilate=replace_stride_with_dilation[1])
213
+ self.layer4 = self._make_layer(block, start_num_channel * 8, layers[3], stride=2,
214
+ dilate=replace_stride_with_dilation[2])
215
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
216
+ self.fc = nn.Linear(start_num_channel * 8 * block.expansion, num_classes)
217
+
218
+ for m in self.modules():
219
+ if isinstance(m, nn.Conv2d):
220
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
221
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
222
+ nn.init.constant_(m.weight, 1)
223
+ nn.init.constant_(m.bias, 0)
224
+
225
+ # Zero-initialize the last BN in each residual branch,
226
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
227
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
228
+ if zero_init_residual:
229
+ for m in self.modules():
230
+ if isinstance(m, Bottleneck):
231
+ nn.init.constant_(m.bn3.weight, 0)
232
+ elif isinstance(m, BasicBlock):
233
+ nn.init.constant_(m.bn2.weight, 0)
234
+
235
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
236
+ # vrl3: if block is 0, allows a smaller network size
237
+ if blocks == 0:
238
+ block = OneLayerBlock
239
+
240
+ norm_layer = self._norm_layer
241
+ downsample = None
242
+ previous_dilation = self.dilation
243
+ if dilate:
244
+ self.dilation *= stride
245
+ stride = 1
246
+ if stride != 1 or self.inplanes != planes * block.expansion:
247
+ downsample = nn.Sequential(
248
+ conv1x1(self.inplanes, planes * block.expansion, stride),
249
+ norm_layer(planes * block.expansion),
250
+ )
251
+
252
+ layers = []
253
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
254
+ self.base_width, previous_dilation, norm_layer))
255
+ self.inplanes = planes * block.expansion
256
+ for _ in range(1, blocks):
257
+ layers.append(block(self.inplanes, planes, groups=self.groups,
258
+ base_width=self.base_width, dilation=self.dilation,
259
+ norm_layer=norm_layer))
260
+
261
+ return nn.Sequential(*layers)
262
+
263
+ def _forward_impl(self, x):
264
+ x = self.conv1(x)
265
+ x = self.bn1(x)
266
+ x = self.relu(x)
267
+
268
+ x = self.layer1(x)
269
+ x = self.layer2(x)
270
+ x = self.layer3(x)
271
+ x = self.layer4(x)
272
+ x = self.avgpool(x)
273
+ x = torch.flatten(x, 1)
274
+ x = self.fc(x)
275
+
276
+ return x
277
+
278
+ def get_feature_size(self):
279
+ assert self.start_num_channel % 32 == 0
280
+ multiplier = self.start_num_channel // 32
281
+ size = 256 * multiplier
282
+ return size
283
+
284
+ def forward(self, x):
285
+ return self._forward_impl(x)
286
+
287
+ def get_features(self, x):
288
+ x = self.conv1(x)
289
+ # print("0", x.shape) # 32 x 41 x 41 = 53792
290
+ x = self.bn1(x)
291
+ x = self.relu(x)
292
+
293
+ x = self.layer1(x)
294
+ # print("1", x.shape) # 32 x 41 x 41= 53792
295
+
296
+ x = self.layer2(x)
297
+ # print("2", x.shape) # 64 x 21 x 21 = 28224
298
+
299
+ x = self.layer3(x)
300
+ # print("3", x.shape) # 128 x 11 x 11 = 15488
301
+
302
+ x = self.layer4(x)
303
+ # print("4", x.shape) # 256 x 6 x 6 = 9216
304
+
305
+ x = self.avgpool(x)
306
+ # print("pool", x.shape) # 256 x 1 x 1
307
+
308
+ final_out = torch.flatten(x, 1)
309
+ # print("flatten", x.shape) # 256
310
+ return final_out
311
+
312
+ class Identity(nn.Module):
313
+ def __init__(self, input_placeholder=None):
314
+ super(Identity, self).__init__()
315
+
316
+ def forward(self, x):
317
+ return x
318
+
VRL3/src/train_stage1.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ # this file is modified from the pytorch official tutorial
5
+ # NOTE: stage 1 training code is currently being cleaned up
6
+
7
+ import argparse
8
+ import os
9
+ import random
10
+ import shutil
11
+ import time
12
+ import warnings
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.parallel
17
+ import torch.backends.cudnn as cudnn
18
+ import torch.distributed as dist
19
+ import torch.optim
20
+ import torch.multiprocessing as mp
21
+ import torch.utils.data
22
+ import torch.utils.data.distributed
23
+ import torchvision.transforms as transforms
24
+ import torchvision.datasets as datasets
25
+ import torchvision.models as models
26
+
27
+ # TODO use another config file to indicate the location of the training data and also where to save models...
28
+ # should also add an option to just test the accuracy of models...
29
+ # we probably can test this locally ....
30
+
31
+ from stage1_models import BasicBlock, ResNet84
32
+
33
+ rl_model_names = ['resnet6_32channel', 'resnet10_32channel', 'resnet18_32channel',
34
+ 'resnet6_64channel', 'resnet10_64channel', 'resnet18_64channel',]
35
+ model_names = sorted(name for name in models.__dict__
36
+ if name.islower() and not name.startswith("__")
37
+ and callable(models.__dict__[name])) + rl_model_names
38
+
39
+ parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
40
+ parser.add_argument('data', metavar='DIR',
41
+ help='path to dataset')
42
+ parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet10_32channel',
43
+ choices=model_names,
44
+ help='model architecture: ' +
45
+ ' | '.join(model_names) +
46
+ ' (default: resnet18)')
47
+ parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
48
+ help='number of data loading workers (default: 4)')
49
+ parser.add_argument('--epochs', default=90, type=int, metavar='N',
50
+ help='number of total epochs to run')
51
+ parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
52
+ help='manual epoch number (useful on restarts)')
53
+ parser.add_argument('-b', '--batch-size', default=256, type=int,
54
+ metavar='N',
55
+ help='mini-batch size (default: 256), this is the total '
56
+ 'batch size of all GPUs on the current node when '
57
+ 'using Data Parallel or Distributed Data Parallel')
58
+ parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
59
+ metavar='LR', help='initial learning rate', dest='lr')
60
+ parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
61
+ help='momentum')
62
+ parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
63
+ metavar='W', help='weight decay (default: 1e-4)',
64
+ dest='weight_decay')
65
+ parser.add_argument('-p', '--print-freq', default=10, type=int,
66
+ metavar='N', help='print frequency (default: 10)')
67
+ parser.add_argument('--resume', default='', type=str, metavar='PATH',
68
+ help='path to latest checkpoint (default: none)')
69
+ parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
70
+ help='evaluate model on validation set')
71
+ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
72
+ help='use pre-trained model')
73
+ parser.add_argument('--world-size', default=-1, type=int,
74
+ help='number of nodes for distributed training')
75
+ parser.add_argument('--rank', default=-1, type=int,
76
+ help='node rank for distributed training')
77
+ parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
78
+ help='url used to set up distributed training')
79
+ parser.add_argument('--dist-backend', default='nccl', type=str,
80
+ help='distributed backend')
81
+ parser.add_argument('--seed', default=None, type=int,
82
+ help='seed for initializing training. ')
83
+ parser.add_argument('--gpu', default=None, type=int,
84
+ help='GPU id to use.')
85
+ parser.add_argument('--multiprocessing-distributed', action='store_true',
86
+ help='Use multi-processing distributed training to launch '
87
+ 'N processes per node, which has N GPUs. This is the '
88
+ 'fastest way to use PyTorch for either single node or '
89
+ 'multi node data parallel training')
90
+ parser.add_argument('--debug', default=0, type=int,
91
+ help='1 for debug mode, 2 for super fast debug mode')
92
+
93
+ best_acc1 = 0
94
+ INPUT_SIZE = 84
95
+ VAL_RESIZE = 100
96
+
97
+ def main():
98
+ print(model_names)
99
+
100
+ args = parser.parse_args()
101
+
102
+ if args.seed is not None:
103
+ random.seed(args.seed)
104
+ torch.manual_seed(args.seed)
105
+ cudnn.deterministic = True
106
+ warnings.warn('You have chosen to seed training. '
107
+ 'This will turn on the CUDNN deterministic setting, '
108
+ 'which can slow down your training considerably! '
109
+ 'You may see unexpected behavior when restarting '
110
+ 'from checkpoints.')
111
+
112
+ if args.gpu is not None:
113
+ warnings.warn('You have chosen a specific GPU. This will completely '
114
+ 'disable data parallelism.')
115
+
116
+ if args.dist_url == "env://" and args.world_size == -1:
117
+ args.world_size = int(os.environ["WORLD_SIZE"])
118
+
119
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
120
+
121
+ ngpus_per_node = torch.cuda.device_count()
122
+ if args.multiprocessing_distributed:
123
+ # Since we have ngpus_per_node processes per node, the total world_size
124
+ # needs to be adjusted accordingly
125
+ args.world_size = ngpus_per_node * args.world_size
126
+ # Use torch.multiprocessing.spawn to launch distributed processes: the
127
+ # main_worker process function
128
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
129
+ else:
130
+ # Simply call main_worker function
131
+ main_worker(args.gpu, ngpus_per_node, args)
132
+
133
+
134
+ def main_worker(gpu, ngpus_per_node, args):
135
+ global best_acc1
136
+ args.gpu = gpu
137
+
138
+ if args.gpu is not None:
139
+ print("Use GPU: {} for training".format(args.gpu))
140
+
141
+ if args.distributed:
142
+ if args.dist_url == "env://" and args.rank == -1:
143
+ args.rank = int(os.environ["RANK"])
144
+ if args.multiprocessing_distributed:
145
+ # For multiprocessing distributed training, rank needs to be the
146
+ # global rank among all the processes
147
+ args.rank = args.rank * ngpus_per_node + gpu
148
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
149
+ world_size=args.world_size, rank=args.rank)
150
+ # create model
151
+ if args.debug > 0:
152
+ print("=> creating model for debug 2")
153
+ # model = ResNet84(BasicBlock, [1, 1, 1, 1], num_classes=5) # 1, 1, 1, 1 will make a resnet10
154
+ model = ResNet84(BasicBlock, [0, 0, 0, 0], num_classes=5) # 0, 0, 0, 0 make a convnet6 (5 conv layers in total lol)
155
+ x = torch.rand((1, 3, 84, 84)).float()
156
+ out = model(x)
157
+ print(model)
158
+ quit()
159
+
160
+ # model = ResNetTest2(BasicBlock, [2, 2, 2, 2])
161
+ #model = Drq4Encoder((3, 84, 84), n_channel, 200)
162
+ else:
163
+ if args.pretrained:
164
+ print("=> using pre-trained model '{}'".format(args.arch))
165
+ model = models.__dict__[args.arch](pretrained=True)
166
+ else:
167
+ print("=> creating model '{}'".format(args.arch))
168
+ if args.arch in rl_model_names:
169
+ if args.arch == 'resnet18_32channel':
170
+ model = ResNet84(BasicBlock, [2, 2, 2, 2], start_num_channel=32) # 1, 1, 1, 1 will make a resnet10
171
+ elif args.arch == 'resnet10_32channel':
172
+ model = ResNet84(BasicBlock, [1, 1, 1, 1], start_num_channel=32) # 1, 1, 1, 1 will make a resnet10
173
+ elif args.arch == 'resnet6_32channel':
174
+ model = ResNet84(BasicBlock, [0, 0, 0, 0], start_num_channel=32) # resnet 6 (actually not even resnet because no skip connection)
175
+ elif args.arch == 'resnet18_64channel':
176
+ model = ResNet84(BasicBlock, [2, 2, 2, 2], start_num_channel=64)
177
+ elif args.arch == 'resnet10_64channel':
178
+ model = ResNet84(BasicBlock, [1, 1, 1, 1], start_num_channel=64)
179
+ elif args.arch == 'resnet6_64channel':
180
+ model = ResNet84(BasicBlock, [0, 0, 0, 0], start_num_channel=64)
181
+ else:
182
+ print("specialized model not yet implemented")
183
+ quit()
184
+ else:
185
+ model = models.__dict__[args.arch]()
186
+
187
+ if not torch.cuda.is_available():
188
+ print('using CPU, this will be slow')
189
+ elif args.distributed:
190
+ print("distributed")
191
+ # For multiprocessing distributed, DistributedDataParallel constructor
192
+ # should always set the single device scope, otherwise,
193
+ # DistributedDataParallel will use all available devices.
194
+ if args.gpu is not None:
195
+ torch.cuda.set_device(args.gpu)
196
+ model.cuda(args.gpu)
197
+ # When using a single GPU per process and per
198
+ # DistributedDataParallel, we need to divide the batch size
199
+ # ourselves based on the total number of GPUs we have
200
+ args.batch_size = int(args.batch_size / ngpus_per_node)
201
+ args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
202
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
203
+ else:
204
+ model.cuda()
205
+ # DistributedDataParallel will divide and allocate batch_size to all
206
+ # available GPUs if device_ids are not set
207
+ model = torch.nn.parallel.DistributedDataParallel(model)
208
+ elif args.gpu is not None:
209
+ print("use gpu:", args.gpu)
210
+ torch.cuda.set_device(args.gpu)
211
+ model = model.cuda(args.gpu)
212
+ else:
213
+ print("data parallel")
214
+ # DataParallel will divide and allocate batch_size to all available GPUs
215
+ if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
216
+ model.features = torch.nn.DataParallel(model.features)
217
+ model.cuda()
218
+ else:
219
+ model = torch.nn.DataParallel(model).cuda()
220
+
221
+ # define loss function (criterion) and optimizer
222
+ criterion = nn.CrossEntropyLoss().cuda(args.gpu)
223
+
224
+ optimizer = torch.optim.SGD(model.parameters(), args.lr,
225
+ momentum=args.momentum,
226
+ weight_decay=args.weight_decay)
227
+
228
+ # optionally resume from a checkpoint
229
+ if args.resume:
230
+ if os.path.isfile(args.resume):
231
+ print("=> loading checkpoint '{}'".format(args.resume))
232
+ if args.gpu is None:
233
+ checkpoint = torch.load(args.resume)
234
+ else:
235
+ # Map model to be loaded to specified single gpu.
236
+ loc = 'cuda:{}'.format(args.gpu)
237
+ checkpoint = torch.load(args.resume, map_location=loc)
238
+ args.start_epoch = checkpoint['epoch']
239
+ best_acc1 = checkpoint['best_acc1']
240
+ if args.gpu is not None:
241
+ # best_acc1 may be from a checkpoint from a different GPU
242
+ best_acc1 = best_acc1.to(args.gpu)
243
+ model.load_state_dict(checkpoint['state_dict'])
244
+ optimizer.load_state_dict(checkpoint['optimizer'])
245
+ print("=> loaded checkpoint '{}' (epoch {})"
246
+ .format(args.resume, checkpoint['epoch']))
247
+ else:
248
+ print("=> no checkpoint found at '{}'".format(args.resume))
249
+
250
+ cudnn.benchmark = True
251
+
252
+ # Data loading code
253
+ traindir = os.path.join(args.data, 'train')
254
+ valdir = os.path.join(args.data, 'val')
255
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
256
+ std=[0.229, 0.224, 0.225])
257
+
258
+ print("train directory is:", traindir)
259
+ print("val directory is:", valdir)
260
+
261
+ train_dataset = datasets.ImageFolder(
262
+ traindir,
263
+ transforms.Compose([
264
+ transforms.RandomResizedCrop(INPUT_SIZE),
265
+ transforms.RandomHorizontalFlip(),
266
+ transforms.ToTensor(),
267
+ normalize,
268
+ ]))
269
+
270
+ print("data set ready")
271
+
272
+ if args.distributed:
273
+ train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
274
+ else:
275
+ train_sampler = None
276
+
277
+ train_loader = torch.utils.data.DataLoader(
278
+ train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
279
+ num_workers=args.workers, pin_memory=True, sampler=train_sampler)
280
+
281
+ val_loader = torch.utils.data.DataLoader(
282
+ datasets.ImageFolder(valdir, transforms.Compose([
283
+ # transforms.Resize(256),
284
+ transforms.Resize(VAL_RESIZE),
285
+ transforms.CenterCrop(INPUT_SIZE),
286
+ transforms.ToTensor(),
287
+ normalize,
288
+ ])),
289
+ batch_size=args.batch_size, shuffle=False,
290
+ num_workers=args.workers, pin_memory=True)
291
+
292
+ if args.evaluate:
293
+ validate(val_loader, model, criterion, args)
294
+ return
295
+
296
+ for epoch in range(args.start_epoch, args.epochs):
297
+ print(epoch)
298
+ epoch_start_time = time.time()
299
+
300
+ if args.distributed:
301
+ train_sampler.set_epoch(epoch)
302
+ adjust_learning_rate(optimizer, epoch, args)
303
+
304
+ # train for one epoch
305
+ train(train_loader, model, criterion, optimizer, epoch, args)
306
+
307
+ # evaluate on validation set
308
+ acc1 = validate(val_loader, model, criterion, args)
309
+
310
+ # remember best acc@1 and save checkpoint
311
+ is_best = acc1 > best_acc1
312
+ best_acc1 = max(acc1, best_acc1)
313
+
314
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed
315
+ and args.rank % ngpus_per_node == 0):
316
+ save_checkpoint({
317
+ 'epoch': epoch + 1,
318
+ 'arch': args.arch,
319
+ 'state_dict': model.state_dict(),
320
+ 'best_acc1': best_acc1,
321
+ 'optimizer' : optimizer.state_dict(),
322
+ }, is_best,
323
+ save_name_prefix=args.arch)
324
+
325
+ epoch_end_time = time.time() - epoch_start_time
326
+ print("epoch finished in %.3f hour" % (epoch_end_time/3600))
327
+
328
+ def train(train_loader, model, criterion, optimizer, epoch, args):
329
+ batch_time = AverageMeter('Time', ':6.3f')
330
+ data_time = AverageMeter('Data', ':6.3f')
331
+ losses = AverageMeter('Loss', ':.4e')
332
+ top1 = AverageMeter('Acc@1', ':6.2f')
333
+ top5 = AverageMeter('Acc@5', ':6.2f')
334
+ progress = ProgressMeter(
335
+ len(train_loader),
336
+ [batch_time, data_time, losses, top1, top5],
337
+ prefix="Epoch: [{}]".format(epoch))
338
+
339
+ # switch to train mode
340
+ model.train()
341
+
342
+ end = time.time()
343
+ for i, (images, target) in enumerate(train_loader):
344
+ # measure data loading time
345
+ data_time.update(time.time() - end)
346
+
347
+ if args.gpu is not None:
348
+ images = images.cuda(args.gpu, non_blocking=True)
349
+ if torch.cuda.is_available():
350
+ target = target.cuda(args.gpu, non_blocking=True)
351
+
352
+ # compute output
353
+ output = model(images)
354
+ loss = criterion(output, target)
355
+
356
+ # measure accuracy and record loss
357
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
358
+ losses.update(loss.item(), images.size(0))
359
+ top1.update(acc1[0], images.size(0))
360
+ top5.update(acc5[0], images.size(0))
361
+
362
+ # compute gradient and do SGD step
363
+ optimizer.zero_grad()
364
+ loss.backward()
365
+ optimizer.step()
366
+
367
+ # measure elapsed time
368
+ batch_time.update(time.time() - end)
369
+ end = time.time()
370
+
371
+ if i % args.print_freq == 0:
372
+ progress.display(i)
373
+
374
+
375
+ def validate(val_loader, model, criterion, args):
376
+ batch_time = AverageMeter('Time', ':6.3f')
377
+ losses = AverageMeter('Loss', ':.4e')
378
+ top1 = AverageMeter('Acc@1', ':6.2f')
379
+ top5 = AverageMeter('Acc@5', ':6.2f')
380
+ progress = ProgressMeter(
381
+ len(val_loader),
382
+ [batch_time, losses, top1, top5],
383
+ prefix='Test: ')
384
+
385
+ # switch to evaluate mode
386
+ model.eval()
387
+
388
+ with torch.no_grad():
389
+ end = time.time()
390
+ for i, (images, target) in enumerate(val_loader):
391
+ if args.gpu is not None:
392
+ images = images.cuda(args.gpu, non_blocking=True)
393
+ if torch.cuda.is_available():
394
+ target = target.cuda(args.gpu, non_blocking=True)
395
+
396
+ # compute output
397
+ output = model(images)
398
+ loss = criterion(output, target)
399
+
400
+ # measure accuracy and record loss
401
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
402
+ losses.update(loss.item(), images.size(0))
403
+ top1.update(acc1[0], images.size(0))
404
+ top5.update(acc5[0], images.size(0))
405
+
406
+ # measure elapsed time
407
+ batch_time.update(time.time() - end)
408
+ end = time.time()
409
+
410
+ if i % args.print_freq == 0:
411
+ progress.display(i)
412
+
413
+ # TODO: this should also be done with the ProgressMeter
414
+ print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
415
+ .format(top1=top1, top5=top5))
416
+
417
+ return top1.avg
418
+
419
+
420
+ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_name_prefix = ''):
421
+ save_name = save_name_prefix + '_' + filename
422
+ torch.save(state, save_name)
423
+ if is_best:
424
+ best_model_save_name = save_name_prefix + '_' + 'model_best.pth.tar'
425
+ shutil.copyfile(save_name, best_model_save_name)
426
+
427
+ class AverageMeter(object):
428
+ """Computes and stores the average and current value"""
429
+ def __init__(self, name, fmt=':f'):
430
+ self.name = name
431
+ self.fmt = fmt
432
+ self.reset()
433
+
434
+ def reset(self):
435
+ self.val = 0
436
+ self.avg = 0
437
+ self.sum = 0
438
+ self.count = 0
439
+
440
+ def update(self, val, n=1):
441
+ self.val = val
442
+ self.sum += val * n
443
+ self.count += n
444
+ self.avg = self.sum / self.count
445
+
446
+ def __str__(self):
447
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
448
+ return fmtstr.format(**self.__dict__)
449
+
450
+
451
+ class ProgressMeter(object):
452
+ def __init__(self, num_batches, meters, prefix=""):
453
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
454
+ self.meters = meters
455
+ self.prefix = prefix
456
+
457
+ def display(self, batch):
458
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
459
+ entries += [str(meter) for meter in self.meters]
460
+ print('\t'.join(entries))
461
+
462
+ def _get_batch_fmtstr(self, num_batches):
463
+ num_digits = len(str(num_batches // 1))
464
+ fmt = '{:' + str(num_digits) + 'd}'
465
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
466
+
467
+
468
+ def adjust_learning_rate(optimizer, epoch, args):
469
+ """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
470
+ lr = args.lr * (0.1 ** (epoch // 30))
471
+ for param_group in optimizer.param_groups:
472
+ param_group['lr'] = lr
473
+
474
+
475
+ def accuracy(output, target, topk=(1,)):
476
+ """Computes the accuracy over the k top predictions for the specified values of k"""
477
+ with torch.no_grad():
478
+ maxk = max(topk)
479
+ batch_size = target.size(0)
480
+
481
+ _, pred = output.topk(maxk, 1, True, True)
482
+ pred = pred.t()
483
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
484
+
485
+ res = []
486
+ for k in topk:
487
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
488
+ res.append(correct_k.mul_(100.0 / batch_size))
489
+ return res
490
+
491
+
492
+ if __name__ == '__main__':
493
+ main()
VRL3/src/utils.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import random
6
+ import re
7
+ import time
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from omegaconf import OmegaConf
14
+ from torch import distributions as pyd
15
+ from torch.distributions.utils import _standard_normal
16
+
17
+
18
+ class eval_mode:
19
+ def __init__(self, *models):
20
+ self.models = models
21
+
22
+ def __enter__(self):
23
+ self.prev_states = []
24
+ for model in self.models:
25
+ self.prev_states.append(model.training)
26
+ model.train(False)
27
+
28
+ def __exit__(self, *args):
29
+ for model, state in zip(self.models, self.prev_states):
30
+ model.train(state)
31
+ return False
32
+
33
+
34
+ def set_seed_everywhere(seed):
35
+ torch.manual_seed(seed)
36
+ if torch.cuda.is_available():
37
+ torch.cuda.manual_seed_all(seed)
38
+ np.random.seed(seed)
39
+ random.seed(seed)
40
+
41
+
42
+ def soft_update_params(net, target_net, tau):
43
+ for param, target_param in zip(net.parameters(), target_net.parameters()):
44
+ target_param.data.copy_(tau * param.data +
45
+ (1 - tau) * target_param.data)
46
+
47
+
48
+ def to_torch(xs, device):
49
+ return tuple(torch.as_tensor(x, device=device) for x in xs)
50
+
51
+
52
+ def weight_init(m):
53
+ if isinstance(m, nn.Linear):
54
+ nn.init.orthogonal_(m.weight.data)
55
+ if hasattr(m.bias, 'data'):
56
+ m.bias.data.fill_(0.0)
57
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
58
+ gain = nn.init.calculate_gain('relu')
59
+ nn.init.orthogonal_(m.weight.data, gain)
60
+ if hasattr(m.bias, 'data'):
61
+ m.bias.data.fill_(0.0)
62
+
63
+
64
+ class Until:
65
+ def __init__(self, until, action_repeat=1):
66
+ self._until = until
67
+ self._action_repeat = action_repeat
68
+
69
+ def __call__(self, step):
70
+ if self._until is None:
71
+ return True
72
+ until = self._until // self._action_repeat
73
+ return step < until
74
+
75
+
76
+ class Every:
77
+ def __init__(self, every, action_repeat=1):
78
+ self._every = every
79
+ self._action_repeat = action_repeat
80
+
81
+ def __call__(self, step):
82
+ if self._every is None:
83
+ return False
84
+ every = self._every // self._action_repeat
85
+ if step % every == 0:
86
+ return True
87
+ return False
88
+
89
+
90
+ class Timer:
91
+ def __init__(self):
92
+ self._start_time = time.time()
93
+ self._last_time = time.time()
94
+
95
+ def reset(self):
96
+ elapsed_time = time.time() - self._last_time
97
+ self._last_time = time.time()
98
+ total_time = time.time() - self._start_time
99
+ return elapsed_time, total_time
100
+
101
+ def total_time(self):
102
+ return time.time() - self._start_time
103
+
104
+
105
+ class TruncatedNormal(pyd.Normal):
106
+ def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
107
+ super().__init__(loc, scale, validate_args=False)
108
+ self.low = low
109
+ self.high = high
110
+ self.eps = eps
111
+
112
+ def _clamp(self, x):
113
+ clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
114
+ x = x - x.detach() + clamped_x.detach()
115
+ return x
116
+
117
+ def sample(self, clip=None, sample_shape=torch.Size()):
118
+ shape = self._extended_shape(sample_shape)
119
+ eps = _standard_normal(shape,
120
+ dtype=self.loc.dtype,
121
+ device=self.loc.device)
122
+ eps *= self.scale
123
+ if clip is not None:
124
+ eps = torch.clamp(eps, -clip, clip)
125
+ x = self.loc + eps
126
+ return self._clamp(x)
127
+
128
+
129
+ def schedule(schdl, step):
130
+ try:
131
+ return float(schdl)
132
+ except ValueError:
133
+ match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
134
+ if match:
135
+ init, final, duration = [float(g) for g in match.groups()]
136
+ mix = np.clip(step / duration, 0.0, 1.0)
137
+ return (1.0 - mix) * init + mix * final
138
+ match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
139
+ if match:
140
+ init, final1, duration1, final2, duration2 = [
141
+ float(g) for g in match.groups()
142
+ ]
143
+ if step <= duration1:
144
+ mix = np.clip(step / duration1, 0.0, 1.0)
145
+ return (1.0 - mix) * init + mix * final1
146
+ else:
147
+ mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
148
+ return (1.0 - mix) * final1 + mix * final2
149
+ raise NotImplementedError(schdl)
VRL3/src/vrl3_agent.py ADDED
@@ -0,0 +1,632 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torchvision import datasets, models, transforms
9
+ from transfer_util import initialize_model
10
+ from stage1_models import BasicBlock, ResNet84
11
+ import os
12
+ import copy
13
+ from PIL import Image
14
+ import platform
15
+ from numbers import Number
16
+ import utils
17
+
18
+ class RandomShiftsAug(nn.Module):
19
+ def __init__(self, pad):
20
+ super().__init__()
21
+ self.pad = pad
22
+
23
+ def forward(self, x):
24
+ n, c, h, w = x.size()
25
+ assert h == w
26
+ padding = tuple([self.pad] * 4)
27
+ x = F.pad(x, padding, 'replicate')
28
+ eps = 1.0 / (h + 2 * self.pad)
29
+ arange = torch.linspace(-1.0 + eps,
30
+ 1.0 - eps,
31
+ h + 2 * self.pad,
32
+ device=x.device,
33
+ dtype=x.dtype)[:h]
34
+ arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
35
+ base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
36
+ base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
37
+
38
+ shift = torch.randint(0,
39
+ 2 * self.pad + 1,
40
+ size=(n, 1, 1, 2),
41
+ device=x.device,
42
+ dtype=x.dtype)
43
+ shift *= 2.0 / (h + 2 * self.pad)
44
+
45
+ grid = base_grid + shift
46
+ return F.grid_sample(x,
47
+ grid,
48
+ padding_mode='zeros',
49
+ align_corners=False)
50
+
51
+ class Identity(nn.Module):
52
+ def __init__(self, input_placeholder=None):
53
+ super(Identity, self).__init__()
54
+
55
+ def forward(self, x):
56
+ return x
57
+
58
+ class RLEncoder(nn.Module):
59
+ def __init__(self, obs_shape, model_name, device):
60
+ super().__init__()
61
+ # a wrapper over a non-RL encoder model
62
+ self.device = device
63
+ assert len(obs_shape) == 3
64
+ self.n_input_channel = obs_shape[0]
65
+ assert self.n_input_channel % 3 == 0
66
+ self.n_images = self.n_input_channel // 3
67
+ self.model = self.init_model(model_name)
68
+ self.model.fc = Identity()
69
+ self.repr_dim = self.model.get_feature_size()
70
+
71
+ self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406),
72
+ (0.229, 0.224, 0.225))
73
+ self.channel_mismatch = True
74
+
75
+ def init_model(self, model_name):
76
+ # model name is e.g. resnet6_32channel
77
+ n_layer_string, n_channel_string = model_name.split('_')
78
+ layer_string_to_layer_list = {
79
+ 'resnet6': [0, 0, 0, 0],
80
+ 'resnet10': [1, 1, 1, 1],
81
+ 'resnet18': [2, 2, 2, 2],
82
+ }
83
+ channel_string_to_n_channel = {
84
+ '32channel': 32,
85
+ '64channel': 64,
86
+ }
87
+ layer_list = layer_string_to_layer_list[n_layer_string]
88
+ start_num_channel = channel_string_to_n_channel[n_channel_string]
89
+ return ResNet84(BasicBlock, layer_list, start_num_channel=start_num_channel).to(self.device)
90
+
91
+ def expand_first_layer(self):
92
+ # convolutional channel expansion to deal with input mismatch
93
+ multiplier = self.n_images
94
+ self.model.conv1.weight.data = self.model.conv1.weight.data.repeat(1,multiplier,1,1) / multiplier
95
+ means = (0.485, 0.456, 0.406) * multiplier
96
+ stds = (0.229, 0.224, 0.225) * multiplier
97
+ self.normalize_op = transforms.Normalize(means, stds)
98
+ self.channel_mismatch = False
99
+
100
+ def freeze_bn(self):
101
+ # freeze batch norm layers (VRL3 ablation shows modifying how
102
+ # batch norm is trained does not affect performance)
103
+ for module in self.model.modules():
104
+ if isinstance(module, nn.BatchNorm2d):
105
+ if hasattr(module, 'weight'):
106
+ module.weight.requires_grad_(False)
107
+ if hasattr(module, 'bias'):
108
+ module.bias.requires_grad_(False)
109
+ module.eval()
110
+
111
+ def get_parameters_that_require_grad(self):
112
+ params = []
113
+ for name, param in self.named_parameters():
114
+ if param.requires_grad == True:
115
+ params.append(param)
116
+ return params
117
+
118
+ def transform_obs_tensor_batch(self, obs):
119
+ # transform obs batch before put into the pretrained resnet
120
+ new_obs = self.normalize_op(obs.float()/255)
121
+ return new_obs
122
+
123
+ def _forward_impl(self, x):
124
+ x = self.model.get_features(x)
125
+ return x
126
+
127
+ def forward(self, obs):
128
+ o = self.transform_obs_tensor_batch(obs)
129
+ h = self._forward_impl(o)
130
+ return h
131
+
132
+ class Stage3ShallowEncoder(nn.Module):
133
+ def __init__(self, obs_shape, n_channel):
134
+ super().__init__()
135
+
136
+ assert len(obs_shape) == 3
137
+ self.repr_dim = n_channel * 35 * 35
138
+
139
+ self.n_input_channel = obs_shape[0]
140
+ self.conv1 = nn.Conv2d(obs_shape[0], n_channel, 3, stride=2)
141
+ self.conv2 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
142
+ self.conv3 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
143
+ self.conv4 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
144
+ self.relu = nn.ReLU(inplace=True)
145
+
146
+ # TODO here add prediction head so we can do contrastive learning...
147
+
148
+ self.apply(utils.weight_init)
149
+ self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406, 0.485, 0.456, 0.406, 0.485, 0.456, 0.406),
150
+ (0.229, 0.224, 0.225, 0.229, 0.224, 0.225, 0.229, 0.224, 0.225))
151
+
152
+ self.compress = nn.Sequential(nn.Linear(self.repr_dim, 50), nn.LayerNorm(50), nn.Tanh())
153
+ self.pred_layer = nn.Linear(50, 50, bias=False)
154
+
155
+ def transform_obs_tensor_batch(self, obs):
156
+ # transform obs batch before put into the pretrained resnet
157
+ # correct order might be first augment, then resize, then normalize
158
+ # obs = F.interpolate(obs, size=self.pretrained_model_input_size)
159
+ new_obs = obs / 255.0 - 0.5
160
+ # new_obs = self.normalize_op(new_obs)
161
+ return new_obs
162
+
163
+ def _forward_impl(self, x):
164
+ x = self.relu(self.conv1(x))
165
+ x = self.relu(self.conv2(x))
166
+ x = self.relu(self.conv3(x))
167
+ x = self.relu(self.conv4(x))
168
+ return x
169
+
170
+ def forward(self, obs):
171
+ o = self.transform_obs_tensor_batch(obs)
172
+ h = self._forward_impl(o)
173
+ h = h.view(h.shape[0], -1)
174
+ return h
175
+
176
+ def get_anchor_output(self, obs, actions=None):
177
+ # typically go through conv and then compression layer and then a mlp
178
+ # used for UL update
179
+ conv_out = self.forward(obs)
180
+ compressed = self.compress(conv_out)
181
+ pred = self.pred_layer(compressed)
182
+ return pred, conv_out
183
+
184
+ def get_positive_output(self, obs):
185
+ # typically go through conv, compression
186
+ # used for UL update
187
+ conv_out = self.forward(obs)
188
+ compressed = self.compress(conv_out)
189
+ return compressed
190
+
191
+ class Encoder(nn.Module):
192
+ def __init__(self, obs_shape, n_channel):
193
+ super().__init__()
194
+
195
+ assert len(obs_shape) == 3
196
+ self.repr_dim = n_channel * 35 * 35
197
+
198
+ self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], n_channel, 3, stride=2),
199
+ nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
200
+ nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
201
+ nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
202
+ nn.ReLU())
203
+
204
+ self.apply(utils.weight_init)
205
+
206
+ def forward(self, obs):
207
+ obs = obs / 255.0 - 0.5
208
+ h = self.convnet(obs)
209
+ h = h.view(h.shape[0], -1)
210
+ return h
211
+
212
+ class IdentityEncoder(nn.Module):
213
+ def __init__(self, obs_shape):
214
+ super().__init__()
215
+
216
+ assert len(obs_shape) == 1
217
+ self.repr_dim = obs_shape[0]
218
+
219
+ def forward(self, obs):
220
+ return obs
221
+
222
+ class Actor(nn.Module):
223
+ def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
224
+ super().__init__()
225
+
226
+ self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
227
+ nn.LayerNorm(feature_dim), nn.Tanh())
228
+
229
+ self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
230
+ nn.ReLU(inplace=True),
231
+ nn.Linear(hidden_dim, hidden_dim),
232
+ nn.ReLU(inplace=True),
233
+ nn.Linear(hidden_dim, action_shape[0]))
234
+
235
+ self.action_shift=0
236
+ self.action_scale=1
237
+ self.apply(utils.weight_init)
238
+
239
+ def forward(self, obs, std):
240
+ h = self.trunk(obs)
241
+
242
+ mu = self.policy(h)
243
+ mu = torch.tanh(mu)
244
+ mu = mu * self.action_scale + self.action_shift
245
+ std = torch.ones_like(mu) * std
246
+
247
+ dist = utils.TruncatedNormal(mu, std)
248
+ return dist
249
+
250
+ def forward_with_pretanh(self, obs, std):
251
+ h = self.trunk(obs)
252
+
253
+ mu = self.policy(h)
254
+ pretanh = mu
255
+ mu = torch.tanh(mu)
256
+ mu = mu * self.action_scale + self.action_shift
257
+ std = torch.ones_like(mu) * std
258
+
259
+ dist = utils.TruncatedNormal(mu, std)
260
+ return dist, pretanh
261
+
262
+ class Critic(nn.Module):
263
+ def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
264
+ super().__init__()
265
+
266
+ self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
267
+ nn.LayerNorm(feature_dim), nn.Tanh())
268
+
269
+ self.Q1 = nn.Sequential(
270
+ nn.Linear(feature_dim + action_shape[0], hidden_dim),
271
+ nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
272
+ nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
273
+
274
+ self.Q2 = nn.Sequential(
275
+ nn.Linear(feature_dim + action_shape[0], hidden_dim),
276
+ nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
277
+ nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
278
+
279
+ self.apply(utils.weight_init)
280
+
281
+ def forward(self, obs, action):
282
+ h = self.trunk(obs)
283
+ h_action = torch.cat([h, action], dim=-1)
284
+ q1 = self.Q1(h_action)
285
+ q2 = self.Q2(h_action)
286
+
287
+ return q1, q2
288
+
289
+ class VRL3Agent:
290
+ def __init__(self, obs_shape, action_shape, device, use_sensor, lr, feature_dim,
291
+ hidden_dim, critic_target_tau, num_expl_steps,
292
+ update_every_steps, stddev_clip, use_tb, use_data_aug, encoder_lr_scale,
293
+ stage1_model_name, safe_q_target_factor, safe_q_threshold, pretanh_penalty, pretanh_threshold,
294
+ stage2_update_encoder, cql_weight, cql_temp, cql_n_random, stage2_std, stage2_bc_weight,
295
+ stage3_update_encoder, std0, std1, std_n_decay,
296
+ stage3_bc_lam0, stage3_bc_lam1):
297
+ self.device = device
298
+ self.critic_target_tau = critic_target_tau
299
+ self.update_every_steps = update_every_steps
300
+ self.use_tb = use_tb
301
+ self.num_expl_steps = num_expl_steps
302
+
303
+ self.stage2_std = stage2_std
304
+ self.stage2_update_encoder = stage2_update_encoder
305
+
306
+ if std1 > std0:
307
+ std1 = std0
308
+ self.stddev_schedule = "linear(%s,%s,%s)" % (str(std0), str(std1), str(std_n_decay))
309
+
310
+ self.stddev_clip = stddev_clip
311
+ self.use_data_aug = use_data_aug
312
+ self.safe_q_target_factor = safe_q_target_factor
313
+ self.q_threshold = safe_q_threshold
314
+ self.pretanh_penalty = pretanh_penalty
315
+
316
+ self.cql_temp = cql_temp
317
+ self.cql_weight = cql_weight
318
+ self.cql_n_random = cql_n_random
319
+
320
+ self.pretanh_threshold = pretanh_threshold
321
+
322
+ self.stage2_bc_weight = stage2_bc_weight
323
+ self.stage3_bc_lam0 = stage3_bc_lam0
324
+ self.stage3_bc_lam1 = stage3_bc_lam1
325
+
326
+ if stage3_update_encoder and encoder_lr_scale > 0 and len(obs_shape) > 1:
327
+ self.stage3_update_encoder = True
328
+ else:
329
+ self.stage3_update_encoder = False
330
+
331
+ self.encoder = RLEncoder(obs_shape, stage1_model_name, device).to(device)
332
+
333
+ self.act_dim = action_shape[0]
334
+
335
+ if use_sensor:
336
+ downstream_input_dim = self.encoder.repr_dim + 24
337
+ else:
338
+ downstream_input_dim = self.encoder.repr_dim
339
+
340
+ self.actor = Actor(downstream_input_dim, action_shape, feature_dim,
341
+ hidden_dim).to(device)
342
+ self.critic = Critic(downstream_input_dim, action_shape, feature_dim,
343
+ hidden_dim).to(device)
344
+ self.critic_target = Critic(downstream_input_dim, action_shape,
345
+ feature_dim, hidden_dim).to(device)
346
+ self.critic_target.load_state_dict(self.critic.state_dict())
347
+
348
+ # optimizers
349
+ self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
350
+ self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
351
+
352
+ encoder_lr = lr * encoder_lr_scale
353
+ """ set up encoder optimizer """
354
+ self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=encoder_lr)
355
+ # data augmentation
356
+ self.aug = RandomShiftsAug(pad=4)
357
+ self.train()
358
+ self.critic_target.train()
359
+
360
+ def load_pretrained_encoder(self, model_path, verbose=True):
361
+ if verbose:
362
+ print("Trying to load pretrained model from:", model_path)
363
+ checkpoint = torch.load(model_path, map_location=torch.device(self.device))
364
+ state_dict = checkpoint['state_dict']
365
+
366
+ pretrained_dict = {}
367
+ # remove `module.` if model was pretrained with distributed mode
368
+ for k, v in state_dict.items():
369
+ if 'module.' in k:
370
+ name = k[7:]
371
+ else:
372
+ name = k
373
+ pretrained_dict[name] = v
374
+ self.encoder.model.load_state_dict(pretrained_dict, strict=False)
375
+ if verbose:
376
+ print("Pretrained model loaded!")
377
+
378
+ def switch_to_RL_stages(self, verbose=True):
379
+ # run convolutional channel expansion to match input shape
380
+ self.encoder.expand_first_layer()
381
+ if verbose:
382
+ print("Convolutional channel expansion finished: now can take in %d images as input." % self.encoder.n_images)
383
+
384
+ def train(self, training=True):
385
+ self.training = training
386
+ self.encoder.train(training)
387
+ self.actor.train(training)
388
+ self.critic.train(training)
389
+
390
+ def act(self, obs, step, eval_mode, obs_sensor=None, is_tensor_input=False, force_action_std=None):
391
+ """
392
+ obs: 3x84x84, uint8, [0,255]
393
+ """
394
+ # eval_mode should be False when taking an exploration action in stage 3
395
+ # eval_mode should be True when evaluate agent performance
396
+
397
+ if force_action_std == None:
398
+ stddev = utils.schedule(self.stddev_schedule, step)
399
+ if step < self.num_expl_steps and not eval_mode:
400
+ action = np.random.uniform(0, 1, (self.act_dim,)).astype(np.float32)
401
+ return action
402
+ else:
403
+ stddev = force_action_std
404
+
405
+ if is_tensor_input:
406
+ obs = self.encoder(obs)
407
+ else:
408
+ obs = torch.as_tensor(obs, device=self.device)
409
+ obs = self.encoder(obs.unsqueeze(0))
410
+
411
+ if obs_sensor is not None:
412
+ obs_sensor = torch.as_tensor(obs_sensor, device=self.device)
413
+ obs_sensor = obs_sensor.unsqueeze(0)
414
+ obs_combined = torch.cat([obs, obs_sensor], dim=1)
415
+ else:
416
+ obs_combined = obs
417
+
418
+ dist = self.actor(obs_combined, stddev)
419
+ if eval_mode:
420
+ action = dist.mean
421
+ else:
422
+ action = dist.sample(clip=None)
423
+ if step < self.num_expl_steps:
424
+ action.uniform_(-1.0, 1.0)
425
+ return action.cpu().numpy()[0]
426
+
427
+ def update(self, replay_iter, step, stage, use_sensor):
428
+ # for stage 2 and 3, we use the same functions but with different hyperparameters
429
+ assert stage in (2, 3)
430
+ metrics = dict()
431
+
432
+ if stage == 2:
433
+ update_encoder = self.stage2_update_encoder
434
+ stddev = self.stage2_std
435
+ conservative_loss_weight = self.cql_weight
436
+ bc_weight = self.stage2_bc_weight
437
+
438
+ if stage == 3:
439
+ if step % self.update_every_steps != 0:
440
+ return metrics
441
+ update_encoder = self.stage3_update_encoder
442
+
443
+ stddev = utils.schedule(self.stddev_schedule, step)
444
+ conservative_loss_weight = 0
445
+
446
+ # compute stage 3 BC weight
447
+ bc_data_per_iter = 40000
448
+ i_iter = step // bc_data_per_iter
449
+ bc_weight = self.stage3_bc_lam0 * self.stage3_bc_lam1 ** i_iter
450
+
451
+ # batch data
452
+ batch = next(replay_iter)
453
+ if use_sensor: # TODO might want to...?
454
+ obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next = utils.to_torch(batch, self.device)
455
+ else:
456
+ obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device)
457
+ obs_sensor, obs_sensor_next = None, None
458
+
459
+ # augment
460
+ if self.use_data_aug:
461
+ obs = self.aug(obs.float())
462
+ next_obs = self.aug(next_obs.float())
463
+ else:
464
+ obs = obs.float()
465
+ next_obs = next_obs.float()
466
+
467
+ # encode
468
+ if update_encoder:
469
+ obs = self.encoder(obs)
470
+ else:
471
+ with torch.no_grad():
472
+ obs = self.encoder(obs)
473
+
474
+ with torch.no_grad():
475
+ next_obs = self.encoder(next_obs)
476
+
477
+ # concatenate obs with additional sensor observation if needed
478
+ obs_combined = torch.cat([obs, obs_sensor], dim=1) if obs_sensor is not None else obs
479
+ obs_next_combined = torch.cat([next_obs, obs_sensor_next], dim=1) if obs_sensor_next is not None else next_obs
480
+
481
+ # update critic
482
+ metrics.update(self.update_critic_vrl3(obs_combined, action, reward, discount, obs_next_combined,
483
+ stddev, update_encoder, conservative_loss_weight))
484
+
485
+ # update actor, following previous works, we do not use actor gradient for encoder update
486
+ metrics.update(self.update_actor_vrl3(obs_combined.detach(), action, stddev, bc_weight,
487
+ self.pretanh_penalty, self.pretanh_threshold))
488
+
489
+ metrics['batch_reward'] = reward.mean().item()
490
+
491
+ # update critic target networks
492
+ utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
493
+ return metrics
494
+
495
+ def update_critic_vrl3(self, obs, action, reward, discount, next_obs, stddev, update_encoder, conservative_loss_weight):
496
+ metrics = dict()
497
+ batch_size = obs.shape[0]
498
+
499
+ """
500
+ STANDARD Q LOSS COMPUTATION:
501
+ - get standard Q loss first, this is the same as in any other online RL methods
502
+ - except for the safe Q technique, which controls how large the Q value can be
503
+ """
504
+ with torch.no_grad():
505
+ dist = self.actor(next_obs, stddev)
506
+ next_action = dist.sample(clip=self.stddev_clip)
507
+ target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
508
+ target_V = torch.min(target_Q1, target_Q2)
509
+ target_Q = reward + (discount * target_V)
510
+
511
+ if self.safe_q_target_factor < 1:
512
+ target_Q[target_Q > (self.q_threshold + 1)] = self.q_threshold + (target_Q[target_Q > (self.q_threshold+1)] - self.q_threshold) ** self.safe_q_target_factor
513
+
514
+ Q1, Q2 = self.critic(obs, action)
515
+ critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
516
+
517
+ """
518
+ CONSERVATIVE Q LOSS COMPUTATION:
519
+ - sample random actions, actions from policy and next actions from policy, as done in CQL authors' code
520
+ (though this detail is not really discussed in the CQL paper)
521
+ - only compute this loss when conservative loss weight > 0
522
+ """
523
+ if conservative_loss_weight > 0:
524
+ random_actions = (torch.rand((batch_size * self.cql_n_random, self.act_dim), device=self.device) - 0.5) * 2
525
+
526
+ dist = self.actor(obs, stddev)
527
+ current_actions = dist.sample(clip=self.stddev_clip)
528
+
529
+ dist = self.actor(next_obs, stddev)
530
+ next_current_actions = dist.sample(clip=self.stddev_clip)
531
+
532
+ # now get Q values for all these actions (for both Q networks)
533
+ obs_repeat = obs.unsqueeze(1).repeat(1, self.cql_n_random, 1).view(obs.shape[0] * self.cql_n_random,
534
+ obs.shape[1])
535
+
536
+ Q1_rand, Q2_rand = self.critic(obs_repeat,
537
+ random_actions) # TODO might want to double check the logic here see if the repeat is correct
538
+ Q1_rand = Q1_rand.view(obs.shape[0], self.cql_n_random)
539
+ Q2_rand = Q2_rand.view(obs.shape[0], self.cql_n_random)
540
+
541
+ Q1_curr, Q2_curr = self.critic(obs, current_actions)
542
+ Q1_curr_next, Q2_curr_next = self.critic(obs, next_current_actions)
543
+
544
+ # now concat all these Q values together
545
+ Q1_cat = torch.cat([Q1_rand, Q1, Q1_curr, Q1_curr_next], 1)
546
+ Q2_cat = torch.cat([Q2_rand, Q2, Q2_curr, Q2_curr_next], 1)
547
+
548
+ cql_min_q1_loss = torch.logsumexp(Q1_cat / self.cql_temp,
549
+ dim=1, ).mean() * conservative_loss_weight * self.cql_temp
550
+ cql_min_q2_loss = torch.logsumexp(Q2_cat / self.cql_temp,
551
+ dim=1, ).mean() * conservative_loss_weight * self.cql_temp
552
+
553
+ """Subtract the log likelihood of data"""
554
+ conservative_q_loss = cql_min_q1_loss + cql_min_q2_loss - (Q1.mean() + Q2.mean()) * conservative_loss_weight
555
+ critic_loss_combined = critic_loss + conservative_q_loss
556
+ else:
557
+ critic_loss_combined = critic_loss
558
+
559
+ # logging
560
+ metrics['critic_target_q'] = target_Q.mean().item()
561
+ metrics['critic_q1'] = Q1.mean().item()
562
+ metrics['critic_q2'] = Q2.mean().item()
563
+ metrics['critic_loss'] = critic_loss.item()
564
+
565
+ # if needed, also update encoder with critic loss
566
+ if update_encoder:
567
+ self.encoder_opt.zero_grad(set_to_none=True)
568
+ self.critic_opt.zero_grad(set_to_none=True)
569
+ critic_loss_combined.backward()
570
+ self.critic_opt.step()
571
+ if update_encoder:
572
+ self.encoder_opt.step()
573
+
574
+ return metrics
575
+
576
+ def update_actor_vrl3(self, obs, action, stddev, bc_weight, pretanh_penalty, pretanh_threshold):
577
+ metrics = dict()
578
+
579
+ """
580
+ get standard actor loss
581
+ """
582
+ dist, pretanh = self.actor.forward_with_pretanh(obs, stddev)
583
+ current_action = dist.sample(clip=self.stddev_clip)
584
+ log_prob = dist.log_prob(current_action).sum(-1, keepdim=True)
585
+ Q1, Q2 = self.critic(obs, current_action)
586
+ Q = torch.min(Q1, Q2)
587
+ actor_loss = -Q.mean()
588
+
589
+ """
590
+ add BC loss
591
+ """
592
+ if bc_weight > 0:
593
+ # get mean action with no action noise (though this might not be necessary)
594
+ stddev_bc = 0
595
+ dist_bc = self.actor(obs, stddev_bc)
596
+ current_mean_action = dist_bc.sample(clip=self.stddev_clip)
597
+ actor_loss_bc = F.mse_loss(current_mean_action, action) * bc_weight
598
+ else:
599
+ actor_loss_bc = torch.FloatTensor([0]).to(self.device)
600
+
601
+ """
602
+ add pretanh penalty (might not be necessary for Adroit)
603
+ """
604
+ pretanh_loss = 0
605
+ if pretanh_penalty > 0:
606
+ pretanh_loss = pretanh.abs() - pretanh_threshold
607
+ pretanh_loss[pretanh_loss < 0] = 0
608
+ pretanh_loss = (pretanh_loss ** 2).mean() * pretanh_penalty
609
+
610
+ """
611
+ combine actor losses and optimize
612
+ """
613
+ actor_loss_combined = actor_loss + actor_loss_bc + pretanh_loss
614
+
615
+ self.actor_opt.zero_grad(set_to_none=True)
616
+ actor_loss_combined.backward()
617
+ self.actor_opt.step()
618
+
619
+ metrics['actor_loss'] = actor_loss.item()
620
+ metrics['actor_loss_bc'] = actor_loss_bc.item()
621
+ metrics['actor_logprob'] = log_prob.mean().item()
622
+ metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
623
+ metrics['abs_pretanh'] = pretanh.abs().mean().item()
624
+ metrics['max_abs_pretanh'] = pretanh.abs().max().item()
625
+
626
+ return metrics
627
+
628
+ def to(self, device):
629
+ self.actor.to(device)
630
+ self.critic.to(device)
631
+ self.encoder.to(device)
632
+ self.device = device
gym-0.21.0/.github/stale.yml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for probot-stale - https://github.com/probot/stale
2
+
3
+ # Number of days of inactivity before an Issue or Pull Request becomes stale
4
+ daysUntilStale: 60
5
+
6
+ # Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
7
+ # Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
8
+ daysUntilClose: 14
9
+
10
+ # Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
11
+ onlyLabels:
12
+ - more-information-needed
13
+
14
+ # Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
15
+ exemptLabels:
16
+ - pinned
17
+ - security
18
+ - "[Status] Maybe Later"
19
+
20
+ # Set to true to ignore issues in a project (defaults to false)
21
+ exemptProjects: true
22
+
23
+ # Set to true to ignore issues in a milestone (defaults to false)
24
+ exemptMilestones: true
25
+
26
+ # Set to true to ignore issues with an assignee (defaults to false)
27
+ exemptAssignees: true
28
+
29
+ # Label to use when marking as stale
30
+ staleLabel: stale
31
+
32
+ # Comment to post when marking as stale. Set to `false` to disable
33
+ markComment: >
34
+ This issue has been automatically marked as stale because it has not had
35
+ recent activity. It will be closed if no further activity occurs. Thank you
36
+ for your contributions.
37
+
38
+ # Comment to post when removing the stale label.
39
+ # unmarkComment: >
40
+ # Your comment here.
41
+
42
+ # Comment to post when closing a stale Issue or Pull Request.
43
+ # closeComment: >
44
+ # Your comment here.
45
+
46
+ # Limit the number of actions per hour, from 1-30. Default is 30
47
+ limitPerRun: 30
48
+
49
+ # Limit to only `issues` or `pulls`
50
+ only: issues
51
+
52
+ # Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
53
+ # pulls:
54
+ # daysUntilStale: 30
55
+ # markComment: >
56
+ # This pull request has been automatically marked as stale because it has not had
57
+ # recent activity. It will be closed if no further activity occurs. Thank you
58
+ # for your contributions.
59
+
60
+ # issues:
61
+ # exemptLabels:
62
+ # - confirmed
gym-0.21.0/CONTRIBUTING.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gym Contribution Guidelines
2
+
3
+ At this time we are currently accepting the current forms of contributions:
4
+
5
+ - Bug reports (keep in mind that changing environment behavior should be minimized as that requires releasing a new version of the environment and makes results hard to compare across versions)
6
+ - Pull requests for bug fixes
7
+ - Documentation improvements
8
+
9
+ Notably, we are not accepting these forms of contributions:
10
+
11
+ - New environments
12
+ - New features
13
+
14
+ This may change in the future.
15
+ If you wish to make a Gym environment, follow the instructions in [Creating Environments](https://github.com/openai/gym/blob/master/docs/creating-environments.md). When your environment works, you can make a PR to add it to the bottom of the [List of Environments](https://github.com/openai/gym/blob/master/docs/environments.md).
16
+
17
+
18
+ Edit July 27, 2021: Please see https://github.com/openai/gym/issues/2259 for new contributing standards
gym-0.21.0/README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Gym
2
+
3
+ Gym is an open source Python library for developing and comparing reinforcement learning algorithms by providing a standard API to communicate between learning algorithms and environments, as well as a standard set of environments compliant with that API. Since its release, Gym's API has become the field standard for doing this.
4
+
5
+ Gym currently has two pieces of documentation: the [documentation website](http://gym.openai.com) and the [FAQ](https://github.com/openai/gym/wiki/FAQ). A new and more comprehensive documentation website is in the works.
6
+
7
+ ## Installation
8
+
9
+ To install the base Gym library, use `pip install gym`.
10
+
11
+ This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install gym[atari]` or use `pip install gym[all]` to install all dependencies.
12
+
13
+ We support Python 3.6, 3.7, 3.8 and 3.9 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
14
+
15
+ ## API
16
+
17
+ The Gym API's API models environments as simple Python `env` classes. Creating environment instances and interacting with them is very simple- here's an example using the "CartPole-v1" environment:
18
+
19
+ ```python
20
+ import gym
21
+ env = gym.make('CartPole-v1')
22
+
23
+ # env is created, now we can use it:
24
+ for episode in range(10):
25
+ obs = env.reset()
26
+ for step in range(50):
27
+ action = env.action_space.sample() # or given a custom model, action = policy(observation)
28
+ nobs, reward, done, info = env.step(action)
29
+ ```
30
+
31
+ ## Notable Related Libraries
32
+
33
+ * [Stable Baselines 3](https://github.com/DLR-RM/stable-baselines3) is a learning library based on the Gym API. It is our recommendation for beginners who want to start learning things quickly.
34
+ * [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) builds upon SB3, containing optimal hyperparameters for Gym environments as well as code to easily find new ones. Such tuning is almost always required.
35
+ * The [Autonomous Learning Library](https://github.com/cpnota/autonomous-learning-library) and [Tianshou](https://github.com/thu-ml/tianshou) are two reinforcement learning libraries I like that are generally geared towards more experienced users.
36
+ * [PettingZoo](https://github.com/PettingZoo-Team/PettingZoo) is like Gym, but for environments with multiple agents.
37
+
38
+ ## Environment Versioning
39
+
40
+ Gym keeps strict versioning for reproducibility reasons. All environments end in a suffix like "\_v0". When changes are made to environments that might impact learning results, the number is increased by one to prevent potential confusion.
41
+
42
+ ## Citation
43
+
44
+ A whitepaper from when OpenAI Gym just came out is available https://arxiv.org/pdf/1606.01540, and can be cited with the following bibtex entry:
45
+
46
+ ```
47
+ @misc{1606.01540,
48
+ Author = {Greg Brockman and Vicki Cheung and Ludwig Pettersson and Jonas Schneider and John Schulman and Jie Tang and Wojciech Zaremba},
49
+ Title = {OpenAI Gym},
50
+ Year = {2016},
51
+ Eprint = {arXiv:1606.01540},
52
+ }
53
+ ```
54
+
55
+ ## Release Notes
56
+
57
+ There used to be release notes for all the new Gym versions here. New release notes are being moved to [releases page](https://github.com/openai/gym/releases) on GitHub, like most other libraries do. Old notes can be viewed [here](https://github.com/openai/gym/blob/31be35ecd460f670f0c4b653a14c9996b7facc6c/README.rst).
gym-0.21.0/docs/toy_text/blackjack.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Blackjack
2
+ ---
3
+ |Title|Action Type|Action Shape|Action Values|Observation Shape|Observation Values|Average Total Reward|Import|
4
+ | ----------- | -----------| ----------- | -----------| ----------- | -----------| ----------- | -----------|
5
+ |Blackjack|Discrete|(1,)|(0,1)|(3,)|[(0,31),(0,10),(0,1)]| |from gym.envs.toy_text import blackjack|
6
+ ---
7
+
8
+ Blackjack is a card game where the goal is to obtain cards that sum to as near as possible to 21 without going over. They're playing against a fixed dealer.
9
+
10
+ Card Values:
11
+
12
+ - Face cards (Jack, Queen, King) have point value 10.
13
+ - Aces can either count as 11 or 1, and it's called 'usable ace' at 11.
14
+ - Numerical cards (2-9) have value of their number.
15
+
16
+ This game is placed with an infinite deck (or with replacement).
17
+ The game starts with dealer having one face up and one face down card, while player having two face up cards.
18
+
19
+ The player can request additional cards (hit, action=1) until they decide to stop
20
+ (stick, action=0) or exceed 21 (bust).
21
+ After the player sticks, the dealer reveals their facedown card, and draws
22
+ until their sum is 17 or greater. If the dealer goes bust the player wins.
23
+ If neither player nor dealer busts, the outcome (win, lose, draw) is
24
+ decided by whose sum is closer to 21.
25
+
26
+ The agent take a 1-element vector for actions.
27
+ The action space is `(action)`, where:
28
+ - `action` is used to decide stick/hit for values (0,1).
29
+
30
+ The observation of a 3-tuple of: the players current sum,
31
+ the dealer's one showing card (1-10 where 1 is ace), and whether or not the player holds a usable ace (0 or 1).
32
+
33
+ This environment corresponds to the version of the blackjack problem
34
+ described in Example 5.1 in Reinforcement Learning: An Introduction
35
+ by Sutton and Barto.
36
+ http://incompleteideas.net/book/the-book-2nd.html
37
+
38
+ **Rewards:**
39
+
40
+ Reward schedule:
41
+ - win game: +1
42
+ - lose game: -1
43
+ - draw game: 0
44
+ - win game with natural blackjack:
45
+
46
+ +1.5 (if <a href="#nat">natural</a> is True.)
47
+
48
+ +1 (if <a href="#nat">natural</a> is False.)
49
+
50
+ ### Arguments
51
+
52
+ ```
53
+ gym.make('Blackjack-v0', natural=False)
54
+ ```
55
+
56
+ <a id="nat">`natural`</a>: Whether to give an additional reward for starting with a natural blackjack, i.e. starting with an ace and ten (sum is 21).
57
+
58
+ ### Version History
59
+
60
+ * v0: Initial versions release (1.0.0)
gym-0.21.0/docs/toy_text/taxi.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Taxi
2
+ ---
3
+ |Title|Action Type|Action Shape|Action Values|Observation Shape|Observation Values|Average Total Reward|Import|
4
+ | ----------- | -----------| ----------- | -----------| ----------- | -----------| ----------- | -----------|
5
+ |Taxi|Discrete|(1,)|(0,5)|(1,)|(0,499)| |from gym.envs.toy_text import taxi|
6
+ ---
7
+
8
+
9
+ The Taxi Problem
10
+ from "Hierarchical Reinforcement Learning with the MAXQ Value Function Decomposition"
11
+
12
+ by Tom Dietterich
13
+
14
+
15
+
16
+ Description:
17
+
18
+ There are four designated locations in the grid world indicated by R(ed), G(reen), Y(ellow), and B(lue). When the episode starts, the taxi starts off at a random square and the passenger is at a random location. The taxi drives to the passenger's location, picks up the passenger, drives to the passenger's destination (another one of the four specified locations), and then drops off the passenger. Once the passenger is dropped off, the episode ends.
19
+
20
+ MAP:
21
+
22
+ +---------+
23
+ |R: | : :G|
24
+ | : | : : |
25
+ | : : : : |
26
+ | | : | : |
27
+ |Y| : |B: |
28
+ +---------+
29
+
30
+ Actions:
31
+
32
+ There are 6 discrete deterministic actions:
33
+ - 0: move south
34
+ - 1: move north
35
+ - 2: move east
36
+ - 3: move west
37
+ - 4: pickup passenger
38
+ - 5: drop off passenger
39
+
40
+ Observations:
41
+
42
+ There are 500 discrete states since there are 25 taxi positions, 5 possible locations of the passenger (including the case when the passenger is in the taxi), and 4 destination locations.
43
+
44
+ Note that there are 400 states that can actually be reached during an episode. The missing states correspond to situations in which the passenger is at the same location as their destination, as this typically signals the end of an episode.
45
+ Four additional states can be observed right after a successful episodes, when both the passenger and the taxi are at the destination.
46
+ This gives a total of 404 reachable discrete states.
47
+
48
+ Passenger locations:
49
+ - 0: R(ed)
50
+ - 1: G(reen)
51
+ - 2: Y(ellow)
52
+ - 3: B(lue)
53
+ - 4: in taxi
54
+
55
+ Destinations:
56
+ - 0: R(ed)
57
+ - 1: G(reen)
58
+ - 2: Y(ellow)
59
+ - 3: B(lue)
60
+
61
+
62
+
63
+ **Rewards:**
64
+
65
+ - -1 per step reward unless other reward is triggered.
66
+ - +20 delivering passenger.
67
+ - -10 executing "pickup" and "drop-off" actions illegally.
68
+
69
+
70
+ Rendering:
71
+ - blue: passenger
72
+ - magenta: destination
73
+ - yellow: empty taxi
74
+ - green: full taxi
75
+ - other letters (R, G, Y and B): locations for passengers and destinations
76
+ state space is represented by:
77
+ (taxi_row, taxi_col, passenger_location, destination)
78
+
79
+ ### Arguments
80
+
81
+ ```
82
+ gym.make('Taxi-v3')
83
+ ```
84
+
85
+
86
+
87
+ ### Version History
88
+
89
+ * v3: Map Correction + Cleaner Domain Description
90
+ * v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold.
91
+ * v1: Remove (3,2) from locs, add passidx<4 check
92
+ * v0: Initial versions release
gym-0.21.0/scripts/generate_json.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gym import envs, logger
2
+ import json
3
+ import os
4
+ import sys
5
+ import argparse
6
+
7
+ from tests.envs.spec_list import should_skip_env_spec_for_tests
8
+ from tests import generate_rollout_hash
9
+
10
+ DATA_DIR = os.path.join(os.path.dirname(__file__), os.pardir, "gym", "envs", "tests")
11
+ ROLLOUT_STEPS = 100
12
+ episodes = ROLLOUT_STEPS
13
+ steps = ROLLOUT_STEPS
14
+
15
+ ROLLOUT_FILE = os.path.join(DATA_DIR, "rollout.json")
16
+
17
+ if not os.path.isfile(ROLLOUT_FILE):
18
+ logger.info(
19
+ "No rollout file found. Writing empty json file to {}".format(ROLLOUT_FILE)
20
+ )
21
+ with open(ROLLOUT_FILE, "w") as outfile:
22
+ json.dump({}, outfile, indent=2)
23
+
24
+
25
+ def update_rollout_dict(spec, rollout_dict):
26
+ """
27
+ Takes as input the environment spec for which the rollout is to be generated,
28
+ and the existing dictionary of rollouts. Returns True iff the dictionary was
29
+ modified.
30
+ """
31
+ # Skip platform-dependent
32
+ if should_skip_env_spec_for_tests(spec):
33
+ logger.info("Skipping tests for {}".format(spec.id))
34
+ return False
35
+
36
+ # Skip environments that are nondeterministic
37
+ if spec.nondeterministic:
38
+ logger.info("Skipping tests for nondeterministic env {}".format(spec.id))
39
+ return False
40
+
41
+ logger.info("Generating rollout for {}".format(spec.id))
42
+
43
+ try:
44
+ (
45
+ observations_hash,
46
+ actions_hash,
47
+ rewards_hash,
48
+ dones_hash,
49
+ ) = generate_rollout_hash(spec)
50
+ except:
51
+ # If running the env generates an exception, don't write to the rollout file
52
+ logger.warn(
53
+ "Exception {} thrown while generating rollout for {}. Rollout not added.".format(
54
+ sys.exc_info()[0], spec.id
55
+ )
56
+ )
57
+ return False
58
+
59
+ rollout = {}
60
+ rollout["observations"] = observations_hash
61
+ rollout["actions"] = actions_hash
62
+ rollout["rewards"] = rewards_hash
63
+ rollout["dones"] = dones_hash
64
+
65
+ existing = rollout_dict.get(spec.id)
66
+ if existing:
67
+ differs = False
68
+ for key, new_hash in rollout.items():
69
+ differs = differs or existing[key] != new_hash
70
+ if not differs:
71
+ logger.debug("Hashes match with existing for {}".format(spec.id))
72
+ return False
73
+ else:
74
+ logger.warn("Got new hash for {}. Overwriting.".format(spec.id))
75
+
76
+ rollout_dict[spec.id] = rollout
77
+ return True
78
+
79
+
80
+ def add_new_rollouts(spec_ids, overwrite):
81
+ environments = [
82
+ spec for spec in envs.registry.all() if spec.entry_point is not None
83
+ ]
84
+ if spec_ids:
85
+ environments = [spec for spec in environments if spec.id in spec_ids]
86
+ assert len(environments) == len(spec_ids), "Some specs not found"
87
+ with open(ROLLOUT_FILE) as data_file:
88
+ rollout_dict = json.load(data_file)
89
+ modified = False
90
+ for spec in environments:
91
+ if not overwrite and spec.id in rollout_dict:
92
+ logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
93
+ else:
94
+ modified = update_rollout_dict(spec, rollout_dict) or modified
95
+
96
+ if modified:
97
+ logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
98
+ with open(ROLLOUT_FILE, "w") as outfile:
99
+ json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
100
+ else:
101
+ logger.info("No modifications needed.")
102
+
103
+
104
+ if __name__ == "__main__":
105
+ parser = argparse.ArgumentParser()
106
+ parser.add_argument(
107
+ "-f",
108
+ "--force",
109
+ action="store_true",
110
+ help="Overwrite " + "existing rollouts if hashes differ.",
111
+ )
112
+ parser.add_argument("-v", "--verbose", action="store_true")
113
+ parser.add_argument(
114
+ "specs", nargs="*", help="ids of env specs to check (default: all)"
115
+ )
116
+ args = parser.parse_args()
117
+ if args.verbose:
118
+ logger.set_level(logger.INFO)
119
+ add_new_rollouts(args.specs, args.force)
gym-0.21.0/setup.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import sys
3
+ import itertools
4
+
5
+ from setuptools import find_packages, setup
6
+
7
+ # Don't import gym module here, since deps may not be installed
8
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "gym"))
9
+ from version import VERSION
10
+
11
+ # Environment-specific dependencies.
12
+ extras = {
13
+ "atari": ["ale-py~=0.7.1"],
14
+ "accept-rom-license": ["autorom[accept-rom-license]~=0.4.2"],
15
+ "box2d": ["box2d-py==2.3.5", "pyglet>=1.4.0"],
16
+ "classic_control": ["pyglet>=1.4.0"],
17
+ "mujoco": ["mujoco_py>=1.50, <2.0"],
18
+ "robotics": ["mujoco_py>=1.50, <2.0"],
19
+ "toy_text": ["scipy>=1.4.1"],
20
+ "other": ["lz4>=3.1.0", "opencv-python>=3.0"],
21
+ }
22
+
23
+ # Meta dependency groups.
24
+ nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license"])
25
+ nomujoco_groups = set(extras.keys()) - nomujoco_blacklist
26
+
27
+ extras["nomujoco"] = list(
28
+ itertools.chain.from_iterable(map(lambda group: extras[group], nomujoco_groups))
29
+ )
30
+
31
+
32
+ all_blacklist = set(["accept-rom-license"])
33
+ all_groups = set(extras.keys()) - all_blacklist
34
+
35
+ extras["all"] = list(
36
+ itertools.chain.from_iterable(map(lambda group: extras[group], all_groups))
37
+ )
38
+
39
+ setup(
40
+ name="gym",
41
+ version=VERSION,
42
+ description="Gym: A universal API for reinforcement learning environments.",
43
+ url="https://github.com/openai/gym",
44
+ author="OpenAI",
45
+ author_email="jkterry@umd.edu",
46
+ license="",
47
+ packages=[package for package in find_packages() if package.startswith("gym")],
48
+ zip_safe=False,
49
+ install_requires=[
50
+ "numpy>=1.18.0",
51
+ "cloudpickle>=1.2.0",
52
+ "importlib_metadata>=4.8.1; python_version < '3.8'",
53
+ ],
54
+ extras_require=extras,
55
+ package_data={
56
+ "gym": [
57
+ "envs/mujoco/assets/*.xml",
58
+ "envs/classic_control/assets/*.png",
59
+ "envs/robotics/assets/LICENSE.md",
60
+ "envs/robotics/assets/fetch/*.xml",
61
+ "envs/robotics/assets/hand/*.xml",
62
+ "envs/robotics/assets/stls/fetch/*.stl",
63
+ "envs/robotics/assets/stls/hand/*.stl",
64
+ "envs/robotics/assets/textures/*.png",
65
+ ]
66
+ },
67
+ tests_require=["pytest", "mock"],
68
+ python_requires=">=3.6",
69
+ classifiers=[
70
+ "Programming Language :: Python :: 3",
71
+ "Programming Language :: Python :: 3.6",
72
+ "Programming Language :: Python :: 3.7",
73
+ "Programming Language :: Python :: 3.8",
74
+ "Programming Language :: Python :: 3.9",
75
+ ],
76
+ )
mujoco-py-2.1.2.14/.gitignore ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mujoco-py-**
2
+ mjkey.txt
3
+ mujoco_py/generated/cymj*
4
+ _pyxbld*
5
+ dist
6
+ cache
7
+ .idea/*
8
+ *~
9
+ .*~
10
+ *#*#
11
+ *.o
12
+ *.dat
13
+ *.prof
14
+ *.lprof
15
+ *.local
16
+ .realsync
17
+ .DS_Store
18
+ **/*.egg-info
19
+ .cache
20
+ *.ckpt
21
+ *.log
22
+ .ipynb_checkpoints
23
+ venv/
24
+ .vimrc
25
+ *.settings
26
+ *.svn
27
+ .project
28
+ .pydevproject
29
+ tags
30
+ *sublime-project
31
+ *sublime-workspace
32
+ # Intermediate outputs
33
+ __pycache__
34
+ **/__pycache__
35
+ *.pb.*
36
+ *.pyc
37
+ *.swp
38
+ *.swo
39
+ # generated data
40
+ *.rdb
41
+ *.db
42
+ *.avi
43
+ # mujoco outputs
44
+ MUJOCO_LOG.TXT
45
+ model.txt
46
+ .window_data
47
+ .idea/*.xml
48
+ outputfile
49
+ tmp*
50
+ cymj.c
51
+ **/.git
52
+ .eggs/
53
+ *.so
54
+ .python-version
55
+ /build
mujoco-py-2.1.2.14/docs/_static/.gitkeep ADDED
File without changes
mujoco-py-2.1.2.14/docs/build/doctrees/reference.doctree ADDED
Binary file (193 kB). View file
 
mujoco-py-2.1.2.14/mujoco_py.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE.md
2
+ MANIFEST.in
3
+ README.md
4
+ pyproject.toml
5
+ requirements.dev.txt
6
+ requirements.txt
7
+ setup.py
8
+ mujoco_py/__init__.py
9
+ mujoco_py/builder.py
10
+ mujoco_py/cymj.pyx
11
+ mujoco_py/mjbatchrenderer.pyx
12
+ mujoco_py/mjpid.pyx
13
+ mujoco_py/mjrendercontext.pyx
14
+ mujoco_py/mjrenderpool.py
15
+ mujoco_py/mjsim.pyx
16
+ mujoco_py/mjsimstate.pyx
17
+ mujoco_py/mjviewer.py
18
+ mujoco_py/modder.py
19
+ mujoco_py/opengl_context.pyx
20
+ mujoco_py/utils.py
21
+ mujoco_py/version.py
22
+ mujoco_py.egg-info/PKG-INFO
23
+ mujoco_py.egg-info/SOURCES.txt
24
+ mujoco_py.egg-info/dependency_links.txt
25
+ mujoco_py.egg-info/requires.txt
26
+ mujoco_py.egg-info/top_level.txt
27
+ mujoco_py/generated/__init__.py
28
+ mujoco_py/generated/const.py
29
+ mujoco_py/generated/wrappers.pxi
30
+ mujoco_py/gl/__init__.py
31
+ mujoco_py/gl/dummyshim.c
32
+ mujoco_py/gl/egl.h
33
+ mujoco_py/gl/eglext.h
34
+ mujoco_py/gl/eglplatform.h
35
+ mujoco_py/gl/eglshim.c
36
+ mujoco_py/gl/glshim.h
37
+ mujoco_py/gl/khrplatform.h
38
+ mujoco_py/gl/osmesashim.c
39
+ mujoco_py/pxd/__init__.py
40
+ mujoco_py/pxd/mjdata.pxd
41
+ mujoco_py/pxd/mjmodel.pxd
42
+ mujoco_py/pxd/mjrender.pxd
43
+ mujoco_py/pxd/mjui.pxd
44
+ mujoco_py/pxd/mjvisualize.pxd
45
+ mujoco_py/pxd/mujoco.pxd
46
+ mujoco_py/tests/__init__.py
47
+ mujoco_py/tests/include.xml
48
+ mujoco_py/tests/test.xml
49
+ mujoco_py/tests/test_composite.py
50
+ mujoco_py/tests/test_cymj.py
51
+ mujoco_py/tests/test_examples.py
52
+ mujoco_py/tests/test_gen_wrappers.py
53
+ mujoco_py/tests/test_modder.py
54
+ mujoco_py/tests/test_opengl_context.py
55
+ mujoco_py/tests/test_pid.py
56
+ mujoco_py/tests/test_render_pool.py
57
+ mujoco_py/tests/test_substep.py
58
+ mujoco_py/tests/test_vfs.py
59
+ mujoco_py/tests/test_viewer.py
60
+ mujoco_py/tests/utils.py
61
+ xmls/claw.xml
62
+ xmls/door.xml
63
+ xmls/juggler.xml
64
+ xmls/key.xml
65
+ xmls/shelf.xml
66
+ xmls/slider.xml
67
+ xmls/tosser.xml
mujoco-py-2.1.2.14/mujoco_py/__pycache__/builder.cpython-38.pyc ADDED
Binary file (16.9 kB). View file
 
mujoco-py-2.1.2.14/mujoco_py/__pycache__/mjviewer.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
mujoco-py-2.1.2.14/mujoco_py/builder.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import distutils
2
+ import glob
3
+ import os
4
+ import shutil
5
+ import subprocess
6
+ import sys
7
+ from distutils.core import Extension
8
+ from distutils.dist import Distribution
9
+ from distutils.sysconfig import customize_compiler
10
+ from importlib.machinery import ExtensionFileLoader
11
+ from os.path import abspath, dirname, exists, join, getmtime
12
+ from random import choice
13
+ from shutil import move
14
+ from string import ascii_lowercase
15
+
16
+ import fasteners
17
+ import numpy as np
18
+ from Cython.Build import cythonize
19
+ from Cython.Distutils.old_build_ext import old_build_ext as build_ext
20
+ from cffi import FFI
21
+
22
+ from mujoco_py.utils import discover_mujoco
23
+ from mujoco_py.version import get_version
24
+
25
+
26
+ def get_nvidia_lib_dir():
27
+ exists_nvidia_smi = subprocess.call("type nvidia-smi", shell=True,
28
+ stdout=subprocess.PIPE, stderr=subprocess.PIPE) == 0
29
+ if not exists_nvidia_smi:
30
+ return None
31
+ docker_path = '/usr/local/nvidia/lib64'
32
+ if exists(docker_path):
33
+ return docker_path
34
+
35
+ nvidia_path = '/usr/lib/nvidia'
36
+ if exists(nvidia_path):
37
+ return nvidia_path
38
+
39
+ paths = glob.glob('/usr/lib/nvidia-[0-9][0-9][0-9]')
40
+ paths = sorted(paths)
41
+ if len(paths) == 0:
42
+ return None
43
+ if len(paths) > 1:
44
+ print("Choosing the latest nvidia driver: %s, among %s" % (paths[-1], str(paths)))
45
+
46
+ return paths[-1]
47
+
48
+
49
+ def load_cython_ext(mujoco_path):
50
+ """
51
+ Loads the cymj Cython extension. This is safe to be called from
52
+ multiple processes running on the same machine.
53
+
54
+ Cython only gives us back the raw path, regardless of whether
55
+ it found a cached version or actually compiled. Since we do
56
+ non-idempotent postprocessing of the DLL, be extra careful
57
+ to only do that once and then atomically move to the final
58
+ location.
59
+ """
60
+ if ('glfw' in sys.modules and
61
+ 'mujoco' in abspath(sys.modules["glfw"].__file__)):
62
+ print('''
63
+ WARNING: Existing glfw python module detected!
64
+
65
+ MuJoCo comes with its own version of GLFW, so it's preferable to use that one.
66
+
67
+ The easy solution is to `import mujoco_py` _before_ `import glfw`.
68
+ ''')
69
+
70
+ lib_path = os.path.join(mujoco_path, "bin")
71
+ if sys.platform == 'darwin':
72
+ Builder = MacExtensionBuilder
73
+ elif sys.platform == 'linux':
74
+ _ensure_set_env_var("LD_LIBRARY_PATH", lib_path)
75
+ if os.getenv('MUJOCO_PY_FORCE_CPU') is None and get_nvidia_lib_dir() is not None:
76
+ _ensure_set_env_var("LD_LIBRARY_PATH", get_nvidia_lib_dir())
77
+ Builder = LinuxGPUExtensionBuilder
78
+ else:
79
+ Builder = LinuxGPUExtensionBuilder
80
+ elif sys.platform.startswith("win"):
81
+ var = "PATH"
82
+ if var not in os.environ or lib_path not in os.environ[var].split(";"):
83
+ raise Exception("Please add mujoco library to your PATH:\n"
84
+ "set %s=%s;%%%s%%" % (var, lib_path, var))
85
+ Builder = WindowsExtensionBuilder
86
+ else:
87
+ raise RuntimeError("Unsupported platform %s" % sys.platform)
88
+
89
+ builder = Builder(mujoco_path)
90
+ cext_so_path = builder.get_so_file_path()
91
+
92
+ lockpath = os.path.join(os.path.dirname(cext_so_path), 'mujocopy-buildlock')
93
+
94
+ with fasteners.InterProcessLock(lockpath):
95
+ mod = None
96
+ force_rebuild = os.environ.get('MUJOCO_PY_FORCE_REBUILD')
97
+ if force_rebuild:
98
+ # Try to remove the old file, ignore errors if it doesn't exist
99
+ print("Removing old mujoco_py cext", cext_so_path)
100
+ try:
101
+ os.remove(cext_so_path)
102
+ except OSError:
103
+ pass
104
+ if exists(cext_so_path):
105
+ try:
106
+ mod = load_dynamic_ext('cymj', cext_so_path)
107
+ except ImportError:
108
+ print("Import error. Trying to rebuild mujoco_py.")
109
+ if mod is None:
110
+ cext_so_path = builder.build()
111
+ mod = load_dynamic_ext('cymj', cext_so_path)
112
+
113
+ return mod
114
+
115
+
116
+ def _ensure_set_env_var(var_name, lib_path):
117
+ paths = os.environ.get(var_name, "").split(":")
118
+ paths = [os.path.abspath(path) for path in paths]
119
+ if lib_path not in paths:
120
+ raise Exception("\nMissing path to your environment variable. \n"
121
+ "Current values %s=%s\n"
122
+ "Please add following line to .bashrc:\n"
123
+ "export %s=$%s:%s" % (var_name, os.environ.get(var_name, ""),
124
+ var_name, var_name, lib_path))
125
+
126
+
127
+ def load_dynamic_ext(name, path):
128
+ """ Load compiled shared object and return as python module. """
129
+ loader = ExtensionFileLoader(name, path)
130
+ return loader.load_module()
131
+
132
+
133
+ class custom_build_ext(build_ext):
134
+ """
135
+ Custom build_ext to suppress the "-Wstrict-prototypes" warning.
136
+ It arises from the fact that we're using C++. This seems to be
137
+ the cleanest way to get rid of the extra flag.
138
+
139
+ See http://stackoverflow.com/a/36293331/248400
140
+ """
141
+
142
+ def build_extensions(self):
143
+ customize_compiler(self.compiler)
144
+
145
+ try:
146
+ self.compiler.compiler_so.remove("-Wstrict-prototypes")
147
+ except (AttributeError, ValueError):
148
+ pass
149
+ build_ext.build_extensions(self)
150
+
151
+
152
+ def fix_shared_library(so_file, name, library_path):
153
+ """ Used to fixup shared libraries on Linux """
154
+ subprocess.check_call(['patchelf', '--remove-rpath', so_file])
155
+ ldd_output = subprocess.check_output(['ldd', so_file]).decode('utf-8')
156
+
157
+ if name in ldd_output:
158
+ subprocess.check_call(['patchelf', '--remove-needed', name, so_file])
159
+ subprocess.check_call(['patchelf', '--add-needed', library_path, so_file])
160
+
161
+
162
+ def manually_link_libraries(mujoco_path, raw_cext_dll_path):
163
+ """ Used to fix mujoco library linking on Mac """
164
+ root, ext = os.path.splitext(raw_cext_dll_path)
165
+ final_cext_dll_path = root + '_final' + ext
166
+
167
+ # If someone else already built the final DLL, don't bother
168
+ # recreating it here, even though this should still be idempotent.
169
+ if (exists(final_cext_dll_path) and
170
+ getmtime(final_cext_dll_path) >= getmtime(raw_cext_dll_path)):
171
+ return final_cext_dll_path
172
+
173
+ tmp_final_cext_dll_path = final_cext_dll_path + '~'
174
+ shutil.copyfile(raw_cext_dll_path, tmp_final_cext_dll_path)
175
+
176
+ mj_bin_path = join(mujoco_path, 'bin')
177
+
178
+ # Fix the rpath of the generated library -- i lost the Stackoverflow
179
+ # reference here
180
+ from_mujoco_path = '@executable_path/libmujoco210.dylib'
181
+ to_mujoco_path = '%s/libmujoco210.dylib' % mj_bin_path
182
+ subprocess.check_call(['install_name_tool',
183
+ '-change',
184
+ from_mujoco_path,
185
+ to_mujoco_path,
186
+ tmp_final_cext_dll_path])
187
+
188
+ from_glfw_path = 'libglfw.3.dylib'
189
+ to_glfw_path = os.path.join(mj_bin_path, 'libglfw.3.dylib')
190
+ subprocess.check_call(['install_name_tool',
191
+ '-change',
192
+ from_glfw_path,
193
+ to_glfw_path,
194
+ tmp_final_cext_dll_path])
195
+
196
+ os.rename(tmp_final_cext_dll_path, final_cext_dll_path)
197
+ return final_cext_dll_path
198
+
199
+
200
+ class MujocoExtensionBuilder():
201
+
202
+ CYMJ_DIR_PATH = abspath(dirname(__file__))
203
+
204
+ def __init__(self, mujoco_path):
205
+ self.mujoco_path = mujoco_path
206
+ python_version = str(sys.version_info.major) + str(sys.version_info.minor)
207
+ self.version = '%s_%s_%s' % (get_version(), python_version, self.build_base())
208
+ self.extension = Extension(
209
+ 'mujoco_py.cymj',
210
+ sources=[join(self.CYMJ_DIR_PATH, "cymj.pyx")],
211
+ include_dirs=[
212
+ self.CYMJ_DIR_PATH,
213
+ join(mujoco_path, 'include'),
214
+ np.get_include(),
215
+ ],
216
+ libraries=['mujoco210'],
217
+ library_dirs=[join(mujoco_path, 'bin')],
218
+ extra_compile_args=[
219
+ '-fopenmp', # needed for OpenMP
220
+ '-w', # suppress numpy compilation warnings
221
+ ],
222
+ extra_link_args=['-fopenmp'],
223
+ language='c')
224
+
225
+ def build(self):
226
+ built_so_file_path = self._build_impl()
227
+ new_so_file_path = self.get_so_file_path()
228
+ move(built_so_file_path, new_so_file_path)
229
+ return new_so_file_path
230
+
231
+ def build_base(self):
232
+ return self.__class__.__name__.lower()
233
+
234
+ def _build_impl(self):
235
+ dist = Distribution({
236
+ "script_name": None,
237
+ "script_args": ["build_ext"]
238
+ })
239
+ dist.ext_modules = cythonize([self.extension])
240
+ dist.include_dirs = []
241
+ dist.cmdclass = {'build_ext': custom_build_ext}
242
+ build = dist.get_command_obj('build')
243
+ # following the convention of cython's pyxbuild and naming
244
+ # base directory "_pyxbld"
245
+ build.build_base = join(self.CYMJ_DIR_PATH, 'generated',
246
+ '_pyxbld_%s' % (self.version))
247
+ dist.parse_command_line()
248
+ obj_build_ext = dist.get_command_obj("build_ext")
249
+ dist.run_commands()
250
+ built_so_file_path, = obj_build_ext.get_outputs()
251
+ return built_so_file_path
252
+
253
+ def get_so_file_path(self):
254
+ dir_path = abspath(dirname(__file__))
255
+ python_version = str(sys.version_info.major) + str(sys.version_info.minor)
256
+ return join(dir_path, "generated", "cymj_{}_{}.so".format(self.version, python_version))
257
+
258
+
259
+ class WindowsExtensionBuilder(MujocoExtensionBuilder):
260
+
261
+ def __init__(self, mujoco_path):
262
+ super().__init__(mujoco_path)
263
+ os.environ["PATH"] += ";" + join(mujoco_path, "bin")
264
+ self.extension.sources.append(self.CYMJ_DIR_PATH + "/gl/dummyshim.c")
265
+
266
+
267
+ class LinuxCPUExtensionBuilder(MujocoExtensionBuilder):
268
+
269
+ def __init__(self, mujoco_path):
270
+ super().__init__(mujoco_path)
271
+
272
+ self.extension.sources.append(
273
+ join(self.CYMJ_DIR_PATH, "gl", "osmesashim.c"))
274
+ self.extension.libraries.extend(['glewosmesa', 'OSMesa', 'GL'])
275
+ self.extension.runtime_library_dirs = [join(mujoco_path, 'bin')]
276
+
277
+ def _build_impl(self):
278
+ so_file_path = super()._build_impl()
279
+ # Removes absolute paths to libraries. Allows for dynamic loading.
280
+ fix_shared_library(so_file_path, 'libmujoco210.so', 'libmujoco210.so')
281
+ fix_shared_library(so_file_path, 'libglewosmesa.so', 'libglewosmesa.so')
282
+ return so_file_path
283
+
284
+
285
+ class LinuxGPUExtensionBuilder(MujocoExtensionBuilder):
286
+
287
+ def __init__(self, mujoco_path):
288
+ super().__init__(mujoco_path)
289
+
290
+ self.extension.sources.append(self.CYMJ_DIR_PATH + "/gl/eglshim.c")
291
+ self.extension.include_dirs.append(self.CYMJ_DIR_PATH + '/vendor/egl')
292
+ self.extension.libraries.extend(['glewegl'])
293
+ self.extension.runtime_library_dirs = [join(mujoco_path, 'bin')]
294
+
295
+ def _build_impl(self):
296
+ so_file_path = super()._build_impl()
297
+ fix_shared_library(so_file_path, 'libOpenGL.so', 'libOpenGL.so.0')
298
+ fix_shared_library(so_file_path, 'libEGL.so', 'libEGL.so.1')
299
+ fix_shared_library(so_file_path, 'libmujoco210.so', 'libmujoco210.so')
300
+ fix_shared_library(so_file_path, 'libglewegl.so', 'libglewegl.so')
301
+ return so_file_path
302
+
303
+
304
+ class MacExtensionBuilder(MujocoExtensionBuilder):
305
+
306
+ def __init__(self, mujoco_path):
307
+ super().__init__(mujoco_path)
308
+
309
+ self.extension.sources.append(self.CYMJ_DIR_PATH + "/gl/dummyshim.c")
310
+ self.extension.libraries.extend(['glfw.3'])
311
+ self.extension.define_macros = [('ONMAC', None)]
312
+ self.extension.runtime_library_dirs = [join(mujoco_path, 'bin')]
313
+
314
+ def _build_impl(self):
315
+ if not os.environ.get('CC'):
316
+ # Known-working versions of GCC on mac (prefer latest one)
317
+ c_compilers = [
318
+ '/usr/local/bin/gcc-9',
319
+ '/usr/local/bin/gcc-8',
320
+ '/usr/local/bin/gcc-7',
321
+ '/usr/local/bin/gcc-6',
322
+ '/opt/local/bin/gcc-mp-9',
323
+ '/opt/local/bin/gcc-mp-8',
324
+ '/opt/local/bin/gcc-mp-7',
325
+ '/opt/local/bin/gcc-mp-6',
326
+ ]
327
+ available_c_compiler = None
328
+ for c_compiler in c_compilers:
329
+ if distutils.spawn.find_executable(c_compiler) is not None:
330
+ available_c_compiler = c_compiler
331
+ break
332
+ if available_c_compiler is None:
333
+ raise RuntimeError(
334
+ 'Could not find supported GCC executable.\n\n'
335
+ 'HINT: On OS X, install GCC 9.x with '
336
+ '`brew install gcc@9`. or '
337
+ '`port install gcc9`.')
338
+ os.environ['CC'] = available_c_compiler
339
+
340
+ so_file_path = super()._build_impl()
341
+ del os.environ['CC']
342
+ else: # User-directed c compiler
343
+ so_file_path = super()._build_impl()
344
+ return manually_link_libraries(self.mujoco_path, so_file_path)
345
+
346
+
347
+ class MujocoException(Exception):
348
+ pass
349
+
350
+
351
+ def user_warning_raise_exception(warn_bytes):
352
+ '''
353
+ User-defined warning callback, which is called by mujoco on warnings.
354
+ Here we have two primary jobs:
355
+ - Detect known warnings and suggest fixes (with code)
356
+ - Decide whether to raise an Exception and raise if needed
357
+ More cases should be added as we find new failures.
358
+ '''
359
+ # TODO: look through test output to see MuJoCo warnings to catch
360
+ # and recommend. Also fix those tests
361
+ warn = warn_bytes.decode() # Convert bytes to string
362
+ if 'Pre-allocated constraint buffer is full' in warn:
363
+ raise MujocoException(warn + 'Increase njmax in mujoco XML')
364
+ if 'Pre-allocated contact buffer is full' in warn:
365
+ raise MujocoException(warn + 'Increase njconmax in mujoco XML')
366
+ # This unhelpfully-named warning is what you get if you feed MuJoCo NaNs
367
+ if 'Unknown warning type' in warn:
368
+ raise MujocoException(warn + 'Check for NaN in simulation.')
369
+ raise MujocoException('Got MuJoCo Warning: {}'.format(warn))
370
+
371
+
372
+ def user_warning_ignore_exception(warn_bytes):
373
+ pass
374
+
375
+
376
+ class ignore_mujoco_warnings:
377
+ """
378
+ Class to turn off mujoco warning exceptions within a scope. Useful for
379
+ large, vectorized rollouts.
380
+ """
381
+
382
+ def __enter__(self):
383
+ self.prev_user_warning = cymj.get_warning_callback()
384
+ cymj.set_warning_callback(user_warning_ignore_exception)
385
+ return self
386
+
387
+ def __exit__(self, type, value, traceback):
388
+ cymj.set_warning_callback(self.prev_user_warning)
389
+
390
+
391
+ def build_fn_cleanup(name):
392
+ '''
393
+ Cleanup files generated by building callback.
394
+ Set the MUJOCO_PY_DEBUG_FN_BUILDER environment variable to disable cleanup.
395
+ '''
396
+ if not os.environ.get('MUJOCO_PY_DEBUG_FN_BUILDER', False):
397
+ for f in glob.glob(name + '*'):
398
+ try:
399
+ os.remove(f)
400
+ except PermissionError as e:
401
+ # This happens trying to remove libraries on appveyor
402
+ print('Error removing {}, continuing anyway: {}'.format(f, e))
403
+
404
+
405
+ def build_callback_fn(function_string, userdata_names=[]):
406
+ '''
407
+ Builds a C callback function and returns a function pointer int.
408
+
409
+ function_string : str
410
+ This is a string of the C function to be compiled
411
+ userdata_names : list or tuple
412
+ This is an optional list to defince convenience names
413
+
414
+ We compile and link and load the function, and return a function pointer.
415
+ See `MjSim.set_substep_callback()` for an example use of these callbacks.
416
+
417
+ The callback function should match the signature:
418
+ void fun(const mjModel *m, mjData *d);
419
+
420
+ Here's an example function_string:
421
+ ```
422
+ """
423
+ #include <stdio.h>
424
+ void fun(const mjModel* m, mjData* d) {
425
+ printf("hello");
426
+ }
427
+ """
428
+ ```
429
+
430
+ Input and output for the function pass through userdata in the data struct:
431
+ ```
432
+ """
433
+ void fun(const mjModel* m, mjData* d) {
434
+ d->userdata[0] += 1;
435
+ }
436
+ """
437
+ ```
438
+
439
+ `userdata_names` is expected to match the model where the callback is used.
440
+ These can bet set on a model with:
441
+ `model.set_userdata_names([...])`
442
+
443
+ If `userdata_names` is supplied, convenience `#define`s are added for each.
444
+ For example:
445
+ `userdata_names = ['my_sum']`
446
+ Will get gerenerated into the extra line:
447
+ `#define my_sum d->userdata[0]`
448
+ And prepended to the top of the function before compilation.
449
+ Here's an example that takes advantage of this:
450
+ ```
451
+ """
452
+ void fun(const mjModel* m, mjData* d) {
453
+ for (int i = 0; i < m->nu; i++) {
454
+ my_sum += d->ctrl[i];
455
+ }
456
+ }
457
+ """
458
+ ```
459
+ Note these are just C `#define`s and are limited in how they can be used.
460
+
461
+ After compilation, the built library containing the function is loaded
462
+ into memory and all of the files (including the library) are deleted.
463
+ To retain these for debugging set the `MUJOCO_PY_DEBUG_FN_BUILDER` envvar.
464
+
465
+ To save time compiling, these function pointers may be re-used by many
466
+ different consumers. They are thread-safe and don't acquire the GIL.
467
+
468
+ See the file `tests/test_substep.py` for additional examples,
469
+ including an example which iterates over contacts to compute penetrations.
470
+ '''
471
+ assert isinstance(userdata_names, (list, tuple)), \
472
+ 'invalid userdata_names: {}'.format(userdata_names)
473
+ ffibuilder = FFI()
474
+ ffibuilder.cdef('extern uintptr_t __fun;')
475
+ name = '_fn_' + ''.join(choice(ascii_lowercase) for _ in range(15))
476
+ source_string = '#include <mujoco.h>\n'
477
+ # Add defines for each userdata to make setting them easier
478
+ for i, data_name in enumerate(userdata_names):
479
+ source_string += '#define {} d->userdata[{}]\n'.format(data_name, i)
480
+ source_string += function_string
481
+ source_string += '\nuintptr_t __fun = (uintptr_t) fun;'
482
+ # Link against mujoco so we can call mujoco functions from within callback
483
+ ffibuilder.set_source(name, source_string,
484
+ include_dirs=[join(mujoco_path, 'include')],
485
+ library_dirs=[join(mujoco_path, 'bin')],
486
+ libraries=['mujoco210'])
487
+ # Catch compilation exceptions so we can cleanup partial files in that case
488
+ try:
489
+ library_path = ffibuilder.compile(verbose=True)
490
+ except Exception as e:
491
+ build_fn_cleanup(name)
492
+ raise e
493
+ # On Mac the MuJoCo library is linked strangely, so we have to fix it here
494
+ if sys.platform == 'darwin':
495
+ fixed_library_path = manually_link_libraries(mujoco_path, library_path)
496
+ move(fixed_library_path, library_path) # Overwrite with fixed library
497
+ module = load_dynamic_ext(name, library_path)
498
+ # Now that the module is loaded into memory, we can actually delete it
499
+ build_fn_cleanup(name)
500
+ return module.lib.__fun
501
+
502
+
503
+ mujoco_path = discover_mujoco()
504
+ cymj = load_cython_ext(mujoco_path)
505
+
506
+
507
+ # Trick to expose all mj* functions from mujoco in mujoco_py.*
508
+ class dict2(object):
509
+ pass
510
+
511
+
512
+ functions = dict2()
513
+ for func_name in dir(cymj):
514
+ if func_name.startswith("_mj"):
515
+ setattr(functions, func_name[1:], getattr(cymj, func_name))
516
+
517
+ # Set user-defined callbacks that raise assertion with message
518
+ cymj.set_warning_callback(user_warning_raise_exception)
mujoco-py-2.1.2.14/mujoco_py/gl/eglplatform.h ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef __eglplatform_h_
2
+ #define __eglplatform_h_
3
+
4
+ /*
5
+ ** Copyright (c) 2007-2013 The Khronos Group Inc.
6
+ **
7
+ ** Permission is hereby granted, free of charge, to any person obtaining a
8
+ ** copy of this software and/or associated documentation files (the
9
+ ** "Materials"), to deal in the Materials without restriction, including
10
+ ** without limitation the rights to use, copy, modify, merge, publish,
11
+ ** distribute, sublicense, and/or sell copies of the Materials, and to
12
+ ** permit persons to whom the Materials are furnished to do so, subject to
13
+ ** the following conditions:
14
+ **
15
+ ** The above copyright notice and this permission notice shall be included
16
+ ** in all copies or substantial portions of the Materials.
17
+ **
18
+ ** THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19
+ ** EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20
+ ** MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
21
+ ** IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22
+ ** CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23
+ ** TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24
+ ** MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
25
+ */
26
+
27
+ /* Platform-specific types and definitions for egl.h
28
+ * $Revision: 30994 $ on $Date: 2015-04-30 13:36:48 -0700 (Thu, 30 Apr 2015) $
29
+ *
30
+ * Adopters may modify khrplatform.h and this file to suit their platform.
31
+ * You are encouraged to submit all modifications to the Khronos group so that
32
+ * they can be included in future versions of this file. Please submit changes
33
+ * by sending them to the public Khronos Bugzilla (http://khronos.org/bugzilla)
34
+ * by filing a bug against product "EGL" component "Registry".
35
+ */
36
+
37
+ #include "khrplatform.h"
38
+
39
+ /* Macros used in EGL function prototype declarations.
40
+ *
41
+ * EGL functions should be prototyped as:
42
+ *
43
+ * EGLAPI return-type EGLAPIENTRY eglFunction(arguments);
44
+ * typedef return-type (EXPAPIENTRYP PFNEGLFUNCTIONPROC) (arguments);
45
+ *
46
+ * KHRONOS_APICALL and KHRONOS_APIENTRY are defined in KHR/khrplatform.h
47
+ */
48
+
49
+ #ifndef EGLAPI
50
+ #define EGLAPI KHRONOS_APICALL
51
+ #endif
52
+
53
+ #ifndef EGLAPIENTRY
54
+ #define EGLAPIENTRY KHRONOS_APIENTRY
55
+ #endif
56
+ #define EGLAPIENTRYP EGLAPIENTRY*
57
+
58
+ /* The types NativeDisplayType, NativeWindowType, and NativePixmapType
59
+ * are aliases of window-system-dependent types, such as X Display * or
60
+ * Windows Device Context. They must be defined in platform-specific
61
+ * code below. The EGL-prefixed versions of Native*Type are the same
62
+ * types, renamed in EGL 1.3 so all types in the API start with "EGL".
63
+ *
64
+ * Khronos STRONGLY RECOMMENDS that you use the default definitions
65
+ * provided below, since these changes affect both binary and source
66
+ * portability of applications using EGL running on different EGL
67
+ * implementations.
68
+ */
69
+
70
+ #if defined(_WIN32) || defined(__VC32__) && !defined(__CYGWIN__) && !defined(__SCITECH_SNAP__) /* Win32 and WinCE */
71
+ #ifndef WIN32_LEAN_AND_MEAN
72
+ #define WIN32_LEAN_AND_MEAN 1
73
+ #endif
74
+ #include <windows.h>
75
+
76
+ typedef HDC EGLNativeDisplayType;
77
+ typedef HBITMAP EGLNativePixmapType;
78
+ typedef HWND EGLNativeWindowType;
79
+
80
+ #elif defined(__APPLE__) || defined(__WINSCW__) || defined(__SYMBIAN32__) /* Symbian */
81
+
82
+ typedef int EGLNativeDisplayType;
83
+ typedef void *EGLNativeWindowType;
84
+ typedef void *EGLNativePixmapType;
85
+
86
+ #elif defined(__ANDROID__) || defined(ANDROID)
87
+
88
+ #include <android/native_window.h>
89
+
90
+ struct egl_native_pixmap_t;
91
+
92
+ typedef struct ANativeWindow* EGLNativeWindowType;
93
+ typedef struct egl_native_pixmap_t* EGLNativePixmapType;
94
+ typedef void* EGLNativeDisplayType;
95
+
96
+ #elif defined(__unix__)
97
+
98
+ /* X11 (tentative) */
99
+ #include <X11/Xlib.h>
100
+ #include <X11/Xutil.h>
101
+
102
+ typedef Display *EGLNativeDisplayType;
103
+ typedef Pixmap EGLNativePixmapType;
104
+ typedef Window EGLNativeWindowType;
105
+
106
+ #else
107
+ #error "Platform not recognized"
108
+ #endif
109
+
110
+ /* EGL 1.2 types, renamed for consistency in EGL 1.3 */
111
+ typedef EGLNativeDisplayType NativeDisplayType;
112
+ typedef EGLNativePixmapType NativePixmapType;
113
+ typedef EGLNativeWindowType NativeWindowType;
114
+
115
+
116
+ /* Define EGLint. This must be a signed integral type large enough to contain
117
+ * all legal attribute names and values passed into and out of EGL, whether
118
+ * their type is boolean, bitmask, enumerant (symbolic constant), integer,
119
+ * handle, or other. While in general a 32-bit integer will suffice, if
120
+ * handles are 64 bit types, then EGLint should be defined as a signed 64-bit
121
+ * integer type.
122
+ */
123
+ typedef khronos_int32_t EGLint;
124
+
125
+ #endif /* __eglplatform_h */
mujoco-py-2.1.2.14/mujoco_py/gl/glshim.h ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef __GLSHIM_H__
2
+ #define __GLSHIM_H__
3
+
4
+ #include "mujoco.h"
5
+ #include "mjrender.h"
6
+
7
+ #ifdef __cplusplus
8
+ extern "C" {
9
+ #endif
10
+
11
+ int usingEGL();
12
+ int initOpenGL(int device_id);
13
+ void closeOpenGL();
14
+ int makeOpenGLContextCurrent(int device_id);
15
+ int setOpenGLBufferSize(int device_id, int width, int height);
16
+
17
+ unsigned int createPBO(int width, int height, int batchSize, int use_short);
18
+ void freePBO(unsigned int pixelBuffer);
19
+ void copyFBOToPBO(mjrContext* con,
20
+ unsigned int pbo_rgb, unsigned int pbo_depth,
21
+ mjrRect viewport, int bufferOffset);
22
+ void readPBO(unsigned char *buffer_rgb, unsigned short *buffer_depth,
23
+ unsigned int pbo_rgb, unsigned int pbo_depth,
24
+ int width, int height, int batchSize);
25
+
26
+ #ifdef __cplusplus
27
+ } // extern "C"
28
+ #endif
29
+
30
+ #endif
mujoco-py-2.1.2.14/mujoco_py/gl/khrplatform.h ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef __khrplatform_h_
2
+ #define __khrplatform_h_
3
+
4
+ /*
5
+ ** Copyright (c) 2008-2009 The Khronos Group Inc.
6
+ **
7
+ ** Permission is hereby granted, free of charge, to any person obtaining a
8
+ ** copy of this software and/or associated documentation files (the
9
+ ** "Materials"), to deal in the Materials without restriction, including
10
+ ** without limitation the rights to use, copy, modify, merge, publish,
11
+ ** distribute, sublicense, and/or sell copies of the Materials, and to
12
+ ** permit persons to whom the Materials are furnished to do so, subject to
13
+ ** the following conditions:
14
+ **
15
+ ** The above copyright notice and this permission notice shall be included
16
+ ** in all copies or substantial portions of the Materials.
17
+ **
18
+ ** THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19
+ ** EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20
+ ** MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
21
+ ** IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22
+ ** CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23
+ ** TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24
+ ** MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
25
+ */
26
+
27
+ /* Khronos platform-specific types and definitions.
28
+ *
29
+ * $Revision: 32517 $ on $Date: 2016-03-11 02:41:19 -0800 (Fri, 11 Mar 2016) $
30
+ *
31
+ * Adopters may modify this file to suit their platform. Adopters are
32
+ * encouraged to submit platform specific modifications to the Khronos
33
+ * group so that they can be included in future versions of this file.
34
+ * Please submit changes by sending them to the public Khronos Bugzilla
35
+ * (http://khronos.org/bugzilla) by filing a bug against product
36
+ * "Khronos (general)" component "Registry".
37
+ *
38
+ * A predefined template which fills in some of the bug fields can be
39
+ * reached using http://tinyurl.com/khrplatform-h-bugreport, but you
40
+ * must create a Bugzilla login first.
41
+ *
42
+ *
43
+ * See the Implementer's Guidelines for information about where this file
44
+ * should be located on your system and for more details of its use:
45
+ * http://www.khronos.org/registry/implementers_guide.pdf
46
+ *
47
+ * This file should be included as
48
+ * #include <KHR/khrplatform.h>
49
+ * by Khronos client API header files that use its types and defines.
50
+ *
51
+ * The types in khrplatform.h should only be used to define API-specific types.
52
+ *
53
+ * Types defined in khrplatform.h:
54
+ * khronos_int8_t signed 8 bit
55
+ * khronos_uint8_t unsigned 8 bit
56
+ * khronos_int16_t signed 16 bit
57
+ * khronos_uint16_t unsigned 16 bit
58
+ * khronos_int32_t signed 32 bit
59
+ * khronos_uint32_t unsigned 32 bit
60
+ * khronos_int64_t signed 64 bit
61
+ * khronos_uint64_t unsigned 64 bit
62
+ * khronos_intptr_t signed same number of bits as a pointer
63
+ * khronos_uintptr_t unsigned same number of bits as a pointer
64
+ * khronos_ssize_t signed size
65
+ * khronos_usize_t unsigned size
66
+ * khronos_float_t signed 32 bit floating point
67
+ * khronos_time_ns_t unsigned 64 bit time in nanoseconds
68
+ * khronos_utime_nanoseconds_t unsigned time interval or absolute time in
69
+ * nanoseconds
70
+ * khronos_stime_nanoseconds_t signed time interval in nanoseconds
71
+ * khronos_boolean_enum_t enumerated boolean type. This should
72
+ * only be used as a base type when a client API's boolean type is
73
+ * an enum. Client APIs which use an integer or other type for
74
+ * booleans cannot use this as the base type for their boolean.
75
+ *
76
+ * Tokens defined in khrplatform.h:
77
+ *
78
+ * KHRONOS_FALSE, KHRONOS_TRUE Enumerated boolean false/true values.
79
+ *
80
+ * KHRONOS_SUPPORT_INT64 is 1 if 64 bit integers are supported; otherwise 0.
81
+ * KHRONOS_SUPPORT_FLOAT is 1 if floats are supported; otherwise 0.
82
+ *
83
+ * Calling convention macros defined in this file:
84
+ * KHRONOS_APICALL
85
+ * KHRONOS_APIENTRY
86
+ * KHRONOS_APIATTRIBUTES
87
+ *
88
+ * These may be used in function prototypes as:
89
+ *
90
+ * KHRONOS_APICALL void KHRONOS_APIENTRY funcname(
91
+ * int arg1,
92
+ * int arg2) KHRONOS_APIATTRIBUTES;
93
+ */
94
+
95
+ /*-------------------------------------------------------------------------
96
+ * Definition of KHRONOS_APICALL
97
+ *-------------------------------------------------------------------------
98
+ * This precedes the return type of the function in the function prototype.
99
+ */
100
+ #if defined(_WIN32) && !defined(__SCITECH_SNAP__)
101
+ # define KHRONOS_APICALL __declspec(dllimport)
102
+ #elif defined (__SYMBIAN32__)
103
+ # define KHRONOS_APICALL IMPORT_C
104
+ #elif defined(__ANDROID__)
105
+ # include <sys/cdefs.h>
106
+ # define KHRONOS_APICALL __attribute__((visibility("default"))) __NDK_FPABI__
107
+ #else
108
+ # define KHRONOS_APICALL
109
+ #endif
110
+
111
+ /*-------------------------------------------------------------------------
112
+ * Definition of KHRONOS_APIENTRY
113
+ *-------------------------------------------------------------------------
114
+ * This follows the return type of the function and precedes the function
115
+ * name in the function prototype.
116
+ */
117
+ #if defined(_WIN32) && !defined(_WIN32_WCE) && !defined(__SCITECH_SNAP__)
118
+ /* Win32 but not WinCE */
119
+ # define KHRONOS_APIENTRY __stdcall
120
+ #else
121
+ # define KHRONOS_APIENTRY
122
+ #endif
123
+
124
+ /*-------------------------------------------------------------------------
125
+ * Definition of KHRONOS_APIATTRIBUTES
126
+ *-------------------------------------------------------------------------
127
+ * This follows the closing parenthesis of the function prototype arguments.
128
+ */
129
+ #if defined (__ARMCC_2__)
130
+ #define KHRONOS_APIATTRIBUTES __softfp
131
+ #else
132
+ #define KHRONOS_APIATTRIBUTES
133
+ #endif
134
+
135
+ /*-------------------------------------------------------------------------
136
+ * basic type definitions
137
+ *-----------------------------------------------------------------------*/
138
+ #if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L) || defined(__GNUC__) || defined(__SCO__) || defined(__USLC__)
139
+
140
+
141
+ /*
142
+ * Using <stdint.h>
143
+ */
144
+ #include <stdint.h>
145
+ typedef int32_t khronos_int32_t;
146
+ typedef uint32_t khronos_uint32_t;
147
+ typedef int64_t khronos_int64_t;
148
+ typedef uint64_t khronos_uint64_t;
149
+ #define KHRONOS_SUPPORT_INT64 1
150
+ #define KHRONOS_SUPPORT_FLOAT 1
151
+
152
+ #elif defined(__VMS ) || defined(__sgi)
153
+
154
+ /*
155
+ * Using <inttypes.h>
156
+ */
157
+ #include <inttypes.h>
158
+ typedef int32_t khronos_int32_t;
159
+ typedef uint32_t khronos_uint32_t;
160
+ typedef int64_t khronos_int64_t;
161
+ typedef uint64_t khronos_uint64_t;
162
+ #define KHRONOS_SUPPORT_INT64 1
163
+ #define KHRONOS_SUPPORT_FLOAT 1
164
+
165
+ #elif defined(_WIN32) && !defined(__SCITECH_SNAP__)
166
+
167
+ /*
168
+ * Win32
169
+ */
170
+ typedef __int32 khronos_int32_t;
171
+ typedef unsigned __int32 khronos_uint32_t;
172
+ typedef __int64 khronos_int64_t;
173
+ typedef unsigned __int64 khronos_uint64_t;
174
+ #define KHRONOS_SUPPORT_INT64 1
175
+ #define KHRONOS_SUPPORT_FLOAT 1
176
+
177
+ #elif defined(__sun__) || defined(__digital__)
178
+
179
+ /*
180
+ * Sun or Digital
181
+ */
182
+ typedef int khronos_int32_t;
183
+ typedef unsigned int khronos_uint32_t;
184
+ #if defined(__arch64__) || defined(_LP64)
185
+ typedef long int khronos_int64_t;
186
+ typedef unsigned long int khronos_uint64_t;
187
+ #else
188
+ typedef long long int khronos_int64_t;
189
+ typedef unsigned long long int khronos_uint64_t;
190
+ #endif /* __arch64__ */
191
+ #define KHRONOS_SUPPORT_INT64 1
192
+ #define KHRONOS_SUPPORT_FLOAT 1
193
+
194
+ #elif 0
195
+
196
+ /*
197
+ * Hypothetical platform with no float or int64 support
198
+ */
199
+ typedef int khronos_int32_t;
200
+ typedef unsigned int khronos_uint32_t;
201
+ #define KHRONOS_SUPPORT_INT64 0
202
+ #define KHRONOS_SUPPORT_FLOAT 0
203
+
204
+ #else
205
+
206
+ /*
207
+ * Generic fallback
208
+ */
209
+ #include <stdint.h>
210
+ typedef int32_t khronos_int32_t;
211
+ typedef uint32_t khronos_uint32_t;
212
+ typedef int64_t khronos_int64_t;
213
+ typedef uint64_t khronos_uint64_t;
214
+ #define KHRONOS_SUPPORT_INT64 1
215
+ #define KHRONOS_SUPPORT_FLOAT 1
216
+
217
+ #endif
218
+
219
+
220
+ /*
221
+ * Types that are (so far) the same on all platforms
222
+ */
223
+ typedef signed char khronos_int8_t;
224
+ typedef unsigned char khronos_uint8_t;
225
+ typedef signed short int khronos_int16_t;
226
+ typedef unsigned short int khronos_uint16_t;
227
+
228
+ /*
229
+ * Types that differ between LLP64 and LP64 architectures - in LLP64,
230
+ * pointers are 64 bits, but 'long' is still 32 bits. Win64 appears
231
+ * to be the only LLP64 architecture in current use.
232
+ */
233
+ #ifdef _WIN64
234
+ typedef signed long long int khronos_intptr_t;
235
+ typedef unsigned long long int khronos_uintptr_t;
236
+ typedef signed long long int khronos_ssize_t;
237
+ typedef unsigned long long int khronos_usize_t;
238
+ #else
239
+ typedef signed long int khronos_intptr_t;
240
+ typedef unsigned long int khronos_uintptr_t;
241
+ typedef signed long int khronos_ssize_t;
242
+ typedef unsigned long int khronos_usize_t;
243
+ #endif
244
+
245
+ #if KHRONOS_SUPPORT_FLOAT
246
+ /*
247
+ * Float type
248
+ */
249
+ typedef float khronos_float_t;
250
+ #endif
251
+
252
+ #if KHRONOS_SUPPORT_INT64
253
+ /* Time types
254
+ *
255
+ * These types can be used to represent a time interval in nanoseconds or
256
+ * an absolute Unadjusted System Time. Unadjusted System Time is the number
257
+ * of nanoseconds since some arbitrary system event (e.g. since the last
258
+ * time the system booted). The Unadjusted System Time is an unsigned
259
+ * 64 bit value that wraps back to 0 every 584 years. Time intervals
260
+ * may be either signed or unsigned.
261
+ */
262
+ typedef khronos_uint64_t khronos_utime_nanoseconds_t;
263
+ typedef khronos_int64_t khronos_stime_nanoseconds_t;
264
+ #endif
265
+
266
+ /*
267
+ * Dummy value used to pad enum types to 32 bits.
268
+ */
269
+ #ifndef KHRONOS_MAX_ENUM
270
+ #define KHRONOS_MAX_ENUM 0x7FFFFFFF
271
+ #endif
272
+
273
+ /*
274
+ * Enumerated boolean type
275
+ *
276
+ * Values other than zero should be considered to be true. Therefore
277
+ * comparisons should not be made against KHRONOS_TRUE.
278
+ */
279
+ typedef enum {
280
+ KHRONOS_FALSE = 0,
281
+ KHRONOS_TRUE = 1,
282
+ KHRONOS_BOOLEAN_ENUM_FORCE_SIZE = KHRONOS_MAX_ENUM
283
+ } khronos_boolean_enum_t;
284
+
285
+ #endif /* __khrplatform_h_ */
mujoco-py-2.1.2.14/mujoco_py/gl/osmesashim.c ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <GL/osmesa.h>
2
+ #include "glshim.h"
3
+
4
+ OSMesaContext ctx;
5
+
6
+ // this size was picked pretty arbitrarily
7
+ int BUFFER_WIDTH = 1024;
8
+ int BUFFER_HEIGHT = 1024;
9
+ // 4 channels for RGBA
10
+ unsigned char buffer[1024 * 1024 * 4];
11
+
12
+ int is_initialized = 0;
13
+
14
+ int usingEGL() {
15
+ return 0;
16
+ }
17
+
18
+ int initOpenGL(int device_id) {
19
+ if (is_initialized)
20
+ return 1;
21
+
22
+ // note: device id not used here
23
+ ctx = OSMesaCreateContextExt(GL_RGBA, 24, 8, 8, 0);
24
+ if( !ctx ) {
25
+ printf("OSMesa context creation failed\n");
26
+ return -1;
27
+ }
28
+
29
+ if( !OSMesaMakeCurrent(ctx, buffer, GL_UNSIGNED_BYTE, BUFFER_WIDTH, BUFFER_HEIGHT) ) {
30
+ printf("OSMesa make current failed\n");
31
+ return -1;
32
+ }
33
+
34
+ is_initialized = 1;
35
+
36
+ return 1;
37
+ }
38
+
39
+ int makeOpenGLContextCurrent(int device_id) {
40
+ // Don't need to make context current here, causes issues with large tests
41
+ return 1;
42
+ }
43
+
44
+ int setOpenGLBufferSize(int device_id, int width, int height) {
45
+ if (width > BUFFER_WIDTH || height > BUFFER_HEIGHT) {
46
+ printf("Buffer size too big\n");
47
+ return -1;
48
+ }
49
+ // Noop since we don't support changing the actual buffer
50
+ return 1;
51
+ }
52
+
53
+ void closeOpenGL() {
54
+ if (is_initialized) {
55
+ OSMesaDestroyContext(ctx);
56
+ is_initialized = 0;
57
+ }
58
+ }
59
+
60
+ unsigned int createPBO(int width, int height, int batchSize, int use_short) {
61
+ return 0;
62
+ }
63
+
64
+ void freePBO(unsigned int pixelBuffer) {
65
+ }
66
+
67
+ void copyFBOToPBO(mjrContext* con,
68
+ unsigned int pbo_rgb, unsigned int pbo_depth,
69
+ mjrRect viewport, int bufferOffset) {
70
+ }
71
+
72
+ void readPBO(unsigned char *buffer_rgb, unsigned short *buffer_depth,
73
+ unsigned int pbo_rgb, unsigned int pbo_depth,
74
+ int width, int height, int batchSize) {
75
+ }
mujoco-py-2.1.2.14/mujoco_py/mjbatchrenderer.pyx ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import pycuda.driver as drv
3
+ except ImportError:
4
+ drv = None
5
+
6
+
7
+ class MjBatchRendererException(Exception):
8
+ pass
9
+
10
+
11
+ class MjBatchRendererNotSupported(MjBatchRendererException):
12
+ pass
13
+
14
+
15
+ class CudaNotEnabledError(MjBatchRendererException):
16
+ pass
17
+
18
+
19
+ class CudaBufferNotMappedError(MjBatchRendererException):
20
+ pass
21
+
22
+
23
+ class CudaBufferMappedError(MjBatchRendererException):
24
+ pass
25
+
26
+
27
+ class MjBatchRenderer(object):
28
+ """
29
+ Utility class for rendering into OpenGL Pixel Buffer Objects (PBOs),
30
+ which allows for accessing multiple rendered images in batch.
31
+
32
+ If used with CUDA (i.e. initialized with use_cuda=True), you need
33
+ to call map/unmap when accessing CUDA buffer pointer. This is to
34
+ ensure that all OpenGL instructions have completed:
35
+
36
+ renderer = MjBatchRenderer(100, 100, use_cuda=True)
37
+ renderer.render(sim)
38
+
39
+ renderer.map()
40
+ image = renderer.read()
41
+ renderer.unmap()
42
+ """
43
+
44
+ def __init__(self, width, height, batch_size=1, device_id=0,
45
+ depth=False, use_cuda=False):
46
+ """
47
+ Args:
48
+ - width (int): Image width.
49
+ - height (int): Image height.
50
+ - batch_size (int): Size of batch to render into. Memory is
51
+ allocated once upon initialization of object.
52
+ - device_id (int): Device to use for storing the batch.
53
+ - depth (bool): if True, render depth in addition to RGB.
54
+ - use_cuda (bool): if True, use OpenGL-CUDA interop to map
55
+ the PBO onto a CUDA buffer.
56
+ """
57
+ # Early initialization to prevent failure in __del__
58
+ self._use_cuda = False
59
+ self.pbo_depth, self.pbo_depth = 0, 0
60
+
61
+ if not usingEGL():
62
+ raise MjBatchRendererNotSupported(
63
+ "MjBatchRenderer currently only supported with EGL-backed"
64
+ "rendering context.")
65
+
66
+ # Make sure OpenGL Context is available before creating PBOs
67
+ initOpenGL(device_id)
68
+ makeOpenGLContextCurrent(device_id)
69
+
70
+ self.pbo_rgb = createPBO(width, height, batch_size, 0)
71
+ self.pbo_depth = createPBO(width, height, batch_size, 1) if depth else 0
72
+
73
+ self._depth = depth
74
+ self._device_id = device_id
75
+ self._width = width
76
+ self._height = height
77
+ self._batch_size = batch_size
78
+ self._current_batch_offset = 0
79
+
80
+ self._use_cuda = use_cuda
81
+ self._cuda_buffers_are_mapped = False
82
+ self._cuda_rgb_ptr, self._cuda_depth_ptr = None, None
83
+ if use_cuda:
84
+ self._init_cuda()
85
+
86
+ def _init_cuda(self):
87
+ if drv is None:
88
+ raise ImportError("Failed to import pycuda.")
89
+ # Use local imports so that we don't have to make pycuda
90
+ # opengl interop a requirement
91
+ from pycuda.gl import RegisteredBuffer
92
+
93
+ drv.init()
94
+ device = drv.Device(self._device_id)
95
+ self._cuda_context = device.make_context()
96
+ self._cuda_context.push()
97
+
98
+ self._cuda_rgb_pbo = RegisteredBuffer(self.pbo_rgb)
99
+ if self._depth:
100
+ self._cuda_depth_pbo = RegisteredBuffer(self.pbo_depth)
101
+
102
+ def map(self):
103
+ """ Map OpenGL buffer to CUDA for reading. """
104
+ if not self._use_cuda:
105
+ raise CudaNotEnabledError()
106
+ elif self._cuda_buffers_are_mapped:
107
+ return # just make it a no-op
108
+
109
+ self._cuda_context.push()
110
+ self._cuda_rgb_mapping = self._cuda_rgb_pbo.map()
111
+ ptr, self._cuda_rgb_buf_size = (
112
+ self._cuda_rgb_mapping.device_ptr_and_size())
113
+ assert ptr is not None and self._cuda_rgb_buf_size > 0
114
+ if self._cuda_rgb_ptr is None:
115
+ self._cuda_rgb_ptr = ptr
116
+
117
+ # There doesn't seem to be a guarantee from the API that the
118
+ # pointer will be the same between mappings, but empirically
119
+ # this has been true. If this isn't true, we need to modify
120
+ # the interface to MjBatchRenderer to make this clearer to user.
121
+ # So, hopefully we won't hit this assert.
122
+ assert self._cuda_rgb_ptr == ptr, (
123
+ "Mapped CUDA rgb buffer pointer %d doesn't match old pointer %d" %
124
+ (ptr, self._cuda_rgb_ptr))
125
+
126
+ if self._depth:
127
+ self._cuda_depth_mapping = self._cuda_depth_pbo.map()
128
+ ptr, self._cuda_depth_buf_size = (
129
+ self._cuda_depth_mapping.device_ptr_and_size())
130
+ assert ptr is not None and self._cuda_depth_buf_size > 0
131
+ if self._cuda_depth_ptr is None:
132
+ self._cuda_depth_ptr = ptr
133
+ assert self._cuda_depth_ptr == ptr, (
134
+ "Mapped CUDA depth buffer pointer %d doesn't match old pointer %d" %
135
+ (ptr, self._cuda_depth_ptr))
136
+
137
+ self._cuda_buffers_are_mapped = True
138
+
139
+ def unmap(self):
140
+ """ Unmap OpenGL buffer from CUDA so that it can be rendered into. """
141
+ if not self._use_cuda:
142
+ raise CudaNotEnabledError()
143
+ elif not self._cuda_buffers_are_mapped:
144
+ return # just make it a no-op
145
+
146
+ self._cuda_context.push()
147
+ self._cuda_rgb_mapping.unmap()
148
+ self._cuda_rgb_mapping = None
149
+ self._cuda_rgb_ptr = None
150
+ if self._depth:
151
+ self._cuda_depth_mapping.unmap()
152
+ self._cuda_depth_mapping = None
153
+ self._cuda_depth_ptr = None
154
+
155
+ self._cuda_buffers_are_mapped = False
156
+
157
+ def prepare_render_context(self, sim):
158
+ """
159
+ Set up the rendering context for an MjSim. Also happens automatically
160
+ on `.render()`.
161
+ """
162
+ for c in sim.render_contexts:
163
+ if (c.offscreen and
164
+ isinstance(c.opengl_context, OffscreenOpenGLContext) and
165
+ c.opengl_context.device_id == self._device_id):
166
+ return c
167
+
168
+ return MjRenderContext(sim, device_id=self._device_id)
169
+
170
+ def render(self, sim, camera_id=None, batch_offset=None):
171
+ """
172
+ Render current scene from the MjSim into the buffer. By
173
+ default the batch offset is automatically incremented with
174
+ each call. It can be reset with the batch_offset parameter.
175
+
176
+ This method doesn't return anything. Use the `.read` method
177
+ to read the buffer, or access the buffer pointer directly with
178
+ e.g. `.cuda_rgb_buffer_pointer` accessor.
179
+
180
+ Args:
181
+ - sim (MjSim): The simulator to use for rendering.
182
+ - camera_id (int): MuJoCo id for the camera, from
183
+ `sim.model.camera_name2id()`.
184
+ - batch_offset (int): offset in batch to render to.
185
+ """
186
+ if self._use_cuda and self._cuda_buffers_are_mapped:
187
+ raise CudaBufferMappedError(
188
+ "CUDA buffers must be unmapped before calling render.")
189
+
190
+ if batch_offset is not None:
191
+ if batch_offset < 0 or batch_offset >= self._batch_size:
192
+ raise ValueError("batch_offset out of range")
193
+ self._current_batch_offset = batch_offset
194
+
195
+ # Ensure the correct device context is used (this takes ~1 µs)
196
+ makeOpenGLContextCurrent(self._device_id)
197
+
198
+ render_context = self.prepare_render_context(sim)
199
+ render_context.update_offscreen_size(self._width, self._height)
200
+ render_context.render(self._width, self._height, camera_id=camera_id)
201
+
202
+ cdef mjrRect viewport
203
+ viewport.left = 0
204
+ viewport.bottom = 0
205
+ viewport.width = self._width
206
+ viewport.height = self._height
207
+
208
+ cdef PyMjrContext con = <PyMjrContext> render_context.con
209
+ copyFBOToPBO(con.ptr, self.pbo_rgb, self.pbo_depth,
210
+ viewport, self._current_batch_offset)
211
+
212
+ self._current_batch_offset = (self._current_batch_offset + 1) % self._batch_size
213
+
214
+ def read(self):
215
+ """
216
+ Transfer a copy of the buffer from the GPU to the CPU as a numpy array.
217
+
218
+ Returns:
219
+ - rgb_batch (numpy array): batch of rgb images in uint8 NHWC format
220
+ - depth_batch (numpy array): batch of depth images in uint16 NHWC format
221
+ """
222
+ if self._use_cuda:
223
+ return self._read_cuda()
224
+ else:
225
+ return self._read_nocuda()
226
+
227
+ def _read_cuda(self):
228
+ if not self._cuda_buffers_are_mapped:
229
+ raise CudaBufferNotMappedError(
230
+ "CUDA buffers must be mapped before reading")
231
+
232
+ rgb_arr = drv.from_device(
233
+ self._cuda_rgb_ptr,
234
+ shape=(self._batch_size, self._height, self._width, 3),
235
+ dtype=np.uint8)
236
+
237
+ if self._depth:
238
+ depth_arr = drv.from_device(
239
+ self._cuda_depth_ptr,
240
+ shape=(self._batch_size, self._height, self._width),
241
+ dtype=np.uint16)
242
+ else:
243
+ depth_arr = None
244
+
245
+ return rgb_arr, depth_arr
246
+
247
+ def _read_nocuda(self):
248
+ rgb_arr = np.zeros(3 * self._width * self._height * self._batch_size, dtype=np.uint8)
249
+ cdef unsigned char[::view.contiguous] rgb_view = rgb_arr
250
+ depth_arr = np.zeros(self._width * self._height * self._batch_size, dtype=np.uint16)
251
+ cdef unsigned short[::view.contiguous] depth_view = depth_arr
252
+
253
+ if self._depth:
254
+ readPBO(&rgb_view[0], &depth_view[0], self.pbo_rgb, self.pbo_depth,
255
+ self._width, self._height, self._batch_size)
256
+ depth_arr = depth_arr.reshape(self._batch_size, self._height, self._width)
257
+ else:
258
+ readPBO(&rgb_view[0], NULL, self.pbo_rgb, 0,
259
+ self._width, self._height, self._batch_size)
260
+ # Fine to throw aray depth_arr above since malloc/free is cheap
261
+ depth_arr = None
262
+
263
+ rgb_arr = rgb_arr.reshape(self._batch_size, self._height, self._width, 3)
264
+ return rgb_arr, depth_arr
265
+
266
+ @property
267
+ def cuda_rgb_buffer_pointer(self):
268
+ """ Pointer to CUDA buffer for RGB batch. """
269
+ if not self._use_cuda:
270
+ raise CudaNotEnabledError()
271
+ elif not self._cuda_buffers_are_mapped:
272
+ raise CudaBufferNotMappedError()
273
+ return self._cuda_rgb_ptr
274
+
275
+ @property
276
+ def cuda_depth_buffer_pointer(self):
277
+ """ Pointer to CUDA buffer for depth batch. """
278
+ if not self._use_cuda:
279
+ raise CudaNotEnabledError()
280
+ elif not self._cuda_buffers_are_mapped:
281
+ raise CudaBufferNotMappedError()
282
+ if not self._depth:
283
+ raise RuntimeError("Depth not enabled. Use depth=True on initialization.")
284
+ return self._cuda_depth_ptr
285
+
286
+ def __del__(self):
287
+ if self._use_cuda:
288
+ self._cuda_context.push()
289
+ self.unmap()
290
+ self._cuda_rgb_pbo.unregister()
291
+ if self._depth:
292
+ self._cuda_depth_pbo.unregister()
293
+
294
+ # Clean up context
295
+ drv.Context.pop()
296
+ self._cuda_context.detach()
297
+
298
+ if self.pbo_depth:
299
+ freePBO(self.pbo_rgb)
300
+ if self.pbo_depth:
301
+ freePBO(self.pbo_depth)
mujoco-py-2.1.2.14/mujoco_py/mjrendercontext.pyx ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Lock
2
+ from mujoco_py.generated import const
3
+ import numpy as np
4
+ cimport numpy as np
5
+
6
+ cdef class MjRenderContext(object):
7
+ """
8
+ Class that encapsulates rendering functionality for a
9
+ MuJoCo simulation.
10
+ """
11
+
12
+ cdef mjModel *_model_ptr
13
+ cdef mjData *_data_ptr
14
+
15
+ cdef mjvScene _scn
16
+ cdef mjvCamera _cam
17
+ cdef mjvOption _vopt
18
+ cdef mjvPerturb _pert
19
+ cdef mjrContext _con
20
+
21
+ # Public wrappers
22
+ cdef readonly PyMjvScene scn
23
+ cdef readonly PyMjvCamera cam
24
+ cdef readonly PyMjvOption vopt
25
+ cdef readonly PyMjvPerturb pert
26
+ cdef readonly PyMjrContext con
27
+
28
+ cdef readonly object opengl_context
29
+ cdef readonly int _visible
30
+ cdef readonly list _markers
31
+ cdef readonly dict _overlay
32
+
33
+ cdef readonly bint offscreen
34
+ cdef public object sim
35
+
36
+ def __cinit__(self):
37
+ maxgeom = 1000
38
+ mjv_makeScene(self._model_ptr, &self._scn, maxgeom)
39
+ mjv_defaultCamera(&self._cam)
40
+ mjv_defaultPerturb(&self._pert)
41
+ mjv_defaultOption(&self._vopt)
42
+ mjr_defaultContext(&self._con)
43
+
44
+ def __init__(self, MjSim sim, bint offscreen=True, int device_id=-1, opengl_backend=None, quiet=False):
45
+ self.sim = sim
46
+ self._setup_opengl_context(offscreen, device_id, opengl_backend, quiet=quiet)
47
+ self.offscreen = offscreen
48
+
49
+ # Ensure the model data has been updated so that there
50
+ # is something to render
51
+ sim.forward()
52
+
53
+ sim.add_render_context(self)
54
+
55
+ self._model_ptr = sim.model.ptr
56
+ self._data_ptr = sim.data.ptr
57
+ self.scn = WrapMjvScene(&self._scn)
58
+ self.cam = WrapMjvCamera(&self._cam)
59
+ self.vopt = WrapMjvOption(&self._vopt)
60
+ self.con = WrapMjrContext(&self._con)
61
+
62
+ self._pert.active = 0
63
+ self._pert.select = 0
64
+ self._pert.skinselect = -1
65
+
66
+ self.pert = WrapMjvPerturb(&self._pert)
67
+
68
+ self._markers = []
69
+ self._overlay = {}
70
+
71
+ self._init_camera(sim)
72
+ self._set_mujoco_buffers()
73
+
74
+ def update_sim(self, MjSim new_sim):
75
+ if new_sim == self.sim:
76
+ return
77
+ self._model_ptr = new_sim.model.ptr
78
+ self._data_ptr = new_sim.data.ptr
79
+ self._set_mujoco_buffers()
80
+ for render_context in self.sim.render_contexts:
81
+ new_sim.add_render_context(render_context)
82
+ self.sim = new_sim
83
+
84
+ def _set_mujoco_buffers(self):
85
+ mjr_makeContext(self._model_ptr, &self._con, mjFONTSCALE_150)
86
+ if self.offscreen:
87
+ mjr_setBuffer(mjFB_OFFSCREEN, &self._con);
88
+ if self._con.currentBuffer != mjFB_OFFSCREEN:
89
+ raise RuntimeError('Offscreen rendering not supported')
90
+ else:
91
+ mjr_setBuffer(mjFB_WINDOW, &self._con);
92
+ if self._con.currentBuffer != mjFB_WINDOW:
93
+ raise RuntimeError('Window rendering not supported')
94
+ self.con = WrapMjrContext(&self._con)
95
+
96
+ def _setup_opengl_context(self, offscreen, device_id, opengl_backend, quiet=False):
97
+ if opengl_backend is None and (not offscreen or sys.platform == 'darwin'):
98
+ # default to glfw for onscreen viewing or mac (both offscreen/onscreen)
99
+ opengl_backend = 'glfw'
100
+
101
+ if opengl_backend == 'glfw':
102
+ self.opengl_context = GlfwContext(offscreen=offscreen, quiet=quiet)
103
+ else:
104
+ if device_id < 0:
105
+ if "GPUS" in os.environ:
106
+ device_id = os.environ["GPUS"]
107
+ else:
108
+ device_id = os.getenv('CUDA_VISIBLE_DEVICES', '')
109
+ if len(device_id) > 0:
110
+ device_id = int(device_id.split(',')[0])
111
+ else:
112
+ # Sometimes env variable is an empty string.
113
+ device_id = 0
114
+ self.opengl_context = OffscreenOpenGLContext(device_id)
115
+
116
+ def _init_camera(self, sim):
117
+ # Make the free camera look at the scene
118
+ self.cam.type = const.CAMERA_FREE
119
+ self.cam.fixedcamid = -1
120
+ for i in range(3):
121
+ self.cam.lookat[i] = np.median(sim.data.geom_xpos[:, i])
122
+ self.cam.distance = sim.model.stat.extent
123
+
124
+ def update_offscreen_size(self, width, height):
125
+ if width != self._con.offWidth or height != self._con.offHeight:
126
+ self._model_ptr.vis.global_.offwidth = width
127
+ self._model_ptr.vis.global_.offheight = height
128
+ mjr_freeContext(&self._con)
129
+ self._set_mujoco_buffers()
130
+
131
+ def render(self, width, height, camera_id=None, segmentation=False):
132
+ cdef mjrRect rect
133
+ rect.left = 0
134
+ rect.bottom = 0
135
+ rect.width = width
136
+ rect.height = height
137
+
138
+ if self.sim.render_callback is not None:
139
+ self.sim.render_callback(self.sim, self)
140
+
141
+ # Sometimes buffers are too small.
142
+ if width > self._con.offWidth or height > self._con.offHeight:
143
+ new_width = max(width, self._model_ptr.vis.global_.offwidth)
144
+ new_height = max(height, self._model_ptr.vis.global_.offheight)
145
+ self.update_offscreen_size(new_width, new_height)
146
+
147
+ if camera_id is not None:
148
+ if camera_id == -1:
149
+ self.cam.type = const.CAMERA_FREE
150
+ else:
151
+ self.cam.type = const.CAMERA_FIXED
152
+ self.cam.fixedcamid = camera_id
153
+
154
+ # This doesn't really do anything else rather than checking for the size of buffer
155
+ # need to investigate further whi is that a no-op
156
+ # self.opengl_context.set_buffer_size(width, height)
157
+
158
+ mjv_updateScene(self._model_ptr, self._data_ptr, &self._vopt,
159
+ &self._pert, &self._cam, mjCAT_ALL, &self._scn)
160
+
161
+ if segmentation:
162
+ self._scn.flags[const.RND_SEGMENT] = 1
163
+ self._scn.flags[const.RND_IDCOLOR] = 1
164
+
165
+ for marker_params in self._markers:
166
+ self._add_marker_to_scene(marker_params)
167
+
168
+ mjr_render(rect, &self._scn, &self._con)
169
+ for gridpos, (text1, text2) in self._overlay.items():
170
+ mjr_overlay(const.FONTSCALE_150, gridpos, rect, text1.encode(), text2.encode(), &self._con)
171
+
172
+ if segmentation:
173
+ self._scn.flags[const.RND_SEGMENT] = 0
174
+ self._scn.flags[const.RND_IDCOLOR] = 0
175
+
176
+ def read_pixels(self, width, height, depth=True, segmentation=False):
177
+ cdef mjrRect rect
178
+ rect.left = 0
179
+ rect.bottom = 0
180
+ rect.width = width
181
+ rect.height = height
182
+
183
+ rgb_arr = np.zeros(3 * rect.width * rect.height, dtype=np.uint8)
184
+ depth_arr = np.zeros(rect.width * rect.height, dtype=np.float32)
185
+
186
+ cdef unsigned char[::view.contiguous] rgb_view = rgb_arr
187
+ cdef float[::view.contiguous] depth_view = depth_arr
188
+ mjr_readPixels(&rgb_view[0], &depth_view[0], rect, &self._con)
189
+ rgb_img = rgb_arr.reshape(rect.height, rect.width, 3)
190
+ cdef np.ndarray[np.npy_uint32, ndim=2] seg_img
191
+ cdef np.ndarray[np.npy_int32, ndim=2] seg_ids
192
+
193
+ ret_img = rgb_img
194
+ if segmentation:
195
+ seg_img = (rgb_img[:, :, 0] + rgb_img[:, :, 1] * (2**8) + rgb_img[:, :, 2] * (2 ** 16))
196
+ seg_img[seg_img >= (self._scn.ngeom + 1)] = 0
197
+ seg_ids = np.full((self._scn.ngeom + 1, 2), fill_value=-1, dtype=np.int32)
198
+
199
+ for i in range(self._scn.ngeom):
200
+ geom = self._scn.geoms[i]
201
+ if geom.segid != -1:
202
+ seg_ids[geom.segid + 1, 0] = geom.objtype
203
+ seg_ids[geom.segid + 1, 1] = geom.objid
204
+ ret_img = seg_ids[seg_img]
205
+
206
+ if depth:
207
+ depth_img = depth_arr.reshape(rect.height, rect.width)
208
+ return (ret_img, depth_img)
209
+ else:
210
+ return ret_img
211
+
212
+ def read_pixels_depth(self, np.ndarray[np.float32_t, mode="c", ndim=2] buffer):
213
+ ''' Read depth pixels into a preallocated buffer '''
214
+ cdef mjrRect rect
215
+ rect.left = 0
216
+ rect.bottom = 0
217
+ rect.width = buffer.shape[1]
218
+ rect.height = buffer.shape[0]
219
+
220
+ cdef float[::view.contiguous] buffer_view = buffer.ravel()
221
+ mjr_readPixels(NULL, &buffer_view[0], rect, &self._con)
222
+
223
+ def upload_texture(self, int tex_id):
224
+ """ Uploads given texture to the GPU. """
225
+ self.opengl_context.make_context_current()
226
+ mjr_uploadTexture(self._model_ptr, &self._con, tex_id)
227
+
228
+ def draw_pixels(self, np.ndarray[np.uint8_t, ndim=3] image, int left, int bottom):
229
+ """Draw an image into the OpenGL buffer."""
230
+ cdef unsigned char[::view.contiguous] image_view = image.ravel()
231
+ cdef mjrRect viewport
232
+ viewport.left = left
233
+ viewport.bottom = bottom
234
+ viewport.width = image.shape[1]
235
+ viewport.height = image.shape[0]
236
+ mjr_drawPixels(&image_view[0], NULL, viewport, &self._con)
237
+
238
+ def move_camera(self, int action, double reldx, double reldy):
239
+ """ Moves the camera based on mouse movements. Action is one of mjMOUSE_*. """
240
+ mjv_moveCamera(self._model_ptr, action, reldx, reldy, &self._scn, &self._cam)
241
+
242
+ def add_overlay(self, int gridpos, str text1, str text2):
243
+ """ Overlays text on the scene. """
244
+ if gridpos not in self._overlay:
245
+ self._overlay[gridpos] = ["", ""]
246
+ self._overlay[gridpos][0] += text1 + "\n"
247
+ self._overlay[gridpos][1] += text2 + "\n"
248
+
249
+ def add_marker(self, **marker_params):
250
+ self._markers.append(marker_params)
251
+
252
+ def _add_marker_to_scene(self, marker_params):
253
+ """ Adds marker to scene, and returns the corresponding object. """
254
+ if self._scn.ngeom >= self._scn.maxgeom:
255
+ raise RuntimeError('Ran out of geoms. maxgeom: %d' % self._scn.maxgeom)
256
+
257
+ cdef mjvGeom *g = self._scn.geoms + self._scn.ngeom
258
+
259
+ # default values.
260
+ g.dataid = -1
261
+ g.objtype = const.OBJ_UNKNOWN
262
+ g.objid = -1
263
+ g.category = const.CAT_DECOR
264
+ g.texid = -1
265
+ g.texuniform = 0
266
+ g.texrepeat[0] = 1
267
+ g.texrepeat[1] = 1
268
+ g.emission = 0
269
+ g.specular = 0.5
270
+ g.shininess = 0.5
271
+ g.reflectance = 0
272
+ g.type = const.GEOM_BOX
273
+ g.size[:] = np.ones(3) * 0.1
274
+ g.mat[:] = np.eye(3).flatten()
275
+ g.rgba[:] = np.ones(4)
276
+ wrapped = WrapMjvGeom(g)
277
+
278
+ for key, value in marker_params.items():
279
+ if isinstance(value, (int, float)):
280
+ setattr(wrapped, key, value)
281
+ elif isinstance(value, (tuple, list, np.ndarray)):
282
+ attr = getattr(wrapped, key)
283
+ attr[:] = np.asarray(value).reshape(attr.shape)
284
+ elif isinstance(value, str):
285
+ assert key == "label", "Only label is a string in mjvGeom."
286
+ if value == None:
287
+ g.label[0] = 0
288
+ else:
289
+ strncpy(g.label, value.encode(), 100)
290
+ elif hasattr(wrapped, key):
291
+ raise ValueError("mjvGeom has attr {} but type {} is invalid".format(key, type(value)))
292
+ else:
293
+ raise ValueError("mjvGeom doesn't have field %s" % key)
294
+
295
+ self._scn.ngeom += 1
296
+
297
+
298
+ def __dealloc__(self):
299
+ mjr_freeContext(&self._con)
300
+ mjv_freeScene(&self._scn)
301
+
302
+
303
+ class MjRenderContextOffscreen(MjRenderContext):
304
+
305
+ def __cinit__(self, MjSim sim, int device_id):
306
+ super().__init__(sim, offscreen=True, device_id=device_id)
307
+
308
+ class MjRenderContextWindow(MjRenderContext):
309
+
310
+ def __init__(self, MjSim sim):
311
+ super().__init__(sim, offscreen=False)
312
+ self.render_swap_callback = None
313
+
314
+ assert isinstance(self.opengl_context, GlfwContext), (
315
+ "Only GlfwContext supported for windowed rendering")
316
+
317
+ @property
318
+ def window(self):
319
+ return self.opengl_context.window
320
+
321
+ def render(self):
322
+ if self.window is None or glfw.window_should_close(self.window):
323
+ return
324
+
325
+ glfw.make_context_current(self.window)
326
+ super().render(*glfw.get_framebuffer_size(self.window))
327
+ if self.render_swap_callback is not None:
328
+ self.render_swap_callback()
329
+ glfw.swap_buffers(self.window)
mujoco-py-2.1.2.14/mujoco_py/mjrenderpool.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ctypes
2
+ import inspect
3
+
4
+ from multiprocessing import Array, get_start_method, Pool, Value
5
+
6
+ import numpy as np
7
+
8
+
9
+ class RenderPoolStorage:
10
+ """
11
+ Helper object used for storing global data for worker processes.
12
+ """
13
+
14
+ __slots__ = ['shared_rgbs_array',
15
+ 'shared_depths_array',
16
+ 'device_id',
17
+ 'sim',
18
+ 'modder']
19
+
20
+
21
+ class MjRenderPool:
22
+ """
23
+ Utilizes a process pool to render a MuJoCo simulation across
24
+ multiple GPU devices. This can scale the throughput linearly
25
+ with the number of available GPUs. Throughput can also be
26
+ slightly increased by using more than one worker per GPU.
27
+ """
28
+
29
+ DEFAULT_MAX_IMAGE_SIZE = 512 * 512 # in pixels
30
+
31
+ def __init__(self, model, device_ids=1, n_workers=None,
32
+ max_batch_size=None, max_image_size=DEFAULT_MAX_IMAGE_SIZE,
33
+ modder=None):
34
+ """
35
+ Args:
36
+ - model (PyMjModel): MuJoCo model to use for rendering
37
+ - device_ids (int/list): list of device ids to use for rendering.
38
+ One or more workers will be assigned to each device, depending
39
+ on how many workers are requested.
40
+ - n_workers (int): number of parallel processes in the pool. Defaults
41
+ to the number of device ids.
42
+ - max_batch_size (int): maximum number of states that can be rendered
43
+ in batch using .render(). Defaults to the number of workers.
44
+ - max_image_size (int): maximum number pixels in images requested
45
+ by .render()
46
+ - modder (Modder): modder to use for domain randomization.
47
+ """
48
+ self._closed, self.pool = False, None
49
+
50
+ if not (modder is None or inspect.isclass(modder)):
51
+ raise ValueError("modder must be a class")
52
+
53
+ if isinstance(device_ids, int):
54
+ device_ids = list(range(device_ids))
55
+ else:
56
+ assert isinstance(device_ids, list), (
57
+ "device_ids must be list of integer")
58
+
59
+ n_workers = n_workers or 1
60
+ self._max_batch_size = max_batch_size or (len(device_ids) * n_workers)
61
+ self._max_image_size = max_image_size
62
+
63
+ array_size = self._max_image_size * self._max_batch_size
64
+
65
+ self._shared_rgbs = Array(ctypes.c_uint8, array_size * 3)
66
+ self._shared_depths = Array(ctypes.c_float, array_size)
67
+
68
+ self._shared_rgbs_array = np.frombuffer(
69
+ self._shared_rgbs.get_obj(), dtype=ctypes.c_uint8)
70
+ assert self._shared_rgbs_array.size == (array_size * 3), (
71
+ "Array size is %d, expected %d" % (
72
+ self._shared_rgbs_array.size, array_size * 3))
73
+ self._shared_depths_array = np.frombuffer(
74
+ self._shared_depths.get_obj(), dtype=ctypes.c_float)
75
+ assert self._shared_depths_array.size == array_size, (
76
+ "Array size is %d, expected %d" % (
77
+ self._shared_depths_array.size, array_size))
78
+
79
+ worker_id = Value(ctypes.c_int)
80
+ worker_id.value = 0
81
+
82
+ if get_start_method() != "spawn":
83
+ raise RuntimeError(
84
+ "Start method must be set to 'spawn' for the "
85
+ "render pool to work. That is, you must add the "
86
+ "following to the _TOP_ of your main script, "
87
+ "before any other imports (since they might be "
88
+ "setting it otherwise):\n"
89
+ " import multiprocessing as mp\n"
90
+ " if __name__ == '__main__':\n"
91
+ " mp.set_start_method('spawn')\n")
92
+
93
+ self.pool = Pool(
94
+ processes=len(device_ids) * n_workers,
95
+ initializer=MjRenderPool._worker_init,
96
+ initargs=(
97
+ model.get_mjb(),
98
+ worker_id,
99
+ device_ids,
100
+ self._shared_rgbs,
101
+ self._shared_depths,
102
+ modder))
103
+
104
+ @staticmethod
105
+ def _worker_init(mjb_bytes, worker_id, device_ids,
106
+ shared_rgbs, shared_depths, modder):
107
+ """
108
+ Initializes the global state for the workers.
109
+ """
110
+ s = RenderPoolStorage()
111
+
112
+ with worker_id.get_lock():
113
+ proc_worker_id = worker_id.value
114
+ worker_id.value += 1
115
+ s.device_id = device_ids[proc_worker_id % len(device_ids)]
116
+
117
+ s.shared_rgbs_array = np.frombuffer(
118
+ shared_rgbs.get_obj(), dtype=ctypes.c_uint8)
119
+ s.shared_depths_array = np.frombuffer(
120
+ shared_depths.get_obj(), dtype=ctypes.c_float)
121
+
122
+ # avoid a circular import
123
+ from mujoco_py import load_model_from_mjb, MjRenderContext, MjSim
124
+ s.sim = MjSim(load_model_from_mjb(mjb_bytes))
125
+ # attach a render context to the sim (needs to happen before
126
+ # modder is called, since it might need to upload textures
127
+ # to the GPU).
128
+ MjRenderContext(s.sim, device_id=s.device_id)
129
+
130
+ if modder is not None:
131
+ s.modder = modder(s.sim, random_state=proc_worker_id)
132
+ s.modder.whiten_materials()
133
+ else:
134
+ s.modder = None
135
+
136
+ global _render_pool_storage
137
+ _render_pool_storage = s
138
+
139
+ @staticmethod
140
+ def _worker_render(worker_id, state, width, height,
141
+ camera_name, randomize):
142
+ """
143
+ Main target function for the workers.
144
+ """
145
+ s = _render_pool_storage
146
+
147
+ forward = False
148
+ if state is not None:
149
+ s.sim.set_state(state)
150
+ forward = True
151
+ if randomize and s.modder is not None:
152
+ s.modder.randomize()
153
+ forward = True
154
+ if forward:
155
+ s.sim.forward()
156
+
157
+ rgb_block = width * height * 3
158
+ rgb_offset = rgb_block * worker_id
159
+ rgb = s.shared_rgbs_array[rgb_offset:rgb_offset + rgb_block]
160
+ rgb = rgb.reshape(height, width, 3)
161
+
162
+ depth_block = width * height
163
+ depth_offset = depth_block * worker_id
164
+ depth = s.shared_depths_array[depth_offset:depth_offset + depth_block]
165
+ depth = depth.reshape(height, width)
166
+
167
+ rgb[:], depth[:] = s.sim.render(
168
+ width, height, camera_name=camera_name, depth=True,
169
+ device_id=s.device_id)
170
+
171
+ def render(self, width, height, states=None, camera_name=None,
172
+ depth=False, randomize=False, copy=True):
173
+ """
174
+ Renders the simulations in batch. If no states are provided,
175
+ the max_batch_size will be used.
176
+
177
+ Args:
178
+ - width (int): width of image to render.
179
+ - height (int): height of image to render.
180
+ - states (list): list of MjSimStates; updates the states before
181
+ rendering. Batch size will be number of states supplied.
182
+ - camera_name (str): name of camera to render from.
183
+ - depth (bool): if True, also return depth.
184
+ - randomize (bool): calls modder.rand_all() before rendering.
185
+ - copy (bool): return a copy rather than a reference
186
+
187
+ Returns:
188
+ - rgbs: NxHxWx3 numpy array of N images in batch of width W
189
+ and height H.
190
+ - depth: NxHxW numpy array of N images in batch of width W
191
+ and height H. Only returned if depth=True.
192
+ """
193
+ if self._closed:
194
+ raise RuntimeError("The pool has been closed.")
195
+
196
+ if (width * height) > self._max_image_size:
197
+ raise ValueError(
198
+ "Requested image larger than maximum image size. Create "
199
+ "a new RenderPool with a larger maximum image size.")
200
+ if states is None:
201
+ batch_size = self._max_batch_size
202
+ states = [None] * batch_size
203
+ else:
204
+ batch_size = len(states)
205
+
206
+ if batch_size > self._max_batch_size:
207
+ raise ValueError(
208
+ "Requested batch size larger than max batch size. Create "
209
+ "a new RenderPool with a larger max batch size.")
210
+
211
+ self.pool.starmap(
212
+ MjRenderPool._worker_render,
213
+ [(i, state, width, height, camera_name, randomize)
214
+ for i, state in enumerate(states)])
215
+
216
+ rgbs = self._shared_rgbs_array[:width * height * 3 * batch_size]
217
+ rgbs = rgbs.reshape(batch_size, height, width, 3)
218
+ if copy:
219
+ rgbs = rgbs.copy()
220
+
221
+ if depth:
222
+ depths = self._shared_depths_array[:width * height * batch_size]
223
+ depths = depths.reshape(batch_size, height, width).copy()
224
+ if copy:
225
+ depths = depths.copy()
226
+ return rgbs, depths
227
+ else:
228
+ return rgbs
229
+
230
+ def close(self):
231
+ """
232
+ Closes the pool and terminates child processes.
233
+ """
234
+ if not self._closed:
235
+ if self.pool is not None:
236
+ self.pool.close()
237
+ self.pool.join()
238
+ self._closed = True
239
+
240
+ def __del__(self):
241
+ self.close()
mujoco-py-2.1.2.14/mujoco_py/mjsim.pyx ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from xml.dom import minidom
2
+ from mujoco_py.utils import remove_empty_lines
3
+ from mujoco_py.builder import build_callback_fn
4
+ from threading import Lock
5
+
6
+ _MjSim_render_lock = Lock()
7
+
8
+ ctypedef void (*substep_udd_t)(const mjModel* m, mjData* d)
9
+
10
+
11
+ cdef class MjSim(object):
12
+ """MjSim represents a running simulation including its state.
13
+
14
+ Similar to Gym's ``MujocoEnv``, it internally wraps a :class:`.PyMjModel`
15
+ and a :class:`.PyMjData`.
16
+
17
+ Parameters
18
+ ----------
19
+ model : :class:`.PyMjModel`
20
+ The model to simulate.
21
+ data : :class:`.PyMjData`
22
+ Optional container for the simulation state. Will be created if ``None``.
23
+ nsubsteps : int
24
+ Optional number of MuJoCo steps to run for every call to :meth:`.step`.
25
+ Buffers will be swapped only once per step.
26
+ udd_callback : fn(:class:`.MjSim`) -> dict
27
+ Optional callback for user-defined dynamics. At every call to
28
+ :meth:`.step`, it receives an MjSim object ``sim`` containing the
29
+ current user-defined dynamics state in ``sim.udd_state``, and returns the
30
+ next ``udd_state`` after applying the user-defined dynamics. This is
31
+ useful e.g. for reward functions that operate over functions of historical
32
+ state.
33
+ substep_callback : str or int or None
34
+ This uses a compiled C function as user-defined dynamics in substeps.
35
+ If given as a string, it's compiled as a C function and set as pointer.
36
+ If given as int, it's interpreted as a function pointer.
37
+ See :meth:`.set_substep_callback` for detailed info.
38
+ userdata_names : list of strings or None
39
+ This is a convenience parameter which is just set on the model.
40
+ Equivalent to calling ``model.set_userdata_names``
41
+ render_callback : callback for rendering.
42
+ """
43
+ # MjRenderContext for rendering camera views.
44
+ cdef readonly list render_contexts
45
+ cdef readonly object _render_context_window
46
+ cdef readonly object _render_context_offscreen
47
+
48
+ # MuJoCo model
49
+ cdef readonly PyMjModel model
50
+ # MuJoCo data
51
+ """
52
+ DATAZ
53
+ """
54
+ cdef readonly PyMjData data
55
+ # Number of substeps when calling .step
56
+ cdef public int nsubsteps
57
+ # User defined state.
58
+ cdef public dict udd_state
59
+ # User defined dynamics callback
60
+ cdef readonly object _udd_callback
61
+ # Allows to store extra information in MjSim.
62
+ cdef readonly dict extras
63
+ # Function pointer for substep callback, stored as uintptr
64
+ cdef readonly uintptr_t substep_callback_ptr
65
+ # Callback executed before rendering.
66
+ cdef public object render_callback
67
+
68
+ def __cinit__(self, PyMjModel model, PyMjData data=None, int nsubsteps=1,
69
+ udd_callback=None, substep_callback=None, userdata_names=None,
70
+ render_callback=None):
71
+ self.nsubsteps = nsubsteps
72
+ self.model = model
73
+ if data is None:
74
+ with wrap_mujoco_warning():
75
+ _data = mj_makeData(self.model.ptr)
76
+ if _data == NULL:
77
+ raise Exception('mj_makeData failed!')
78
+ self.data = WrapMjData(_data, self.model)
79
+ else:
80
+ self.data = data
81
+
82
+ self.render_contexts = []
83
+ self._render_context_offscreen = None
84
+ self._render_context_window = None
85
+ self.udd_state = None
86
+ self.udd_callback = udd_callback
87
+ self.render_callback = render_callback
88
+ self.extras = {}
89
+ self.set_substep_callback(substep_callback, userdata_names)
90
+
91
+ def reset(self):
92
+ """
93
+ Resets the simulation data and clears buffers.
94
+ """
95
+ with wrap_mujoco_warning():
96
+ mj_resetData(self.model.ptr, self.data.ptr)
97
+
98
+ self.udd_state = None
99
+ self.step_udd()
100
+
101
+ def forward(self):
102
+ """
103
+ Computes the forward kinematics. Calls ``mj_forward`` internally.
104
+ """
105
+ with wrap_mujoco_warning():
106
+ mj_forward(self.model.ptr, self.data.ptr)
107
+
108
+ def set_constants(self):
109
+ """
110
+ Set constant fields of mjModel, corresponding to qpos0 configuration.
111
+ """
112
+ with wrap_mujoco_warning():
113
+ mj_setConst(self.model.ptr, self.data.ptr)
114
+
115
+ def step(self, with_udd=True):
116
+ """
117
+ Advances the simulation by calling ``mj_step``.
118
+
119
+ If ``qpos`` or ``qvel`` have been modified directly, the user is required to call
120
+ :meth:`.forward` before :meth:`.step` if their ``udd_callback`` requires access to MuJoCo state
121
+ set during the forward dynamics.
122
+ """
123
+ if with_udd:
124
+ self.step_udd()
125
+
126
+ with wrap_mujoco_warning():
127
+ for _ in range(self.nsubsteps):
128
+ self.substep_callback()
129
+ mj_step(self.model.ptr, self.data.ptr)
130
+
131
+ def render(self, width=None, height=None, *, camera_name=None, depth=False,
132
+ mode='offscreen', device_id=-1, segmentation=False):
133
+ """
134
+ Renders view from a camera and returns image as an `numpy.ndarray`.
135
+
136
+ Args:
137
+ - width (int): desired image width.
138
+ - height (int): desired image height.
139
+ - camera_name (str): name of camera in model. If None, the free
140
+ camera will be used.
141
+ - depth (bool): if True, also return depth buffer
142
+ - device (int): device to use for rendering (only for GPU-backed
143
+ rendering).
144
+
145
+ Returns:
146
+ - rgb (uint8 array): image buffer from camera
147
+ - depth (float array): depth buffer from camera (only returned
148
+ if depth=True)
149
+ """
150
+ if camera_name is None:
151
+ camera_id = None
152
+ else:
153
+ camera_id = self.model.camera_name2id(camera_name)
154
+
155
+ if mode == 'offscreen':
156
+ with _MjSim_render_lock:
157
+ if self._render_context_offscreen is None:
158
+ render_context = MjRenderContextOffscreen(
159
+ self, device_id=device_id)
160
+ else:
161
+ render_context = self._render_context_offscreen
162
+
163
+ render_context.render(
164
+ width=width, height=height, camera_id=camera_id, segmentation=segmentation)
165
+ return render_context.read_pixels(
166
+ width, height, depth=depth, segmentation=segmentation)
167
+ elif mode == 'window':
168
+ if self._render_context_window is None:
169
+ from mujoco_py.mjviewer import MjViewer
170
+ render_context = MjViewer(self)
171
+ else:
172
+ render_context = self._render_context_window
173
+
174
+ render_context.render()
175
+
176
+ else:
177
+ raise ValueError("Mode must be either 'window' or 'offscreen'.")
178
+
179
+ def add_render_context(self, render_context):
180
+ self.render_contexts.append(render_context)
181
+ if render_context.offscreen and self._render_context_offscreen is None:
182
+ self._render_context_offscreen = render_context
183
+ elif not render_context.offscreen and self._render_context_window is None:
184
+ self._render_context_window = render_context
185
+
186
+ @property
187
+ def udd_callback(self):
188
+ return self._udd_callback
189
+
190
+ @udd_callback.setter
191
+ def udd_callback(self, value):
192
+ self._udd_callback = value
193
+ self.udd_state = None
194
+ self.step_udd()
195
+
196
+ cpdef substep_callback(self):
197
+ if self.substep_callback_ptr:
198
+ (<mjfGeneric>self.substep_callback_ptr)(self.model.ptr, self.data.ptr)
199
+
200
+ def set_substep_callback(self, substep_callback, userdata_names=None):
201
+ '''
202
+ Set a substep callback function.
203
+
204
+ Parameters :
205
+ substep_callback : str or int or None
206
+ If `substep_callback` is a string, compile to function pointer and set.
207
+ See `builder.build_callback_fn()` for documentation.
208
+ If `substep_callback` is an int, we interpret it as a function pointer.
209
+ If `substep_callback` is None, we disable substep_callbacks.
210
+ userdata_names : list of strings or None
211
+ This is a convenience parameter, if not None, this is passed
212
+ onto ``model.set_userdata_names()``.
213
+ '''
214
+ if userdata_names is not None:
215
+ self.model.set_userdata_names(userdata_names)
216
+ if substep_callback is None:
217
+ self.substep_callback_ptr = 0
218
+ elif isinstance(substep_callback, int):
219
+ self.substep_callback_ptr = substep_callback
220
+ elif isinstance(substep_callback, str):
221
+ self.substep_callback_ptr = build_callback_fn(substep_callback,
222
+ self.model.userdata_names)
223
+ else:
224
+ raise TypeError('invalid: {}'.format(type(substep_callback)))
225
+
226
+ def step_udd(self):
227
+ if self._udd_callback is None:
228
+ self.udd_state = {}
229
+ else:
230
+ schema_example = self.udd_state
231
+ self.udd_state = self._udd_callback(self)
232
+ # Check to make sure the udd_state has consistent keys and dimension across steps
233
+ if schema_example is not None:
234
+ keys = set(schema_example.keys()) | set(self.udd_state.keys())
235
+ for key in keys:
236
+ assert key in schema_example, "Keys cannot be added to udd_state between steps."
237
+ assert key in self.udd_state, "Keys cannot be dropped from udd_state between steps."
238
+ if isinstance(schema_example[key], Number):
239
+ assert isinstance(self.udd_state[key], Number), \
240
+ "Every value in udd_state must be either a number or a numpy array"
241
+ else:
242
+ assert isinstance(self.udd_state[key], np.ndarray), \
243
+ "Every value in udd_state must be either a number or a numpy array"
244
+ assert self.udd_state[key].shape == schema_example[key].shape, \
245
+ "Numpy array values in udd_state must keep the same dimension across steps."
246
+
247
+ def get_state(self):
248
+ """ Returns a copy of the simulator state. """
249
+ qpos = np.copy(self.data.qpos)
250
+ qvel = np.copy(self.data.qvel)
251
+ if self.model.na == 0:
252
+ act = None
253
+ else:
254
+ act = np.copy(self.data.act)
255
+ udd_state = copy.deepcopy(self.udd_state)
256
+
257
+ return MjSimState(self.data.time, qpos, qvel, act, udd_state)
258
+
259
+ def set_state(self, value):
260
+ """
261
+ Sets the state from an MjSimState.
262
+ If the MjSimState was previously unflattened from a numpy array, consider
263
+ set_state_from_flattened, as the defensive copy is a substantial overhead
264
+ in an inner loop.
265
+
266
+ Args:
267
+ - value (MjSimState): the desired state.
268
+ - call_forward: optionally call sim.forward(). Called by default if
269
+ the udd_callback is set.
270
+ """
271
+ self.data.time = value.time
272
+ self.data.qpos[:] = np.copy(value.qpos)
273
+ self.data.qvel[:] = np.copy(value.qvel)
274
+ if self.model.na != 0:
275
+ self.data.act[:] = np.copy(value.act)
276
+ self.udd_state = copy.deepcopy(value.udd_state)
277
+
278
+ def set_state_from_flattened(self, value):
279
+ """ This helper method sets the state from an array without requiring a defensive copy."""
280
+ state = MjSimState.from_flattened(value, self)
281
+
282
+ self.data.time = state.time
283
+ self.data.qpos[:] = state.qpos
284
+ self.data.qvel[:] = state.qvel
285
+ if self.model.na != 0:
286
+ self.data.act[:] = state.act
287
+ self.udd_state = state.udd_state
288
+
289
+ def save(self, file, format='xml', keep_inertials=False):
290
+ """
291
+ Saves the simulator model and state to a file as either
292
+ a MuJoCo XML or MJB file. The current state is saved as
293
+ a keyframe in the model file. This is useful for debugging
294
+ using MuJoCo's `simulate` utility.
295
+
296
+ Note that this doesn't save the UDD-state which is
297
+ part of MjSimState, since that's not supported natively
298
+ by MuJoCo. If you want to save the model together with
299
+ the UDD-state, you should use the `get_xml` or `get_mjb`
300
+ methods on `MjModel` together with `MjSim.get_state` and
301
+ save them with e.g. pickle.
302
+
303
+ Args:
304
+ - file (IO stream): stream to write model to.
305
+ - format: format to use (either 'xml' or 'mjb')
306
+ - keep_inertials (bool): if False, removes all <inertial>
307
+ properties derived automatically for geoms by MuJoco. Note
308
+ that this removes ones that were provided by the user
309
+ as well.
310
+ """
311
+ xml_str = self.model.get_xml()
312
+ dom = minidom.parseString(xml_str)
313
+
314
+ mujoco_node = dom.childNodes[0]
315
+ assert mujoco_node.tagName == 'mujoco'
316
+
317
+ keyframe_el = dom.createElement('keyframe')
318
+ key_el = dom.createElement('key')
319
+ keyframe_el.appendChild(key_el)
320
+ mujoco_node.appendChild(keyframe_el)
321
+
322
+ def str_array(arr):
323
+ return " ".join(map(str, arr))
324
+
325
+ key_el.setAttribute('time', str(self.data.time))
326
+ key_el.setAttribute('qpos', str_array(self.data.qpos))
327
+ key_el.setAttribute('qvel', str_array(self.data.qvel))
328
+ if self.data.act is not None:
329
+ key_el.setAttribute('act', str_array(self.data.act))
330
+
331
+ if not keep_inertials:
332
+ for element in dom.getElementsByTagName('inertial'):
333
+ element.parentNode.removeChild(element)
334
+
335
+ result_xml = remove_empty_lines(dom.toprettyxml(indent=" " * 4))
336
+
337
+ if format == 'xml':
338
+ file.write(result_xml)
339
+ elif format == 'mjb':
340
+ new_model = load_model_from_xml(result_xml)
341
+ file.write(new_model.get_mjb())
342
+ else:
343
+ raise ValueError("Unsupported format. Valid ones are 'xml' and 'mjb'")
344
+
345
+ def ray(self, pnt, vec, include_static_geoms=True, exclude_body=-1, group_filter=None):
346
+ """
347
+ Cast a ray into the scene, and return the first valid geom it intersects.
348
+ pnt - origin point of the ray in world coordinates (X Y Z)
349
+ vec - direction of the ray in world coordinates (X Y Z)
350
+ include_static_geoms - if False, we exclude geoms that are children of worldbody.
351
+ exclude_body - if this is a body ID, we exclude all children geoms of this body.
352
+ group_filter - a vector of booleans of length const.NGROUP
353
+ which specifies what geom groups (stored in model.geom_group)
354
+ to enable or disable. If none, all groups are used
355
+ Returns (distance, geomid) where
356
+ distance - distance along ray until first collision with geom
357
+ geomid - id of the geom the ray collided with
358
+ If no collision was found in the scene, return (-1, None)
359
+
360
+ NOTE: sometimes self.forward() needs to be called before self.ray().
361
+
362
+ See self.ray_fast_group() and self.ray_fast_nogroup() for versions of this call
363
+ with more stringent type requirements.
364
+ """
365
+ cdef mjtNum distance
366
+ cdef mjtNum[::view.contiguous] pnt_view = pnt
367
+ cdef mjtNum[::view.contiguous] vec_view = vec
368
+
369
+ if group_filter is None:
370
+ return self.ray_fast_nogroup(
371
+ np.asarray(pnt, dtype=np.float64),
372
+ np.asarray(vec, dtype=np.float64),
373
+ 1 if include_static_geoms else 0,
374
+ exclude_body)
375
+ else:
376
+ return self.ray_fast_group(
377
+ np.asarray(pnt, dtype=np.float64),
378
+ np.asarray(vec, dtype=np.float64),
379
+ np.asarray(group_filter, dtype=np.uint8),
380
+ 1 if include_static_geoms else 0,
381
+ exclude_body)
382
+
383
+ def ray_fast_group(self,
384
+ np.ndarray[np.float64_t, mode="c", ndim=1] pnt,
385
+ np.ndarray[np.float64_t, mode="c", ndim=1] vec,
386
+ np.ndarray[np.uint8_t, mode="c", ndim=1] geomgroup,
387
+ mjtByte flg_static=1,
388
+ int bodyexclude=-1):
389
+ """
390
+ Faster version of sim.ray(), which avoids extra copies,
391
+ but needs to be given all the correct type arrays.
392
+
393
+ See self.ray() for explanation of arguments
394
+ """
395
+ cdef int geomid
396
+ cdef mjtNum distance
397
+ cdef mjtNum[::view.contiguous] pnt_view = pnt
398
+ cdef mjtNum[::view.contiguous] vec_view = vec
399
+ cdef mjtByte[::view.contiguous] geomgroup_view = geomgroup
400
+
401
+ distance = mj_ray(self.model.ptr,
402
+ self.data.ptr,
403
+ &pnt_view[0],
404
+ &vec_view[0],
405
+ &geomgroup_view[0],
406
+ flg_static,
407
+ bodyexclude,
408
+ &geomid)
409
+ return (distance, geomid)
410
+
411
+
412
+ def ray_fast_nogroup(self,
413
+ np.ndarray[np.float64_t, mode="c", ndim=1] pnt,
414
+ np.ndarray[np.float64_t, mode="c", ndim=1] vec,
415
+ mjtByte flg_static=1,
416
+ int bodyexclude=-1):
417
+ """
418
+ Faster version of sim.ray(), which avoids extra copies,
419
+ but needs to be given all the correct type arrays.
420
+
421
+ This version hardcodes the geomgroup to NULL.
422
+ (Can't easily express a signature that is "numpy array of specific type or None")
423
+
424
+ See self.ray() for explanation of arguments
425
+ """
426
+ cdef int geomid
427
+ cdef mjtNum distance
428
+ cdef mjtNum[::view.contiguous] pnt_view = pnt
429
+ cdef mjtNum[::view.contiguous] vec_view = vec
430
+
431
+ distance = mj_ray(self.model.ptr,
432
+ self.data.ptr,
433
+ &pnt_view[0],
434
+ &vec_view[0],
435
+ NULL,
436
+ flg_static,
437
+ bodyexclude,
438
+ &geomid)
439
+ return (distance, geomid)
mujoco-py-2.1.2.14/mujoco_py/pxd/__init__.py ADDED
File without changes
mujoco-py-2.1.2.14/mujoco_py/pxd/mjdata.pxd ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cdef extern from "mjdata.h" nogil:
2
+
3
+ #---------------------------- primitive types (mjt) ------------------------------------
4
+
5
+ ctypedef enum mjtWarning: # warning types
6
+ mjWARN_INERTIA = 0, # (near) singular inertia matrix
7
+ mjWARN_CONTACTFULL, # too many contacts in contact list
8
+ mjWARN_CNSTRFULL, # too many constraints
9
+ mjWARN_VGEOMFULL, # too many visual geoms
10
+ mjWARN_BADQPOS, # bad number in qpos
11
+ mjWARN_BADQVEL, # bad number in qvel
12
+ mjWARN_BADQACC, # bad number in qacc
13
+ mjWARN_BADCTRL, # bad number in ctrl
14
+
15
+ enum: mjNWARNING # number of warnings
16
+
17
+
18
+ ctypedef enum mjtTimer:
19
+ # main api
20
+ mjTIMER_STEP = 0, # step
21
+ mjTIMER_FORWARD, # forward
22
+ mjTIMER_INVERSE, # inverse
23
+
24
+ # breakdown of step/forward
25
+ mjTIMER_POSITION, # fwdPosition
26
+ mjTIMER_VELOCITY, # fwdVelocity
27
+ mjTIMER_ACTUATION, # fwdActuation
28
+ mjTIMER_ACCELERATION, # fwdAcceleration
29
+ mjTIMER_CONSTRAINT, # fwdConstraint
30
+
31
+ # breakdown of fwdPosition
32
+ mjTIMER_POS_KINEMATICS, # kinematics, com, tendon, transmission
33
+ mjTIMER_POS_INERTIA, # inertia computations
34
+ mjTIMER_POS_COLLISION, # collision detection
35
+ mjTIMER_POS_MAKE, # make constraints
36
+ mjTIMER_POS_PROJECT, # project constraints
37
+
38
+ enum: mjNTIMER # number of timers
39
+
40
+
41
+ #------------------------------ mjContact ----------------------------------------------
42
+
43
+ ctypedef struct mjContact: # result of collision detection functions
44
+ # contact parameters set by geom-specific collision detector
45
+ mjtNum dist # distance between nearest points; neg: penetration
46
+ mjtNum pos[3] # position of contact point: midpoint between geoms
47
+ mjtNum frame[9] # normal is in [0-2]
48
+
49
+ # contact parameters set by mj_collideGeoms
50
+ mjtNum includemargin # include if dist<includemargin=margin-gap
51
+ mjtNum friction[5] # tangent1, 2, spin, roll1, 2
52
+ mjtNum solref[mjNREF] # constraint solver reference
53
+ mjtNum solimp[mjNIMP] # constraint solver impedance
54
+
55
+ # storage used internally by constraint solver
56
+ mjtNum mu # friction of regularized cone
57
+ mjtNum H[36] # cone Hessian, set by mj_updateConstraint
58
+
59
+ # contact descriptors set by mj_collideGeoms
60
+ int dim # contact space dimensionality: 1, 3, 4 or 6
61
+ int geom1 # id of geom 1
62
+ int geom2 # id of geom 2
63
+
64
+ # flag set by mj_fuseContact or mj_instantianteEquality
65
+ int exclude # 0: include, 1: in gap, 2: fused, 3: equality
66
+
67
+ # address computed by mj_instantiateContact
68
+ int efc_address # address in efc; -1: not included, -2-i: distance constraint i ???
69
+
70
+ #------------------------------ diagnostics --------------------------------------------
71
+
72
+ ctypedef struct mjWarningStat: # warning statistics
73
+ int lastinfo # info from last warning
74
+ int number # how many times was warning raised
75
+
76
+
77
+ ctypedef struct mjTimerStat: # timer statistics
78
+ mjtNum duration # cumulative duration
79
+ int number # how many times was timer called
80
+
81
+
82
+ ctypedef struct mjSolverStat: # per-iteration solver statistics
83
+ mjtNum improvement # cost reduction, scaled by 1/trace(M(qpos0))
84
+ mjtNum gradient # gradient norm (primal only, scaled)
85
+ mjtNum lineslope # slope in linesearch
86
+ int nactive # number of active constraints
87
+ int nchange # number of constraint state changes
88
+ int neval # number of cost evaluations in line search
89
+ int nupdate # number of Cholesky updates in line search
90
+
91
+
92
+
93
+ #---------------------------------- mjData ---------------------------------------------
94
+ ctypedef struct mjData:
95
+ # constant sizes
96
+ int nstack # number of mjtNums that can fit in stack
97
+ int nbuffer # size of main buffer in bytes
98
+
99
+ # stack pointer
100
+ int pstack # first available mjtNum address in stack
101
+
102
+ # memory utilization stats
103
+ int maxuse_stack # maximum stack allocation
104
+ int maxuse_con # maximum number of contacts
105
+ int maxuse_efc # maximum number of scalar constraints
106
+
107
+ # diagnostics
108
+ mjWarningStat warning[mjNWARNING] # warning statistics
109
+ mjTimerStat timer[mjNTIMER] # timer statistics
110
+ mjSolverStat solver[mjNSOLVER] # solver statistics per iteration
111
+ int solver_iter # number of solver iterations
112
+ int solver_nnz # number of non-zeros in Hessian or efc_AR
113
+ mjtNum solver_fwdinv[2] # forward-inverse comparison: qfrc, efc
114
+
115
+ # variable sizes
116
+ int ne # number of equality constraints
117
+ int nf # number of friction constraints
118
+ int nefc # number of constraints
119
+ int ncon # number of detected contacts
120
+
121
+ # global properties
122
+ mjtNum time # simulation time
123
+ mjtNum energy[2] # potential, kinetic energy
124
+
125
+ #-------------------------------- end of info header
126
+
127
+ # buffers
128
+ void* buffer # main buffer; all pointers point in it (nbuffer bytes)
129
+ mjtNum* stack # stack buffer (nstack mjtNums)
130
+
131
+ #-------------------------------- main inputs and outputs of the computation
132
+
133
+ # state
134
+ mjtNum* qpos # position (nq x 1)
135
+ mjtNum* qvel # velocity (nv x 1)
136
+ mjtNum* act # actuator activation (na x 1)
137
+ mjtNum* qacc_warmstart # acceleration used for warmstart (nv x 1)
138
+
139
+ # control
140
+ mjtNum* ctrl # control (nu x 1)
141
+ mjtNum* qfrc_applied # applied generalized force (nv x 1)
142
+ mjtNum* xfrc_applied # applied Cartesian force/torque (nbody x 6)
143
+
144
+ # dynamics
145
+ mjtNum* qacc # acceleration (nv x 1)
146
+ mjtNum* act_dot # time-derivative of actuator activation (na x 1)
147
+
148
+ # mocap data
149
+ mjtNum* mocap_pos # positions of mocap bodies (nmocap x 3)
150
+ mjtNum* mocap_quat # orientations of mocap bodies (nmocap x 4)
151
+
152
+ # user data
153
+ mjtNum* userdata # user data, not touched by engine (nuserdata x 1)
154
+
155
+ # sensors
156
+ mjtNum* sensordata # sensor data array (nsensordata x 1)
157
+
158
+ #-------------------------------- POSITION dependent
159
+
160
+ # computed by mj_fwdPosition/mj_kinematics
161
+ mjtNum* xpos # Cartesian position of body frame (nbody x 3)
162
+ mjtNum* xquat # Cartesian orientation of body frame (nbody x 4)
163
+ mjtNum* xmat # Cartesian orientation of body frame (nbody x 9)
164
+ mjtNum* xipos # Cartesian position of body com (nbody x 3)
165
+ mjtNum* ximat # Cartesian orientation of body inertia (nbody x 9)
166
+ mjtNum* xanchor # Cartesian position of joint anchor (njnt x 3)
167
+ mjtNum* xaxis # Cartesian joint axis (njnt x 3)
168
+ mjtNum* geom_xpos # Cartesian geom position (ngeom x 3)
169
+ mjtNum* geom_xmat # Cartesian geom orientation (ngeom x 9)
170
+ mjtNum* site_xpos # Cartesian site position (nsite x 3)
171
+ mjtNum* site_xmat # Cartesian site orientation (nsite x 9)
172
+ mjtNum* cam_xpos # Cartesian camera position (ncam x 3)
173
+ mjtNum* cam_xmat # Cartesian camera orientation (ncam x 9)
174
+ mjtNum* light_xpos # Cartesian light position (nlight x 3)
175
+ mjtNum* light_xdir # Cartesian light direction (nlight x 3)
176
+
177
+ # computed by mj_fwdPosition/mj_comPos
178
+ mjtNum* subtree_com # center of mass of each subtree (nbody x 3)
179
+ mjtNum* cdof # com-based motion axis of each dof (nv x 6)
180
+ mjtNum* cinert # com-based body inertia and mass (nbody x 10)
181
+
182
+ # computed by mj_fwdPosition/mj_tendon
183
+ int* ten_wrapadr # start address of tendon's path (ntendon x 1)
184
+ int* ten_wrapnum # number of wrap points in path (ntendon x 1)
185
+ int* ten_J_rownnz # number of non-zeros in Jacobian row (ntendon x 1)
186
+ int* ten_J_rowadr # row start address in colind array (ntendon x 1)
187
+ int* ten_J_colind # column indices in sparse Jacobian (ntendon x nv)
188
+ mjtNum* ten_length # tendon lengths (ntendon x 1)
189
+ mjtNum* ten_J # tendon Jacobian (ntendon x nv)
190
+ int* wrap_obj # geom id; -1: site; -2: pulley (nwrap*2 x 1)
191
+ mjtNum* wrap_xpos # Cartesian 3D points in all path (nwrap*2 x 3)
192
+
193
+ # computed by mj_fwdPosition/mj_transmission
194
+ mjtNum* actuator_length # actuator lengths (nu x 1)
195
+ mjtNum* actuator_moment # actuator moment arms (nu x nv)
196
+
197
+ # computed by mj_fwdPosition/mj_crb
198
+ mjtNum* crb # com-based composite inertia and mass (nbody x 10)
199
+ mjtNum* qM # total inertia (nM x 1)
200
+
201
+ # computed by mj_fwdPosition/mj_factorM
202
+ mjtNum* qLD # L'*D*L factorization of M (nM x 1)
203
+ mjtNum* qLDiagInv # 1/diag(D) (nv x 1)
204
+ mjtNum* qLDiagSqrtInv # 1/sqrt(diag(D)) (nv x 1)
205
+
206
+ # computed by mj_fwdPosition/mj_collision
207
+ mjContact* contact # list of all detected contacts (nconmax x 1)
208
+
209
+ # computed by mj_fwdPosition/mj_makeConstraint
210
+ int* efc_type # constraint type (mjtConstraint) (njmax x 1)
211
+ int* efc_id # id of object of specified type (njmax x 1)
212
+ int* efc_J_rownnz # number of non-zeros in Jacobian row (njmax x 1)
213
+ int* efc_J_rowadr # row start address in colind array (njmax x 1)
214
+ int* efc_J_rowsuper # number of subsequent rows in supernode (njmax x 1)
215
+ int* efc_J_colind # column indices in sparse Jacobian (njmax x nv)
216
+ int* efc_JT_rownnz # number of non-zeros in Jacobian row T (nv x 1)
217
+ int* efc_JT_rowadr # row start address in colind array T (nv x 1)
218
+ int* efc_JT_rowsuper # number of subsequent rows in supernode T (nv x 1)
219
+ int* efc_JT_colind # column indices in sparse Jacobian T (nv x njmax)
220
+ mjtNum* efc_solref # constraint solver reference (njmax x mjNREF)
221
+ mjtNum* efc_solimp # constraint solver impedance (njmax x mjNIMP)
222
+ mjtNum* efc_J # constraint Jacobian (njmax x nv)
223
+ mjtNum* efc_JT # sparse constraint Jacobian transposed (nv x njmax)
224
+ mjtNum* efc_pos # constraint position (equality, contact) (njmax x 1)
225
+ mjtNum* efc_margin # inclusion margin (contact) (njmax x 1)
226
+ mjtNum* efc_frictionloss # frictionloss (friction) (njmax x 1)
227
+ mjtNum* efc_diagApprox # approximation to diagonal of A (njmax x 1)
228
+ mjtNum* efc_KBIP # stiffness, damping, impedance, imp' (njmax x 4)
229
+ mjtNum* efc_D # constraint mass (njmax x 1)
230
+ mjtNum* efc_R # inverse constraint mass (njmax x 1)
231
+
232
+ # computed by mj_fwdPosition/mj_projectConstraint
233
+ int* efc_AR_rownnz # number of non-zeros in AR (njmax x 1)
234
+ int* efc_AR_rowadr # row start address in colind array (njmax x 1)
235
+ int* efc_AR_colind # column indices in sparse AR (njmax x njmax)
236
+ mjtNum* efc_AR # J*inv(M)*J' + R (njmax x njmax)
237
+
238
+ #-------------------------------- POSITION, VELOCITY dependent
239
+
240
+ # computed by mj_fwdVelocity
241
+ mjtNum* ten_velocity # tendon velocities (ntendon x 1)
242
+ mjtNum* actuator_velocity # actuator velocities (nu x 1)
243
+
244
+ # computed by mj_fwdVelocity/mj_comVel
245
+ mjtNum* cvel # com-based velocity [3D rot; 3D tran] (nbody x 6)
246
+ mjtNum* cdof_dot # time-derivative of cdof (nv x 6)
247
+
248
+ # computed by mj_fwdVelocity/mj_rne (without acceleration)
249
+ mjtNum* qfrc_bias # C(qpos,qvel) (nv x 1)
250
+
251
+ # computed by mj_fwdVelocity/mj_passive
252
+ mjtNum* qfrc_passive # passive force (nv x 1)
253
+
254
+ # computed by mj_fwdVelocity/mj_referenceConstraint
255
+ mjtNum* efc_vel # velocity in constraint space: J*qvel (njmax x 1)
256
+ mjtNum* efc_aref # reference pseudo-acceleration (njmax x 1)
257
+
258
+ # computed by mj_sensorVel
259
+ mjtNum* subtree_linvel # linear velocity of subtree com (nbody x 3)
260
+ mjtNum* subtree_angmom # angular momentum about subtree com (nbody x 3)
261
+
262
+ #-------------------------------- POSITION, VELOCITY, CONTROL/ACCELERATION dependent
263
+
264
+ # computed by mj_fwdActuation
265
+ mjtNum* actuator_force # actuator force in actuation space (nu x 1)
266
+ mjtNum* qfrc_actuator # actuator force (nv x 1)
267
+
268
+ # computed by mj_fwdAcceleration
269
+ mjtNum* qfrc_unc # net unconstrained force (nv x 1)
270
+ mjtNum* qacc_unc # unconstrained acceleration (nv x 1)
271
+
272
+ # computed by mj_fwdConstraint/mj_inverse
273
+ mjtNum* efc_b # linear cost term: J*qacc_unc - aref (njmax x 1)
274
+ mjtNum* efc_force # constraint force in constraint space (njmax x 1)
275
+ int* efc_state # constraint state (mjtConstraintState) (njmax x 1)
276
+ mjtNum* qfrc_constraint # constraint force (nv x 1)
277
+
278
+ # computed by mj_inverse
279
+ mjtNum* qfrc_inverse # net external force; should equal: (nv x 1)
280
+ # qfrc_applied + J'*xfrc_applied + qfrc_actuator
281
+
282
+ # computed by mj_sensorAcc/mj_rnePostConstraint; rotation:translation format
283
+ mjtNum* cacc # com-based acceleration (nbody x 6)
284
+ mjtNum* cfrc_int # com-based interaction force with parent (nbody x 6)
285
+ mjtNum* cfrc_ext # com-based external force on body (nbody x 6)
286
+
287
+
288
+ #---------------------------------- callback function types ----------------------------
289
+
290
+ # generic MuJoCo function
291
+ ctypedef void (*mjfGeneric)(const mjModel* m, mjData* d)
292
+
293
+ # sensor simulation
294
+ ctypedef void (*mjfSensor)(const mjModel* m, mjData* d, int stage)
295
+
296
+ # timer
297
+ ctypedef long long int (*mjfTime)();
298
+
299
+ # actuator dynamics, gain, bias
300
+ ctypedef mjtNum (*mjfAct)(const mjModel* m, const mjData* d, int id);
301
+
302
+ # solver impedance
303
+ ctypedef mjtNum (*mjfSolImp)(const mjModel* m, const mjData* d, int id,
304
+ mjtNum distance, mjtNum* constimp);
305
+
306
+ # solver reference
307
+ ctypedef void (*mjfSolRef)(const mjModel* m, const mjData* d, int id,
308
+ mjtNum constimp, mjtNum imp, int dim, mjtNum* ref);
309
+
310
+ # collision detection
311
+ ctypedef int (*mjfCollision)(const mjModel* m, const mjData* d,
312
+ mjContact* con, int g1, int g2, mjtNum margin);
mujoco-py-2.1.2.14/mujoco_py/pxd/mjmodel.pxd ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cdef struct mjVisual_global_: # global parameters
2
+ float fovy # y-field of view (deg) for free camera
3
+ float ipd # inter-pupilary distance for free camera
4
+ float linewidth # line width for wireframe rendering
5
+ float glow # glow coefficient for selected body
6
+ int offwidth # width of offscreen buffer
7
+ int offheight # height of offscreen buffer
8
+
9
+ cdef struct mjVisual_quality: # rendering quality
10
+ int shadowsize # size of shadowmap texture
11
+ int offsamples # number of multisamples for offscreen rendering
12
+ int numslices # number of slices for Glu drawing
13
+ int numstacks # number of stacks for Glu drawing
14
+ int numquads # number of quads for box rendering
15
+
16
+ cdef struct mjVisual_headlight: # head light
17
+ float ambient[3] # ambient rgb (alpha=1)
18
+ float diffuse[3] # diffuse rgb (alpha=1)
19
+ float specular[3] # specular rgb (alpha=1)
20
+ int active # is headlight active
21
+
22
+ cdef struct mjVisual_map: # mapping
23
+ float stiffness # mouse perturbation stiffness (space->force)
24
+ float stiffnessrot # mouse perturbation stiffness (space->torque)
25
+ float force # from force units to space units
26
+ float torque # from torque units to space units
27
+ float alpha # scale geom alphas when transparency is enabled
28
+ float fogstart # OpenGL fog starts at fogstart * mjModel.stat.extent
29
+ float fogend # OpenGL fog ends at fogend * mjModel.stat.extent
30
+ float znear # near clipping plane = znear * mjModel.stat.extent
31
+ float zfar # far clipping plane = zfar * mjModel.stat.extent
32
+ float haze # haze ratio
33
+ float shadowclip # directional light: shadowclip * mjModel.stat.extent
34
+ float shadowscale # spot light: shadowscale * light.cutoff
35
+ float actuatortendon # scale tendon width
36
+
37
+ cdef struct mjVisual_scale: # scale of decor elements relative to mean body size
38
+ float forcewidth # width of force arrow
39
+ float contactwidth # contact width
40
+ float contactheight # contact height
41
+ float connect # autoconnect capsule width
42
+ float com # com radius
43
+ float camera # camera object
44
+ float light # light object
45
+ float selectpoint # selection point
46
+ float jointlength # joint length
47
+ float jointwidth # joint width
48
+ float actuatorlength # actuator length
49
+ float actuatorwidth # actuator width
50
+ float framelength # bodyframe axis length
51
+ float framewidth # bodyframe axis width
52
+ float constraint # constraint width
53
+ float slidercrank # slidercrank width
54
+
55
+ cdef struct mjVisual_rgba: # color of decor elements
56
+ float fog[4] # external force
57
+ float haze[4] # haze
58
+ float force[4] # external force
59
+ float inertia[4] # inertia box
60
+ float joint[4] # joint
61
+ float actuator[4] # actuator
62
+ float actuatornegative[4] # actuator, negative limit
63
+ float actuatorpositive[4] # actuator, positive limit
64
+ float com[4] # center of mass
65
+ float camera[4] # camera object
66
+ float light[4] # light object
67
+ float selectpoint[4] # selection point
68
+ float connect[4] # auto connect
69
+ float contactpoint[4] # contact point
70
+ float contactforce[4] # contact force
71
+ float contactfriction[4] # contact friction force
72
+ float contacttorque[4] # contact torque
73
+ float contactgap[4] # contact point in gap
74
+ float rangefinder[4] # rangefinder ray
75
+ float constraint[4] # constraint
76
+ float slidercrank[4] # slidercrank
77
+ float crankbroken[4] # used when crank must be stretched/broken
78
+
79
+ cdef extern from "mjmodel.h" nogil:
80
+ # ---------------------------- floating-point definitions -------------------------------
81
+ ctypedef double mjtNum
82
+
83
+ # global constants
84
+ enum: mjPI
85
+ enum: mjMAXVAL
86
+ enum: mjMINMU
87
+ enum: mjMINIMP
88
+ enum: mjMAXIMP
89
+ enum: mjMAXCONPAIR
90
+ enum: mjMAXVFS
91
+ enum: mjMAXVFSNAME
92
+
93
+ # ---------------------------- sizes ----------------------------------------------------
94
+ enum: mjNEQDATA
95
+ enum: mjNDYN
96
+ enum: mjNGAIN
97
+ enum: mjNBIAS
98
+ enum: mjNREF
99
+ enum: mjNIMP
100
+ enum: mjNSOLVER
101
+
102
+ # ---------------------------- primitive types (mjt) ------------------------------------
103
+ ctypedef unsigned char mjtByte # used for true/false
104
+
105
+
106
+ ctypedef enum mjtDisableBit: # disable default feature bitflags
107
+ mjDSBL_CONSTRAINT # entire constraint solver
108
+ mjDSBL_EQUALITY # equality constraints
109
+ mjDSBL_FRICTIONLOSS # joint and tendon frictionloss constraints
110
+ mjDSBL_LIMIT # joint and tendon limit constraints
111
+ mjDSBL_CONTACT # contact constraints
112
+ mjDSBL_PASSIVE # passive forces
113
+ mjDSBL_GRAVITY # gravitational forces
114
+ mjDSBL_CLAMPCTRL # clamp control to specified range
115
+ mjDSBL_WARMSTART # warmstart constraint solver
116
+ mjDSBL_FILTERPARENT # remove collisions with parent body
117
+ mjDSBL_ACTUATION # apply actuation forces
118
+ mjDSBL_REFSAFE # integrator safety: make ref[0]>=2*timestep
119
+ enum: mjNDISABLE # number of disable flags
120
+
121
+ ctypedef enum mjtEnableBit: # enable optional feature bitflags
122
+ mjENBL_OVERRIDE # override contact parameters
123
+ mjENBL_ENERGY # energy computation
124
+ mjENBL_FWDINV # record solver statistics
125
+ mjENBL_SENSORNOISE # add noise to sensor data
126
+ enum: mjNENABLE # number of enable flags
127
+
128
+ ctypedef enum mjtJoint: # type of degree of freedom
129
+ mjJNT_FREE = 0, # global position and orientation (quat) (7)
130
+ mjJNT_BALL, # orientation (quat) relative to parent (4)
131
+ mjJNT_SLIDE, # sliding distance along body-fixed axis (1)
132
+ mjJNT_HINGE # rotation angle (rad) around body-fixed axis (1)
133
+
134
+ ctypedef enum mjtGeom: # type of geometric shape
135
+ # regular geom types
136
+ mjGEOM_PLANE = 0, # plane
137
+ mjGEOM_HFIELD, # height field
138
+ mjGEOM_SPHERE, # sphere
139
+ mjGEOM_CAPSULE, # capsule
140
+ mjGEOM_ELLIPSOID, # ellipsoid
141
+ mjGEOM_CYLINDER, # cylinder
142
+ mjGEOM_BOX, # box
143
+ mjGEOM_MESH, # mesh
144
+
145
+ mjNGEOMTYPES, # number of regular geom types
146
+
147
+ # rendering-only geom types: not used in mjModel, not counted in mjNGEOMTYPES
148
+ mjGEOM_ARROW = 100, # arrow
149
+ mjGEOM_ARROW1, # arrow without wedges
150
+ mjGEOM_ARROW2, # arrow in both directions
151
+ mjGEOM_LABEL, # text label
152
+
153
+ mjGEOM_NONE = 1001 # missing geom type
154
+
155
+
156
+ ctypedef enum mjtCamLight: # tracking mode for camera and light
157
+ mjCAMLIGHT_FIXED = 0, # pos and rot fixed in body
158
+ mjCAMLIGHT_TRACK, # pos tracks body, rot fixed in global
159
+ mjCAMLIGHT_TRACKCOM, # pos tracks subtree com, rot fixed in body
160
+ mjCAMLIGHT_TARGETBODY, # pos fixed in body, rot tracks target body
161
+ mjCAMLIGHT_TARGETBODYCOM # pos fixed in body, rot tracks target subtree com
162
+
163
+
164
+ ctypedef enum mjtTexture: # type of texture
165
+ mjTEXTURE_2D = 0, # 2d texture, suitable for planes and hfields
166
+ mjTEXTURE_CUBE, # cube texture, suitable for all other geom types
167
+ mjTEXTURE_SKYBOX # cube texture used as skybox
168
+
169
+
170
+ ctypedef enum mjtIntegrator: # integrator mode
171
+ mjINT_EULER = 0, # semi-implicit Euler
172
+ mjINT_RK4 # 4th-order Runge Kutta
173
+
174
+
175
+ ctypedef enum mjtCollision: # collision mode for selecting geom pairs
176
+ mjCOL_ALL = 0, # test precomputed and dynamic pairs
177
+ mjCOL_PAIR, # test predefined pairs only
178
+ mjCOL_DYNAMIC # test dynamic pairs only
179
+
180
+
181
+ ctypedef enum mjtCone: # type of friction cone
182
+ mjCONE_PYRAMIDAL = 0, # pyramidal
183
+ mjCONE_ELLIPTIC # elliptic
184
+
185
+
186
+ ctypedef enum mjtJacobian: # type of constraint Jacobian
187
+ mjJAC_DENSE = 0, # dense
188
+ mjJAC_SPARSE, # sparse
189
+ mjJAC_AUTO # dense if nv<60, sparse otherwise
190
+
191
+
192
+ ctypedef enum mjtSolver: # constraint solver algorithm
193
+ mjSOL_PGS = 0, # PGS (dual)
194
+ mjSOL_CG, # CG (primal)
195
+ mjSOL_NEWTON # Newton (primal)
196
+
197
+
198
+ ctypedef enum mjtImp: # how to interpret solimp parameters
199
+ mjIMP_CONSTANT = 0, # constant solimp[1]
200
+ mjIMP_SIGMOID, # sigmoid from solimp[0] to solimp[1], width solimp[2]
201
+ mjIMP_LINEAR, # piece-wise linear sigmoid
202
+ mjIMP_USER # impedance computed by callback
203
+
204
+
205
+ ctypedef enum mjtRef: # how to interpret solref parameters
206
+ mjREF_SPRING = 0, # spring-damper: timeconst=solref[0], dampratio=solref[1]
207
+ mjREF_USER # reference computed by callback
208
+
209
+
210
+ ctypedef enum mjtEq: # type of equality constraint
211
+ mjEQ_CONNECT = 0, # connect two bodies at a point (ball joint)
212
+ mjEQ_WELD, # fix relative position and orientation of two bodies
213
+ mjEQ_JOINT, # couple the values of two scalar joints with cubic
214
+ mjEQ_TENDON, # couple the lengths of two tendons with cubic
215
+ mjEQ_DISTANCE # fix the contact distance betweent two geoms
216
+
217
+
218
+ ctypedef enum mjtWrap: # type of tendon wrap object
219
+ mjWRAP_NONE = 0, # null object
220
+ mjWRAP_JOINT, # constant moment arm
221
+ mjWRAP_PULLEY, # pulley used to split tendon
222
+ mjWRAP_SITE, # pass through site
223
+ mjWRAP_SPHERE, # wrap around sphere
224
+ mjWRAP_CYLINDER # wrap around (infinite) cylinder
225
+
226
+
227
+ ctypedef enum mjtTrn: # type of actuator transmission
228
+ mjTRN_JOINT = 0, # force on joint
229
+ mjTRN_JOINTINPARENT, # force on joint, expressed in parent frame
230
+ mjTRN_SLIDERCRANK, # force via slider-crank linkage
231
+ mjTRN_TENDON, # force on tendon
232
+ mjTRN_SITE, # force on site
233
+
234
+ mjTRN_UNDEFINED = 1000 # undefined transmission type
235
+
236
+
237
+ ctypedef enum mjtDyn: # type of actuator dynamics
238
+ mjDYN_NONE = 0, # no internal dynamics; ctrl specifies force
239
+ mjDYN_INTEGRATOR, # integrator: da/dt = u
240
+ mjDYN_FILTER, # linear filter: da/dt = (u-a) / tau
241
+ mjDYN_USER # user-defined dynamics type
242
+
243
+
244
+ ctypedef enum mjtGain: # type of actuator gain
245
+ mjGAIN_FIXED = 0, # fixed gain
246
+ mjGAIN_USER # user-defined gain type
247
+
248
+
249
+ ctypedef enum mjtBias: # type of actuator bias
250
+ mjBIAS_NONE = 0, # no bias
251
+ mjBIAS_AFFINE, # const + kp*length + kv*velocity
252
+ mjBIAS_USER # user-defined bias type
253
+
254
+
255
+ ctypedef enum mjtObj: # type of MujoCo object
256
+ mjOBJ_UNKNOWN = 0, # unknown object type
257
+ mjOBJ_BODY, # body
258
+ mjOBJ_XBODY, # body, used to access regular frame instead of i-frame
259
+ mjOBJ_JOINT, # joint
260
+ mjOBJ_DOF, # dof
261
+ mjOBJ_GEOM, # geom
262
+ mjOBJ_SITE, # site
263
+ mjOBJ_CAMERA, # camera
264
+ mjOBJ_LIGHT, # light
265
+ mjOBJ_MESH, # mesh
266
+ mjOBJ_HFIELD, # heightfield
267
+ mjOBJ_TEXTURE, # texture
268
+ mjOBJ_MATERIAL, # material for rendering
269
+ mjOBJ_PAIR, # geom pair to include
270
+ mjOBJ_EXCLUDE, # body pair to exclude
271
+ mjOBJ_EQUALITY, # equality constraint
272
+ mjOBJ_TENDON, # tendon
273
+ mjOBJ_ACTUATOR, # actuator
274
+ mjOBJ_SENSOR, # sensor
275
+ mjOBJ_NUMERIC, # numeric
276
+ mjOBJ_TEXT, # text
277
+ mjOBJ_TUPLE, # tuple
278
+ mjOBJ_KEY # keyframe
279
+
280
+
281
+ ctypedef enum mjtConstraint: # type of constraint
282
+ mjCNSTR_EQUALITY = 0, # equality constraint
283
+ mjCNSTR_FRICTION_DOF, # dof friction
284
+ mjCNSTR_FRICTION_TENDON, # tendon friction
285
+ mjCNSTR_LIMIT_JOINT, # joint limit
286
+ mjCNSTR_LIMIT_TENDON, # tendon limit
287
+ mjCNSTR_CONTACT_FRICTIONLESS, # frictionless contact
288
+ mjCNSTR_CONTACT_PYRAMIDAL, # frictional contact, pyramidal friction cone
289
+ mjCNSTR_CONTACT_ELLIPTIC # frictional contact, elliptic friction cone
290
+
291
+
292
+ ctypedef enum mjtConstraintState: # constraint state
293
+ mjCNSTRSTATE_SATISFIED = 0, # constraint satisfied, zero cost (limit, contact)
294
+ mjCNSTRSTATE_QUADRATIC, # quadratic cost (equality, friction, limit, contact)
295
+ mjCNSTRSTATE_LINEARNEG, # linear cost, negative side (friction)
296
+ mjCNSTRSTATE_LINEARPOS, # linear cost, positive side (friction)
297
+ mjCNSTRSTATE_CONE # squared distance to cone cost (elliptic contact)
298
+
299
+
300
+
301
+ ctypedef enum mjtSensor: # type of sensor
302
+ # common robotic sensors, attached to a site
303
+ mjSENS_TOUCH = 0, # scalar contact normal forces summed over sensor zone
304
+ mjSENS_ACCELEROMETER, # 3D linear acceleration, in local frame
305
+ mjSENS_VELOCIMETER, # 3D linear velocity, in local frame
306
+ mjSENS_GYRO, # 3D angular velocity, in local frame
307
+ mjSENS_FORCE, # 3D force between site's body and its parent body
308
+ mjSENS_TORQUE, # 3D torque between site's body and its parent body
309
+ mjSENS_MAGNETOMETER, # 3D magnetometer
310
+ mjSENS_RANGEFINDER, # scalar distance to nearest geom or site along z-axis
311
+
312
+ # sensors related to scalar joints, tendons, actuators
313
+ mjSENS_JOINTPOS, # scalar joint position (hinge and slide only)
314
+ mjSENS_JOINTVEL, # scalar joint velocity (hinge and slide only)
315
+ mjSENS_TENDONPOS, # scalar tendon position
316
+ mjSENS_TENDONVEL, # scalar tendon velocity
317
+ mjSENS_ACTUATORPOS, # scalar actuator position
318
+ mjSENS_ACTUATORVEL, # scalar actuator velocity
319
+ mjSENS_ACTUATORFRC, # scalar actuator force
320
+
321
+ # sensors related to ball joints
322
+ mjSENS_BALLQUAT, # 4D ball joint quaterion
323
+ mjSENS_BALLANGVEL, # 3D ball joint angular velocity
324
+
325
+ # sensors attached to an object with spatial frame: (x)body, geom, site, camera
326
+ mjSENS_FRAMEPOS, # 3D position
327
+ mjSENS_FRAMEQUAT, # 4D unit quaternion orientation
328
+ mjSENS_FRAMEXAXIS, # 3D unit vector: x-axis of object's frame
329
+ mjSENS_FRAMEYAXIS, # 3D unit vector: y-axis of object's frame
330
+ mjSENS_FRAMEZAXIS, # 3D unit vector: z-axis of object's frame
331
+ mjSENS_FRAMELINVEL, # 3D linear velocity
332
+ mjSENS_FRAMEANGVEL, # 3D angular velocity
333
+ mjSENS_FRAMELINACC, # 3D linear acceleration
334
+ mjSENS_FRAMEANGACC, # 3D angular acceleration
335
+
336
+ # sensors related to kinematic subtrees; attached to a body (which is the subtree root)
337
+ mjSENS_SUBTREECOM, # 3D center of mass of subtree
338
+ mjSENS_SUBTREELINVEL, # 3D linear velocity of subtree
339
+ mjSENS_SUBTREEANGMOM, # 3D angular momentum of subtree
340
+
341
+ # user-defined sensor
342
+ mjSENS_USER # sensor data provided by mjcb_sensor callback
343
+
344
+
345
+ ctypedef enum mjtStage: # computation stage
346
+ mjSTAGE_NONE = 0, # no computations
347
+ mjSTAGE_POS, # position-dependent computations
348
+ mjSTAGE_VEL, # velocity-dependent computations
349
+ mjSTAGE_ACC # acceleration/force-dependent computations
350
+
351
+
352
+ ctypedef enum mjtDataType: # data type for sensors
353
+ mjDATATYPE_REAL = 0, # real values, no constraints
354
+ mjDATATYPE_POSITIVE, # positive values; 0 or negative: inactive
355
+ mjDATATYPE_AXIS, # 3D unit vector
356
+ mjDATATYPE_QUAT # unit quaternion
357
+
358
+ #------------------------------ mjVFS --------------------------------------------------
359
+
360
+ ctypedef struct mjVFS: # virtual file system for loading from memory
361
+ int nfile # number of files present
362
+ char filename[mjMAXVFS][mjMAXVFSNAME] # file name without path
363
+ int filesize[mjMAXVFS] # file size in bytes
364
+ void* filedata[mjMAXVFS] # buffer with file data
365
+
366
+
367
+
368
+ #------------------------------ mjOption -----------------------------------------------
369
+
370
+ ctypedef struct mjOption: # physics options
371
+ # timing parameters
372
+ mjtNum timestep # timestep
373
+ mjtNum apirate # update rate for remote API (Hz)
374
+
375
+ # solver parameters
376
+ mjtNum impratio # ratio of friction-to-normal contact impedance
377
+ mjtNum tolerance # solver convergence tolerance
378
+ mjtNum noslip_tolerance # noslip solver tolerance
379
+ mjtNum mpr_tolerance # MPR solver tolerance
380
+
381
+ # physical constants
382
+ mjtNum gravity[3] # gravitational acceleration
383
+ mjtNum wind[3] # wind (for lift, drag and viscosity)
384
+ mjtNum magnetic[3] # global magnetic flux
385
+ mjtNum density # density of medium
386
+ mjtNum viscosity # viscosity of medium
387
+
388
+ # override contact solver parameters (if enabled)
389
+ mjtNum o_margin # margin
390
+ mjtNum o_solref[mjNREF] # solref
391
+ mjtNum o_solimp[mjNIMP] # solimp
392
+
393
+ # discrete settings
394
+ int integrator # integration mode (mjtIntegrator)
395
+ int collision # collision mode (mjtCollision)
396
+ int cone # type of friction cone (mjtCone)
397
+ int jacobian # type of Jacobian (mjtJacobian)
398
+ int solver # solver mode (mjtSolver)
399
+ int iterations # maximum number of solver iterations
400
+ int noslip_iterations # maximum number of noslip solver iterations
401
+ int mpr_iterations # maximum number of MPR solver iterations
402
+ int disableflags # bit flags for disabling standard features
403
+ int enableflags # bit flags for enabling optional features
404
+
405
+ #------------------------------ mjLROpt ------------------------------------------------
406
+
407
+ ctypedef struct mjLROpt:
408
+ # flags
409
+ int mode # which actuators to process (mjtLRMode)
410
+ int useexisting # use existing length range if available
411
+ int uselimit # use joint and tendon limits if available
412
+
413
+ # algorithm parameters
414
+ mjtNum accel # target acceleration used to compute force
415
+ mjtNum maxforce # maximum force; 0: no limit
416
+ mjtNum timeconst # time constant for velocity reduction; min 0.01
417
+ mjtNum timestep # simulation timestep; 0: use mjOption.timestep
418
+ mjtNum inttotal # total simulation time interval
419
+ mjtNum inteval # evaluation time interval (at the end)
420
+ mjtNum tolrange # convergence tolerance (relative to range)
421
+
422
+ #------------------------------ mjVisual -----------------------------------------------
423
+
424
+
425
+ ctypedef struct mjVisual:
426
+ mjVisual_global_ global_ "global"
427
+ mjVisual_quality quality
428
+ mjVisual_headlight headlight
429
+ mjVisual_map map
430
+ mjVisual_scale scale
431
+ mjVisual_rgba rgba
432
+
433
+ #------------------------------ mjStatistic --------------------------------------------
434
+
435
+ ctypedef struct mjStatistic: # model statistics (in qpos0)
436
+ mjtNum meaninertia # mean diagonal inertia
437
+ mjtNum meanmass # mean body mass
438
+ mjtNum meansize # mean body size
439
+ mjtNum extent # spatial extent
440
+ mjtNum center[3] # center of model
441
+
442
+
443
+ # ---------------------------------- mjModel --------------------------------------------
444
+ ctypedef struct mjModel:
445
+ # ------------------------------- sizes
446
+
447
+ # sizes needed at mjModel construction
448
+ int nq # number of generalized coordinates = dim(qpos)
449
+ int nv # number of degrees of freedom = dim(qvel)
450
+ int nu # number of actuators/controls = dim(ctrl)
451
+ int na # number of activation states = dim(act)
452
+ int nbody # number of bodies
453
+ int njnt # number of joints
454
+ int ngeom # number of geoms
455
+ int nsite # number of sites
456
+ int ncam # number of cameras
457
+ int nlight # number of lights
458
+ int nmesh # number of meshes
459
+ int nmeshvert # number of vertices in all meshes
460
+ int nmeshtexvert; # number of vertices with texcoords in all meshes
461
+ int nmeshface # number of triangular faces in all meshes
462
+ int nmeshgraph # number of ints in mesh auxiliary data
463
+ int nskin # number of skins
464
+ int nskinvert # number of vertices in all skins
465
+ int nskintexvert # number of vertiex with texcoords in all skins
466
+ int nskinface # number of triangular faces in all skins
467
+ int nskinbone # number of bones in all skins
468
+ int nskinbonevert # number of vertices in all skin bones
469
+ int nhfield # number of heightfields
470
+ int nhfielddata # number of data points in all heightfields
471
+ int ntex # number of textures
472
+ int ntexdata # number of bytes in texture rgb data
473
+ int nmat # number of materials
474
+ int npair # number of predefined geom pairs
475
+ int nexclude # number of excluded geom pairs
476
+ int neq # number of equality constraints
477
+ int ntendon # number of tendons
478
+ int nwrap # number of wrap objects in all tendon paths
479
+ int nsensor # number of sensors
480
+ int nnumeric # number of numeric custom fields
481
+ int nnumericdata # number of mjtNums in all numeric fields
482
+ int ntext # number of text custom fields
483
+ int ntextdata # number of mjtBytes in all text fields
484
+ int ntuple # number of tuple custom fields
485
+ int ntupledata # number of objects in all tuple fields
486
+ int nkey # number of keyframes
487
+ int nmocap # number of mocap bodies
488
+ int nuser_body # number of mjtNums in body_user
489
+ int nuser_jnt # number of mjtNums in jnt_user
490
+ int nuser_geom # number of mjtNums in geom_user
491
+ int nuser_site # number of mjtNums in site_user
492
+ int nuser_cam # number of mjtNums in cam_user
493
+ int nuser_tendon # number of mjtNums in tendon_user
494
+ int nuser_actuator # number of mjtNums in actuator_user
495
+ int nuser_sensor # number of mjtNums in sensor_user
496
+ int nnames # number of chars in all names
497
+
498
+ # sizes set after jModel construction (only affect mjData)
499
+ int nM # number of non-zeros in sparse inertia matrix
500
+ int nemax # number of potential equality-constraint rows
501
+ int njmax # number of available rows in constraint Jacobian
502
+ int nconmax # number of potential contacts in contact list
503
+ int nstack # number of fields in mjData stack
504
+ int nuserdata # number of extra fields in mjData
505
+ int nsensordata # number of fields in sensor data vector
506
+
507
+ int nbuffer # number of bytes in buffer
508
+
509
+ # ------------------------------- options and statistics
510
+
511
+ mjOption opt # physics options
512
+ mjVisual vis # visualization options
513
+ mjStatistic stat # model statistics
514
+
515
+ # ------------------------------- buffers
516
+
517
+ # main buffer
518
+ void* buffer # main buffer; all pointers point in it (nbuffer)
519
+
520
+ # default generalized coordinates
521
+ mjtNum* qpos0 # qpos values at default pose (nq x 1)
522
+ mjtNum* qpos_spring # reference pose for springs (nq x 1)
523
+
524
+ # bodies
525
+ int* body_parentid # id of body's parent (nbody x 1)
526
+ int* body_rootid # id of root above body (nbody x 1)
527
+ int* body_weldid # id of body that this body is welded to (nbody x 1)
528
+ int* body_mocapid # id of mocap data; -1: none (nbody x 1)
529
+ int* body_jntnum # number of joints for this body (nbody x 1)
530
+ int* body_jntadr # start addr of joints; -1: no joints (nbody x 1)
531
+ int* body_dofnum # number of motion degrees of freedom (nbody x 1)
532
+ int* body_dofadr # start addr of dofs; -1: no dofs (nbody x 1)
533
+ int* body_geomnum # number of geoms (nbody x 1)
534
+ int* body_geomadr # start addr of geoms; -1: no geoms (nbody x 1)
535
+ mjtByte* body_simple # body is simple (has diagonal M) (nbody x 1)
536
+ mjtByte* body_sameframe # inertial frame is same as body frame (nbody x 1)
537
+ mjtNum* body_pos # position offset rel. to parent body (nbody x 3)
538
+ mjtNum* body_quat # orientation offset rel. to parent body (nbody x 4)
539
+ mjtNum* body_ipos # local position of center of mass (nbody x 3)
540
+ mjtNum* body_iquat # local orientation of inertia ellipsoid (nbody x 4)
541
+ mjtNum* body_mass # mass (nbody x 1)
542
+ mjtNum* body_subtreemass # mass of subtree starting at this body (nbody x 1)
543
+ mjtNum* body_inertia # diagonal inertia in ipos/iquat frame (nbody x 3)
544
+ mjtNum* body_invweight0 # mean inv inert in qpos0 (trn, rot) (nbody x 2)
545
+ mjtNum* body_user # user data (nbody x nuser_body)
546
+
547
+ # joints
548
+ int* jnt_type # type of joint (mjtJoint) (njnt x 1)
549
+ int* jnt_qposadr # start addr in 'qpos' for joint's data (njnt x 1)
550
+ int* jnt_dofadr # start addr in 'qvel' for joint's data (njnt x 1)
551
+ int* jnt_bodyid # id of joint's body (njnt x 1)
552
+ int* jnt_group # group for visibility (njnt x 1)
553
+ mjtByte* jnt_limited # does joint have limits (njnt x 1)
554
+ mjtNum* jnt_solref # constraint solver reference: limit (njnt x mjNREF)
555
+ mjtNum* jnt_solimp # constraint solver impedance: limit (njnt x mjNIMP)
556
+ mjtNum* jnt_pos # local anchor position (njnt x 3)
557
+ mjtNum* jnt_axis # local joint axis (njnt x 3)
558
+ mjtNum* jnt_stiffness # stiffness coefficient (njnt x 1)
559
+ mjtNum* jnt_range # joint limits (njnt x 2)
560
+ mjtNum* jnt_margin # min distance for limit detection (njnt x 1)
561
+ mjtNum* jnt_user # user data (njnt x nuser_jnt)
562
+
563
+ # dofs
564
+ int* dof_bodyid # id of dof's body (nv x 1)
565
+ int* dof_jntid # id of dof's joint (nv x 1)
566
+ int* dof_parentid # id of dof's parent; -1: none (nv x 1)
567
+ int* dof_Madr # dof address in M-diagonal (nv x 1)
568
+ int* dof_simplenum # number of consecutive simple dofs (nv x 1)
569
+ mjtNum* dof_solref # constraint solver reference:frictionloss (nv x mjNREF)
570
+ mjtNum* dof_solimp # constraint solver impedance:frictionloss (nv x mjNIMP)
571
+ mjtNum* dof_frictionloss # dof friction loss (nv x 1)
572
+ mjtNum* dof_armature # dof armature inertia/mass (nv x 1)
573
+ mjtNum* dof_damping # damping coefficient (nv x 1)
574
+ mjtNum* dof_invweight0 # inv. diag. inertia in qpos0 (nv x 1)
575
+ mjtNum* dof_M0 # diag. inertia in qpos0 (nv x 1)
576
+
577
+ # geoms
578
+ int* geom_type # geometric type (mjtGeom) (ngeom x 1)
579
+ int* geom_contype # geom contact type (ngeom x 1)
580
+ int* geom_conaffinity # geom contact affinity (ngeom x 1)
581
+ int* geom_condim # contact dimensionality (1, 3, 4, 6) (ngeom x 1)
582
+ int* geom_bodyid # id of geom's body (ngeom x 1)
583
+ int* geom_dataid # id of geom's mesh/hfield (-1: none) (ngeom x 1)
584
+ int* geom_matid # material id for rendering (ngeom x 1)
585
+ int* geom_group # group for visibility (ngeom x 1)
586
+ int* geom_priority # geom contact priority (ngeom x 1)
587
+ mjtByte* geom_sameframe # same as body frame (1) or iframe (2) (ngeom x 1)
588
+ mjtNum* geom_solmix # mixing coef for solref/imp in geom pair (ngeom x 1)
589
+ mjtNum* geom_solref # constraint solver reference: contact (ngeom x mjNREF)
590
+ mjtNum* geom_solimp # constraint solver impedance: contact (ngeom x mjNIMP)
591
+ mjtNum* geom_size # geom-specific size parameters (ngeom x 3)
592
+ mjtNum* geom_rbound # radius of bounding sphere (ngeom x 1)
593
+ mjtNum* geom_pos # local position offset rel. to body (ngeom x 3)
594
+ mjtNum* geom_quat # local orientation offset rel. to body (ngeom x 4)
595
+ mjtNum* geom_friction # friction for (slide, spin, roll) (ngeom x 3)
596
+ mjtNum* geom_margin # detect contact if dist<margin (ngeom x 1)
597
+ mjtNum* geom_gap # include in solver if dist<margin-gap (ngeom x 1)
598
+ mjtNum* geom_user # user data (ngeom x nuser_geom)
599
+ float* geom_rgba # rgba when material is omitted (ngeom x 4)
600
+
601
+ # sites
602
+ int* site_type # geom type for rendering (mjtGeom) (nsite x 1)
603
+ int* site_bodyid # id of site's body (nsite x 1)
604
+ int* site_matid # material id for rendering (nsite x 1)
605
+ int* site_group # group for visibility (nsite x 1)
606
+ mjtByte* site_sameframe # same as body frame (1) or iframe (2) (nsite x 1)
607
+ mjtNum* site_size # geom size for rendering (nsite x 3)
608
+ mjtNum* site_pos # local position offset rel. to body (nsite x 3)
609
+ mjtNum* site_quat # local orientation offset rel. to body (nsite x 4)
610
+ mjtNum* site_user # user data (nsite x nuser_site)
611
+ float* site_rgba # rgba when material is omitted (nsite x 4)
612
+
613
+ # cameras
614
+ int* cam_mode # camera tracking mode (mjtCamLight) (ncam x 1)
615
+ int* cam_bodyid # id of camera's body (ncam x 1)
616
+ int* cam_targetbodyid # id of targeted body; -1: none (ncam x 1)
617
+ mjtNum* cam_pos # position rel. to body frame (ncam x 3)
618
+ mjtNum* cam_quat # orientation rel. to body frame (ncam x 4)
619
+ mjtNum* cam_poscom0 # global position rel. to sub-com in qpos0 (ncam x 3)
620
+ mjtNum* cam_pos0 # global position rel. to body in qpos0 (ncam x 3)
621
+ mjtNum* cam_mat0 # global orientation in qpos0 (ncam x 9)
622
+ mjtNum* cam_fovy # y-field of view (deg) (ncam x 1)
623
+ mjtNum* cam_ipd # inter-pupilary distance (ncam x 1)
624
+ mjtNum* cam_user # user data (ncam x nuser_cam)
625
+
626
+ # lights
627
+ int* light_mode # light tracking mode (mjtCamLight) (nlight x 1)
628
+ int* light_bodyid # id of light's body (nlight x 1)
629
+ int* light_targetbodyid # id of targeted body; -1: none (nlight x 1)
630
+ mjtByte* light_directional # directional light (nlight x 1)
631
+ mjtByte* light_castshadow # does light cast shadows (nlight x 1)
632
+ mjtByte* light_active # is light on (nlight x 1)
633
+ mjtNum* light_pos # position rel. to body frame (nlight x 3)
634
+ mjtNum* light_dir # direction rel. to body frame (nlight x 3)
635
+ mjtNum* light_poscom0 # global position rel. to sub-com in qpos0 (nlight x 3)
636
+ mjtNum* light_pos0 # global position rel. to body in qpos0 (nlight x 3)
637
+ mjtNum* light_dir0 # global direction in qpos0 (nlight x 3)
638
+ float* light_attenuation # OpenGL attenuation (quadratic model) (nlight x 3)
639
+ float* light_cutoff # OpenGL cutoff (nlight x 1)
640
+ float* light_exponent # OpenGL exponent (nlight x 1)
641
+ float* light_ambient # ambient rgb (alpha=1) (nlight x 3)
642
+ float* light_diffuse # diffuse rgb (alpha=1) (nlight x 3)
643
+ float* light_specular # specular rgb (alpha=1) (nlight x 3)
644
+
645
+ # meshes
646
+ int* mesh_vertadr # first vertex address (nmesh x 1)
647
+ int* mesh_vertnum # number of vertices (nmesh x 1)
648
+ int* mesh_texcoordadr # texcoord data address; -1: no texcoord (nmesh x 1)
649
+ int* mesh_faceadr # first face address (nmesh x 1)
650
+ int* mesh_facenum # number of faces (nmesh x 1)
651
+ int* mesh_graphadr # graph data address; -1: no graph (nmesh x 1)
652
+ float* mesh_vert # vertex data for all meshes (nmeshvert x 3)
653
+ float* mesh_normal # vertex normal data for all meshes (nmeshvert x 3)
654
+ float* mesh_texcoord # vertex texcoords for all meshes (nmeshtexvert x 2)
655
+ int* mesh_face # triangle face data (nmeshface x 3)
656
+ int* mesh_graph # convex graph data (nmeshgraph x 1)
657
+
658
+ # skins
659
+ int* skin_matid # skin material id; -1: none (nskin x 1)
660
+ float* skin_rgba # skin rgba (nskin x 4)
661
+ float* skin_inflate # inflate skin in normal direction (nskin x 1)
662
+ int* skin_vertadr # first vertex address (nskin x 1)
663
+ int* skin_vertnum # number of vertices (nskin x 1)
664
+ int* skin_texcoordadr # texcoord data address; -1: no texcoord (nskin x 1)
665
+ int* skin_faceadr # first face address (nskin x 1)
666
+ int* skin_facenum # number of faces (nskin x 1)
667
+ int* skin_boneadr # first bone in skin (nskin x 1)
668
+ int* skin_bonenum # number of bones in skin (nskin x 1)
669
+ float* skin_vert # vertex positions for all skin meshes (nskinvert x 3)
670
+ float* skin_texcoord # vertex texcoords for all skin meshes (nskintexvert x 2)
671
+ int* skin_face # triangle faces for all skin meshes (nskinface x 3)
672
+ int* skin_bonevertadr # first vertex in each bone (nskinbone x 1)
673
+ int* skin_bonevertnum # number of vertices in each bone (nskinbone x 1)
674
+ float* skin_bonebindpos # bind pos of each bone (nskinbone x 3)
675
+ float* skin_bonebindquat # bind quat of each bone (nskinbone x 4)
676
+ int* skin_bonebodyid # body id of each bone (nskinbone x 1)
677
+ int* skin_bonevertid # mesh ids of vertices in each bone (nskinbonevert x 1)
678
+ float* skin_bonevertweight # weights of vertices in each bone (nskinbonevert x 1)
679
+
680
+ # height fields
681
+ mjtNum* hfield_size # (x, y, z_top, z_bottom) (nhfield x 4)
682
+ int* hfield_nrow # number of rows in grid (nhfield x 1)
683
+ int* hfield_ncol # number of columns in grid (nhfield x 1)
684
+ int* hfield_adr # address in hfield_data (nhfield x 1)
685
+ float* hfield_data # elevation data (nhfielddata x 1)
686
+
687
+ # textures
688
+ int* tex_type # texture type (mjtTexture) (ntex x 1)
689
+ int* tex_height # number of rows in texture image (ntex x 1)
690
+ int* tex_width # number of columns in texture image (ntex x 1)
691
+ int* tex_adr # address in rgb (ntex x 1)
692
+ mjtByte* tex_rgb # rgb (alpha = 1) (ntexdata x 1)
693
+
694
+ # materials
695
+ int* mat_texid # texture id; -1: none (nmat x 1)
696
+ mjtByte* mat_texuniform # make texture cube uniform (nmat x 1)
697
+ float* mat_texrepeat # texture repetition for 2d mapping (nmat x 2)
698
+ float* mat_emission # emission (x rgb) (nmat x 1)
699
+ float* mat_specular # specular (x white) (nmat x 1)
700
+ float* mat_shininess # shininess coef (nmat x 1)
701
+ float* mat_reflectance # reflectance (0: disable) (nmat x 1)
702
+ float* mat_rgba # rgba (nmat x 4)
703
+
704
+ # predefined geom pairs for collision detection; has precedence over exclude
705
+ int* pair_dim # contact dimensionality (npair x 1)
706
+ int* pair_geom1 # id of geom1 (npair x 1)
707
+ int* pair_geom2 # id of geom2 (npair x 1)
708
+ int* pair_signature # (body1+1)<<16 + body2+1 (npair x 1)
709
+ mjtNum* pair_solref # constraint solver reference: contact (npair x mjNREF)
710
+ mjtNum* pair_solimp # constraint solver impedance: contact (npair x mjNIMP)
711
+ mjtNum* pair_margin # detect contact if dist<margin (npair x 1)
712
+ mjtNum* pair_gap # include in solver if dist<margin-gap (npair x 1)
713
+ mjtNum* pair_friction # tangent1, 2, spin, roll1, 2 (npair x 5)
714
+
715
+ # excluded body pairs for collision detection
716
+ int* exclude_signature # (body1+1)<<16 + body2+1 (nexclude x 1)
717
+
718
+ # equality constraints
719
+ int* eq_type # constraint type (mjtEq) (neq x 1)
720
+ int* eq_obj1id # id of object 1 (neq x 1)
721
+ int* eq_obj2id # id of object 2 (neq x 1)
722
+ mjtByte* eq_active # enable/disable constraint (neq x 1)
723
+ mjtNum* eq_solref # constraint solver reference (neq x mjNREF)
724
+ mjtNum* eq_solimp # constraint solver impedance (neq x mjNIMP)
725
+ mjtNum* eq_data # numeric data for constraint (neq x mjNEQDATA)
726
+
727
+ # tendons
728
+ int* tendon_adr # address of first object in tendon's path (ntendon x 1)
729
+ int* tendon_num # number of objects in tendon's path (ntendon x 1)
730
+ int* tendon_matid # material id for rendering (ntendon x 1)
731
+ int* tendon_group # group for visibility (ntendon x 1)
732
+ mjtByte* tendon_limited # does tendon have length limits (ntendon x 1)
733
+ mjtNum* tendon_width # width for rendering (ntendon x 1)
734
+ mjtNum* tendon_solref_lim # constraint solver reference: limit (ntendon x mjNREF)
735
+ mjtNum* tendon_solimp_lim # constraint solver impedance: limit (ntendon x mjNIMP)
736
+ mjtNum* tendon_solref_fri # constraint solver reference: friction (ntendon x mjNREF)
737
+ mjtNum* tendon_solimp_fri # constraint solver impedance: friction (ntendon x mjNIMP)
738
+ mjtNum* tendon_range # tendon length limits (ntendon x 2)
739
+ mjtNum* tendon_margin # min distance for limit detection (ntendon x 1)
740
+ mjtNum* tendon_stiffness # stiffness coefficient (ntendon x 1)
741
+ mjtNum* tendon_damping # damping coefficient (ntendon x 1)
742
+ mjtNum* tendon_frictionloss; # loss due to friction (ntendon x 1)
743
+ mjtNum* tendon_lengthspring; # tendon length in qpos_spring (ntendon x 1)
744
+ mjtNum* tendon_length0 # tendon length in qpos0 (ntendon x 1)
745
+ mjtNum* tendon_invweight0 # inv. weight in qpos0 (ntendon x 1)
746
+ mjtNum* tendon_user # user data (ntendon x nuser_tendon)
747
+ float* tendon_rgba # rgba when material is omitted (ntendon x 4)
748
+
749
+ # list of all wrap objects in tendon paths
750
+ int* wrap_type # wrap object type (mjtWrap) (nwrap x 1)
751
+ int* wrap_objid # object id: geom, site, joint (nwrap x 1)
752
+ mjtNum* wrap_prm # divisor, joint coef, or site id (nwrap x 1)
753
+
754
+ # actuators
755
+ int* actuator_trntype # transmission type (mjtTrn) (nu x 1)
756
+ int* actuator_dyntype # dynamics type (mjtDyn) (nu x 1)
757
+ int* actuator_gaintype # gain type (mjtGain) (nu x 1)
758
+ int* actuator_biastype # bias type (mjtBias) (nu x 1)
759
+ int* actuator_trnid # transmission id: joint, tendon, site (nu x 2)
760
+ int* actuator_group # group for visibility (nu x 1)
761
+ mjtByte* actuator_ctrllimited; # is control limited (nu x 1)
762
+ mjtByte* actuator_forcelimited;# is force limited (nu x 1)
763
+ mjtNum* actuator_dynprm # dynamics parameters (nu x mjNDYN)
764
+ mjtNum* actuator_gainprm # gain parameters (nu x mjNGAIN)
765
+ mjtNum* actuator_biasprm # bias parameters (nu x mjNBIAS)
766
+ mjtNum* actuator_ctrlrange # range of controls (nu x 2)
767
+ mjtNum* actuator_forcerange; # range of forces (nu x 2)
768
+ mjtNum* actuator_gear # scale length and transmitted force (nu x 6)
769
+ mjtNum* actuator_cranklength; # crank length for slider-crank (nu x 1)
770
+ mjtNum* actuator_acc0 # acceleration from unit force in qpos0 (nu x 1)
771
+ mjtNum* actuator_length0 # actuator length in qpos0 (nu x 1)
772
+ mjtNum* actuator_lengthrange # ... not yet implemented ??? (nu x 2)
773
+ mjtNum* actuator_user # user data (nu x nuser_actuator)
774
+
775
+ # sensors
776
+ int* sensor_type # sensor type (mjtSensor) (nsensor x 1)
777
+ int* sensor_datatype # numeric data type (mjtDataType) (nsensor x 1)
778
+ int* sensor_needstage # required compute stage (mjtStage) (nsensor x 1)
779
+ int* sensor_objtype # type of sensorized object (mjtObj) (nsensor x 1)
780
+ int* sensor_objid # id of sensorized object (nsensor x 1)
781
+ int* sensor_dim # number of scalar outputs (nsensor x 1)
782
+ int* sensor_adr # address in sensor array (nsensor x 1)
783
+ mjtNum* sensor_cutoff # cutoff for real and positive; 0: ignore (nsensor x 1)
784
+ mjtNum* sensor_noise # noise standard deviation (nsensor x 1)
785
+ mjtNum* sensor_user # user data (nsensor x nuser_sensor)
786
+
787
+ # custom numeric fields
788
+ int* numeric_adr # address of field in numeric_data (nnumeric x 1)
789
+ int* numeric_size # size of numeric field (nnumeric x 1)
790
+ mjtNum* numeric_data # array of all numeric fields (nnumericdata x 1)
791
+
792
+ # custom text fields
793
+ int* text_adr # address of text in text_data (ntext x 1)
794
+ int* text_size # size of text field (strlen+1) (ntext x 1)
795
+ char* text_data # array of all text fields (0-terminated) (ntextdata x 1)
796
+
797
+ # custom tuple fields
798
+ int* tuple_adr # address of text in text_data (ntuple x 1)
799
+ int* tuple_size # number of objects in tuple (ntuple x 1)
800
+ int* tuple_objtype # array of object types in all tuples (ntupledata x 1)
801
+ int* tuple_objid # array of object ids in all tuples (ntupledata x 1)
802
+ mjtNum* tuple_objprm # array of object params in all tuples (ntupledata x 1)
803
+
804
+ # keyframes
805
+ mjtNum* key_time # key time (nkey x 1)
806
+ mjtNum* key_qpos # key position (nkey x nq)
807
+ mjtNum* key_qvel # key velocity (nkey x nv)
808
+ mjtNum* key_act # key activation (nkey x na)
809
+ mjtNum* key_mpos # key mocap position (nkey x 3*nmocap)
810
+ mjtNum* key_mquat # key mocap quaternion (nkey x 4*nmocap)
811
+
812
+ # names
813
+ int* name_bodyadr # body name pointers (nbody x 1)
814
+ int* name_jntadr # joint name pointers (njnt x 1)
815
+ int* name_geomadr # geom name pointers (ngeom x 1)
816
+ int* name_siteadr # site name pointers (nsite x 1)
817
+ int* name_camadr # camera name pointers (ncam x 1)
818
+ int* name_lightadr # light name pointers (nlight x 1)
819
+ int* name_meshadr # mesh name pointers (nmesh x 1)
820
+ int* name_skinadr # skin name pointers (nskin x 1)
821
+ int* name_hfieldadr # hfield name pointers (nhfield x 1)
822
+ int* name_texadr # texture name pointers (ntex x 1)
823
+ int* name_matadr # material name pointers (nmat x 1)
824
+ int* name_pairadr # geom pair name pointers (npair x 1)
825
+ int* name_excludeadr # exclude name pointers (nexclude x 1)
826
+ int* name_eqadr # equality constraint name pointers (neq x 1)
827
+ int* name_tendonadr # tendon name pointers (ntendon x 1)
828
+ int* name_actuatoradr # actuator name pointers (nu x 1)
829
+ int* name_sensoradr # sensor name pointers (nsensor x 1)
830
+ int* name_numericadr # numeric name pointers (nnumeric x 1)
831
+ int* name_textadr # text name pointers (ntext x 1)
832
+ int* name_tupleadr # tuple name pointers (ntuple x 1)
833
+ int* name_keyadr # keyframe name pointers (nkey x 1)
834
+ char* names # names of all objects, 0-terminated (nnames x 1)
mujoco-py-2.1.2.14/mujoco_py/pxd/mjrender.pxd ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cdef extern from "mjrender.h" nogil:
2
+ # Global constants
3
+ enum: mjNAUX
4
+ enum: mjMAXTEXTURE
5
+
6
+ ctypedef enum mjtGridPos: # grid position for overlay
7
+ mjGRID_TOPLEFT = 0, # top left
8
+ mjGRID_TOPRIGHT, # top right
9
+ mjGRID_BOTTOMLEFT, # bottom left
10
+ mjGRID_BOTTOMRIGHT # bottom right
11
+
12
+
13
+ ctypedef enum mjtFramebuffer: # OpenGL framebuffer option
14
+ mjFB_WINDOW = 0, # default/window buffer
15
+ mjFB_OFFSCREEN # offscreen buffer
16
+
17
+
18
+ ctypedef enum mjtFontScale: # font scale, used at context creation
19
+ mjFONTSCALE_100 = 100, # normal scale, suitable in the absence of DPI scaling
20
+ mjFONTSCALE_150 = 150, # 150% scale
21
+ mjFONTSCALE_200 = 200 # 200% scale
22
+
23
+
24
+ ctypedef enum mjtFont: # font type, used at each text operation
25
+ mjFONT_NORMAL = 0, # normal font
26
+ mjFONT_SHADOW, # normal font with shadow (for higher contrast)
27
+ mjFONT_BIG # big font (for user alerts)
28
+
29
+ ctypedef struct mjrRect: # OpenGL rectangle
30
+ int left # left (usually 0)
31
+ int bottom # bottom (usually 0)
32
+ int width # width (usually buffer width)
33
+ int height # height (usually buffer height)
34
+
35
+
36
+ ctypedef struct mjrContext: # custom OpenGL context
37
+ # parameters copied from mjVisual
38
+ float lineWidth # line width for wireframe rendering
39
+ float shadowClip # clipping radius for directional lights
40
+ float shadowScale # fraction of light cutoff for spot lights
41
+ float fogStart # fog start = stat.extent * vis.map.fogstart
42
+ float fogEnd # fog end = stat.extent * vis.map.fogend
43
+ float fogRGBA[4] # fog rgba
44
+ int shadowSize # size of shadow map texture
45
+ int offWidth # width of offscreen buffer
46
+ int offHeight # height of offscreen buffer
47
+ int offSamples # number of offscreen buffer multisamples
48
+
49
+ # parameters specified at creation
50
+ int fontScale; # font scale
51
+ int auxWidth[mjNAUX] # auxiliary buffer width
52
+ int auxHeight[mjNAUX] # auxiliary buffer height
53
+ int auxSamples[mjNAUX] # auxiliary buffer multisamples
54
+
55
+ # offscreen rendering objects
56
+ unsigned int offFBO # offscreen framebuffer object
57
+ unsigned int offFBO_r # offscreen framebuffer for resolving multisamples
58
+ unsigned int offColor # offscreen color buffer
59
+ unsigned int offColor_r # offscreen color buffer for resolving multisamples
60
+ unsigned int offDepthStencil # offscreen depth and stencil buffer
61
+ unsigned int offDepthStencil_r # offscreen depth and stencil buffer for resolving multisamples
62
+
63
+ # shadow rendering objects
64
+ unsigned int shadowFBO # shadow map framebuffer object
65
+ unsigned int shadowTex # shadow map texture
66
+
67
+ # auxiliary buffers
68
+ unsigned int auxFBO[mjNAUX] # auxiliary framebuffer object
69
+ unsigned int auxFBO_r[mjNAUX] # auxiliary framebuffer object for resolving
70
+ unsigned int auxColor[mjNAUX] # auxiliary color buffer
71
+ unsigned int auxColor_r[mjNAUX] # auxiliary color buffer for resolving
72
+
73
+ # texture objects and info
74
+ int ntexture # number of allocated textures
75
+ int textureType[100] # type of texture (mjtTexture)
76
+ unsigned int texture[100] # texture names
77
+
78
+ # displaylist starting positions
79
+ unsigned int basePlane # all planes from model
80
+ unsigned int baseMesh # all meshes from model
81
+ unsigned int baseHField # all hfields from model
82
+ unsigned int baseBuiltin # all buildin geoms, with quality from model
83
+ unsigned int baseFontNormal # normal font
84
+ unsigned int baseFontShadow # shadow font
85
+ unsigned int baseFontBig # big font
86
+
87
+ # displaylist ranges
88
+ int rangePlane # all planes from model
89
+ int rangeMesh # all meshes from model
90
+ int rangeHField # all hfields from model
91
+ int rangeBuiltin # all builtin geoms, with quality from model
92
+ int rangeFont # all characters in font
93
+
94
+ # skin VBOs
95
+ int nskin # number of skins
96
+ unsigned int* skinvertVBO # skin vertex position VBOs
97
+ unsigned int* skinnormalVBO # skin vertex normal VBOs
98
+ unsigned int* skintexcoordVBO # skin vertex texture coordinate VBOs
99
+ unsigned int* skinfaceVBO # skin face index VBOs
100
+
101
+ # character info
102
+ int charWidth[127] # character widths: normal and shadow
103
+ int charWidthBig[127] # chacarter widths: big
104
+ int charHeight # character heights: normal and shadow
105
+ int charHeightBig # character heights: big
106
+
107
+ # capabilities
108
+ int glewInitialized # is glew initialized
109
+ int windowAvailable # is default/window framebuffer available
110
+ int windowSamples # number of samples for default/window framebuffer
111
+ int windowStereo # is stereo available for default/window framebuffer
112
+ int windowDoublebuffer # is default/window framebuffer double buffered
113
+
114
+ # only field that changes after mjr_makeContext
115
+ int currentBuffer # currently active framebuffer: mjFB_WINDOW or mjFB_OFFSCREEN
mujoco-py-2.1.2.14/mujoco_py/pxd/mujoco.pxd ADDED
@@ -0,0 +1,1083 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include "mjmodel.pxd"
2
+ include "mjdata.pxd"
3
+ include "mjrender.pxd"
4
+ include "mjui.pxd"
5
+ include "mjvisualize.pxd"
6
+
7
+
8
+ cdef extern from "mujoco.h" nogil:
9
+ # macros
10
+ #define mjMARKSTACK int _mark = d->pstack;
11
+ #define mjFREESTACK d->pstack = _mark;
12
+ #define mjDISABLED(x) (m->opt.disableflags & (x))
13
+ #define mjENABLED(x) (m->opt.enableflags & (x))
14
+
15
+
16
+ # user error and memory handlers
17
+ void (*mju_user_error)(const char*);
18
+ void (*mju_user_warning)(const char*);
19
+ void* (*mju_user_malloc)(size_t);
20
+ void (*mju_user_free)(void*);
21
+
22
+
23
+ # # callbacks extending computation pipeline
24
+ # mjfGeneric mjcb_passive;
25
+ # mjfGeneric mjcb_control;
26
+ # mjfSensor mjcb_sensor;
27
+ # mjfTime mjcb_time;
28
+ # mjfAct mjcb_act_dyn;
29
+ mjfAct mjcb_act_gain;
30
+ mjfAct mjcb_act_bias;
31
+ # mjfSolImp mjcb_sol_imp;
32
+ # mjfSolRef mjcb_sol_ref;
33
+ #
34
+ #
35
+ # # collision function table
36
+ # mjfCollision mjCOLLISIONFUNC[mjNGEOMTYPES][mjNGEOMTYPES];
37
+ #
38
+ #
39
+ # # string names
40
+ const char* mjDISABLESTRING[mjNDISABLE];
41
+ const char* mjENABLESTRING[mjNENABLE];
42
+ const char* mjTIMERSTRING[mjNTIMER];
43
+ const char* mjLABELSTRING[mjNLABEL];
44
+ const char* mjFRAMESTRING[mjNFRAME];
45
+ const char* mjVISSTRING[mjNVISFLAG][3];
46
+ const char* mjRNDSTRING[mjNRNDFLAG][3];
47
+
48
+
49
+ #---------------------- Activation -----------------------------------------------------
50
+
51
+ # activate license, call mju_error on failure; return 1 if ok, 0 if failure
52
+ int mj_activate(const char* filename);
53
+
54
+ # deactivate license, free memory
55
+ void mj_deactivate();
56
+
57
+ #---------------------- Virtual file system --------------------------------------------
58
+
59
+ # Initialize VFS to empty (no deallocation).
60
+ void mj_defaultVFS(mjVFS* vfs);
61
+
62
+ # Add file to VFS, return 0: success, 1: full, 2: repeated name, -1: not found on disk.
63
+ int mj_addFileVFS(mjVFS* vfs, const char* directory, const char* filename);
64
+
65
+ # Make empty file in VFS, return 0: success, 1: full, 2: repeated name.
66
+ int mj_makeEmptyFileVFS(mjVFS* vfs, const char* filename, int filesize);
67
+
68
+ # Return file index in VFS, or -1 if not found in VFS.
69
+ int mj_findFileVFS(const mjVFS* vfs, const char* filename);
70
+
71
+ # Delete file from VFS, return 0: success, -1: not found in VFS.
72
+ int mj_deleteFileVFS(mjVFS* vfs, const char* filename);
73
+
74
+ # Delete all files from VFS.
75
+ void mj_deleteVFS(mjVFS* vfs);
76
+
77
+ #--------------------- Parse and compile ----------------------------------------------
78
+
79
+ # Parse XML file in MJCF or URDF format, compile it, return low-level model.
80
+ # If vfs is not NULL, look up files in vfs before reading from disk.
81
+ # If error is not NULL, it must have size error_sz.
82
+ mjModel* mj_loadXML(const char* filename, const mjVFS* vfs,
83
+ char* error, int error_sz);
84
+
85
+ # Update XML data structures with info from low-level model, save as MJCF.
86
+ # If error is not NULL, it must have size error_sz.
87
+ int mj_saveLastXML(const char* filename, const mjModel* m,
88
+ char* error, int error_sz);
89
+
90
+ # Free last XML model if loaded. Called internally at each load.
91
+ void mj_freeLastXML();
92
+
93
+ # Print internal XML schema as plain text or HTML, with style-padding or &nbsp;.
94
+ int mj_printSchema(const char* filename, char* buffer, int buffer_sz,
95
+ int flg_html, int flg_pad);
96
+
97
+
98
+ #--------------------- Main simulation ------------------------------------------------
99
+
100
+ # Advance simulation, use control callback to obtain external force and control.
101
+ void mj_step(const mjModel* m, mjData* d);
102
+
103
+ # Advance simulation in two steps: before external force and control is set by user.
104
+ void mj_step1(const mjModel* m, mjData* d);
105
+
106
+ # Advance simulation in two steps: after external force and control is set by user.
107
+ void mj_step2(const mjModel* m, mjData* d);
108
+
109
+ # Forward dynamics: same as mj_step but do not integrate in time.
110
+ void mj_forward(const mjModel* m, mjData* d);
111
+
112
+ # Inverse dynamics: qacc must be set before calling.
113
+ void mj_inverse(const mjModel* m, mjData* d);
114
+
115
+ # Forward dynamics with skip; skipstage is mjtStage.
116
+ void mj_forwardSkip(const mjModel* m, mjData* d,
117
+ int skipstage, int skipsensorenergy);
118
+
119
+ # Inverse dynamics with skip; skipstage is mjtStage.
120
+ void mj_inverseSkip(const mjModel* m, mjData* d,
121
+ int skipstage, int skipsensorenergy);
122
+
123
+ # Forward dynamics with skip; skipstage is mjtStage.
124
+ void mj_forwardSkip(const mjModel* m, mjData* d, int skipstage, int skipsensor);
125
+
126
+ # Inverse dynamics with skip; skipstage is mjtStage.
127
+ void mj_inverseSkip(const mjModel* m, mjData* d, int skipstage, int skipsensor);
128
+
129
+
130
+ #--------------------- Initialization -------------------------------------------------
131
+
132
+ # Set default options for length range computation.
133
+ void mj_defaultLROpt(mjLROpt* opt);
134
+
135
+ # Set solver parameters to default values.
136
+ void mj_defaultSolRefImp(mjtNum* solref, mjtNum* solimp);
137
+
138
+ # Set physics options to default values.
139
+ void mj_defaultOption(mjOption* opt);
140
+
141
+ # Set visual options to default values.
142
+ void mj_defaultVisual(mjVisual* vis);
143
+
144
+ # Copy mjModel, allocate new if dest is NULL.
145
+ mjModel* mj_copyModel(mjModel* dest, const mjModel* src);
146
+
147
+ # Save model to binary MJB file or memory buffer; buffer has precedence when given.
148
+ void mj_saveModel(const mjModel* m, const char* filename, void* buffer, int buffer_sz);
149
+
150
+ # Load model from binary MJB file.
151
+ # If vfs is not NULL, look up file in vfs before reading from disk.
152
+ mjModel* mj_loadModel(const char* filename, mjVFS* vfs);
153
+
154
+ # Free memory allocation in model.
155
+ void mj_deleteModel(mjModel* m);
156
+
157
+ # Return size of buffer needed to hold model.
158
+ int mj_sizeModel(const mjModel* m);
159
+
160
+ # Allocate mjData correponding to given model.
161
+ mjData* mj_makeData(const mjModel* m);
162
+
163
+ # Copy mjData.
164
+ mjData* mj_copyData(mjData* dest, const mjModel* m, const mjData* src);
165
+
166
+ # Reset data to defaults.
167
+ void mj_resetData(const mjModel* m, mjData* d);
168
+
169
+ # Reset data to defaults, fill everything else with debug_value.
170
+ void mj_resetDataDebug(const mjModel* m, mjData* d, unsigned char debug_value);
171
+
172
+ # Reset data, set fields from specified keyframe.
173
+ void mj_resetDataKeyframe(const mjModel* m, mjData* d, int key);
174
+
175
+ # Allocate array of specified size on mjData stack. Call mju_error on stack overflow.
176
+ mjtNum* mj_stackAlloc(mjData* d, int size);
177
+
178
+ # Free memory allocation in mjData.
179
+ void mj_deleteData(mjData* d);
180
+
181
+ # Reset all callbacks to NULL pointers (NULL is the default).
182
+ void mj_resetCallbacks();
183
+
184
+ # Set constant fields of mjModel, corresponding to qpos0 configuration.
185
+ void mj_setConst(mjModel* m, mjData* d);
186
+
187
+ # Set actuator_lengthrange for specified actuator; return 1 if ok, 0 if error.
188
+ int mj_setLengthRange(mjModel* m, mjData* d, int index,
189
+ const mjLROpt* opt, char* error, int error_sz);
190
+
191
+ #--------------------- Printing -------------------------------------------------------
192
+
193
+ # Print model to text file.
194
+ void mj_printModel(const mjModel* m, const char* filename);
195
+
196
+ # Print data to text file.
197
+ void mj_printData(const mjModel* m, mjData* d, const char* filename);
198
+
199
+ # Print matrix to screen.
200
+ void mju_printMat(const mjtNum* mat, int nr, int nc);
201
+
202
+ # Print sparse matrix to screen.
203
+ void mju_printMatSparse(const mjtNum* mat, int nr,
204
+ const int* rownnz, const int* rowadr,
205
+ const int* colind);
206
+
207
+
208
+ #--------------------- Components -----------------------------------------------------
209
+
210
+ # Run position-dependent computations.
211
+ void mj_fwdPosition(const mjModel* m, mjData* d);
212
+
213
+ # Run velocity-dependent computations.
214
+ void mj_fwdVelocity(const mjModel* m, mjData* d);
215
+
216
+ # Compute actuator force qfrc_actuation.
217
+ void mj_fwdActuation(const mjModel* m, mjData* d);
218
+
219
+ # Add up all non-constraint forces, compute qacc_unc.
220
+ void mj_fwdAcceleration(const mjModel* m, mjData* d);
221
+
222
+ # Run selected constraint solver.
223
+ void mj_fwdConstraint(const mjModel* m, mjData* d);
224
+
225
+ # Euler integrator, semi-implicit in velocity.
226
+ void mj_Euler(const mjModel* m, mjData* d);
227
+
228
+ # Runge-Kutta explicit order-N integrator.
229
+ void mj_RungeKutta(const mjModel* m, mjData* d, int N);
230
+
231
+ # Run position-dependent computations in inverse dynamics.
232
+ void mj_invPosition(const mjModel* m, mjData* d);
233
+
234
+ # Run velocity-dependent computations in inverse dynamics.
235
+ void mj_invVelocity(const mjModel* m, mjData* d);
236
+
237
+ # Apply the analytical formula for inverse constraint dynamics.
238
+ void mj_invConstraint(const mjModel* m, mjData* d);
239
+
240
+ # Compare forward and inverse dynamics, save results in fwdinv.
241
+ void mj_compareFwdInv(const mjModel* m, mjData* d);
242
+
243
+
244
+ #--------------------- Sub components -------------------------------------------------
245
+
246
+ # Evaluate position-dependent sensors.
247
+ void mj_sensorPos(const mjModel* m, mjData* d);
248
+
249
+ # Evaluate velocity-dependent sensors.
250
+ void mj_sensorVel(const mjModel* m, mjData* d);
251
+
252
+ # Evaluate acceleration and force-dependent sensors.
253
+ void mj_sensorAcc(const mjModel* m, mjData* d);
254
+
255
+ # Evaluate position-dependent energy (potential).
256
+ void mj_energyPos(const mjModel* m, mjData* d);
257
+
258
+ # Evaluate velocity-dependent energy (kinetic).
259
+ void mj_energyVel(const mjModel* m, mjData* d);
260
+
261
+ # Check qpos, reset if any element is too big or nan.
262
+ void mj_checkPos(const mjModel* m, mjData* d);
263
+
264
+ # Check qvel, reset if any element is too big or nan.
265
+ void mj_checkVel(const mjModel* m, mjData* d);
266
+
267
+ # Check qacc, reset if any element is too big or nan.
268
+ void mj_checkAcc(const mjModel* m, mjData* d);
269
+
270
+ # Run forward kinematics.
271
+ void mj_kinematics(const mjModel* m, mjData* d);
272
+
273
+ # Map inertias and motion dofs to global frame centered at CoM.
274
+ void mj_comPos(const mjModel* m, mjData* d);
275
+
276
+ # Compute camera and light positions and orientations.
277
+ void mj_camlight(const mjModel* m, mjData* d);
278
+
279
+ # Compute tendon lengths, velocities and moment arms.
280
+ void mj_tendon(const mjModel* m, mjData* d);
281
+
282
+ # Compute actuator transmission lengths and moments.
283
+ void mj_transmission(const mjModel* m, mjData* d);
284
+
285
+ # Run composite rigid body inertia algorithm (CRB).
286
+ void mj_crb(const mjModel* m, mjData* d);
287
+
288
+ # Compute sparse L'*D*L factorizaton of inertia matrix.
289
+ void mj_factorM(const mjModel* m, mjData* d);
290
+
291
+ # Solve linear system M * x = y using factorization: x = inv(L'*D*L)*y
292
+ void mj_solveM(const mjModel* m, mjData* d, mjtNum* x, const mjtNum* y, int n);
293
+
294
+ # Half of linear solve: x = sqrt(inv(D))*inv(L')*y
295
+ void mj_solveM2(const mjModel* m, mjData* d, mjtNum* x, const mjtNum* y, int n);
296
+
297
+ # Compute cvel, cdof_dot.
298
+ void mj_comVel(const mjModel* m, mjData* d);
299
+
300
+ # Compute qfrc_passive from spring-dampers, viscosity and density.
301
+ void mj_passive(const mjModel* m, mjData* d);
302
+
303
+ # subtree linear velocity and angular momentum
304
+ void mj_subtreeVel(const mjModel* m, mjData* d);
305
+
306
+ # RNE: compute M(qpos)*qacc + C(qpos,qvel); flg_acc=0 removes inertial term.
307
+ void mj_rne(const mjModel* m, mjData* d, int flg_acc, mjtNum* result);
308
+
309
+ # RNE with complete data: compute cacc, cfrc_ext, cfrc_int.
310
+ void mj_rnePostConstraint(const mjModel* m, mjData* d);
311
+
312
+ # Run collision detection.
313
+ void mj_collision(const mjModel* m, mjData* d);
314
+
315
+ # Construct constraints.
316
+ void mj_makeConstraint(const mjModel* m, mjData* d);
317
+
318
+ # Compute inverse constaint inertia efc_AR.
319
+ void mj_projectConstraint(const mjModel* m, mjData* d);
320
+
321
+ # Compute efc_vel, efc_aref.
322
+ void mj_referenceConstraint(const mjModel* m, mjData* d);
323
+
324
+ # Compute efc_state, efc_force, qfrc_constraint, and (optionally) cone Hessians.
325
+ # If cost is not NULL, set *cost = s(jar) where jar = Jac*qacc-aref.
326
+ void mj_constraintUpdate(const mjModel* m, mjData* d, const mjtNum* jar,
327
+ mjtNum* cost, int flg_coneHessian);
328
+
329
+
330
+ #--------------------- Support --------------------------------------------------------
331
+
332
+ # Add contact to d->contact list; return 0 if success; 1 if buffer full.
333
+ int mj_addContact(const mjModel* m, mjData* d, const mjContact* con);
334
+
335
+ # Determine type of friction cone.
336
+ int mj_isPyramidal(const mjModel* m);
337
+
338
+ # Determine type of constraint Jacobian.
339
+ int mj_isSparse(const mjModel* m);
340
+
341
+ # Determine type of solver (PGS is dual, CG and Newton are primal).
342
+ int mj_isDual(const mjModel* m);
343
+
344
+ # Multiply dense or sparse constraint Jacobian by vector.
345
+ void mj_mulJacVec(const mjModel* m, mjData* d,
346
+ mjtNum* res, const mjtNum* vec);
347
+
348
+ # Multiply dense or sparse constraint Jacobian transpose by vector.
349
+ void mj_mulJacTVec(const mjModel* m, mjData* d, mjtNum* res, const mjtNum* vec);
350
+
351
+ # Compute 3/6-by-nv end-effector Jacobian of global point attached to given body.
352
+ void mj_jac(const mjModel* m, const mjData* d,
353
+ mjtNum* jacp, mjtNum* jacr, const mjtNum point[3], int body);
354
+
355
+ # Compute body frame end-effector Jacobian.
356
+ void mj_jacBody(const mjModel* m, const mjData* d,
357
+ mjtNum* jacp, mjtNum* jacr, int body);
358
+
359
+ # Compute body center-of-mass end-effector Jacobian.
360
+ void mj_jacBodyCom(const mjModel* m, const mjData* d,
361
+ mjtNum* jacp, mjtNum* jacr, int body);
362
+
363
+ # Compute geom end-effector Jacobian.
364
+ void mj_jacGeom(const mjModel* m, const mjData* d,
365
+ mjtNum* jacp, mjtNum* jacr, int geom);
366
+
367
+ # Compute site end-effector Jacobian.
368
+ void mj_jacSite(const mjModel* m, const mjData* d,
369
+ mjtNum* jacp, mjtNum* jacr, int site);
370
+
371
+ # Compute translation end-effector Jacobian of point, and rotation Jacobian of axis.
372
+ void mj_jacPointAxis(const mjModel* m, mjData* d,
373
+ mjtNum* jacPoint, mjtNum* jacAxis,
374
+ const mjtNum point[3], const mjtNum axis[3], int body);
375
+
376
+ # Get id of object with specified name, return -1 if not found; type is mjtObj.
377
+ int mj_name2id(const mjModel* m, int type, const char* name);
378
+
379
+ # Get name of object with specified id, return 0 if invalid type or id; type is mjtObj.
380
+ const char* mj_id2name(const mjModel* m, int type, int id);
381
+
382
+ # Convert sparse inertia matrix M into full (i.e. dense) matrix.
383
+ void mj_fullM(const mjModel* m, mjtNum* dst, const mjtNum* M);
384
+
385
+ # Multiply vector by inertia matrix.
386
+ void mj_mulM(const mjModel* m, const mjData* d, mjtNum* res, const mjtNum* vec);
387
+
388
+ # Multiply vector by (inertia matrix)^(1/2).
389
+ void mj_mulM2(const mjModel* m, const mjData* d, mjtNum* res, const mjtNum* vec);
390
+
391
+ # Add inertia matrix to destination matrix.
392
+ # Destination can be sparse uncompressed, or dense when all int* are NULL
393
+ void mj_addM(const mjModel* m, mjData* d, mjtNum* dst,
394
+ int* rownnz, int* rowadr, int* colind);
395
+
396
+ # Apply cartesian force and torque (outside xfrc_applied mechanism).
397
+ void mj_applyFT(const mjModel* m, mjData* d,
398
+ const mjtNum* force, const mjtNum* torque,
399
+ const mjtNum* point, int body, mjtNum* qfrc_target);
400
+
401
+ # Compute object 6D velocity in object-centered frame, world/local orientation.
402
+ void mj_objectVelocity(const mjModel* m, const mjData* d,
403
+ int objtype, int objid, mjtNum* res, int flg_local);
404
+
405
+ # Compute object 6D acceleration in object-centered frame, world/local orientation.
406
+ void mj_objectAcceleration(const mjModel* m, const mjData* d,
407
+ int objtype, int objid, mjtNum* res, int flg_local);
408
+
409
+ # Extract 6D force:torque for one contact, in contact frame.
410
+ void mj_contactForce(const mjModel* m, const mjData* d, int id, mjtNum* result);
411
+
412
+ # Compute velocity by finite-differencing two positions.
413
+ void mj_differentiatePos(const mjModel* m, mjtNum* qvel, mjtNum dt,
414
+ const mjtNum* qpos1, const mjtNum* qpos2);
415
+
416
+ # Integrate position with given velocity.
417
+ void mj_integratePos(const mjModel* m, mjtNum* qpos, const mjtNum* qvel, mjtNum dt);
418
+
419
+ # Normalize all quaterions in qpos-type vector.
420
+ void mj_normalizeQuat(const mjModel* m, mjtNum* qpos);
421
+
422
+ # Map from body local to global Cartesian coordinates.
423
+ void mj_local2Global(mjData* d, mjtNum* xpos, mjtNum* xmat, const mjtNum* pos, const mjtNum* quat,
424
+ int body, mjtByte sameframe);
425
+
426
+ # Sum all body masses.
427
+ mjtNum mj_getTotalmass(const mjModel* m);
428
+
429
+ # Scale body masses and inertias to achieve specified total mass.
430
+ void mj_setTotalmass(mjModel* m, mjtNum newmass);
431
+
432
+ # Return version number: 1.0.2 is encoded as 102.
433
+ int mj_version();
434
+
435
+
436
+ #--------------------- Ray collisions -------------------------------------------------
437
+
438
+ # Intersect ray (pnt+x*vec, x>=0) with visible geoms, except geoms in bodyexclude.
439
+ # Return geomid and distance (x) to nearest surface, or -1 if no intersection.
440
+ # geomgroup, flg_static are as in mjvOption; geomgroup==NULL skips group exclusion.
441
+ mjtNum mj_ray(const mjModel* m, const mjData* d, const mjtNum* pnt, const mjtNum* vec,
442
+ const mjtByte* geomgroup, mjtByte flg_static, int bodyexclude,
443
+ int* geomid);
444
+
445
+ # Interect ray with hfield, return nearest distance or -1 if no intersection.
446
+ mjtNum mj_rayHfield(const mjModel* m, const mjData* d, int geomid,
447
+ const mjtNum* pnt, const mjtNum* vec);
448
+
449
+ # Interect ray with mesh, return nearest distance or -1 if no intersection.
450
+ mjtNum mj_rayMesh(const mjModel* m, const mjData* d, int geomid,
451
+ const mjtNum* pnt, const mjtNum* vec);
452
+
453
+ # Interect ray with pure geom, return nearest distance or -1 if no intersection.
454
+ mjtNum mju_rayGeom(const mjtNum* pos, const mjtNum* mat, const mjtNum* size,
455
+ const mjtNum* pnt, const mjtNum* vec, int geomtype);
456
+
457
+
458
+ #--------------------- Interaction ----------------------------------------------------
459
+
460
+ # Set default camera.
461
+ void mjv_defaultCamera(mjvCamera* cam);
462
+
463
+ # Set default perturbation.
464
+ void mjv_defaultPerturb(mjvPerturb* pert);
465
+
466
+ # Transform pose from room to model space.
467
+ void mjv_room2model(mjtNum* modelpos, mjtNum* modelquat, const mjtNum* roompos,
468
+ const mjtNum* roomquat, const mjvScene* scn);
469
+
470
+ # Transform pose from model to room space.
471
+ void mjv_model2room(mjtNum* roompos, mjtNum* roomquat, const mjtNum* modelpos,
472
+ const mjtNum* modelquat, const mjvScene* scn);
473
+
474
+ # Get camera info in model space; average left and right OpenGL cameras.
475
+ void mjv_cameraInModel(mjtNum* headpos, mjtNum* forward, mjtNum* up,
476
+ const mjvScene* scn);
477
+
478
+ # Get camera info in room space; average left and right OpenGL cameras.
479
+ void mjv_cameraInRoom(mjtNum* headpos, mjtNum* forward, mjtNum* up,
480
+ const mjvScene* scn);
481
+
482
+ # Get frustum height at unit distance from camera; average left and right OpenGL cameras.
483
+ mjtNum mjv_frustumHeight(const mjvScene* scn);
484
+
485
+ # Rotate 3D vec in horizontal plane by angle between (0,1) and (forward_x,forward_y).
486
+ void mjv_alignToCamera(mjtNum* res, const mjtNum* vec, const mjtNum* forward);
487
+
488
+ # Move camera with mouse; action is mjtMouse.
489
+ void mjv_moveCamera(const mjModel* m, int action, mjtNum reldx, mjtNum reldy,
490
+ const mjvScene* scn, mjvCamera* cam);
491
+
492
+ # Move perturb object with mouse; action is mjtMouse.
493
+ void mjv_movePerturb(const mjModel* m, const mjData* d, int action, mjtNum reldx,
494
+ mjtNum reldy, const mjvScene* scn, mjvPerturb* pert);
495
+
496
+ # Move model with mouse; action is mjtMouse.
497
+ void mjv_moveModel(const mjModel* m, int action, mjtNum reldx, mjtNum reldy,
498
+ const mjtNum* roomup, mjvScene* scn);
499
+
500
+ # Copy perturb pos,quat from selected body; set scale for perturbation.
501
+ void mjv_initPerturb(const mjModel* m, const mjData* d,
502
+ const mjvScene* scn, mjvPerturb* pert);
503
+
504
+ # Set perturb pos,quat in d->mocap when selected body is mocap, and in d->qpos otherwise.
505
+ # Write d->qpos only if flg_paused and subtree root for selected body has free joint.
506
+ void mjv_applyPerturbPose(const mjModel* m, mjData* d, const mjvPerturb* pert,
507
+ int flg_paused);
508
+
509
+ # Set perturb force,torque in d->xfrc_applied, if selected body is dynamic.
510
+ void mjv_applyPerturbForce(const mjModel* m, mjData* d, const mjvPerturb* pert);
511
+
512
+ # Return the average of two OpenGL cameras.
513
+ mjvGLCamera mjv_averageCamera(const mjvGLCamera* cam1, const mjvGLCamera* cam2);
514
+
515
+ # Select geom or skin with mouse, return bodyid; -1: none selected.
516
+ int mjv_select(const mjModel* m, const mjData* d, const mjvOption* vopt,
517
+ mjtNum aspectratio, mjtNum relx, mjtNum rely,
518
+ const mjvScene* scn, mjtNum* selpnt, int* geomid, int* skinid);
519
+
520
+ #--------------------- Visualization --------------------------------------------------
521
+
522
+ # Set default visualization options.
523
+ void mjv_defaultOption(mjvOption* opt);
524
+
525
+ # Set default figure.
526
+ void mjv_defaultFigure(mjvFigure* fig);
527
+
528
+ # Initialize given geom fields when not NULL, set the rest to their default values.
529
+ void mjv_initGeom(mjvGeom* geom, int type, const mjtNum* size,
530
+ const mjtNum* pos, const mjtNum* mat, const float* rgba);
531
+
532
+ # Set (type, size, pos, mat) for connector-type geom between given points.
533
+ # Assume that mjv_initGeom was already called to set all other properties.
534
+ void mjv_makeConnector(mjvGeom* geom, int type, mjtNum width,
535
+ mjtNum a0, mjtNum a1, mjtNum a2,
536
+ mjtNum b0, mjtNum b1, mjtNum b2);
537
+
538
+ # Set default abstract scene.
539
+ void mjv_defaultScene(mjvScene* scn);
540
+
541
+ # Allocate resources in abstract scene.
542
+ void mjv_makeScene(const mjModel* m, mjvScene* scn, int maxgeom);
543
+
544
+ # Free abstract scene.
545
+ void mjv_freeScene(mjvScene* scn);
546
+
547
+ # Update entire scene given model state.
548
+ void mjv_updateScene(const mjModel* m, mjData* d, const mjvOption* opt,
549
+ const mjvPerturb* pert, mjvCamera* cam, int catmask, mjvScene* scn);
550
+
551
+ # Add geoms from selected categories to existing scene.
552
+ void mjv_addGeoms(const mjModel* m, mjData* d, const mjvOption* opt,
553
+ const mjvPerturb* pert, int catmask, mjvScene* scn);
554
+
555
+ # Make list of lights.
556
+ void mjv_makeLights(const mjModel* m, mjData* d, mjvScene* scn);
557
+
558
+ # Update camera only.
559
+ void mjv_updateCamera(const mjModel* m, mjData* d, mjvCamera* cam, mjvScene* scn);
560
+
561
+ # Update skins.
562
+ void mjv_updateSkin(const mjModel* m, mjData* d, mjvScene* scn);
563
+
564
+ #--------------------- OpenGL rendering -----------------------------------------------
565
+
566
+ # Set default mjrContext.
567
+ void mjr_defaultContext(mjrContext* con);
568
+
569
+ # Allocate resources in custom OpenGL context; fontscale is mjtFontScale.
570
+ void mjr_makeContext(const mjModel* m, mjrContext* con, int fontscale);
571
+
572
+ # Change font of existing context.
573
+ void mjr_changeFont(int fontscale, mjrContext* con);
574
+
575
+ # Add Aux buffer with given index to context; free previous Aux buffer.
576
+ void mjr_addAux(int index, int width, int height, int samples, mjrContext* con);
577
+
578
+ # Free resources in custom OpenGL context, set to default.
579
+ void mjr_freeContext(mjrContext* con);
580
+
581
+ # Upload texture to GPU, overwriting previous upload if any.
582
+ void mjr_uploadTexture(const mjModel* m, const mjrContext* con, int texid);
583
+
584
+ # Upload mesh to GPU, overwriting previous upload if any.
585
+ void mjr_uploadMesh(const mjModel* m, const mjrContext* con, int meshid);
586
+
587
+ # Upload height field to GPU, overwriting previous upload if any.
588
+ void mjr_uploadHField(const mjModel* m, const mjrContext* con, int hfieldid);
589
+
590
+ # Make con->currentBuffer current again.
591
+ void mjr_restoreBuffer(const mjrContext* con);
592
+
593
+ # Set OpenGL framebuffer for rendering: mjFB_WINDOW or mjFB_OFFSCREEN.
594
+ # If only one buffer is available, set that buffer and ignore framebuffer argument.
595
+ void mjr_setBuffer(int framebuffer, mjrContext* con);
596
+
597
+ # Read pixels from current OpenGL framebuffer to client buffer.
598
+ # Viewport is in OpenGL framebuffer; client buffer starts at (0,0).
599
+ void mjr_readPixels(unsigned char* rgb, float* depth,
600
+ mjrRect viewport, const mjrContext* con);
601
+
602
+ # Draw pixels from client buffer to current OpenGL framebuffer.
603
+ # Viewport is in OpenGL framebuffer; client buffer starts at (0,0).
604
+ void mjr_drawPixels(const unsigned char* rgb, const float* depth,
605
+ mjrRect viewport, const mjrContext* con);
606
+
607
+ # Blit from src viewpoint in current framebuffer to dst viewport in other framebuffer.
608
+ # If src, dst have different size and flg_depth==0, color is interpolated with GL_LINEAR.
609
+ void mjr_blitBuffer(mjrRect src, mjrRect dst, int flg_color, int flg_depth, const mjrContext* con);
610
+
611
+ # Set Aux buffer for custom OpenGL rendering (call restoreBuffer when done).
612
+ void mjr_setAux(int index, const mjrContext* con);
613
+
614
+ # Blit from Aux buffer to con->currentBuffer.
615
+ void mjr_blitAux(int index, mjrRect src, int left, int bottom, const mjrContext* con);
616
+
617
+ # Draw text at (x,y) in relative coordinates; font is mjtFont.
618
+ void mjr_text(int font, const char* txt, const mjrContext* con,
619
+ float x, float y, float r, float g, float b);
620
+
621
+ # Draw text overlay; font is mjtFont; gridpos is mjtGridPos.
622
+ void mjr_overlay(int font, int gridpos, mjrRect viewport,
623
+ const char* overlay, const char* overlay2, const mjrContext* con);
624
+
625
+ # Get maximum viewport for active buffer.
626
+ mjrRect mjr_maxViewport(const mjrContext* con);
627
+
628
+ # Draw rectangle.
629
+ void mjr_rectangle(mjrRect viewport, float r, float g, float b, float a);
630
+
631
+ # Draw rectangle with centered text.
632
+ void mjr_label(mjrRect viewport, int font, const char* txt,
633
+ float r, float g, float b, float a, float rt, float gt, float bt,
634
+ const mjrContext* con);
635
+
636
+ # Draw 2D figure.
637
+ void mjr_figure(mjrRect viewport, const mjvFigure* fig, const mjrContext* con);
638
+
639
+ # Render 3D scene.
640
+ void mjr_render(mjrRect viewport, mjvScene* scn, const mjrContext* con);
641
+
642
+ # Call glFinish.
643
+ void mjr_finish();
644
+
645
+ # Call glGetError and return result.
646
+ int mjr_getError();
647
+
648
+ # Find first rectangle containing mouse, -1: not found.
649
+ int mjr_findRect(int x, int y, int nrect, const mjrRect* rect);
650
+
651
+ #---------------------- UI framework ---------------------------------------------------
652
+
653
+ # Add definitions to UI.
654
+ void mjui_add(mjUI* ui, const mjuiDef* _def);
655
+
656
+ # Add definitions to UI section.
657
+ void mjui_addToSection(mjUI* ui, int sect, const mjuiDef* _def);
658
+
659
+
660
+ # Compute UI sizes.
661
+ void mjui_resize(mjUI* ui, const mjrContext* con);
662
+
663
+ # Update specific section/item; -1: update all.
664
+ void mjui_update(int section, int item, const mjUI* ui, const mjuiState* state, const mjrContext* con);
665
+
666
+ # Handle UI event, return pointer to changed item, NULL if no change.
667
+ mjuiItem* mjui_event(mjUI* ui, mjuiState* state, const mjrContext* con);
668
+
669
+ # Copy UI image to current buffer.
670
+ void mjui_render(mjUI* ui, const mjuiState* state, const mjrContext* con);
671
+
672
+
673
+ #--------------------- Error and memory -----------------------------------------------
674
+
675
+ # Main error function; does not return to caller.
676
+ void mju_error(const char* msg);
677
+
678
+ # Error function with int argument; msg is a printf format string.
679
+ void mju_error_i(const char* msg, int i);
680
+
681
+ # Error function with string argument.
682
+ void mju_error_s(const char* msg, const char* text);
683
+
684
+ # Main warning function; returns to caller.
685
+ void mju_warning(const char* msg);
686
+
687
+ # Warning function with int argument.
688
+ void mju_warning_i(const char* msg, int i);
689
+
690
+ # Warning function with string argument.
691
+ void mju_warning_s(const char* msg, const char* text);
692
+
693
+ # Clear user error and memory handlers.
694
+ void mju_clearHandlers();
695
+
696
+ # Allocate memory; byte-align on 8; pad size to multiple of 8.
697
+ void* mju_malloc(size_t size);
698
+
699
+ # Free memory, using free() by default.
700
+ void mju_free(void* ptr);
701
+
702
+ # High-level warning function: count warnings in mjData, print only the first.
703
+ void mj_warning(mjData* d, int warning, int info);
704
+
705
+ # Write [datetime, type: message] to MUJOCO_LOG.TXT.
706
+ void mju_writeLog(const char* type, const char* msg);
707
+
708
+
709
+ #--------------------- Standard math --------------------------------------------------
710
+
711
+ #define mjMAX(a,b) (((a) > (b)) ? (a) : (b))
712
+ #define mjMIN(a,b) (((a) < (b)) ? (a) : (b))
713
+
714
+ #ifdef mjUSEDOUBLE
715
+ #define mju_sqrt sqrt
716
+ #define mju_exp exp
717
+ #define mju_sin sin
718
+ #define mju_cos cos
719
+ #define mju_tan tan
720
+ #define mju_asin asin
721
+ #define mju_acos acos
722
+ #define mju_atan2 atan2
723
+ #define mju_tanh tanh
724
+ #define mju_pow pow
725
+ #define mju_abs fabs
726
+ #define mju_log log
727
+ #define mju_log10 log10
728
+ #define mju_floor floor
729
+ #define mju_ceil ceil
730
+
731
+ #else
732
+ #define mju_sqrt sqrtf
733
+ #define mju_exp expf
734
+ #define mju_sin sinf
735
+ #define mju_cos cosf
736
+ #define mju_tan tanf
737
+ #define mju_asin asinf
738
+ #define mju_acos acosf
739
+ #define mju_atan2 atan2f
740
+ #define mju_tanh tanhf
741
+ #define mju_pow powf
742
+ #define mju_abs fabsf
743
+ #define mju_log logf
744
+ #define mju_log10 log10f
745
+ #define mju_floor floorf
746
+ #define mju_ceil ceilf
747
+ #endif
748
+
749
+
750
+ #----------------------------- Vector math --------------------------------------------
751
+
752
+ # Set res = 0.
753
+ void mju_zero3(mjtNum res[3]);
754
+
755
+ # Set res = vec.
756
+ void mju_copy3(mjtNum res[3], const mjtNum data[3]);
757
+
758
+ # Set res = vec*scl.
759
+ void mju_scl3(mjtNum res[3], const mjtNum vec[3], mjtNum scl);
760
+
761
+ # Set res = vec1 + vec2.
762
+ void mju_add3(mjtNum res[3], const mjtNum vec1[3], const mjtNum vec2[3]);
763
+
764
+ # Set res = vec1 - vec2.
765
+ void mju_sub3(mjtNum res[3], const mjtNum vec1[3], const mjtNum vec2[3]);
766
+
767
+ # Set res = res + vec.
768
+ void mju_addTo3(mjtNum res[3], const mjtNum vec[3]);
769
+
770
+ # Set res = res - vec.
771
+ void mju_subFrom3(mjtNum res[3], const mjtNum vec[3]);
772
+
773
+ # Set res = res + vec*scl.
774
+ void mju_addToScl3(mjtNum res[3], const mjtNum vec[3], mjtNum scl);
775
+
776
+ # Set res = vec1 + vec2*scl.
777
+ void mju_addScl3(mjtNum res[3], const mjtNum vec1[3], const mjtNum vec2[3], mjtNum scl);
778
+
779
+ # Normalize vector, return length before normalization.
780
+ mjtNum mju_normalize3(mjtNum res[3]);
781
+
782
+ # Return vector length (without normalizing the vector).
783
+ mjtNum mju_norm3(const mjtNum vec[3]);
784
+
785
+ # Return dot-product of vec1 and vec2.
786
+ mjtNum mju_dot3(const mjtNum vec1[3], const mjtNum vec2[3]);
787
+
788
+ # Return Cartesian distance between 3D vectors pos1 and pos2.
789
+ mjtNum mju_dist3(const mjtNum pos1[3], const mjtNum pos2[3]);
790
+
791
+ # Multiply vector by 3D rotation matrix: res = mat * vec.
792
+ void mju_rotVecMat(mjtNum res[3], const mjtNum vec[3], const mjtNum mat[9]);
793
+
794
+ # Multiply vector by transposed 3D rotation matrix: res = mat' * vec.
795
+ void mju_rotVecMatT(mjtNum res[3], const mjtNum vec[3], const mjtNum mat[9]);
796
+
797
+ # Compute cross-product: res = cross(a, b).
798
+ void mju_cross(mjtNum res[3], const mjtNum a[3], const mjtNum b[3]);
799
+
800
+ # Set res = 0.
801
+ void mju_zero4(mjtNum res[4]);
802
+
803
+ # Set res = (1,0,0,0).
804
+ void mju_unit4(mjtNum res[4]);
805
+
806
+ # Set res = vec.
807
+ void mju_copy4(mjtNum res[4], const mjtNum data[4]);
808
+
809
+ # Normalize vector, return length before normalization.
810
+ mjtNum mju_normalize4(mjtNum res[4]);
811
+
812
+ # Set res = 0.
813
+ void mju_zero(mjtNum* res, int n);
814
+
815
+ # Set res = vec.
816
+ void mju_copy(mjtNum* res, const mjtNum* data, int n);
817
+
818
+ # Return sum(vec).
819
+ mjtNum mju_sum(const mjtNum* vec, int n);
820
+
821
+ # Return L1 norm: sum(abs(vec)).
822
+ mjtNum mju_L1(const mjtNum* vec, int n);
823
+
824
+ # Set res = vec*scl.
825
+ void mju_scl(mjtNum* res, const mjtNum* vec, mjtNum scl, int n);
826
+
827
+ # Set res = vec1 + vec2.
828
+ void mju_add(mjtNum* res, const mjtNum* vec1, const mjtNum* vec2, int n);
829
+
830
+ # Set res = vec1 - vec2.
831
+ void mju_sub(mjtNum* res, const mjtNum* vec1, const mjtNum* vec2, int n);
832
+
833
+ # Set res = res + vec.
834
+ void mju_addTo(mjtNum* res, const mjtNum* vec, int n);
835
+
836
+ # Set res = res - vec.
837
+ void mju_subFrom(mjtNum* res, const mjtNum* vec, int n);
838
+
839
+ # Set res = res + vec*scl.
840
+ void mju_addToScl(mjtNum* res, const mjtNum* vec, mjtNum scl, int n);
841
+
842
+ # Set res = vec1 + vec2*scl.
843
+ void mju_addScl(mjtNum* res, const mjtNum* vec1, const mjtNum* vec2, mjtNum scl, int n);
844
+
845
+ # Normalize vector, return length before normalization.
846
+ mjtNum mju_normalize(mjtNum* res, int n);
847
+
848
+ # Return vector length (without normalizing vector).
849
+ mjtNum mju_norm(const mjtNum* res, int n);
850
+
851
+ # Return dot-product of vec1 and vec2.
852
+ mjtNum mju_dot(const mjtNum* vec1, const mjtNum* vec2, const int n);
853
+
854
+ # Multiply matrix and vector: res = mat * vec.
855
+ void mju_mulMatVec(mjtNum* res, const mjtNum* mat, const mjtNum* vec,
856
+ int nr, int nc);
857
+
858
+ # Multiply transposed matrix and vector: res = mat' * vec.
859
+ void mju_mulMatTVec(mjtNum* res, const mjtNum* mat, const mjtNum* vec,
860
+ int nr, int nc);
861
+
862
+ # Transpose matrix: res = mat'.
863
+ void mju_transpose(mjtNum* res, const mjtNum* mat, int nr, int nc);
864
+
865
+ # Multiply matrices: res = mat1 * mat2.
866
+ void mju_mulMatMat(mjtNum* res, const mjtNum* mat1, const mjtNum* mat2,
867
+ int r1, int c1, int c2);
868
+
869
+ # Multiply matrices, second argument transposed: res = mat1 * mat2'.
870
+ void mju_mulMatMatT(mjtNum* res, const mjtNum* mat1, const mjtNum* mat2,
871
+ int r1, int c1, int r2);
872
+
873
+ # Multiply matrices, first argument transposed: res = mat1' * mat2.
874
+ void mju_mulMatTMat(mjtNum* res, const mjtNum* mat1, const mjtNum* mat2,
875
+ int r1, int c1, int c2);
876
+
877
+ # Set res = mat' * diag * mat if diag is not NULL, and res = mat' * mat otherwise.
878
+ void mju_sqrMatTD(mjtNum* res, const mjtNum* mat, const mjtNum* diag, int nr, int nc);
879
+
880
+ # Coordinate transform of 6D motion or force vector in rotation:translation format.
881
+ # rotnew2old is 3-by-3, NULL means no rotation; flg_force specifies force or motion type.
882
+ void mju_transformSpatial(mjtNum res[6], const mjtNum vec[6], int flg_force,
883
+ const mjtNum newpos[3], const mjtNum oldpos[3],
884
+ const mjtNum rotnew2old[9]);
885
+
886
+
887
+ #--------------------- Sparse math ----------------------------------------------------
888
+
889
+ # Return dot-product of vec1 and vec2, where vec1 is sparse.
890
+ mjtNum mju_dotSparse(const mjtNum* vec1, const mjtNum* vec2,
891
+ const int nnz1, const int* ind1);
892
+
893
+ # Return dot-product of vec1 and vec2, where both vectors are sparse.
894
+ mjtNum mju_dotSparse2(const mjtNum* vec1, const mjtNum* vec2,
895
+ const int nnz1, const int* ind1,
896
+ const int nnz2, const int* ind2);
897
+
898
+ # Convert matrix from dense to sparse format.
899
+ void mju_dense2sparse(mjtNum* res, const mjtNum* mat, int nr, int nc,
900
+ int* rownnz, int* rowadr, int* colind);
901
+
902
+ # Convert matrix from sparse to dense format.
903
+ void mju_sparse2dense(mjtNum* res, const mjtNum* mat, int nr, int nc,
904
+ const int* rownnz, const int* rowadr, const int* colind);
905
+
906
+ # Multiply sparse matrix and dense vector: res = mat * vec.
907
+ void mju_mulMatVecSparse(mjtNum* res, const mjtNum* mat, const mjtNum* vec, int nr,
908
+ const int* rownnz, const int* rowadr, const int* colind);
909
+
910
+ # Compress layout of sparse matrix.
911
+ void mju_compressSparse(mjtNum* mat, int nr, int nc,
912
+ int* rownnz, int* rowadr, int* colind);
913
+
914
+ # Set dst = a*dst + b*src, return nnz of result, modify dst sparsity pattern as needed.
915
+ # Both vectors are sparse. The required scratch space is 2*n.
916
+ int mju_combineSparse(mjtNum* dst, const mjtNum* src, int n, mjtNum a, mjtNum b,
917
+ int dst_nnz, int src_nnz, int* dst_ind, const int* src_ind,
918
+ mjtNum* scratch, int nscratch);
919
+
920
+ # Set res = matT * diag * mat if diag is not NULL, and res = matT * mat otherwise.
921
+ # The required scratch space is 3*nc. The result has uncompressed layout.
922
+ void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
923
+ const mjtNum* diag, int nr, int nc,
924
+ int* res_rownnz, int* res_rowadr, int* res_colind,
925
+ const int* rownnz, const int* rowadr, const int* colind,
926
+ const int* rownnzT, const int* rowadrT, const int* colindT,
927
+ mjtNum* scratch, int nscratch);
928
+
929
+ # Transpose sparse matrix.
930
+ void mju_transposeSparse(mjtNum* res, const mjtNum* mat, int nr, int nc,
931
+ int* res_rownnz, int* res_rowadr, int* res_colind,
932
+ const int* rownnz, const int* rowadr, const int* colind);
933
+
934
+
935
+ #--------------------- Quaternions ----------------------------------------------------
936
+
937
+ # Rotate vector by quaternion.
938
+ void mju_rotVecQuat(mjtNum res[3], const mjtNum vec[3], const mjtNum quat[4]);
939
+
940
+ # Negate quaternion.
941
+ void mju_negQuat(mjtNum res[4], const mjtNum quat[4]);
942
+
943
+ # Muiltiply quaternions.
944
+ void mju_mulQuat(mjtNum res[4], const mjtNum quat1[4], const mjtNum quat2[4]);
945
+
946
+ # Muiltiply quaternion and axis.
947
+ void mju_mulQuatAxis(mjtNum res[4], const mjtNum quat[4], const mjtNum axis[3]);
948
+
949
+ # Convert axisAngle to quaternion.
950
+ void mju_axisAngle2Quat(mjtNum res[4], const mjtNum axis[3], mjtNum angle);
951
+
952
+ # Convert quaternion (corresponding to orientation difference) to 3D velocity.
953
+ void mju_quat2Vel(mjtNum res[3], const mjtNum quat[4], mjtNum dt);
954
+
955
+ # Subtract quaternions, express as 3D velocity: qb*quat(res) = qa.
956
+ void mju_subQuat(mjtNum res[3], const mjtNum qa[4], const mjtNum qb[4]);
957
+
958
+ # Convert quaternion to 3D rotation matrix.
959
+ void mju_quat2Mat(mjtNum res[9], const mjtNum quat[4]);
960
+
961
+ # Convert 3D rotation matrix to quaterion.
962
+ void mju_mat2Quat(mjtNum quat[4], const mjtNum mat[9]);
963
+
964
+ # Compute time-derivative of quaternion, given 3D rotational velocity.
965
+ void mju_derivQuat(mjtNum res[4], const mjtNum quat[4], const mjtNum vel[3]);
966
+
967
+ # Integrate quaterion given 3D angular velocity.
968
+ void mju_quatIntegrate(mjtNum quat[4], const mjtNum vel[3], mjtNum scale);
969
+
970
+ # Construct quaternion performing rotation from z-axis to given vector.
971
+ void mju_quatZ2Vec(mjtNum quat[4], const mjtNum vec[3]);
972
+
973
+
974
+ #--------------------- Poses ----------------------------------------------------------
975
+
976
+ # Multiply two poses.
977
+ void mju_mulPose(mjtNum posres[3], mjtNum quatres[4],
978
+ const mjtNum pos1[3], const mjtNum quat1[4],
979
+ const mjtNum pos2[3], const mjtNum quat2[4]);
980
+
981
+ # Negate pose.
982
+ void mju_negPose(mjtNum posres[3], mjtNum quatres[4],
983
+ const mjtNum pos[3], const mjtNum quat[4]);
984
+
985
+ # Transform vector by pose.
986
+ void mju_trnVecPose(mjtNum res[3], const mjtNum pos[3], const mjtNum quat[4],
987
+ const mjtNum vec[3]);
988
+
989
+
990
+ #--------------------- Decompositions --------------------------------------------------
991
+
992
+ # Cholesky decomposition: mat = L*L'; return rank.
993
+ int mju_cholFactor(mjtNum* mat, int n, mjtNum mindiag);
994
+
995
+ # Solve mat * res = vec, where mat is Cholesky-factorized
996
+ void mju_cholSolve(mjtNum* res, const mjtNum* mat, const mjtNum* vec, int n);
997
+
998
+ # Cholesky rank-one update: L*L' +/- x*x'; return rank.
999
+ int mju_cholUpdate(mjtNum* mat, mjtNum* x, int n, int flg_plus);
1000
+
1001
+ # Eigenvalue decomposition of symmetric 3x3 matrix.
1002
+ int mju_eig3(mjtNum* eigval, mjtNum* eigvec, mjtNum* quat, const mjtNum* mat);
1003
+
1004
+
1005
+ #--------------------- Miscellaneous --------------------------------------------------
1006
+
1007
+ # Muscle active force, prm = (range[2], force, scale, lmin, lmax, vmax, fpmax, fvmax).
1008
+ mjtNum mju_muscleGain(mjtNum len, mjtNum vel, const mjtNum lengthrange[2],
1009
+ mjtNum acc0, const mjtNum prm[9]);
1010
+
1011
+ # Muscle passive force, prm = (range[2], force, scale, lmin, lmax, vmax, fpmax, fvmax).
1012
+ mjtNum mju_muscleBias(mjtNum len, const mjtNum lengthrange[2],
1013
+ mjtNum acc0, const mjtNum prm[9]);
1014
+
1015
+ # Muscle activation dynamics, prm = (tau_act, tau_deact).
1016
+ mjtNum mju_muscleDynamics(mjtNum ctrl, mjtNum act, const mjtNum prm[2]);
1017
+
1018
+
1019
+ # Convert contact force to pyramid representation.
1020
+ void mju_encodePyramid(mjtNum* pyramid, const mjtNum* force,
1021
+ const mjtNum* mu, int dim);
1022
+
1023
+ # Convert pyramid representation to contact force.
1024
+ void mju_decodePyramid(mjtNum* force, const mjtNum* pyramid,
1025
+ const mjtNum* mu, int dim);
1026
+
1027
+ # Integrate spring-damper analytically, return pos(dt).
1028
+ mjtNum mju_springDamper(mjtNum pos0, mjtNum vel0, mjtNum Kp, mjtNum Kv, mjtNum dt);
1029
+
1030
+ # Return min(a,b) with single evaluation of a and b.
1031
+ mjtNum mju_min(mjtNum a, mjtNum b);
1032
+
1033
+ # Return max(a,b) with single evaluation of a and b.
1034
+ mjtNum mju_max(mjtNum a, mjtNum b);
1035
+
1036
+ # Return sign of x: +1, -1 or 0.
1037
+ mjtNum mju_sign(mjtNum x);
1038
+
1039
+ # Round x to nearest integer.
1040
+ int mju_round(mjtNum x);
1041
+
1042
+ # Convert type id (mjtObj) to type name.
1043
+ const char* mju_type2Str(int type);
1044
+
1045
+ # Convert type name to type id (mjtObj).
1046
+ int mju_str2Type(const char* str);
1047
+
1048
+ # Construct a warning message given the warning type and info.
1049
+ const char* mju_warningText(int warning, int info);
1050
+
1051
+ # Return 1 if nan or abs(x)>mjMAXVAL, 0 otherwise. Used by check functions.
1052
+ int mju_isBad(mjtNum x);
1053
+
1054
+ # Return 1 if all elements are 0.
1055
+ int mju_isZero(mjtNum* vec, int n);
1056
+
1057
+ # Standard normal random number generator (optional second number).
1058
+ mjtNum mju_standardNormal(mjtNum* num2);
1059
+
1060
+ # Convert from float to mjtNum.
1061
+ void mju_f2n(mjtNum* res, const float* vec, int n);
1062
+
1063
+ # Convert from mjtNum to float.
1064
+ void mju_n2f(float* res, const mjtNum* vec, int n);
1065
+
1066
+ # Convert from double to mjtNum.
1067
+ void mju_d2n(mjtNum* res, const double* vec, int n);
1068
+
1069
+ # Convert from mjtNum to double.
1070
+ void mju_n2d(double* res, const mjtNum* vec, int n);
1071
+
1072
+ # Insertion sort, resulting list is in increasing order.
1073
+ void mju_insertionSort(mjtNum* list, int n);
1074
+
1075
+ # Integer insertion sort, resulting list is in increasing order.
1076
+ void mju_insertionSortInt(int* list, int n);
1077
+
1078
+ # Generate Halton sequence.
1079
+ mjtNum mju_Halton(int index, int base);
1080
+
1081
+ # Sigmoid function over 0<=x<=1 constructed from half-quadratics.
1082
+ mjtNum mju_sigmoid(mjtNum x);
1083
+
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_materials.premod.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop0_1.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop1_0.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop1_1.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop2_1.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_render_pool.mp_test_states.2.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_render_pool.mp_test_states.3.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.camera1.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.freecam.depth-darwin.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.freecam.depth.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_resetting.loop1_1.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_textures.rgb.png ADDED
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_textures.variety.png ADDED