Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- VRL3/LICENSE +21 -0
- VRL3/src/cfgs_adroit/task/door.yaml +4 -0
- VRL3/src/cfgs_adroit/task/relocate.yaml +4 -0
- VRL3/src/logger.py +182 -0
- VRL3/src/replay_buffer.py +222 -0
- VRL3/src/rrl_local/__pycache__/rrl_multicam.cpython-38.pyc +0 -0
- VRL3/src/stage1_models.py +318 -0
- VRL3/src/train_stage1.py +493 -0
- VRL3/src/utils.py +149 -0
- VRL3/src/vrl3_agent.py +632 -0
- gym-0.21.0/.github/stale.yml +62 -0
- gym-0.21.0/CONTRIBUTING.md +18 -0
- gym-0.21.0/README.md +57 -0
- gym-0.21.0/docs/toy_text/blackjack.md +60 -0
- gym-0.21.0/docs/toy_text/taxi.md +92 -0
- gym-0.21.0/scripts/generate_json.py +119 -0
- gym-0.21.0/setup.py +76 -0
- mujoco-py-2.1.2.14/.gitignore +55 -0
- mujoco-py-2.1.2.14/docs/_static/.gitkeep +0 -0
- mujoco-py-2.1.2.14/docs/build/doctrees/reference.doctree +0 -0
- mujoco-py-2.1.2.14/mujoco_py.egg-info/SOURCES.txt +67 -0
- mujoco-py-2.1.2.14/mujoco_py/__pycache__/builder.cpython-38.pyc +0 -0
- mujoco-py-2.1.2.14/mujoco_py/__pycache__/mjviewer.cpython-38.pyc +0 -0
- mujoco-py-2.1.2.14/mujoco_py/builder.py +518 -0
- mujoco-py-2.1.2.14/mujoco_py/gl/eglplatform.h +125 -0
- mujoco-py-2.1.2.14/mujoco_py/gl/glshim.h +30 -0
- mujoco-py-2.1.2.14/mujoco_py/gl/khrplatform.h +285 -0
- mujoco-py-2.1.2.14/mujoco_py/gl/osmesashim.c +75 -0
- mujoco-py-2.1.2.14/mujoco_py/mjbatchrenderer.pyx +301 -0
- mujoco-py-2.1.2.14/mujoco_py/mjrendercontext.pyx +329 -0
- mujoco-py-2.1.2.14/mujoco_py/mjrenderpool.py +241 -0
- mujoco-py-2.1.2.14/mujoco_py/mjsim.pyx +439 -0
- mujoco-py-2.1.2.14/mujoco_py/pxd/__init__.py +0 -0
- mujoco-py-2.1.2.14/mujoco_py/pxd/mjdata.pxd +312 -0
- mujoco-py-2.1.2.14/mujoco_py/pxd/mjmodel.pxd +834 -0
- mujoco-py-2.1.2.14/mujoco_py/pxd/mjrender.pxd +115 -0
- mujoco-py-2.1.2.14/mujoco_py/pxd/mujoco.pxd +1083 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_materials.premod.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop0_1.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop1_0.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop1_1.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop2_1.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_render_pool.mp_test_states.2.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_render_pool.mp_test_states.3.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.camera1.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.freecam.depth-darwin.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.freecam.depth.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_resetting.loop1_1.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_textures.rgb.png +0 -0
- mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_textures.variety.png +0 -0
VRL3/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Microsoft Corporation.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE
|
VRL3/src/cfgs_adroit/task/door.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_train_frames: 4100000
|
| 2 |
+
task_name: door-v0
|
| 3 |
+
agent:
|
| 4 |
+
encoder_lr_scale: 1
|
VRL3/src/cfgs_adroit/task/relocate.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
num_train_frames: 4100000
|
| 2 |
+
task_name: relocate-v0
|
| 3 |
+
agent:
|
| 4 |
+
encoder_lr_scale: 0.01
|
VRL3/src/logger.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
import csv
|
| 6 |
+
import datetime
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torchvision
|
| 12 |
+
from termcolor import colored
|
| 13 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 14 |
+
|
| 15 |
+
COMMON_TRAIN_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
|
| 16 |
+
('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
|
| 17 |
+
('episode_reward', 'R', 'float'),
|
| 18 |
+
('buffer_size', 'BS', 'int'), ('fps', 'FPS', 'float'),
|
| 19 |
+
('total_time', 'T', 'time')]
|
| 20 |
+
|
| 21 |
+
COMMON_EVAL_FORMAT = [('frame', 'F', 'int'), ('step', 'S', 'int'),
|
| 22 |
+
('episode', 'E', 'int'), ('episode_length', 'L', 'int'),
|
| 23 |
+
('episode_reward', 'R', 'float'),
|
| 24 |
+
('total_time', 'T', 'time')]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AverageMeter(object):
|
| 28 |
+
def __init__(self):
|
| 29 |
+
self._sum = 0
|
| 30 |
+
self._count = 0
|
| 31 |
+
|
| 32 |
+
def update(self, value, n=1):
|
| 33 |
+
self._sum += value
|
| 34 |
+
self._count += n
|
| 35 |
+
|
| 36 |
+
def value(self):
|
| 37 |
+
return self._sum / max(1, self._count)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MetersGroup(object):
|
| 41 |
+
def __init__(self, csv_file_name, formating):
|
| 42 |
+
self._csv_file_name = csv_file_name
|
| 43 |
+
self._formating = formating
|
| 44 |
+
self._meters = defaultdict(AverageMeter)
|
| 45 |
+
self._csv_file = None
|
| 46 |
+
self._csv_writer = None
|
| 47 |
+
|
| 48 |
+
def log(self, key, value, n=1):
|
| 49 |
+
self._meters[key].update(value, n)
|
| 50 |
+
|
| 51 |
+
def _prime_meters(self):
|
| 52 |
+
data = dict()
|
| 53 |
+
for key, meter in self._meters.items():
|
| 54 |
+
if key.startswith('train'):
|
| 55 |
+
key = key[len('train') + 1:]
|
| 56 |
+
else:
|
| 57 |
+
key = key[len('eval') + 1:]
|
| 58 |
+
key = key.replace('/', '_')
|
| 59 |
+
data[key] = meter.value()
|
| 60 |
+
return data
|
| 61 |
+
|
| 62 |
+
def _remove_old_entries(self, data):
|
| 63 |
+
rows = []
|
| 64 |
+
with self._csv_file_name.open('r') as f:
|
| 65 |
+
reader = csv.DictReader(f)
|
| 66 |
+
for row in reader:
|
| 67 |
+
if float(row['episode']) >= data['episode']:
|
| 68 |
+
break
|
| 69 |
+
rows.append(row)
|
| 70 |
+
with self._csv_file_name.open('w') as f:
|
| 71 |
+
writer = csv.DictWriter(f,
|
| 72 |
+
fieldnames=sorted(data.keys()),
|
| 73 |
+
restval=0.0)
|
| 74 |
+
writer.writeheader()
|
| 75 |
+
for row in rows:
|
| 76 |
+
writer.writerow(row)
|
| 77 |
+
|
| 78 |
+
def _dump_to_csv(self, data):
|
| 79 |
+
if self._csv_writer is None:
|
| 80 |
+
should_write_header = True
|
| 81 |
+
if self._csv_file_name.exists():
|
| 82 |
+
self._remove_old_entries(data)
|
| 83 |
+
should_write_header = False
|
| 84 |
+
|
| 85 |
+
self._csv_file = self._csv_file_name.open('a')
|
| 86 |
+
self._csv_writer = csv.DictWriter(self._csv_file,
|
| 87 |
+
fieldnames=sorted(data.keys()),
|
| 88 |
+
restval=0.0)
|
| 89 |
+
if should_write_header:
|
| 90 |
+
self._csv_writer.writeheader()
|
| 91 |
+
|
| 92 |
+
self._csv_writer.writerow(data)
|
| 93 |
+
self._csv_file.flush()
|
| 94 |
+
|
| 95 |
+
def _format(self, key, value, ty):
|
| 96 |
+
if ty == 'int':
|
| 97 |
+
value = int(value)
|
| 98 |
+
return f'{key}: {value}'
|
| 99 |
+
elif ty == 'float':
|
| 100 |
+
return f'{key}: {value:.04f}'
|
| 101 |
+
elif ty == 'time':
|
| 102 |
+
value = str(datetime.timedelta(seconds=int(value)))
|
| 103 |
+
return f'{key}: {value}'
|
| 104 |
+
else:
|
| 105 |
+
raise f'invalid format type: {ty}'
|
| 106 |
+
|
| 107 |
+
def _dump_to_console(self, data, prefix):
|
| 108 |
+
prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
|
| 109 |
+
pieces = [f'| {prefix: <14}']
|
| 110 |
+
for key, disp_key, ty in self._formating:
|
| 111 |
+
value = data.get(key, 0)
|
| 112 |
+
pieces.append(self._format(disp_key, value, ty))
|
| 113 |
+
print(' | '.join(pieces))
|
| 114 |
+
|
| 115 |
+
def dump(self, step, prefix):
|
| 116 |
+
if len(self._meters) == 0:
|
| 117 |
+
return
|
| 118 |
+
data = self._prime_meters()
|
| 119 |
+
data['frame'] = step
|
| 120 |
+
self._dump_to_csv(data)
|
| 121 |
+
self._dump_to_console(data, prefix)
|
| 122 |
+
self._meters.clear()
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class Logger(object):
|
| 126 |
+
def __init__(self, log_dir, use_tb, stage2_logger=False):
|
| 127 |
+
self._log_dir = log_dir
|
| 128 |
+
if not stage2_logger:
|
| 129 |
+
self._train_mg = MetersGroup(log_dir / 'train.csv',
|
| 130 |
+
formating=COMMON_TRAIN_FORMAT)
|
| 131 |
+
self._eval_mg = MetersGroup(log_dir / 'eval.csv',
|
| 132 |
+
formating=COMMON_EVAL_FORMAT)
|
| 133 |
+
else:
|
| 134 |
+
self._train_mg = MetersGroup(log_dir / 'train_stage2.csv',
|
| 135 |
+
formating=COMMON_TRAIN_FORMAT)
|
| 136 |
+
self._eval_mg = MetersGroup(log_dir / 'eval_stage2.csv',
|
| 137 |
+
formating=COMMON_EVAL_FORMAT)
|
| 138 |
+
if use_tb:
|
| 139 |
+
self._sw = SummaryWriter(str(log_dir / 'tb'))
|
| 140 |
+
else:
|
| 141 |
+
self._sw = None
|
| 142 |
+
|
| 143 |
+
def _try_sw_log(self, key, value, step):
|
| 144 |
+
if self._sw is not None:
|
| 145 |
+
self._sw.add_scalar(key, value, step)
|
| 146 |
+
|
| 147 |
+
def log(self, key, value, step):
|
| 148 |
+
assert key.startswith('train') or key.startswith('eval')
|
| 149 |
+
if type(value) == torch.Tensor:
|
| 150 |
+
value = value.item()
|
| 151 |
+
self._try_sw_log(key, value, step)
|
| 152 |
+
mg = self._train_mg if key.startswith('train') else self._eval_mg
|
| 153 |
+
mg.log(key, value)
|
| 154 |
+
|
| 155 |
+
def log_metrics(self, metrics, step, ty):
|
| 156 |
+
for key, value in metrics.items():
|
| 157 |
+
self.log(f'{ty}/{key}', value, step)
|
| 158 |
+
|
| 159 |
+
def dump(self, step, ty=None):
|
| 160 |
+
if ty is None or ty == 'eval':
|
| 161 |
+
self._eval_mg.dump(step, 'eval')
|
| 162 |
+
if ty is None or ty == 'train':
|
| 163 |
+
self._train_mg.dump(step, 'train')
|
| 164 |
+
|
| 165 |
+
def log_and_dump_ctx(self, step, ty):
|
| 166 |
+
return LogAndDumpCtx(self, step, ty)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class LogAndDumpCtx:
|
| 170 |
+
def __init__(self, logger, step, ty):
|
| 171 |
+
self._logger = logger
|
| 172 |
+
self._step = step
|
| 173 |
+
self._ty = ty
|
| 174 |
+
|
| 175 |
+
def __enter__(self):
|
| 176 |
+
return self
|
| 177 |
+
|
| 178 |
+
def __call__(self, key, value):
|
| 179 |
+
self._logger.log(f'{self._ty}/{key}', value, self._step)
|
| 180 |
+
|
| 181 |
+
def __exit__(self, *args):
|
| 182 |
+
self._logger.dump(self._step, self._ty)
|
VRL3/src/replay_buffer.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
import datetime
|
| 6 |
+
import io
|
| 7 |
+
import random
|
| 8 |
+
import traceback
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch.utils.data import IterableDataset
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def episode_len(episode):
|
| 18 |
+
# subtract -1 because the dummy first transition
|
| 19 |
+
return next(iter(episode.values())).shape[0] - 1
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def save_episode(episode, fn):
|
| 23 |
+
with io.BytesIO() as bs:
|
| 24 |
+
np.savez_compressed(bs, **episode)
|
| 25 |
+
bs.seek(0)
|
| 26 |
+
with fn.open('wb') as f:
|
| 27 |
+
f.write(bs.read())
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_episode(fn):
|
| 31 |
+
with fn.open('rb') as f:
|
| 32 |
+
episode = np.load(f)
|
| 33 |
+
episode = {k: episode[k] for k in episode.keys()}
|
| 34 |
+
return episode
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ReplayBufferStorage:
|
| 38 |
+
def __init__(self, data_specs, replay_dir):
|
| 39 |
+
self._data_specs = data_specs
|
| 40 |
+
self._replay_dir = replay_dir
|
| 41 |
+
replay_dir.mkdir(exist_ok=True)
|
| 42 |
+
self._current_episode = defaultdict(list)
|
| 43 |
+
self._preload()
|
| 44 |
+
|
| 45 |
+
def __len__(self):
|
| 46 |
+
return self._num_transitions
|
| 47 |
+
|
| 48 |
+
def add(self, time_step):
|
| 49 |
+
for spec in self._data_specs:
|
| 50 |
+
value = time_step[spec.name]
|
| 51 |
+
if np.isscalar(value):
|
| 52 |
+
value = np.full(spec.shape, value, spec.dtype)
|
| 53 |
+
# print(spec.name, spec.shape, spec.dtype, value.shape, value.dtype)
|
| 54 |
+
assert spec.shape == value.shape and spec.dtype == value.dtype
|
| 55 |
+
self._current_episode[spec.name].append(value)
|
| 56 |
+
if time_step.last():
|
| 57 |
+
episode = dict()
|
| 58 |
+
for spec in self._data_specs:
|
| 59 |
+
value = self._current_episode[spec.name]
|
| 60 |
+
episode[spec.name] = np.array(value, spec.dtype)
|
| 61 |
+
self._current_episode = defaultdict(list)
|
| 62 |
+
self._store_episode(episode)
|
| 63 |
+
|
| 64 |
+
def _preload(self):
|
| 65 |
+
self._num_episodes = 0
|
| 66 |
+
self._num_transitions = 0
|
| 67 |
+
for fn in self._replay_dir.glob('*.npz'):
|
| 68 |
+
_, _, eps_len = fn.stem.split('_')
|
| 69 |
+
self._num_episodes += 1
|
| 70 |
+
self._num_transitions += int(eps_len)
|
| 71 |
+
|
| 72 |
+
def _store_episode(self, episode):
|
| 73 |
+
eps_idx = self._num_episodes
|
| 74 |
+
eps_len = episode_len(episode)
|
| 75 |
+
self._num_episodes += 1
|
| 76 |
+
self._num_transitions += eps_len
|
| 77 |
+
ts = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
|
| 78 |
+
eps_fn = f'{ts}_{eps_idx}_{eps_len}.npz'
|
| 79 |
+
save_episode(episode, self._replay_dir / eps_fn)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class ReplayBuffer(IterableDataset):
|
| 83 |
+
def __init__(self, replay_dir, max_size, num_workers, nstep, discount,
|
| 84 |
+
fetch_every, save_snapshot, is_adroit=False, return_next_action=False):
|
| 85 |
+
self._replay_dir = replay_dir
|
| 86 |
+
self._size = 0
|
| 87 |
+
self._max_size = max_size
|
| 88 |
+
self._num_workers = max(1, num_workers)
|
| 89 |
+
self._episode_fns = []
|
| 90 |
+
self._episodes = dict()
|
| 91 |
+
self._nstep = nstep
|
| 92 |
+
self._discount = discount
|
| 93 |
+
self._fetch_every = fetch_every
|
| 94 |
+
self._samples_since_last_fetch = fetch_every
|
| 95 |
+
self._save_snapshot = save_snapshot
|
| 96 |
+
self._is_adroit = is_adroit
|
| 97 |
+
self._return_next_action = return_next_action
|
| 98 |
+
|
| 99 |
+
def set_nstep(self, nstep):
|
| 100 |
+
self._nstep = nstep
|
| 101 |
+
|
| 102 |
+
def _sample_episode(self):
|
| 103 |
+
eps_fn = random.choice(self._episode_fns)
|
| 104 |
+
return self._episodes[eps_fn]
|
| 105 |
+
|
| 106 |
+
def _store_episode(self, eps_fn):
|
| 107 |
+
try:
|
| 108 |
+
episode = load_episode(eps_fn)
|
| 109 |
+
except:
|
| 110 |
+
return False
|
| 111 |
+
eps_len = episode_len(episode)
|
| 112 |
+
while eps_len + self._size > self._max_size:
|
| 113 |
+
early_eps_fn = self._episode_fns.pop(0)
|
| 114 |
+
early_eps = self._episodes.pop(early_eps_fn)
|
| 115 |
+
self._size -= episode_len(early_eps)
|
| 116 |
+
early_eps_fn.unlink(missing_ok=True)
|
| 117 |
+
self._episode_fns.append(eps_fn)
|
| 118 |
+
self._episode_fns.sort()
|
| 119 |
+
self._episodes[eps_fn] = episode
|
| 120 |
+
self._size += eps_len
|
| 121 |
+
|
| 122 |
+
if not self._save_snapshot:
|
| 123 |
+
eps_fn.unlink(missing_ok=True)
|
| 124 |
+
return True
|
| 125 |
+
|
| 126 |
+
def _try_fetch(self):
|
| 127 |
+
if self._samples_since_last_fetch < self._fetch_every:
|
| 128 |
+
return
|
| 129 |
+
self._samples_since_last_fetch = 0
|
| 130 |
+
try:
|
| 131 |
+
worker_id = torch.utils.data.get_worker_info().id
|
| 132 |
+
except:
|
| 133 |
+
worker_id = 0
|
| 134 |
+
eps_fns = sorted(self._replay_dir.glob('*.npz'), reverse=True)
|
| 135 |
+
fetched_size = 0
|
| 136 |
+
for eps_fn in eps_fns:
|
| 137 |
+
eps_idx, eps_len = [int(x) for x in eps_fn.stem.split('_')[1:]]
|
| 138 |
+
if eps_idx % self._num_workers != worker_id:
|
| 139 |
+
continue
|
| 140 |
+
if eps_fn in self._episodes.keys():
|
| 141 |
+
break
|
| 142 |
+
if fetched_size + eps_len > self._max_size:
|
| 143 |
+
break
|
| 144 |
+
fetched_size += eps_len
|
| 145 |
+
if not self._store_episode(eps_fn):
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
def _sample(self):
|
| 149 |
+
try:
|
| 150 |
+
self._try_fetch()
|
| 151 |
+
except:
|
| 152 |
+
traceback.print_exc()
|
| 153 |
+
self._samples_since_last_fetch += 1
|
| 154 |
+
episode = self._sample_episode()
|
| 155 |
+
# add +1 for the first dummy transition
|
| 156 |
+
idx = np.random.randint(0, episode_len(episode) - self._nstep + 1) + 1
|
| 157 |
+
obs = episode['observation'][idx - 1]
|
| 158 |
+
action = episode['action'][idx]
|
| 159 |
+
next_obs = episode['observation'][idx + self._nstep - 1]
|
| 160 |
+
reward = np.zeros_like(episode['reward'][idx])
|
| 161 |
+
discount = np.ones_like(episode['discount'][idx])
|
| 162 |
+
for i in range(self._nstep):
|
| 163 |
+
step_reward = episode['reward'][idx + i]
|
| 164 |
+
reward += discount * step_reward
|
| 165 |
+
discount *= episode['discount'][idx + i] * self._discount
|
| 166 |
+
|
| 167 |
+
if self._return_next_action:
|
| 168 |
+
next_action = episode['action'][idx + self._nstep - 1]
|
| 169 |
+
|
| 170 |
+
if not self._is_adroit:
|
| 171 |
+
if self._return_next_action:
|
| 172 |
+
return (obs, action, reward, discount, next_obs, next_action)
|
| 173 |
+
else:
|
| 174 |
+
return (obs, action, reward, discount, next_obs)
|
| 175 |
+
else:
|
| 176 |
+
obs_sensor = episode['observation_sensor'][idx - 1]
|
| 177 |
+
obs_sensor_next = episode['observation_sensor'][idx + self._nstep - 1]
|
| 178 |
+
if self._return_next_action:
|
| 179 |
+
return (obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next, next_action)
|
| 180 |
+
else:
|
| 181 |
+
return (obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next)
|
| 182 |
+
|
| 183 |
+
def __iter__(self):
|
| 184 |
+
while True:
|
| 185 |
+
yield self._sample()
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _worker_init_fn(worker_id):
|
| 189 |
+
seed = np.random.get_state()[1][0] + worker_id
|
| 190 |
+
np.random.seed(seed)
|
| 191 |
+
random.seed(seed)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def make_replay_loader(replay_dir, max_size, batch_size, num_workers,
|
| 195 |
+
save_snapshot, nstep, discount, fetch_every=1000, is_adroit=False, return_next_action=False):
|
| 196 |
+
max_size_per_worker = max_size // max(1, num_workers)
|
| 197 |
+
|
| 198 |
+
iterable = ReplayBuffer(replay_dir,
|
| 199 |
+
max_size_per_worker,
|
| 200 |
+
num_workers,
|
| 201 |
+
nstep,
|
| 202 |
+
discount,
|
| 203 |
+
fetch_every=fetch_every,
|
| 204 |
+
save_snapshot=save_snapshot,
|
| 205 |
+
is_adroit=is_adroit,
|
| 206 |
+
return_next_action=return_next_action)
|
| 207 |
+
|
| 208 |
+
loader = torch.utils.data.DataLoader(iterable,
|
| 209 |
+
batch_size=batch_size,
|
| 210 |
+
num_workers=num_workers,
|
| 211 |
+
pin_memory=True,
|
| 212 |
+
worker_init_fn=_worker_init_fn)
|
| 213 |
+
return loader
|
| 214 |
+
|
| 215 |
+
def reinit_data_loader(data_loader, batch_size, num_workers):
|
| 216 |
+
# reinit a data loader with a new batch size
|
| 217 |
+
loader = torch.utils.data.DataLoader(data_loader.dataset,
|
| 218 |
+
batch_size=batch_size,
|
| 219 |
+
num_workers=num_workers,
|
| 220 |
+
pin_memory=True,
|
| 221 |
+
worker_init_fn=_worker_init_fn)
|
| 222 |
+
return loader
|
VRL3/src/rrl_local/__pycache__/rrl_multicam.cpython-38.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
VRL3/src/stage1_models.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
from numpy import identity
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
most code here are modified from the TORCHVISION.MODELS.RESNET
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 16 |
+
"""3x3 convolution with padding"""
|
| 17 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 18 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 19 |
+
|
| 20 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 21 |
+
"""1x1 convolution"""
|
| 22 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 23 |
+
|
| 24 |
+
class BasicBlock(nn.Module):
|
| 25 |
+
expansion = 1
|
| 26 |
+
|
| 27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 28 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 29 |
+
super(BasicBlock, self).__init__()
|
| 30 |
+
if norm_layer is None:
|
| 31 |
+
norm_layer = nn.BatchNorm2d
|
| 32 |
+
if groups != 1 or base_width != 64:
|
| 33 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 34 |
+
if dilation > 1:
|
| 35 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 36 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 37 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 38 |
+
self.bn1 = norm_layer(planes)
|
| 39 |
+
self.relu = nn.ReLU(inplace=True)
|
| 40 |
+
self.conv2 = conv3x3(planes, planes)
|
| 41 |
+
self.bn2 = norm_layer(planes)
|
| 42 |
+
self.downsample = downsample
|
| 43 |
+
self.stride = stride
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
identity = x
|
| 47 |
+
|
| 48 |
+
out = self.conv1(x)
|
| 49 |
+
out = self.bn1(out)
|
| 50 |
+
out = self.relu(out)
|
| 51 |
+
|
| 52 |
+
out = self.conv2(out)
|
| 53 |
+
out = self.bn2(out)
|
| 54 |
+
|
| 55 |
+
if self.downsample is not None:
|
| 56 |
+
identity = self.downsample(x)
|
| 57 |
+
|
| 58 |
+
out += identity
|
| 59 |
+
out = self.relu(out)
|
| 60 |
+
|
| 61 |
+
return out
|
| 62 |
+
|
| 63 |
+
class OneLayerBlock(nn.Module):
|
| 64 |
+
# similar to BasicBlock, but shallower
|
| 65 |
+
expansion = 1
|
| 66 |
+
|
| 67 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 68 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 69 |
+
super(OneLayerBlock, self).__init__()
|
| 70 |
+
if norm_layer is None:
|
| 71 |
+
norm_layer = nn.BatchNorm2d
|
| 72 |
+
if groups != 1 or base_width != 64:
|
| 73 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 74 |
+
if dilation > 1:
|
| 75 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 76 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 77 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 78 |
+
self.bn1 = norm_layer(planes)
|
| 79 |
+
self.relu = nn.ReLU(inplace=True)
|
| 80 |
+
self.stride = stride
|
| 81 |
+
|
| 82 |
+
def forward(self, x):
|
| 83 |
+
out = self.conv1(x)
|
| 84 |
+
out = self.bn1(out)
|
| 85 |
+
out = self.relu(out)
|
| 86 |
+
return out
|
| 87 |
+
|
| 88 |
+
class Bottleneck(nn.Module):
|
| 89 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 90 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 91 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 92 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 93 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 94 |
+
|
| 95 |
+
expansion = 4
|
| 96 |
+
|
| 97 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 98 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 99 |
+
super(Bottleneck, self).__init__()
|
| 100 |
+
if norm_layer is None:
|
| 101 |
+
norm_layer = nn.BatchNorm2d
|
| 102 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 103 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 104 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 105 |
+
self.bn1 = norm_layer(width)
|
| 106 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 107 |
+
self.bn2 = norm_layer(width)
|
| 108 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 109 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 110 |
+
self.relu = nn.ReLU(inplace=True)
|
| 111 |
+
self.downsample = downsample
|
| 112 |
+
self.stride = stride
|
| 113 |
+
|
| 114 |
+
def forward(self, x):
|
| 115 |
+
identity = x
|
| 116 |
+
|
| 117 |
+
out = self.conv1(x)
|
| 118 |
+
out = self.bn1(out)
|
| 119 |
+
out = self.relu(out)
|
| 120 |
+
|
| 121 |
+
out = self.conv2(out)
|
| 122 |
+
out = self.bn2(out)
|
| 123 |
+
out = self.relu(out)
|
| 124 |
+
|
| 125 |
+
out = self.conv3(out)
|
| 126 |
+
out = self.bn3(out)
|
| 127 |
+
|
| 128 |
+
if self.downsample is not None:
|
| 129 |
+
identity = self.downsample(x)
|
| 130 |
+
|
| 131 |
+
out += identity
|
| 132 |
+
out = self.relu(out)
|
| 133 |
+
|
| 134 |
+
return out
|
| 135 |
+
|
| 136 |
+
def drq_weight_init(m):
|
| 137 |
+
# weight init scheme used in DrQv2
|
| 138 |
+
if isinstance(m, nn.Linear):
|
| 139 |
+
nn.init.orthogonal_(m.weight.data)
|
| 140 |
+
if hasattr(m.bias, 'data'):
|
| 141 |
+
m.bias.data.fill_(0.0)
|
| 142 |
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
| 143 |
+
gain = nn.init.calculate_gain('relu')
|
| 144 |
+
nn.init.orthogonal_(m.weight.data, gain)
|
| 145 |
+
if hasattr(m.bias, 'data'):
|
| 146 |
+
m.bias.data.fill_(0.0)
|
| 147 |
+
|
| 148 |
+
class Stage3ShallowEncoder(nn.Module):
|
| 149 |
+
"""
|
| 150 |
+
this is the encoder architecture used in DrQv2
|
| 151 |
+
"""
|
| 152 |
+
def __init__(self, obs_shape, n_channel):
|
| 153 |
+
super().__init__()
|
| 154 |
+
assert len(obs_shape) == 3
|
| 155 |
+
self.repr_dim = n_channel * 35 * 35
|
| 156 |
+
self.conv1 = nn.Conv2d(obs_shape[0], n_channel, 3, stride=2)
|
| 157 |
+
self.conv2 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
|
| 158 |
+
self.conv3 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
|
| 159 |
+
self.conv4 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
|
| 160 |
+
self.relu = nn.ReLU(inplace=True)
|
| 161 |
+
self.apply(drq_weight_init)
|
| 162 |
+
|
| 163 |
+
def _forward_impl(self, x):
|
| 164 |
+
x = self.relu(self.conv1(x))
|
| 165 |
+
x = self.relu(self.conv2(x))
|
| 166 |
+
x = self.relu(self.conv3(x))
|
| 167 |
+
x = self.relu(self.conv4(x))
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
def forward(self, obs):
|
| 171 |
+
o = obs
|
| 172 |
+
h = self._forward_impl(o)
|
| 173 |
+
h = h.view(h.shape[0], -1)
|
| 174 |
+
return h
|
| 175 |
+
|
| 176 |
+
class ResNet84(nn.Module):
|
| 177 |
+
"""
|
| 178 |
+
default stage 1 encoder used by VRL3, this is modified from the PyTorch standard ResNet class
|
| 179 |
+
but is more lightweight and this is much faster with 84x84 input size
|
| 180 |
+
use "layers" to specify how deep the network is
|
| 181 |
+
use "start_num_channel" to control how wide it is
|
| 182 |
+
"""
|
| 183 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
| 184 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 185 |
+
norm_layer=None, start_num_channel=32):
|
| 186 |
+
super(ResNet84, self).__init__()
|
| 187 |
+
if norm_layer is None:
|
| 188 |
+
norm_layer = nn.BatchNorm2d
|
| 189 |
+
self._norm_layer = norm_layer
|
| 190 |
+
|
| 191 |
+
self.start_num_channel = start_num_channel
|
| 192 |
+
self.inplanes = start_num_channel
|
| 193 |
+
self.dilation = 1
|
| 194 |
+
if replace_stride_with_dilation is None:
|
| 195 |
+
# each element in the tuple indicates if we should replace
|
| 196 |
+
# the 2x2 stride with a dilated convolution instead
|
| 197 |
+
replace_stride_with_dilation = [False, False, False]
|
| 198 |
+
if len(replace_stride_with_dilation) != 3:
|
| 199 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 200 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 201 |
+
self.groups = groups
|
| 202 |
+
self.base_width = width_per_group
|
| 203 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, 3, stride=2)
|
| 204 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 205 |
+
self.relu = nn.ReLU(inplace=True)
|
| 206 |
+
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 207 |
+
|
| 208 |
+
self.layer1 = self._make_layer(block, start_num_channel, layers[0])
|
| 209 |
+
self.layer2 = self._make_layer(block, start_num_channel * 2, layers[1], stride=2,
|
| 210 |
+
dilate=replace_stride_with_dilation[0])
|
| 211 |
+
self.layer3 = self._make_layer(block, start_num_channel * 4, layers[2], stride=2,
|
| 212 |
+
dilate=replace_stride_with_dilation[1])
|
| 213 |
+
self.layer4 = self._make_layer(block, start_num_channel * 8, layers[3], stride=2,
|
| 214 |
+
dilate=replace_stride_with_dilation[2])
|
| 215 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 216 |
+
self.fc = nn.Linear(start_num_channel * 8 * block.expansion, num_classes)
|
| 217 |
+
|
| 218 |
+
for m in self.modules():
|
| 219 |
+
if isinstance(m, nn.Conv2d):
|
| 220 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 221 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 222 |
+
nn.init.constant_(m.weight, 1)
|
| 223 |
+
nn.init.constant_(m.bias, 0)
|
| 224 |
+
|
| 225 |
+
# Zero-initialize the last BN in each residual branch,
|
| 226 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 227 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 228 |
+
if zero_init_residual:
|
| 229 |
+
for m in self.modules():
|
| 230 |
+
if isinstance(m, Bottleneck):
|
| 231 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 232 |
+
elif isinstance(m, BasicBlock):
|
| 233 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 234 |
+
|
| 235 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 236 |
+
# vrl3: if block is 0, allows a smaller network size
|
| 237 |
+
if blocks == 0:
|
| 238 |
+
block = OneLayerBlock
|
| 239 |
+
|
| 240 |
+
norm_layer = self._norm_layer
|
| 241 |
+
downsample = None
|
| 242 |
+
previous_dilation = self.dilation
|
| 243 |
+
if dilate:
|
| 244 |
+
self.dilation *= stride
|
| 245 |
+
stride = 1
|
| 246 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 247 |
+
downsample = nn.Sequential(
|
| 248 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 249 |
+
norm_layer(planes * block.expansion),
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
layers = []
|
| 253 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 254 |
+
self.base_width, previous_dilation, norm_layer))
|
| 255 |
+
self.inplanes = planes * block.expansion
|
| 256 |
+
for _ in range(1, blocks):
|
| 257 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 258 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 259 |
+
norm_layer=norm_layer))
|
| 260 |
+
|
| 261 |
+
return nn.Sequential(*layers)
|
| 262 |
+
|
| 263 |
+
def _forward_impl(self, x):
|
| 264 |
+
x = self.conv1(x)
|
| 265 |
+
x = self.bn1(x)
|
| 266 |
+
x = self.relu(x)
|
| 267 |
+
|
| 268 |
+
x = self.layer1(x)
|
| 269 |
+
x = self.layer2(x)
|
| 270 |
+
x = self.layer3(x)
|
| 271 |
+
x = self.layer4(x)
|
| 272 |
+
x = self.avgpool(x)
|
| 273 |
+
x = torch.flatten(x, 1)
|
| 274 |
+
x = self.fc(x)
|
| 275 |
+
|
| 276 |
+
return x
|
| 277 |
+
|
| 278 |
+
def get_feature_size(self):
|
| 279 |
+
assert self.start_num_channel % 32 == 0
|
| 280 |
+
multiplier = self.start_num_channel // 32
|
| 281 |
+
size = 256 * multiplier
|
| 282 |
+
return size
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
return self._forward_impl(x)
|
| 286 |
+
|
| 287 |
+
def get_features(self, x):
|
| 288 |
+
x = self.conv1(x)
|
| 289 |
+
# print("0", x.shape) # 32 x 41 x 41 = 53792
|
| 290 |
+
x = self.bn1(x)
|
| 291 |
+
x = self.relu(x)
|
| 292 |
+
|
| 293 |
+
x = self.layer1(x)
|
| 294 |
+
# print("1", x.shape) # 32 x 41 x 41= 53792
|
| 295 |
+
|
| 296 |
+
x = self.layer2(x)
|
| 297 |
+
# print("2", x.shape) # 64 x 21 x 21 = 28224
|
| 298 |
+
|
| 299 |
+
x = self.layer3(x)
|
| 300 |
+
# print("3", x.shape) # 128 x 11 x 11 = 15488
|
| 301 |
+
|
| 302 |
+
x = self.layer4(x)
|
| 303 |
+
# print("4", x.shape) # 256 x 6 x 6 = 9216
|
| 304 |
+
|
| 305 |
+
x = self.avgpool(x)
|
| 306 |
+
# print("pool", x.shape) # 256 x 1 x 1
|
| 307 |
+
|
| 308 |
+
final_out = torch.flatten(x, 1)
|
| 309 |
+
# print("flatten", x.shape) # 256
|
| 310 |
+
return final_out
|
| 311 |
+
|
| 312 |
+
class Identity(nn.Module):
|
| 313 |
+
def __init__(self, input_placeholder=None):
|
| 314 |
+
super(Identity, self).__init__()
|
| 315 |
+
|
| 316 |
+
def forward(self, x):
|
| 317 |
+
return x
|
| 318 |
+
|
VRL3/src/train_stage1.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
# this file is modified from the pytorch official tutorial
|
| 5 |
+
# NOTE: stage 1 training code is currently being cleaned up
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
import shutil
|
| 11 |
+
import time
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.parallel
|
| 17 |
+
import torch.backends.cudnn as cudnn
|
| 18 |
+
import torch.distributed as dist
|
| 19 |
+
import torch.optim
|
| 20 |
+
import torch.multiprocessing as mp
|
| 21 |
+
import torch.utils.data
|
| 22 |
+
import torch.utils.data.distributed
|
| 23 |
+
import torchvision.transforms as transforms
|
| 24 |
+
import torchvision.datasets as datasets
|
| 25 |
+
import torchvision.models as models
|
| 26 |
+
|
| 27 |
+
# TODO use another config file to indicate the location of the training data and also where to save models...
|
| 28 |
+
# should also add an option to just test the accuracy of models...
|
| 29 |
+
# we probably can test this locally ....
|
| 30 |
+
|
| 31 |
+
from stage1_models import BasicBlock, ResNet84
|
| 32 |
+
|
| 33 |
+
rl_model_names = ['resnet6_32channel', 'resnet10_32channel', 'resnet18_32channel',
|
| 34 |
+
'resnet6_64channel', 'resnet10_64channel', 'resnet18_64channel',]
|
| 35 |
+
model_names = sorted(name for name in models.__dict__
|
| 36 |
+
if name.islower() and not name.startswith("__")
|
| 37 |
+
and callable(models.__dict__[name])) + rl_model_names
|
| 38 |
+
|
| 39 |
+
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
|
| 40 |
+
parser.add_argument('data', metavar='DIR',
|
| 41 |
+
help='path to dataset')
|
| 42 |
+
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet10_32channel',
|
| 43 |
+
choices=model_names,
|
| 44 |
+
help='model architecture: ' +
|
| 45 |
+
' | '.join(model_names) +
|
| 46 |
+
' (default: resnet18)')
|
| 47 |
+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
| 48 |
+
help='number of data loading workers (default: 4)')
|
| 49 |
+
parser.add_argument('--epochs', default=90, type=int, metavar='N',
|
| 50 |
+
help='number of total epochs to run')
|
| 51 |
+
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
| 52 |
+
help='manual epoch number (useful on restarts)')
|
| 53 |
+
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
| 54 |
+
metavar='N',
|
| 55 |
+
help='mini-batch size (default: 256), this is the total '
|
| 56 |
+
'batch size of all GPUs on the current node when '
|
| 57 |
+
'using Data Parallel or Distributed Data Parallel')
|
| 58 |
+
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
|
| 59 |
+
metavar='LR', help='initial learning rate', dest='lr')
|
| 60 |
+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
| 61 |
+
help='momentum')
|
| 62 |
+
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
| 63 |
+
metavar='W', help='weight decay (default: 1e-4)',
|
| 64 |
+
dest='weight_decay')
|
| 65 |
+
parser.add_argument('-p', '--print-freq', default=10, type=int,
|
| 66 |
+
metavar='N', help='print frequency (default: 10)')
|
| 67 |
+
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
| 68 |
+
help='path to latest checkpoint (default: none)')
|
| 69 |
+
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
|
| 70 |
+
help='evaluate model on validation set')
|
| 71 |
+
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
| 72 |
+
help='use pre-trained model')
|
| 73 |
+
parser.add_argument('--world-size', default=-1, type=int,
|
| 74 |
+
help='number of nodes for distributed training')
|
| 75 |
+
parser.add_argument('--rank', default=-1, type=int,
|
| 76 |
+
help='node rank for distributed training')
|
| 77 |
+
parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
|
| 78 |
+
help='url used to set up distributed training')
|
| 79 |
+
parser.add_argument('--dist-backend', default='nccl', type=str,
|
| 80 |
+
help='distributed backend')
|
| 81 |
+
parser.add_argument('--seed', default=None, type=int,
|
| 82 |
+
help='seed for initializing training. ')
|
| 83 |
+
parser.add_argument('--gpu', default=None, type=int,
|
| 84 |
+
help='GPU id to use.')
|
| 85 |
+
parser.add_argument('--multiprocessing-distributed', action='store_true',
|
| 86 |
+
help='Use multi-processing distributed training to launch '
|
| 87 |
+
'N processes per node, which has N GPUs. This is the '
|
| 88 |
+
'fastest way to use PyTorch for either single node or '
|
| 89 |
+
'multi node data parallel training')
|
| 90 |
+
parser.add_argument('--debug', default=0, type=int,
|
| 91 |
+
help='1 for debug mode, 2 for super fast debug mode')
|
| 92 |
+
|
| 93 |
+
best_acc1 = 0
|
| 94 |
+
INPUT_SIZE = 84
|
| 95 |
+
VAL_RESIZE = 100
|
| 96 |
+
|
| 97 |
+
def main():
|
| 98 |
+
print(model_names)
|
| 99 |
+
|
| 100 |
+
args = parser.parse_args()
|
| 101 |
+
|
| 102 |
+
if args.seed is not None:
|
| 103 |
+
random.seed(args.seed)
|
| 104 |
+
torch.manual_seed(args.seed)
|
| 105 |
+
cudnn.deterministic = True
|
| 106 |
+
warnings.warn('You have chosen to seed training. '
|
| 107 |
+
'This will turn on the CUDNN deterministic setting, '
|
| 108 |
+
'which can slow down your training considerably! '
|
| 109 |
+
'You may see unexpected behavior when restarting '
|
| 110 |
+
'from checkpoints.')
|
| 111 |
+
|
| 112 |
+
if args.gpu is not None:
|
| 113 |
+
warnings.warn('You have chosen a specific GPU. This will completely '
|
| 114 |
+
'disable data parallelism.')
|
| 115 |
+
|
| 116 |
+
if args.dist_url == "env://" and args.world_size == -1:
|
| 117 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
| 118 |
+
|
| 119 |
+
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
| 120 |
+
|
| 121 |
+
ngpus_per_node = torch.cuda.device_count()
|
| 122 |
+
if args.multiprocessing_distributed:
|
| 123 |
+
# Since we have ngpus_per_node processes per node, the total world_size
|
| 124 |
+
# needs to be adjusted accordingly
|
| 125 |
+
args.world_size = ngpus_per_node * args.world_size
|
| 126 |
+
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
| 127 |
+
# main_worker process function
|
| 128 |
+
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
| 129 |
+
else:
|
| 130 |
+
# Simply call main_worker function
|
| 131 |
+
main_worker(args.gpu, ngpus_per_node, args)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def main_worker(gpu, ngpus_per_node, args):
|
| 135 |
+
global best_acc1
|
| 136 |
+
args.gpu = gpu
|
| 137 |
+
|
| 138 |
+
if args.gpu is not None:
|
| 139 |
+
print("Use GPU: {} for training".format(args.gpu))
|
| 140 |
+
|
| 141 |
+
if args.distributed:
|
| 142 |
+
if args.dist_url == "env://" and args.rank == -1:
|
| 143 |
+
args.rank = int(os.environ["RANK"])
|
| 144 |
+
if args.multiprocessing_distributed:
|
| 145 |
+
# For multiprocessing distributed training, rank needs to be the
|
| 146 |
+
# global rank among all the processes
|
| 147 |
+
args.rank = args.rank * ngpus_per_node + gpu
|
| 148 |
+
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 149 |
+
world_size=args.world_size, rank=args.rank)
|
| 150 |
+
# create model
|
| 151 |
+
if args.debug > 0:
|
| 152 |
+
print("=> creating model for debug 2")
|
| 153 |
+
# model = ResNet84(BasicBlock, [1, 1, 1, 1], num_classes=5) # 1, 1, 1, 1 will make a resnet10
|
| 154 |
+
model = ResNet84(BasicBlock, [0, 0, 0, 0], num_classes=5) # 0, 0, 0, 0 make a convnet6 (5 conv layers in total lol)
|
| 155 |
+
x = torch.rand((1, 3, 84, 84)).float()
|
| 156 |
+
out = model(x)
|
| 157 |
+
print(model)
|
| 158 |
+
quit()
|
| 159 |
+
|
| 160 |
+
# model = ResNetTest2(BasicBlock, [2, 2, 2, 2])
|
| 161 |
+
#model = Drq4Encoder((3, 84, 84), n_channel, 200)
|
| 162 |
+
else:
|
| 163 |
+
if args.pretrained:
|
| 164 |
+
print("=> using pre-trained model '{}'".format(args.arch))
|
| 165 |
+
model = models.__dict__[args.arch](pretrained=True)
|
| 166 |
+
else:
|
| 167 |
+
print("=> creating model '{}'".format(args.arch))
|
| 168 |
+
if args.arch in rl_model_names:
|
| 169 |
+
if args.arch == 'resnet18_32channel':
|
| 170 |
+
model = ResNet84(BasicBlock, [2, 2, 2, 2], start_num_channel=32) # 1, 1, 1, 1 will make a resnet10
|
| 171 |
+
elif args.arch == 'resnet10_32channel':
|
| 172 |
+
model = ResNet84(BasicBlock, [1, 1, 1, 1], start_num_channel=32) # 1, 1, 1, 1 will make a resnet10
|
| 173 |
+
elif args.arch == 'resnet6_32channel':
|
| 174 |
+
model = ResNet84(BasicBlock, [0, 0, 0, 0], start_num_channel=32) # resnet 6 (actually not even resnet because no skip connection)
|
| 175 |
+
elif args.arch == 'resnet18_64channel':
|
| 176 |
+
model = ResNet84(BasicBlock, [2, 2, 2, 2], start_num_channel=64)
|
| 177 |
+
elif args.arch == 'resnet10_64channel':
|
| 178 |
+
model = ResNet84(BasicBlock, [1, 1, 1, 1], start_num_channel=64)
|
| 179 |
+
elif args.arch == 'resnet6_64channel':
|
| 180 |
+
model = ResNet84(BasicBlock, [0, 0, 0, 0], start_num_channel=64)
|
| 181 |
+
else:
|
| 182 |
+
print("specialized model not yet implemented")
|
| 183 |
+
quit()
|
| 184 |
+
else:
|
| 185 |
+
model = models.__dict__[args.arch]()
|
| 186 |
+
|
| 187 |
+
if not torch.cuda.is_available():
|
| 188 |
+
print('using CPU, this will be slow')
|
| 189 |
+
elif args.distributed:
|
| 190 |
+
print("distributed")
|
| 191 |
+
# For multiprocessing distributed, DistributedDataParallel constructor
|
| 192 |
+
# should always set the single device scope, otherwise,
|
| 193 |
+
# DistributedDataParallel will use all available devices.
|
| 194 |
+
if args.gpu is not None:
|
| 195 |
+
torch.cuda.set_device(args.gpu)
|
| 196 |
+
model.cuda(args.gpu)
|
| 197 |
+
# When using a single GPU per process and per
|
| 198 |
+
# DistributedDataParallel, we need to divide the batch size
|
| 199 |
+
# ourselves based on the total number of GPUs we have
|
| 200 |
+
args.batch_size = int(args.batch_size / ngpus_per_node)
|
| 201 |
+
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
| 202 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
| 203 |
+
else:
|
| 204 |
+
model.cuda()
|
| 205 |
+
# DistributedDataParallel will divide and allocate batch_size to all
|
| 206 |
+
# available GPUs if device_ids are not set
|
| 207 |
+
model = torch.nn.parallel.DistributedDataParallel(model)
|
| 208 |
+
elif args.gpu is not None:
|
| 209 |
+
print("use gpu:", args.gpu)
|
| 210 |
+
torch.cuda.set_device(args.gpu)
|
| 211 |
+
model = model.cuda(args.gpu)
|
| 212 |
+
else:
|
| 213 |
+
print("data parallel")
|
| 214 |
+
# DataParallel will divide and allocate batch_size to all available GPUs
|
| 215 |
+
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
|
| 216 |
+
model.features = torch.nn.DataParallel(model.features)
|
| 217 |
+
model.cuda()
|
| 218 |
+
else:
|
| 219 |
+
model = torch.nn.DataParallel(model).cuda()
|
| 220 |
+
|
| 221 |
+
# define loss function (criterion) and optimizer
|
| 222 |
+
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
|
| 223 |
+
|
| 224 |
+
optimizer = torch.optim.SGD(model.parameters(), args.lr,
|
| 225 |
+
momentum=args.momentum,
|
| 226 |
+
weight_decay=args.weight_decay)
|
| 227 |
+
|
| 228 |
+
# optionally resume from a checkpoint
|
| 229 |
+
if args.resume:
|
| 230 |
+
if os.path.isfile(args.resume):
|
| 231 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
| 232 |
+
if args.gpu is None:
|
| 233 |
+
checkpoint = torch.load(args.resume)
|
| 234 |
+
else:
|
| 235 |
+
# Map model to be loaded to specified single gpu.
|
| 236 |
+
loc = 'cuda:{}'.format(args.gpu)
|
| 237 |
+
checkpoint = torch.load(args.resume, map_location=loc)
|
| 238 |
+
args.start_epoch = checkpoint['epoch']
|
| 239 |
+
best_acc1 = checkpoint['best_acc1']
|
| 240 |
+
if args.gpu is not None:
|
| 241 |
+
# best_acc1 may be from a checkpoint from a different GPU
|
| 242 |
+
best_acc1 = best_acc1.to(args.gpu)
|
| 243 |
+
model.load_state_dict(checkpoint['state_dict'])
|
| 244 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 245 |
+
print("=> loaded checkpoint '{}' (epoch {})"
|
| 246 |
+
.format(args.resume, checkpoint['epoch']))
|
| 247 |
+
else:
|
| 248 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
| 249 |
+
|
| 250 |
+
cudnn.benchmark = True
|
| 251 |
+
|
| 252 |
+
# Data loading code
|
| 253 |
+
traindir = os.path.join(args.data, 'train')
|
| 254 |
+
valdir = os.path.join(args.data, 'val')
|
| 255 |
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 256 |
+
std=[0.229, 0.224, 0.225])
|
| 257 |
+
|
| 258 |
+
print("train directory is:", traindir)
|
| 259 |
+
print("val directory is:", valdir)
|
| 260 |
+
|
| 261 |
+
train_dataset = datasets.ImageFolder(
|
| 262 |
+
traindir,
|
| 263 |
+
transforms.Compose([
|
| 264 |
+
transforms.RandomResizedCrop(INPUT_SIZE),
|
| 265 |
+
transforms.RandomHorizontalFlip(),
|
| 266 |
+
transforms.ToTensor(),
|
| 267 |
+
normalize,
|
| 268 |
+
]))
|
| 269 |
+
|
| 270 |
+
print("data set ready")
|
| 271 |
+
|
| 272 |
+
if args.distributed:
|
| 273 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
| 274 |
+
else:
|
| 275 |
+
train_sampler = None
|
| 276 |
+
|
| 277 |
+
train_loader = torch.utils.data.DataLoader(
|
| 278 |
+
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
| 279 |
+
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
| 280 |
+
|
| 281 |
+
val_loader = torch.utils.data.DataLoader(
|
| 282 |
+
datasets.ImageFolder(valdir, transforms.Compose([
|
| 283 |
+
# transforms.Resize(256),
|
| 284 |
+
transforms.Resize(VAL_RESIZE),
|
| 285 |
+
transforms.CenterCrop(INPUT_SIZE),
|
| 286 |
+
transforms.ToTensor(),
|
| 287 |
+
normalize,
|
| 288 |
+
])),
|
| 289 |
+
batch_size=args.batch_size, shuffle=False,
|
| 290 |
+
num_workers=args.workers, pin_memory=True)
|
| 291 |
+
|
| 292 |
+
if args.evaluate:
|
| 293 |
+
validate(val_loader, model, criterion, args)
|
| 294 |
+
return
|
| 295 |
+
|
| 296 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 297 |
+
print(epoch)
|
| 298 |
+
epoch_start_time = time.time()
|
| 299 |
+
|
| 300 |
+
if args.distributed:
|
| 301 |
+
train_sampler.set_epoch(epoch)
|
| 302 |
+
adjust_learning_rate(optimizer, epoch, args)
|
| 303 |
+
|
| 304 |
+
# train for one epoch
|
| 305 |
+
train(train_loader, model, criterion, optimizer, epoch, args)
|
| 306 |
+
|
| 307 |
+
# evaluate on validation set
|
| 308 |
+
acc1 = validate(val_loader, model, criterion, args)
|
| 309 |
+
|
| 310 |
+
# remember best acc@1 and save checkpoint
|
| 311 |
+
is_best = acc1 > best_acc1
|
| 312 |
+
best_acc1 = max(acc1, best_acc1)
|
| 313 |
+
|
| 314 |
+
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
| 315 |
+
and args.rank % ngpus_per_node == 0):
|
| 316 |
+
save_checkpoint({
|
| 317 |
+
'epoch': epoch + 1,
|
| 318 |
+
'arch': args.arch,
|
| 319 |
+
'state_dict': model.state_dict(),
|
| 320 |
+
'best_acc1': best_acc1,
|
| 321 |
+
'optimizer' : optimizer.state_dict(),
|
| 322 |
+
}, is_best,
|
| 323 |
+
save_name_prefix=args.arch)
|
| 324 |
+
|
| 325 |
+
epoch_end_time = time.time() - epoch_start_time
|
| 326 |
+
print("epoch finished in %.3f hour" % (epoch_end_time/3600))
|
| 327 |
+
|
| 328 |
+
def train(train_loader, model, criterion, optimizer, epoch, args):
|
| 329 |
+
batch_time = AverageMeter('Time', ':6.3f')
|
| 330 |
+
data_time = AverageMeter('Data', ':6.3f')
|
| 331 |
+
losses = AverageMeter('Loss', ':.4e')
|
| 332 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
| 333 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
| 334 |
+
progress = ProgressMeter(
|
| 335 |
+
len(train_loader),
|
| 336 |
+
[batch_time, data_time, losses, top1, top5],
|
| 337 |
+
prefix="Epoch: [{}]".format(epoch))
|
| 338 |
+
|
| 339 |
+
# switch to train mode
|
| 340 |
+
model.train()
|
| 341 |
+
|
| 342 |
+
end = time.time()
|
| 343 |
+
for i, (images, target) in enumerate(train_loader):
|
| 344 |
+
# measure data loading time
|
| 345 |
+
data_time.update(time.time() - end)
|
| 346 |
+
|
| 347 |
+
if args.gpu is not None:
|
| 348 |
+
images = images.cuda(args.gpu, non_blocking=True)
|
| 349 |
+
if torch.cuda.is_available():
|
| 350 |
+
target = target.cuda(args.gpu, non_blocking=True)
|
| 351 |
+
|
| 352 |
+
# compute output
|
| 353 |
+
output = model(images)
|
| 354 |
+
loss = criterion(output, target)
|
| 355 |
+
|
| 356 |
+
# measure accuracy and record loss
|
| 357 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
| 358 |
+
losses.update(loss.item(), images.size(0))
|
| 359 |
+
top1.update(acc1[0], images.size(0))
|
| 360 |
+
top5.update(acc5[0], images.size(0))
|
| 361 |
+
|
| 362 |
+
# compute gradient and do SGD step
|
| 363 |
+
optimizer.zero_grad()
|
| 364 |
+
loss.backward()
|
| 365 |
+
optimizer.step()
|
| 366 |
+
|
| 367 |
+
# measure elapsed time
|
| 368 |
+
batch_time.update(time.time() - end)
|
| 369 |
+
end = time.time()
|
| 370 |
+
|
| 371 |
+
if i % args.print_freq == 0:
|
| 372 |
+
progress.display(i)
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def validate(val_loader, model, criterion, args):
|
| 376 |
+
batch_time = AverageMeter('Time', ':6.3f')
|
| 377 |
+
losses = AverageMeter('Loss', ':.4e')
|
| 378 |
+
top1 = AverageMeter('Acc@1', ':6.2f')
|
| 379 |
+
top5 = AverageMeter('Acc@5', ':6.2f')
|
| 380 |
+
progress = ProgressMeter(
|
| 381 |
+
len(val_loader),
|
| 382 |
+
[batch_time, losses, top1, top5],
|
| 383 |
+
prefix='Test: ')
|
| 384 |
+
|
| 385 |
+
# switch to evaluate mode
|
| 386 |
+
model.eval()
|
| 387 |
+
|
| 388 |
+
with torch.no_grad():
|
| 389 |
+
end = time.time()
|
| 390 |
+
for i, (images, target) in enumerate(val_loader):
|
| 391 |
+
if args.gpu is not None:
|
| 392 |
+
images = images.cuda(args.gpu, non_blocking=True)
|
| 393 |
+
if torch.cuda.is_available():
|
| 394 |
+
target = target.cuda(args.gpu, non_blocking=True)
|
| 395 |
+
|
| 396 |
+
# compute output
|
| 397 |
+
output = model(images)
|
| 398 |
+
loss = criterion(output, target)
|
| 399 |
+
|
| 400 |
+
# measure accuracy and record loss
|
| 401 |
+
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
| 402 |
+
losses.update(loss.item(), images.size(0))
|
| 403 |
+
top1.update(acc1[0], images.size(0))
|
| 404 |
+
top5.update(acc5[0], images.size(0))
|
| 405 |
+
|
| 406 |
+
# measure elapsed time
|
| 407 |
+
batch_time.update(time.time() - end)
|
| 408 |
+
end = time.time()
|
| 409 |
+
|
| 410 |
+
if i % args.print_freq == 0:
|
| 411 |
+
progress.display(i)
|
| 412 |
+
|
| 413 |
+
# TODO: this should also be done with the ProgressMeter
|
| 414 |
+
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
|
| 415 |
+
.format(top1=top1, top5=top5))
|
| 416 |
+
|
| 417 |
+
return top1.avg
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', save_name_prefix = ''):
|
| 421 |
+
save_name = save_name_prefix + '_' + filename
|
| 422 |
+
torch.save(state, save_name)
|
| 423 |
+
if is_best:
|
| 424 |
+
best_model_save_name = save_name_prefix + '_' + 'model_best.pth.tar'
|
| 425 |
+
shutil.copyfile(save_name, best_model_save_name)
|
| 426 |
+
|
| 427 |
+
class AverageMeter(object):
|
| 428 |
+
"""Computes and stores the average and current value"""
|
| 429 |
+
def __init__(self, name, fmt=':f'):
|
| 430 |
+
self.name = name
|
| 431 |
+
self.fmt = fmt
|
| 432 |
+
self.reset()
|
| 433 |
+
|
| 434 |
+
def reset(self):
|
| 435 |
+
self.val = 0
|
| 436 |
+
self.avg = 0
|
| 437 |
+
self.sum = 0
|
| 438 |
+
self.count = 0
|
| 439 |
+
|
| 440 |
+
def update(self, val, n=1):
|
| 441 |
+
self.val = val
|
| 442 |
+
self.sum += val * n
|
| 443 |
+
self.count += n
|
| 444 |
+
self.avg = self.sum / self.count
|
| 445 |
+
|
| 446 |
+
def __str__(self):
|
| 447 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
| 448 |
+
return fmtstr.format(**self.__dict__)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class ProgressMeter(object):
|
| 452 |
+
def __init__(self, num_batches, meters, prefix=""):
|
| 453 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
| 454 |
+
self.meters = meters
|
| 455 |
+
self.prefix = prefix
|
| 456 |
+
|
| 457 |
+
def display(self, batch):
|
| 458 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
| 459 |
+
entries += [str(meter) for meter in self.meters]
|
| 460 |
+
print('\t'.join(entries))
|
| 461 |
+
|
| 462 |
+
def _get_batch_fmtstr(self, num_batches):
|
| 463 |
+
num_digits = len(str(num_batches // 1))
|
| 464 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
| 465 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def adjust_learning_rate(optimizer, epoch, args):
|
| 469 |
+
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
|
| 470 |
+
lr = args.lr * (0.1 ** (epoch // 30))
|
| 471 |
+
for param_group in optimizer.param_groups:
|
| 472 |
+
param_group['lr'] = lr
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def accuracy(output, target, topk=(1,)):
|
| 476 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
| 477 |
+
with torch.no_grad():
|
| 478 |
+
maxk = max(topk)
|
| 479 |
+
batch_size = target.size(0)
|
| 480 |
+
|
| 481 |
+
_, pred = output.topk(maxk, 1, True, True)
|
| 482 |
+
pred = pred.t()
|
| 483 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
| 484 |
+
|
| 485 |
+
res = []
|
| 486 |
+
for k in topk:
|
| 487 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
| 488 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
| 489 |
+
return res
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
if __name__ == '__main__':
|
| 493 |
+
main()
|
VRL3/src/utils.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from omegaconf import OmegaConf
|
| 14 |
+
from torch import distributions as pyd
|
| 15 |
+
from torch.distributions.utils import _standard_normal
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class eval_mode:
|
| 19 |
+
def __init__(self, *models):
|
| 20 |
+
self.models = models
|
| 21 |
+
|
| 22 |
+
def __enter__(self):
|
| 23 |
+
self.prev_states = []
|
| 24 |
+
for model in self.models:
|
| 25 |
+
self.prev_states.append(model.training)
|
| 26 |
+
model.train(False)
|
| 27 |
+
|
| 28 |
+
def __exit__(self, *args):
|
| 29 |
+
for model, state in zip(self.models, self.prev_states):
|
| 30 |
+
model.train(state)
|
| 31 |
+
return False
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def set_seed_everywhere(seed):
|
| 35 |
+
torch.manual_seed(seed)
|
| 36 |
+
if torch.cuda.is_available():
|
| 37 |
+
torch.cuda.manual_seed_all(seed)
|
| 38 |
+
np.random.seed(seed)
|
| 39 |
+
random.seed(seed)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def soft_update_params(net, target_net, tau):
|
| 43 |
+
for param, target_param in zip(net.parameters(), target_net.parameters()):
|
| 44 |
+
target_param.data.copy_(tau * param.data +
|
| 45 |
+
(1 - tau) * target_param.data)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def to_torch(xs, device):
|
| 49 |
+
return tuple(torch.as_tensor(x, device=device) for x in xs)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def weight_init(m):
|
| 53 |
+
if isinstance(m, nn.Linear):
|
| 54 |
+
nn.init.orthogonal_(m.weight.data)
|
| 55 |
+
if hasattr(m.bias, 'data'):
|
| 56 |
+
m.bias.data.fill_(0.0)
|
| 57 |
+
elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
| 58 |
+
gain = nn.init.calculate_gain('relu')
|
| 59 |
+
nn.init.orthogonal_(m.weight.data, gain)
|
| 60 |
+
if hasattr(m.bias, 'data'):
|
| 61 |
+
m.bias.data.fill_(0.0)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Until:
|
| 65 |
+
def __init__(self, until, action_repeat=1):
|
| 66 |
+
self._until = until
|
| 67 |
+
self._action_repeat = action_repeat
|
| 68 |
+
|
| 69 |
+
def __call__(self, step):
|
| 70 |
+
if self._until is None:
|
| 71 |
+
return True
|
| 72 |
+
until = self._until // self._action_repeat
|
| 73 |
+
return step < until
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class Every:
|
| 77 |
+
def __init__(self, every, action_repeat=1):
|
| 78 |
+
self._every = every
|
| 79 |
+
self._action_repeat = action_repeat
|
| 80 |
+
|
| 81 |
+
def __call__(self, step):
|
| 82 |
+
if self._every is None:
|
| 83 |
+
return False
|
| 84 |
+
every = self._every // self._action_repeat
|
| 85 |
+
if step % every == 0:
|
| 86 |
+
return True
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Timer:
|
| 91 |
+
def __init__(self):
|
| 92 |
+
self._start_time = time.time()
|
| 93 |
+
self._last_time = time.time()
|
| 94 |
+
|
| 95 |
+
def reset(self):
|
| 96 |
+
elapsed_time = time.time() - self._last_time
|
| 97 |
+
self._last_time = time.time()
|
| 98 |
+
total_time = time.time() - self._start_time
|
| 99 |
+
return elapsed_time, total_time
|
| 100 |
+
|
| 101 |
+
def total_time(self):
|
| 102 |
+
return time.time() - self._start_time
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class TruncatedNormal(pyd.Normal):
|
| 106 |
+
def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
|
| 107 |
+
super().__init__(loc, scale, validate_args=False)
|
| 108 |
+
self.low = low
|
| 109 |
+
self.high = high
|
| 110 |
+
self.eps = eps
|
| 111 |
+
|
| 112 |
+
def _clamp(self, x):
|
| 113 |
+
clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
|
| 114 |
+
x = x - x.detach() + clamped_x.detach()
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
def sample(self, clip=None, sample_shape=torch.Size()):
|
| 118 |
+
shape = self._extended_shape(sample_shape)
|
| 119 |
+
eps = _standard_normal(shape,
|
| 120 |
+
dtype=self.loc.dtype,
|
| 121 |
+
device=self.loc.device)
|
| 122 |
+
eps *= self.scale
|
| 123 |
+
if clip is not None:
|
| 124 |
+
eps = torch.clamp(eps, -clip, clip)
|
| 125 |
+
x = self.loc + eps
|
| 126 |
+
return self._clamp(x)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def schedule(schdl, step):
|
| 130 |
+
try:
|
| 131 |
+
return float(schdl)
|
| 132 |
+
except ValueError:
|
| 133 |
+
match = re.match(r'linear\((.+),(.+),(.+)\)', schdl)
|
| 134 |
+
if match:
|
| 135 |
+
init, final, duration = [float(g) for g in match.groups()]
|
| 136 |
+
mix = np.clip(step / duration, 0.0, 1.0)
|
| 137 |
+
return (1.0 - mix) * init + mix * final
|
| 138 |
+
match = re.match(r'step_linear\((.+),(.+),(.+),(.+),(.+)\)', schdl)
|
| 139 |
+
if match:
|
| 140 |
+
init, final1, duration1, final2, duration2 = [
|
| 141 |
+
float(g) for g in match.groups()
|
| 142 |
+
]
|
| 143 |
+
if step <= duration1:
|
| 144 |
+
mix = np.clip(step / duration1, 0.0, 1.0)
|
| 145 |
+
return (1.0 - mix) * init + mix * final1
|
| 146 |
+
else:
|
| 147 |
+
mix = np.clip((step - duration1) / duration2, 0.0, 1.0)
|
| 148 |
+
return (1.0 - mix) * final1 + mix * final2
|
| 149 |
+
raise NotImplementedError(schdl)
|
VRL3/src/vrl3_agent.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Microsoft Corporation.
|
| 2 |
+
# Licensed under the MIT License.
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torchvision import datasets, models, transforms
|
| 9 |
+
from transfer_util import initialize_model
|
| 10 |
+
from stage1_models import BasicBlock, ResNet84
|
| 11 |
+
import os
|
| 12 |
+
import copy
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import platform
|
| 15 |
+
from numbers import Number
|
| 16 |
+
import utils
|
| 17 |
+
|
| 18 |
+
class RandomShiftsAug(nn.Module):
|
| 19 |
+
def __init__(self, pad):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.pad = pad
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
n, c, h, w = x.size()
|
| 25 |
+
assert h == w
|
| 26 |
+
padding = tuple([self.pad] * 4)
|
| 27 |
+
x = F.pad(x, padding, 'replicate')
|
| 28 |
+
eps = 1.0 / (h + 2 * self.pad)
|
| 29 |
+
arange = torch.linspace(-1.0 + eps,
|
| 30 |
+
1.0 - eps,
|
| 31 |
+
h + 2 * self.pad,
|
| 32 |
+
device=x.device,
|
| 33 |
+
dtype=x.dtype)[:h]
|
| 34 |
+
arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2)
|
| 35 |
+
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
| 36 |
+
base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1)
|
| 37 |
+
|
| 38 |
+
shift = torch.randint(0,
|
| 39 |
+
2 * self.pad + 1,
|
| 40 |
+
size=(n, 1, 1, 2),
|
| 41 |
+
device=x.device,
|
| 42 |
+
dtype=x.dtype)
|
| 43 |
+
shift *= 2.0 / (h + 2 * self.pad)
|
| 44 |
+
|
| 45 |
+
grid = base_grid + shift
|
| 46 |
+
return F.grid_sample(x,
|
| 47 |
+
grid,
|
| 48 |
+
padding_mode='zeros',
|
| 49 |
+
align_corners=False)
|
| 50 |
+
|
| 51 |
+
class Identity(nn.Module):
|
| 52 |
+
def __init__(self, input_placeholder=None):
|
| 53 |
+
super(Identity, self).__init__()
|
| 54 |
+
|
| 55 |
+
def forward(self, x):
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
class RLEncoder(nn.Module):
|
| 59 |
+
def __init__(self, obs_shape, model_name, device):
|
| 60 |
+
super().__init__()
|
| 61 |
+
# a wrapper over a non-RL encoder model
|
| 62 |
+
self.device = device
|
| 63 |
+
assert len(obs_shape) == 3
|
| 64 |
+
self.n_input_channel = obs_shape[0]
|
| 65 |
+
assert self.n_input_channel % 3 == 0
|
| 66 |
+
self.n_images = self.n_input_channel // 3
|
| 67 |
+
self.model = self.init_model(model_name)
|
| 68 |
+
self.model.fc = Identity()
|
| 69 |
+
self.repr_dim = self.model.get_feature_size()
|
| 70 |
+
|
| 71 |
+
self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406),
|
| 72 |
+
(0.229, 0.224, 0.225))
|
| 73 |
+
self.channel_mismatch = True
|
| 74 |
+
|
| 75 |
+
def init_model(self, model_name):
|
| 76 |
+
# model name is e.g. resnet6_32channel
|
| 77 |
+
n_layer_string, n_channel_string = model_name.split('_')
|
| 78 |
+
layer_string_to_layer_list = {
|
| 79 |
+
'resnet6': [0, 0, 0, 0],
|
| 80 |
+
'resnet10': [1, 1, 1, 1],
|
| 81 |
+
'resnet18': [2, 2, 2, 2],
|
| 82 |
+
}
|
| 83 |
+
channel_string_to_n_channel = {
|
| 84 |
+
'32channel': 32,
|
| 85 |
+
'64channel': 64,
|
| 86 |
+
}
|
| 87 |
+
layer_list = layer_string_to_layer_list[n_layer_string]
|
| 88 |
+
start_num_channel = channel_string_to_n_channel[n_channel_string]
|
| 89 |
+
return ResNet84(BasicBlock, layer_list, start_num_channel=start_num_channel).to(self.device)
|
| 90 |
+
|
| 91 |
+
def expand_first_layer(self):
|
| 92 |
+
# convolutional channel expansion to deal with input mismatch
|
| 93 |
+
multiplier = self.n_images
|
| 94 |
+
self.model.conv1.weight.data = self.model.conv1.weight.data.repeat(1,multiplier,1,1) / multiplier
|
| 95 |
+
means = (0.485, 0.456, 0.406) * multiplier
|
| 96 |
+
stds = (0.229, 0.224, 0.225) * multiplier
|
| 97 |
+
self.normalize_op = transforms.Normalize(means, stds)
|
| 98 |
+
self.channel_mismatch = False
|
| 99 |
+
|
| 100 |
+
def freeze_bn(self):
|
| 101 |
+
# freeze batch norm layers (VRL3 ablation shows modifying how
|
| 102 |
+
# batch norm is trained does not affect performance)
|
| 103 |
+
for module in self.model.modules():
|
| 104 |
+
if isinstance(module, nn.BatchNorm2d):
|
| 105 |
+
if hasattr(module, 'weight'):
|
| 106 |
+
module.weight.requires_grad_(False)
|
| 107 |
+
if hasattr(module, 'bias'):
|
| 108 |
+
module.bias.requires_grad_(False)
|
| 109 |
+
module.eval()
|
| 110 |
+
|
| 111 |
+
def get_parameters_that_require_grad(self):
|
| 112 |
+
params = []
|
| 113 |
+
for name, param in self.named_parameters():
|
| 114 |
+
if param.requires_grad == True:
|
| 115 |
+
params.append(param)
|
| 116 |
+
return params
|
| 117 |
+
|
| 118 |
+
def transform_obs_tensor_batch(self, obs):
|
| 119 |
+
# transform obs batch before put into the pretrained resnet
|
| 120 |
+
new_obs = self.normalize_op(obs.float()/255)
|
| 121 |
+
return new_obs
|
| 122 |
+
|
| 123 |
+
def _forward_impl(self, x):
|
| 124 |
+
x = self.model.get_features(x)
|
| 125 |
+
return x
|
| 126 |
+
|
| 127 |
+
def forward(self, obs):
|
| 128 |
+
o = self.transform_obs_tensor_batch(obs)
|
| 129 |
+
h = self._forward_impl(o)
|
| 130 |
+
return h
|
| 131 |
+
|
| 132 |
+
class Stage3ShallowEncoder(nn.Module):
|
| 133 |
+
def __init__(self, obs_shape, n_channel):
|
| 134 |
+
super().__init__()
|
| 135 |
+
|
| 136 |
+
assert len(obs_shape) == 3
|
| 137 |
+
self.repr_dim = n_channel * 35 * 35
|
| 138 |
+
|
| 139 |
+
self.n_input_channel = obs_shape[0]
|
| 140 |
+
self.conv1 = nn.Conv2d(obs_shape[0], n_channel, 3, stride=2)
|
| 141 |
+
self.conv2 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
|
| 142 |
+
self.conv3 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
|
| 143 |
+
self.conv4 = nn.Conv2d(n_channel, n_channel, 3, stride=1)
|
| 144 |
+
self.relu = nn.ReLU(inplace=True)
|
| 145 |
+
|
| 146 |
+
# TODO here add prediction head so we can do contrastive learning...
|
| 147 |
+
|
| 148 |
+
self.apply(utils.weight_init)
|
| 149 |
+
self.normalize_op = transforms.Normalize((0.485, 0.456, 0.406, 0.485, 0.456, 0.406, 0.485, 0.456, 0.406),
|
| 150 |
+
(0.229, 0.224, 0.225, 0.229, 0.224, 0.225, 0.229, 0.224, 0.225))
|
| 151 |
+
|
| 152 |
+
self.compress = nn.Sequential(nn.Linear(self.repr_dim, 50), nn.LayerNorm(50), nn.Tanh())
|
| 153 |
+
self.pred_layer = nn.Linear(50, 50, bias=False)
|
| 154 |
+
|
| 155 |
+
def transform_obs_tensor_batch(self, obs):
|
| 156 |
+
# transform obs batch before put into the pretrained resnet
|
| 157 |
+
# correct order might be first augment, then resize, then normalize
|
| 158 |
+
# obs = F.interpolate(obs, size=self.pretrained_model_input_size)
|
| 159 |
+
new_obs = obs / 255.0 - 0.5
|
| 160 |
+
# new_obs = self.normalize_op(new_obs)
|
| 161 |
+
return new_obs
|
| 162 |
+
|
| 163 |
+
def _forward_impl(self, x):
|
| 164 |
+
x = self.relu(self.conv1(x))
|
| 165 |
+
x = self.relu(self.conv2(x))
|
| 166 |
+
x = self.relu(self.conv3(x))
|
| 167 |
+
x = self.relu(self.conv4(x))
|
| 168 |
+
return x
|
| 169 |
+
|
| 170 |
+
def forward(self, obs):
|
| 171 |
+
o = self.transform_obs_tensor_batch(obs)
|
| 172 |
+
h = self._forward_impl(o)
|
| 173 |
+
h = h.view(h.shape[0], -1)
|
| 174 |
+
return h
|
| 175 |
+
|
| 176 |
+
def get_anchor_output(self, obs, actions=None):
|
| 177 |
+
# typically go through conv and then compression layer and then a mlp
|
| 178 |
+
# used for UL update
|
| 179 |
+
conv_out = self.forward(obs)
|
| 180 |
+
compressed = self.compress(conv_out)
|
| 181 |
+
pred = self.pred_layer(compressed)
|
| 182 |
+
return pred, conv_out
|
| 183 |
+
|
| 184 |
+
def get_positive_output(self, obs):
|
| 185 |
+
# typically go through conv, compression
|
| 186 |
+
# used for UL update
|
| 187 |
+
conv_out = self.forward(obs)
|
| 188 |
+
compressed = self.compress(conv_out)
|
| 189 |
+
return compressed
|
| 190 |
+
|
| 191 |
+
class Encoder(nn.Module):
|
| 192 |
+
def __init__(self, obs_shape, n_channel):
|
| 193 |
+
super().__init__()
|
| 194 |
+
|
| 195 |
+
assert len(obs_shape) == 3
|
| 196 |
+
self.repr_dim = n_channel * 35 * 35
|
| 197 |
+
|
| 198 |
+
self.convnet = nn.Sequential(nn.Conv2d(obs_shape[0], n_channel, 3, stride=2),
|
| 199 |
+
nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
|
| 200 |
+
nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
|
| 201 |
+
nn.ReLU(), nn.Conv2d(n_channel, n_channel, 3, stride=1),
|
| 202 |
+
nn.ReLU())
|
| 203 |
+
|
| 204 |
+
self.apply(utils.weight_init)
|
| 205 |
+
|
| 206 |
+
def forward(self, obs):
|
| 207 |
+
obs = obs / 255.0 - 0.5
|
| 208 |
+
h = self.convnet(obs)
|
| 209 |
+
h = h.view(h.shape[0], -1)
|
| 210 |
+
return h
|
| 211 |
+
|
| 212 |
+
class IdentityEncoder(nn.Module):
|
| 213 |
+
def __init__(self, obs_shape):
|
| 214 |
+
super().__init__()
|
| 215 |
+
|
| 216 |
+
assert len(obs_shape) == 1
|
| 217 |
+
self.repr_dim = obs_shape[0]
|
| 218 |
+
|
| 219 |
+
def forward(self, obs):
|
| 220 |
+
return obs
|
| 221 |
+
|
| 222 |
+
class Actor(nn.Module):
|
| 223 |
+
def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
|
| 224 |
+
super().__init__()
|
| 225 |
+
|
| 226 |
+
self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
|
| 227 |
+
nn.LayerNorm(feature_dim), nn.Tanh())
|
| 228 |
+
|
| 229 |
+
self.policy = nn.Sequential(nn.Linear(feature_dim, hidden_dim),
|
| 230 |
+
nn.ReLU(inplace=True),
|
| 231 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 232 |
+
nn.ReLU(inplace=True),
|
| 233 |
+
nn.Linear(hidden_dim, action_shape[0]))
|
| 234 |
+
|
| 235 |
+
self.action_shift=0
|
| 236 |
+
self.action_scale=1
|
| 237 |
+
self.apply(utils.weight_init)
|
| 238 |
+
|
| 239 |
+
def forward(self, obs, std):
|
| 240 |
+
h = self.trunk(obs)
|
| 241 |
+
|
| 242 |
+
mu = self.policy(h)
|
| 243 |
+
mu = torch.tanh(mu)
|
| 244 |
+
mu = mu * self.action_scale + self.action_shift
|
| 245 |
+
std = torch.ones_like(mu) * std
|
| 246 |
+
|
| 247 |
+
dist = utils.TruncatedNormal(mu, std)
|
| 248 |
+
return dist
|
| 249 |
+
|
| 250 |
+
def forward_with_pretanh(self, obs, std):
|
| 251 |
+
h = self.trunk(obs)
|
| 252 |
+
|
| 253 |
+
mu = self.policy(h)
|
| 254 |
+
pretanh = mu
|
| 255 |
+
mu = torch.tanh(mu)
|
| 256 |
+
mu = mu * self.action_scale + self.action_shift
|
| 257 |
+
std = torch.ones_like(mu) * std
|
| 258 |
+
|
| 259 |
+
dist = utils.TruncatedNormal(mu, std)
|
| 260 |
+
return dist, pretanh
|
| 261 |
+
|
| 262 |
+
class Critic(nn.Module):
|
| 263 |
+
def __init__(self, repr_dim, action_shape, feature_dim, hidden_dim):
|
| 264 |
+
super().__init__()
|
| 265 |
+
|
| 266 |
+
self.trunk = nn.Sequential(nn.Linear(repr_dim, feature_dim),
|
| 267 |
+
nn.LayerNorm(feature_dim), nn.Tanh())
|
| 268 |
+
|
| 269 |
+
self.Q1 = nn.Sequential(
|
| 270 |
+
nn.Linear(feature_dim + action_shape[0], hidden_dim),
|
| 271 |
+
nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
|
| 272 |
+
nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
|
| 273 |
+
|
| 274 |
+
self.Q2 = nn.Sequential(
|
| 275 |
+
nn.Linear(feature_dim + action_shape[0], hidden_dim),
|
| 276 |
+
nn.ReLU(inplace=True), nn.Linear(hidden_dim, hidden_dim),
|
| 277 |
+
nn.ReLU(inplace=True), nn.Linear(hidden_dim, 1))
|
| 278 |
+
|
| 279 |
+
self.apply(utils.weight_init)
|
| 280 |
+
|
| 281 |
+
def forward(self, obs, action):
|
| 282 |
+
h = self.trunk(obs)
|
| 283 |
+
h_action = torch.cat([h, action], dim=-1)
|
| 284 |
+
q1 = self.Q1(h_action)
|
| 285 |
+
q2 = self.Q2(h_action)
|
| 286 |
+
|
| 287 |
+
return q1, q2
|
| 288 |
+
|
| 289 |
+
class VRL3Agent:
|
| 290 |
+
def __init__(self, obs_shape, action_shape, device, use_sensor, lr, feature_dim,
|
| 291 |
+
hidden_dim, critic_target_tau, num_expl_steps,
|
| 292 |
+
update_every_steps, stddev_clip, use_tb, use_data_aug, encoder_lr_scale,
|
| 293 |
+
stage1_model_name, safe_q_target_factor, safe_q_threshold, pretanh_penalty, pretanh_threshold,
|
| 294 |
+
stage2_update_encoder, cql_weight, cql_temp, cql_n_random, stage2_std, stage2_bc_weight,
|
| 295 |
+
stage3_update_encoder, std0, std1, std_n_decay,
|
| 296 |
+
stage3_bc_lam0, stage3_bc_lam1):
|
| 297 |
+
self.device = device
|
| 298 |
+
self.critic_target_tau = critic_target_tau
|
| 299 |
+
self.update_every_steps = update_every_steps
|
| 300 |
+
self.use_tb = use_tb
|
| 301 |
+
self.num_expl_steps = num_expl_steps
|
| 302 |
+
|
| 303 |
+
self.stage2_std = stage2_std
|
| 304 |
+
self.stage2_update_encoder = stage2_update_encoder
|
| 305 |
+
|
| 306 |
+
if std1 > std0:
|
| 307 |
+
std1 = std0
|
| 308 |
+
self.stddev_schedule = "linear(%s,%s,%s)" % (str(std0), str(std1), str(std_n_decay))
|
| 309 |
+
|
| 310 |
+
self.stddev_clip = stddev_clip
|
| 311 |
+
self.use_data_aug = use_data_aug
|
| 312 |
+
self.safe_q_target_factor = safe_q_target_factor
|
| 313 |
+
self.q_threshold = safe_q_threshold
|
| 314 |
+
self.pretanh_penalty = pretanh_penalty
|
| 315 |
+
|
| 316 |
+
self.cql_temp = cql_temp
|
| 317 |
+
self.cql_weight = cql_weight
|
| 318 |
+
self.cql_n_random = cql_n_random
|
| 319 |
+
|
| 320 |
+
self.pretanh_threshold = pretanh_threshold
|
| 321 |
+
|
| 322 |
+
self.stage2_bc_weight = stage2_bc_weight
|
| 323 |
+
self.stage3_bc_lam0 = stage3_bc_lam0
|
| 324 |
+
self.stage3_bc_lam1 = stage3_bc_lam1
|
| 325 |
+
|
| 326 |
+
if stage3_update_encoder and encoder_lr_scale > 0 and len(obs_shape) > 1:
|
| 327 |
+
self.stage3_update_encoder = True
|
| 328 |
+
else:
|
| 329 |
+
self.stage3_update_encoder = False
|
| 330 |
+
|
| 331 |
+
self.encoder = RLEncoder(obs_shape, stage1_model_name, device).to(device)
|
| 332 |
+
|
| 333 |
+
self.act_dim = action_shape[0]
|
| 334 |
+
|
| 335 |
+
if use_sensor:
|
| 336 |
+
downstream_input_dim = self.encoder.repr_dim + 24
|
| 337 |
+
else:
|
| 338 |
+
downstream_input_dim = self.encoder.repr_dim
|
| 339 |
+
|
| 340 |
+
self.actor = Actor(downstream_input_dim, action_shape, feature_dim,
|
| 341 |
+
hidden_dim).to(device)
|
| 342 |
+
self.critic = Critic(downstream_input_dim, action_shape, feature_dim,
|
| 343 |
+
hidden_dim).to(device)
|
| 344 |
+
self.critic_target = Critic(downstream_input_dim, action_shape,
|
| 345 |
+
feature_dim, hidden_dim).to(device)
|
| 346 |
+
self.critic_target.load_state_dict(self.critic.state_dict())
|
| 347 |
+
|
| 348 |
+
# optimizers
|
| 349 |
+
self.actor_opt = torch.optim.Adam(self.actor.parameters(), lr=lr)
|
| 350 |
+
self.critic_opt = torch.optim.Adam(self.critic.parameters(), lr=lr)
|
| 351 |
+
|
| 352 |
+
encoder_lr = lr * encoder_lr_scale
|
| 353 |
+
""" set up encoder optimizer """
|
| 354 |
+
self.encoder_opt = torch.optim.Adam(self.encoder.parameters(), lr=encoder_lr)
|
| 355 |
+
# data augmentation
|
| 356 |
+
self.aug = RandomShiftsAug(pad=4)
|
| 357 |
+
self.train()
|
| 358 |
+
self.critic_target.train()
|
| 359 |
+
|
| 360 |
+
def load_pretrained_encoder(self, model_path, verbose=True):
|
| 361 |
+
if verbose:
|
| 362 |
+
print("Trying to load pretrained model from:", model_path)
|
| 363 |
+
checkpoint = torch.load(model_path, map_location=torch.device(self.device))
|
| 364 |
+
state_dict = checkpoint['state_dict']
|
| 365 |
+
|
| 366 |
+
pretrained_dict = {}
|
| 367 |
+
# remove `module.` if model was pretrained with distributed mode
|
| 368 |
+
for k, v in state_dict.items():
|
| 369 |
+
if 'module.' in k:
|
| 370 |
+
name = k[7:]
|
| 371 |
+
else:
|
| 372 |
+
name = k
|
| 373 |
+
pretrained_dict[name] = v
|
| 374 |
+
self.encoder.model.load_state_dict(pretrained_dict, strict=False)
|
| 375 |
+
if verbose:
|
| 376 |
+
print("Pretrained model loaded!")
|
| 377 |
+
|
| 378 |
+
def switch_to_RL_stages(self, verbose=True):
|
| 379 |
+
# run convolutional channel expansion to match input shape
|
| 380 |
+
self.encoder.expand_first_layer()
|
| 381 |
+
if verbose:
|
| 382 |
+
print("Convolutional channel expansion finished: now can take in %d images as input." % self.encoder.n_images)
|
| 383 |
+
|
| 384 |
+
def train(self, training=True):
|
| 385 |
+
self.training = training
|
| 386 |
+
self.encoder.train(training)
|
| 387 |
+
self.actor.train(training)
|
| 388 |
+
self.critic.train(training)
|
| 389 |
+
|
| 390 |
+
def act(self, obs, step, eval_mode, obs_sensor=None, is_tensor_input=False, force_action_std=None):
|
| 391 |
+
"""
|
| 392 |
+
obs: 3x84x84, uint8, [0,255]
|
| 393 |
+
"""
|
| 394 |
+
# eval_mode should be False when taking an exploration action in stage 3
|
| 395 |
+
# eval_mode should be True when evaluate agent performance
|
| 396 |
+
|
| 397 |
+
if force_action_std == None:
|
| 398 |
+
stddev = utils.schedule(self.stddev_schedule, step)
|
| 399 |
+
if step < self.num_expl_steps and not eval_mode:
|
| 400 |
+
action = np.random.uniform(0, 1, (self.act_dim,)).astype(np.float32)
|
| 401 |
+
return action
|
| 402 |
+
else:
|
| 403 |
+
stddev = force_action_std
|
| 404 |
+
|
| 405 |
+
if is_tensor_input:
|
| 406 |
+
obs = self.encoder(obs)
|
| 407 |
+
else:
|
| 408 |
+
obs = torch.as_tensor(obs, device=self.device)
|
| 409 |
+
obs = self.encoder(obs.unsqueeze(0))
|
| 410 |
+
|
| 411 |
+
if obs_sensor is not None:
|
| 412 |
+
obs_sensor = torch.as_tensor(obs_sensor, device=self.device)
|
| 413 |
+
obs_sensor = obs_sensor.unsqueeze(0)
|
| 414 |
+
obs_combined = torch.cat([obs, obs_sensor], dim=1)
|
| 415 |
+
else:
|
| 416 |
+
obs_combined = obs
|
| 417 |
+
|
| 418 |
+
dist = self.actor(obs_combined, stddev)
|
| 419 |
+
if eval_mode:
|
| 420 |
+
action = dist.mean
|
| 421 |
+
else:
|
| 422 |
+
action = dist.sample(clip=None)
|
| 423 |
+
if step < self.num_expl_steps:
|
| 424 |
+
action.uniform_(-1.0, 1.0)
|
| 425 |
+
return action.cpu().numpy()[0]
|
| 426 |
+
|
| 427 |
+
def update(self, replay_iter, step, stage, use_sensor):
|
| 428 |
+
# for stage 2 and 3, we use the same functions but with different hyperparameters
|
| 429 |
+
assert stage in (2, 3)
|
| 430 |
+
metrics = dict()
|
| 431 |
+
|
| 432 |
+
if stage == 2:
|
| 433 |
+
update_encoder = self.stage2_update_encoder
|
| 434 |
+
stddev = self.stage2_std
|
| 435 |
+
conservative_loss_weight = self.cql_weight
|
| 436 |
+
bc_weight = self.stage2_bc_weight
|
| 437 |
+
|
| 438 |
+
if stage == 3:
|
| 439 |
+
if step % self.update_every_steps != 0:
|
| 440 |
+
return metrics
|
| 441 |
+
update_encoder = self.stage3_update_encoder
|
| 442 |
+
|
| 443 |
+
stddev = utils.schedule(self.stddev_schedule, step)
|
| 444 |
+
conservative_loss_weight = 0
|
| 445 |
+
|
| 446 |
+
# compute stage 3 BC weight
|
| 447 |
+
bc_data_per_iter = 40000
|
| 448 |
+
i_iter = step // bc_data_per_iter
|
| 449 |
+
bc_weight = self.stage3_bc_lam0 * self.stage3_bc_lam1 ** i_iter
|
| 450 |
+
|
| 451 |
+
# batch data
|
| 452 |
+
batch = next(replay_iter)
|
| 453 |
+
if use_sensor: # TODO might want to...?
|
| 454 |
+
obs, action, reward, discount, next_obs, obs_sensor, obs_sensor_next = utils.to_torch(batch, self.device)
|
| 455 |
+
else:
|
| 456 |
+
obs, action, reward, discount, next_obs = utils.to_torch(batch, self.device)
|
| 457 |
+
obs_sensor, obs_sensor_next = None, None
|
| 458 |
+
|
| 459 |
+
# augment
|
| 460 |
+
if self.use_data_aug:
|
| 461 |
+
obs = self.aug(obs.float())
|
| 462 |
+
next_obs = self.aug(next_obs.float())
|
| 463 |
+
else:
|
| 464 |
+
obs = obs.float()
|
| 465 |
+
next_obs = next_obs.float()
|
| 466 |
+
|
| 467 |
+
# encode
|
| 468 |
+
if update_encoder:
|
| 469 |
+
obs = self.encoder(obs)
|
| 470 |
+
else:
|
| 471 |
+
with torch.no_grad():
|
| 472 |
+
obs = self.encoder(obs)
|
| 473 |
+
|
| 474 |
+
with torch.no_grad():
|
| 475 |
+
next_obs = self.encoder(next_obs)
|
| 476 |
+
|
| 477 |
+
# concatenate obs with additional sensor observation if needed
|
| 478 |
+
obs_combined = torch.cat([obs, obs_sensor], dim=1) if obs_sensor is not None else obs
|
| 479 |
+
obs_next_combined = torch.cat([next_obs, obs_sensor_next], dim=1) if obs_sensor_next is not None else next_obs
|
| 480 |
+
|
| 481 |
+
# update critic
|
| 482 |
+
metrics.update(self.update_critic_vrl3(obs_combined, action, reward, discount, obs_next_combined,
|
| 483 |
+
stddev, update_encoder, conservative_loss_weight))
|
| 484 |
+
|
| 485 |
+
# update actor, following previous works, we do not use actor gradient for encoder update
|
| 486 |
+
metrics.update(self.update_actor_vrl3(obs_combined.detach(), action, stddev, bc_weight,
|
| 487 |
+
self.pretanh_penalty, self.pretanh_threshold))
|
| 488 |
+
|
| 489 |
+
metrics['batch_reward'] = reward.mean().item()
|
| 490 |
+
|
| 491 |
+
# update critic target networks
|
| 492 |
+
utils.soft_update_params(self.critic, self.critic_target, self.critic_target_tau)
|
| 493 |
+
return metrics
|
| 494 |
+
|
| 495 |
+
def update_critic_vrl3(self, obs, action, reward, discount, next_obs, stddev, update_encoder, conservative_loss_weight):
|
| 496 |
+
metrics = dict()
|
| 497 |
+
batch_size = obs.shape[0]
|
| 498 |
+
|
| 499 |
+
"""
|
| 500 |
+
STANDARD Q LOSS COMPUTATION:
|
| 501 |
+
- get standard Q loss first, this is the same as in any other online RL methods
|
| 502 |
+
- except for the safe Q technique, which controls how large the Q value can be
|
| 503 |
+
"""
|
| 504 |
+
with torch.no_grad():
|
| 505 |
+
dist = self.actor(next_obs, stddev)
|
| 506 |
+
next_action = dist.sample(clip=self.stddev_clip)
|
| 507 |
+
target_Q1, target_Q2 = self.critic_target(next_obs, next_action)
|
| 508 |
+
target_V = torch.min(target_Q1, target_Q2)
|
| 509 |
+
target_Q = reward + (discount * target_V)
|
| 510 |
+
|
| 511 |
+
if self.safe_q_target_factor < 1:
|
| 512 |
+
target_Q[target_Q > (self.q_threshold + 1)] = self.q_threshold + (target_Q[target_Q > (self.q_threshold+1)] - self.q_threshold) ** self.safe_q_target_factor
|
| 513 |
+
|
| 514 |
+
Q1, Q2 = self.critic(obs, action)
|
| 515 |
+
critic_loss = F.mse_loss(Q1, target_Q) + F.mse_loss(Q2, target_Q)
|
| 516 |
+
|
| 517 |
+
"""
|
| 518 |
+
CONSERVATIVE Q LOSS COMPUTATION:
|
| 519 |
+
- sample random actions, actions from policy and next actions from policy, as done in CQL authors' code
|
| 520 |
+
(though this detail is not really discussed in the CQL paper)
|
| 521 |
+
- only compute this loss when conservative loss weight > 0
|
| 522 |
+
"""
|
| 523 |
+
if conservative_loss_weight > 0:
|
| 524 |
+
random_actions = (torch.rand((batch_size * self.cql_n_random, self.act_dim), device=self.device) - 0.5) * 2
|
| 525 |
+
|
| 526 |
+
dist = self.actor(obs, stddev)
|
| 527 |
+
current_actions = dist.sample(clip=self.stddev_clip)
|
| 528 |
+
|
| 529 |
+
dist = self.actor(next_obs, stddev)
|
| 530 |
+
next_current_actions = dist.sample(clip=self.stddev_clip)
|
| 531 |
+
|
| 532 |
+
# now get Q values for all these actions (for both Q networks)
|
| 533 |
+
obs_repeat = obs.unsqueeze(1).repeat(1, self.cql_n_random, 1).view(obs.shape[0] * self.cql_n_random,
|
| 534 |
+
obs.shape[1])
|
| 535 |
+
|
| 536 |
+
Q1_rand, Q2_rand = self.critic(obs_repeat,
|
| 537 |
+
random_actions) # TODO might want to double check the logic here see if the repeat is correct
|
| 538 |
+
Q1_rand = Q1_rand.view(obs.shape[0], self.cql_n_random)
|
| 539 |
+
Q2_rand = Q2_rand.view(obs.shape[0], self.cql_n_random)
|
| 540 |
+
|
| 541 |
+
Q1_curr, Q2_curr = self.critic(obs, current_actions)
|
| 542 |
+
Q1_curr_next, Q2_curr_next = self.critic(obs, next_current_actions)
|
| 543 |
+
|
| 544 |
+
# now concat all these Q values together
|
| 545 |
+
Q1_cat = torch.cat([Q1_rand, Q1, Q1_curr, Q1_curr_next], 1)
|
| 546 |
+
Q2_cat = torch.cat([Q2_rand, Q2, Q2_curr, Q2_curr_next], 1)
|
| 547 |
+
|
| 548 |
+
cql_min_q1_loss = torch.logsumexp(Q1_cat / self.cql_temp,
|
| 549 |
+
dim=1, ).mean() * conservative_loss_weight * self.cql_temp
|
| 550 |
+
cql_min_q2_loss = torch.logsumexp(Q2_cat / self.cql_temp,
|
| 551 |
+
dim=1, ).mean() * conservative_loss_weight * self.cql_temp
|
| 552 |
+
|
| 553 |
+
"""Subtract the log likelihood of data"""
|
| 554 |
+
conservative_q_loss = cql_min_q1_loss + cql_min_q2_loss - (Q1.mean() + Q2.mean()) * conservative_loss_weight
|
| 555 |
+
critic_loss_combined = critic_loss + conservative_q_loss
|
| 556 |
+
else:
|
| 557 |
+
critic_loss_combined = critic_loss
|
| 558 |
+
|
| 559 |
+
# logging
|
| 560 |
+
metrics['critic_target_q'] = target_Q.mean().item()
|
| 561 |
+
metrics['critic_q1'] = Q1.mean().item()
|
| 562 |
+
metrics['critic_q2'] = Q2.mean().item()
|
| 563 |
+
metrics['critic_loss'] = critic_loss.item()
|
| 564 |
+
|
| 565 |
+
# if needed, also update encoder with critic loss
|
| 566 |
+
if update_encoder:
|
| 567 |
+
self.encoder_opt.zero_grad(set_to_none=True)
|
| 568 |
+
self.critic_opt.zero_grad(set_to_none=True)
|
| 569 |
+
critic_loss_combined.backward()
|
| 570 |
+
self.critic_opt.step()
|
| 571 |
+
if update_encoder:
|
| 572 |
+
self.encoder_opt.step()
|
| 573 |
+
|
| 574 |
+
return metrics
|
| 575 |
+
|
| 576 |
+
def update_actor_vrl3(self, obs, action, stddev, bc_weight, pretanh_penalty, pretanh_threshold):
|
| 577 |
+
metrics = dict()
|
| 578 |
+
|
| 579 |
+
"""
|
| 580 |
+
get standard actor loss
|
| 581 |
+
"""
|
| 582 |
+
dist, pretanh = self.actor.forward_with_pretanh(obs, stddev)
|
| 583 |
+
current_action = dist.sample(clip=self.stddev_clip)
|
| 584 |
+
log_prob = dist.log_prob(current_action).sum(-1, keepdim=True)
|
| 585 |
+
Q1, Q2 = self.critic(obs, current_action)
|
| 586 |
+
Q = torch.min(Q1, Q2)
|
| 587 |
+
actor_loss = -Q.mean()
|
| 588 |
+
|
| 589 |
+
"""
|
| 590 |
+
add BC loss
|
| 591 |
+
"""
|
| 592 |
+
if bc_weight > 0:
|
| 593 |
+
# get mean action with no action noise (though this might not be necessary)
|
| 594 |
+
stddev_bc = 0
|
| 595 |
+
dist_bc = self.actor(obs, stddev_bc)
|
| 596 |
+
current_mean_action = dist_bc.sample(clip=self.stddev_clip)
|
| 597 |
+
actor_loss_bc = F.mse_loss(current_mean_action, action) * bc_weight
|
| 598 |
+
else:
|
| 599 |
+
actor_loss_bc = torch.FloatTensor([0]).to(self.device)
|
| 600 |
+
|
| 601 |
+
"""
|
| 602 |
+
add pretanh penalty (might not be necessary for Adroit)
|
| 603 |
+
"""
|
| 604 |
+
pretanh_loss = 0
|
| 605 |
+
if pretanh_penalty > 0:
|
| 606 |
+
pretanh_loss = pretanh.abs() - pretanh_threshold
|
| 607 |
+
pretanh_loss[pretanh_loss < 0] = 0
|
| 608 |
+
pretanh_loss = (pretanh_loss ** 2).mean() * pretanh_penalty
|
| 609 |
+
|
| 610 |
+
"""
|
| 611 |
+
combine actor losses and optimize
|
| 612 |
+
"""
|
| 613 |
+
actor_loss_combined = actor_loss + actor_loss_bc + pretanh_loss
|
| 614 |
+
|
| 615 |
+
self.actor_opt.zero_grad(set_to_none=True)
|
| 616 |
+
actor_loss_combined.backward()
|
| 617 |
+
self.actor_opt.step()
|
| 618 |
+
|
| 619 |
+
metrics['actor_loss'] = actor_loss.item()
|
| 620 |
+
metrics['actor_loss_bc'] = actor_loss_bc.item()
|
| 621 |
+
metrics['actor_logprob'] = log_prob.mean().item()
|
| 622 |
+
metrics['actor_ent'] = dist.entropy().sum(dim=-1).mean().item()
|
| 623 |
+
metrics['abs_pretanh'] = pretanh.abs().mean().item()
|
| 624 |
+
metrics['max_abs_pretanh'] = pretanh.abs().max().item()
|
| 625 |
+
|
| 626 |
+
return metrics
|
| 627 |
+
|
| 628 |
+
def to(self, device):
|
| 629 |
+
self.actor.to(device)
|
| 630 |
+
self.critic.to(device)
|
| 631 |
+
self.encoder.to(device)
|
| 632 |
+
self.device = device
|
gym-0.21.0/.github/stale.yml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration for probot-stale - https://github.com/probot/stale
|
| 2 |
+
|
| 3 |
+
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
| 4 |
+
daysUntilStale: 60
|
| 5 |
+
|
| 6 |
+
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
|
| 7 |
+
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
|
| 8 |
+
daysUntilClose: 14
|
| 9 |
+
|
| 10 |
+
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
|
| 11 |
+
onlyLabels:
|
| 12 |
+
- more-information-needed
|
| 13 |
+
|
| 14 |
+
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
| 15 |
+
exemptLabels:
|
| 16 |
+
- pinned
|
| 17 |
+
- security
|
| 18 |
+
- "[Status] Maybe Later"
|
| 19 |
+
|
| 20 |
+
# Set to true to ignore issues in a project (defaults to false)
|
| 21 |
+
exemptProjects: true
|
| 22 |
+
|
| 23 |
+
# Set to true to ignore issues in a milestone (defaults to false)
|
| 24 |
+
exemptMilestones: true
|
| 25 |
+
|
| 26 |
+
# Set to true to ignore issues with an assignee (defaults to false)
|
| 27 |
+
exemptAssignees: true
|
| 28 |
+
|
| 29 |
+
# Label to use when marking as stale
|
| 30 |
+
staleLabel: stale
|
| 31 |
+
|
| 32 |
+
# Comment to post when marking as stale. Set to `false` to disable
|
| 33 |
+
markComment: >
|
| 34 |
+
This issue has been automatically marked as stale because it has not had
|
| 35 |
+
recent activity. It will be closed if no further activity occurs. Thank you
|
| 36 |
+
for your contributions.
|
| 37 |
+
|
| 38 |
+
# Comment to post when removing the stale label.
|
| 39 |
+
# unmarkComment: >
|
| 40 |
+
# Your comment here.
|
| 41 |
+
|
| 42 |
+
# Comment to post when closing a stale Issue or Pull Request.
|
| 43 |
+
# closeComment: >
|
| 44 |
+
# Your comment here.
|
| 45 |
+
|
| 46 |
+
# Limit the number of actions per hour, from 1-30. Default is 30
|
| 47 |
+
limitPerRun: 30
|
| 48 |
+
|
| 49 |
+
# Limit to only `issues` or `pulls`
|
| 50 |
+
only: issues
|
| 51 |
+
|
| 52 |
+
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
|
| 53 |
+
# pulls:
|
| 54 |
+
# daysUntilStale: 30
|
| 55 |
+
# markComment: >
|
| 56 |
+
# This pull request has been automatically marked as stale because it has not had
|
| 57 |
+
# recent activity. It will be closed if no further activity occurs. Thank you
|
| 58 |
+
# for your contributions.
|
| 59 |
+
|
| 60 |
+
# issues:
|
| 61 |
+
# exemptLabels:
|
| 62 |
+
# - confirmed
|
gym-0.21.0/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gym Contribution Guidelines
|
| 2 |
+
|
| 3 |
+
At this time we are currently accepting the current forms of contributions:
|
| 4 |
+
|
| 5 |
+
- Bug reports (keep in mind that changing environment behavior should be minimized as that requires releasing a new version of the environment and makes results hard to compare across versions)
|
| 6 |
+
- Pull requests for bug fixes
|
| 7 |
+
- Documentation improvements
|
| 8 |
+
|
| 9 |
+
Notably, we are not accepting these forms of contributions:
|
| 10 |
+
|
| 11 |
+
- New environments
|
| 12 |
+
- New features
|
| 13 |
+
|
| 14 |
+
This may change in the future.
|
| 15 |
+
If you wish to make a Gym environment, follow the instructions in [Creating Environments](https://github.com/openai/gym/blob/master/docs/creating-environments.md). When your environment works, you can make a PR to add it to the bottom of the [List of Environments](https://github.com/openai/gym/blob/master/docs/environments.md).
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
Edit July 27, 2021: Please see https://github.com/openai/gym/issues/2259 for new contributing standards
|
gym-0.21.0/README.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Gym
|
| 2 |
+
|
| 3 |
+
Gym is an open source Python library for developing and comparing reinforcement learning algorithms by providing a standard API to communicate between learning algorithms and environments, as well as a standard set of environments compliant with that API. Since its release, Gym's API has become the field standard for doing this.
|
| 4 |
+
|
| 5 |
+
Gym currently has two pieces of documentation: the [documentation website](http://gym.openai.com) and the [FAQ](https://github.com/openai/gym/wiki/FAQ). A new and more comprehensive documentation website is in the works.
|
| 6 |
+
|
| 7 |
+
## Installation
|
| 8 |
+
|
| 9 |
+
To install the base Gym library, use `pip install gym`.
|
| 10 |
+
|
| 11 |
+
This does not include dependencies for all families of environments (there's a massive number, and some can be problematic to install on certain systems). You can install these dependencies for one family like `pip install gym[atari]` or use `pip install gym[all]` to install all dependencies.
|
| 12 |
+
|
| 13 |
+
We support Python 3.6, 3.7, 3.8 and 3.9 on Linux and macOS. We will accept PRs related to Windows, but do not officially support it.
|
| 14 |
+
|
| 15 |
+
## API
|
| 16 |
+
|
| 17 |
+
The Gym API's API models environments as simple Python `env` classes. Creating environment instances and interacting with them is very simple- here's an example using the "CartPole-v1" environment:
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
import gym
|
| 21 |
+
env = gym.make('CartPole-v1')
|
| 22 |
+
|
| 23 |
+
# env is created, now we can use it:
|
| 24 |
+
for episode in range(10):
|
| 25 |
+
obs = env.reset()
|
| 26 |
+
for step in range(50):
|
| 27 |
+
action = env.action_space.sample() # or given a custom model, action = policy(observation)
|
| 28 |
+
nobs, reward, done, info = env.step(action)
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Notable Related Libraries
|
| 32 |
+
|
| 33 |
+
* [Stable Baselines 3](https://github.com/DLR-RM/stable-baselines3) is a learning library based on the Gym API. It is our recommendation for beginners who want to start learning things quickly.
|
| 34 |
+
* [RL Baselines3 Zoo](https://github.com/DLR-RM/rl-baselines3-zoo) builds upon SB3, containing optimal hyperparameters for Gym environments as well as code to easily find new ones. Such tuning is almost always required.
|
| 35 |
+
* The [Autonomous Learning Library](https://github.com/cpnota/autonomous-learning-library) and [Tianshou](https://github.com/thu-ml/tianshou) are two reinforcement learning libraries I like that are generally geared towards more experienced users.
|
| 36 |
+
* [PettingZoo](https://github.com/PettingZoo-Team/PettingZoo) is like Gym, but for environments with multiple agents.
|
| 37 |
+
|
| 38 |
+
## Environment Versioning
|
| 39 |
+
|
| 40 |
+
Gym keeps strict versioning for reproducibility reasons. All environments end in a suffix like "\_v0". When changes are made to environments that might impact learning results, the number is increased by one to prevent potential confusion.
|
| 41 |
+
|
| 42 |
+
## Citation
|
| 43 |
+
|
| 44 |
+
A whitepaper from when OpenAI Gym just came out is available https://arxiv.org/pdf/1606.01540, and can be cited with the following bibtex entry:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
@misc{1606.01540,
|
| 48 |
+
Author = {Greg Brockman and Vicki Cheung and Ludwig Pettersson and Jonas Schneider and John Schulman and Jie Tang and Wojciech Zaremba},
|
| 49 |
+
Title = {OpenAI Gym},
|
| 50 |
+
Year = {2016},
|
| 51 |
+
Eprint = {arXiv:1606.01540},
|
| 52 |
+
}
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## Release Notes
|
| 56 |
+
|
| 57 |
+
There used to be release notes for all the new Gym versions here. New release notes are being moved to [releases page](https://github.com/openai/gym/releases) on GitHub, like most other libraries do. Old notes can be viewed [here](https://github.com/openai/gym/blob/31be35ecd460f670f0c4b653a14c9996b7facc6c/README.rst).
|
gym-0.21.0/docs/toy_text/blackjack.md
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Blackjack
|
| 2 |
+
---
|
| 3 |
+
|Title|Action Type|Action Shape|Action Values|Observation Shape|Observation Values|Average Total Reward|Import|
|
| 4 |
+
| ----------- | -----------| ----------- | -----------| ----------- | -----------| ----------- | -----------|
|
| 5 |
+
|Blackjack|Discrete|(1,)|(0,1)|(3,)|[(0,31),(0,10),(0,1)]| |from gym.envs.toy_text import blackjack|
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
Blackjack is a card game where the goal is to obtain cards that sum to as near as possible to 21 without going over. They're playing against a fixed dealer.
|
| 9 |
+
|
| 10 |
+
Card Values:
|
| 11 |
+
|
| 12 |
+
- Face cards (Jack, Queen, King) have point value 10.
|
| 13 |
+
- Aces can either count as 11 or 1, and it's called 'usable ace' at 11.
|
| 14 |
+
- Numerical cards (2-9) have value of their number.
|
| 15 |
+
|
| 16 |
+
This game is placed with an infinite deck (or with replacement).
|
| 17 |
+
The game starts with dealer having one face up and one face down card, while player having two face up cards.
|
| 18 |
+
|
| 19 |
+
The player can request additional cards (hit, action=1) until they decide to stop
|
| 20 |
+
(stick, action=0) or exceed 21 (bust).
|
| 21 |
+
After the player sticks, the dealer reveals their facedown card, and draws
|
| 22 |
+
until their sum is 17 or greater. If the dealer goes bust the player wins.
|
| 23 |
+
If neither player nor dealer busts, the outcome (win, lose, draw) is
|
| 24 |
+
decided by whose sum is closer to 21.
|
| 25 |
+
|
| 26 |
+
The agent take a 1-element vector for actions.
|
| 27 |
+
The action space is `(action)`, where:
|
| 28 |
+
- `action` is used to decide stick/hit for values (0,1).
|
| 29 |
+
|
| 30 |
+
The observation of a 3-tuple of: the players current sum,
|
| 31 |
+
the dealer's one showing card (1-10 where 1 is ace), and whether or not the player holds a usable ace (0 or 1).
|
| 32 |
+
|
| 33 |
+
This environment corresponds to the version of the blackjack problem
|
| 34 |
+
described in Example 5.1 in Reinforcement Learning: An Introduction
|
| 35 |
+
by Sutton and Barto.
|
| 36 |
+
http://incompleteideas.net/book/the-book-2nd.html
|
| 37 |
+
|
| 38 |
+
**Rewards:**
|
| 39 |
+
|
| 40 |
+
Reward schedule:
|
| 41 |
+
- win game: +1
|
| 42 |
+
- lose game: -1
|
| 43 |
+
- draw game: 0
|
| 44 |
+
- win game with natural blackjack:
|
| 45 |
+
|
| 46 |
+
+1.5 (if <a href="#nat">natural</a> is True.)
|
| 47 |
+
|
| 48 |
+
+1 (if <a href="#nat">natural</a> is False.)
|
| 49 |
+
|
| 50 |
+
### Arguments
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
gym.make('Blackjack-v0', natural=False)
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
<a id="nat">`natural`</a>: Whether to give an additional reward for starting with a natural blackjack, i.e. starting with an ace and ten (sum is 21).
|
| 57 |
+
|
| 58 |
+
### Version History
|
| 59 |
+
|
| 60 |
+
* v0: Initial versions release (1.0.0)
|
gym-0.21.0/docs/toy_text/taxi.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Taxi
|
| 2 |
+
---
|
| 3 |
+
|Title|Action Type|Action Shape|Action Values|Observation Shape|Observation Values|Average Total Reward|Import|
|
| 4 |
+
| ----------- | -----------| ----------- | -----------| ----------- | -----------| ----------- | -----------|
|
| 5 |
+
|Taxi|Discrete|(1,)|(0,5)|(1,)|(0,499)| |from gym.envs.toy_text import taxi|
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
The Taxi Problem
|
| 10 |
+
from "Hierarchical Reinforcement Learning with the MAXQ Value Function Decomposition"
|
| 11 |
+
|
| 12 |
+
by Tom Dietterich
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
Description:
|
| 17 |
+
|
| 18 |
+
There are four designated locations in the grid world indicated by R(ed), G(reen), Y(ellow), and B(lue). When the episode starts, the taxi starts off at a random square and the passenger is at a random location. The taxi drives to the passenger's location, picks up the passenger, drives to the passenger's destination (another one of the four specified locations), and then drops off the passenger. Once the passenger is dropped off, the episode ends.
|
| 19 |
+
|
| 20 |
+
MAP:
|
| 21 |
+
|
| 22 |
+
+---------+
|
| 23 |
+
|R: | : :G|
|
| 24 |
+
| : | : : |
|
| 25 |
+
| : : : : |
|
| 26 |
+
| | : | : |
|
| 27 |
+
|Y| : |B: |
|
| 28 |
+
+---------+
|
| 29 |
+
|
| 30 |
+
Actions:
|
| 31 |
+
|
| 32 |
+
There are 6 discrete deterministic actions:
|
| 33 |
+
- 0: move south
|
| 34 |
+
- 1: move north
|
| 35 |
+
- 2: move east
|
| 36 |
+
- 3: move west
|
| 37 |
+
- 4: pickup passenger
|
| 38 |
+
- 5: drop off passenger
|
| 39 |
+
|
| 40 |
+
Observations:
|
| 41 |
+
|
| 42 |
+
There are 500 discrete states since there are 25 taxi positions, 5 possible locations of the passenger (including the case when the passenger is in the taxi), and 4 destination locations.
|
| 43 |
+
|
| 44 |
+
Note that there are 400 states that can actually be reached during an episode. The missing states correspond to situations in which the passenger is at the same location as their destination, as this typically signals the end of an episode.
|
| 45 |
+
Four additional states can be observed right after a successful episodes, when both the passenger and the taxi are at the destination.
|
| 46 |
+
This gives a total of 404 reachable discrete states.
|
| 47 |
+
|
| 48 |
+
Passenger locations:
|
| 49 |
+
- 0: R(ed)
|
| 50 |
+
- 1: G(reen)
|
| 51 |
+
- 2: Y(ellow)
|
| 52 |
+
- 3: B(lue)
|
| 53 |
+
- 4: in taxi
|
| 54 |
+
|
| 55 |
+
Destinations:
|
| 56 |
+
- 0: R(ed)
|
| 57 |
+
- 1: G(reen)
|
| 58 |
+
- 2: Y(ellow)
|
| 59 |
+
- 3: B(lue)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
**Rewards:**
|
| 64 |
+
|
| 65 |
+
- -1 per step reward unless other reward is triggered.
|
| 66 |
+
- +20 delivering passenger.
|
| 67 |
+
- -10 executing "pickup" and "drop-off" actions illegally.
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
Rendering:
|
| 71 |
+
- blue: passenger
|
| 72 |
+
- magenta: destination
|
| 73 |
+
- yellow: empty taxi
|
| 74 |
+
- green: full taxi
|
| 75 |
+
- other letters (R, G, Y and B): locations for passengers and destinations
|
| 76 |
+
state space is represented by:
|
| 77 |
+
(taxi_row, taxi_col, passenger_location, destination)
|
| 78 |
+
|
| 79 |
+
### Arguments
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
gym.make('Taxi-v3')
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
### Version History
|
| 88 |
+
|
| 89 |
+
* v3: Map Correction + Cleaner Domain Description
|
| 90 |
+
* v2: Disallow Taxi start location = goal location, Update Taxi observations in the rollout, Update Taxi reward threshold.
|
| 91 |
+
* v1: Remove (3,2) from locs, add passidx<4 check
|
| 92 |
+
* v0: Initial versions release
|
gym-0.21.0/scripts/generate_json.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gym import envs, logger
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
from tests.envs.spec_list import should_skip_env_spec_for_tests
|
| 8 |
+
from tests import generate_rollout_hash
|
| 9 |
+
|
| 10 |
+
DATA_DIR = os.path.join(os.path.dirname(__file__), os.pardir, "gym", "envs", "tests")
|
| 11 |
+
ROLLOUT_STEPS = 100
|
| 12 |
+
episodes = ROLLOUT_STEPS
|
| 13 |
+
steps = ROLLOUT_STEPS
|
| 14 |
+
|
| 15 |
+
ROLLOUT_FILE = os.path.join(DATA_DIR, "rollout.json")
|
| 16 |
+
|
| 17 |
+
if not os.path.isfile(ROLLOUT_FILE):
|
| 18 |
+
logger.info(
|
| 19 |
+
"No rollout file found. Writing empty json file to {}".format(ROLLOUT_FILE)
|
| 20 |
+
)
|
| 21 |
+
with open(ROLLOUT_FILE, "w") as outfile:
|
| 22 |
+
json.dump({}, outfile, indent=2)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def update_rollout_dict(spec, rollout_dict):
|
| 26 |
+
"""
|
| 27 |
+
Takes as input the environment spec for which the rollout is to be generated,
|
| 28 |
+
and the existing dictionary of rollouts. Returns True iff the dictionary was
|
| 29 |
+
modified.
|
| 30 |
+
"""
|
| 31 |
+
# Skip platform-dependent
|
| 32 |
+
if should_skip_env_spec_for_tests(spec):
|
| 33 |
+
logger.info("Skipping tests for {}".format(spec.id))
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
# Skip environments that are nondeterministic
|
| 37 |
+
if spec.nondeterministic:
|
| 38 |
+
logger.info("Skipping tests for nondeterministic env {}".format(spec.id))
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
logger.info("Generating rollout for {}".format(spec.id))
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
(
|
| 45 |
+
observations_hash,
|
| 46 |
+
actions_hash,
|
| 47 |
+
rewards_hash,
|
| 48 |
+
dones_hash,
|
| 49 |
+
) = generate_rollout_hash(spec)
|
| 50 |
+
except:
|
| 51 |
+
# If running the env generates an exception, don't write to the rollout file
|
| 52 |
+
logger.warn(
|
| 53 |
+
"Exception {} thrown while generating rollout for {}. Rollout not added.".format(
|
| 54 |
+
sys.exc_info()[0], spec.id
|
| 55 |
+
)
|
| 56 |
+
)
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
rollout = {}
|
| 60 |
+
rollout["observations"] = observations_hash
|
| 61 |
+
rollout["actions"] = actions_hash
|
| 62 |
+
rollout["rewards"] = rewards_hash
|
| 63 |
+
rollout["dones"] = dones_hash
|
| 64 |
+
|
| 65 |
+
existing = rollout_dict.get(spec.id)
|
| 66 |
+
if existing:
|
| 67 |
+
differs = False
|
| 68 |
+
for key, new_hash in rollout.items():
|
| 69 |
+
differs = differs or existing[key] != new_hash
|
| 70 |
+
if not differs:
|
| 71 |
+
logger.debug("Hashes match with existing for {}".format(spec.id))
|
| 72 |
+
return False
|
| 73 |
+
else:
|
| 74 |
+
logger.warn("Got new hash for {}. Overwriting.".format(spec.id))
|
| 75 |
+
|
| 76 |
+
rollout_dict[spec.id] = rollout
|
| 77 |
+
return True
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def add_new_rollouts(spec_ids, overwrite):
|
| 81 |
+
environments = [
|
| 82 |
+
spec for spec in envs.registry.all() if spec.entry_point is not None
|
| 83 |
+
]
|
| 84 |
+
if spec_ids:
|
| 85 |
+
environments = [spec for spec in environments if spec.id in spec_ids]
|
| 86 |
+
assert len(environments) == len(spec_ids), "Some specs not found"
|
| 87 |
+
with open(ROLLOUT_FILE) as data_file:
|
| 88 |
+
rollout_dict = json.load(data_file)
|
| 89 |
+
modified = False
|
| 90 |
+
for spec in environments:
|
| 91 |
+
if not overwrite and spec.id in rollout_dict:
|
| 92 |
+
logger.debug("Rollout already exists for {}. Skipping.".format(spec.id))
|
| 93 |
+
else:
|
| 94 |
+
modified = update_rollout_dict(spec, rollout_dict) or modified
|
| 95 |
+
|
| 96 |
+
if modified:
|
| 97 |
+
logger.info("Writing new rollout file to {}".format(ROLLOUT_FILE))
|
| 98 |
+
with open(ROLLOUT_FILE, "w") as outfile:
|
| 99 |
+
json.dump(rollout_dict, outfile, indent=2, sort_keys=True)
|
| 100 |
+
else:
|
| 101 |
+
logger.info("No modifications needed.")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
parser = argparse.ArgumentParser()
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"-f",
|
| 108 |
+
"--force",
|
| 109 |
+
action="store_true",
|
| 110 |
+
help="Overwrite " + "existing rollouts if hashes differ.",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument("-v", "--verbose", action="store_true")
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"specs", nargs="*", help="ids of env specs to check (default: all)"
|
| 115 |
+
)
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
if args.verbose:
|
| 118 |
+
logger.set_level(logger.INFO)
|
| 119 |
+
add_new_rollouts(args.specs, args.force)
|
gym-0.21.0/setup.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path
|
| 2 |
+
import sys
|
| 3 |
+
import itertools
|
| 4 |
+
|
| 5 |
+
from setuptools import find_packages, setup
|
| 6 |
+
|
| 7 |
+
# Don't import gym module here, since deps may not be installed
|
| 8 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "gym"))
|
| 9 |
+
from version import VERSION
|
| 10 |
+
|
| 11 |
+
# Environment-specific dependencies.
|
| 12 |
+
extras = {
|
| 13 |
+
"atari": ["ale-py~=0.7.1"],
|
| 14 |
+
"accept-rom-license": ["autorom[accept-rom-license]~=0.4.2"],
|
| 15 |
+
"box2d": ["box2d-py==2.3.5", "pyglet>=1.4.0"],
|
| 16 |
+
"classic_control": ["pyglet>=1.4.0"],
|
| 17 |
+
"mujoco": ["mujoco_py>=1.50, <2.0"],
|
| 18 |
+
"robotics": ["mujoco_py>=1.50, <2.0"],
|
| 19 |
+
"toy_text": ["scipy>=1.4.1"],
|
| 20 |
+
"other": ["lz4>=3.1.0", "opencv-python>=3.0"],
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
# Meta dependency groups.
|
| 24 |
+
nomujoco_blacklist = set(["mujoco", "robotics", "accept-rom-license"])
|
| 25 |
+
nomujoco_groups = set(extras.keys()) - nomujoco_blacklist
|
| 26 |
+
|
| 27 |
+
extras["nomujoco"] = list(
|
| 28 |
+
itertools.chain.from_iterable(map(lambda group: extras[group], nomujoco_groups))
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
all_blacklist = set(["accept-rom-license"])
|
| 33 |
+
all_groups = set(extras.keys()) - all_blacklist
|
| 34 |
+
|
| 35 |
+
extras["all"] = list(
|
| 36 |
+
itertools.chain.from_iterable(map(lambda group: extras[group], all_groups))
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
setup(
|
| 40 |
+
name="gym",
|
| 41 |
+
version=VERSION,
|
| 42 |
+
description="Gym: A universal API for reinforcement learning environments.",
|
| 43 |
+
url="https://github.com/openai/gym",
|
| 44 |
+
author="OpenAI",
|
| 45 |
+
author_email="jkterry@umd.edu",
|
| 46 |
+
license="",
|
| 47 |
+
packages=[package for package in find_packages() if package.startswith("gym")],
|
| 48 |
+
zip_safe=False,
|
| 49 |
+
install_requires=[
|
| 50 |
+
"numpy>=1.18.0",
|
| 51 |
+
"cloudpickle>=1.2.0",
|
| 52 |
+
"importlib_metadata>=4.8.1; python_version < '3.8'",
|
| 53 |
+
],
|
| 54 |
+
extras_require=extras,
|
| 55 |
+
package_data={
|
| 56 |
+
"gym": [
|
| 57 |
+
"envs/mujoco/assets/*.xml",
|
| 58 |
+
"envs/classic_control/assets/*.png",
|
| 59 |
+
"envs/robotics/assets/LICENSE.md",
|
| 60 |
+
"envs/robotics/assets/fetch/*.xml",
|
| 61 |
+
"envs/robotics/assets/hand/*.xml",
|
| 62 |
+
"envs/robotics/assets/stls/fetch/*.stl",
|
| 63 |
+
"envs/robotics/assets/stls/hand/*.stl",
|
| 64 |
+
"envs/robotics/assets/textures/*.png",
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
tests_require=["pytest", "mock"],
|
| 68 |
+
python_requires=">=3.6",
|
| 69 |
+
classifiers=[
|
| 70 |
+
"Programming Language :: Python :: 3",
|
| 71 |
+
"Programming Language :: Python :: 3.6",
|
| 72 |
+
"Programming Language :: Python :: 3.7",
|
| 73 |
+
"Programming Language :: Python :: 3.8",
|
| 74 |
+
"Programming Language :: Python :: 3.9",
|
| 75 |
+
],
|
| 76 |
+
)
|
mujoco-py-2.1.2.14/.gitignore
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
mujoco-py-**
|
| 2 |
+
mjkey.txt
|
| 3 |
+
mujoco_py/generated/cymj*
|
| 4 |
+
_pyxbld*
|
| 5 |
+
dist
|
| 6 |
+
cache
|
| 7 |
+
.idea/*
|
| 8 |
+
*~
|
| 9 |
+
.*~
|
| 10 |
+
*#*#
|
| 11 |
+
*.o
|
| 12 |
+
*.dat
|
| 13 |
+
*.prof
|
| 14 |
+
*.lprof
|
| 15 |
+
*.local
|
| 16 |
+
.realsync
|
| 17 |
+
.DS_Store
|
| 18 |
+
**/*.egg-info
|
| 19 |
+
.cache
|
| 20 |
+
*.ckpt
|
| 21 |
+
*.log
|
| 22 |
+
.ipynb_checkpoints
|
| 23 |
+
venv/
|
| 24 |
+
.vimrc
|
| 25 |
+
*.settings
|
| 26 |
+
*.svn
|
| 27 |
+
.project
|
| 28 |
+
.pydevproject
|
| 29 |
+
tags
|
| 30 |
+
*sublime-project
|
| 31 |
+
*sublime-workspace
|
| 32 |
+
# Intermediate outputs
|
| 33 |
+
__pycache__
|
| 34 |
+
**/__pycache__
|
| 35 |
+
*.pb.*
|
| 36 |
+
*.pyc
|
| 37 |
+
*.swp
|
| 38 |
+
*.swo
|
| 39 |
+
# generated data
|
| 40 |
+
*.rdb
|
| 41 |
+
*.db
|
| 42 |
+
*.avi
|
| 43 |
+
# mujoco outputs
|
| 44 |
+
MUJOCO_LOG.TXT
|
| 45 |
+
model.txt
|
| 46 |
+
.window_data
|
| 47 |
+
.idea/*.xml
|
| 48 |
+
outputfile
|
| 49 |
+
tmp*
|
| 50 |
+
cymj.c
|
| 51 |
+
**/.git
|
| 52 |
+
.eggs/
|
| 53 |
+
*.so
|
| 54 |
+
.python-version
|
| 55 |
+
/build
|
mujoco-py-2.1.2.14/docs/_static/.gitkeep
ADDED
|
File without changes
|
mujoco-py-2.1.2.14/docs/build/doctrees/reference.doctree
ADDED
|
Binary file (193 kB). View file
|
|
|
mujoco-py-2.1.2.14/mujoco_py.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE.md
|
| 2 |
+
MANIFEST.in
|
| 3 |
+
README.md
|
| 4 |
+
pyproject.toml
|
| 5 |
+
requirements.dev.txt
|
| 6 |
+
requirements.txt
|
| 7 |
+
setup.py
|
| 8 |
+
mujoco_py/__init__.py
|
| 9 |
+
mujoco_py/builder.py
|
| 10 |
+
mujoco_py/cymj.pyx
|
| 11 |
+
mujoco_py/mjbatchrenderer.pyx
|
| 12 |
+
mujoco_py/mjpid.pyx
|
| 13 |
+
mujoco_py/mjrendercontext.pyx
|
| 14 |
+
mujoco_py/mjrenderpool.py
|
| 15 |
+
mujoco_py/mjsim.pyx
|
| 16 |
+
mujoco_py/mjsimstate.pyx
|
| 17 |
+
mujoco_py/mjviewer.py
|
| 18 |
+
mujoco_py/modder.py
|
| 19 |
+
mujoco_py/opengl_context.pyx
|
| 20 |
+
mujoco_py/utils.py
|
| 21 |
+
mujoco_py/version.py
|
| 22 |
+
mujoco_py.egg-info/PKG-INFO
|
| 23 |
+
mujoco_py.egg-info/SOURCES.txt
|
| 24 |
+
mujoco_py.egg-info/dependency_links.txt
|
| 25 |
+
mujoco_py.egg-info/requires.txt
|
| 26 |
+
mujoco_py.egg-info/top_level.txt
|
| 27 |
+
mujoco_py/generated/__init__.py
|
| 28 |
+
mujoco_py/generated/const.py
|
| 29 |
+
mujoco_py/generated/wrappers.pxi
|
| 30 |
+
mujoco_py/gl/__init__.py
|
| 31 |
+
mujoco_py/gl/dummyshim.c
|
| 32 |
+
mujoco_py/gl/egl.h
|
| 33 |
+
mujoco_py/gl/eglext.h
|
| 34 |
+
mujoco_py/gl/eglplatform.h
|
| 35 |
+
mujoco_py/gl/eglshim.c
|
| 36 |
+
mujoco_py/gl/glshim.h
|
| 37 |
+
mujoco_py/gl/khrplatform.h
|
| 38 |
+
mujoco_py/gl/osmesashim.c
|
| 39 |
+
mujoco_py/pxd/__init__.py
|
| 40 |
+
mujoco_py/pxd/mjdata.pxd
|
| 41 |
+
mujoco_py/pxd/mjmodel.pxd
|
| 42 |
+
mujoco_py/pxd/mjrender.pxd
|
| 43 |
+
mujoco_py/pxd/mjui.pxd
|
| 44 |
+
mujoco_py/pxd/mjvisualize.pxd
|
| 45 |
+
mujoco_py/pxd/mujoco.pxd
|
| 46 |
+
mujoco_py/tests/__init__.py
|
| 47 |
+
mujoco_py/tests/include.xml
|
| 48 |
+
mujoco_py/tests/test.xml
|
| 49 |
+
mujoco_py/tests/test_composite.py
|
| 50 |
+
mujoco_py/tests/test_cymj.py
|
| 51 |
+
mujoco_py/tests/test_examples.py
|
| 52 |
+
mujoco_py/tests/test_gen_wrappers.py
|
| 53 |
+
mujoco_py/tests/test_modder.py
|
| 54 |
+
mujoco_py/tests/test_opengl_context.py
|
| 55 |
+
mujoco_py/tests/test_pid.py
|
| 56 |
+
mujoco_py/tests/test_render_pool.py
|
| 57 |
+
mujoco_py/tests/test_substep.py
|
| 58 |
+
mujoco_py/tests/test_vfs.py
|
| 59 |
+
mujoco_py/tests/test_viewer.py
|
| 60 |
+
mujoco_py/tests/utils.py
|
| 61 |
+
xmls/claw.xml
|
| 62 |
+
xmls/door.xml
|
| 63 |
+
xmls/juggler.xml
|
| 64 |
+
xmls/key.xml
|
| 65 |
+
xmls/shelf.xml
|
| 66 |
+
xmls/slider.xml
|
| 67 |
+
xmls/tosser.xml
|
mujoco-py-2.1.2.14/mujoco_py/__pycache__/builder.cpython-38.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
mujoco-py-2.1.2.14/mujoco_py/__pycache__/mjviewer.cpython-38.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
mujoco-py-2.1.2.14/mujoco_py/builder.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import distutils
|
| 2 |
+
import glob
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import subprocess
|
| 6 |
+
import sys
|
| 7 |
+
from distutils.core import Extension
|
| 8 |
+
from distutils.dist import Distribution
|
| 9 |
+
from distutils.sysconfig import customize_compiler
|
| 10 |
+
from importlib.machinery import ExtensionFileLoader
|
| 11 |
+
from os.path import abspath, dirname, exists, join, getmtime
|
| 12 |
+
from random import choice
|
| 13 |
+
from shutil import move
|
| 14 |
+
from string import ascii_lowercase
|
| 15 |
+
|
| 16 |
+
import fasteners
|
| 17 |
+
import numpy as np
|
| 18 |
+
from Cython.Build import cythonize
|
| 19 |
+
from Cython.Distutils.old_build_ext import old_build_ext as build_ext
|
| 20 |
+
from cffi import FFI
|
| 21 |
+
|
| 22 |
+
from mujoco_py.utils import discover_mujoco
|
| 23 |
+
from mujoco_py.version import get_version
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_nvidia_lib_dir():
|
| 27 |
+
exists_nvidia_smi = subprocess.call("type nvidia-smi", shell=True,
|
| 28 |
+
stdout=subprocess.PIPE, stderr=subprocess.PIPE) == 0
|
| 29 |
+
if not exists_nvidia_smi:
|
| 30 |
+
return None
|
| 31 |
+
docker_path = '/usr/local/nvidia/lib64'
|
| 32 |
+
if exists(docker_path):
|
| 33 |
+
return docker_path
|
| 34 |
+
|
| 35 |
+
nvidia_path = '/usr/lib/nvidia'
|
| 36 |
+
if exists(nvidia_path):
|
| 37 |
+
return nvidia_path
|
| 38 |
+
|
| 39 |
+
paths = glob.glob('/usr/lib/nvidia-[0-9][0-9][0-9]')
|
| 40 |
+
paths = sorted(paths)
|
| 41 |
+
if len(paths) == 0:
|
| 42 |
+
return None
|
| 43 |
+
if len(paths) > 1:
|
| 44 |
+
print("Choosing the latest nvidia driver: %s, among %s" % (paths[-1], str(paths)))
|
| 45 |
+
|
| 46 |
+
return paths[-1]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_cython_ext(mujoco_path):
|
| 50 |
+
"""
|
| 51 |
+
Loads the cymj Cython extension. This is safe to be called from
|
| 52 |
+
multiple processes running on the same machine.
|
| 53 |
+
|
| 54 |
+
Cython only gives us back the raw path, regardless of whether
|
| 55 |
+
it found a cached version or actually compiled. Since we do
|
| 56 |
+
non-idempotent postprocessing of the DLL, be extra careful
|
| 57 |
+
to only do that once and then atomically move to the final
|
| 58 |
+
location.
|
| 59 |
+
"""
|
| 60 |
+
if ('glfw' in sys.modules and
|
| 61 |
+
'mujoco' in abspath(sys.modules["glfw"].__file__)):
|
| 62 |
+
print('''
|
| 63 |
+
WARNING: Existing glfw python module detected!
|
| 64 |
+
|
| 65 |
+
MuJoCo comes with its own version of GLFW, so it's preferable to use that one.
|
| 66 |
+
|
| 67 |
+
The easy solution is to `import mujoco_py` _before_ `import glfw`.
|
| 68 |
+
''')
|
| 69 |
+
|
| 70 |
+
lib_path = os.path.join(mujoco_path, "bin")
|
| 71 |
+
if sys.platform == 'darwin':
|
| 72 |
+
Builder = MacExtensionBuilder
|
| 73 |
+
elif sys.platform == 'linux':
|
| 74 |
+
_ensure_set_env_var("LD_LIBRARY_PATH", lib_path)
|
| 75 |
+
if os.getenv('MUJOCO_PY_FORCE_CPU') is None and get_nvidia_lib_dir() is not None:
|
| 76 |
+
_ensure_set_env_var("LD_LIBRARY_PATH", get_nvidia_lib_dir())
|
| 77 |
+
Builder = LinuxGPUExtensionBuilder
|
| 78 |
+
else:
|
| 79 |
+
Builder = LinuxGPUExtensionBuilder
|
| 80 |
+
elif sys.platform.startswith("win"):
|
| 81 |
+
var = "PATH"
|
| 82 |
+
if var not in os.environ or lib_path not in os.environ[var].split(";"):
|
| 83 |
+
raise Exception("Please add mujoco library to your PATH:\n"
|
| 84 |
+
"set %s=%s;%%%s%%" % (var, lib_path, var))
|
| 85 |
+
Builder = WindowsExtensionBuilder
|
| 86 |
+
else:
|
| 87 |
+
raise RuntimeError("Unsupported platform %s" % sys.platform)
|
| 88 |
+
|
| 89 |
+
builder = Builder(mujoco_path)
|
| 90 |
+
cext_so_path = builder.get_so_file_path()
|
| 91 |
+
|
| 92 |
+
lockpath = os.path.join(os.path.dirname(cext_so_path), 'mujocopy-buildlock')
|
| 93 |
+
|
| 94 |
+
with fasteners.InterProcessLock(lockpath):
|
| 95 |
+
mod = None
|
| 96 |
+
force_rebuild = os.environ.get('MUJOCO_PY_FORCE_REBUILD')
|
| 97 |
+
if force_rebuild:
|
| 98 |
+
# Try to remove the old file, ignore errors if it doesn't exist
|
| 99 |
+
print("Removing old mujoco_py cext", cext_so_path)
|
| 100 |
+
try:
|
| 101 |
+
os.remove(cext_so_path)
|
| 102 |
+
except OSError:
|
| 103 |
+
pass
|
| 104 |
+
if exists(cext_so_path):
|
| 105 |
+
try:
|
| 106 |
+
mod = load_dynamic_ext('cymj', cext_so_path)
|
| 107 |
+
except ImportError:
|
| 108 |
+
print("Import error. Trying to rebuild mujoco_py.")
|
| 109 |
+
if mod is None:
|
| 110 |
+
cext_so_path = builder.build()
|
| 111 |
+
mod = load_dynamic_ext('cymj', cext_so_path)
|
| 112 |
+
|
| 113 |
+
return mod
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _ensure_set_env_var(var_name, lib_path):
|
| 117 |
+
paths = os.environ.get(var_name, "").split(":")
|
| 118 |
+
paths = [os.path.abspath(path) for path in paths]
|
| 119 |
+
if lib_path not in paths:
|
| 120 |
+
raise Exception("\nMissing path to your environment variable. \n"
|
| 121 |
+
"Current values %s=%s\n"
|
| 122 |
+
"Please add following line to .bashrc:\n"
|
| 123 |
+
"export %s=$%s:%s" % (var_name, os.environ.get(var_name, ""),
|
| 124 |
+
var_name, var_name, lib_path))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def load_dynamic_ext(name, path):
|
| 128 |
+
""" Load compiled shared object and return as python module. """
|
| 129 |
+
loader = ExtensionFileLoader(name, path)
|
| 130 |
+
return loader.load_module()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class custom_build_ext(build_ext):
|
| 134 |
+
"""
|
| 135 |
+
Custom build_ext to suppress the "-Wstrict-prototypes" warning.
|
| 136 |
+
It arises from the fact that we're using C++. This seems to be
|
| 137 |
+
the cleanest way to get rid of the extra flag.
|
| 138 |
+
|
| 139 |
+
See http://stackoverflow.com/a/36293331/248400
|
| 140 |
+
"""
|
| 141 |
+
|
| 142 |
+
def build_extensions(self):
|
| 143 |
+
customize_compiler(self.compiler)
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
self.compiler.compiler_so.remove("-Wstrict-prototypes")
|
| 147 |
+
except (AttributeError, ValueError):
|
| 148 |
+
pass
|
| 149 |
+
build_ext.build_extensions(self)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def fix_shared_library(so_file, name, library_path):
|
| 153 |
+
""" Used to fixup shared libraries on Linux """
|
| 154 |
+
subprocess.check_call(['patchelf', '--remove-rpath', so_file])
|
| 155 |
+
ldd_output = subprocess.check_output(['ldd', so_file]).decode('utf-8')
|
| 156 |
+
|
| 157 |
+
if name in ldd_output:
|
| 158 |
+
subprocess.check_call(['patchelf', '--remove-needed', name, so_file])
|
| 159 |
+
subprocess.check_call(['patchelf', '--add-needed', library_path, so_file])
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def manually_link_libraries(mujoco_path, raw_cext_dll_path):
|
| 163 |
+
""" Used to fix mujoco library linking on Mac """
|
| 164 |
+
root, ext = os.path.splitext(raw_cext_dll_path)
|
| 165 |
+
final_cext_dll_path = root + '_final' + ext
|
| 166 |
+
|
| 167 |
+
# If someone else already built the final DLL, don't bother
|
| 168 |
+
# recreating it here, even though this should still be idempotent.
|
| 169 |
+
if (exists(final_cext_dll_path) and
|
| 170 |
+
getmtime(final_cext_dll_path) >= getmtime(raw_cext_dll_path)):
|
| 171 |
+
return final_cext_dll_path
|
| 172 |
+
|
| 173 |
+
tmp_final_cext_dll_path = final_cext_dll_path + '~'
|
| 174 |
+
shutil.copyfile(raw_cext_dll_path, tmp_final_cext_dll_path)
|
| 175 |
+
|
| 176 |
+
mj_bin_path = join(mujoco_path, 'bin')
|
| 177 |
+
|
| 178 |
+
# Fix the rpath of the generated library -- i lost the Stackoverflow
|
| 179 |
+
# reference here
|
| 180 |
+
from_mujoco_path = '@executable_path/libmujoco210.dylib'
|
| 181 |
+
to_mujoco_path = '%s/libmujoco210.dylib' % mj_bin_path
|
| 182 |
+
subprocess.check_call(['install_name_tool',
|
| 183 |
+
'-change',
|
| 184 |
+
from_mujoco_path,
|
| 185 |
+
to_mujoco_path,
|
| 186 |
+
tmp_final_cext_dll_path])
|
| 187 |
+
|
| 188 |
+
from_glfw_path = 'libglfw.3.dylib'
|
| 189 |
+
to_glfw_path = os.path.join(mj_bin_path, 'libglfw.3.dylib')
|
| 190 |
+
subprocess.check_call(['install_name_tool',
|
| 191 |
+
'-change',
|
| 192 |
+
from_glfw_path,
|
| 193 |
+
to_glfw_path,
|
| 194 |
+
tmp_final_cext_dll_path])
|
| 195 |
+
|
| 196 |
+
os.rename(tmp_final_cext_dll_path, final_cext_dll_path)
|
| 197 |
+
return final_cext_dll_path
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class MujocoExtensionBuilder():
|
| 201 |
+
|
| 202 |
+
CYMJ_DIR_PATH = abspath(dirname(__file__))
|
| 203 |
+
|
| 204 |
+
def __init__(self, mujoco_path):
|
| 205 |
+
self.mujoco_path = mujoco_path
|
| 206 |
+
python_version = str(sys.version_info.major) + str(sys.version_info.minor)
|
| 207 |
+
self.version = '%s_%s_%s' % (get_version(), python_version, self.build_base())
|
| 208 |
+
self.extension = Extension(
|
| 209 |
+
'mujoco_py.cymj',
|
| 210 |
+
sources=[join(self.CYMJ_DIR_PATH, "cymj.pyx")],
|
| 211 |
+
include_dirs=[
|
| 212 |
+
self.CYMJ_DIR_PATH,
|
| 213 |
+
join(mujoco_path, 'include'),
|
| 214 |
+
np.get_include(),
|
| 215 |
+
],
|
| 216 |
+
libraries=['mujoco210'],
|
| 217 |
+
library_dirs=[join(mujoco_path, 'bin')],
|
| 218 |
+
extra_compile_args=[
|
| 219 |
+
'-fopenmp', # needed for OpenMP
|
| 220 |
+
'-w', # suppress numpy compilation warnings
|
| 221 |
+
],
|
| 222 |
+
extra_link_args=['-fopenmp'],
|
| 223 |
+
language='c')
|
| 224 |
+
|
| 225 |
+
def build(self):
|
| 226 |
+
built_so_file_path = self._build_impl()
|
| 227 |
+
new_so_file_path = self.get_so_file_path()
|
| 228 |
+
move(built_so_file_path, new_so_file_path)
|
| 229 |
+
return new_so_file_path
|
| 230 |
+
|
| 231 |
+
def build_base(self):
|
| 232 |
+
return self.__class__.__name__.lower()
|
| 233 |
+
|
| 234 |
+
def _build_impl(self):
|
| 235 |
+
dist = Distribution({
|
| 236 |
+
"script_name": None,
|
| 237 |
+
"script_args": ["build_ext"]
|
| 238 |
+
})
|
| 239 |
+
dist.ext_modules = cythonize([self.extension])
|
| 240 |
+
dist.include_dirs = []
|
| 241 |
+
dist.cmdclass = {'build_ext': custom_build_ext}
|
| 242 |
+
build = dist.get_command_obj('build')
|
| 243 |
+
# following the convention of cython's pyxbuild and naming
|
| 244 |
+
# base directory "_pyxbld"
|
| 245 |
+
build.build_base = join(self.CYMJ_DIR_PATH, 'generated',
|
| 246 |
+
'_pyxbld_%s' % (self.version))
|
| 247 |
+
dist.parse_command_line()
|
| 248 |
+
obj_build_ext = dist.get_command_obj("build_ext")
|
| 249 |
+
dist.run_commands()
|
| 250 |
+
built_so_file_path, = obj_build_ext.get_outputs()
|
| 251 |
+
return built_so_file_path
|
| 252 |
+
|
| 253 |
+
def get_so_file_path(self):
|
| 254 |
+
dir_path = abspath(dirname(__file__))
|
| 255 |
+
python_version = str(sys.version_info.major) + str(sys.version_info.minor)
|
| 256 |
+
return join(dir_path, "generated", "cymj_{}_{}.so".format(self.version, python_version))
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class WindowsExtensionBuilder(MujocoExtensionBuilder):
|
| 260 |
+
|
| 261 |
+
def __init__(self, mujoco_path):
|
| 262 |
+
super().__init__(mujoco_path)
|
| 263 |
+
os.environ["PATH"] += ";" + join(mujoco_path, "bin")
|
| 264 |
+
self.extension.sources.append(self.CYMJ_DIR_PATH + "/gl/dummyshim.c")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class LinuxCPUExtensionBuilder(MujocoExtensionBuilder):
|
| 268 |
+
|
| 269 |
+
def __init__(self, mujoco_path):
|
| 270 |
+
super().__init__(mujoco_path)
|
| 271 |
+
|
| 272 |
+
self.extension.sources.append(
|
| 273 |
+
join(self.CYMJ_DIR_PATH, "gl", "osmesashim.c"))
|
| 274 |
+
self.extension.libraries.extend(['glewosmesa', 'OSMesa', 'GL'])
|
| 275 |
+
self.extension.runtime_library_dirs = [join(mujoco_path, 'bin')]
|
| 276 |
+
|
| 277 |
+
def _build_impl(self):
|
| 278 |
+
so_file_path = super()._build_impl()
|
| 279 |
+
# Removes absolute paths to libraries. Allows for dynamic loading.
|
| 280 |
+
fix_shared_library(so_file_path, 'libmujoco210.so', 'libmujoco210.so')
|
| 281 |
+
fix_shared_library(so_file_path, 'libglewosmesa.so', 'libglewosmesa.so')
|
| 282 |
+
return so_file_path
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class LinuxGPUExtensionBuilder(MujocoExtensionBuilder):
|
| 286 |
+
|
| 287 |
+
def __init__(self, mujoco_path):
|
| 288 |
+
super().__init__(mujoco_path)
|
| 289 |
+
|
| 290 |
+
self.extension.sources.append(self.CYMJ_DIR_PATH + "/gl/eglshim.c")
|
| 291 |
+
self.extension.include_dirs.append(self.CYMJ_DIR_PATH + '/vendor/egl')
|
| 292 |
+
self.extension.libraries.extend(['glewegl'])
|
| 293 |
+
self.extension.runtime_library_dirs = [join(mujoco_path, 'bin')]
|
| 294 |
+
|
| 295 |
+
def _build_impl(self):
|
| 296 |
+
so_file_path = super()._build_impl()
|
| 297 |
+
fix_shared_library(so_file_path, 'libOpenGL.so', 'libOpenGL.so.0')
|
| 298 |
+
fix_shared_library(so_file_path, 'libEGL.so', 'libEGL.so.1')
|
| 299 |
+
fix_shared_library(so_file_path, 'libmujoco210.so', 'libmujoco210.so')
|
| 300 |
+
fix_shared_library(so_file_path, 'libglewegl.so', 'libglewegl.so')
|
| 301 |
+
return so_file_path
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
class MacExtensionBuilder(MujocoExtensionBuilder):
|
| 305 |
+
|
| 306 |
+
def __init__(self, mujoco_path):
|
| 307 |
+
super().__init__(mujoco_path)
|
| 308 |
+
|
| 309 |
+
self.extension.sources.append(self.CYMJ_DIR_PATH + "/gl/dummyshim.c")
|
| 310 |
+
self.extension.libraries.extend(['glfw.3'])
|
| 311 |
+
self.extension.define_macros = [('ONMAC', None)]
|
| 312 |
+
self.extension.runtime_library_dirs = [join(mujoco_path, 'bin')]
|
| 313 |
+
|
| 314 |
+
def _build_impl(self):
|
| 315 |
+
if not os.environ.get('CC'):
|
| 316 |
+
# Known-working versions of GCC on mac (prefer latest one)
|
| 317 |
+
c_compilers = [
|
| 318 |
+
'/usr/local/bin/gcc-9',
|
| 319 |
+
'/usr/local/bin/gcc-8',
|
| 320 |
+
'/usr/local/bin/gcc-7',
|
| 321 |
+
'/usr/local/bin/gcc-6',
|
| 322 |
+
'/opt/local/bin/gcc-mp-9',
|
| 323 |
+
'/opt/local/bin/gcc-mp-8',
|
| 324 |
+
'/opt/local/bin/gcc-mp-7',
|
| 325 |
+
'/opt/local/bin/gcc-mp-6',
|
| 326 |
+
]
|
| 327 |
+
available_c_compiler = None
|
| 328 |
+
for c_compiler in c_compilers:
|
| 329 |
+
if distutils.spawn.find_executable(c_compiler) is not None:
|
| 330 |
+
available_c_compiler = c_compiler
|
| 331 |
+
break
|
| 332 |
+
if available_c_compiler is None:
|
| 333 |
+
raise RuntimeError(
|
| 334 |
+
'Could not find supported GCC executable.\n\n'
|
| 335 |
+
'HINT: On OS X, install GCC 9.x with '
|
| 336 |
+
'`brew install gcc@9`. or '
|
| 337 |
+
'`port install gcc9`.')
|
| 338 |
+
os.environ['CC'] = available_c_compiler
|
| 339 |
+
|
| 340 |
+
so_file_path = super()._build_impl()
|
| 341 |
+
del os.environ['CC']
|
| 342 |
+
else: # User-directed c compiler
|
| 343 |
+
so_file_path = super()._build_impl()
|
| 344 |
+
return manually_link_libraries(self.mujoco_path, so_file_path)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class MujocoException(Exception):
|
| 348 |
+
pass
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def user_warning_raise_exception(warn_bytes):
|
| 352 |
+
'''
|
| 353 |
+
User-defined warning callback, which is called by mujoco on warnings.
|
| 354 |
+
Here we have two primary jobs:
|
| 355 |
+
- Detect known warnings and suggest fixes (with code)
|
| 356 |
+
- Decide whether to raise an Exception and raise if needed
|
| 357 |
+
More cases should be added as we find new failures.
|
| 358 |
+
'''
|
| 359 |
+
# TODO: look through test output to see MuJoCo warnings to catch
|
| 360 |
+
# and recommend. Also fix those tests
|
| 361 |
+
warn = warn_bytes.decode() # Convert bytes to string
|
| 362 |
+
if 'Pre-allocated constraint buffer is full' in warn:
|
| 363 |
+
raise MujocoException(warn + 'Increase njmax in mujoco XML')
|
| 364 |
+
if 'Pre-allocated contact buffer is full' in warn:
|
| 365 |
+
raise MujocoException(warn + 'Increase njconmax in mujoco XML')
|
| 366 |
+
# This unhelpfully-named warning is what you get if you feed MuJoCo NaNs
|
| 367 |
+
if 'Unknown warning type' in warn:
|
| 368 |
+
raise MujocoException(warn + 'Check for NaN in simulation.')
|
| 369 |
+
raise MujocoException('Got MuJoCo Warning: {}'.format(warn))
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def user_warning_ignore_exception(warn_bytes):
|
| 373 |
+
pass
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class ignore_mujoco_warnings:
|
| 377 |
+
"""
|
| 378 |
+
Class to turn off mujoco warning exceptions within a scope. Useful for
|
| 379 |
+
large, vectorized rollouts.
|
| 380 |
+
"""
|
| 381 |
+
|
| 382 |
+
def __enter__(self):
|
| 383 |
+
self.prev_user_warning = cymj.get_warning_callback()
|
| 384 |
+
cymj.set_warning_callback(user_warning_ignore_exception)
|
| 385 |
+
return self
|
| 386 |
+
|
| 387 |
+
def __exit__(self, type, value, traceback):
|
| 388 |
+
cymj.set_warning_callback(self.prev_user_warning)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def build_fn_cleanup(name):
|
| 392 |
+
'''
|
| 393 |
+
Cleanup files generated by building callback.
|
| 394 |
+
Set the MUJOCO_PY_DEBUG_FN_BUILDER environment variable to disable cleanup.
|
| 395 |
+
'''
|
| 396 |
+
if not os.environ.get('MUJOCO_PY_DEBUG_FN_BUILDER', False):
|
| 397 |
+
for f in glob.glob(name + '*'):
|
| 398 |
+
try:
|
| 399 |
+
os.remove(f)
|
| 400 |
+
except PermissionError as e:
|
| 401 |
+
# This happens trying to remove libraries on appveyor
|
| 402 |
+
print('Error removing {}, continuing anyway: {}'.format(f, e))
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
def build_callback_fn(function_string, userdata_names=[]):
|
| 406 |
+
'''
|
| 407 |
+
Builds a C callback function and returns a function pointer int.
|
| 408 |
+
|
| 409 |
+
function_string : str
|
| 410 |
+
This is a string of the C function to be compiled
|
| 411 |
+
userdata_names : list or tuple
|
| 412 |
+
This is an optional list to defince convenience names
|
| 413 |
+
|
| 414 |
+
We compile and link and load the function, and return a function pointer.
|
| 415 |
+
See `MjSim.set_substep_callback()` for an example use of these callbacks.
|
| 416 |
+
|
| 417 |
+
The callback function should match the signature:
|
| 418 |
+
void fun(const mjModel *m, mjData *d);
|
| 419 |
+
|
| 420 |
+
Here's an example function_string:
|
| 421 |
+
```
|
| 422 |
+
"""
|
| 423 |
+
#include <stdio.h>
|
| 424 |
+
void fun(const mjModel* m, mjData* d) {
|
| 425 |
+
printf("hello");
|
| 426 |
+
}
|
| 427 |
+
"""
|
| 428 |
+
```
|
| 429 |
+
|
| 430 |
+
Input and output for the function pass through userdata in the data struct:
|
| 431 |
+
```
|
| 432 |
+
"""
|
| 433 |
+
void fun(const mjModel* m, mjData* d) {
|
| 434 |
+
d->userdata[0] += 1;
|
| 435 |
+
}
|
| 436 |
+
"""
|
| 437 |
+
```
|
| 438 |
+
|
| 439 |
+
`userdata_names` is expected to match the model where the callback is used.
|
| 440 |
+
These can bet set on a model with:
|
| 441 |
+
`model.set_userdata_names([...])`
|
| 442 |
+
|
| 443 |
+
If `userdata_names` is supplied, convenience `#define`s are added for each.
|
| 444 |
+
For example:
|
| 445 |
+
`userdata_names = ['my_sum']`
|
| 446 |
+
Will get gerenerated into the extra line:
|
| 447 |
+
`#define my_sum d->userdata[0]`
|
| 448 |
+
And prepended to the top of the function before compilation.
|
| 449 |
+
Here's an example that takes advantage of this:
|
| 450 |
+
```
|
| 451 |
+
"""
|
| 452 |
+
void fun(const mjModel* m, mjData* d) {
|
| 453 |
+
for (int i = 0; i < m->nu; i++) {
|
| 454 |
+
my_sum += d->ctrl[i];
|
| 455 |
+
}
|
| 456 |
+
}
|
| 457 |
+
"""
|
| 458 |
+
```
|
| 459 |
+
Note these are just C `#define`s and are limited in how they can be used.
|
| 460 |
+
|
| 461 |
+
After compilation, the built library containing the function is loaded
|
| 462 |
+
into memory and all of the files (including the library) are deleted.
|
| 463 |
+
To retain these for debugging set the `MUJOCO_PY_DEBUG_FN_BUILDER` envvar.
|
| 464 |
+
|
| 465 |
+
To save time compiling, these function pointers may be re-used by many
|
| 466 |
+
different consumers. They are thread-safe and don't acquire the GIL.
|
| 467 |
+
|
| 468 |
+
See the file `tests/test_substep.py` for additional examples,
|
| 469 |
+
including an example which iterates over contacts to compute penetrations.
|
| 470 |
+
'''
|
| 471 |
+
assert isinstance(userdata_names, (list, tuple)), \
|
| 472 |
+
'invalid userdata_names: {}'.format(userdata_names)
|
| 473 |
+
ffibuilder = FFI()
|
| 474 |
+
ffibuilder.cdef('extern uintptr_t __fun;')
|
| 475 |
+
name = '_fn_' + ''.join(choice(ascii_lowercase) for _ in range(15))
|
| 476 |
+
source_string = '#include <mujoco.h>\n'
|
| 477 |
+
# Add defines for each userdata to make setting them easier
|
| 478 |
+
for i, data_name in enumerate(userdata_names):
|
| 479 |
+
source_string += '#define {} d->userdata[{}]\n'.format(data_name, i)
|
| 480 |
+
source_string += function_string
|
| 481 |
+
source_string += '\nuintptr_t __fun = (uintptr_t) fun;'
|
| 482 |
+
# Link against mujoco so we can call mujoco functions from within callback
|
| 483 |
+
ffibuilder.set_source(name, source_string,
|
| 484 |
+
include_dirs=[join(mujoco_path, 'include')],
|
| 485 |
+
library_dirs=[join(mujoco_path, 'bin')],
|
| 486 |
+
libraries=['mujoco210'])
|
| 487 |
+
# Catch compilation exceptions so we can cleanup partial files in that case
|
| 488 |
+
try:
|
| 489 |
+
library_path = ffibuilder.compile(verbose=True)
|
| 490 |
+
except Exception as e:
|
| 491 |
+
build_fn_cleanup(name)
|
| 492 |
+
raise e
|
| 493 |
+
# On Mac the MuJoCo library is linked strangely, so we have to fix it here
|
| 494 |
+
if sys.platform == 'darwin':
|
| 495 |
+
fixed_library_path = manually_link_libraries(mujoco_path, library_path)
|
| 496 |
+
move(fixed_library_path, library_path) # Overwrite with fixed library
|
| 497 |
+
module = load_dynamic_ext(name, library_path)
|
| 498 |
+
# Now that the module is loaded into memory, we can actually delete it
|
| 499 |
+
build_fn_cleanup(name)
|
| 500 |
+
return module.lib.__fun
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
mujoco_path = discover_mujoco()
|
| 504 |
+
cymj = load_cython_ext(mujoco_path)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
# Trick to expose all mj* functions from mujoco in mujoco_py.*
|
| 508 |
+
class dict2(object):
|
| 509 |
+
pass
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
functions = dict2()
|
| 513 |
+
for func_name in dir(cymj):
|
| 514 |
+
if func_name.startswith("_mj"):
|
| 515 |
+
setattr(functions, func_name[1:], getattr(cymj, func_name))
|
| 516 |
+
|
| 517 |
+
# Set user-defined callbacks that raise assertion with message
|
| 518 |
+
cymj.set_warning_callback(user_warning_raise_exception)
|
mujoco-py-2.1.2.14/mujoco_py/gl/eglplatform.h
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef __eglplatform_h_
|
| 2 |
+
#define __eglplatform_h_
|
| 3 |
+
|
| 4 |
+
/*
|
| 5 |
+
** Copyright (c) 2007-2013 The Khronos Group Inc.
|
| 6 |
+
**
|
| 7 |
+
** Permission is hereby granted, free of charge, to any person obtaining a
|
| 8 |
+
** copy of this software and/or associated documentation files (the
|
| 9 |
+
** "Materials"), to deal in the Materials without restriction, including
|
| 10 |
+
** without limitation the rights to use, copy, modify, merge, publish,
|
| 11 |
+
** distribute, sublicense, and/or sell copies of the Materials, and to
|
| 12 |
+
** permit persons to whom the Materials are furnished to do so, subject to
|
| 13 |
+
** the following conditions:
|
| 14 |
+
**
|
| 15 |
+
** The above copyright notice and this permission notice shall be included
|
| 16 |
+
** in all copies or substantial portions of the Materials.
|
| 17 |
+
**
|
| 18 |
+
** THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 19 |
+
** EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 20 |
+
** MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
| 21 |
+
** IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
| 22 |
+
** CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
| 23 |
+
** TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
| 24 |
+
** MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
|
| 25 |
+
*/
|
| 26 |
+
|
| 27 |
+
/* Platform-specific types and definitions for egl.h
|
| 28 |
+
* $Revision: 30994 $ on $Date: 2015-04-30 13:36:48 -0700 (Thu, 30 Apr 2015) $
|
| 29 |
+
*
|
| 30 |
+
* Adopters may modify khrplatform.h and this file to suit their platform.
|
| 31 |
+
* You are encouraged to submit all modifications to the Khronos group so that
|
| 32 |
+
* they can be included in future versions of this file. Please submit changes
|
| 33 |
+
* by sending them to the public Khronos Bugzilla (http://khronos.org/bugzilla)
|
| 34 |
+
* by filing a bug against product "EGL" component "Registry".
|
| 35 |
+
*/
|
| 36 |
+
|
| 37 |
+
#include "khrplatform.h"
|
| 38 |
+
|
| 39 |
+
/* Macros used in EGL function prototype declarations.
|
| 40 |
+
*
|
| 41 |
+
* EGL functions should be prototyped as:
|
| 42 |
+
*
|
| 43 |
+
* EGLAPI return-type EGLAPIENTRY eglFunction(arguments);
|
| 44 |
+
* typedef return-type (EXPAPIENTRYP PFNEGLFUNCTIONPROC) (arguments);
|
| 45 |
+
*
|
| 46 |
+
* KHRONOS_APICALL and KHRONOS_APIENTRY are defined in KHR/khrplatform.h
|
| 47 |
+
*/
|
| 48 |
+
|
| 49 |
+
#ifndef EGLAPI
|
| 50 |
+
#define EGLAPI KHRONOS_APICALL
|
| 51 |
+
#endif
|
| 52 |
+
|
| 53 |
+
#ifndef EGLAPIENTRY
|
| 54 |
+
#define EGLAPIENTRY KHRONOS_APIENTRY
|
| 55 |
+
#endif
|
| 56 |
+
#define EGLAPIENTRYP EGLAPIENTRY*
|
| 57 |
+
|
| 58 |
+
/* The types NativeDisplayType, NativeWindowType, and NativePixmapType
|
| 59 |
+
* are aliases of window-system-dependent types, such as X Display * or
|
| 60 |
+
* Windows Device Context. They must be defined in platform-specific
|
| 61 |
+
* code below. The EGL-prefixed versions of Native*Type are the same
|
| 62 |
+
* types, renamed in EGL 1.3 so all types in the API start with "EGL".
|
| 63 |
+
*
|
| 64 |
+
* Khronos STRONGLY RECOMMENDS that you use the default definitions
|
| 65 |
+
* provided below, since these changes affect both binary and source
|
| 66 |
+
* portability of applications using EGL running on different EGL
|
| 67 |
+
* implementations.
|
| 68 |
+
*/
|
| 69 |
+
|
| 70 |
+
#if defined(_WIN32) || defined(__VC32__) && !defined(__CYGWIN__) && !defined(__SCITECH_SNAP__) /* Win32 and WinCE */
|
| 71 |
+
#ifndef WIN32_LEAN_AND_MEAN
|
| 72 |
+
#define WIN32_LEAN_AND_MEAN 1
|
| 73 |
+
#endif
|
| 74 |
+
#include <windows.h>
|
| 75 |
+
|
| 76 |
+
typedef HDC EGLNativeDisplayType;
|
| 77 |
+
typedef HBITMAP EGLNativePixmapType;
|
| 78 |
+
typedef HWND EGLNativeWindowType;
|
| 79 |
+
|
| 80 |
+
#elif defined(__APPLE__) || defined(__WINSCW__) || defined(__SYMBIAN32__) /* Symbian */
|
| 81 |
+
|
| 82 |
+
typedef int EGLNativeDisplayType;
|
| 83 |
+
typedef void *EGLNativeWindowType;
|
| 84 |
+
typedef void *EGLNativePixmapType;
|
| 85 |
+
|
| 86 |
+
#elif defined(__ANDROID__) || defined(ANDROID)
|
| 87 |
+
|
| 88 |
+
#include <android/native_window.h>
|
| 89 |
+
|
| 90 |
+
struct egl_native_pixmap_t;
|
| 91 |
+
|
| 92 |
+
typedef struct ANativeWindow* EGLNativeWindowType;
|
| 93 |
+
typedef struct egl_native_pixmap_t* EGLNativePixmapType;
|
| 94 |
+
typedef void* EGLNativeDisplayType;
|
| 95 |
+
|
| 96 |
+
#elif defined(__unix__)
|
| 97 |
+
|
| 98 |
+
/* X11 (tentative) */
|
| 99 |
+
#include <X11/Xlib.h>
|
| 100 |
+
#include <X11/Xutil.h>
|
| 101 |
+
|
| 102 |
+
typedef Display *EGLNativeDisplayType;
|
| 103 |
+
typedef Pixmap EGLNativePixmapType;
|
| 104 |
+
typedef Window EGLNativeWindowType;
|
| 105 |
+
|
| 106 |
+
#else
|
| 107 |
+
#error "Platform not recognized"
|
| 108 |
+
#endif
|
| 109 |
+
|
| 110 |
+
/* EGL 1.2 types, renamed for consistency in EGL 1.3 */
|
| 111 |
+
typedef EGLNativeDisplayType NativeDisplayType;
|
| 112 |
+
typedef EGLNativePixmapType NativePixmapType;
|
| 113 |
+
typedef EGLNativeWindowType NativeWindowType;
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
/* Define EGLint. This must be a signed integral type large enough to contain
|
| 117 |
+
* all legal attribute names and values passed into and out of EGL, whether
|
| 118 |
+
* their type is boolean, bitmask, enumerant (symbolic constant), integer,
|
| 119 |
+
* handle, or other. While in general a 32-bit integer will suffice, if
|
| 120 |
+
* handles are 64 bit types, then EGLint should be defined as a signed 64-bit
|
| 121 |
+
* integer type.
|
| 122 |
+
*/
|
| 123 |
+
typedef khronos_int32_t EGLint;
|
| 124 |
+
|
| 125 |
+
#endif /* __eglplatform_h */
|
mujoco-py-2.1.2.14/mujoco_py/gl/glshim.h
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef __GLSHIM_H__
|
| 2 |
+
#define __GLSHIM_H__
|
| 3 |
+
|
| 4 |
+
#include "mujoco.h"
|
| 5 |
+
#include "mjrender.h"
|
| 6 |
+
|
| 7 |
+
#ifdef __cplusplus
|
| 8 |
+
extern "C" {
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
int usingEGL();
|
| 12 |
+
int initOpenGL(int device_id);
|
| 13 |
+
void closeOpenGL();
|
| 14 |
+
int makeOpenGLContextCurrent(int device_id);
|
| 15 |
+
int setOpenGLBufferSize(int device_id, int width, int height);
|
| 16 |
+
|
| 17 |
+
unsigned int createPBO(int width, int height, int batchSize, int use_short);
|
| 18 |
+
void freePBO(unsigned int pixelBuffer);
|
| 19 |
+
void copyFBOToPBO(mjrContext* con,
|
| 20 |
+
unsigned int pbo_rgb, unsigned int pbo_depth,
|
| 21 |
+
mjrRect viewport, int bufferOffset);
|
| 22 |
+
void readPBO(unsigned char *buffer_rgb, unsigned short *buffer_depth,
|
| 23 |
+
unsigned int pbo_rgb, unsigned int pbo_depth,
|
| 24 |
+
int width, int height, int batchSize);
|
| 25 |
+
|
| 26 |
+
#ifdef __cplusplus
|
| 27 |
+
} // extern "C"
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
#endif
|
mujoco-py-2.1.2.14/mujoco_py/gl/khrplatform.h
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#ifndef __khrplatform_h_
|
| 2 |
+
#define __khrplatform_h_
|
| 3 |
+
|
| 4 |
+
/*
|
| 5 |
+
** Copyright (c) 2008-2009 The Khronos Group Inc.
|
| 6 |
+
**
|
| 7 |
+
** Permission is hereby granted, free of charge, to any person obtaining a
|
| 8 |
+
** copy of this software and/or associated documentation files (the
|
| 9 |
+
** "Materials"), to deal in the Materials without restriction, including
|
| 10 |
+
** without limitation the rights to use, copy, modify, merge, publish,
|
| 11 |
+
** distribute, sublicense, and/or sell copies of the Materials, and to
|
| 12 |
+
** permit persons to whom the Materials are furnished to do so, subject to
|
| 13 |
+
** the following conditions:
|
| 14 |
+
**
|
| 15 |
+
** The above copyright notice and this permission notice shall be included
|
| 16 |
+
** in all copies or substantial portions of the Materials.
|
| 17 |
+
**
|
| 18 |
+
** THE MATERIALS ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 19 |
+
** EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 20 |
+
** MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
| 21 |
+
** IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
| 22 |
+
** CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
| 23 |
+
** TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
| 24 |
+
** MATERIALS OR THE USE OR OTHER DEALINGS IN THE MATERIALS.
|
| 25 |
+
*/
|
| 26 |
+
|
| 27 |
+
/* Khronos platform-specific types and definitions.
|
| 28 |
+
*
|
| 29 |
+
* $Revision: 32517 $ on $Date: 2016-03-11 02:41:19 -0800 (Fri, 11 Mar 2016) $
|
| 30 |
+
*
|
| 31 |
+
* Adopters may modify this file to suit their platform. Adopters are
|
| 32 |
+
* encouraged to submit platform specific modifications to the Khronos
|
| 33 |
+
* group so that they can be included in future versions of this file.
|
| 34 |
+
* Please submit changes by sending them to the public Khronos Bugzilla
|
| 35 |
+
* (http://khronos.org/bugzilla) by filing a bug against product
|
| 36 |
+
* "Khronos (general)" component "Registry".
|
| 37 |
+
*
|
| 38 |
+
* A predefined template which fills in some of the bug fields can be
|
| 39 |
+
* reached using http://tinyurl.com/khrplatform-h-bugreport, but you
|
| 40 |
+
* must create a Bugzilla login first.
|
| 41 |
+
*
|
| 42 |
+
*
|
| 43 |
+
* See the Implementer's Guidelines for information about where this file
|
| 44 |
+
* should be located on your system and for more details of its use:
|
| 45 |
+
* http://www.khronos.org/registry/implementers_guide.pdf
|
| 46 |
+
*
|
| 47 |
+
* This file should be included as
|
| 48 |
+
* #include <KHR/khrplatform.h>
|
| 49 |
+
* by Khronos client API header files that use its types and defines.
|
| 50 |
+
*
|
| 51 |
+
* The types in khrplatform.h should only be used to define API-specific types.
|
| 52 |
+
*
|
| 53 |
+
* Types defined in khrplatform.h:
|
| 54 |
+
* khronos_int8_t signed 8 bit
|
| 55 |
+
* khronos_uint8_t unsigned 8 bit
|
| 56 |
+
* khronos_int16_t signed 16 bit
|
| 57 |
+
* khronos_uint16_t unsigned 16 bit
|
| 58 |
+
* khronos_int32_t signed 32 bit
|
| 59 |
+
* khronos_uint32_t unsigned 32 bit
|
| 60 |
+
* khronos_int64_t signed 64 bit
|
| 61 |
+
* khronos_uint64_t unsigned 64 bit
|
| 62 |
+
* khronos_intptr_t signed same number of bits as a pointer
|
| 63 |
+
* khronos_uintptr_t unsigned same number of bits as a pointer
|
| 64 |
+
* khronos_ssize_t signed size
|
| 65 |
+
* khronos_usize_t unsigned size
|
| 66 |
+
* khronos_float_t signed 32 bit floating point
|
| 67 |
+
* khronos_time_ns_t unsigned 64 bit time in nanoseconds
|
| 68 |
+
* khronos_utime_nanoseconds_t unsigned time interval or absolute time in
|
| 69 |
+
* nanoseconds
|
| 70 |
+
* khronos_stime_nanoseconds_t signed time interval in nanoseconds
|
| 71 |
+
* khronos_boolean_enum_t enumerated boolean type. This should
|
| 72 |
+
* only be used as a base type when a client API's boolean type is
|
| 73 |
+
* an enum. Client APIs which use an integer or other type for
|
| 74 |
+
* booleans cannot use this as the base type for their boolean.
|
| 75 |
+
*
|
| 76 |
+
* Tokens defined in khrplatform.h:
|
| 77 |
+
*
|
| 78 |
+
* KHRONOS_FALSE, KHRONOS_TRUE Enumerated boolean false/true values.
|
| 79 |
+
*
|
| 80 |
+
* KHRONOS_SUPPORT_INT64 is 1 if 64 bit integers are supported; otherwise 0.
|
| 81 |
+
* KHRONOS_SUPPORT_FLOAT is 1 if floats are supported; otherwise 0.
|
| 82 |
+
*
|
| 83 |
+
* Calling convention macros defined in this file:
|
| 84 |
+
* KHRONOS_APICALL
|
| 85 |
+
* KHRONOS_APIENTRY
|
| 86 |
+
* KHRONOS_APIATTRIBUTES
|
| 87 |
+
*
|
| 88 |
+
* These may be used in function prototypes as:
|
| 89 |
+
*
|
| 90 |
+
* KHRONOS_APICALL void KHRONOS_APIENTRY funcname(
|
| 91 |
+
* int arg1,
|
| 92 |
+
* int arg2) KHRONOS_APIATTRIBUTES;
|
| 93 |
+
*/
|
| 94 |
+
|
| 95 |
+
/*-------------------------------------------------------------------------
|
| 96 |
+
* Definition of KHRONOS_APICALL
|
| 97 |
+
*-------------------------------------------------------------------------
|
| 98 |
+
* This precedes the return type of the function in the function prototype.
|
| 99 |
+
*/
|
| 100 |
+
#if defined(_WIN32) && !defined(__SCITECH_SNAP__)
|
| 101 |
+
# define KHRONOS_APICALL __declspec(dllimport)
|
| 102 |
+
#elif defined (__SYMBIAN32__)
|
| 103 |
+
# define KHRONOS_APICALL IMPORT_C
|
| 104 |
+
#elif defined(__ANDROID__)
|
| 105 |
+
# include <sys/cdefs.h>
|
| 106 |
+
# define KHRONOS_APICALL __attribute__((visibility("default"))) __NDK_FPABI__
|
| 107 |
+
#else
|
| 108 |
+
# define KHRONOS_APICALL
|
| 109 |
+
#endif
|
| 110 |
+
|
| 111 |
+
/*-------------------------------------------------------------------------
|
| 112 |
+
* Definition of KHRONOS_APIENTRY
|
| 113 |
+
*-------------------------------------------------------------------------
|
| 114 |
+
* This follows the return type of the function and precedes the function
|
| 115 |
+
* name in the function prototype.
|
| 116 |
+
*/
|
| 117 |
+
#if defined(_WIN32) && !defined(_WIN32_WCE) && !defined(__SCITECH_SNAP__)
|
| 118 |
+
/* Win32 but not WinCE */
|
| 119 |
+
# define KHRONOS_APIENTRY __stdcall
|
| 120 |
+
#else
|
| 121 |
+
# define KHRONOS_APIENTRY
|
| 122 |
+
#endif
|
| 123 |
+
|
| 124 |
+
/*-------------------------------------------------------------------------
|
| 125 |
+
* Definition of KHRONOS_APIATTRIBUTES
|
| 126 |
+
*-------------------------------------------------------------------------
|
| 127 |
+
* This follows the closing parenthesis of the function prototype arguments.
|
| 128 |
+
*/
|
| 129 |
+
#if defined (__ARMCC_2__)
|
| 130 |
+
#define KHRONOS_APIATTRIBUTES __softfp
|
| 131 |
+
#else
|
| 132 |
+
#define KHRONOS_APIATTRIBUTES
|
| 133 |
+
#endif
|
| 134 |
+
|
| 135 |
+
/*-------------------------------------------------------------------------
|
| 136 |
+
* basic type definitions
|
| 137 |
+
*-----------------------------------------------------------------------*/
|
| 138 |
+
#if (defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L) || defined(__GNUC__) || defined(__SCO__) || defined(__USLC__)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
/*
|
| 142 |
+
* Using <stdint.h>
|
| 143 |
+
*/
|
| 144 |
+
#include <stdint.h>
|
| 145 |
+
typedef int32_t khronos_int32_t;
|
| 146 |
+
typedef uint32_t khronos_uint32_t;
|
| 147 |
+
typedef int64_t khronos_int64_t;
|
| 148 |
+
typedef uint64_t khronos_uint64_t;
|
| 149 |
+
#define KHRONOS_SUPPORT_INT64 1
|
| 150 |
+
#define KHRONOS_SUPPORT_FLOAT 1
|
| 151 |
+
|
| 152 |
+
#elif defined(__VMS ) || defined(__sgi)
|
| 153 |
+
|
| 154 |
+
/*
|
| 155 |
+
* Using <inttypes.h>
|
| 156 |
+
*/
|
| 157 |
+
#include <inttypes.h>
|
| 158 |
+
typedef int32_t khronos_int32_t;
|
| 159 |
+
typedef uint32_t khronos_uint32_t;
|
| 160 |
+
typedef int64_t khronos_int64_t;
|
| 161 |
+
typedef uint64_t khronos_uint64_t;
|
| 162 |
+
#define KHRONOS_SUPPORT_INT64 1
|
| 163 |
+
#define KHRONOS_SUPPORT_FLOAT 1
|
| 164 |
+
|
| 165 |
+
#elif defined(_WIN32) && !defined(__SCITECH_SNAP__)
|
| 166 |
+
|
| 167 |
+
/*
|
| 168 |
+
* Win32
|
| 169 |
+
*/
|
| 170 |
+
typedef __int32 khronos_int32_t;
|
| 171 |
+
typedef unsigned __int32 khronos_uint32_t;
|
| 172 |
+
typedef __int64 khronos_int64_t;
|
| 173 |
+
typedef unsigned __int64 khronos_uint64_t;
|
| 174 |
+
#define KHRONOS_SUPPORT_INT64 1
|
| 175 |
+
#define KHRONOS_SUPPORT_FLOAT 1
|
| 176 |
+
|
| 177 |
+
#elif defined(__sun__) || defined(__digital__)
|
| 178 |
+
|
| 179 |
+
/*
|
| 180 |
+
* Sun or Digital
|
| 181 |
+
*/
|
| 182 |
+
typedef int khronos_int32_t;
|
| 183 |
+
typedef unsigned int khronos_uint32_t;
|
| 184 |
+
#if defined(__arch64__) || defined(_LP64)
|
| 185 |
+
typedef long int khronos_int64_t;
|
| 186 |
+
typedef unsigned long int khronos_uint64_t;
|
| 187 |
+
#else
|
| 188 |
+
typedef long long int khronos_int64_t;
|
| 189 |
+
typedef unsigned long long int khronos_uint64_t;
|
| 190 |
+
#endif /* __arch64__ */
|
| 191 |
+
#define KHRONOS_SUPPORT_INT64 1
|
| 192 |
+
#define KHRONOS_SUPPORT_FLOAT 1
|
| 193 |
+
|
| 194 |
+
#elif 0
|
| 195 |
+
|
| 196 |
+
/*
|
| 197 |
+
* Hypothetical platform with no float or int64 support
|
| 198 |
+
*/
|
| 199 |
+
typedef int khronos_int32_t;
|
| 200 |
+
typedef unsigned int khronos_uint32_t;
|
| 201 |
+
#define KHRONOS_SUPPORT_INT64 0
|
| 202 |
+
#define KHRONOS_SUPPORT_FLOAT 0
|
| 203 |
+
|
| 204 |
+
#else
|
| 205 |
+
|
| 206 |
+
/*
|
| 207 |
+
* Generic fallback
|
| 208 |
+
*/
|
| 209 |
+
#include <stdint.h>
|
| 210 |
+
typedef int32_t khronos_int32_t;
|
| 211 |
+
typedef uint32_t khronos_uint32_t;
|
| 212 |
+
typedef int64_t khronos_int64_t;
|
| 213 |
+
typedef uint64_t khronos_uint64_t;
|
| 214 |
+
#define KHRONOS_SUPPORT_INT64 1
|
| 215 |
+
#define KHRONOS_SUPPORT_FLOAT 1
|
| 216 |
+
|
| 217 |
+
#endif
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
/*
|
| 221 |
+
* Types that are (so far) the same on all platforms
|
| 222 |
+
*/
|
| 223 |
+
typedef signed char khronos_int8_t;
|
| 224 |
+
typedef unsigned char khronos_uint8_t;
|
| 225 |
+
typedef signed short int khronos_int16_t;
|
| 226 |
+
typedef unsigned short int khronos_uint16_t;
|
| 227 |
+
|
| 228 |
+
/*
|
| 229 |
+
* Types that differ between LLP64 and LP64 architectures - in LLP64,
|
| 230 |
+
* pointers are 64 bits, but 'long' is still 32 bits. Win64 appears
|
| 231 |
+
* to be the only LLP64 architecture in current use.
|
| 232 |
+
*/
|
| 233 |
+
#ifdef _WIN64
|
| 234 |
+
typedef signed long long int khronos_intptr_t;
|
| 235 |
+
typedef unsigned long long int khronos_uintptr_t;
|
| 236 |
+
typedef signed long long int khronos_ssize_t;
|
| 237 |
+
typedef unsigned long long int khronos_usize_t;
|
| 238 |
+
#else
|
| 239 |
+
typedef signed long int khronos_intptr_t;
|
| 240 |
+
typedef unsigned long int khronos_uintptr_t;
|
| 241 |
+
typedef signed long int khronos_ssize_t;
|
| 242 |
+
typedef unsigned long int khronos_usize_t;
|
| 243 |
+
#endif
|
| 244 |
+
|
| 245 |
+
#if KHRONOS_SUPPORT_FLOAT
|
| 246 |
+
/*
|
| 247 |
+
* Float type
|
| 248 |
+
*/
|
| 249 |
+
typedef float khronos_float_t;
|
| 250 |
+
#endif
|
| 251 |
+
|
| 252 |
+
#if KHRONOS_SUPPORT_INT64
|
| 253 |
+
/* Time types
|
| 254 |
+
*
|
| 255 |
+
* These types can be used to represent a time interval in nanoseconds or
|
| 256 |
+
* an absolute Unadjusted System Time. Unadjusted System Time is the number
|
| 257 |
+
* of nanoseconds since some arbitrary system event (e.g. since the last
|
| 258 |
+
* time the system booted). The Unadjusted System Time is an unsigned
|
| 259 |
+
* 64 bit value that wraps back to 0 every 584 years. Time intervals
|
| 260 |
+
* may be either signed or unsigned.
|
| 261 |
+
*/
|
| 262 |
+
typedef khronos_uint64_t khronos_utime_nanoseconds_t;
|
| 263 |
+
typedef khronos_int64_t khronos_stime_nanoseconds_t;
|
| 264 |
+
#endif
|
| 265 |
+
|
| 266 |
+
/*
|
| 267 |
+
* Dummy value used to pad enum types to 32 bits.
|
| 268 |
+
*/
|
| 269 |
+
#ifndef KHRONOS_MAX_ENUM
|
| 270 |
+
#define KHRONOS_MAX_ENUM 0x7FFFFFFF
|
| 271 |
+
#endif
|
| 272 |
+
|
| 273 |
+
/*
|
| 274 |
+
* Enumerated boolean type
|
| 275 |
+
*
|
| 276 |
+
* Values other than zero should be considered to be true. Therefore
|
| 277 |
+
* comparisons should not be made against KHRONOS_TRUE.
|
| 278 |
+
*/
|
| 279 |
+
typedef enum {
|
| 280 |
+
KHRONOS_FALSE = 0,
|
| 281 |
+
KHRONOS_TRUE = 1,
|
| 282 |
+
KHRONOS_BOOLEAN_ENUM_FORCE_SIZE = KHRONOS_MAX_ENUM
|
| 283 |
+
} khronos_boolean_enum_t;
|
| 284 |
+
|
| 285 |
+
#endif /* __khrplatform_h_ */
|
mujoco-py-2.1.2.14/mujoco_py/gl/osmesashim.c
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <GL/osmesa.h>
|
| 2 |
+
#include "glshim.h"
|
| 3 |
+
|
| 4 |
+
OSMesaContext ctx;
|
| 5 |
+
|
| 6 |
+
// this size was picked pretty arbitrarily
|
| 7 |
+
int BUFFER_WIDTH = 1024;
|
| 8 |
+
int BUFFER_HEIGHT = 1024;
|
| 9 |
+
// 4 channels for RGBA
|
| 10 |
+
unsigned char buffer[1024 * 1024 * 4];
|
| 11 |
+
|
| 12 |
+
int is_initialized = 0;
|
| 13 |
+
|
| 14 |
+
int usingEGL() {
|
| 15 |
+
return 0;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
int initOpenGL(int device_id) {
|
| 19 |
+
if (is_initialized)
|
| 20 |
+
return 1;
|
| 21 |
+
|
| 22 |
+
// note: device id not used here
|
| 23 |
+
ctx = OSMesaCreateContextExt(GL_RGBA, 24, 8, 8, 0);
|
| 24 |
+
if( !ctx ) {
|
| 25 |
+
printf("OSMesa context creation failed\n");
|
| 26 |
+
return -1;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
if( !OSMesaMakeCurrent(ctx, buffer, GL_UNSIGNED_BYTE, BUFFER_WIDTH, BUFFER_HEIGHT) ) {
|
| 30 |
+
printf("OSMesa make current failed\n");
|
| 31 |
+
return -1;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
is_initialized = 1;
|
| 35 |
+
|
| 36 |
+
return 1;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
int makeOpenGLContextCurrent(int device_id) {
|
| 40 |
+
// Don't need to make context current here, causes issues with large tests
|
| 41 |
+
return 1;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
int setOpenGLBufferSize(int device_id, int width, int height) {
|
| 45 |
+
if (width > BUFFER_WIDTH || height > BUFFER_HEIGHT) {
|
| 46 |
+
printf("Buffer size too big\n");
|
| 47 |
+
return -1;
|
| 48 |
+
}
|
| 49 |
+
// Noop since we don't support changing the actual buffer
|
| 50 |
+
return 1;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
void closeOpenGL() {
|
| 54 |
+
if (is_initialized) {
|
| 55 |
+
OSMesaDestroyContext(ctx);
|
| 56 |
+
is_initialized = 0;
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
unsigned int createPBO(int width, int height, int batchSize, int use_short) {
|
| 61 |
+
return 0;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
void freePBO(unsigned int pixelBuffer) {
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
void copyFBOToPBO(mjrContext* con,
|
| 68 |
+
unsigned int pbo_rgb, unsigned int pbo_depth,
|
| 69 |
+
mjrRect viewport, int bufferOffset) {
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
void readPBO(unsigned char *buffer_rgb, unsigned short *buffer_depth,
|
| 73 |
+
unsigned int pbo_rgb, unsigned int pbo_depth,
|
| 74 |
+
int width, int height, int batchSize) {
|
| 75 |
+
}
|
mujoco-py-2.1.2.14/mujoco_py/mjbatchrenderer.pyx
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
import pycuda.driver as drv
|
| 3 |
+
except ImportError:
|
| 4 |
+
drv = None
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MjBatchRendererException(Exception):
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MjBatchRendererNotSupported(MjBatchRendererException):
|
| 12 |
+
pass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CudaNotEnabledError(MjBatchRendererException):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CudaBufferNotMappedError(MjBatchRendererException):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CudaBufferMappedError(MjBatchRendererException):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class MjBatchRenderer(object):
|
| 28 |
+
"""
|
| 29 |
+
Utility class for rendering into OpenGL Pixel Buffer Objects (PBOs),
|
| 30 |
+
which allows for accessing multiple rendered images in batch.
|
| 31 |
+
|
| 32 |
+
If used with CUDA (i.e. initialized with use_cuda=True), you need
|
| 33 |
+
to call map/unmap when accessing CUDA buffer pointer. This is to
|
| 34 |
+
ensure that all OpenGL instructions have completed:
|
| 35 |
+
|
| 36 |
+
renderer = MjBatchRenderer(100, 100, use_cuda=True)
|
| 37 |
+
renderer.render(sim)
|
| 38 |
+
|
| 39 |
+
renderer.map()
|
| 40 |
+
image = renderer.read()
|
| 41 |
+
renderer.unmap()
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, width, height, batch_size=1, device_id=0,
|
| 45 |
+
depth=False, use_cuda=False):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
- width (int): Image width.
|
| 49 |
+
- height (int): Image height.
|
| 50 |
+
- batch_size (int): Size of batch to render into. Memory is
|
| 51 |
+
allocated once upon initialization of object.
|
| 52 |
+
- device_id (int): Device to use for storing the batch.
|
| 53 |
+
- depth (bool): if True, render depth in addition to RGB.
|
| 54 |
+
- use_cuda (bool): if True, use OpenGL-CUDA interop to map
|
| 55 |
+
the PBO onto a CUDA buffer.
|
| 56 |
+
"""
|
| 57 |
+
# Early initialization to prevent failure in __del__
|
| 58 |
+
self._use_cuda = False
|
| 59 |
+
self.pbo_depth, self.pbo_depth = 0, 0
|
| 60 |
+
|
| 61 |
+
if not usingEGL():
|
| 62 |
+
raise MjBatchRendererNotSupported(
|
| 63 |
+
"MjBatchRenderer currently only supported with EGL-backed"
|
| 64 |
+
"rendering context.")
|
| 65 |
+
|
| 66 |
+
# Make sure OpenGL Context is available before creating PBOs
|
| 67 |
+
initOpenGL(device_id)
|
| 68 |
+
makeOpenGLContextCurrent(device_id)
|
| 69 |
+
|
| 70 |
+
self.pbo_rgb = createPBO(width, height, batch_size, 0)
|
| 71 |
+
self.pbo_depth = createPBO(width, height, batch_size, 1) if depth else 0
|
| 72 |
+
|
| 73 |
+
self._depth = depth
|
| 74 |
+
self._device_id = device_id
|
| 75 |
+
self._width = width
|
| 76 |
+
self._height = height
|
| 77 |
+
self._batch_size = batch_size
|
| 78 |
+
self._current_batch_offset = 0
|
| 79 |
+
|
| 80 |
+
self._use_cuda = use_cuda
|
| 81 |
+
self._cuda_buffers_are_mapped = False
|
| 82 |
+
self._cuda_rgb_ptr, self._cuda_depth_ptr = None, None
|
| 83 |
+
if use_cuda:
|
| 84 |
+
self._init_cuda()
|
| 85 |
+
|
| 86 |
+
def _init_cuda(self):
|
| 87 |
+
if drv is None:
|
| 88 |
+
raise ImportError("Failed to import pycuda.")
|
| 89 |
+
# Use local imports so that we don't have to make pycuda
|
| 90 |
+
# opengl interop a requirement
|
| 91 |
+
from pycuda.gl import RegisteredBuffer
|
| 92 |
+
|
| 93 |
+
drv.init()
|
| 94 |
+
device = drv.Device(self._device_id)
|
| 95 |
+
self._cuda_context = device.make_context()
|
| 96 |
+
self._cuda_context.push()
|
| 97 |
+
|
| 98 |
+
self._cuda_rgb_pbo = RegisteredBuffer(self.pbo_rgb)
|
| 99 |
+
if self._depth:
|
| 100 |
+
self._cuda_depth_pbo = RegisteredBuffer(self.pbo_depth)
|
| 101 |
+
|
| 102 |
+
def map(self):
|
| 103 |
+
""" Map OpenGL buffer to CUDA for reading. """
|
| 104 |
+
if not self._use_cuda:
|
| 105 |
+
raise CudaNotEnabledError()
|
| 106 |
+
elif self._cuda_buffers_are_mapped:
|
| 107 |
+
return # just make it a no-op
|
| 108 |
+
|
| 109 |
+
self._cuda_context.push()
|
| 110 |
+
self._cuda_rgb_mapping = self._cuda_rgb_pbo.map()
|
| 111 |
+
ptr, self._cuda_rgb_buf_size = (
|
| 112 |
+
self._cuda_rgb_mapping.device_ptr_and_size())
|
| 113 |
+
assert ptr is not None and self._cuda_rgb_buf_size > 0
|
| 114 |
+
if self._cuda_rgb_ptr is None:
|
| 115 |
+
self._cuda_rgb_ptr = ptr
|
| 116 |
+
|
| 117 |
+
# There doesn't seem to be a guarantee from the API that the
|
| 118 |
+
# pointer will be the same between mappings, but empirically
|
| 119 |
+
# this has been true. If this isn't true, we need to modify
|
| 120 |
+
# the interface to MjBatchRenderer to make this clearer to user.
|
| 121 |
+
# So, hopefully we won't hit this assert.
|
| 122 |
+
assert self._cuda_rgb_ptr == ptr, (
|
| 123 |
+
"Mapped CUDA rgb buffer pointer %d doesn't match old pointer %d" %
|
| 124 |
+
(ptr, self._cuda_rgb_ptr))
|
| 125 |
+
|
| 126 |
+
if self._depth:
|
| 127 |
+
self._cuda_depth_mapping = self._cuda_depth_pbo.map()
|
| 128 |
+
ptr, self._cuda_depth_buf_size = (
|
| 129 |
+
self._cuda_depth_mapping.device_ptr_and_size())
|
| 130 |
+
assert ptr is not None and self._cuda_depth_buf_size > 0
|
| 131 |
+
if self._cuda_depth_ptr is None:
|
| 132 |
+
self._cuda_depth_ptr = ptr
|
| 133 |
+
assert self._cuda_depth_ptr == ptr, (
|
| 134 |
+
"Mapped CUDA depth buffer pointer %d doesn't match old pointer %d" %
|
| 135 |
+
(ptr, self._cuda_depth_ptr))
|
| 136 |
+
|
| 137 |
+
self._cuda_buffers_are_mapped = True
|
| 138 |
+
|
| 139 |
+
def unmap(self):
|
| 140 |
+
""" Unmap OpenGL buffer from CUDA so that it can be rendered into. """
|
| 141 |
+
if not self._use_cuda:
|
| 142 |
+
raise CudaNotEnabledError()
|
| 143 |
+
elif not self._cuda_buffers_are_mapped:
|
| 144 |
+
return # just make it a no-op
|
| 145 |
+
|
| 146 |
+
self._cuda_context.push()
|
| 147 |
+
self._cuda_rgb_mapping.unmap()
|
| 148 |
+
self._cuda_rgb_mapping = None
|
| 149 |
+
self._cuda_rgb_ptr = None
|
| 150 |
+
if self._depth:
|
| 151 |
+
self._cuda_depth_mapping.unmap()
|
| 152 |
+
self._cuda_depth_mapping = None
|
| 153 |
+
self._cuda_depth_ptr = None
|
| 154 |
+
|
| 155 |
+
self._cuda_buffers_are_mapped = False
|
| 156 |
+
|
| 157 |
+
def prepare_render_context(self, sim):
|
| 158 |
+
"""
|
| 159 |
+
Set up the rendering context for an MjSim. Also happens automatically
|
| 160 |
+
on `.render()`.
|
| 161 |
+
"""
|
| 162 |
+
for c in sim.render_contexts:
|
| 163 |
+
if (c.offscreen and
|
| 164 |
+
isinstance(c.opengl_context, OffscreenOpenGLContext) and
|
| 165 |
+
c.opengl_context.device_id == self._device_id):
|
| 166 |
+
return c
|
| 167 |
+
|
| 168 |
+
return MjRenderContext(sim, device_id=self._device_id)
|
| 169 |
+
|
| 170 |
+
def render(self, sim, camera_id=None, batch_offset=None):
|
| 171 |
+
"""
|
| 172 |
+
Render current scene from the MjSim into the buffer. By
|
| 173 |
+
default the batch offset is automatically incremented with
|
| 174 |
+
each call. It can be reset with the batch_offset parameter.
|
| 175 |
+
|
| 176 |
+
This method doesn't return anything. Use the `.read` method
|
| 177 |
+
to read the buffer, or access the buffer pointer directly with
|
| 178 |
+
e.g. `.cuda_rgb_buffer_pointer` accessor.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
- sim (MjSim): The simulator to use for rendering.
|
| 182 |
+
- camera_id (int): MuJoCo id for the camera, from
|
| 183 |
+
`sim.model.camera_name2id()`.
|
| 184 |
+
- batch_offset (int): offset in batch to render to.
|
| 185 |
+
"""
|
| 186 |
+
if self._use_cuda and self._cuda_buffers_are_mapped:
|
| 187 |
+
raise CudaBufferMappedError(
|
| 188 |
+
"CUDA buffers must be unmapped before calling render.")
|
| 189 |
+
|
| 190 |
+
if batch_offset is not None:
|
| 191 |
+
if batch_offset < 0 or batch_offset >= self._batch_size:
|
| 192 |
+
raise ValueError("batch_offset out of range")
|
| 193 |
+
self._current_batch_offset = batch_offset
|
| 194 |
+
|
| 195 |
+
# Ensure the correct device context is used (this takes ~1 µs)
|
| 196 |
+
makeOpenGLContextCurrent(self._device_id)
|
| 197 |
+
|
| 198 |
+
render_context = self.prepare_render_context(sim)
|
| 199 |
+
render_context.update_offscreen_size(self._width, self._height)
|
| 200 |
+
render_context.render(self._width, self._height, camera_id=camera_id)
|
| 201 |
+
|
| 202 |
+
cdef mjrRect viewport
|
| 203 |
+
viewport.left = 0
|
| 204 |
+
viewport.bottom = 0
|
| 205 |
+
viewport.width = self._width
|
| 206 |
+
viewport.height = self._height
|
| 207 |
+
|
| 208 |
+
cdef PyMjrContext con = <PyMjrContext> render_context.con
|
| 209 |
+
copyFBOToPBO(con.ptr, self.pbo_rgb, self.pbo_depth,
|
| 210 |
+
viewport, self._current_batch_offset)
|
| 211 |
+
|
| 212 |
+
self._current_batch_offset = (self._current_batch_offset + 1) % self._batch_size
|
| 213 |
+
|
| 214 |
+
def read(self):
|
| 215 |
+
"""
|
| 216 |
+
Transfer a copy of the buffer from the GPU to the CPU as a numpy array.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
- rgb_batch (numpy array): batch of rgb images in uint8 NHWC format
|
| 220 |
+
- depth_batch (numpy array): batch of depth images in uint16 NHWC format
|
| 221 |
+
"""
|
| 222 |
+
if self._use_cuda:
|
| 223 |
+
return self._read_cuda()
|
| 224 |
+
else:
|
| 225 |
+
return self._read_nocuda()
|
| 226 |
+
|
| 227 |
+
def _read_cuda(self):
|
| 228 |
+
if not self._cuda_buffers_are_mapped:
|
| 229 |
+
raise CudaBufferNotMappedError(
|
| 230 |
+
"CUDA buffers must be mapped before reading")
|
| 231 |
+
|
| 232 |
+
rgb_arr = drv.from_device(
|
| 233 |
+
self._cuda_rgb_ptr,
|
| 234 |
+
shape=(self._batch_size, self._height, self._width, 3),
|
| 235 |
+
dtype=np.uint8)
|
| 236 |
+
|
| 237 |
+
if self._depth:
|
| 238 |
+
depth_arr = drv.from_device(
|
| 239 |
+
self._cuda_depth_ptr,
|
| 240 |
+
shape=(self._batch_size, self._height, self._width),
|
| 241 |
+
dtype=np.uint16)
|
| 242 |
+
else:
|
| 243 |
+
depth_arr = None
|
| 244 |
+
|
| 245 |
+
return rgb_arr, depth_arr
|
| 246 |
+
|
| 247 |
+
def _read_nocuda(self):
|
| 248 |
+
rgb_arr = np.zeros(3 * self._width * self._height * self._batch_size, dtype=np.uint8)
|
| 249 |
+
cdef unsigned char[::view.contiguous] rgb_view = rgb_arr
|
| 250 |
+
depth_arr = np.zeros(self._width * self._height * self._batch_size, dtype=np.uint16)
|
| 251 |
+
cdef unsigned short[::view.contiguous] depth_view = depth_arr
|
| 252 |
+
|
| 253 |
+
if self._depth:
|
| 254 |
+
readPBO(&rgb_view[0], &depth_view[0], self.pbo_rgb, self.pbo_depth,
|
| 255 |
+
self._width, self._height, self._batch_size)
|
| 256 |
+
depth_arr = depth_arr.reshape(self._batch_size, self._height, self._width)
|
| 257 |
+
else:
|
| 258 |
+
readPBO(&rgb_view[0], NULL, self.pbo_rgb, 0,
|
| 259 |
+
self._width, self._height, self._batch_size)
|
| 260 |
+
# Fine to throw aray depth_arr above since malloc/free is cheap
|
| 261 |
+
depth_arr = None
|
| 262 |
+
|
| 263 |
+
rgb_arr = rgb_arr.reshape(self._batch_size, self._height, self._width, 3)
|
| 264 |
+
return rgb_arr, depth_arr
|
| 265 |
+
|
| 266 |
+
@property
|
| 267 |
+
def cuda_rgb_buffer_pointer(self):
|
| 268 |
+
""" Pointer to CUDA buffer for RGB batch. """
|
| 269 |
+
if not self._use_cuda:
|
| 270 |
+
raise CudaNotEnabledError()
|
| 271 |
+
elif not self._cuda_buffers_are_mapped:
|
| 272 |
+
raise CudaBufferNotMappedError()
|
| 273 |
+
return self._cuda_rgb_ptr
|
| 274 |
+
|
| 275 |
+
@property
|
| 276 |
+
def cuda_depth_buffer_pointer(self):
|
| 277 |
+
""" Pointer to CUDA buffer for depth batch. """
|
| 278 |
+
if not self._use_cuda:
|
| 279 |
+
raise CudaNotEnabledError()
|
| 280 |
+
elif not self._cuda_buffers_are_mapped:
|
| 281 |
+
raise CudaBufferNotMappedError()
|
| 282 |
+
if not self._depth:
|
| 283 |
+
raise RuntimeError("Depth not enabled. Use depth=True on initialization.")
|
| 284 |
+
return self._cuda_depth_ptr
|
| 285 |
+
|
| 286 |
+
def __del__(self):
|
| 287 |
+
if self._use_cuda:
|
| 288 |
+
self._cuda_context.push()
|
| 289 |
+
self.unmap()
|
| 290 |
+
self._cuda_rgb_pbo.unregister()
|
| 291 |
+
if self._depth:
|
| 292 |
+
self._cuda_depth_pbo.unregister()
|
| 293 |
+
|
| 294 |
+
# Clean up context
|
| 295 |
+
drv.Context.pop()
|
| 296 |
+
self._cuda_context.detach()
|
| 297 |
+
|
| 298 |
+
if self.pbo_depth:
|
| 299 |
+
freePBO(self.pbo_rgb)
|
| 300 |
+
if self.pbo_depth:
|
| 301 |
+
freePBO(self.pbo_depth)
|
mujoco-py-2.1.2.14/mujoco_py/mjrendercontext.pyx
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from threading import Lock
|
| 2 |
+
from mujoco_py.generated import const
|
| 3 |
+
import numpy as np
|
| 4 |
+
cimport numpy as np
|
| 5 |
+
|
| 6 |
+
cdef class MjRenderContext(object):
|
| 7 |
+
"""
|
| 8 |
+
Class that encapsulates rendering functionality for a
|
| 9 |
+
MuJoCo simulation.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
cdef mjModel *_model_ptr
|
| 13 |
+
cdef mjData *_data_ptr
|
| 14 |
+
|
| 15 |
+
cdef mjvScene _scn
|
| 16 |
+
cdef mjvCamera _cam
|
| 17 |
+
cdef mjvOption _vopt
|
| 18 |
+
cdef mjvPerturb _pert
|
| 19 |
+
cdef mjrContext _con
|
| 20 |
+
|
| 21 |
+
# Public wrappers
|
| 22 |
+
cdef readonly PyMjvScene scn
|
| 23 |
+
cdef readonly PyMjvCamera cam
|
| 24 |
+
cdef readonly PyMjvOption vopt
|
| 25 |
+
cdef readonly PyMjvPerturb pert
|
| 26 |
+
cdef readonly PyMjrContext con
|
| 27 |
+
|
| 28 |
+
cdef readonly object opengl_context
|
| 29 |
+
cdef readonly int _visible
|
| 30 |
+
cdef readonly list _markers
|
| 31 |
+
cdef readonly dict _overlay
|
| 32 |
+
|
| 33 |
+
cdef readonly bint offscreen
|
| 34 |
+
cdef public object sim
|
| 35 |
+
|
| 36 |
+
def __cinit__(self):
|
| 37 |
+
maxgeom = 1000
|
| 38 |
+
mjv_makeScene(self._model_ptr, &self._scn, maxgeom)
|
| 39 |
+
mjv_defaultCamera(&self._cam)
|
| 40 |
+
mjv_defaultPerturb(&self._pert)
|
| 41 |
+
mjv_defaultOption(&self._vopt)
|
| 42 |
+
mjr_defaultContext(&self._con)
|
| 43 |
+
|
| 44 |
+
def __init__(self, MjSim sim, bint offscreen=True, int device_id=-1, opengl_backend=None, quiet=False):
|
| 45 |
+
self.sim = sim
|
| 46 |
+
self._setup_opengl_context(offscreen, device_id, opengl_backend, quiet=quiet)
|
| 47 |
+
self.offscreen = offscreen
|
| 48 |
+
|
| 49 |
+
# Ensure the model data has been updated so that there
|
| 50 |
+
# is something to render
|
| 51 |
+
sim.forward()
|
| 52 |
+
|
| 53 |
+
sim.add_render_context(self)
|
| 54 |
+
|
| 55 |
+
self._model_ptr = sim.model.ptr
|
| 56 |
+
self._data_ptr = sim.data.ptr
|
| 57 |
+
self.scn = WrapMjvScene(&self._scn)
|
| 58 |
+
self.cam = WrapMjvCamera(&self._cam)
|
| 59 |
+
self.vopt = WrapMjvOption(&self._vopt)
|
| 60 |
+
self.con = WrapMjrContext(&self._con)
|
| 61 |
+
|
| 62 |
+
self._pert.active = 0
|
| 63 |
+
self._pert.select = 0
|
| 64 |
+
self._pert.skinselect = -1
|
| 65 |
+
|
| 66 |
+
self.pert = WrapMjvPerturb(&self._pert)
|
| 67 |
+
|
| 68 |
+
self._markers = []
|
| 69 |
+
self._overlay = {}
|
| 70 |
+
|
| 71 |
+
self._init_camera(sim)
|
| 72 |
+
self._set_mujoco_buffers()
|
| 73 |
+
|
| 74 |
+
def update_sim(self, MjSim new_sim):
|
| 75 |
+
if new_sim == self.sim:
|
| 76 |
+
return
|
| 77 |
+
self._model_ptr = new_sim.model.ptr
|
| 78 |
+
self._data_ptr = new_sim.data.ptr
|
| 79 |
+
self._set_mujoco_buffers()
|
| 80 |
+
for render_context in self.sim.render_contexts:
|
| 81 |
+
new_sim.add_render_context(render_context)
|
| 82 |
+
self.sim = new_sim
|
| 83 |
+
|
| 84 |
+
def _set_mujoco_buffers(self):
|
| 85 |
+
mjr_makeContext(self._model_ptr, &self._con, mjFONTSCALE_150)
|
| 86 |
+
if self.offscreen:
|
| 87 |
+
mjr_setBuffer(mjFB_OFFSCREEN, &self._con);
|
| 88 |
+
if self._con.currentBuffer != mjFB_OFFSCREEN:
|
| 89 |
+
raise RuntimeError('Offscreen rendering not supported')
|
| 90 |
+
else:
|
| 91 |
+
mjr_setBuffer(mjFB_WINDOW, &self._con);
|
| 92 |
+
if self._con.currentBuffer != mjFB_WINDOW:
|
| 93 |
+
raise RuntimeError('Window rendering not supported')
|
| 94 |
+
self.con = WrapMjrContext(&self._con)
|
| 95 |
+
|
| 96 |
+
def _setup_opengl_context(self, offscreen, device_id, opengl_backend, quiet=False):
|
| 97 |
+
if opengl_backend is None and (not offscreen or sys.platform == 'darwin'):
|
| 98 |
+
# default to glfw for onscreen viewing or mac (both offscreen/onscreen)
|
| 99 |
+
opengl_backend = 'glfw'
|
| 100 |
+
|
| 101 |
+
if opengl_backend == 'glfw':
|
| 102 |
+
self.opengl_context = GlfwContext(offscreen=offscreen, quiet=quiet)
|
| 103 |
+
else:
|
| 104 |
+
if device_id < 0:
|
| 105 |
+
if "GPUS" in os.environ:
|
| 106 |
+
device_id = os.environ["GPUS"]
|
| 107 |
+
else:
|
| 108 |
+
device_id = os.getenv('CUDA_VISIBLE_DEVICES', '')
|
| 109 |
+
if len(device_id) > 0:
|
| 110 |
+
device_id = int(device_id.split(',')[0])
|
| 111 |
+
else:
|
| 112 |
+
# Sometimes env variable is an empty string.
|
| 113 |
+
device_id = 0
|
| 114 |
+
self.opengl_context = OffscreenOpenGLContext(device_id)
|
| 115 |
+
|
| 116 |
+
def _init_camera(self, sim):
|
| 117 |
+
# Make the free camera look at the scene
|
| 118 |
+
self.cam.type = const.CAMERA_FREE
|
| 119 |
+
self.cam.fixedcamid = -1
|
| 120 |
+
for i in range(3):
|
| 121 |
+
self.cam.lookat[i] = np.median(sim.data.geom_xpos[:, i])
|
| 122 |
+
self.cam.distance = sim.model.stat.extent
|
| 123 |
+
|
| 124 |
+
def update_offscreen_size(self, width, height):
|
| 125 |
+
if width != self._con.offWidth or height != self._con.offHeight:
|
| 126 |
+
self._model_ptr.vis.global_.offwidth = width
|
| 127 |
+
self._model_ptr.vis.global_.offheight = height
|
| 128 |
+
mjr_freeContext(&self._con)
|
| 129 |
+
self._set_mujoco_buffers()
|
| 130 |
+
|
| 131 |
+
def render(self, width, height, camera_id=None, segmentation=False):
|
| 132 |
+
cdef mjrRect rect
|
| 133 |
+
rect.left = 0
|
| 134 |
+
rect.bottom = 0
|
| 135 |
+
rect.width = width
|
| 136 |
+
rect.height = height
|
| 137 |
+
|
| 138 |
+
if self.sim.render_callback is not None:
|
| 139 |
+
self.sim.render_callback(self.sim, self)
|
| 140 |
+
|
| 141 |
+
# Sometimes buffers are too small.
|
| 142 |
+
if width > self._con.offWidth or height > self._con.offHeight:
|
| 143 |
+
new_width = max(width, self._model_ptr.vis.global_.offwidth)
|
| 144 |
+
new_height = max(height, self._model_ptr.vis.global_.offheight)
|
| 145 |
+
self.update_offscreen_size(new_width, new_height)
|
| 146 |
+
|
| 147 |
+
if camera_id is not None:
|
| 148 |
+
if camera_id == -1:
|
| 149 |
+
self.cam.type = const.CAMERA_FREE
|
| 150 |
+
else:
|
| 151 |
+
self.cam.type = const.CAMERA_FIXED
|
| 152 |
+
self.cam.fixedcamid = camera_id
|
| 153 |
+
|
| 154 |
+
# This doesn't really do anything else rather than checking for the size of buffer
|
| 155 |
+
# need to investigate further whi is that a no-op
|
| 156 |
+
# self.opengl_context.set_buffer_size(width, height)
|
| 157 |
+
|
| 158 |
+
mjv_updateScene(self._model_ptr, self._data_ptr, &self._vopt,
|
| 159 |
+
&self._pert, &self._cam, mjCAT_ALL, &self._scn)
|
| 160 |
+
|
| 161 |
+
if segmentation:
|
| 162 |
+
self._scn.flags[const.RND_SEGMENT] = 1
|
| 163 |
+
self._scn.flags[const.RND_IDCOLOR] = 1
|
| 164 |
+
|
| 165 |
+
for marker_params in self._markers:
|
| 166 |
+
self._add_marker_to_scene(marker_params)
|
| 167 |
+
|
| 168 |
+
mjr_render(rect, &self._scn, &self._con)
|
| 169 |
+
for gridpos, (text1, text2) in self._overlay.items():
|
| 170 |
+
mjr_overlay(const.FONTSCALE_150, gridpos, rect, text1.encode(), text2.encode(), &self._con)
|
| 171 |
+
|
| 172 |
+
if segmentation:
|
| 173 |
+
self._scn.flags[const.RND_SEGMENT] = 0
|
| 174 |
+
self._scn.flags[const.RND_IDCOLOR] = 0
|
| 175 |
+
|
| 176 |
+
def read_pixels(self, width, height, depth=True, segmentation=False):
|
| 177 |
+
cdef mjrRect rect
|
| 178 |
+
rect.left = 0
|
| 179 |
+
rect.bottom = 0
|
| 180 |
+
rect.width = width
|
| 181 |
+
rect.height = height
|
| 182 |
+
|
| 183 |
+
rgb_arr = np.zeros(3 * rect.width * rect.height, dtype=np.uint8)
|
| 184 |
+
depth_arr = np.zeros(rect.width * rect.height, dtype=np.float32)
|
| 185 |
+
|
| 186 |
+
cdef unsigned char[::view.contiguous] rgb_view = rgb_arr
|
| 187 |
+
cdef float[::view.contiguous] depth_view = depth_arr
|
| 188 |
+
mjr_readPixels(&rgb_view[0], &depth_view[0], rect, &self._con)
|
| 189 |
+
rgb_img = rgb_arr.reshape(rect.height, rect.width, 3)
|
| 190 |
+
cdef np.ndarray[np.npy_uint32, ndim=2] seg_img
|
| 191 |
+
cdef np.ndarray[np.npy_int32, ndim=2] seg_ids
|
| 192 |
+
|
| 193 |
+
ret_img = rgb_img
|
| 194 |
+
if segmentation:
|
| 195 |
+
seg_img = (rgb_img[:, :, 0] + rgb_img[:, :, 1] * (2**8) + rgb_img[:, :, 2] * (2 ** 16))
|
| 196 |
+
seg_img[seg_img >= (self._scn.ngeom + 1)] = 0
|
| 197 |
+
seg_ids = np.full((self._scn.ngeom + 1, 2), fill_value=-1, dtype=np.int32)
|
| 198 |
+
|
| 199 |
+
for i in range(self._scn.ngeom):
|
| 200 |
+
geom = self._scn.geoms[i]
|
| 201 |
+
if geom.segid != -1:
|
| 202 |
+
seg_ids[geom.segid + 1, 0] = geom.objtype
|
| 203 |
+
seg_ids[geom.segid + 1, 1] = geom.objid
|
| 204 |
+
ret_img = seg_ids[seg_img]
|
| 205 |
+
|
| 206 |
+
if depth:
|
| 207 |
+
depth_img = depth_arr.reshape(rect.height, rect.width)
|
| 208 |
+
return (ret_img, depth_img)
|
| 209 |
+
else:
|
| 210 |
+
return ret_img
|
| 211 |
+
|
| 212 |
+
def read_pixels_depth(self, np.ndarray[np.float32_t, mode="c", ndim=2] buffer):
|
| 213 |
+
''' Read depth pixels into a preallocated buffer '''
|
| 214 |
+
cdef mjrRect rect
|
| 215 |
+
rect.left = 0
|
| 216 |
+
rect.bottom = 0
|
| 217 |
+
rect.width = buffer.shape[1]
|
| 218 |
+
rect.height = buffer.shape[0]
|
| 219 |
+
|
| 220 |
+
cdef float[::view.contiguous] buffer_view = buffer.ravel()
|
| 221 |
+
mjr_readPixels(NULL, &buffer_view[0], rect, &self._con)
|
| 222 |
+
|
| 223 |
+
def upload_texture(self, int tex_id):
|
| 224 |
+
""" Uploads given texture to the GPU. """
|
| 225 |
+
self.opengl_context.make_context_current()
|
| 226 |
+
mjr_uploadTexture(self._model_ptr, &self._con, tex_id)
|
| 227 |
+
|
| 228 |
+
def draw_pixels(self, np.ndarray[np.uint8_t, ndim=3] image, int left, int bottom):
|
| 229 |
+
"""Draw an image into the OpenGL buffer."""
|
| 230 |
+
cdef unsigned char[::view.contiguous] image_view = image.ravel()
|
| 231 |
+
cdef mjrRect viewport
|
| 232 |
+
viewport.left = left
|
| 233 |
+
viewport.bottom = bottom
|
| 234 |
+
viewport.width = image.shape[1]
|
| 235 |
+
viewport.height = image.shape[0]
|
| 236 |
+
mjr_drawPixels(&image_view[0], NULL, viewport, &self._con)
|
| 237 |
+
|
| 238 |
+
def move_camera(self, int action, double reldx, double reldy):
|
| 239 |
+
""" Moves the camera based on mouse movements. Action is one of mjMOUSE_*. """
|
| 240 |
+
mjv_moveCamera(self._model_ptr, action, reldx, reldy, &self._scn, &self._cam)
|
| 241 |
+
|
| 242 |
+
def add_overlay(self, int gridpos, str text1, str text2):
|
| 243 |
+
""" Overlays text on the scene. """
|
| 244 |
+
if gridpos not in self._overlay:
|
| 245 |
+
self._overlay[gridpos] = ["", ""]
|
| 246 |
+
self._overlay[gridpos][0] += text1 + "\n"
|
| 247 |
+
self._overlay[gridpos][1] += text2 + "\n"
|
| 248 |
+
|
| 249 |
+
def add_marker(self, **marker_params):
|
| 250 |
+
self._markers.append(marker_params)
|
| 251 |
+
|
| 252 |
+
def _add_marker_to_scene(self, marker_params):
|
| 253 |
+
""" Adds marker to scene, and returns the corresponding object. """
|
| 254 |
+
if self._scn.ngeom >= self._scn.maxgeom:
|
| 255 |
+
raise RuntimeError('Ran out of geoms. maxgeom: %d' % self._scn.maxgeom)
|
| 256 |
+
|
| 257 |
+
cdef mjvGeom *g = self._scn.geoms + self._scn.ngeom
|
| 258 |
+
|
| 259 |
+
# default values.
|
| 260 |
+
g.dataid = -1
|
| 261 |
+
g.objtype = const.OBJ_UNKNOWN
|
| 262 |
+
g.objid = -1
|
| 263 |
+
g.category = const.CAT_DECOR
|
| 264 |
+
g.texid = -1
|
| 265 |
+
g.texuniform = 0
|
| 266 |
+
g.texrepeat[0] = 1
|
| 267 |
+
g.texrepeat[1] = 1
|
| 268 |
+
g.emission = 0
|
| 269 |
+
g.specular = 0.5
|
| 270 |
+
g.shininess = 0.5
|
| 271 |
+
g.reflectance = 0
|
| 272 |
+
g.type = const.GEOM_BOX
|
| 273 |
+
g.size[:] = np.ones(3) * 0.1
|
| 274 |
+
g.mat[:] = np.eye(3).flatten()
|
| 275 |
+
g.rgba[:] = np.ones(4)
|
| 276 |
+
wrapped = WrapMjvGeom(g)
|
| 277 |
+
|
| 278 |
+
for key, value in marker_params.items():
|
| 279 |
+
if isinstance(value, (int, float)):
|
| 280 |
+
setattr(wrapped, key, value)
|
| 281 |
+
elif isinstance(value, (tuple, list, np.ndarray)):
|
| 282 |
+
attr = getattr(wrapped, key)
|
| 283 |
+
attr[:] = np.asarray(value).reshape(attr.shape)
|
| 284 |
+
elif isinstance(value, str):
|
| 285 |
+
assert key == "label", "Only label is a string in mjvGeom."
|
| 286 |
+
if value == None:
|
| 287 |
+
g.label[0] = 0
|
| 288 |
+
else:
|
| 289 |
+
strncpy(g.label, value.encode(), 100)
|
| 290 |
+
elif hasattr(wrapped, key):
|
| 291 |
+
raise ValueError("mjvGeom has attr {} but type {} is invalid".format(key, type(value)))
|
| 292 |
+
else:
|
| 293 |
+
raise ValueError("mjvGeom doesn't have field %s" % key)
|
| 294 |
+
|
| 295 |
+
self._scn.ngeom += 1
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def __dealloc__(self):
|
| 299 |
+
mjr_freeContext(&self._con)
|
| 300 |
+
mjv_freeScene(&self._scn)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
class MjRenderContextOffscreen(MjRenderContext):
|
| 304 |
+
|
| 305 |
+
def __cinit__(self, MjSim sim, int device_id):
|
| 306 |
+
super().__init__(sim, offscreen=True, device_id=device_id)
|
| 307 |
+
|
| 308 |
+
class MjRenderContextWindow(MjRenderContext):
|
| 309 |
+
|
| 310 |
+
def __init__(self, MjSim sim):
|
| 311 |
+
super().__init__(sim, offscreen=False)
|
| 312 |
+
self.render_swap_callback = None
|
| 313 |
+
|
| 314 |
+
assert isinstance(self.opengl_context, GlfwContext), (
|
| 315 |
+
"Only GlfwContext supported for windowed rendering")
|
| 316 |
+
|
| 317 |
+
@property
|
| 318 |
+
def window(self):
|
| 319 |
+
return self.opengl_context.window
|
| 320 |
+
|
| 321 |
+
def render(self):
|
| 322 |
+
if self.window is None or glfw.window_should_close(self.window):
|
| 323 |
+
return
|
| 324 |
+
|
| 325 |
+
glfw.make_context_current(self.window)
|
| 326 |
+
super().render(*glfw.get_framebuffer_size(self.window))
|
| 327 |
+
if self.render_swap_callback is not None:
|
| 328 |
+
self.render_swap_callback()
|
| 329 |
+
glfw.swap_buffers(self.window)
|
mujoco-py-2.1.2.14/mujoco_py/mjrenderpool.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ctypes
|
| 2 |
+
import inspect
|
| 3 |
+
|
| 4 |
+
from multiprocessing import Array, get_start_method, Pool, Value
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class RenderPoolStorage:
|
| 10 |
+
"""
|
| 11 |
+
Helper object used for storing global data for worker processes.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
__slots__ = ['shared_rgbs_array',
|
| 15 |
+
'shared_depths_array',
|
| 16 |
+
'device_id',
|
| 17 |
+
'sim',
|
| 18 |
+
'modder']
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MjRenderPool:
|
| 22 |
+
"""
|
| 23 |
+
Utilizes a process pool to render a MuJoCo simulation across
|
| 24 |
+
multiple GPU devices. This can scale the throughput linearly
|
| 25 |
+
with the number of available GPUs. Throughput can also be
|
| 26 |
+
slightly increased by using more than one worker per GPU.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
DEFAULT_MAX_IMAGE_SIZE = 512 * 512 # in pixels
|
| 30 |
+
|
| 31 |
+
def __init__(self, model, device_ids=1, n_workers=None,
|
| 32 |
+
max_batch_size=None, max_image_size=DEFAULT_MAX_IMAGE_SIZE,
|
| 33 |
+
modder=None):
|
| 34 |
+
"""
|
| 35 |
+
Args:
|
| 36 |
+
- model (PyMjModel): MuJoCo model to use for rendering
|
| 37 |
+
- device_ids (int/list): list of device ids to use for rendering.
|
| 38 |
+
One or more workers will be assigned to each device, depending
|
| 39 |
+
on how many workers are requested.
|
| 40 |
+
- n_workers (int): number of parallel processes in the pool. Defaults
|
| 41 |
+
to the number of device ids.
|
| 42 |
+
- max_batch_size (int): maximum number of states that can be rendered
|
| 43 |
+
in batch using .render(). Defaults to the number of workers.
|
| 44 |
+
- max_image_size (int): maximum number pixels in images requested
|
| 45 |
+
by .render()
|
| 46 |
+
- modder (Modder): modder to use for domain randomization.
|
| 47 |
+
"""
|
| 48 |
+
self._closed, self.pool = False, None
|
| 49 |
+
|
| 50 |
+
if not (modder is None or inspect.isclass(modder)):
|
| 51 |
+
raise ValueError("modder must be a class")
|
| 52 |
+
|
| 53 |
+
if isinstance(device_ids, int):
|
| 54 |
+
device_ids = list(range(device_ids))
|
| 55 |
+
else:
|
| 56 |
+
assert isinstance(device_ids, list), (
|
| 57 |
+
"device_ids must be list of integer")
|
| 58 |
+
|
| 59 |
+
n_workers = n_workers or 1
|
| 60 |
+
self._max_batch_size = max_batch_size or (len(device_ids) * n_workers)
|
| 61 |
+
self._max_image_size = max_image_size
|
| 62 |
+
|
| 63 |
+
array_size = self._max_image_size * self._max_batch_size
|
| 64 |
+
|
| 65 |
+
self._shared_rgbs = Array(ctypes.c_uint8, array_size * 3)
|
| 66 |
+
self._shared_depths = Array(ctypes.c_float, array_size)
|
| 67 |
+
|
| 68 |
+
self._shared_rgbs_array = np.frombuffer(
|
| 69 |
+
self._shared_rgbs.get_obj(), dtype=ctypes.c_uint8)
|
| 70 |
+
assert self._shared_rgbs_array.size == (array_size * 3), (
|
| 71 |
+
"Array size is %d, expected %d" % (
|
| 72 |
+
self._shared_rgbs_array.size, array_size * 3))
|
| 73 |
+
self._shared_depths_array = np.frombuffer(
|
| 74 |
+
self._shared_depths.get_obj(), dtype=ctypes.c_float)
|
| 75 |
+
assert self._shared_depths_array.size == array_size, (
|
| 76 |
+
"Array size is %d, expected %d" % (
|
| 77 |
+
self._shared_depths_array.size, array_size))
|
| 78 |
+
|
| 79 |
+
worker_id = Value(ctypes.c_int)
|
| 80 |
+
worker_id.value = 0
|
| 81 |
+
|
| 82 |
+
if get_start_method() != "spawn":
|
| 83 |
+
raise RuntimeError(
|
| 84 |
+
"Start method must be set to 'spawn' for the "
|
| 85 |
+
"render pool to work. That is, you must add the "
|
| 86 |
+
"following to the _TOP_ of your main script, "
|
| 87 |
+
"before any other imports (since they might be "
|
| 88 |
+
"setting it otherwise):\n"
|
| 89 |
+
" import multiprocessing as mp\n"
|
| 90 |
+
" if __name__ == '__main__':\n"
|
| 91 |
+
" mp.set_start_method('spawn')\n")
|
| 92 |
+
|
| 93 |
+
self.pool = Pool(
|
| 94 |
+
processes=len(device_ids) * n_workers,
|
| 95 |
+
initializer=MjRenderPool._worker_init,
|
| 96 |
+
initargs=(
|
| 97 |
+
model.get_mjb(),
|
| 98 |
+
worker_id,
|
| 99 |
+
device_ids,
|
| 100 |
+
self._shared_rgbs,
|
| 101 |
+
self._shared_depths,
|
| 102 |
+
modder))
|
| 103 |
+
|
| 104 |
+
@staticmethod
|
| 105 |
+
def _worker_init(mjb_bytes, worker_id, device_ids,
|
| 106 |
+
shared_rgbs, shared_depths, modder):
|
| 107 |
+
"""
|
| 108 |
+
Initializes the global state for the workers.
|
| 109 |
+
"""
|
| 110 |
+
s = RenderPoolStorage()
|
| 111 |
+
|
| 112 |
+
with worker_id.get_lock():
|
| 113 |
+
proc_worker_id = worker_id.value
|
| 114 |
+
worker_id.value += 1
|
| 115 |
+
s.device_id = device_ids[proc_worker_id % len(device_ids)]
|
| 116 |
+
|
| 117 |
+
s.shared_rgbs_array = np.frombuffer(
|
| 118 |
+
shared_rgbs.get_obj(), dtype=ctypes.c_uint8)
|
| 119 |
+
s.shared_depths_array = np.frombuffer(
|
| 120 |
+
shared_depths.get_obj(), dtype=ctypes.c_float)
|
| 121 |
+
|
| 122 |
+
# avoid a circular import
|
| 123 |
+
from mujoco_py import load_model_from_mjb, MjRenderContext, MjSim
|
| 124 |
+
s.sim = MjSim(load_model_from_mjb(mjb_bytes))
|
| 125 |
+
# attach a render context to the sim (needs to happen before
|
| 126 |
+
# modder is called, since it might need to upload textures
|
| 127 |
+
# to the GPU).
|
| 128 |
+
MjRenderContext(s.sim, device_id=s.device_id)
|
| 129 |
+
|
| 130 |
+
if modder is not None:
|
| 131 |
+
s.modder = modder(s.sim, random_state=proc_worker_id)
|
| 132 |
+
s.modder.whiten_materials()
|
| 133 |
+
else:
|
| 134 |
+
s.modder = None
|
| 135 |
+
|
| 136 |
+
global _render_pool_storage
|
| 137 |
+
_render_pool_storage = s
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def _worker_render(worker_id, state, width, height,
|
| 141 |
+
camera_name, randomize):
|
| 142 |
+
"""
|
| 143 |
+
Main target function for the workers.
|
| 144 |
+
"""
|
| 145 |
+
s = _render_pool_storage
|
| 146 |
+
|
| 147 |
+
forward = False
|
| 148 |
+
if state is not None:
|
| 149 |
+
s.sim.set_state(state)
|
| 150 |
+
forward = True
|
| 151 |
+
if randomize and s.modder is not None:
|
| 152 |
+
s.modder.randomize()
|
| 153 |
+
forward = True
|
| 154 |
+
if forward:
|
| 155 |
+
s.sim.forward()
|
| 156 |
+
|
| 157 |
+
rgb_block = width * height * 3
|
| 158 |
+
rgb_offset = rgb_block * worker_id
|
| 159 |
+
rgb = s.shared_rgbs_array[rgb_offset:rgb_offset + rgb_block]
|
| 160 |
+
rgb = rgb.reshape(height, width, 3)
|
| 161 |
+
|
| 162 |
+
depth_block = width * height
|
| 163 |
+
depth_offset = depth_block * worker_id
|
| 164 |
+
depth = s.shared_depths_array[depth_offset:depth_offset + depth_block]
|
| 165 |
+
depth = depth.reshape(height, width)
|
| 166 |
+
|
| 167 |
+
rgb[:], depth[:] = s.sim.render(
|
| 168 |
+
width, height, camera_name=camera_name, depth=True,
|
| 169 |
+
device_id=s.device_id)
|
| 170 |
+
|
| 171 |
+
def render(self, width, height, states=None, camera_name=None,
|
| 172 |
+
depth=False, randomize=False, copy=True):
|
| 173 |
+
"""
|
| 174 |
+
Renders the simulations in batch. If no states are provided,
|
| 175 |
+
the max_batch_size will be used.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
- width (int): width of image to render.
|
| 179 |
+
- height (int): height of image to render.
|
| 180 |
+
- states (list): list of MjSimStates; updates the states before
|
| 181 |
+
rendering. Batch size will be number of states supplied.
|
| 182 |
+
- camera_name (str): name of camera to render from.
|
| 183 |
+
- depth (bool): if True, also return depth.
|
| 184 |
+
- randomize (bool): calls modder.rand_all() before rendering.
|
| 185 |
+
- copy (bool): return a copy rather than a reference
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
- rgbs: NxHxWx3 numpy array of N images in batch of width W
|
| 189 |
+
and height H.
|
| 190 |
+
- depth: NxHxW numpy array of N images in batch of width W
|
| 191 |
+
and height H. Only returned if depth=True.
|
| 192 |
+
"""
|
| 193 |
+
if self._closed:
|
| 194 |
+
raise RuntimeError("The pool has been closed.")
|
| 195 |
+
|
| 196 |
+
if (width * height) > self._max_image_size:
|
| 197 |
+
raise ValueError(
|
| 198 |
+
"Requested image larger than maximum image size. Create "
|
| 199 |
+
"a new RenderPool with a larger maximum image size.")
|
| 200 |
+
if states is None:
|
| 201 |
+
batch_size = self._max_batch_size
|
| 202 |
+
states = [None] * batch_size
|
| 203 |
+
else:
|
| 204 |
+
batch_size = len(states)
|
| 205 |
+
|
| 206 |
+
if batch_size > self._max_batch_size:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
"Requested batch size larger than max batch size. Create "
|
| 209 |
+
"a new RenderPool with a larger max batch size.")
|
| 210 |
+
|
| 211 |
+
self.pool.starmap(
|
| 212 |
+
MjRenderPool._worker_render,
|
| 213 |
+
[(i, state, width, height, camera_name, randomize)
|
| 214 |
+
for i, state in enumerate(states)])
|
| 215 |
+
|
| 216 |
+
rgbs = self._shared_rgbs_array[:width * height * 3 * batch_size]
|
| 217 |
+
rgbs = rgbs.reshape(batch_size, height, width, 3)
|
| 218 |
+
if copy:
|
| 219 |
+
rgbs = rgbs.copy()
|
| 220 |
+
|
| 221 |
+
if depth:
|
| 222 |
+
depths = self._shared_depths_array[:width * height * batch_size]
|
| 223 |
+
depths = depths.reshape(batch_size, height, width).copy()
|
| 224 |
+
if copy:
|
| 225 |
+
depths = depths.copy()
|
| 226 |
+
return rgbs, depths
|
| 227 |
+
else:
|
| 228 |
+
return rgbs
|
| 229 |
+
|
| 230 |
+
def close(self):
|
| 231 |
+
"""
|
| 232 |
+
Closes the pool and terminates child processes.
|
| 233 |
+
"""
|
| 234 |
+
if not self._closed:
|
| 235 |
+
if self.pool is not None:
|
| 236 |
+
self.pool.close()
|
| 237 |
+
self.pool.join()
|
| 238 |
+
self._closed = True
|
| 239 |
+
|
| 240 |
+
def __del__(self):
|
| 241 |
+
self.close()
|
mujoco-py-2.1.2.14/mujoco_py/mjsim.pyx
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from xml.dom import minidom
|
| 2 |
+
from mujoco_py.utils import remove_empty_lines
|
| 3 |
+
from mujoco_py.builder import build_callback_fn
|
| 4 |
+
from threading import Lock
|
| 5 |
+
|
| 6 |
+
_MjSim_render_lock = Lock()
|
| 7 |
+
|
| 8 |
+
ctypedef void (*substep_udd_t)(const mjModel* m, mjData* d)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
cdef class MjSim(object):
|
| 12 |
+
"""MjSim represents a running simulation including its state.
|
| 13 |
+
|
| 14 |
+
Similar to Gym's ``MujocoEnv``, it internally wraps a :class:`.PyMjModel`
|
| 15 |
+
and a :class:`.PyMjData`.
|
| 16 |
+
|
| 17 |
+
Parameters
|
| 18 |
+
----------
|
| 19 |
+
model : :class:`.PyMjModel`
|
| 20 |
+
The model to simulate.
|
| 21 |
+
data : :class:`.PyMjData`
|
| 22 |
+
Optional container for the simulation state. Will be created if ``None``.
|
| 23 |
+
nsubsteps : int
|
| 24 |
+
Optional number of MuJoCo steps to run for every call to :meth:`.step`.
|
| 25 |
+
Buffers will be swapped only once per step.
|
| 26 |
+
udd_callback : fn(:class:`.MjSim`) -> dict
|
| 27 |
+
Optional callback for user-defined dynamics. At every call to
|
| 28 |
+
:meth:`.step`, it receives an MjSim object ``sim`` containing the
|
| 29 |
+
current user-defined dynamics state in ``sim.udd_state``, and returns the
|
| 30 |
+
next ``udd_state`` after applying the user-defined dynamics. This is
|
| 31 |
+
useful e.g. for reward functions that operate over functions of historical
|
| 32 |
+
state.
|
| 33 |
+
substep_callback : str or int or None
|
| 34 |
+
This uses a compiled C function as user-defined dynamics in substeps.
|
| 35 |
+
If given as a string, it's compiled as a C function and set as pointer.
|
| 36 |
+
If given as int, it's interpreted as a function pointer.
|
| 37 |
+
See :meth:`.set_substep_callback` for detailed info.
|
| 38 |
+
userdata_names : list of strings or None
|
| 39 |
+
This is a convenience parameter which is just set on the model.
|
| 40 |
+
Equivalent to calling ``model.set_userdata_names``
|
| 41 |
+
render_callback : callback for rendering.
|
| 42 |
+
"""
|
| 43 |
+
# MjRenderContext for rendering camera views.
|
| 44 |
+
cdef readonly list render_contexts
|
| 45 |
+
cdef readonly object _render_context_window
|
| 46 |
+
cdef readonly object _render_context_offscreen
|
| 47 |
+
|
| 48 |
+
# MuJoCo model
|
| 49 |
+
cdef readonly PyMjModel model
|
| 50 |
+
# MuJoCo data
|
| 51 |
+
"""
|
| 52 |
+
DATAZ
|
| 53 |
+
"""
|
| 54 |
+
cdef readonly PyMjData data
|
| 55 |
+
# Number of substeps when calling .step
|
| 56 |
+
cdef public int nsubsteps
|
| 57 |
+
# User defined state.
|
| 58 |
+
cdef public dict udd_state
|
| 59 |
+
# User defined dynamics callback
|
| 60 |
+
cdef readonly object _udd_callback
|
| 61 |
+
# Allows to store extra information in MjSim.
|
| 62 |
+
cdef readonly dict extras
|
| 63 |
+
# Function pointer for substep callback, stored as uintptr
|
| 64 |
+
cdef readonly uintptr_t substep_callback_ptr
|
| 65 |
+
# Callback executed before rendering.
|
| 66 |
+
cdef public object render_callback
|
| 67 |
+
|
| 68 |
+
def __cinit__(self, PyMjModel model, PyMjData data=None, int nsubsteps=1,
|
| 69 |
+
udd_callback=None, substep_callback=None, userdata_names=None,
|
| 70 |
+
render_callback=None):
|
| 71 |
+
self.nsubsteps = nsubsteps
|
| 72 |
+
self.model = model
|
| 73 |
+
if data is None:
|
| 74 |
+
with wrap_mujoco_warning():
|
| 75 |
+
_data = mj_makeData(self.model.ptr)
|
| 76 |
+
if _data == NULL:
|
| 77 |
+
raise Exception('mj_makeData failed!')
|
| 78 |
+
self.data = WrapMjData(_data, self.model)
|
| 79 |
+
else:
|
| 80 |
+
self.data = data
|
| 81 |
+
|
| 82 |
+
self.render_contexts = []
|
| 83 |
+
self._render_context_offscreen = None
|
| 84 |
+
self._render_context_window = None
|
| 85 |
+
self.udd_state = None
|
| 86 |
+
self.udd_callback = udd_callback
|
| 87 |
+
self.render_callback = render_callback
|
| 88 |
+
self.extras = {}
|
| 89 |
+
self.set_substep_callback(substep_callback, userdata_names)
|
| 90 |
+
|
| 91 |
+
def reset(self):
|
| 92 |
+
"""
|
| 93 |
+
Resets the simulation data and clears buffers.
|
| 94 |
+
"""
|
| 95 |
+
with wrap_mujoco_warning():
|
| 96 |
+
mj_resetData(self.model.ptr, self.data.ptr)
|
| 97 |
+
|
| 98 |
+
self.udd_state = None
|
| 99 |
+
self.step_udd()
|
| 100 |
+
|
| 101 |
+
def forward(self):
|
| 102 |
+
"""
|
| 103 |
+
Computes the forward kinematics. Calls ``mj_forward`` internally.
|
| 104 |
+
"""
|
| 105 |
+
with wrap_mujoco_warning():
|
| 106 |
+
mj_forward(self.model.ptr, self.data.ptr)
|
| 107 |
+
|
| 108 |
+
def set_constants(self):
|
| 109 |
+
"""
|
| 110 |
+
Set constant fields of mjModel, corresponding to qpos0 configuration.
|
| 111 |
+
"""
|
| 112 |
+
with wrap_mujoco_warning():
|
| 113 |
+
mj_setConst(self.model.ptr, self.data.ptr)
|
| 114 |
+
|
| 115 |
+
def step(self, with_udd=True):
|
| 116 |
+
"""
|
| 117 |
+
Advances the simulation by calling ``mj_step``.
|
| 118 |
+
|
| 119 |
+
If ``qpos`` or ``qvel`` have been modified directly, the user is required to call
|
| 120 |
+
:meth:`.forward` before :meth:`.step` if their ``udd_callback`` requires access to MuJoCo state
|
| 121 |
+
set during the forward dynamics.
|
| 122 |
+
"""
|
| 123 |
+
if with_udd:
|
| 124 |
+
self.step_udd()
|
| 125 |
+
|
| 126 |
+
with wrap_mujoco_warning():
|
| 127 |
+
for _ in range(self.nsubsteps):
|
| 128 |
+
self.substep_callback()
|
| 129 |
+
mj_step(self.model.ptr, self.data.ptr)
|
| 130 |
+
|
| 131 |
+
def render(self, width=None, height=None, *, camera_name=None, depth=False,
|
| 132 |
+
mode='offscreen', device_id=-1, segmentation=False):
|
| 133 |
+
"""
|
| 134 |
+
Renders view from a camera and returns image as an `numpy.ndarray`.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
- width (int): desired image width.
|
| 138 |
+
- height (int): desired image height.
|
| 139 |
+
- camera_name (str): name of camera in model. If None, the free
|
| 140 |
+
camera will be used.
|
| 141 |
+
- depth (bool): if True, also return depth buffer
|
| 142 |
+
- device (int): device to use for rendering (only for GPU-backed
|
| 143 |
+
rendering).
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
- rgb (uint8 array): image buffer from camera
|
| 147 |
+
- depth (float array): depth buffer from camera (only returned
|
| 148 |
+
if depth=True)
|
| 149 |
+
"""
|
| 150 |
+
if camera_name is None:
|
| 151 |
+
camera_id = None
|
| 152 |
+
else:
|
| 153 |
+
camera_id = self.model.camera_name2id(camera_name)
|
| 154 |
+
|
| 155 |
+
if mode == 'offscreen':
|
| 156 |
+
with _MjSim_render_lock:
|
| 157 |
+
if self._render_context_offscreen is None:
|
| 158 |
+
render_context = MjRenderContextOffscreen(
|
| 159 |
+
self, device_id=device_id)
|
| 160 |
+
else:
|
| 161 |
+
render_context = self._render_context_offscreen
|
| 162 |
+
|
| 163 |
+
render_context.render(
|
| 164 |
+
width=width, height=height, camera_id=camera_id, segmentation=segmentation)
|
| 165 |
+
return render_context.read_pixels(
|
| 166 |
+
width, height, depth=depth, segmentation=segmentation)
|
| 167 |
+
elif mode == 'window':
|
| 168 |
+
if self._render_context_window is None:
|
| 169 |
+
from mujoco_py.mjviewer import MjViewer
|
| 170 |
+
render_context = MjViewer(self)
|
| 171 |
+
else:
|
| 172 |
+
render_context = self._render_context_window
|
| 173 |
+
|
| 174 |
+
render_context.render()
|
| 175 |
+
|
| 176 |
+
else:
|
| 177 |
+
raise ValueError("Mode must be either 'window' or 'offscreen'.")
|
| 178 |
+
|
| 179 |
+
def add_render_context(self, render_context):
|
| 180 |
+
self.render_contexts.append(render_context)
|
| 181 |
+
if render_context.offscreen and self._render_context_offscreen is None:
|
| 182 |
+
self._render_context_offscreen = render_context
|
| 183 |
+
elif not render_context.offscreen and self._render_context_window is None:
|
| 184 |
+
self._render_context_window = render_context
|
| 185 |
+
|
| 186 |
+
@property
|
| 187 |
+
def udd_callback(self):
|
| 188 |
+
return self._udd_callback
|
| 189 |
+
|
| 190 |
+
@udd_callback.setter
|
| 191 |
+
def udd_callback(self, value):
|
| 192 |
+
self._udd_callback = value
|
| 193 |
+
self.udd_state = None
|
| 194 |
+
self.step_udd()
|
| 195 |
+
|
| 196 |
+
cpdef substep_callback(self):
|
| 197 |
+
if self.substep_callback_ptr:
|
| 198 |
+
(<mjfGeneric>self.substep_callback_ptr)(self.model.ptr, self.data.ptr)
|
| 199 |
+
|
| 200 |
+
def set_substep_callback(self, substep_callback, userdata_names=None):
|
| 201 |
+
'''
|
| 202 |
+
Set a substep callback function.
|
| 203 |
+
|
| 204 |
+
Parameters :
|
| 205 |
+
substep_callback : str or int or None
|
| 206 |
+
If `substep_callback` is a string, compile to function pointer and set.
|
| 207 |
+
See `builder.build_callback_fn()` for documentation.
|
| 208 |
+
If `substep_callback` is an int, we interpret it as a function pointer.
|
| 209 |
+
If `substep_callback` is None, we disable substep_callbacks.
|
| 210 |
+
userdata_names : list of strings or None
|
| 211 |
+
This is a convenience parameter, if not None, this is passed
|
| 212 |
+
onto ``model.set_userdata_names()``.
|
| 213 |
+
'''
|
| 214 |
+
if userdata_names is not None:
|
| 215 |
+
self.model.set_userdata_names(userdata_names)
|
| 216 |
+
if substep_callback is None:
|
| 217 |
+
self.substep_callback_ptr = 0
|
| 218 |
+
elif isinstance(substep_callback, int):
|
| 219 |
+
self.substep_callback_ptr = substep_callback
|
| 220 |
+
elif isinstance(substep_callback, str):
|
| 221 |
+
self.substep_callback_ptr = build_callback_fn(substep_callback,
|
| 222 |
+
self.model.userdata_names)
|
| 223 |
+
else:
|
| 224 |
+
raise TypeError('invalid: {}'.format(type(substep_callback)))
|
| 225 |
+
|
| 226 |
+
def step_udd(self):
|
| 227 |
+
if self._udd_callback is None:
|
| 228 |
+
self.udd_state = {}
|
| 229 |
+
else:
|
| 230 |
+
schema_example = self.udd_state
|
| 231 |
+
self.udd_state = self._udd_callback(self)
|
| 232 |
+
# Check to make sure the udd_state has consistent keys and dimension across steps
|
| 233 |
+
if schema_example is not None:
|
| 234 |
+
keys = set(schema_example.keys()) | set(self.udd_state.keys())
|
| 235 |
+
for key in keys:
|
| 236 |
+
assert key in schema_example, "Keys cannot be added to udd_state between steps."
|
| 237 |
+
assert key in self.udd_state, "Keys cannot be dropped from udd_state between steps."
|
| 238 |
+
if isinstance(schema_example[key], Number):
|
| 239 |
+
assert isinstance(self.udd_state[key], Number), \
|
| 240 |
+
"Every value in udd_state must be either a number or a numpy array"
|
| 241 |
+
else:
|
| 242 |
+
assert isinstance(self.udd_state[key], np.ndarray), \
|
| 243 |
+
"Every value in udd_state must be either a number or a numpy array"
|
| 244 |
+
assert self.udd_state[key].shape == schema_example[key].shape, \
|
| 245 |
+
"Numpy array values in udd_state must keep the same dimension across steps."
|
| 246 |
+
|
| 247 |
+
def get_state(self):
|
| 248 |
+
""" Returns a copy of the simulator state. """
|
| 249 |
+
qpos = np.copy(self.data.qpos)
|
| 250 |
+
qvel = np.copy(self.data.qvel)
|
| 251 |
+
if self.model.na == 0:
|
| 252 |
+
act = None
|
| 253 |
+
else:
|
| 254 |
+
act = np.copy(self.data.act)
|
| 255 |
+
udd_state = copy.deepcopy(self.udd_state)
|
| 256 |
+
|
| 257 |
+
return MjSimState(self.data.time, qpos, qvel, act, udd_state)
|
| 258 |
+
|
| 259 |
+
def set_state(self, value):
|
| 260 |
+
"""
|
| 261 |
+
Sets the state from an MjSimState.
|
| 262 |
+
If the MjSimState was previously unflattened from a numpy array, consider
|
| 263 |
+
set_state_from_flattened, as the defensive copy is a substantial overhead
|
| 264 |
+
in an inner loop.
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
- value (MjSimState): the desired state.
|
| 268 |
+
- call_forward: optionally call sim.forward(). Called by default if
|
| 269 |
+
the udd_callback is set.
|
| 270 |
+
"""
|
| 271 |
+
self.data.time = value.time
|
| 272 |
+
self.data.qpos[:] = np.copy(value.qpos)
|
| 273 |
+
self.data.qvel[:] = np.copy(value.qvel)
|
| 274 |
+
if self.model.na != 0:
|
| 275 |
+
self.data.act[:] = np.copy(value.act)
|
| 276 |
+
self.udd_state = copy.deepcopy(value.udd_state)
|
| 277 |
+
|
| 278 |
+
def set_state_from_flattened(self, value):
|
| 279 |
+
""" This helper method sets the state from an array without requiring a defensive copy."""
|
| 280 |
+
state = MjSimState.from_flattened(value, self)
|
| 281 |
+
|
| 282 |
+
self.data.time = state.time
|
| 283 |
+
self.data.qpos[:] = state.qpos
|
| 284 |
+
self.data.qvel[:] = state.qvel
|
| 285 |
+
if self.model.na != 0:
|
| 286 |
+
self.data.act[:] = state.act
|
| 287 |
+
self.udd_state = state.udd_state
|
| 288 |
+
|
| 289 |
+
def save(self, file, format='xml', keep_inertials=False):
|
| 290 |
+
"""
|
| 291 |
+
Saves the simulator model and state to a file as either
|
| 292 |
+
a MuJoCo XML or MJB file. The current state is saved as
|
| 293 |
+
a keyframe in the model file. This is useful for debugging
|
| 294 |
+
using MuJoCo's `simulate` utility.
|
| 295 |
+
|
| 296 |
+
Note that this doesn't save the UDD-state which is
|
| 297 |
+
part of MjSimState, since that's not supported natively
|
| 298 |
+
by MuJoCo. If you want to save the model together with
|
| 299 |
+
the UDD-state, you should use the `get_xml` or `get_mjb`
|
| 300 |
+
methods on `MjModel` together with `MjSim.get_state` and
|
| 301 |
+
save them with e.g. pickle.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
- file (IO stream): stream to write model to.
|
| 305 |
+
- format: format to use (either 'xml' or 'mjb')
|
| 306 |
+
- keep_inertials (bool): if False, removes all <inertial>
|
| 307 |
+
properties derived automatically for geoms by MuJoco. Note
|
| 308 |
+
that this removes ones that were provided by the user
|
| 309 |
+
as well.
|
| 310 |
+
"""
|
| 311 |
+
xml_str = self.model.get_xml()
|
| 312 |
+
dom = minidom.parseString(xml_str)
|
| 313 |
+
|
| 314 |
+
mujoco_node = dom.childNodes[0]
|
| 315 |
+
assert mujoco_node.tagName == 'mujoco'
|
| 316 |
+
|
| 317 |
+
keyframe_el = dom.createElement('keyframe')
|
| 318 |
+
key_el = dom.createElement('key')
|
| 319 |
+
keyframe_el.appendChild(key_el)
|
| 320 |
+
mujoco_node.appendChild(keyframe_el)
|
| 321 |
+
|
| 322 |
+
def str_array(arr):
|
| 323 |
+
return " ".join(map(str, arr))
|
| 324 |
+
|
| 325 |
+
key_el.setAttribute('time', str(self.data.time))
|
| 326 |
+
key_el.setAttribute('qpos', str_array(self.data.qpos))
|
| 327 |
+
key_el.setAttribute('qvel', str_array(self.data.qvel))
|
| 328 |
+
if self.data.act is not None:
|
| 329 |
+
key_el.setAttribute('act', str_array(self.data.act))
|
| 330 |
+
|
| 331 |
+
if not keep_inertials:
|
| 332 |
+
for element in dom.getElementsByTagName('inertial'):
|
| 333 |
+
element.parentNode.removeChild(element)
|
| 334 |
+
|
| 335 |
+
result_xml = remove_empty_lines(dom.toprettyxml(indent=" " * 4))
|
| 336 |
+
|
| 337 |
+
if format == 'xml':
|
| 338 |
+
file.write(result_xml)
|
| 339 |
+
elif format == 'mjb':
|
| 340 |
+
new_model = load_model_from_xml(result_xml)
|
| 341 |
+
file.write(new_model.get_mjb())
|
| 342 |
+
else:
|
| 343 |
+
raise ValueError("Unsupported format. Valid ones are 'xml' and 'mjb'")
|
| 344 |
+
|
| 345 |
+
def ray(self, pnt, vec, include_static_geoms=True, exclude_body=-1, group_filter=None):
|
| 346 |
+
"""
|
| 347 |
+
Cast a ray into the scene, and return the first valid geom it intersects.
|
| 348 |
+
pnt - origin point of the ray in world coordinates (X Y Z)
|
| 349 |
+
vec - direction of the ray in world coordinates (X Y Z)
|
| 350 |
+
include_static_geoms - if False, we exclude geoms that are children of worldbody.
|
| 351 |
+
exclude_body - if this is a body ID, we exclude all children geoms of this body.
|
| 352 |
+
group_filter - a vector of booleans of length const.NGROUP
|
| 353 |
+
which specifies what geom groups (stored in model.geom_group)
|
| 354 |
+
to enable or disable. If none, all groups are used
|
| 355 |
+
Returns (distance, geomid) where
|
| 356 |
+
distance - distance along ray until first collision with geom
|
| 357 |
+
geomid - id of the geom the ray collided with
|
| 358 |
+
If no collision was found in the scene, return (-1, None)
|
| 359 |
+
|
| 360 |
+
NOTE: sometimes self.forward() needs to be called before self.ray().
|
| 361 |
+
|
| 362 |
+
See self.ray_fast_group() and self.ray_fast_nogroup() for versions of this call
|
| 363 |
+
with more stringent type requirements.
|
| 364 |
+
"""
|
| 365 |
+
cdef mjtNum distance
|
| 366 |
+
cdef mjtNum[::view.contiguous] pnt_view = pnt
|
| 367 |
+
cdef mjtNum[::view.contiguous] vec_view = vec
|
| 368 |
+
|
| 369 |
+
if group_filter is None:
|
| 370 |
+
return self.ray_fast_nogroup(
|
| 371 |
+
np.asarray(pnt, dtype=np.float64),
|
| 372 |
+
np.asarray(vec, dtype=np.float64),
|
| 373 |
+
1 if include_static_geoms else 0,
|
| 374 |
+
exclude_body)
|
| 375 |
+
else:
|
| 376 |
+
return self.ray_fast_group(
|
| 377 |
+
np.asarray(pnt, dtype=np.float64),
|
| 378 |
+
np.asarray(vec, dtype=np.float64),
|
| 379 |
+
np.asarray(group_filter, dtype=np.uint8),
|
| 380 |
+
1 if include_static_geoms else 0,
|
| 381 |
+
exclude_body)
|
| 382 |
+
|
| 383 |
+
def ray_fast_group(self,
|
| 384 |
+
np.ndarray[np.float64_t, mode="c", ndim=1] pnt,
|
| 385 |
+
np.ndarray[np.float64_t, mode="c", ndim=1] vec,
|
| 386 |
+
np.ndarray[np.uint8_t, mode="c", ndim=1] geomgroup,
|
| 387 |
+
mjtByte flg_static=1,
|
| 388 |
+
int bodyexclude=-1):
|
| 389 |
+
"""
|
| 390 |
+
Faster version of sim.ray(), which avoids extra copies,
|
| 391 |
+
but needs to be given all the correct type arrays.
|
| 392 |
+
|
| 393 |
+
See self.ray() for explanation of arguments
|
| 394 |
+
"""
|
| 395 |
+
cdef int geomid
|
| 396 |
+
cdef mjtNum distance
|
| 397 |
+
cdef mjtNum[::view.contiguous] pnt_view = pnt
|
| 398 |
+
cdef mjtNum[::view.contiguous] vec_view = vec
|
| 399 |
+
cdef mjtByte[::view.contiguous] geomgroup_view = geomgroup
|
| 400 |
+
|
| 401 |
+
distance = mj_ray(self.model.ptr,
|
| 402 |
+
self.data.ptr,
|
| 403 |
+
&pnt_view[0],
|
| 404 |
+
&vec_view[0],
|
| 405 |
+
&geomgroup_view[0],
|
| 406 |
+
flg_static,
|
| 407 |
+
bodyexclude,
|
| 408 |
+
&geomid)
|
| 409 |
+
return (distance, geomid)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def ray_fast_nogroup(self,
|
| 413 |
+
np.ndarray[np.float64_t, mode="c", ndim=1] pnt,
|
| 414 |
+
np.ndarray[np.float64_t, mode="c", ndim=1] vec,
|
| 415 |
+
mjtByte flg_static=1,
|
| 416 |
+
int bodyexclude=-1):
|
| 417 |
+
"""
|
| 418 |
+
Faster version of sim.ray(), which avoids extra copies,
|
| 419 |
+
but needs to be given all the correct type arrays.
|
| 420 |
+
|
| 421 |
+
This version hardcodes the geomgroup to NULL.
|
| 422 |
+
(Can't easily express a signature that is "numpy array of specific type or None")
|
| 423 |
+
|
| 424 |
+
See self.ray() for explanation of arguments
|
| 425 |
+
"""
|
| 426 |
+
cdef int geomid
|
| 427 |
+
cdef mjtNum distance
|
| 428 |
+
cdef mjtNum[::view.contiguous] pnt_view = pnt
|
| 429 |
+
cdef mjtNum[::view.contiguous] vec_view = vec
|
| 430 |
+
|
| 431 |
+
distance = mj_ray(self.model.ptr,
|
| 432 |
+
self.data.ptr,
|
| 433 |
+
&pnt_view[0],
|
| 434 |
+
&vec_view[0],
|
| 435 |
+
NULL,
|
| 436 |
+
flg_static,
|
| 437 |
+
bodyexclude,
|
| 438 |
+
&geomid)
|
| 439 |
+
return (distance, geomid)
|
mujoco-py-2.1.2.14/mujoco_py/pxd/__init__.py
ADDED
|
File without changes
|
mujoco-py-2.1.2.14/mujoco_py/pxd/mjdata.pxd
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cdef extern from "mjdata.h" nogil:
|
| 2 |
+
|
| 3 |
+
#---------------------------- primitive types (mjt) ------------------------------------
|
| 4 |
+
|
| 5 |
+
ctypedef enum mjtWarning: # warning types
|
| 6 |
+
mjWARN_INERTIA = 0, # (near) singular inertia matrix
|
| 7 |
+
mjWARN_CONTACTFULL, # too many contacts in contact list
|
| 8 |
+
mjWARN_CNSTRFULL, # too many constraints
|
| 9 |
+
mjWARN_VGEOMFULL, # too many visual geoms
|
| 10 |
+
mjWARN_BADQPOS, # bad number in qpos
|
| 11 |
+
mjWARN_BADQVEL, # bad number in qvel
|
| 12 |
+
mjWARN_BADQACC, # bad number in qacc
|
| 13 |
+
mjWARN_BADCTRL, # bad number in ctrl
|
| 14 |
+
|
| 15 |
+
enum: mjNWARNING # number of warnings
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
ctypedef enum mjtTimer:
|
| 19 |
+
# main api
|
| 20 |
+
mjTIMER_STEP = 0, # step
|
| 21 |
+
mjTIMER_FORWARD, # forward
|
| 22 |
+
mjTIMER_INVERSE, # inverse
|
| 23 |
+
|
| 24 |
+
# breakdown of step/forward
|
| 25 |
+
mjTIMER_POSITION, # fwdPosition
|
| 26 |
+
mjTIMER_VELOCITY, # fwdVelocity
|
| 27 |
+
mjTIMER_ACTUATION, # fwdActuation
|
| 28 |
+
mjTIMER_ACCELERATION, # fwdAcceleration
|
| 29 |
+
mjTIMER_CONSTRAINT, # fwdConstraint
|
| 30 |
+
|
| 31 |
+
# breakdown of fwdPosition
|
| 32 |
+
mjTIMER_POS_KINEMATICS, # kinematics, com, tendon, transmission
|
| 33 |
+
mjTIMER_POS_INERTIA, # inertia computations
|
| 34 |
+
mjTIMER_POS_COLLISION, # collision detection
|
| 35 |
+
mjTIMER_POS_MAKE, # make constraints
|
| 36 |
+
mjTIMER_POS_PROJECT, # project constraints
|
| 37 |
+
|
| 38 |
+
enum: mjNTIMER # number of timers
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
#------------------------------ mjContact ----------------------------------------------
|
| 42 |
+
|
| 43 |
+
ctypedef struct mjContact: # result of collision detection functions
|
| 44 |
+
# contact parameters set by geom-specific collision detector
|
| 45 |
+
mjtNum dist # distance between nearest points; neg: penetration
|
| 46 |
+
mjtNum pos[3] # position of contact point: midpoint between geoms
|
| 47 |
+
mjtNum frame[9] # normal is in [0-2]
|
| 48 |
+
|
| 49 |
+
# contact parameters set by mj_collideGeoms
|
| 50 |
+
mjtNum includemargin # include if dist<includemargin=margin-gap
|
| 51 |
+
mjtNum friction[5] # tangent1, 2, spin, roll1, 2
|
| 52 |
+
mjtNum solref[mjNREF] # constraint solver reference
|
| 53 |
+
mjtNum solimp[mjNIMP] # constraint solver impedance
|
| 54 |
+
|
| 55 |
+
# storage used internally by constraint solver
|
| 56 |
+
mjtNum mu # friction of regularized cone
|
| 57 |
+
mjtNum H[36] # cone Hessian, set by mj_updateConstraint
|
| 58 |
+
|
| 59 |
+
# contact descriptors set by mj_collideGeoms
|
| 60 |
+
int dim # contact space dimensionality: 1, 3, 4 or 6
|
| 61 |
+
int geom1 # id of geom 1
|
| 62 |
+
int geom2 # id of geom 2
|
| 63 |
+
|
| 64 |
+
# flag set by mj_fuseContact or mj_instantianteEquality
|
| 65 |
+
int exclude # 0: include, 1: in gap, 2: fused, 3: equality
|
| 66 |
+
|
| 67 |
+
# address computed by mj_instantiateContact
|
| 68 |
+
int efc_address # address in efc; -1: not included, -2-i: distance constraint i ???
|
| 69 |
+
|
| 70 |
+
#------------------------------ diagnostics --------------------------------------------
|
| 71 |
+
|
| 72 |
+
ctypedef struct mjWarningStat: # warning statistics
|
| 73 |
+
int lastinfo # info from last warning
|
| 74 |
+
int number # how many times was warning raised
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
ctypedef struct mjTimerStat: # timer statistics
|
| 78 |
+
mjtNum duration # cumulative duration
|
| 79 |
+
int number # how many times was timer called
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
ctypedef struct mjSolverStat: # per-iteration solver statistics
|
| 83 |
+
mjtNum improvement # cost reduction, scaled by 1/trace(M(qpos0))
|
| 84 |
+
mjtNum gradient # gradient norm (primal only, scaled)
|
| 85 |
+
mjtNum lineslope # slope in linesearch
|
| 86 |
+
int nactive # number of active constraints
|
| 87 |
+
int nchange # number of constraint state changes
|
| 88 |
+
int neval # number of cost evaluations in line search
|
| 89 |
+
int nupdate # number of Cholesky updates in line search
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#---------------------------------- mjData ---------------------------------------------
|
| 94 |
+
ctypedef struct mjData:
|
| 95 |
+
# constant sizes
|
| 96 |
+
int nstack # number of mjtNums that can fit in stack
|
| 97 |
+
int nbuffer # size of main buffer in bytes
|
| 98 |
+
|
| 99 |
+
# stack pointer
|
| 100 |
+
int pstack # first available mjtNum address in stack
|
| 101 |
+
|
| 102 |
+
# memory utilization stats
|
| 103 |
+
int maxuse_stack # maximum stack allocation
|
| 104 |
+
int maxuse_con # maximum number of contacts
|
| 105 |
+
int maxuse_efc # maximum number of scalar constraints
|
| 106 |
+
|
| 107 |
+
# diagnostics
|
| 108 |
+
mjWarningStat warning[mjNWARNING] # warning statistics
|
| 109 |
+
mjTimerStat timer[mjNTIMER] # timer statistics
|
| 110 |
+
mjSolverStat solver[mjNSOLVER] # solver statistics per iteration
|
| 111 |
+
int solver_iter # number of solver iterations
|
| 112 |
+
int solver_nnz # number of non-zeros in Hessian or efc_AR
|
| 113 |
+
mjtNum solver_fwdinv[2] # forward-inverse comparison: qfrc, efc
|
| 114 |
+
|
| 115 |
+
# variable sizes
|
| 116 |
+
int ne # number of equality constraints
|
| 117 |
+
int nf # number of friction constraints
|
| 118 |
+
int nefc # number of constraints
|
| 119 |
+
int ncon # number of detected contacts
|
| 120 |
+
|
| 121 |
+
# global properties
|
| 122 |
+
mjtNum time # simulation time
|
| 123 |
+
mjtNum energy[2] # potential, kinetic energy
|
| 124 |
+
|
| 125 |
+
#-------------------------------- end of info header
|
| 126 |
+
|
| 127 |
+
# buffers
|
| 128 |
+
void* buffer # main buffer; all pointers point in it (nbuffer bytes)
|
| 129 |
+
mjtNum* stack # stack buffer (nstack mjtNums)
|
| 130 |
+
|
| 131 |
+
#-------------------------------- main inputs and outputs of the computation
|
| 132 |
+
|
| 133 |
+
# state
|
| 134 |
+
mjtNum* qpos # position (nq x 1)
|
| 135 |
+
mjtNum* qvel # velocity (nv x 1)
|
| 136 |
+
mjtNum* act # actuator activation (na x 1)
|
| 137 |
+
mjtNum* qacc_warmstart # acceleration used for warmstart (nv x 1)
|
| 138 |
+
|
| 139 |
+
# control
|
| 140 |
+
mjtNum* ctrl # control (nu x 1)
|
| 141 |
+
mjtNum* qfrc_applied # applied generalized force (nv x 1)
|
| 142 |
+
mjtNum* xfrc_applied # applied Cartesian force/torque (nbody x 6)
|
| 143 |
+
|
| 144 |
+
# dynamics
|
| 145 |
+
mjtNum* qacc # acceleration (nv x 1)
|
| 146 |
+
mjtNum* act_dot # time-derivative of actuator activation (na x 1)
|
| 147 |
+
|
| 148 |
+
# mocap data
|
| 149 |
+
mjtNum* mocap_pos # positions of mocap bodies (nmocap x 3)
|
| 150 |
+
mjtNum* mocap_quat # orientations of mocap bodies (nmocap x 4)
|
| 151 |
+
|
| 152 |
+
# user data
|
| 153 |
+
mjtNum* userdata # user data, not touched by engine (nuserdata x 1)
|
| 154 |
+
|
| 155 |
+
# sensors
|
| 156 |
+
mjtNum* sensordata # sensor data array (nsensordata x 1)
|
| 157 |
+
|
| 158 |
+
#-------------------------------- POSITION dependent
|
| 159 |
+
|
| 160 |
+
# computed by mj_fwdPosition/mj_kinematics
|
| 161 |
+
mjtNum* xpos # Cartesian position of body frame (nbody x 3)
|
| 162 |
+
mjtNum* xquat # Cartesian orientation of body frame (nbody x 4)
|
| 163 |
+
mjtNum* xmat # Cartesian orientation of body frame (nbody x 9)
|
| 164 |
+
mjtNum* xipos # Cartesian position of body com (nbody x 3)
|
| 165 |
+
mjtNum* ximat # Cartesian orientation of body inertia (nbody x 9)
|
| 166 |
+
mjtNum* xanchor # Cartesian position of joint anchor (njnt x 3)
|
| 167 |
+
mjtNum* xaxis # Cartesian joint axis (njnt x 3)
|
| 168 |
+
mjtNum* geom_xpos # Cartesian geom position (ngeom x 3)
|
| 169 |
+
mjtNum* geom_xmat # Cartesian geom orientation (ngeom x 9)
|
| 170 |
+
mjtNum* site_xpos # Cartesian site position (nsite x 3)
|
| 171 |
+
mjtNum* site_xmat # Cartesian site orientation (nsite x 9)
|
| 172 |
+
mjtNum* cam_xpos # Cartesian camera position (ncam x 3)
|
| 173 |
+
mjtNum* cam_xmat # Cartesian camera orientation (ncam x 9)
|
| 174 |
+
mjtNum* light_xpos # Cartesian light position (nlight x 3)
|
| 175 |
+
mjtNum* light_xdir # Cartesian light direction (nlight x 3)
|
| 176 |
+
|
| 177 |
+
# computed by mj_fwdPosition/mj_comPos
|
| 178 |
+
mjtNum* subtree_com # center of mass of each subtree (nbody x 3)
|
| 179 |
+
mjtNum* cdof # com-based motion axis of each dof (nv x 6)
|
| 180 |
+
mjtNum* cinert # com-based body inertia and mass (nbody x 10)
|
| 181 |
+
|
| 182 |
+
# computed by mj_fwdPosition/mj_tendon
|
| 183 |
+
int* ten_wrapadr # start address of tendon's path (ntendon x 1)
|
| 184 |
+
int* ten_wrapnum # number of wrap points in path (ntendon x 1)
|
| 185 |
+
int* ten_J_rownnz # number of non-zeros in Jacobian row (ntendon x 1)
|
| 186 |
+
int* ten_J_rowadr # row start address in colind array (ntendon x 1)
|
| 187 |
+
int* ten_J_colind # column indices in sparse Jacobian (ntendon x nv)
|
| 188 |
+
mjtNum* ten_length # tendon lengths (ntendon x 1)
|
| 189 |
+
mjtNum* ten_J # tendon Jacobian (ntendon x nv)
|
| 190 |
+
int* wrap_obj # geom id; -1: site; -2: pulley (nwrap*2 x 1)
|
| 191 |
+
mjtNum* wrap_xpos # Cartesian 3D points in all path (nwrap*2 x 3)
|
| 192 |
+
|
| 193 |
+
# computed by mj_fwdPosition/mj_transmission
|
| 194 |
+
mjtNum* actuator_length # actuator lengths (nu x 1)
|
| 195 |
+
mjtNum* actuator_moment # actuator moment arms (nu x nv)
|
| 196 |
+
|
| 197 |
+
# computed by mj_fwdPosition/mj_crb
|
| 198 |
+
mjtNum* crb # com-based composite inertia and mass (nbody x 10)
|
| 199 |
+
mjtNum* qM # total inertia (nM x 1)
|
| 200 |
+
|
| 201 |
+
# computed by mj_fwdPosition/mj_factorM
|
| 202 |
+
mjtNum* qLD # L'*D*L factorization of M (nM x 1)
|
| 203 |
+
mjtNum* qLDiagInv # 1/diag(D) (nv x 1)
|
| 204 |
+
mjtNum* qLDiagSqrtInv # 1/sqrt(diag(D)) (nv x 1)
|
| 205 |
+
|
| 206 |
+
# computed by mj_fwdPosition/mj_collision
|
| 207 |
+
mjContact* contact # list of all detected contacts (nconmax x 1)
|
| 208 |
+
|
| 209 |
+
# computed by mj_fwdPosition/mj_makeConstraint
|
| 210 |
+
int* efc_type # constraint type (mjtConstraint) (njmax x 1)
|
| 211 |
+
int* efc_id # id of object of specified type (njmax x 1)
|
| 212 |
+
int* efc_J_rownnz # number of non-zeros in Jacobian row (njmax x 1)
|
| 213 |
+
int* efc_J_rowadr # row start address in colind array (njmax x 1)
|
| 214 |
+
int* efc_J_rowsuper # number of subsequent rows in supernode (njmax x 1)
|
| 215 |
+
int* efc_J_colind # column indices in sparse Jacobian (njmax x nv)
|
| 216 |
+
int* efc_JT_rownnz # number of non-zeros in Jacobian row T (nv x 1)
|
| 217 |
+
int* efc_JT_rowadr # row start address in colind array T (nv x 1)
|
| 218 |
+
int* efc_JT_rowsuper # number of subsequent rows in supernode T (nv x 1)
|
| 219 |
+
int* efc_JT_colind # column indices in sparse Jacobian T (nv x njmax)
|
| 220 |
+
mjtNum* efc_solref # constraint solver reference (njmax x mjNREF)
|
| 221 |
+
mjtNum* efc_solimp # constraint solver impedance (njmax x mjNIMP)
|
| 222 |
+
mjtNum* efc_J # constraint Jacobian (njmax x nv)
|
| 223 |
+
mjtNum* efc_JT # sparse constraint Jacobian transposed (nv x njmax)
|
| 224 |
+
mjtNum* efc_pos # constraint position (equality, contact) (njmax x 1)
|
| 225 |
+
mjtNum* efc_margin # inclusion margin (contact) (njmax x 1)
|
| 226 |
+
mjtNum* efc_frictionloss # frictionloss (friction) (njmax x 1)
|
| 227 |
+
mjtNum* efc_diagApprox # approximation to diagonal of A (njmax x 1)
|
| 228 |
+
mjtNum* efc_KBIP # stiffness, damping, impedance, imp' (njmax x 4)
|
| 229 |
+
mjtNum* efc_D # constraint mass (njmax x 1)
|
| 230 |
+
mjtNum* efc_R # inverse constraint mass (njmax x 1)
|
| 231 |
+
|
| 232 |
+
# computed by mj_fwdPosition/mj_projectConstraint
|
| 233 |
+
int* efc_AR_rownnz # number of non-zeros in AR (njmax x 1)
|
| 234 |
+
int* efc_AR_rowadr # row start address in colind array (njmax x 1)
|
| 235 |
+
int* efc_AR_colind # column indices in sparse AR (njmax x njmax)
|
| 236 |
+
mjtNum* efc_AR # J*inv(M)*J' + R (njmax x njmax)
|
| 237 |
+
|
| 238 |
+
#-------------------------------- POSITION, VELOCITY dependent
|
| 239 |
+
|
| 240 |
+
# computed by mj_fwdVelocity
|
| 241 |
+
mjtNum* ten_velocity # tendon velocities (ntendon x 1)
|
| 242 |
+
mjtNum* actuator_velocity # actuator velocities (nu x 1)
|
| 243 |
+
|
| 244 |
+
# computed by mj_fwdVelocity/mj_comVel
|
| 245 |
+
mjtNum* cvel # com-based velocity [3D rot; 3D tran] (nbody x 6)
|
| 246 |
+
mjtNum* cdof_dot # time-derivative of cdof (nv x 6)
|
| 247 |
+
|
| 248 |
+
# computed by mj_fwdVelocity/mj_rne (without acceleration)
|
| 249 |
+
mjtNum* qfrc_bias # C(qpos,qvel) (nv x 1)
|
| 250 |
+
|
| 251 |
+
# computed by mj_fwdVelocity/mj_passive
|
| 252 |
+
mjtNum* qfrc_passive # passive force (nv x 1)
|
| 253 |
+
|
| 254 |
+
# computed by mj_fwdVelocity/mj_referenceConstraint
|
| 255 |
+
mjtNum* efc_vel # velocity in constraint space: J*qvel (njmax x 1)
|
| 256 |
+
mjtNum* efc_aref # reference pseudo-acceleration (njmax x 1)
|
| 257 |
+
|
| 258 |
+
# computed by mj_sensorVel
|
| 259 |
+
mjtNum* subtree_linvel # linear velocity of subtree com (nbody x 3)
|
| 260 |
+
mjtNum* subtree_angmom # angular momentum about subtree com (nbody x 3)
|
| 261 |
+
|
| 262 |
+
#-------------------------------- POSITION, VELOCITY, CONTROL/ACCELERATION dependent
|
| 263 |
+
|
| 264 |
+
# computed by mj_fwdActuation
|
| 265 |
+
mjtNum* actuator_force # actuator force in actuation space (nu x 1)
|
| 266 |
+
mjtNum* qfrc_actuator # actuator force (nv x 1)
|
| 267 |
+
|
| 268 |
+
# computed by mj_fwdAcceleration
|
| 269 |
+
mjtNum* qfrc_unc # net unconstrained force (nv x 1)
|
| 270 |
+
mjtNum* qacc_unc # unconstrained acceleration (nv x 1)
|
| 271 |
+
|
| 272 |
+
# computed by mj_fwdConstraint/mj_inverse
|
| 273 |
+
mjtNum* efc_b # linear cost term: J*qacc_unc - aref (njmax x 1)
|
| 274 |
+
mjtNum* efc_force # constraint force in constraint space (njmax x 1)
|
| 275 |
+
int* efc_state # constraint state (mjtConstraintState) (njmax x 1)
|
| 276 |
+
mjtNum* qfrc_constraint # constraint force (nv x 1)
|
| 277 |
+
|
| 278 |
+
# computed by mj_inverse
|
| 279 |
+
mjtNum* qfrc_inverse # net external force; should equal: (nv x 1)
|
| 280 |
+
# qfrc_applied + J'*xfrc_applied + qfrc_actuator
|
| 281 |
+
|
| 282 |
+
# computed by mj_sensorAcc/mj_rnePostConstraint; rotation:translation format
|
| 283 |
+
mjtNum* cacc # com-based acceleration (nbody x 6)
|
| 284 |
+
mjtNum* cfrc_int # com-based interaction force with parent (nbody x 6)
|
| 285 |
+
mjtNum* cfrc_ext # com-based external force on body (nbody x 6)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
#---------------------------------- callback function types ----------------------------
|
| 289 |
+
|
| 290 |
+
# generic MuJoCo function
|
| 291 |
+
ctypedef void (*mjfGeneric)(const mjModel* m, mjData* d)
|
| 292 |
+
|
| 293 |
+
# sensor simulation
|
| 294 |
+
ctypedef void (*mjfSensor)(const mjModel* m, mjData* d, int stage)
|
| 295 |
+
|
| 296 |
+
# timer
|
| 297 |
+
ctypedef long long int (*mjfTime)();
|
| 298 |
+
|
| 299 |
+
# actuator dynamics, gain, bias
|
| 300 |
+
ctypedef mjtNum (*mjfAct)(const mjModel* m, const mjData* d, int id);
|
| 301 |
+
|
| 302 |
+
# solver impedance
|
| 303 |
+
ctypedef mjtNum (*mjfSolImp)(const mjModel* m, const mjData* d, int id,
|
| 304 |
+
mjtNum distance, mjtNum* constimp);
|
| 305 |
+
|
| 306 |
+
# solver reference
|
| 307 |
+
ctypedef void (*mjfSolRef)(const mjModel* m, const mjData* d, int id,
|
| 308 |
+
mjtNum constimp, mjtNum imp, int dim, mjtNum* ref);
|
| 309 |
+
|
| 310 |
+
# collision detection
|
| 311 |
+
ctypedef int (*mjfCollision)(const mjModel* m, const mjData* d,
|
| 312 |
+
mjContact* con, int g1, int g2, mjtNum margin);
|
mujoco-py-2.1.2.14/mujoco_py/pxd/mjmodel.pxd
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cdef struct mjVisual_global_: # global parameters
|
| 2 |
+
float fovy # y-field of view (deg) for free camera
|
| 3 |
+
float ipd # inter-pupilary distance for free camera
|
| 4 |
+
float linewidth # line width for wireframe rendering
|
| 5 |
+
float glow # glow coefficient for selected body
|
| 6 |
+
int offwidth # width of offscreen buffer
|
| 7 |
+
int offheight # height of offscreen buffer
|
| 8 |
+
|
| 9 |
+
cdef struct mjVisual_quality: # rendering quality
|
| 10 |
+
int shadowsize # size of shadowmap texture
|
| 11 |
+
int offsamples # number of multisamples for offscreen rendering
|
| 12 |
+
int numslices # number of slices for Glu drawing
|
| 13 |
+
int numstacks # number of stacks for Glu drawing
|
| 14 |
+
int numquads # number of quads for box rendering
|
| 15 |
+
|
| 16 |
+
cdef struct mjVisual_headlight: # head light
|
| 17 |
+
float ambient[3] # ambient rgb (alpha=1)
|
| 18 |
+
float diffuse[3] # diffuse rgb (alpha=1)
|
| 19 |
+
float specular[3] # specular rgb (alpha=1)
|
| 20 |
+
int active # is headlight active
|
| 21 |
+
|
| 22 |
+
cdef struct mjVisual_map: # mapping
|
| 23 |
+
float stiffness # mouse perturbation stiffness (space->force)
|
| 24 |
+
float stiffnessrot # mouse perturbation stiffness (space->torque)
|
| 25 |
+
float force # from force units to space units
|
| 26 |
+
float torque # from torque units to space units
|
| 27 |
+
float alpha # scale geom alphas when transparency is enabled
|
| 28 |
+
float fogstart # OpenGL fog starts at fogstart * mjModel.stat.extent
|
| 29 |
+
float fogend # OpenGL fog ends at fogend * mjModel.stat.extent
|
| 30 |
+
float znear # near clipping plane = znear * mjModel.stat.extent
|
| 31 |
+
float zfar # far clipping plane = zfar * mjModel.stat.extent
|
| 32 |
+
float haze # haze ratio
|
| 33 |
+
float shadowclip # directional light: shadowclip * mjModel.stat.extent
|
| 34 |
+
float shadowscale # spot light: shadowscale * light.cutoff
|
| 35 |
+
float actuatortendon # scale tendon width
|
| 36 |
+
|
| 37 |
+
cdef struct mjVisual_scale: # scale of decor elements relative to mean body size
|
| 38 |
+
float forcewidth # width of force arrow
|
| 39 |
+
float contactwidth # contact width
|
| 40 |
+
float contactheight # contact height
|
| 41 |
+
float connect # autoconnect capsule width
|
| 42 |
+
float com # com radius
|
| 43 |
+
float camera # camera object
|
| 44 |
+
float light # light object
|
| 45 |
+
float selectpoint # selection point
|
| 46 |
+
float jointlength # joint length
|
| 47 |
+
float jointwidth # joint width
|
| 48 |
+
float actuatorlength # actuator length
|
| 49 |
+
float actuatorwidth # actuator width
|
| 50 |
+
float framelength # bodyframe axis length
|
| 51 |
+
float framewidth # bodyframe axis width
|
| 52 |
+
float constraint # constraint width
|
| 53 |
+
float slidercrank # slidercrank width
|
| 54 |
+
|
| 55 |
+
cdef struct mjVisual_rgba: # color of decor elements
|
| 56 |
+
float fog[4] # external force
|
| 57 |
+
float haze[4] # haze
|
| 58 |
+
float force[4] # external force
|
| 59 |
+
float inertia[4] # inertia box
|
| 60 |
+
float joint[4] # joint
|
| 61 |
+
float actuator[4] # actuator
|
| 62 |
+
float actuatornegative[4] # actuator, negative limit
|
| 63 |
+
float actuatorpositive[4] # actuator, positive limit
|
| 64 |
+
float com[4] # center of mass
|
| 65 |
+
float camera[4] # camera object
|
| 66 |
+
float light[4] # light object
|
| 67 |
+
float selectpoint[4] # selection point
|
| 68 |
+
float connect[4] # auto connect
|
| 69 |
+
float contactpoint[4] # contact point
|
| 70 |
+
float contactforce[4] # contact force
|
| 71 |
+
float contactfriction[4] # contact friction force
|
| 72 |
+
float contacttorque[4] # contact torque
|
| 73 |
+
float contactgap[4] # contact point in gap
|
| 74 |
+
float rangefinder[4] # rangefinder ray
|
| 75 |
+
float constraint[4] # constraint
|
| 76 |
+
float slidercrank[4] # slidercrank
|
| 77 |
+
float crankbroken[4] # used when crank must be stretched/broken
|
| 78 |
+
|
| 79 |
+
cdef extern from "mjmodel.h" nogil:
|
| 80 |
+
# ---------------------------- floating-point definitions -------------------------------
|
| 81 |
+
ctypedef double mjtNum
|
| 82 |
+
|
| 83 |
+
# global constants
|
| 84 |
+
enum: mjPI
|
| 85 |
+
enum: mjMAXVAL
|
| 86 |
+
enum: mjMINMU
|
| 87 |
+
enum: mjMINIMP
|
| 88 |
+
enum: mjMAXIMP
|
| 89 |
+
enum: mjMAXCONPAIR
|
| 90 |
+
enum: mjMAXVFS
|
| 91 |
+
enum: mjMAXVFSNAME
|
| 92 |
+
|
| 93 |
+
# ---------------------------- sizes ----------------------------------------------------
|
| 94 |
+
enum: mjNEQDATA
|
| 95 |
+
enum: mjNDYN
|
| 96 |
+
enum: mjNGAIN
|
| 97 |
+
enum: mjNBIAS
|
| 98 |
+
enum: mjNREF
|
| 99 |
+
enum: mjNIMP
|
| 100 |
+
enum: mjNSOLVER
|
| 101 |
+
|
| 102 |
+
# ---------------------------- primitive types (mjt) ------------------------------------
|
| 103 |
+
ctypedef unsigned char mjtByte # used for true/false
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
ctypedef enum mjtDisableBit: # disable default feature bitflags
|
| 107 |
+
mjDSBL_CONSTRAINT # entire constraint solver
|
| 108 |
+
mjDSBL_EQUALITY # equality constraints
|
| 109 |
+
mjDSBL_FRICTIONLOSS # joint and tendon frictionloss constraints
|
| 110 |
+
mjDSBL_LIMIT # joint and tendon limit constraints
|
| 111 |
+
mjDSBL_CONTACT # contact constraints
|
| 112 |
+
mjDSBL_PASSIVE # passive forces
|
| 113 |
+
mjDSBL_GRAVITY # gravitational forces
|
| 114 |
+
mjDSBL_CLAMPCTRL # clamp control to specified range
|
| 115 |
+
mjDSBL_WARMSTART # warmstart constraint solver
|
| 116 |
+
mjDSBL_FILTERPARENT # remove collisions with parent body
|
| 117 |
+
mjDSBL_ACTUATION # apply actuation forces
|
| 118 |
+
mjDSBL_REFSAFE # integrator safety: make ref[0]>=2*timestep
|
| 119 |
+
enum: mjNDISABLE # number of disable flags
|
| 120 |
+
|
| 121 |
+
ctypedef enum mjtEnableBit: # enable optional feature bitflags
|
| 122 |
+
mjENBL_OVERRIDE # override contact parameters
|
| 123 |
+
mjENBL_ENERGY # energy computation
|
| 124 |
+
mjENBL_FWDINV # record solver statistics
|
| 125 |
+
mjENBL_SENSORNOISE # add noise to sensor data
|
| 126 |
+
enum: mjNENABLE # number of enable flags
|
| 127 |
+
|
| 128 |
+
ctypedef enum mjtJoint: # type of degree of freedom
|
| 129 |
+
mjJNT_FREE = 0, # global position and orientation (quat) (7)
|
| 130 |
+
mjJNT_BALL, # orientation (quat) relative to parent (4)
|
| 131 |
+
mjJNT_SLIDE, # sliding distance along body-fixed axis (1)
|
| 132 |
+
mjJNT_HINGE # rotation angle (rad) around body-fixed axis (1)
|
| 133 |
+
|
| 134 |
+
ctypedef enum mjtGeom: # type of geometric shape
|
| 135 |
+
# regular geom types
|
| 136 |
+
mjGEOM_PLANE = 0, # plane
|
| 137 |
+
mjGEOM_HFIELD, # height field
|
| 138 |
+
mjGEOM_SPHERE, # sphere
|
| 139 |
+
mjGEOM_CAPSULE, # capsule
|
| 140 |
+
mjGEOM_ELLIPSOID, # ellipsoid
|
| 141 |
+
mjGEOM_CYLINDER, # cylinder
|
| 142 |
+
mjGEOM_BOX, # box
|
| 143 |
+
mjGEOM_MESH, # mesh
|
| 144 |
+
|
| 145 |
+
mjNGEOMTYPES, # number of regular geom types
|
| 146 |
+
|
| 147 |
+
# rendering-only geom types: not used in mjModel, not counted in mjNGEOMTYPES
|
| 148 |
+
mjGEOM_ARROW = 100, # arrow
|
| 149 |
+
mjGEOM_ARROW1, # arrow without wedges
|
| 150 |
+
mjGEOM_ARROW2, # arrow in both directions
|
| 151 |
+
mjGEOM_LABEL, # text label
|
| 152 |
+
|
| 153 |
+
mjGEOM_NONE = 1001 # missing geom type
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
ctypedef enum mjtCamLight: # tracking mode for camera and light
|
| 157 |
+
mjCAMLIGHT_FIXED = 0, # pos and rot fixed in body
|
| 158 |
+
mjCAMLIGHT_TRACK, # pos tracks body, rot fixed in global
|
| 159 |
+
mjCAMLIGHT_TRACKCOM, # pos tracks subtree com, rot fixed in body
|
| 160 |
+
mjCAMLIGHT_TARGETBODY, # pos fixed in body, rot tracks target body
|
| 161 |
+
mjCAMLIGHT_TARGETBODYCOM # pos fixed in body, rot tracks target subtree com
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
ctypedef enum mjtTexture: # type of texture
|
| 165 |
+
mjTEXTURE_2D = 0, # 2d texture, suitable for planes and hfields
|
| 166 |
+
mjTEXTURE_CUBE, # cube texture, suitable for all other geom types
|
| 167 |
+
mjTEXTURE_SKYBOX # cube texture used as skybox
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
ctypedef enum mjtIntegrator: # integrator mode
|
| 171 |
+
mjINT_EULER = 0, # semi-implicit Euler
|
| 172 |
+
mjINT_RK4 # 4th-order Runge Kutta
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
ctypedef enum mjtCollision: # collision mode for selecting geom pairs
|
| 176 |
+
mjCOL_ALL = 0, # test precomputed and dynamic pairs
|
| 177 |
+
mjCOL_PAIR, # test predefined pairs only
|
| 178 |
+
mjCOL_DYNAMIC # test dynamic pairs only
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
ctypedef enum mjtCone: # type of friction cone
|
| 182 |
+
mjCONE_PYRAMIDAL = 0, # pyramidal
|
| 183 |
+
mjCONE_ELLIPTIC # elliptic
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
ctypedef enum mjtJacobian: # type of constraint Jacobian
|
| 187 |
+
mjJAC_DENSE = 0, # dense
|
| 188 |
+
mjJAC_SPARSE, # sparse
|
| 189 |
+
mjJAC_AUTO # dense if nv<60, sparse otherwise
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
ctypedef enum mjtSolver: # constraint solver algorithm
|
| 193 |
+
mjSOL_PGS = 0, # PGS (dual)
|
| 194 |
+
mjSOL_CG, # CG (primal)
|
| 195 |
+
mjSOL_NEWTON # Newton (primal)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
ctypedef enum mjtImp: # how to interpret solimp parameters
|
| 199 |
+
mjIMP_CONSTANT = 0, # constant solimp[1]
|
| 200 |
+
mjIMP_SIGMOID, # sigmoid from solimp[0] to solimp[1], width solimp[2]
|
| 201 |
+
mjIMP_LINEAR, # piece-wise linear sigmoid
|
| 202 |
+
mjIMP_USER # impedance computed by callback
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
ctypedef enum mjtRef: # how to interpret solref parameters
|
| 206 |
+
mjREF_SPRING = 0, # spring-damper: timeconst=solref[0], dampratio=solref[1]
|
| 207 |
+
mjREF_USER # reference computed by callback
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
ctypedef enum mjtEq: # type of equality constraint
|
| 211 |
+
mjEQ_CONNECT = 0, # connect two bodies at a point (ball joint)
|
| 212 |
+
mjEQ_WELD, # fix relative position and orientation of two bodies
|
| 213 |
+
mjEQ_JOINT, # couple the values of two scalar joints with cubic
|
| 214 |
+
mjEQ_TENDON, # couple the lengths of two tendons with cubic
|
| 215 |
+
mjEQ_DISTANCE # fix the contact distance betweent two geoms
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
ctypedef enum mjtWrap: # type of tendon wrap object
|
| 219 |
+
mjWRAP_NONE = 0, # null object
|
| 220 |
+
mjWRAP_JOINT, # constant moment arm
|
| 221 |
+
mjWRAP_PULLEY, # pulley used to split tendon
|
| 222 |
+
mjWRAP_SITE, # pass through site
|
| 223 |
+
mjWRAP_SPHERE, # wrap around sphere
|
| 224 |
+
mjWRAP_CYLINDER # wrap around (infinite) cylinder
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
ctypedef enum mjtTrn: # type of actuator transmission
|
| 228 |
+
mjTRN_JOINT = 0, # force on joint
|
| 229 |
+
mjTRN_JOINTINPARENT, # force on joint, expressed in parent frame
|
| 230 |
+
mjTRN_SLIDERCRANK, # force via slider-crank linkage
|
| 231 |
+
mjTRN_TENDON, # force on tendon
|
| 232 |
+
mjTRN_SITE, # force on site
|
| 233 |
+
|
| 234 |
+
mjTRN_UNDEFINED = 1000 # undefined transmission type
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
ctypedef enum mjtDyn: # type of actuator dynamics
|
| 238 |
+
mjDYN_NONE = 0, # no internal dynamics; ctrl specifies force
|
| 239 |
+
mjDYN_INTEGRATOR, # integrator: da/dt = u
|
| 240 |
+
mjDYN_FILTER, # linear filter: da/dt = (u-a) / tau
|
| 241 |
+
mjDYN_USER # user-defined dynamics type
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
ctypedef enum mjtGain: # type of actuator gain
|
| 245 |
+
mjGAIN_FIXED = 0, # fixed gain
|
| 246 |
+
mjGAIN_USER # user-defined gain type
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
ctypedef enum mjtBias: # type of actuator bias
|
| 250 |
+
mjBIAS_NONE = 0, # no bias
|
| 251 |
+
mjBIAS_AFFINE, # const + kp*length + kv*velocity
|
| 252 |
+
mjBIAS_USER # user-defined bias type
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
ctypedef enum mjtObj: # type of MujoCo object
|
| 256 |
+
mjOBJ_UNKNOWN = 0, # unknown object type
|
| 257 |
+
mjOBJ_BODY, # body
|
| 258 |
+
mjOBJ_XBODY, # body, used to access regular frame instead of i-frame
|
| 259 |
+
mjOBJ_JOINT, # joint
|
| 260 |
+
mjOBJ_DOF, # dof
|
| 261 |
+
mjOBJ_GEOM, # geom
|
| 262 |
+
mjOBJ_SITE, # site
|
| 263 |
+
mjOBJ_CAMERA, # camera
|
| 264 |
+
mjOBJ_LIGHT, # light
|
| 265 |
+
mjOBJ_MESH, # mesh
|
| 266 |
+
mjOBJ_HFIELD, # heightfield
|
| 267 |
+
mjOBJ_TEXTURE, # texture
|
| 268 |
+
mjOBJ_MATERIAL, # material for rendering
|
| 269 |
+
mjOBJ_PAIR, # geom pair to include
|
| 270 |
+
mjOBJ_EXCLUDE, # body pair to exclude
|
| 271 |
+
mjOBJ_EQUALITY, # equality constraint
|
| 272 |
+
mjOBJ_TENDON, # tendon
|
| 273 |
+
mjOBJ_ACTUATOR, # actuator
|
| 274 |
+
mjOBJ_SENSOR, # sensor
|
| 275 |
+
mjOBJ_NUMERIC, # numeric
|
| 276 |
+
mjOBJ_TEXT, # text
|
| 277 |
+
mjOBJ_TUPLE, # tuple
|
| 278 |
+
mjOBJ_KEY # keyframe
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
ctypedef enum mjtConstraint: # type of constraint
|
| 282 |
+
mjCNSTR_EQUALITY = 0, # equality constraint
|
| 283 |
+
mjCNSTR_FRICTION_DOF, # dof friction
|
| 284 |
+
mjCNSTR_FRICTION_TENDON, # tendon friction
|
| 285 |
+
mjCNSTR_LIMIT_JOINT, # joint limit
|
| 286 |
+
mjCNSTR_LIMIT_TENDON, # tendon limit
|
| 287 |
+
mjCNSTR_CONTACT_FRICTIONLESS, # frictionless contact
|
| 288 |
+
mjCNSTR_CONTACT_PYRAMIDAL, # frictional contact, pyramidal friction cone
|
| 289 |
+
mjCNSTR_CONTACT_ELLIPTIC # frictional contact, elliptic friction cone
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
ctypedef enum mjtConstraintState: # constraint state
|
| 293 |
+
mjCNSTRSTATE_SATISFIED = 0, # constraint satisfied, zero cost (limit, contact)
|
| 294 |
+
mjCNSTRSTATE_QUADRATIC, # quadratic cost (equality, friction, limit, contact)
|
| 295 |
+
mjCNSTRSTATE_LINEARNEG, # linear cost, negative side (friction)
|
| 296 |
+
mjCNSTRSTATE_LINEARPOS, # linear cost, positive side (friction)
|
| 297 |
+
mjCNSTRSTATE_CONE # squared distance to cone cost (elliptic contact)
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
ctypedef enum mjtSensor: # type of sensor
|
| 302 |
+
# common robotic sensors, attached to a site
|
| 303 |
+
mjSENS_TOUCH = 0, # scalar contact normal forces summed over sensor zone
|
| 304 |
+
mjSENS_ACCELEROMETER, # 3D linear acceleration, in local frame
|
| 305 |
+
mjSENS_VELOCIMETER, # 3D linear velocity, in local frame
|
| 306 |
+
mjSENS_GYRO, # 3D angular velocity, in local frame
|
| 307 |
+
mjSENS_FORCE, # 3D force between site's body and its parent body
|
| 308 |
+
mjSENS_TORQUE, # 3D torque between site's body and its parent body
|
| 309 |
+
mjSENS_MAGNETOMETER, # 3D magnetometer
|
| 310 |
+
mjSENS_RANGEFINDER, # scalar distance to nearest geom or site along z-axis
|
| 311 |
+
|
| 312 |
+
# sensors related to scalar joints, tendons, actuators
|
| 313 |
+
mjSENS_JOINTPOS, # scalar joint position (hinge and slide only)
|
| 314 |
+
mjSENS_JOINTVEL, # scalar joint velocity (hinge and slide only)
|
| 315 |
+
mjSENS_TENDONPOS, # scalar tendon position
|
| 316 |
+
mjSENS_TENDONVEL, # scalar tendon velocity
|
| 317 |
+
mjSENS_ACTUATORPOS, # scalar actuator position
|
| 318 |
+
mjSENS_ACTUATORVEL, # scalar actuator velocity
|
| 319 |
+
mjSENS_ACTUATORFRC, # scalar actuator force
|
| 320 |
+
|
| 321 |
+
# sensors related to ball joints
|
| 322 |
+
mjSENS_BALLQUAT, # 4D ball joint quaterion
|
| 323 |
+
mjSENS_BALLANGVEL, # 3D ball joint angular velocity
|
| 324 |
+
|
| 325 |
+
# sensors attached to an object with spatial frame: (x)body, geom, site, camera
|
| 326 |
+
mjSENS_FRAMEPOS, # 3D position
|
| 327 |
+
mjSENS_FRAMEQUAT, # 4D unit quaternion orientation
|
| 328 |
+
mjSENS_FRAMEXAXIS, # 3D unit vector: x-axis of object's frame
|
| 329 |
+
mjSENS_FRAMEYAXIS, # 3D unit vector: y-axis of object's frame
|
| 330 |
+
mjSENS_FRAMEZAXIS, # 3D unit vector: z-axis of object's frame
|
| 331 |
+
mjSENS_FRAMELINVEL, # 3D linear velocity
|
| 332 |
+
mjSENS_FRAMEANGVEL, # 3D angular velocity
|
| 333 |
+
mjSENS_FRAMELINACC, # 3D linear acceleration
|
| 334 |
+
mjSENS_FRAMEANGACC, # 3D angular acceleration
|
| 335 |
+
|
| 336 |
+
# sensors related to kinematic subtrees; attached to a body (which is the subtree root)
|
| 337 |
+
mjSENS_SUBTREECOM, # 3D center of mass of subtree
|
| 338 |
+
mjSENS_SUBTREELINVEL, # 3D linear velocity of subtree
|
| 339 |
+
mjSENS_SUBTREEANGMOM, # 3D angular momentum of subtree
|
| 340 |
+
|
| 341 |
+
# user-defined sensor
|
| 342 |
+
mjSENS_USER # sensor data provided by mjcb_sensor callback
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
ctypedef enum mjtStage: # computation stage
|
| 346 |
+
mjSTAGE_NONE = 0, # no computations
|
| 347 |
+
mjSTAGE_POS, # position-dependent computations
|
| 348 |
+
mjSTAGE_VEL, # velocity-dependent computations
|
| 349 |
+
mjSTAGE_ACC # acceleration/force-dependent computations
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
ctypedef enum mjtDataType: # data type for sensors
|
| 353 |
+
mjDATATYPE_REAL = 0, # real values, no constraints
|
| 354 |
+
mjDATATYPE_POSITIVE, # positive values; 0 or negative: inactive
|
| 355 |
+
mjDATATYPE_AXIS, # 3D unit vector
|
| 356 |
+
mjDATATYPE_QUAT # unit quaternion
|
| 357 |
+
|
| 358 |
+
#------------------------------ mjVFS --------------------------------------------------
|
| 359 |
+
|
| 360 |
+
ctypedef struct mjVFS: # virtual file system for loading from memory
|
| 361 |
+
int nfile # number of files present
|
| 362 |
+
char filename[mjMAXVFS][mjMAXVFSNAME] # file name without path
|
| 363 |
+
int filesize[mjMAXVFS] # file size in bytes
|
| 364 |
+
void* filedata[mjMAXVFS] # buffer with file data
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
#------------------------------ mjOption -----------------------------------------------
|
| 369 |
+
|
| 370 |
+
ctypedef struct mjOption: # physics options
|
| 371 |
+
# timing parameters
|
| 372 |
+
mjtNum timestep # timestep
|
| 373 |
+
mjtNum apirate # update rate for remote API (Hz)
|
| 374 |
+
|
| 375 |
+
# solver parameters
|
| 376 |
+
mjtNum impratio # ratio of friction-to-normal contact impedance
|
| 377 |
+
mjtNum tolerance # solver convergence tolerance
|
| 378 |
+
mjtNum noslip_tolerance # noslip solver tolerance
|
| 379 |
+
mjtNum mpr_tolerance # MPR solver tolerance
|
| 380 |
+
|
| 381 |
+
# physical constants
|
| 382 |
+
mjtNum gravity[3] # gravitational acceleration
|
| 383 |
+
mjtNum wind[3] # wind (for lift, drag and viscosity)
|
| 384 |
+
mjtNum magnetic[3] # global magnetic flux
|
| 385 |
+
mjtNum density # density of medium
|
| 386 |
+
mjtNum viscosity # viscosity of medium
|
| 387 |
+
|
| 388 |
+
# override contact solver parameters (if enabled)
|
| 389 |
+
mjtNum o_margin # margin
|
| 390 |
+
mjtNum o_solref[mjNREF] # solref
|
| 391 |
+
mjtNum o_solimp[mjNIMP] # solimp
|
| 392 |
+
|
| 393 |
+
# discrete settings
|
| 394 |
+
int integrator # integration mode (mjtIntegrator)
|
| 395 |
+
int collision # collision mode (mjtCollision)
|
| 396 |
+
int cone # type of friction cone (mjtCone)
|
| 397 |
+
int jacobian # type of Jacobian (mjtJacobian)
|
| 398 |
+
int solver # solver mode (mjtSolver)
|
| 399 |
+
int iterations # maximum number of solver iterations
|
| 400 |
+
int noslip_iterations # maximum number of noslip solver iterations
|
| 401 |
+
int mpr_iterations # maximum number of MPR solver iterations
|
| 402 |
+
int disableflags # bit flags for disabling standard features
|
| 403 |
+
int enableflags # bit flags for enabling optional features
|
| 404 |
+
|
| 405 |
+
#------------------------------ mjLROpt ------------------------------------------------
|
| 406 |
+
|
| 407 |
+
ctypedef struct mjLROpt:
|
| 408 |
+
# flags
|
| 409 |
+
int mode # which actuators to process (mjtLRMode)
|
| 410 |
+
int useexisting # use existing length range if available
|
| 411 |
+
int uselimit # use joint and tendon limits if available
|
| 412 |
+
|
| 413 |
+
# algorithm parameters
|
| 414 |
+
mjtNum accel # target acceleration used to compute force
|
| 415 |
+
mjtNum maxforce # maximum force; 0: no limit
|
| 416 |
+
mjtNum timeconst # time constant for velocity reduction; min 0.01
|
| 417 |
+
mjtNum timestep # simulation timestep; 0: use mjOption.timestep
|
| 418 |
+
mjtNum inttotal # total simulation time interval
|
| 419 |
+
mjtNum inteval # evaluation time interval (at the end)
|
| 420 |
+
mjtNum tolrange # convergence tolerance (relative to range)
|
| 421 |
+
|
| 422 |
+
#------------------------------ mjVisual -----------------------------------------------
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
ctypedef struct mjVisual:
|
| 426 |
+
mjVisual_global_ global_ "global"
|
| 427 |
+
mjVisual_quality quality
|
| 428 |
+
mjVisual_headlight headlight
|
| 429 |
+
mjVisual_map map
|
| 430 |
+
mjVisual_scale scale
|
| 431 |
+
mjVisual_rgba rgba
|
| 432 |
+
|
| 433 |
+
#------------------------------ mjStatistic --------------------------------------------
|
| 434 |
+
|
| 435 |
+
ctypedef struct mjStatistic: # model statistics (in qpos0)
|
| 436 |
+
mjtNum meaninertia # mean diagonal inertia
|
| 437 |
+
mjtNum meanmass # mean body mass
|
| 438 |
+
mjtNum meansize # mean body size
|
| 439 |
+
mjtNum extent # spatial extent
|
| 440 |
+
mjtNum center[3] # center of model
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
# ---------------------------------- mjModel --------------------------------------------
|
| 444 |
+
ctypedef struct mjModel:
|
| 445 |
+
# ------------------------------- sizes
|
| 446 |
+
|
| 447 |
+
# sizes needed at mjModel construction
|
| 448 |
+
int nq # number of generalized coordinates = dim(qpos)
|
| 449 |
+
int nv # number of degrees of freedom = dim(qvel)
|
| 450 |
+
int nu # number of actuators/controls = dim(ctrl)
|
| 451 |
+
int na # number of activation states = dim(act)
|
| 452 |
+
int nbody # number of bodies
|
| 453 |
+
int njnt # number of joints
|
| 454 |
+
int ngeom # number of geoms
|
| 455 |
+
int nsite # number of sites
|
| 456 |
+
int ncam # number of cameras
|
| 457 |
+
int nlight # number of lights
|
| 458 |
+
int nmesh # number of meshes
|
| 459 |
+
int nmeshvert # number of vertices in all meshes
|
| 460 |
+
int nmeshtexvert; # number of vertices with texcoords in all meshes
|
| 461 |
+
int nmeshface # number of triangular faces in all meshes
|
| 462 |
+
int nmeshgraph # number of ints in mesh auxiliary data
|
| 463 |
+
int nskin # number of skins
|
| 464 |
+
int nskinvert # number of vertices in all skins
|
| 465 |
+
int nskintexvert # number of vertiex with texcoords in all skins
|
| 466 |
+
int nskinface # number of triangular faces in all skins
|
| 467 |
+
int nskinbone # number of bones in all skins
|
| 468 |
+
int nskinbonevert # number of vertices in all skin bones
|
| 469 |
+
int nhfield # number of heightfields
|
| 470 |
+
int nhfielddata # number of data points in all heightfields
|
| 471 |
+
int ntex # number of textures
|
| 472 |
+
int ntexdata # number of bytes in texture rgb data
|
| 473 |
+
int nmat # number of materials
|
| 474 |
+
int npair # number of predefined geom pairs
|
| 475 |
+
int nexclude # number of excluded geom pairs
|
| 476 |
+
int neq # number of equality constraints
|
| 477 |
+
int ntendon # number of tendons
|
| 478 |
+
int nwrap # number of wrap objects in all tendon paths
|
| 479 |
+
int nsensor # number of sensors
|
| 480 |
+
int nnumeric # number of numeric custom fields
|
| 481 |
+
int nnumericdata # number of mjtNums in all numeric fields
|
| 482 |
+
int ntext # number of text custom fields
|
| 483 |
+
int ntextdata # number of mjtBytes in all text fields
|
| 484 |
+
int ntuple # number of tuple custom fields
|
| 485 |
+
int ntupledata # number of objects in all tuple fields
|
| 486 |
+
int nkey # number of keyframes
|
| 487 |
+
int nmocap # number of mocap bodies
|
| 488 |
+
int nuser_body # number of mjtNums in body_user
|
| 489 |
+
int nuser_jnt # number of mjtNums in jnt_user
|
| 490 |
+
int nuser_geom # number of mjtNums in geom_user
|
| 491 |
+
int nuser_site # number of mjtNums in site_user
|
| 492 |
+
int nuser_cam # number of mjtNums in cam_user
|
| 493 |
+
int nuser_tendon # number of mjtNums in tendon_user
|
| 494 |
+
int nuser_actuator # number of mjtNums in actuator_user
|
| 495 |
+
int nuser_sensor # number of mjtNums in sensor_user
|
| 496 |
+
int nnames # number of chars in all names
|
| 497 |
+
|
| 498 |
+
# sizes set after jModel construction (only affect mjData)
|
| 499 |
+
int nM # number of non-zeros in sparse inertia matrix
|
| 500 |
+
int nemax # number of potential equality-constraint rows
|
| 501 |
+
int njmax # number of available rows in constraint Jacobian
|
| 502 |
+
int nconmax # number of potential contacts in contact list
|
| 503 |
+
int nstack # number of fields in mjData stack
|
| 504 |
+
int nuserdata # number of extra fields in mjData
|
| 505 |
+
int nsensordata # number of fields in sensor data vector
|
| 506 |
+
|
| 507 |
+
int nbuffer # number of bytes in buffer
|
| 508 |
+
|
| 509 |
+
# ------------------------------- options and statistics
|
| 510 |
+
|
| 511 |
+
mjOption opt # physics options
|
| 512 |
+
mjVisual vis # visualization options
|
| 513 |
+
mjStatistic stat # model statistics
|
| 514 |
+
|
| 515 |
+
# ------------------------------- buffers
|
| 516 |
+
|
| 517 |
+
# main buffer
|
| 518 |
+
void* buffer # main buffer; all pointers point in it (nbuffer)
|
| 519 |
+
|
| 520 |
+
# default generalized coordinates
|
| 521 |
+
mjtNum* qpos0 # qpos values at default pose (nq x 1)
|
| 522 |
+
mjtNum* qpos_spring # reference pose for springs (nq x 1)
|
| 523 |
+
|
| 524 |
+
# bodies
|
| 525 |
+
int* body_parentid # id of body's parent (nbody x 1)
|
| 526 |
+
int* body_rootid # id of root above body (nbody x 1)
|
| 527 |
+
int* body_weldid # id of body that this body is welded to (nbody x 1)
|
| 528 |
+
int* body_mocapid # id of mocap data; -1: none (nbody x 1)
|
| 529 |
+
int* body_jntnum # number of joints for this body (nbody x 1)
|
| 530 |
+
int* body_jntadr # start addr of joints; -1: no joints (nbody x 1)
|
| 531 |
+
int* body_dofnum # number of motion degrees of freedom (nbody x 1)
|
| 532 |
+
int* body_dofadr # start addr of dofs; -1: no dofs (nbody x 1)
|
| 533 |
+
int* body_geomnum # number of geoms (nbody x 1)
|
| 534 |
+
int* body_geomadr # start addr of geoms; -1: no geoms (nbody x 1)
|
| 535 |
+
mjtByte* body_simple # body is simple (has diagonal M) (nbody x 1)
|
| 536 |
+
mjtByte* body_sameframe # inertial frame is same as body frame (nbody x 1)
|
| 537 |
+
mjtNum* body_pos # position offset rel. to parent body (nbody x 3)
|
| 538 |
+
mjtNum* body_quat # orientation offset rel. to parent body (nbody x 4)
|
| 539 |
+
mjtNum* body_ipos # local position of center of mass (nbody x 3)
|
| 540 |
+
mjtNum* body_iquat # local orientation of inertia ellipsoid (nbody x 4)
|
| 541 |
+
mjtNum* body_mass # mass (nbody x 1)
|
| 542 |
+
mjtNum* body_subtreemass # mass of subtree starting at this body (nbody x 1)
|
| 543 |
+
mjtNum* body_inertia # diagonal inertia in ipos/iquat frame (nbody x 3)
|
| 544 |
+
mjtNum* body_invweight0 # mean inv inert in qpos0 (trn, rot) (nbody x 2)
|
| 545 |
+
mjtNum* body_user # user data (nbody x nuser_body)
|
| 546 |
+
|
| 547 |
+
# joints
|
| 548 |
+
int* jnt_type # type of joint (mjtJoint) (njnt x 1)
|
| 549 |
+
int* jnt_qposadr # start addr in 'qpos' for joint's data (njnt x 1)
|
| 550 |
+
int* jnt_dofadr # start addr in 'qvel' for joint's data (njnt x 1)
|
| 551 |
+
int* jnt_bodyid # id of joint's body (njnt x 1)
|
| 552 |
+
int* jnt_group # group for visibility (njnt x 1)
|
| 553 |
+
mjtByte* jnt_limited # does joint have limits (njnt x 1)
|
| 554 |
+
mjtNum* jnt_solref # constraint solver reference: limit (njnt x mjNREF)
|
| 555 |
+
mjtNum* jnt_solimp # constraint solver impedance: limit (njnt x mjNIMP)
|
| 556 |
+
mjtNum* jnt_pos # local anchor position (njnt x 3)
|
| 557 |
+
mjtNum* jnt_axis # local joint axis (njnt x 3)
|
| 558 |
+
mjtNum* jnt_stiffness # stiffness coefficient (njnt x 1)
|
| 559 |
+
mjtNum* jnt_range # joint limits (njnt x 2)
|
| 560 |
+
mjtNum* jnt_margin # min distance for limit detection (njnt x 1)
|
| 561 |
+
mjtNum* jnt_user # user data (njnt x nuser_jnt)
|
| 562 |
+
|
| 563 |
+
# dofs
|
| 564 |
+
int* dof_bodyid # id of dof's body (nv x 1)
|
| 565 |
+
int* dof_jntid # id of dof's joint (nv x 1)
|
| 566 |
+
int* dof_parentid # id of dof's parent; -1: none (nv x 1)
|
| 567 |
+
int* dof_Madr # dof address in M-diagonal (nv x 1)
|
| 568 |
+
int* dof_simplenum # number of consecutive simple dofs (nv x 1)
|
| 569 |
+
mjtNum* dof_solref # constraint solver reference:frictionloss (nv x mjNREF)
|
| 570 |
+
mjtNum* dof_solimp # constraint solver impedance:frictionloss (nv x mjNIMP)
|
| 571 |
+
mjtNum* dof_frictionloss # dof friction loss (nv x 1)
|
| 572 |
+
mjtNum* dof_armature # dof armature inertia/mass (nv x 1)
|
| 573 |
+
mjtNum* dof_damping # damping coefficient (nv x 1)
|
| 574 |
+
mjtNum* dof_invweight0 # inv. diag. inertia in qpos0 (nv x 1)
|
| 575 |
+
mjtNum* dof_M0 # diag. inertia in qpos0 (nv x 1)
|
| 576 |
+
|
| 577 |
+
# geoms
|
| 578 |
+
int* geom_type # geometric type (mjtGeom) (ngeom x 1)
|
| 579 |
+
int* geom_contype # geom contact type (ngeom x 1)
|
| 580 |
+
int* geom_conaffinity # geom contact affinity (ngeom x 1)
|
| 581 |
+
int* geom_condim # contact dimensionality (1, 3, 4, 6) (ngeom x 1)
|
| 582 |
+
int* geom_bodyid # id of geom's body (ngeom x 1)
|
| 583 |
+
int* geom_dataid # id of geom's mesh/hfield (-1: none) (ngeom x 1)
|
| 584 |
+
int* geom_matid # material id for rendering (ngeom x 1)
|
| 585 |
+
int* geom_group # group for visibility (ngeom x 1)
|
| 586 |
+
int* geom_priority # geom contact priority (ngeom x 1)
|
| 587 |
+
mjtByte* geom_sameframe # same as body frame (1) or iframe (2) (ngeom x 1)
|
| 588 |
+
mjtNum* geom_solmix # mixing coef for solref/imp in geom pair (ngeom x 1)
|
| 589 |
+
mjtNum* geom_solref # constraint solver reference: contact (ngeom x mjNREF)
|
| 590 |
+
mjtNum* geom_solimp # constraint solver impedance: contact (ngeom x mjNIMP)
|
| 591 |
+
mjtNum* geom_size # geom-specific size parameters (ngeom x 3)
|
| 592 |
+
mjtNum* geom_rbound # radius of bounding sphere (ngeom x 1)
|
| 593 |
+
mjtNum* geom_pos # local position offset rel. to body (ngeom x 3)
|
| 594 |
+
mjtNum* geom_quat # local orientation offset rel. to body (ngeom x 4)
|
| 595 |
+
mjtNum* geom_friction # friction for (slide, spin, roll) (ngeom x 3)
|
| 596 |
+
mjtNum* geom_margin # detect contact if dist<margin (ngeom x 1)
|
| 597 |
+
mjtNum* geom_gap # include in solver if dist<margin-gap (ngeom x 1)
|
| 598 |
+
mjtNum* geom_user # user data (ngeom x nuser_geom)
|
| 599 |
+
float* geom_rgba # rgba when material is omitted (ngeom x 4)
|
| 600 |
+
|
| 601 |
+
# sites
|
| 602 |
+
int* site_type # geom type for rendering (mjtGeom) (nsite x 1)
|
| 603 |
+
int* site_bodyid # id of site's body (nsite x 1)
|
| 604 |
+
int* site_matid # material id for rendering (nsite x 1)
|
| 605 |
+
int* site_group # group for visibility (nsite x 1)
|
| 606 |
+
mjtByte* site_sameframe # same as body frame (1) or iframe (2) (nsite x 1)
|
| 607 |
+
mjtNum* site_size # geom size for rendering (nsite x 3)
|
| 608 |
+
mjtNum* site_pos # local position offset rel. to body (nsite x 3)
|
| 609 |
+
mjtNum* site_quat # local orientation offset rel. to body (nsite x 4)
|
| 610 |
+
mjtNum* site_user # user data (nsite x nuser_site)
|
| 611 |
+
float* site_rgba # rgba when material is omitted (nsite x 4)
|
| 612 |
+
|
| 613 |
+
# cameras
|
| 614 |
+
int* cam_mode # camera tracking mode (mjtCamLight) (ncam x 1)
|
| 615 |
+
int* cam_bodyid # id of camera's body (ncam x 1)
|
| 616 |
+
int* cam_targetbodyid # id of targeted body; -1: none (ncam x 1)
|
| 617 |
+
mjtNum* cam_pos # position rel. to body frame (ncam x 3)
|
| 618 |
+
mjtNum* cam_quat # orientation rel. to body frame (ncam x 4)
|
| 619 |
+
mjtNum* cam_poscom0 # global position rel. to sub-com in qpos0 (ncam x 3)
|
| 620 |
+
mjtNum* cam_pos0 # global position rel. to body in qpos0 (ncam x 3)
|
| 621 |
+
mjtNum* cam_mat0 # global orientation in qpos0 (ncam x 9)
|
| 622 |
+
mjtNum* cam_fovy # y-field of view (deg) (ncam x 1)
|
| 623 |
+
mjtNum* cam_ipd # inter-pupilary distance (ncam x 1)
|
| 624 |
+
mjtNum* cam_user # user data (ncam x nuser_cam)
|
| 625 |
+
|
| 626 |
+
# lights
|
| 627 |
+
int* light_mode # light tracking mode (mjtCamLight) (nlight x 1)
|
| 628 |
+
int* light_bodyid # id of light's body (nlight x 1)
|
| 629 |
+
int* light_targetbodyid # id of targeted body; -1: none (nlight x 1)
|
| 630 |
+
mjtByte* light_directional # directional light (nlight x 1)
|
| 631 |
+
mjtByte* light_castshadow # does light cast shadows (nlight x 1)
|
| 632 |
+
mjtByte* light_active # is light on (nlight x 1)
|
| 633 |
+
mjtNum* light_pos # position rel. to body frame (nlight x 3)
|
| 634 |
+
mjtNum* light_dir # direction rel. to body frame (nlight x 3)
|
| 635 |
+
mjtNum* light_poscom0 # global position rel. to sub-com in qpos0 (nlight x 3)
|
| 636 |
+
mjtNum* light_pos0 # global position rel. to body in qpos0 (nlight x 3)
|
| 637 |
+
mjtNum* light_dir0 # global direction in qpos0 (nlight x 3)
|
| 638 |
+
float* light_attenuation # OpenGL attenuation (quadratic model) (nlight x 3)
|
| 639 |
+
float* light_cutoff # OpenGL cutoff (nlight x 1)
|
| 640 |
+
float* light_exponent # OpenGL exponent (nlight x 1)
|
| 641 |
+
float* light_ambient # ambient rgb (alpha=1) (nlight x 3)
|
| 642 |
+
float* light_diffuse # diffuse rgb (alpha=1) (nlight x 3)
|
| 643 |
+
float* light_specular # specular rgb (alpha=1) (nlight x 3)
|
| 644 |
+
|
| 645 |
+
# meshes
|
| 646 |
+
int* mesh_vertadr # first vertex address (nmesh x 1)
|
| 647 |
+
int* mesh_vertnum # number of vertices (nmesh x 1)
|
| 648 |
+
int* mesh_texcoordadr # texcoord data address; -1: no texcoord (nmesh x 1)
|
| 649 |
+
int* mesh_faceadr # first face address (nmesh x 1)
|
| 650 |
+
int* mesh_facenum # number of faces (nmesh x 1)
|
| 651 |
+
int* mesh_graphadr # graph data address; -1: no graph (nmesh x 1)
|
| 652 |
+
float* mesh_vert # vertex data for all meshes (nmeshvert x 3)
|
| 653 |
+
float* mesh_normal # vertex normal data for all meshes (nmeshvert x 3)
|
| 654 |
+
float* mesh_texcoord # vertex texcoords for all meshes (nmeshtexvert x 2)
|
| 655 |
+
int* mesh_face # triangle face data (nmeshface x 3)
|
| 656 |
+
int* mesh_graph # convex graph data (nmeshgraph x 1)
|
| 657 |
+
|
| 658 |
+
# skins
|
| 659 |
+
int* skin_matid # skin material id; -1: none (nskin x 1)
|
| 660 |
+
float* skin_rgba # skin rgba (nskin x 4)
|
| 661 |
+
float* skin_inflate # inflate skin in normal direction (nskin x 1)
|
| 662 |
+
int* skin_vertadr # first vertex address (nskin x 1)
|
| 663 |
+
int* skin_vertnum # number of vertices (nskin x 1)
|
| 664 |
+
int* skin_texcoordadr # texcoord data address; -1: no texcoord (nskin x 1)
|
| 665 |
+
int* skin_faceadr # first face address (nskin x 1)
|
| 666 |
+
int* skin_facenum # number of faces (nskin x 1)
|
| 667 |
+
int* skin_boneadr # first bone in skin (nskin x 1)
|
| 668 |
+
int* skin_bonenum # number of bones in skin (nskin x 1)
|
| 669 |
+
float* skin_vert # vertex positions for all skin meshes (nskinvert x 3)
|
| 670 |
+
float* skin_texcoord # vertex texcoords for all skin meshes (nskintexvert x 2)
|
| 671 |
+
int* skin_face # triangle faces for all skin meshes (nskinface x 3)
|
| 672 |
+
int* skin_bonevertadr # first vertex in each bone (nskinbone x 1)
|
| 673 |
+
int* skin_bonevertnum # number of vertices in each bone (nskinbone x 1)
|
| 674 |
+
float* skin_bonebindpos # bind pos of each bone (nskinbone x 3)
|
| 675 |
+
float* skin_bonebindquat # bind quat of each bone (nskinbone x 4)
|
| 676 |
+
int* skin_bonebodyid # body id of each bone (nskinbone x 1)
|
| 677 |
+
int* skin_bonevertid # mesh ids of vertices in each bone (nskinbonevert x 1)
|
| 678 |
+
float* skin_bonevertweight # weights of vertices in each bone (nskinbonevert x 1)
|
| 679 |
+
|
| 680 |
+
# height fields
|
| 681 |
+
mjtNum* hfield_size # (x, y, z_top, z_bottom) (nhfield x 4)
|
| 682 |
+
int* hfield_nrow # number of rows in grid (nhfield x 1)
|
| 683 |
+
int* hfield_ncol # number of columns in grid (nhfield x 1)
|
| 684 |
+
int* hfield_adr # address in hfield_data (nhfield x 1)
|
| 685 |
+
float* hfield_data # elevation data (nhfielddata x 1)
|
| 686 |
+
|
| 687 |
+
# textures
|
| 688 |
+
int* tex_type # texture type (mjtTexture) (ntex x 1)
|
| 689 |
+
int* tex_height # number of rows in texture image (ntex x 1)
|
| 690 |
+
int* tex_width # number of columns in texture image (ntex x 1)
|
| 691 |
+
int* tex_adr # address in rgb (ntex x 1)
|
| 692 |
+
mjtByte* tex_rgb # rgb (alpha = 1) (ntexdata x 1)
|
| 693 |
+
|
| 694 |
+
# materials
|
| 695 |
+
int* mat_texid # texture id; -1: none (nmat x 1)
|
| 696 |
+
mjtByte* mat_texuniform # make texture cube uniform (nmat x 1)
|
| 697 |
+
float* mat_texrepeat # texture repetition for 2d mapping (nmat x 2)
|
| 698 |
+
float* mat_emission # emission (x rgb) (nmat x 1)
|
| 699 |
+
float* mat_specular # specular (x white) (nmat x 1)
|
| 700 |
+
float* mat_shininess # shininess coef (nmat x 1)
|
| 701 |
+
float* mat_reflectance # reflectance (0: disable) (nmat x 1)
|
| 702 |
+
float* mat_rgba # rgba (nmat x 4)
|
| 703 |
+
|
| 704 |
+
# predefined geom pairs for collision detection; has precedence over exclude
|
| 705 |
+
int* pair_dim # contact dimensionality (npair x 1)
|
| 706 |
+
int* pair_geom1 # id of geom1 (npair x 1)
|
| 707 |
+
int* pair_geom2 # id of geom2 (npair x 1)
|
| 708 |
+
int* pair_signature # (body1+1)<<16 + body2+1 (npair x 1)
|
| 709 |
+
mjtNum* pair_solref # constraint solver reference: contact (npair x mjNREF)
|
| 710 |
+
mjtNum* pair_solimp # constraint solver impedance: contact (npair x mjNIMP)
|
| 711 |
+
mjtNum* pair_margin # detect contact if dist<margin (npair x 1)
|
| 712 |
+
mjtNum* pair_gap # include in solver if dist<margin-gap (npair x 1)
|
| 713 |
+
mjtNum* pair_friction # tangent1, 2, spin, roll1, 2 (npair x 5)
|
| 714 |
+
|
| 715 |
+
# excluded body pairs for collision detection
|
| 716 |
+
int* exclude_signature # (body1+1)<<16 + body2+1 (nexclude x 1)
|
| 717 |
+
|
| 718 |
+
# equality constraints
|
| 719 |
+
int* eq_type # constraint type (mjtEq) (neq x 1)
|
| 720 |
+
int* eq_obj1id # id of object 1 (neq x 1)
|
| 721 |
+
int* eq_obj2id # id of object 2 (neq x 1)
|
| 722 |
+
mjtByte* eq_active # enable/disable constraint (neq x 1)
|
| 723 |
+
mjtNum* eq_solref # constraint solver reference (neq x mjNREF)
|
| 724 |
+
mjtNum* eq_solimp # constraint solver impedance (neq x mjNIMP)
|
| 725 |
+
mjtNum* eq_data # numeric data for constraint (neq x mjNEQDATA)
|
| 726 |
+
|
| 727 |
+
# tendons
|
| 728 |
+
int* tendon_adr # address of first object in tendon's path (ntendon x 1)
|
| 729 |
+
int* tendon_num # number of objects in tendon's path (ntendon x 1)
|
| 730 |
+
int* tendon_matid # material id for rendering (ntendon x 1)
|
| 731 |
+
int* tendon_group # group for visibility (ntendon x 1)
|
| 732 |
+
mjtByte* tendon_limited # does tendon have length limits (ntendon x 1)
|
| 733 |
+
mjtNum* tendon_width # width for rendering (ntendon x 1)
|
| 734 |
+
mjtNum* tendon_solref_lim # constraint solver reference: limit (ntendon x mjNREF)
|
| 735 |
+
mjtNum* tendon_solimp_lim # constraint solver impedance: limit (ntendon x mjNIMP)
|
| 736 |
+
mjtNum* tendon_solref_fri # constraint solver reference: friction (ntendon x mjNREF)
|
| 737 |
+
mjtNum* tendon_solimp_fri # constraint solver impedance: friction (ntendon x mjNIMP)
|
| 738 |
+
mjtNum* tendon_range # tendon length limits (ntendon x 2)
|
| 739 |
+
mjtNum* tendon_margin # min distance for limit detection (ntendon x 1)
|
| 740 |
+
mjtNum* tendon_stiffness # stiffness coefficient (ntendon x 1)
|
| 741 |
+
mjtNum* tendon_damping # damping coefficient (ntendon x 1)
|
| 742 |
+
mjtNum* tendon_frictionloss; # loss due to friction (ntendon x 1)
|
| 743 |
+
mjtNum* tendon_lengthspring; # tendon length in qpos_spring (ntendon x 1)
|
| 744 |
+
mjtNum* tendon_length0 # tendon length in qpos0 (ntendon x 1)
|
| 745 |
+
mjtNum* tendon_invweight0 # inv. weight in qpos0 (ntendon x 1)
|
| 746 |
+
mjtNum* tendon_user # user data (ntendon x nuser_tendon)
|
| 747 |
+
float* tendon_rgba # rgba when material is omitted (ntendon x 4)
|
| 748 |
+
|
| 749 |
+
# list of all wrap objects in tendon paths
|
| 750 |
+
int* wrap_type # wrap object type (mjtWrap) (nwrap x 1)
|
| 751 |
+
int* wrap_objid # object id: geom, site, joint (nwrap x 1)
|
| 752 |
+
mjtNum* wrap_prm # divisor, joint coef, or site id (nwrap x 1)
|
| 753 |
+
|
| 754 |
+
# actuators
|
| 755 |
+
int* actuator_trntype # transmission type (mjtTrn) (nu x 1)
|
| 756 |
+
int* actuator_dyntype # dynamics type (mjtDyn) (nu x 1)
|
| 757 |
+
int* actuator_gaintype # gain type (mjtGain) (nu x 1)
|
| 758 |
+
int* actuator_biastype # bias type (mjtBias) (nu x 1)
|
| 759 |
+
int* actuator_trnid # transmission id: joint, tendon, site (nu x 2)
|
| 760 |
+
int* actuator_group # group for visibility (nu x 1)
|
| 761 |
+
mjtByte* actuator_ctrllimited; # is control limited (nu x 1)
|
| 762 |
+
mjtByte* actuator_forcelimited;# is force limited (nu x 1)
|
| 763 |
+
mjtNum* actuator_dynprm # dynamics parameters (nu x mjNDYN)
|
| 764 |
+
mjtNum* actuator_gainprm # gain parameters (nu x mjNGAIN)
|
| 765 |
+
mjtNum* actuator_biasprm # bias parameters (nu x mjNBIAS)
|
| 766 |
+
mjtNum* actuator_ctrlrange # range of controls (nu x 2)
|
| 767 |
+
mjtNum* actuator_forcerange; # range of forces (nu x 2)
|
| 768 |
+
mjtNum* actuator_gear # scale length and transmitted force (nu x 6)
|
| 769 |
+
mjtNum* actuator_cranklength; # crank length for slider-crank (nu x 1)
|
| 770 |
+
mjtNum* actuator_acc0 # acceleration from unit force in qpos0 (nu x 1)
|
| 771 |
+
mjtNum* actuator_length0 # actuator length in qpos0 (nu x 1)
|
| 772 |
+
mjtNum* actuator_lengthrange # ... not yet implemented ??? (nu x 2)
|
| 773 |
+
mjtNum* actuator_user # user data (nu x nuser_actuator)
|
| 774 |
+
|
| 775 |
+
# sensors
|
| 776 |
+
int* sensor_type # sensor type (mjtSensor) (nsensor x 1)
|
| 777 |
+
int* sensor_datatype # numeric data type (mjtDataType) (nsensor x 1)
|
| 778 |
+
int* sensor_needstage # required compute stage (mjtStage) (nsensor x 1)
|
| 779 |
+
int* sensor_objtype # type of sensorized object (mjtObj) (nsensor x 1)
|
| 780 |
+
int* sensor_objid # id of sensorized object (nsensor x 1)
|
| 781 |
+
int* sensor_dim # number of scalar outputs (nsensor x 1)
|
| 782 |
+
int* sensor_adr # address in sensor array (nsensor x 1)
|
| 783 |
+
mjtNum* sensor_cutoff # cutoff for real and positive; 0: ignore (nsensor x 1)
|
| 784 |
+
mjtNum* sensor_noise # noise standard deviation (nsensor x 1)
|
| 785 |
+
mjtNum* sensor_user # user data (nsensor x nuser_sensor)
|
| 786 |
+
|
| 787 |
+
# custom numeric fields
|
| 788 |
+
int* numeric_adr # address of field in numeric_data (nnumeric x 1)
|
| 789 |
+
int* numeric_size # size of numeric field (nnumeric x 1)
|
| 790 |
+
mjtNum* numeric_data # array of all numeric fields (nnumericdata x 1)
|
| 791 |
+
|
| 792 |
+
# custom text fields
|
| 793 |
+
int* text_adr # address of text in text_data (ntext x 1)
|
| 794 |
+
int* text_size # size of text field (strlen+1) (ntext x 1)
|
| 795 |
+
char* text_data # array of all text fields (0-terminated) (ntextdata x 1)
|
| 796 |
+
|
| 797 |
+
# custom tuple fields
|
| 798 |
+
int* tuple_adr # address of text in text_data (ntuple x 1)
|
| 799 |
+
int* tuple_size # number of objects in tuple (ntuple x 1)
|
| 800 |
+
int* tuple_objtype # array of object types in all tuples (ntupledata x 1)
|
| 801 |
+
int* tuple_objid # array of object ids in all tuples (ntupledata x 1)
|
| 802 |
+
mjtNum* tuple_objprm # array of object params in all tuples (ntupledata x 1)
|
| 803 |
+
|
| 804 |
+
# keyframes
|
| 805 |
+
mjtNum* key_time # key time (nkey x 1)
|
| 806 |
+
mjtNum* key_qpos # key position (nkey x nq)
|
| 807 |
+
mjtNum* key_qvel # key velocity (nkey x nv)
|
| 808 |
+
mjtNum* key_act # key activation (nkey x na)
|
| 809 |
+
mjtNum* key_mpos # key mocap position (nkey x 3*nmocap)
|
| 810 |
+
mjtNum* key_mquat # key mocap quaternion (nkey x 4*nmocap)
|
| 811 |
+
|
| 812 |
+
# names
|
| 813 |
+
int* name_bodyadr # body name pointers (nbody x 1)
|
| 814 |
+
int* name_jntadr # joint name pointers (njnt x 1)
|
| 815 |
+
int* name_geomadr # geom name pointers (ngeom x 1)
|
| 816 |
+
int* name_siteadr # site name pointers (nsite x 1)
|
| 817 |
+
int* name_camadr # camera name pointers (ncam x 1)
|
| 818 |
+
int* name_lightadr # light name pointers (nlight x 1)
|
| 819 |
+
int* name_meshadr # mesh name pointers (nmesh x 1)
|
| 820 |
+
int* name_skinadr # skin name pointers (nskin x 1)
|
| 821 |
+
int* name_hfieldadr # hfield name pointers (nhfield x 1)
|
| 822 |
+
int* name_texadr # texture name pointers (ntex x 1)
|
| 823 |
+
int* name_matadr # material name pointers (nmat x 1)
|
| 824 |
+
int* name_pairadr # geom pair name pointers (npair x 1)
|
| 825 |
+
int* name_excludeadr # exclude name pointers (nexclude x 1)
|
| 826 |
+
int* name_eqadr # equality constraint name pointers (neq x 1)
|
| 827 |
+
int* name_tendonadr # tendon name pointers (ntendon x 1)
|
| 828 |
+
int* name_actuatoradr # actuator name pointers (nu x 1)
|
| 829 |
+
int* name_sensoradr # sensor name pointers (nsensor x 1)
|
| 830 |
+
int* name_numericadr # numeric name pointers (nnumeric x 1)
|
| 831 |
+
int* name_textadr # text name pointers (ntext x 1)
|
| 832 |
+
int* name_tupleadr # tuple name pointers (ntuple x 1)
|
| 833 |
+
int* name_keyadr # keyframe name pointers (nkey x 1)
|
| 834 |
+
char* names # names of all objects, 0-terminated (nnames x 1)
|
mujoco-py-2.1.2.14/mujoco_py/pxd/mjrender.pxd
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cdef extern from "mjrender.h" nogil:
|
| 2 |
+
# Global constants
|
| 3 |
+
enum: mjNAUX
|
| 4 |
+
enum: mjMAXTEXTURE
|
| 5 |
+
|
| 6 |
+
ctypedef enum mjtGridPos: # grid position for overlay
|
| 7 |
+
mjGRID_TOPLEFT = 0, # top left
|
| 8 |
+
mjGRID_TOPRIGHT, # top right
|
| 9 |
+
mjGRID_BOTTOMLEFT, # bottom left
|
| 10 |
+
mjGRID_BOTTOMRIGHT # bottom right
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
ctypedef enum mjtFramebuffer: # OpenGL framebuffer option
|
| 14 |
+
mjFB_WINDOW = 0, # default/window buffer
|
| 15 |
+
mjFB_OFFSCREEN # offscreen buffer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
ctypedef enum mjtFontScale: # font scale, used at context creation
|
| 19 |
+
mjFONTSCALE_100 = 100, # normal scale, suitable in the absence of DPI scaling
|
| 20 |
+
mjFONTSCALE_150 = 150, # 150% scale
|
| 21 |
+
mjFONTSCALE_200 = 200 # 200% scale
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
ctypedef enum mjtFont: # font type, used at each text operation
|
| 25 |
+
mjFONT_NORMAL = 0, # normal font
|
| 26 |
+
mjFONT_SHADOW, # normal font with shadow (for higher contrast)
|
| 27 |
+
mjFONT_BIG # big font (for user alerts)
|
| 28 |
+
|
| 29 |
+
ctypedef struct mjrRect: # OpenGL rectangle
|
| 30 |
+
int left # left (usually 0)
|
| 31 |
+
int bottom # bottom (usually 0)
|
| 32 |
+
int width # width (usually buffer width)
|
| 33 |
+
int height # height (usually buffer height)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
ctypedef struct mjrContext: # custom OpenGL context
|
| 37 |
+
# parameters copied from mjVisual
|
| 38 |
+
float lineWidth # line width for wireframe rendering
|
| 39 |
+
float shadowClip # clipping radius for directional lights
|
| 40 |
+
float shadowScale # fraction of light cutoff for spot lights
|
| 41 |
+
float fogStart # fog start = stat.extent * vis.map.fogstart
|
| 42 |
+
float fogEnd # fog end = stat.extent * vis.map.fogend
|
| 43 |
+
float fogRGBA[4] # fog rgba
|
| 44 |
+
int shadowSize # size of shadow map texture
|
| 45 |
+
int offWidth # width of offscreen buffer
|
| 46 |
+
int offHeight # height of offscreen buffer
|
| 47 |
+
int offSamples # number of offscreen buffer multisamples
|
| 48 |
+
|
| 49 |
+
# parameters specified at creation
|
| 50 |
+
int fontScale; # font scale
|
| 51 |
+
int auxWidth[mjNAUX] # auxiliary buffer width
|
| 52 |
+
int auxHeight[mjNAUX] # auxiliary buffer height
|
| 53 |
+
int auxSamples[mjNAUX] # auxiliary buffer multisamples
|
| 54 |
+
|
| 55 |
+
# offscreen rendering objects
|
| 56 |
+
unsigned int offFBO # offscreen framebuffer object
|
| 57 |
+
unsigned int offFBO_r # offscreen framebuffer for resolving multisamples
|
| 58 |
+
unsigned int offColor # offscreen color buffer
|
| 59 |
+
unsigned int offColor_r # offscreen color buffer for resolving multisamples
|
| 60 |
+
unsigned int offDepthStencil # offscreen depth and stencil buffer
|
| 61 |
+
unsigned int offDepthStencil_r # offscreen depth and stencil buffer for resolving multisamples
|
| 62 |
+
|
| 63 |
+
# shadow rendering objects
|
| 64 |
+
unsigned int shadowFBO # shadow map framebuffer object
|
| 65 |
+
unsigned int shadowTex # shadow map texture
|
| 66 |
+
|
| 67 |
+
# auxiliary buffers
|
| 68 |
+
unsigned int auxFBO[mjNAUX] # auxiliary framebuffer object
|
| 69 |
+
unsigned int auxFBO_r[mjNAUX] # auxiliary framebuffer object for resolving
|
| 70 |
+
unsigned int auxColor[mjNAUX] # auxiliary color buffer
|
| 71 |
+
unsigned int auxColor_r[mjNAUX] # auxiliary color buffer for resolving
|
| 72 |
+
|
| 73 |
+
# texture objects and info
|
| 74 |
+
int ntexture # number of allocated textures
|
| 75 |
+
int textureType[100] # type of texture (mjtTexture)
|
| 76 |
+
unsigned int texture[100] # texture names
|
| 77 |
+
|
| 78 |
+
# displaylist starting positions
|
| 79 |
+
unsigned int basePlane # all planes from model
|
| 80 |
+
unsigned int baseMesh # all meshes from model
|
| 81 |
+
unsigned int baseHField # all hfields from model
|
| 82 |
+
unsigned int baseBuiltin # all buildin geoms, with quality from model
|
| 83 |
+
unsigned int baseFontNormal # normal font
|
| 84 |
+
unsigned int baseFontShadow # shadow font
|
| 85 |
+
unsigned int baseFontBig # big font
|
| 86 |
+
|
| 87 |
+
# displaylist ranges
|
| 88 |
+
int rangePlane # all planes from model
|
| 89 |
+
int rangeMesh # all meshes from model
|
| 90 |
+
int rangeHField # all hfields from model
|
| 91 |
+
int rangeBuiltin # all builtin geoms, with quality from model
|
| 92 |
+
int rangeFont # all characters in font
|
| 93 |
+
|
| 94 |
+
# skin VBOs
|
| 95 |
+
int nskin # number of skins
|
| 96 |
+
unsigned int* skinvertVBO # skin vertex position VBOs
|
| 97 |
+
unsigned int* skinnormalVBO # skin vertex normal VBOs
|
| 98 |
+
unsigned int* skintexcoordVBO # skin vertex texture coordinate VBOs
|
| 99 |
+
unsigned int* skinfaceVBO # skin face index VBOs
|
| 100 |
+
|
| 101 |
+
# character info
|
| 102 |
+
int charWidth[127] # character widths: normal and shadow
|
| 103 |
+
int charWidthBig[127] # chacarter widths: big
|
| 104 |
+
int charHeight # character heights: normal and shadow
|
| 105 |
+
int charHeightBig # character heights: big
|
| 106 |
+
|
| 107 |
+
# capabilities
|
| 108 |
+
int glewInitialized # is glew initialized
|
| 109 |
+
int windowAvailable # is default/window framebuffer available
|
| 110 |
+
int windowSamples # number of samples for default/window framebuffer
|
| 111 |
+
int windowStereo # is stereo available for default/window framebuffer
|
| 112 |
+
int windowDoublebuffer # is default/window framebuffer double buffered
|
| 113 |
+
|
| 114 |
+
# only field that changes after mjr_makeContext
|
| 115 |
+
int currentBuffer # currently active framebuffer: mjFB_WINDOW or mjFB_OFFSCREEN
|
mujoco-py-2.1.2.14/mujoco_py/pxd/mujoco.pxd
ADDED
|
@@ -0,0 +1,1083 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include "mjmodel.pxd"
|
| 2 |
+
include "mjdata.pxd"
|
| 3 |
+
include "mjrender.pxd"
|
| 4 |
+
include "mjui.pxd"
|
| 5 |
+
include "mjvisualize.pxd"
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
cdef extern from "mujoco.h" nogil:
|
| 9 |
+
# macros
|
| 10 |
+
#define mjMARKSTACK int _mark = d->pstack;
|
| 11 |
+
#define mjFREESTACK d->pstack = _mark;
|
| 12 |
+
#define mjDISABLED(x) (m->opt.disableflags & (x))
|
| 13 |
+
#define mjENABLED(x) (m->opt.enableflags & (x))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# user error and memory handlers
|
| 17 |
+
void (*mju_user_error)(const char*);
|
| 18 |
+
void (*mju_user_warning)(const char*);
|
| 19 |
+
void* (*mju_user_malloc)(size_t);
|
| 20 |
+
void (*mju_user_free)(void*);
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# # callbacks extending computation pipeline
|
| 24 |
+
# mjfGeneric mjcb_passive;
|
| 25 |
+
# mjfGeneric mjcb_control;
|
| 26 |
+
# mjfSensor mjcb_sensor;
|
| 27 |
+
# mjfTime mjcb_time;
|
| 28 |
+
# mjfAct mjcb_act_dyn;
|
| 29 |
+
mjfAct mjcb_act_gain;
|
| 30 |
+
mjfAct mjcb_act_bias;
|
| 31 |
+
# mjfSolImp mjcb_sol_imp;
|
| 32 |
+
# mjfSolRef mjcb_sol_ref;
|
| 33 |
+
#
|
| 34 |
+
#
|
| 35 |
+
# # collision function table
|
| 36 |
+
# mjfCollision mjCOLLISIONFUNC[mjNGEOMTYPES][mjNGEOMTYPES];
|
| 37 |
+
#
|
| 38 |
+
#
|
| 39 |
+
# # string names
|
| 40 |
+
const char* mjDISABLESTRING[mjNDISABLE];
|
| 41 |
+
const char* mjENABLESTRING[mjNENABLE];
|
| 42 |
+
const char* mjTIMERSTRING[mjNTIMER];
|
| 43 |
+
const char* mjLABELSTRING[mjNLABEL];
|
| 44 |
+
const char* mjFRAMESTRING[mjNFRAME];
|
| 45 |
+
const char* mjVISSTRING[mjNVISFLAG][3];
|
| 46 |
+
const char* mjRNDSTRING[mjNRNDFLAG][3];
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
#---------------------- Activation -----------------------------------------------------
|
| 50 |
+
|
| 51 |
+
# activate license, call mju_error on failure; return 1 if ok, 0 if failure
|
| 52 |
+
int mj_activate(const char* filename);
|
| 53 |
+
|
| 54 |
+
# deactivate license, free memory
|
| 55 |
+
void mj_deactivate();
|
| 56 |
+
|
| 57 |
+
#---------------------- Virtual file system --------------------------------------------
|
| 58 |
+
|
| 59 |
+
# Initialize VFS to empty (no deallocation).
|
| 60 |
+
void mj_defaultVFS(mjVFS* vfs);
|
| 61 |
+
|
| 62 |
+
# Add file to VFS, return 0: success, 1: full, 2: repeated name, -1: not found on disk.
|
| 63 |
+
int mj_addFileVFS(mjVFS* vfs, const char* directory, const char* filename);
|
| 64 |
+
|
| 65 |
+
# Make empty file in VFS, return 0: success, 1: full, 2: repeated name.
|
| 66 |
+
int mj_makeEmptyFileVFS(mjVFS* vfs, const char* filename, int filesize);
|
| 67 |
+
|
| 68 |
+
# Return file index in VFS, or -1 if not found in VFS.
|
| 69 |
+
int mj_findFileVFS(const mjVFS* vfs, const char* filename);
|
| 70 |
+
|
| 71 |
+
# Delete file from VFS, return 0: success, -1: not found in VFS.
|
| 72 |
+
int mj_deleteFileVFS(mjVFS* vfs, const char* filename);
|
| 73 |
+
|
| 74 |
+
# Delete all files from VFS.
|
| 75 |
+
void mj_deleteVFS(mjVFS* vfs);
|
| 76 |
+
|
| 77 |
+
#--------------------- Parse and compile ----------------------------------------------
|
| 78 |
+
|
| 79 |
+
# Parse XML file in MJCF or URDF format, compile it, return low-level model.
|
| 80 |
+
# If vfs is not NULL, look up files in vfs before reading from disk.
|
| 81 |
+
# If error is not NULL, it must have size error_sz.
|
| 82 |
+
mjModel* mj_loadXML(const char* filename, const mjVFS* vfs,
|
| 83 |
+
char* error, int error_sz);
|
| 84 |
+
|
| 85 |
+
# Update XML data structures with info from low-level model, save as MJCF.
|
| 86 |
+
# If error is not NULL, it must have size error_sz.
|
| 87 |
+
int mj_saveLastXML(const char* filename, const mjModel* m,
|
| 88 |
+
char* error, int error_sz);
|
| 89 |
+
|
| 90 |
+
# Free last XML model if loaded. Called internally at each load.
|
| 91 |
+
void mj_freeLastXML();
|
| 92 |
+
|
| 93 |
+
# Print internal XML schema as plain text or HTML, with style-padding or .
|
| 94 |
+
int mj_printSchema(const char* filename, char* buffer, int buffer_sz,
|
| 95 |
+
int flg_html, int flg_pad);
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
#--------------------- Main simulation ------------------------------------------------
|
| 99 |
+
|
| 100 |
+
# Advance simulation, use control callback to obtain external force and control.
|
| 101 |
+
void mj_step(const mjModel* m, mjData* d);
|
| 102 |
+
|
| 103 |
+
# Advance simulation in two steps: before external force and control is set by user.
|
| 104 |
+
void mj_step1(const mjModel* m, mjData* d);
|
| 105 |
+
|
| 106 |
+
# Advance simulation in two steps: after external force and control is set by user.
|
| 107 |
+
void mj_step2(const mjModel* m, mjData* d);
|
| 108 |
+
|
| 109 |
+
# Forward dynamics: same as mj_step but do not integrate in time.
|
| 110 |
+
void mj_forward(const mjModel* m, mjData* d);
|
| 111 |
+
|
| 112 |
+
# Inverse dynamics: qacc must be set before calling.
|
| 113 |
+
void mj_inverse(const mjModel* m, mjData* d);
|
| 114 |
+
|
| 115 |
+
# Forward dynamics with skip; skipstage is mjtStage.
|
| 116 |
+
void mj_forwardSkip(const mjModel* m, mjData* d,
|
| 117 |
+
int skipstage, int skipsensorenergy);
|
| 118 |
+
|
| 119 |
+
# Inverse dynamics with skip; skipstage is mjtStage.
|
| 120 |
+
void mj_inverseSkip(const mjModel* m, mjData* d,
|
| 121 |
+
int skipstage, int skipsensorenergy);
|
| 122 |
+
|
| 123 |
+
# Forward dynamics with skip; skipstage is mjtStage.
|
| 124 |
+
void mj_forwardSkip(const mjModel* m, mjData* d, int skipstage, int skipsensor);
|
| 125 |
+
|
| 126 |
+
# Inverse dynamics with skip; skipstage is mjtStage.
|
| 127 |
+
void mj_inverseSkip(const mjModel* m, mjData* d, int skipstage, int skipsensor);
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
#--------------------- Initialization -------------------------------------------------
|
| 131 |
+
|
| 132 |
+
# Set default options for length range computation.
|
| 133 |
+
void mj_defaultLROpt(mjLROpt* opt);
|
| 134 |
+
|
| 135 |
+
# Set solver parameters to default values.
|
| 136 |
+
void mj_defaultSolRefImp(mjtNum* solref, mjtNum* solimp);
|
| 137 |
+
|
| 138 |
+
# Set physics options to default values.
|
| 139 |
+
void mj_defaultOption(mjOption* opt);
|
| 140 |
+
|
| 141 |
+
# Set visual options to default values.
|
| 142 |
+
void mj_defaultVisual(mjVisual* vis);
|
| 143 |
+
|
| 144 |
+
# Copy mjModel, allocate new if dest is NULL.
|
| 145 |
+
mjModel* mj_copyModel(mjModel* dest, const mjModel* src);
|
| 146 |
+
|
| 147 |
+
# Save model to binary MJB file or memory buffer; buffer has precedence when given.
|
| 148 |
+
void mj_saveModel(const mjModel* m, const char* filename, void* buffer, int buffer_sz);
|
| 149 |
+
|
| 150 |
+
# Load model from binary MJB file.
|
| 151 |
+
# If vfs is not NULL, look up file in vfs before reading from disk.
|
| 152 |
+
mjModel* mj_loadModel(const char* filename, mjVFS* vfs);
|
| 153 |
+
|
| 154 |
+
# Free memory allocation in model.
|
| 155 |
+
void mj_deleteModel(mjModel* m);
|
| 156 |
+
|
| 157 |
+
# Return size of buffer needed to hold model.
|
| 158 |
+
int mj_sizeModel(const mjModel* m);
|
| 159 |
+
|
| 160 |
+
# Allocate mjData correponding to given model.
|
| 161 |
+
mjData* mj_makeData(const mjModel* m);
|
| 162 |
+
|
| 163 |
+
# Copy mjData.
|
| 164 |
+
mjData* mj_copyData(mjData* dest, const mjModel* m, const mjData* src);
|
| 165 |
+
|
| 166 |
+
# Reset data to defaults.
|
| 167 |
+
void mj_resetData(const mjModel* m, mjData* d);
|
| 168 |
+
|
| 169 |
+
# Reset data to defaults, fill everything else with debug_value.
|
| 170 |
+
void mj_resetDataDebug(const mjModel* m, mjData* d, unsigned char debug_value);
|
| 171 |
+
|
| 172 |
+
# Reset data, set fields from specified keyframe.
|
| 173 |
+
void mj_resetDataKeyframe(const mjModel* m, mjData* d, int key);
|
| 174 |
+
|
| 175 |
+
# Allocate array of specified size on mjData stack. Call mju_error on stack overflow.
|
| 176 |
+
mjtNum* mj_stackAlloc(mjData* d, int size);
|
| 177 |
+
|
| 178 |
+
# Free memory allocation in mjData.
|
| 179 |
+
void mj_deleteData(mjData* d);
|
| 180 |
+
|
| 181 |
+
# Reset all callbacks to NULL pointers (NULL is the default).
|
| 182 |
+
void mj_resetCallbacks();
|
| 183 |
+
|
| 184 |
+
# Set constant fields of mjModel, corresponding to qpos0 configuration.
|
| 185 |
+
void mj_setConst(mjModel* m, mjData* d);
|
| 186 |
+
|
| 187 |
+
# Set actuator_lengthrange for specified actuator; return 1 if ok, 0 if error.
|
| 188 |
+
int mj_setLengthRange(mjModel* m, mjData* d, int index,
|
| 189 |
+
const mjLROpt* opt, char* error, int error_sz);
|
| 190 |
+
|
| 191 |
+
#--------------------- Printing -------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
# Print model to text file.
|
| 194 |
+
void mj_printModel(const mjModel* m, const char* filename);
|
| 195 |
+
|
| 196 |
+
# Print data to text file.
|
| 197 |
+
void mj_printData(const mjModel* m, mjData* d, const char* filename);
|
| 198 |
+
|
| 199 |
+
# Print matrix to screen.
|
| 200 |
+
void mju_printMat(const mjtNum* mat, int nr, int nc);
|
| 201 |
+
|
| 202 |
+
# Print sparse matrix to screen.
|
| 203 |
+
void mju_printMatSparse(const mjtNum* mat, int nr,
|
| 204 |
+
const int* rownnz, const int* rowadr,
|
| 205 |
+
const int* colind);
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
#--------------------- Components -----------------------------------------------------
|
| 209 |
+
|
| 210 |
+
# Run position-dependent computations.
|
| 211 |
+
void mj_fwdPosition(const mjModel* m, mjData* d);
|
| 212 |
+
|
| 213 |
+
# Run velocity-dependent computations.
|
| 214 |
+
void mj_fwdVelocity(const mjModel* m, mjData* d);
|
| 215 |
+
|
| 216 |
+
# Compute actuator force qfrc_actuation.
|
| 217 |
+
void mj_fwdActuation(const mjModel* m, mjData* d);
|
| 218 |
+
|
| 219 |
+
# Add up all non-constraint forces, compute qacc_unc.
|
| 220 |
+
void mj_fwdAcceleration(const mjModel* m, mjData* d);
|
| 221 |
+
|
| 222 |
+
# Run selected constraint solver.
|
| 223 |
+
void mj_fwdConstraint(const mjModel* m, mjData* d);
|
| 224 |
+
|
| 225 |
+
# Euler integrator, semi-implicit in velocity.
|
| 226 |
+
void mj_Euler(const mjModel* m, mjData* d);
|
| 227 |
+
|
| 228 |
+
# Runge-Kutta explicit order-N integrator.
|
| 229 |
+
void mj_RungeKutta(const mjModel* m, mjData* d, int N);
|
| 230 |
+
|
| 231 |
+
# Run position-dependent computations in inverse dynamics.
|
| 232 |
+
void mj_invPosition(const mjModel* m, mjData* d);
|
| 233 |
+
|
| 234 |
+
# Run velocity-dependent computations in inverse dynamics.
|
| 235 |
+
void mj_invVelocity(const mjModel* m, mjData* d);
|
| 236 |
+
|
| 237 |
+
# Apply the analytical formula for inverse constraint dynamics.
|
| 238 |
+
void mj_invConstraint(const mjModel* m, mjData* d);
|
| 239 |
+
|
| 240 |
+
# Compare forward and inverse dynamics, save results in fwdinv.
|
| 241 |
+
void mj_compareFwdInv(const mjModel* m, mjData* d);
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
#--------------------- Sub components -------------------------------------------------
|
| 245 |
+
|
| 246 |
+
# Evaluate position-dependent sensors.
|
| 247 |
+
void mj_sensorPos(const mjModel* m, mjData* d);
|
| 248 |
+
|
| 249 |
+
# Evaluate velocity-dependent sensors.
|
| 250 |
+
void mj_sensorVel(const mjModel* m, mjData* d);
|
| 251 |
+
|
| 252 |
+
# Evaluate acceleration and force-dependent sensors.
|
| 253 |
+
void mj_sensorAcc(const mjModel* m, mjData* d);
|
| 254 |
+
|
| 255 |
+
# Evaluate position-dependent energy (potential).
|
| 256 |
+
void mj_energyPos(const mjModel* m, mjData* d);
|
| 257 |
+
|
| 258 |
+
# Evaluate velocity-dependent energy (kinetic).
|
| 259 |
+
void mj_energyVel(const mjModel* m, mjData* d);
|
| 260 |
+
|
| 261 |
+
# Check qpos, reset if any element is too big or nan.
|
| 262 |
+
void mj_checkPos(const mjModel* m, mjData* d);
|
| 263 |
+
|
| 264 |
+
# Check qvel, reset if any element is too big or nan.
|
| 265 |
+
void mj_checkVel(const mjModel* m, mjData* d);
|
| 266 |
+
|
| 267 |
+
# Check qacc, reset if any element is too big or nan.
|
| 268 |
+
void mj_checkAcc(const mjModel* m, mjData* d);
|
| 269 |
+
|
| 270 |
+
# Run forward kinematics.
|
| 271 |
+
void mj_kinematics(const mjModel* m, mjData* d);
|
| 272 |
+
|
| 273 |
+
# Map inertias and motion dofs to global frame centered at CoM.
|
| 274 |
+
void mj_comPos(const mjModel* m, mjData* d);
|
| 275 |
+
|
| 276 |
+
# Compute camera and light positions and orientations.
|
| 277 |
+
void mj_camlight(const mjModel* m, mjData* d);
|
| 278 |
+
|
| 279 |
+
# Compute tendon lengths, velocities and moment arms.
|
| 280 |
+
void mj_tendon(const mjModel* m, mjData* d);
|
| 281 |
+
|
| 282 |
+
# Compute actuator transmission lengths and moments.
|
| 283 |
+
void mj_transmission(const mjModel* m, mjData* d);
|
| 284 |
+
|
| 285 |
+
# Run composite rigid body inertia algorithm (CRB).
|
| 286 |
+
void mj_crb(const mjModel* m, mjData* d);
|
| 287 |
+
|
| 288 |
+
# Compute sparse L'*D*L factorizaton of inertia matrix.
|
| 289 |
+
void mj_factorM(const mjModel* m, mjData* d);
|
| 290 |
+
|
| 291 |
+
# Solve linear system M * x = y using factorization: x = inv(L'*D*L)*y
|
| 292 |
+
void mj_solveM(const mjModel* m, mjData* d, mjtNum* x, const mjtNum* y, int n);
|
| 293 |
+
|
| 294 |
+
# Half of linear solve: x = sqrt(inv(D))*inv(L')*y
|
| 295 |
+
void mj_solveM2(const mjModel* m, mjData* d, mjtNum* x, const mjtNum* y, int n);
|
| 296 |
+
|
| 297 |
+
# Compute cvel, cdof_dot.
|
| 298 |
+
void mj_comVel(const mjModel* m, mjData* d);
|
| 299 |
+
|
| 300 |
+
# Compute qfrc_passive from spring-dampers, viscosity and density.
|
| 301 |
+
void mj_passive(const mjModel* m, mjData* d);
|
| 302 |
+
|
| 303 |
+
# subtree linear velocity and angular momentum
|
| 304 |
+
void mj_subtreeVel(const mjModel* m, mjData* d);
|
| 305 |
+
|
| 306 |
+
# RNE: compute M(qpos)*qacc + C(qpos,qvel); flg_acc=0 removes inertial term.
|
| 307 |
+
void mj_rne(const mjModel* m, mjData* d, int flg_acc, mjtNum* result);
|
| 308 |
+
|
| 309 |
+
# RNE with complete data: compute cacc, cfrc_ext, cfrc_int.
|
| 310 |
+
void mj_rnePostConstraint(const mjModel* m, mjData* d);
|
| 311 |
+
|
| 312 |
+
# Run collision detection.
|
| 313 |
+
void mj_collision(const mjModel* m, mjData* d);
|
| 314 |
+
|
| 315 |
+
# Construct constraints.
|
| 316 |
+
void mj_makeConstraint(const mjModel* m, mjData* d);
|
| 317 |
+
|
| 318 |
+
# Compute inverse constaint inertia efc_AR.
|
| 319 |
+
void mj_projectConstraint(const mjModel* m, mjData* d);
|
| 320 |
+
|
| 321 |
+
# Compute efc_vel, efc_aref.
|
| 322 |
+
void mj_referenceConstraint(const mjModel* m, mjData* d);
|
| 323 |
+
|
| 324 |
+
# Compute efc_state, efc_force, qfrc_constraint, and (optionally) cone Hessians.
|
| 325 |
+
# If cost is not NULL, set *cost = s(jar) where jar = Jac*qacc-aref.
|
| 326 |
+
void mj_constraintUpdate(const mjModel* m, mjData* d, const mjtNum* jar,
|
| 327 |
+
mjtNum* cost, int flg_coneHessian);
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
#--------------------- Support --------------------------------------------------------
|
| 331 |
+
|
| 332 |
+
# Add contact to d->contact list; return 0 if success; 1 if buffer full.
|
| 333 |
+
int mj_addContact(const mjModel* m, mjData* d, const mjContact* con);
|
| 334 |
+
|
| 335 |
+
# Determine type of friction cone.
|
| 336 |
+
int mj_isPyramidal(const mjModel* m);
|
| 337 |
+
|
| 338 |
+
# Determine type of constraint Jacobian.
|
| 339 |
+
int mj_isSparse(const mjModel* m);
|
| 340 |
+
|
| 341 |
+
# Determine type of solver (PGS is dual, CG and Newton are primal).
|
| 342 |
+
int mj_isDual(const mjModel* m);
|
| 343 |
+
|
| 344 |
+
# Multiply dense or sparse constraint Jacobian by vector.
|
| 345 |
+
void mj_mulJacVec(const mjModel* m, mjData* d,
|
| 346 |
+
mjtNum* res, const mjtNum* vec);
|
| 347 |
+
|
| 348 |
+
# Multiply dense or sparse constraint Jacobian transpose by vector.
|
| 349 |
+
void mj_mulJacTVec(const mjModel* m, mjData* d, mjtNum* res, const mjtNum* vec);
|
| 350 |
+
|
| 351 |
+
# Compute 3/6-by-nv end-effector Jacobian of global point attached to given body.
|
| 352 |
+
void mj_jac(const mjModel* m, const mjData* d,
|
| 353 |
+
mjtNum* jacp, mjtNum* jacr, const mjtNum point[3], int body);
|
| 354 |
+
|
| 355 |
+
# Compute body frame end-effector Jacobian.
|
| 356 |
+
void mj_jacBody(const mjModel* m, const mjData* d,
|
| 357 |
+
mjtNum* jacp, mjtNum* jacr, int body);
|
| 358 |
+
|
| 359 |
+
# Compute body center-of-mass end-effector Jacobian.
|
| 360 |
+
void mj_jacBodyCom(const mjModel* m, const mjData* d,
|
| 361 |
+
mjtNum* jacp, mjtNum* jacr, int body);
|
| 362 |
+
|
| 363 |
+
# Compute geom end-effector Jacobian.
|
| 364 |
+
void mj_jacGeom(const mjModel* m, const mjData* d,
|
| 365 |
+
mjtNum* jacp, mjtNum* jacr, int geom);
|
| 366 |
+
|
| 367 |
+
# Compute site end-effector Jacobian.
|
| 368 |
+
void mj_jacSite(const mjModel* m, const mjData* d,
|
| 369 |
+
mjtNum* jacp, mjtNum* jacr, int site);
|
| 370 |
+
|
| 371 |
+
# Compute translation end-effector Jacobian of point, and rotation Jacobian of axis.
|
| 372 |
+
void mj_jacPointAxis(const mjModel* m, mjData* d,
|
| 373 |
+
mjtNum* jacPoint, mjtNum* jacAxis,
|
| 374 |
+
const mjtNum point[3], const mjtNum axis[3], int body);
|
| 375 |
+
|
| 376 |
+
# Get id of object with specified name, return -1 if not found; type is mjtObj.
|
| 377 |
+
int mj_name2id(const mjModel* m, int type, const char* name);
|
| 378 |
+
|
| 379 |
+
# Get name of object with specified id, return 0 if invalid type or id; type is mjtObj.
|
| 380 |
+
const char* mj_id2name(const mjModel* m, int type, int id);
|
| 381 |
+
|
| 382 |
+
# Convert sparse inertia matrix M into full (i.e. dense) matrix.
|
| 383 |
+
void mj_fullM(const mjModel* m, mjtNum* dst, const mjtNum* M);
|
| 384 |
+
|
| 385 |
+
# Multiply vector by inertia matrix.
|
| 386 |
+
void mj_mulM(const mjModel* m, const mjData* d, mjtNum* res, const mjtNum* vec);
|
| 387 |
+
|
| 388 |
+
# Multiply vector by (inertia matrix)^(1/2).
|
| 389 |
+
void mj_mulM2(const mjModel* m, const mjData* d, mjtNum* res, const mjtNum* vec);
|
| 390 |
+
|
| 391 |
+
# Add inertia matrix to destination matrix.
|
| 392 |
+
# Destination can be sparse uncompressed, or dense when all int* are NULL
|
| 393 |
+
void mj_addM(const mjModel* m, mjData* d, mjtNum* dst,
|
| 394 |
+
int* rownnz, int* rowadr, int* colind);
|
| 395 |
+
|
| 396 |
+
# Apply cartesian force and torque (outside xfrc_applied mechanism).
|
| 397 |
+
void mj_applyFT(const mjModel* m, mjData* d,
|
| 398 |
+
const mjtNum* force, const mjtNum* torque,
|
| 399 |
+
const mjtNum* point, int body, mjtNum* qfrc_target);
|
| 400 |
+
|
| 401 |
+
# Compute object 6D velocity in object-centered frame, world/local orientation.
|
| 402 |
+
void mj_objectVelocity(const mjModel* m, const mjData* d,
|
| 403 |
+
int objtype, int objid, mjtNum* res, int flg_local);
|
| 404 |
+
|
| 405 |
+
# Compute object 6D acceleration in object-centered frame, world/local orientation.
|
| 406 |
+
void mj_objectAcceleration(const mjModel* m, const mjData* d,
|
| 407 |
+
int objtype, int objid, mjtNum* res, int flg_local);
|
| 408 |
+
|
| 409 |
+
# Extract 6D force:torque for one contact, in contact frame.
|
| 410 |
+
void mj_contactForce(const mjModel* m, const mjData* d, int id, mjtNum* result);
|
| 411 |
+
|
| 412 |
+
# Compute velocity by finite-differencing two positions.
|
| 413 |
+
void mj_differentiatePos(const mjModel* m, mjtNum* qvel, mjtNum dt,
|
| 414 |
+
const mjtNum* qpos1, const mjtNum* qpos2);
|
| 415 |
+
|
| 416 |
+
# Integrate position with given velocity.
|
| 417 |
+
void mj_integratePos(const mjModel* m, mjtNum* qpos, const mjtNum* qvel, mjtNum dt);
|
| 418 |
+
|
| 419 |
+
# Normalize all quaterions in qpos-type vector.
|
| 420 |
+
void mj_normalizeQuat(const mjModel* m, mjtNum* qpos);
|
| 421 |
+
|
| 422 |
+
# Map from body local to global Cartesian coordinates.
|
| 423 |
+
void mj_local2Global(mjData* d, mjtNum* xpos, mjtNum* xmat, const mjtNum* pos, const mjtNum* quat,
|
| 424 |
+
int body, mjtByte sameframe);
|
| 425 |
+
|
| 426 |
+
# Sum all body masses.
|
| 427 |
+
mjtNum mj_getTotalmass(const mjModel* m);
|
| 428 |
+
|
| 429 |
+
# Scale body masses and inertias to achieve specified total mass.
|
| 430 |
+
void mj_setTotalmass(mjModel* m, mjtNum newmass);
|
| 431 |
+
|
| 432 |
+
# Return version number: 1.0.2 is encoded as 102.
|
| 433 |
+
int mj_version();
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
#--------------------- Ray collisions -------------------------------------------------
|
| 437 |
+
|
| 438 |
+
# Intersect ray (pnt+x*vec, x>=0) with visible geoms, except geoms in bodyexclude.
|
| 439 |
+
# Return geomid and distance (x) to nearest surface, or -1 if no intersection.
|
| 440 |
+
# geomgroup, flg_static are as in mjvOption; geomgroup==NULL skips group exclusion.
|
| 441 |
+
mjtNum mj_ray(const mjModel* m, const mjData* d, const mjtNum* pnt, const mjtNum* vec,
|
| 442 |
+
const mjtByte* geomgroup, mjtByte flg_static, int bodyexclude,
|
| 443 |
+
int* geomid);
|
| 444 |
+
|
| 445 |
+
# Interect ray with hfield, return nearest distance or -1 if no intersection.
|
| 446 |
+
mjtNum mj_rayHfield(const mjModel* m, const mjData* d, int geomid,
|
| 447 |
+
const mjtNum* pnt, const mjtNum* vec);
|
| 448 |
+
|
| 449 |
+
# Interect ray with mesh, return nearest distance or -1 if no intersection.
|
| 450 |
+
mjtNum mj_rayMesh(const mjModel* m, const mjData* d, int geomid,
|
| 451 |
+
const mjtNum* pnt, const mjtNum* vec);
|
| 452 |
+
|
| 453 |
+
# Interect ray with pure geom, return nearest distance or -1 if no intersection.
|
| 454 |
+
mjtNum mju_rayGeom(const mjtNum* pos, const mjtNum* mat, const mjtNum* size,
|
| 455 |
+
const mjtNum* pnt, const mjtNum* vec, int geomtype);
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
#--------------------- Interaction ----------------------------------------------------
|
| 459 |
+
|
| 460 |
+
# Set default camera.
|
| 461 |
+
void mjv_defaultCamera(mjvCamera* cam);
|
| 462 |
+
|
| 463 |
+
# Set default perturbation.
|
| 464 |
+
void mjv_defaultPerturb(mjvPerturb* pert);
|
| 465 |
+
|
| 466 |
+
# Transform pose from room to model space.
|
| 467 |
+
void mjv_room2model(mjtNum* modelpos, mjtNum* modelquat, const mjtNum* roompos,
|
| 468 |
+
const mjtNum* roomquat, const mjvScene* scn);
|
| 469 |
+
|
| 470 |
+
# Transform pose from model to room space.
|
| 471 |
+
void mjv_model2room(mjtNum* roompos, mjtNum* roomquat, const mjtNum* modelpos,
|
| 472 |
+
const mjtNum* modelquat, const mjvScene* scn);
|
| 473 |
+
|
| 474 |
+
# Get camera info in model space; average left and right OpenGL cameras.
|
| 475 |
+
void mjv_cameraInModel(mjtNum* headpos, mjtNum* forward, mjtNum* up,
|
| 476 |
+
const mjvScene* scn);
|
| 477 |
+
|
| 478 |
+
# Get camera info in room space; average left and right OpenGL cameras.
|
| 479 |
+
void mjv_cameraInRoom(mjtNum* headpos, mjtNum* forward, mjtNum* up,
|
| 480 |
+
const mjvScene* scn);
|
| 481 |
+
|
| 482 |
+
# Get frustum height at unit distance from camera; average left and right OpenGL cameras.
|
| 483 |
+
mjtNum mjv_frustumHeight(const mjvScene* scn);
|
| 484 |
+
|
| 485 |
+
# Rotate 3D vec in horizontal plane by angle between (0,1) and (forward_x,forward_y).
|
| 486 |
+
void mjv_alignToCamera(mjtNum* res, const mjtNum* vec, const mjtNum* forward);
|
| 487 |
+
|
| 488 |
+
# Move camera with mouse; action is mjtMouse.
|
| 489 |
+
void mjv_moveCamera(const mjModel* m, int action, mjtNum reldx, mjtNum reldy,
|
| 490 |
+
const mjvScene* scn, mjvCamera* cam);
|
| 491 |
+
|
| 492 |
+
# Move perturb object with mouse; action is mjtMouse.
|
| 493 |
+
void mjv_movePerturb(const mjModel* m, const mjData* d, int action, mjtNum reldx,
|
| 494 |
+
mjtNum reldy, const mjvScene* scn, mjvPerturb* pert);
|
| 495 |
+
|
| 496 |
+
# Move model with mouse; action is mjtMouse.
|
| 497 |
+
void mjv_moveModel(const mjModel* m, int action, mjtNum reldx, mjtNum reldy,
|
| 498 |
+
const mjtNum* roomup, mjvScene* scn);
|
| 499 |
+
|
| 500 |
+
# Copy perturb pos,quat from selected body; set scale for perturbation.
|
| 501 |
+
void mjv_initPerturb(const mjModel* m, const mjData* d,
|
| 502 |
+
const mjvScene* scn, mjvPerturb* pert);
|
| 503 |
+
|
| 504 |
+
# Set perturb pos,quat in d->mocap when selected body is mocap, and in d->qpos otherwise.
|
| 505 |
+
# Write d->qpos only if flg_paused and subtree root for selected body has free joint.
|
| 506 |
+
void mjv_applyPerturbPose(const mjModel* m, mjData* d, const mjvPerturb* pert,
|
| 507 |
+
int flg_paused);
|
| 508 |
+
|
| 509 |
+
# Set perturb force,torque in d->xfrc_applied, if selected body is dynamic.
|
| 510 |
+
void mjv_applyPerturbForce(const mjModel* m, mjData* d, const mjvPerturb* pert);
|
| 511 |
+
|
| 512 |
+
# Return the average of two OpenGL cameras.
|
| 513 |
+
mjvGLCamera mjv_averageCamera(const mjvGLCamera* cam1, const mjvGLCamera* cam2);
|
| 514 |
+
|
| 515 |
+
# Select geom or skin with mouse, return bodyid; -1: none selected.
|
| 516 |
+
int mjv_select(const mjModel* m, const mjData* d, const mjvOption* vopt,
|
| 517 |
+
mjtNum aspectratio, mjtNum relx, mjtNum rely,
|
| 518 |
+
const mjvScene* scn, mjtNum* selpnt, int* geomid, int* skinid);
|
| 519 |
+
|
| 520 |
+
#--------------------- Visualization --------------------------------------------------
|
| 521 |
+
|
| 522 |
+
# Set default visualization options.
|
| 523 |
+
void mjv_defaultOption(mjvOption* opt);
|
| 524 |
+
|
| 525 |
+
# Set default figure.
|
| 526 |
+
void mjv_defaultFigure(mjvFigure* fig);
|
| 527 |
+
|
| 528 |
+
# Initialize given geom fields when not NULL, set the rest to their default values.
|
| 529 |
+
void mjv_initGeom(mjvGeom* geom, int type, const mjtNum* size,
|
| 530 |
+
const mjtNum* pos, const mjtNum* mat, const float* rgba);
|
| 531 |
+
|
| 532 |
+
# Set (type, size, pos, mat) for connector-type geom between given points.
|
| 533 |
+
# Assume that mjv_initGeom was already called to set all other properties.
|
| 534 |
+
void mjv_makeConnector(mjvGeom* geom, int type, mjtNum width,
|
| 535 |
+
mjtNum a0, mjtNum a1, mjtNum a2,
|
| 536 |
+
mjtNum b0, mjtNum b1, mjtNum b2);
|
| 537 |
+
|
| 538 |
+
# Set default abstract scene.
|
| 539 |
+
void mjv_defaultScene(mjvScene* scn);
|
| 540 |
+
|
| 541 |
+
# Allocate resources in abstract scene.
|
| 542 |
+
void mjv_makeScene(const mjModel* m, mjvScene* scn, int maxgeom);
|
| 543 |
+
|
| 544 |
+
# Free abstract scene.
|
| 545 |
+
void mjv_freeScene(mjvScene* scn);
|
| 546 |
+
|
| 547 |
+
# Update entire scene given model state.
|
| 548 |
+
void mjv_updateScene(const mjModel* m, mjData* d, const mjvOption* opt,
|
| 549 |
+
const mjvPerturb* pert, mjvCamera* cam, int catmask, mjvScene* scn);
|
| 550 |
+
|
| 551 |
+
# Add geoms from selected categories to existing scene.
|
| 552 |
+
void mjv_addGeoms(const mjModel* m, mjData* d, const mjvOption* opt,
|
| 553 |
+
const mjvPerturb* pert, int catmask, mjvScene* scn);
|
| 554 |
+
|
| 555 |
+
# Make list of lights.
|
| 556 |
+
void mjv_makeLights(const mjModel* m, mjData* d, mjvScene* scn);
|
| 557 |
+
|
| 558 |
+
# Update camera only.
|
| 559 |
+
void mjv_updateCamera(const mjModel* m, mjData* d, mjvCamera* cam, mjvScene* scn);
|
| 560 |
+
|
| 561 |
+
# Update skins.
|
| 562 |
+
void mjv_updateSkin(const mjModel* m, mjData* d, mjvScene* scn);
|
| 563 |
+
|
| 564 |
+
#--------------------- OpenGL rendering -----------------------------------------------
|
| 565 |
+
|
| 566 |
+
# Set default mjrContext.
|
| 567 |
+
void mjr_defaultContext(mjrContext* con);
|
| 568 |
+
|
| 569 |
+
# Allocate resources in custom OpenGL context; fontscale is mjtFontScale.
|
| 570 |
+
void mjr_makeContext(const mjModel* m, mjrContext* con, int fontscale);
|
| 571 |
+
|
| 572 |
+
# Change font of existing context.
|
| 573 |
+
void mjr_changeFont(int fontscale, mjrContext* con);
|
| 574 |
+
|
| 575 |
+
# Add Aux buffer with given index to context; free previous Aux buffer.
|
| 576 |
+
void mjr_addAux(int index, int width, int height, int samples, mjrContext* con);
|
| 577 |
+
|
| 578 |
+
# Free resources in custom OpenGL context, set to default.
|
| 579 |
+
void mjr_freeContext(mjrContext* con);
|
| 580 |
+
|
| 581 |
+
# Upload texture to GPU, overwriting previous upload if any.
|
| 582 |
+
void mjr_uploadTexture(const mjModel* m, const mjrContext* con, int texid);
|
| 583 |
+
|
| 584 |
+
# Upload mesh to GPU, overwriting previous upload if any.
|
| 585 |
+
void mjr_uploadMesh(const mjModel* m, const mjrContext* con, int meshid);
|
| 586 |
+
|
| 587 |
+
# Upload height field to GPU, overwriting previous upload if any.
|
| 588 |
+
void mjr_uploadHField(const mjModel* m, const mjrContext* con, int hfieldid);
|
| 589 |
+
|
| 590 |
+
# Make con->currentBuffer current again.
|
| 591 |
+
void mjr_restoreBuffer(const mjrContext* con);
|
| 592 |
+
|
| 593 |
+
# Set OpenGL framebuffer for rendering: mjFB_WINDOW or mjFB_OFFSCREEN.
|
| 594 |
+
# If only one buffer is available, set that buffer and ignore framebuffer argument.
|
| 595 |
+
void mjr_setBuffer(int framebuffer, mjrContext* con);
|
| 596 |
+
|
| 597 |
+
# Read pixels from current OpenGL framebuffer to client buffer.
|
| 598 |
+
# Viewport is in OpenGL framebuffer; client buffer starts at (0,0).
|
| 599 |
+
void mjr_readPixels(unsigned char* rgb, float* depth,
|
| 600 |
+
mjrRect viewport, const mjrContext* con);
|
| 601 |
+
|
| 602 |
+
# Draw pixels from client buffer to current OpenGL framebuffer.
|
| 603 |
+
# Viewport is in OpenGL framebuffer; client buffer starts at (0,0).
|
| 604 |
+
void mjr_drawPixels(const unsigned char* rgb, const float* depth,
|
| 605 |
+
mjrRect viewport, const mjrContext* con);
|
| 606 |
+
|
| 607 |
+
# Blit from src viewpoint in current framebuffer to dst viewport in other framebuffer.
|
| 608 |
+
# If src, dst have different size and flg_depth==0, color is interpolated with GL_LINEAR.
|
| 609 |
+
void mjr_blitBuffer(mjrRect src, mjrRect dst, int flg_color, int flg_depth, const mjrContext* con);
|
| 610 |
+
|
| 611 |
+
# Set Aux buffer for custom OpenGL rendering (call restoreBuffer when done).
|
| 612 |
+
void mjr_setAux(int index, const mjrContext* con);
|
| 613 |
+
|
| 614 |
+
# Blit from Aux buffer to con->currentBuffer.
|
| 615 |
+
void mjr_blitAux(int index, mjrRect src, int left, int bottom, const mjrContext* con);
|
| 616 |
+
|
| 617 |
+
# Draw text at (x,y) in relative coordinates; font is mjtFont.
|
| 618 |
+
void mjr_text(int font, const char* txt, const mjrContext* con,
|
| 619 |
+
float x, float y, float r, float g, float b);
|
| 620 |
+
|
| 621 |
+
# Draw text overlay; font is mjtFont; gridpos is mjtGridPos.
|
| 622 |
+
void mjr_overlay(int font, int gridpos, mjrRect viewport,
|
| 623 |
+
const char* overlay, const char* overlay2, const mjrContext* con);
|
| 624 |
+
|
| 625 |
+
# Get maximum viewport for active buffer.
|
| 626 |
+
mjrRect mjr_maxViewport(const mjrContext* con);
|
| 627 |
+
|
| 628 |
+
# Draw rectangle.
|
| 629 |
+
void mjr_rectangle(mjrRect viewport, float r, float g, float b, float a);
|
| 630 |
+
|
| 631 |
+
# Draw rectangle with centered text.
|
| 632 |
+
void mjr_label(mjrRect viewport, int font, const char* txt,
|
| 633 |
+
float r, float g, float b, float a, float rt, float gt, float bt,
|
| 634 |
+
const mjrContext* con);
|
| 635 |
+
|
| 636 |
+
# Draw 2D figure.
|
| 637 |
+
void mjr_figure(mjrRect viewport, const mjvFigure* fig, const mjrContext* con);
|
| 638 |
+
|
| 639 |
+
# Render 3D scene.
|
| 640 |
+
void mjr_render(mjrRect viewport, mjvScene* scn, const mjrContext* con);
|
| 641 |
+
|
| 642 |
+
# Call glFinish.
|
| 643 |
+
void mjr_finish();
|
| 644 |
+
|
| 645 |
+
# Call glGetError and return result.
|
| 646 |
+
int mjr_getError();
|
| 647 |
+
|
| 648 |
+
# Find first rectangle containing mouse, -1: not found.
|
| 649 |
+
int mjr_findRect(int x, int y, int nrect, const mjrRect* rect);
|
| 650 |
+
|
| 651 |
+
#---------------------- UI framework ---------------------------------------------------
|
| 652 |
+
|
| 653 |
+
# Add definitions to UI.
|
| 654 |
+
void mjui_add(mjUI* ui, const mjuiDef* _def);
|
| 655 |
+
|
| 656 |
+
# Add definitions to UI section.
|
| 657 |
+
void mjui_addToSection(mjUI* ui, int sect, const mjuiDef* _def);
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
# Compute UI sizes.
|
| 661 |
+
void mjui_resize(mjUI* ui, const mjrContext* con);
|
| 662 |
+
|
| 663 |
+
# Update specific section/item; -1: update all.
|
| 664 |
+
void mjui_update(int section, int item, const mjUI* ui, const mjuiState* state, const mjrContext* con);
|
| 665 |
+
|
| 666 |
+
# Handle UI event, return pointer to changed item, NULL if no change.
|
| 667 |
+
mjuiItem* mjui_event(mjUI* ui, mjuiState* state, const mjrContext* con);
|
| 668 |
+
|
| 669 |
+
# Copy UI image to current buffer.
|
| 670 |
+
void mjui_render(mjUI* ui, const mjuiState* state, const mjrContext* con);
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
#--------------------- Error and memory -----------------------------------------------
|
| 674 |
+
|
| 675 |
+
# Main error function; does not return to caller.
|
| 676 |
+
void mju_error(const char* msg);
|
| 677 |
+
|
| 678 |
+
# Error function with int argument; msg is a printf format string.
|
| 679 |
+
void mju_error_i(const char* msg, int i);
|
| 680 |
+
|
| 681 |
+
# Error function with string argument.
|
| 682 |
+
void mju_error_s(const char* msg, const char* text);
|
| 683 |
+
|
| 684 |
+
# Main warning function; returns to caller.
|
| 685 |
+
void mju_warning(const char* msg);
|
| 686 |
+
|
| 687 |
+
# Warning function with int argument.
|
| 688 |
+
void mju_warning_i(const char* msg, int i);
|
| 689 |
+
|
| 690 |
+
# Warning function with string argument.
|
| 691 |
+
void mju_warning_s(const char* msg, const char* text);
|
| 692 |
+
|
| 693 |
+
# Clear user error and memory handlers.
|
| 694 |
+
void mju_clearHandlers();
|
| 695 |
+
|
| 696 |
+
# Allocate memory; byte-align on 8; pad size to multiple of 8.
|
| 697 |
+
void* mju_malloc(size_t size);
|
| 698 |
+
|
| 699 |
+
# Free memory, using free() by default.
|
| 700 |
+
void mju_free(void* ptr);
|
| 701 |
+
|
| 702 |
+
# High-level warning function: count warnings in mjData, print only the first.
|
| 703 |
+
void mj_warning(mjData* d, int warning, int info);
|
| 704 |
+
|
| 705 |
+
# Write [datetime, type: message] to MUJOCO_LOG.TXT.
|
| 706 |
+
void mju_writeLog(const char* type, const char* msg);
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
#--------------------- Standard math --------------------------------------------------
|
| 710 |
+
|
| 711 |
+
#define mjMAX(a,b) (((a) > (b)) ? (a) : (b))
|
| 712 |
+
#define mjMIN(a,b) (((a) < (b)) ? (a) : (b))
|
| 713 |
+
|
| 714 |
+
#ifdef mjUSEDOUBLE
|
| 715 |
+
#define mju_sqrt sqrt
|
| 716 |
+
#define mju_exp exp
|
| 717 |
+
#define mju_sin sin
|
| 718 |
+
#define mju_cos cos
|
| 719 |
+
#define mju_tan tan
|
| 720 |
+
#define mju_asin asin
|
| 721 |
+
#define mju_acos acos
|
| 722 |
+
#define mju_atan2 atan2
|
| 723 |
+
#define mju_tanh tanh
|
| 724 |
+
#define mju_pow pow
|
| 725 |
+
#define mju_abs fabs
|
| 726 |
+
#define mju_log log
|
| 727 |
+
#define mju_log10 log10
|
| 728 |
+
#define mju_floor floor
|
| 729 |
+
#define mju_ceil ceil
|
| 730 |
+
|
| 731 |
+
#else
|
| 732 |
+
#define mju_sqrt sqrtf
|
| 733 |
+
#define mju_exp expf
|
| 734 |
+
#define mju_sin sinf
|
| 735 |
+
#define mju_cos cosf
|
| 736 |
+
#define mju_tan tanf
|
| 737 |
+
#define mju_asin asinf
|
| 738 |
+
#define mju_acos acosf
|
| 739 |
+
#define mju_atan2 atan2f
|
| 740 |
+
#define mju_tanh tanhf
|
| 741 |
+
#define mju_pow powf
|
| 742 |
+
#define mju_abs fabsf
|
| 743 |
+
#define mju_log logf
|
| 744 |
+
#define mju_log10 log10f
|
| 745 |
+
#define mju_floor floorf
|
| 746 |
+
#define mju_ceil ceilf
|
| 747 |
+
#endif
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
#----------------------------- Vector math --------------------------------------------
|
| 751 |
+
|
| 752 |
+
# Set res = 0.
|
| 753 |
+
void mju_zero3(mjtNum res[3]);
|
| 754 |
+
|
| 755 |
+
# Set res = vec.
|
| 756 |
+
void mju_copy3(mjtNum res[3], const mjtNum data[3]);
|
| 757 |
+
|
| 758 |
+
# Set res = vec*scl.
|
| 759 |
+
void mju_scl3(mjtNum res[3], const mjtNum vec[3], mjtNum scl);
|
| 760 |
+
|
| 761 |
+
# Set res = vec1 + vec2.
|
| 762 |
+
void mju_add3(mjtNum res[3], const mjtNum vec1[3], const mjtNum vec2[3]);
|
| 763 |
+
|
| 764 |
+
# Set res = vec1 - vec2.
|
| 765 |
+
void mju_sub3(mjtNum res[3], const mjtNum vec1[3], const mjtNum vec2[3]);
|
| 766 |
+
|
| 767 |
+
# Set res = res + vec.
|
| 768 |
+
void mju_addTo3(mjtNum res[3], const mjtNum vec[3]);
|
| 769 |
+
|
| 770 |
+
# Set res = res - vec.
|
| 771 |
+
void mju_subFrom3(mjtNum res[3], const mjtNum vec[3]);
|
| 772 |
+
|
| 773 |
+
# Set res = res + vec*scl.
|
| 774 |
+
void mju_addToScl3(mjtNum res[3], const mjtNum vec[3], mjtNum scl);
|
| 775 |
+
|
| 776 |
+
# Set res = vec1 + vec2*scl.
|
| 777 |
+
void mju_addScl3(mjtNum res[3], const mjtNum vec1[3], const mjtNum vec2[3], mjtNum scl);
|
| 778 |
+
|
| 779 |
+
# Normalize vector, return length before normalization.
|
| 780 |
+
mjtNum mju_normalize3(mjtNum res[3]);
|
| 781 |
+
|
| 782 |
+
# Return vector length (without normalizing the vector).
|
| 783 |
+
mjtNum mju_norm3(const mjtNum vec[3]);
|
| 784 |
+
|
| 785 |
+
# Return dot-product of vec1 and vec2.
|
| 786 |
+
mjtNum mju_dot3(const mjtNum vec1[3], const mjtNum vec2[3]);
|
| 787 |
+
|
| 788 |
+
# Return Cartesian distance between 3D vectors pos1 and pos2.
|
| 789 |
+
mjtNum mju_dist3(const mjtNum pos1[3], const mjtNum pos2[3]);
|
| 790 |
+
|
| 791 |
+
# Multiply vector by 3D rotation matrix: res = mat * vec.
|
| 792 |
+
void mju_rotVecMat(mjtNum res[3], const mjtNum vec[3], const mjtNum mat[9]);
|
| 793 |
+
|
| 794 |
+
# Multiply vector by transposed 3D rotation matrix: res = mat' * vec.
|
| 795 |
+
void mju_rotVecMatT(mjtNum res[3], const mjtNum vec[3], const mjtNum mat[9]);
|
| 796 |
+
|
| 797 |
+
# Compute cross-product: res = cross(a, b).
|
| 798 |
+
void mju_cross(mjtNum res[3], const mjtNum a[3], const mjtNum b[3]);
|
| 799 |
+
|
| 800 |
+
# Set res = 0.
|
| 801 |
+
void mju_zero4(mjtNum res[4]);
|
| 802 |
+
|
| 803 |
+
# Set res = (1,0,0,0).
|
| 804 |
+
void mju_unit4(mjtNum res[4]);
|
| 805 |
+
|
| 806 |
+
# Set res = vec.
|
| 807 |
+
void mju_copy4(mjtNum res[4], const mjtNum data[4]);
|
| 808 |
+
|
| 809 |
+
# Normalize vector, return length before normalization.
|
| 810 |
+
mjtNum mju_normalize4(mjtNum res[4]);
|
| 811 |
+
|
| 812 |
+
# Set res = 0.
|
| 813 |
+
void mju_zero(mjtNum* res, int n);
|
| 814 |
+
|
| 815 |
+
# Set res = vec.
|
| 816 |
+
void mju_copy(mjtNum* res, const mjtNum* data, int n);
|
| 817 |
+
|
| 818 |
+
# Return sum(vec).
|
| 819 |
+
mjtNum mju_sum(const mjtNum* vec, int n);
|
| 820 |
+
|
| 821 |
+
# Return L1 norm: sum(abs(vec)).
|
| 822 |
+
mjtNum mju_L1(const mjtNum* vec, int n);
|
| 823 |
+
|
| 824 |
+
# Set res = vec*scl.
|
| 825 |
+
void mju_scl(mjtNum* res, const mjtNum* vec, mjtNum scl, int n);
|
| 826 |
+
|
| 827 |
+
# Set res = vec1 + vec2.
|
| 828 |
+
void mju_add(mjtNum* res, const mjtNum* vec1, const mjtNum* vec2, int n);
|
| 829 |
+
|
| 830 |
+
# Set res = vec1 - vec2.
|
| 831 |
+
void mju_sub(mjtNum* res, const mjtNum* vec1, const mjtNum* vec2, int n);
|
| 832 |
+
|
| 833 |
+
# Set res = res + vec.
|
| 834 |
+
void mju_addTo(mjtNum* res, const mjtNum* vec, int n);
|
| 835 |
+
|
| 836 |
+
# Set res = res - vec.
|
| 837 |
+
void mju_subFrom(mjtNum* res, const mjtNum* vec, int n);
|
| 838 |
+
|
| 839 |
+
# Set res = res + vec*scl.
|
| 840 |
+
void mju_addToScl(mjtNum* res, const mjtNum* vec, mjtNum scl, int n);
|
| 841 |
+
|
| 842 |
+
# Set res = vec1 + vec2*scl.
|
| 843 |
+
void mju_addScl(mjtNum* res, const mjtNum* vec1, const mjtNum* vec2, mjtNum scl, int n);
|
| 844 |
+
|
| 845 |
+
# Normalize vector, return length before normalization.
|
| 846 |
+
mjtNum mju_normalize(mjtNum* res, int n);
|
| 847 |
+
|
| 848 |
+
# Return vector length (without normalizing vector).
|
| 849 |
+
mjtNum mju_norm(const mjtNum* res, int n);
|
| 850 |
+
|
| 851 |
+
# Return dot-product of vec1 and vec2.
|
| 852 |
+
mjtNum mju_dot(const mjtNum* vec1, const mjtNum* vec2, const int n);
|
| 853 |
+
|
| 854 |
+
# Multiply matrix and vector: res = mat * vec.
|
| 855 |
+
void mju_mulMatVec(mjtNum* res, const mjtNum* mat, const mjtNum* vec,
|
| 856 |
+
int nr, int nc);
|
| 857 |
+
|
| 858 |
+
# Multiply transposed matrix and vector: res = mat' * vec.
|
| 859 |
+
void mju_mulMatTVec(mjtNum* res, const mjtNum* mat, const mjtNum* vec,
|
| 860 |
+
int nr, int nc);
|
| 861 |
+
|
| 862 |
+
# Transpose matrix: res = mat'.
|
| 863 |
+
void mju_transpose(mjtNum* res, const mjtNum* mat, int nr, int nc);
|
| 864 |
+
|
| 865 |
+
# Multiply matrices: res = mat1 * mat2.
|
| 866 |
+
void mju_mulMatMat(mjtNum* res, const mjtNum* mat1, const mjtNum* mat2,
|
| 867 |
+
int r1, int c1, int c2);
|
| 868 |
+
|
| 869 |
+
# Multiply matrices, second argument transposed: res = mat1 * mat2'.
|
| 870 |
+
void mju_mulMatMatT(mjtNum* res, const mjtNum* mat1, const mjtNum* mat2,
|
| 871 |
+
int r1, int c1, int r2);
|
| 872 |
+
|
| 873 |
+
# Multiply matrices, first argument transposed: res = mat1' * mat2.
|
| 874 |
+
void mju_mulMatTMat(mjtNum* res, const mjtNum* mat1, const mjtNum* mat2,
|
| 875 |
+
int r1, int c1, int c2);
|
| 876 |
+
|
| 877 |
+
# Set res = mat' * diag * mat if diag is not NULL, and res = mat' * mat otherwise.
|
| 878 |
+
void mju_sqrMatTD(mjtNum* res, const mjtNum* mat, const mjtNum* diag, int nr, int nc);
|
| 879 |
+
|
| 880 |
+
# Coordinate transform of 6D motion or force vector in rotation:translation format.
|
| 881 |
+
# rotnew2old is 3-by-3, NULL means no rotation; flg_force specifies force or motion type.
|
| 882 |
+
void mju_transformSpatial(mjtNum res[6], const mjtNum vec[6], int flg_force,
|
| 883 |
+
const mjtNum newpos[3], const mjtNum oldpos[3],
|
| 884 |
+
const mjtNum rotnew2old[9]);
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
#--------------------- Sparse math ----------------------------------------------------
|
| 888 |
+
|
| 889 |
+
# Return dot-product of vec1 and vec2, where vec1 is sparse.
|
| 890 |
+
mjtNum mju_dotSparse(const mjtNum* vec1, const mjtNum* vec2,
|
| 891 |
+
const int nnz1, const int* ind1);
|
| 892 |
+
|
| 893 |
+
# Return dot-product of vec1 and vec2, where both vectors are sparse.
|
| 894 |
+
mjtNum mju_dotSparse2(const mjtNum* vec1, const mjtNum* vec2,
|
| 895 |
+
const int nnz1, const int* ind1,
|
| 896 |
+
const int nnz2, const int* ind2);
|
| 897 |
+
|
| 898 |
+
# Convert matrix from dense to sparse format.
|
| 899 |
+
void mju_dense2sparse(mjtNum* res, const mjtNum* mat, int nr, int nc,
|
| 900 |
+
int* rownnz, int* rowadr, int* colind);
|
| 901 |
+
|
| 902 |
+
# Convert matrix from sparse to dense format.
|
| 903 |
+
void mju_sparse2dense(mjtNum* res, const mjtNum* mat, int nr, int nc,
|
| 904 |
+
const int* rownnz, const int* rowadr, const int* colind);
|
| 905 |
+
|
| 906 |
+
# Multiply sparse matrix and dense vector: res = mat * vec.
|
| 907 |
+
void mju_mulMatVecSparse(mjtNum* res, const mjtNum* mat, const mjtNum* vec, int nr,
|
| 908 |
+
const int* rownnz, const int* rowadr, const int* colind);
|
| 909 |
+
|
| 910 |
+
# Compress layout of sparse matrix.
|
| 911 |
+
void mju_compressSparse(mjtNum* mat, int nr, int nc,
|
| 912 |
+
int* rownnz, int* rowadr, int* colind);
|
| 913 |
+
|
| 914 |
+
# Set dst = a*dst + b*src, return nnz of result, modify dst sparsity pattern as needed.
|
| 915 |
+
# Both vectors are sparse. The required scratch space is 2*n.
|
| 916 |
+
int mju_combineSparse(mjtNum* dst, const mjtNum* src, int n, mjtNum a, mjtNum b,
|
| 917 |
+
int dst_nnz, int src_nnz, int* dst_ind, const int* src_ind,
|
| 918 |
+
mjtNum* scratch, int nscratch);
|
| 919 |
+
|
| 920 |
+
# Set res = matT * diag * mat if diag is not NULL, and res = matT * mat otherwise.
|
| 921 |
+
# The required scratch space is 3*nc. The result has uncompressed layout.
|
| 922 |
+
void mju_sqrMatTDSparse(mjtNum* res, const mjtNum* mat, const mjtNum* matT,
|
| 923 |
+
const mjtNum* diag, int nr, int nc,
|
| 924 |
+
int* res_rownnz, int* res_rowadr, int* res_colind,
|
| 925 |
+
const int* rownnz, const int* rowadr, const int* colind,
|
| 926 |
+
const int* rownnzT, const int* rowadrT, const int* colindT,
|
| 927 |
+
mjtNum* scratch, int nscratch);
|
| 928 |
+
|
| 929 |
+
# Transpose sparse matrix.
|
| 930 |
+
void mju_transposeSparse(mjtNum* res, const mjtNum* mat, int nr, int nc,
|
| 931 |
+
int* res_rownnz, int* res_rowadr, int* res_colind,
|
| 932 |
+
const int* rownnz, const int* rowadr, const int* colind);
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
#--------------------- Quaternions ----------------------------------------------------
|
| 936 |
+
|
| 937 |
+
# Rotate vector by quaternion.
|
| 938 |
+
void mju_rotVecQuat(mjtNum res[3], const mjtNum vec[3], const mjtNum quat[4]);
|
| 939 |
+
|
| 940 |
+
# Negate quaternion.
|
| 941 |
+
void mju_negQuat(mjtNum res[4], const mjtNum quat[4]);
|
| 942 |
+
|
| 943 |
+
# Muiltiply quaternions.
|
| 944 |
+
void mju_mulQuat(mjtNum res[4], const mjtNum quat1[4], const mjtNum quat2[4]);
|
| 945 |
+
|
| 946 |
+
# Muiltiply quaternion and axis.
|
| 947 |
+
void mju_mulQuatAxis(mjtNum res[4], const mjtNum quat[4], const mjtNum axis[3]);
|
| 948 |
+
|
| 949 |
+
# Convert axisAngle to quaternion.
|
| 950 |
+
void mju_axisAngle2Quat(mjtNum res[4], const mjtNum axis[3], mjtNum angle);
|
| 951 |
+
|
| 952 |
+
# Convert quaternion (corresponding to orientation difference) to 3D velocity.
|
| 953 |
+
void mju_quat2Vel(mjtNum res[3], const mjtNum quat[4], mjtNum dt);
|
| 954 |
+
|
| 955 |
+
# Subtract quaternions, express as 3D velocity: qb*quat(res) = qa.
|
| 956 |
+
void mju_subQuat(mjtNum res[3], const mjtNum qa[4], const mjtNum qb[4]);
|
| 957 |
+
|
| 958 |
+
# Convert quaternion to 3D rotation matrix.
|
| 959 |
+
void mju_quat2Mat(mjtNum res[9], const mjtNum quat[4]);
|
| 960 |
+
|
| 961 |
+
# Convert 3D rotation matrix to quaterion.
|
| 962 |
+
void mju_mat2Quat(mjtNum quat[4], const mjtNum mat[9]);
|
| 963 |
+
|
| 964 |
+
# Compute time-derivative of quaternion, given 3D rotational velocity.
|
| 965 |
+
void mju_derivQuat(mjtNum res[4], const mjtNum quat[4], const mjtNum vel[3]);
|
| 966 |
+
|
| 967 |
+
# Integrate quaterion given 3D angular velocity.
|
| 968 |
+
void mju_quatIntegrate(mjtNum quat[4], const mjtNum vel[3], mjtNum scale);
|
| 969 |
+
|
| 970 |
+
# Construct quaternion performing rotation from z-axis to given vector.
|
| 971 |
+
void mju_quatZ2Vec(mjtNum quat[4], const mjtNum vec[3]);
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
#--------------------- Poses ----------------------------------------------------------
|
| 975 |
+
|
| 976 |
+
# Multiply two poses.
|
| 977 |
+
void mju_mulPose(mjtNum posres[3], mjtNum quatres[4],
|
| 978 |
+
const mjtNum pos1[3], const mjtNum quat1[4],
|
| 979 |
+
const mjtNum pos2[3], const mjtNum quat2[4]);
|
| 980 |
+
|
| 981 |
+
# Negate pose.
|
| 982 |
+
void mju_negPose(mjtNum posres[3], mjtNum quatres[4],
|
| 983 |
+
const mjtNum pos[3], const mjtNum quat[4]);
|
| 984 |
+
|
| 985 |
+
# Transform vector by pose.
|
| 986 |
+
void mju_trnVecPose(mjtNum res[3], const mjtNum pos[3], const mjtNum quat[4],
|
| 987 |
+
const mjtNum vec[3]);
|
| 988 |
+
|
| 989 |
+
|
| 990 |
+
#--------------------- Decompositions --------------------------------------------------
|
| 991 |
+
|
| 992 |
+
# Cholesky decomposition: mat = L*L'; return rank.
|
| 993 |
+
int mju_cholFactor(mjtNum* mat, int n, mjtNum mindiag);
|
| 994 |
+
|
| 995 |
+
# Solve mat * res = vec, where mat is Cholesky-factorized
|
| 996 |
+
void mju_cholSolve(mjtNum* res, const mjtNum* mat, const mjtNum* vec, int n);
|
| 997 |
+
|
| 998 |
+
# Cholesky rank-one update: L*L' +/- x*x'; return rank.
|
| 999 |
+
int mju_cholUpdate(mjtNum* mat, mjtNum* x, int n, int flg_plus);
|
| 1000 |
+
|
| 1001 |
+
# Eigenvalue decomposition of symmetric 3x3 matrix.
|
| 1002 |
+
int mju_eig3(mjtNum* eigval, mjtNum* eigvec, mjtNum* quat, const mjtNum* mat);
|
| 1003 |
+
|
| 1004 |
+
|
| 1005 |
+
#--------------------- Miscellaneous --------------------------------------------------
|
| 1006 |
+
|
| 1007 |
+
# Muscle active force, prm = (range[2], force, scale, lmin, lmax, vmax, fpmax, fvmax).
|
| 1008 |
+
mjtNum mju_muscleGain(mjtNum len, mjtNum vel, const mjtNum lengthrange[2],
|
| 1009 |
+
mjtNum acc0, const mjtNum prm[9]);
|
| 1010 |
+
|
| 1011 |
+
# Muscle passive force, prm = (range[2], force, scale, lmin, lmax, vmax, fpmax, fvmax).
|
| 1012 |
+
mjtNum mju_muscleBias(mjtNum len, const mjtNum lengthrange[2],
|
| 1013 |
+
mjtNum acc0, const mjtNum prm[9]);
|
| 1014 |
+
|
| 1015 |
+
# Muscle activation dynamics, prm = (tau_act, tau_deact).
|
| 1016 |
+
mjtNum mju_muscleDynamics(mjtNum ctrl, mjtNum act, const mjtNum prm[2]);
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
# Convert contact force to pyramid representation.
|
| 1020 |
+
void mju_encodePyramid(mjtNum* pyramid, const mjtNum* force,
|
| 1021 |
+
const mjtNum* mu, int dim);
|
| 1022 |
+
|
| 1023 |
+
# Convert pyramid representation to contact force.
|
| 1024 |
+
void mju_decodePyramid(mjtNum* force, const mjtNum* pyramid,
|
| 1025 |
+
const mjtNum* mu, int dim);
|
| 1026 |
+
|
| 1027 |
+
# Integrate spring-damper analytically, return pos(dt).
|
| 1028 |
+
mjtNum mju_springDamper(mjtNum pos0, mjtNum vel0, mjtNum Kp, mjtNum Kv, mjtNum dt);
|
| 1029 |
+
|
| 1030 |
+
# Return min(a,b) with single evaluation of a and b.
|
| 1031 |
+
mjtNum mju_min(mjtNum a, mjtNum b);
|
| 1032 |
+
|
| 1033 |
+
# Return max(a,b) with single evaluation of a and b.
|
| 1034 |
+
mjtNum mju_max(mjtNum a, mjtNum b);
|
| 1035 |
+
|
| 1036 |
+
# Return sign of x: +1, -1 or 0.
|
| 1037 |
+
mjtNum mju_sign(mjtNum x);
|
| 1038 |
+
|
| 1039 |
+
# Round x to nearest integer.
|
| 1040 |
+
int mju_round(mjtNum x);
|
| 1041 |
+
|
| 1042 |
+
# Convert type id (mjtObj) to type name.
|
| 1043 |
+
const char* mju_type2Str(int type);
|
| 1044 |
+
|
| 1045 |
+
# Convert type name to type id (mjtObj).
|
| 1046 |
+
int mju_str2Type(const char* str);
|
| 1047 |
+
|
| 1048 |
+
# Construct a warning message given the warning type and info.
|
| 1049 |
+
const char* mju_warningText(int warning, int info);
|
| 1050 |
+
|
| 1051 |
+
# Return 1 if nan or abs(x)>mjMAXVAL, 0 otherwise. Used by check functions.
|
| 1052 |
+
int mju_isBad(mjtNum x);
|
| 1053 |
+
|
| 1054 |
+
# Return 1 if all elements are 0.
|
| 1055 |
+
int mju_isZero(mjtNum* vec, int n);
|
| 1056 |
+
|
| 1057 |
+
# Standard normal random number generator (optional second number).
|
| 1058 |
+
mjtNum mju_standardNormal(mjtNum* num2);
|
| 1059 |
+
|
| 1060 |
+
# Convert from float to mjtNum.
|
| 1061 |
+
void mju_f2n(mjtNum* res, const float* vec, int n);
|
| 1062 |
+
|
| 1063 |
+
# Convert from mjtNum to float.
|
| 1064 |
+
void mju_n2f(float* res, const mjtNum* vec, int n);
|
| 1065 |
+
|
| 1066 |
+
# Convert from double to mjtNum.
|
| 1067 |
+
void mju_d2n(mjtNum* res, const double* vec, int n);
|
| 1068 |
+
|
| 1069 |
+
# Convert from mjtNum to double.
|
| 1070 |
+
void mju_n2d(double* res, const mjtNum* vec, int n);
|
| 1071 |
+
|
| 1072 |
+
# Insertion sort, resulting list is in increasing order.
|
| 1073 |
+
void mju_insertionSort(mjtNum* list, int n);
|
| 1074 |
+
|
| 1075 |
+
# Integer insertion sort, resulting list is in increasing order.
|
| 1076 |
+
void mju_insertionSortInt(int* list, int n);
|
| 1077 |
+
|
| 1078 |
+
# Generate Halton sequence.
|
| 1079 |
+
mjtNum mju_Halton(int index, int base);
|
| 1080 |
+
|
| 1081 |
+
# Sigmoid function over 0<=x<=1 constructed from half-quadratics.
|
| 1082 |
+
mjtNum mju_sigmoid(mjtNum x);
|
| 1083 |
+
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_materials.premod.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop0_1.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop1_0.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop1_1.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_multiple_sims.loop2_1.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_render_pool.mp_test_states.2.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_render_pool.mp_test_states.3.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.camera1.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.freecam.depth-darwin.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_rendering.freecam.depth.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_resetting.loop1_1.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_textures.rgb.png
ADDED
|
mujoco-py-2.1.2.14/mujoco_py/test_imgs/test_textures.variety.png
ADDED
|