Spaces:
Sleeping
Sleeping
File size: 45,780 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 |
"""
Contains torch Modules that help deal with inputs consisting of multiple
modalities. This is extremely common when networks must deal with one or
more observation dictionaries, where each input dictionary can have
observation keys of a certain modality and shape.
As an example, an observation could consist of a flat "robot0_eef_pos" observation key,
and a 3-channel RGB "agentview_image" observation key.
"""
import sys
import numpy as np
import textwrap
from copy import deepcopy
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
from robomimic.utils.python_utils import extract_class_init_kwargs_from_dict
import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils
from robomimic.models.base_nets import Module, Sequential, MLP, RNN_Base, ResNet18Conv, SpatialSoftmax, \
FeatureAggregator
from robomimic.models.obs_core import VisualCore, Randomizer
from robomimic.models.transformers import PositionalEncoding, GPT_Backbone
def obs_encoder_factory(
obs_shapes,
feature_activation=nn.ReLU,
encoder_kwargs=None,
):
"""
Utility function to create an @ObservationEncoder from kwargs specified in config.
Args:
obs_shapes (OrderedDict): a dictionary that maps observation key to
expected shapes for observations.
feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass
None to apply no activation.
encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should be
nested dictionary containing relevant per-modality information for encoder networks.
Should be of form:
obs_modality1: dict
feature_dimension: int
core_class: str
core_kwargs: dict
...
...
obs_randomizer_class: str
obs_randomizer_kwargs: dict
...
...
obs_modality2: dict
...
"""
enc = ObservationEncoder(feature_activation=feature_activation)
for k, obs_shape in obs_shapes.items():
obs_modality = ObsUtils.OBS_KEYS_TO_MODALITIES[k]
enc_kwargs = deepcopy(ObsUtils.DEFAULT_ENCODER_KWARGS[obs_modality]) if encoder_kwargs is None else \
deepcopy(encoder_kwargs[obs_modality])
for obs_module, cls_mapping in zip(("core", "obs_randomizer"),
(ObsUtils.OBS_ENCODER_CORES, ObsUtils.OBS_RANDOMIZERS)):
# Sanity check for kwargs in case they don't exist / are None
if enc_kwargs.get(f"{obs_module}_kwargs", None) is None:
enc_kwargs[f"{obs_module}_kwargs"] = {}
# Add in input shape info
enc_kwargs[f"{obs_module}_kwargs"]["input_shape"] = obs_shape
# If group class is specified, then make sure corresponding kwargs only contain relevant kwargs
if enc_kwargs[f"{obs_module}_class"] is not None:
enc_kwargs[f"{obs_module}_kwargs"] = extract_class_init_kwargs_from_dict(
cls=cls_mapping[enc_kwargs[f"{obs_module}_class"]],
dic=enc_kwargs[f"{obs_module}_kwargs"],
copy=False,
)
# Add in input shape info
randomizer = None if enc_kwargs["obs_randomizer_class"] is None else \
ObsUtils.OBS_RANDOMIZERS[enc_kwargs["obs_randomizer_class"]](**enc_kwargs["obs_randomizer_kwargs"])
enc.register_obs_key(
name=k,
shape=obs_shape,
net_class=enc_kwargs["core_class"],
net_kwargs=enc_kwargs["core_kwargs"],
randomizer=randomizer,
)
enc.make()
return enc
class ObservationEncoder(Module):
"""
Module that processes inputs by observation key and then concatenates the processed
observation keys together. Each key is processed with an encoder head network.
Call @register_obs_key to register observation keys with the encoder and then
finally call @make to create the encoder networks.
"""
def __init__(self, feature_activation=nn.ReLU):
"""
Args:
feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass
None to apply no activation.
"""
super(ObservationEncoder, self).__init__()
self.obs_shapes = OrderedDict()
self.obs_nets_classes = OrderedDict()
self.obs_nets_kwargs = OrderedDict()
self.obs_share_mods = OrderedDict()
self.obs_nets = nn.ModuleDict()
self.obs_randomizers = nn.ModuleDict()
self.feature_activation = feature_activation
self._locked = False
def register_obs_key(
self,
name,
shape,
net_class=None,
net_kwargs=None,
net=None,
randomizer=None,
share_net_from=None,
):
"""
Register an observation key that this encoder should be responsible for.
Args:
name (str): modality name
shape (int tuple): shape of modality
net_class (str): name of class in base_nets.py that should be used
to process this observation key before concatenation. Pass None to flatten
and concatenate the observation key directly.
net_kwargs (dict): arguments to pass to @net_class
net (Module instance): if provided, use this Module to process the observation key
instead of creating a different net
randomizer (Randomizer instance): if provided, use this Module to augment observation keys
coming in to the encoder, and possibly augment the processed output as well
share_net_from (str): if provided, use the same instance of @net_class
as another observation key. This observation key must already exist in this encoder.
Warning: Note that this does not share the observation key randomizer
"""
assert not self._locked, "ObservationEncoder: @register_obs_key called after @make"
assert name not in self.obs_shapes, "ObservationEncoder: modality {} already exists".format(name)
if net is not None:
assert isinstance(net, Module), "ObservationEncoder: @net must be instance of Module class"
assert (net_class is None) and (net_kwargs is None) and (share_net_from is None), \
"ObservationEncoder: @net provided - ignore other net creation options"
if share_net_from is not None:
# share processing with another modality
assert (net_class is None) and (net_kwargs is None)
assert share_net_from in self.obs_shapes
net_kwargs = deepcopy(net_kwargs) if net_kwargs is not None else {}
if randomizer is not None:
assert isinstance(randomizer, Randomizer)
if net_kwargs is not None:
# update input shape to visual core
net_kwargs["input_shape"] = randomizer.output_shape_in(shape)
self.obs_shapes[name] = shape
self.obs_nets_classes[name] = net_class
self.obs_nets_kwargs[name] = net_kwargs
self.obs_nets[name] = net
self.obs_randomizers[name] = randomizer
self.obs_share_mods[name] = share_net_from
def make(self):
"""
Creates the encoder networks and locks the encoder so that more modalities cannot be added.
"""
assert not self._locked, "ObservationEncoder: @make called more than once"
self._create_layers()
self._locked = True
def _create_layers(self):
"""
Creates all networks and layers required by this encoder using the registered modalities.
"""
assert not self._locked, "ObservationEncoder: layers have already been created"
for k in self.obs_shapes:
if self.obs_nets_classes[k] is not None:
# create net to process this modality
self.obs_nets[k] = ObsUtils.OBS_ENCODER_CORES[self.obs_nets_classes[k]](**self.obs_nets_kwargs[k])
elif self.obs_share_mods[k] is not None:
# make sure net is shared with another modality
self.obs_nets[k] = self.obs_nets[self.obs_share_mods[k]]
self.activation = None
if self.feature_activation is not None:
self.activation = self.feature_activation()
def forward(self, obs_dict):
"""
Processes modalities according to the ordering in @self.obs_shapes. For each
modality, it is processed with a randomizer (if present), an encoder
network (if present), and again with the randomizer (if present), flattened,
and then concatenated with the other processed modalities.
Args:
obs_dict (OrderedDict): dictionary that maps modalities to torch.Tensor
batches that agree with @self.obs_shapes. All modalities in
@self.obs_shapes must be present, but additional modalities
can also be present.
Returns:
feats (torch.Tensor): flat features of shape [B, D]
"""
assert self._locked, "ObservationEncoder: @make has not been called yet"
# ensure all modalities that the encoder handles are present
assert set(self.obs_shapes.keys()).issubset(obs_dict), "ObservationEncoder: {} does not contain all modalities {}".format(
list(obs_dict.keys()), list(self.obs_shapes.keys())
)
# process modalities by order given by @self.obs_shapes
feats = []
for k in self.obs_shapes:
x = obs_dict[k]
# maybe process encoder input with randomizer
if self.obs_randomizers[k] is not None:
x = self.obs_randomizers[k].forward_in(x)
# maybe process with obs net
if self.obs_nets[k] is not None:
x = self.obs_nets[k](x)
if self.activation is not None:
x = self.activation(x)
# maybe process encoder output with randomizer
if self.obs_randomizers[k] is not None:
x = self.obs_randomizers[k].forward_out(x)
# flatten to [B, D]
x = TensorUtils.flatten(x, begin_axis=1)
feats.append(x)
# concatenate all features together
return torch.cat(feats, dim=-1)
def output_shape(self, input_shape=None):
"""
Compute the output shape of the encoder.
"""
feat_dim = 0
for k in self.obs_shapes:
feat_shape = self.obs_shapes[k]
if self.obs_randomizers[k] is not None:
feat_shape = self.obs_randomizers[k].output_shape_in(feat_shape)
if self.obs_nets[k] is not None:
feat_shape = self.obs_nets[k].output_shape(feat_shape)
if self.obs_randomizers[k] is not None:
feat_shape = self.obs_randomizers[k].output_shape_out(feat_shape)
feat_dim += int(np.prod(feat_shape))
return [feat_dim]
def __repr__(self):
"""
Pretty print the encoder.
"""
header = '{}'.format(str(self.__class__.__name__))
msg = ''
for k in self.obs_shapes:
msg += textwrap.indent('\nKey(\n', ' ' * 4)
indent = ' ' * 8
msg += textwrap.indent("name={}\nshape={}\n".format(k, self.obs_shapes[k]), indent)
msg += textwrap.indent("modality={}\n".format(ObsUtils.OBS_KEYS_TO_MODALITIES[k]), indent)
msg += textwrap.indent("randomizer={}\n".format(self.obs_randomizers[k]), indent)
msg += textwrap.indent("net={}\n".format(self.obs_nets[k]), indent)
msg += textwrap.indent("sharing_from={}\n".format(self.obs_share_mods[k]), indent)
msg += textwrap.indent(")", ' ' * 4)
msg += textwrap.indent("\noutput_shape={}".format(self.output_shape()), ' ' * 4)
msg = header + '(' + msg + '\n)'
return msg
class ObservationDecoder(Module):
"""
Module that can generate observation outputs by modality. Inputs are assumed
to be flat (usually outputs from some hidden layer). Each observation output
is generated with a linear layer from these flat inputs. Subclass this
module in order to implement more complex schemes for generating each
modality.
"""
def __init__(
self,
decode_shapes,
input_feat_dim,
):
"""
Args:
decode_shapes (OrderedDict): a dictionary that maps observation key to
expected shape. This is used to generate output modalities from the
input features.
input_feat_dim (int): flat input dimension size
"""
super(ObservationDecoder, self).__init__()
# important: sort observation keys to ensure consistent ordering of modalities
assert isinstance(decode_shapes, OrderedDict)
self.obs_shapes = OrderedDict()
for k in decode_shapes:
self.obs_shapes[k] = decode_shapes[k]
self.input_feat_dim = input_feat_dim
self._create_layers()
def _create_layers(self):
"""
Create a linear layer to predict each modality.
"""
self.nets = nn.ModuleDict()
for k in self.obs_shapes:
layer_out_dim = int(np.prod(self.obs_shapes[k]))
self.nets[k] = nn.Linear(self.input_feat_dim, layer_out_dim)
def output_shape(self, input_shape=None):
"""
Returns output shape for this module, which is a dictionary instead
of a list since outputs are dictionaries.
"""
return { k : list(self.obs_shapes[k]) for k in self.obs_shapes }
def forward(self, feats):
"""
Predict each modality from input features, and reshape to each modality's shape.
"""
output = {}
for k in self.obs_shapes:
out = self.nets[k](feats)
output[k] = out.reshape(-1, *self.obs_shapes[k])
return output
def __repr__(self):
"""Pretty print network."""
header = '{}'.format(str(self.__class__.__name__))
msg = ''
for k in self.obs_shapes:
msg += textwrap.indent('\nKey(\n', ' ' * 4)
indent = ' ' * 8
msg += textwrap.indent("name={}\nshape={}\n".format(k, self.obs_shapes[k]), indent)
msg += textwrap.indent("modality={}\n".format(ObsUtils.OBS_KEYS_TO_MODALITIES[k]), indent)
msg += textwrap.indent("net=({})\n".format(self.nets[k]), indent)
msg += textwrap.indent(")", ' ' * 4)
msg = header + '(' + msg + '\n)'
return msg
class ObservationGroupEncoder(Module):
"""
This class allows networks to encode multiple observation dictionaries into a single
flat, concatenated vector representation. It does this by assigning each observation
dictionary (observation group) an @ObservationEncoder object.
The class takes a dictionary of dictionaries, @observation_group_shapes.
Each key corresponds to a observation group (e.g. 'obs', 'subgoal', 'goal')
and each OrderedDict should be a map between modalities and
expected input shapes (e.g. { 'image' : (3, 120, 160) }).
"""
def __init__(
self,
observation_group_shapes,
feature_activation=nn.ReLU,
encoder_kwargs=None,
):
"""
Args:
observation_group_shapes (OrderedDict): a dictionary of dictionaries.
Each key in this dictionary should specify an observation group, and
the value should be an OrderedDict that maps modalities to
expected shapes.
feature_activation: non-linearity to apply after each obs net - defaults to ReLU. Pass
None to apply no activation.
encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
be nested dictionary containing relevant per-modality information for encoder networks.
Should be of form:
obs_modality1: dict
feature_dimension: int
core_class: str
core_kwargs: dict
...
...
obs_randomizer_class: str
obs_randomizer_kwargs: dict
...
...
obs_modality2: dict
...
"""
super(ObservationGroupEncoder, self).__init__()
# type checking
assert isinstance(observation_group_shapes, OrderedDict)
assert np.all([isinstance(observation_group_shapes[k], OrderedDict) for k in observation_group_shapes])
self.observation_group_shapes = observation_group_shapes
# create an observation encoder per observation group
self.nets = nn.ModuleDict()
for obs_group in self.observation_group_shapes:
self.nets[obs_group] = obs_encoder_factory(
obs_shapes=self.observation_group_shapes[obs_group],
feature_activation=feature_activation,
encoder_kwargs=encoder_kwargs,
)
def forward(self, **inputs):
"""
Process each set of inputs in its own observation group.
Args:
inputs (dict): dictionary that maps observation groups to observation
dictionaries of torch.Tensor batches that agree with
@self.observation_group_shapes. All observation groups in
@self.observation_group_shapes must be present, but additional
observation groups can also be present. Note that these are specified
as kwargs for ease of use with networks that name each observation
stream in their forward calls.
Returns:
outputs (torch.Tensor): flat outputs of shape [B, D]
"""
# ensure all observation groups we need are present
assert set(self.observation_group_shapes.keys()).issubset(inputs), "{} does not contain all observation groups {}".format(
list(inputs.keys()), list(self.observation_group_shapes.keys())
)
outputs = []
# Deterministic order since self.observation_group_shapes is OrderedDict
for obs_group in self.observation_group_shapes:
# pass through encoder
outputs.append(
self.nets[obs_group].forward(inputs[obs_group])
)
return torch.cat(outputs, dim=-1)
def output_shape(self):
"""
Compute the output shape of this encoder.
"""
feat_dim = 0
for obs_group in self.observation_group_shapes:
# get feature dimension of these keys
feat_dim += self.nets[obs_group].output_shape()[0]
return [feat_dim]
def __repr__(self):
"""Pretty print network."""
header = '{}'.format(str(self.__class__.__name__))
msg = ''
for k in self.observation_group_shapes:
msg += '\n'
indent = ' ' * 4
msg += textwrap.indent("group={}\n{}".format(k, self.nets[k]), indent)
msg = header + '(' + msg + '\n)'
return msg
class MIMO_MLP(Module):
"""
Extension to MLP to accept multiple observation dictionaries as input and
to output dictionaries of tensors. Inputs are specified as a dictionary of
observation dictionaries, with each key corresponding to an observation group.
This module utilizes @ObservationGroupEncoder to process the multiple input dictionaries and
@ObservationDecoder to generate tensor dictionaries. The default behavior
for encoding the inputs is to process visual inputs with a learned CNN and concatenating
the flat encodings with the other flat inputs. The default behavior for generating
outputs is to use a linear layer branch to produce each modality separately
(including visual outputs).
"""
def __init__(
self,
input_obs_group_shapes,
output_shapes,
layer_dims,
layer_func=nn.Linear,
activation=nn.ReLU,
encoder_kwargs=None,
):
"""
Args:
input_obs_group_shapes (OrderedDict): a dictionary of dictionaries.
Each key in this dictionary should specify an observation group, and
the value should be an OrderedDict that maps modalities to
expected shapes.
output_shapes (OrderedDict): a dictionary that maps modality to
expected shapes for outputs.
layer_dims ([int]): sequence of integers for the MLP hidden layer sizes
layer_func: mapping per MLP layer - defaults to Linear
activation: non-linearity per MLP layer - defaults to ReLU
encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
be nested dictionary containing relevant per-modality information for encoder networks.
Should be of form:
obs_modality1: dict
feature_dimension: int
core_class: str
core_kwargs: dict
...
...
obs_randomizer_class: str
obs_randomizer_kwargs: dict
...
...
obs_modality2: dict
...
"""
super(MIMO_MLP, self).__init__()
assert isinstance(input_obs_group_shapes, OrderedDict)
assert np.all([isinstance(input_obs_group_shapes[k], OrderedDict) for k in input_obs_group_shapes])
assert isinstance(output_shapes, OrderedDict)
self.input_obs_group_shapes = input_obs_group_shapes
self.output_shapes = output_shapes
self.nets = nn.ModuleDict()
# Encoder for all observation groups.
self.nets["encoder"] = ObservationGroupEncoder(
observation_group_shapes=input_obs_group_shapes,
encoder_kwargs=encoder_kwargs,
)
# flat encoder output dimension
mlp_input_dim = self.nets["encoder"].output_shape()[0]
# intermediate MLP layers
self.nets["mlp"] = MLP(
input_dim=mlp_input_dim,
output_dim=layer_dims[-1],
layer_dims=layer_dims[:-1],
layer_func=layer_func,
activation=activation,
output_activation=activation, # make sure non-linearity is applied before decoder
)
# decoder for output modalities
self.nets["decoder"] = ObservationDecoder(
decode_shapes=self.output_shapes,
input_feat_dim=layer_dims[-1],
)
def output_shape(self, input_shape=None):
"""
Returns output shape for this module, which is a dictionary instead
of a list since outputs are dictionaries.
"""
return { k : list(self.output_shapes[k]) for k in self.output_shapes }
def forward(self, **inputs):
"""
Process each set of inputs in its own observation group.
Args:
inputs (dict): a dictionary of dictionaries with one dictionary per
observation group. Each observation group's dictionary should map
modality to torch.Tensor batches. Should be consistent with
@self.input_obs_group_shapes.
Returns:
outputs (dict): dictionary of output torch.Tensors, that corresponds
to @self.output_shapes
"""
enc_outputs = self.nets["encoder"](**inputs)
mlp_out = self.nets["mlp"](enc_outputs)
return self.nets["decoder"](mlp_out)
def _to_string(self):
"""
Subclasses should override this method to print out info about network / policy.
"""
return ''
def __repr__(self):
"""Pretty print network."""
header = '{}'.format(str(self.__class__.__name__))
msg = ''
indent = ' ' * 4
if self._to_string() != '':
msg += textwrap.indent("\n" + self._to_string() + "\n", indent)
msg += textwrap.indent("\nencoder={}".format(self.nets["encoder"]), indent)
msg += textwrap.indent("\n\nmlp={}".format(self.nets["mlp"]), indent)
msg += textwrap.indent("\n\ndecoder={}".format(self.nets["decoder"]), indent)
msg = header + '(' + msg + '\n)'
return msg
class RNN_MIMO_MLP(Module):
"""
A wrapper class for a multi-step RNN and a per-step MLP and a decoder.
Structure: [encoder -> rnn -> mlp -> decoder]
All temporal inputs are processed by a shared @ObservationGroupEncoder,
followed by an RNN, and then a per-step multi-output MLP.
"""
def __init__(
self,
input_obs_group_shapes,
output_shapes,
mlp_layer_dims,
rnn_hidden_dim,
rnn_num_layers,
rnn_type="LSTM", # [LSTM, GRU]
rnn_kwargs=None,
mlp_activation=nn.ReLU,
mlp_layer_func=nn.Linear,
per_step=True,
encoder_kwargs=None,
):
"""
Args:
input_obs_group_shapes (OrderedDict): a dictionary of dictionaries.
Each key in this dictionary should specify an observation group, and
the value should be an OrderedDict that maps modalities to
expected shapes.
output_shapes (OrderedDict): a dictionary that maps modality to
expected shapes for outputs.
rnn_hidden_dim (int): RNN hidden dimension
rnn_num_layers (int): number of RNN layers
rnn_type (str): [LSTM, GRU]
rnn_kwargs (dict): kwargs for the rnn model
per_step (bool): if True, apply the MLP and observation decoder into @output_shapes
at every step of the RNN. Otherwise, apply them to the final hidden state of the
RNN.
encoder_kwargs (dict or None): If None, results in default encoder_kwargs being applied. Otherwise, should
be nested dictionary containing relevant per-modality information for encoder networks.
Should be of form:
obs_modality1: dict
feature_dimension: int
core_class: str
core_kwargs: dict
...
...
obs_randomizer_class: str
obs_randomizer_kwargs: dict
...
...
obs_modality2: dict
...
"""
super(RNN_MIMO_MLP, self).__init__()
assert isinstance(input_obs_group_shapes, OrderedDict)
assert np.all([isinstance(input_obs_group_shapes[k], OrderedDict) for k in input_obs_group_shapes])
assert isinstance(output_shapes, OrderedDict)
self.input_obs_group_shapes = input_obs_group_shapes
self.output_shapes = output_shapes
self.per_step = per_step
self.nets = nn.ModuleDict()
# Encoder for all observation groups.
self.nets["encoder"] = ObservationGroupEncoder(
observation_group_shapes=input_obs_group_shapes,
encoder_kwargs=encoder_kwargs,
)
# flat encoder output dimension
rnn_input_dim = self.nets["encoder"].output_shape()[0]
# bidirectional RNNs mean that the output of RNN will be twice the hidden dimension
rnn_is_bidirectional = rnn_kwargs.get("bidirectional", False)
num_directions = int(rnn_is_bidirectional) + 1 # 2 if bidirectional, 1 otherwise
rnn_output_dim = num_directions * rnn_hidden_dim
per_step_net = None
self._has_mlp = (len(mlp_layer_dims) > 0)
if self._has_mlp:
self.nets["mlp"] = MLP(
input_dim=rnn_output_dim,
output_dim=mlp_layer_dims[-1],
layer_dims=mlp_layer_dims[:-1],
output_activation=mlp_activation,
layer_func=mlp_layer_func
)
self.nets["decoder"] = ObservationDecoder(
decode_shapes=self.output_shapes,
input_feat_dim=mlp_layer_dims[-1],
)
if self.per_step:
per_step_net = Sequential(self.nets["mlp"], self.nets["decoder"])
else:
self.nets["decoder"] = ObservationDecoder(
decode_shapes=self.output_shapes,
input_feat_dim=rnn_output_dim,
)
if self.per_step:
per_step_net = self.nets["decoder"]
# core network
self.nets["rnn"] = RNN_Base(
input_dim=rnn_input_dim,
rnn_hidden_dim=rnn_hidden_dim,
rnn_num_layers=rnn_num_layers,
rnn_type=rnn_type,
per_step_net=per_step_net,
rnn_kwargs=rnn_kwargs
)
def get_rnn_init_state(self, batch_size, device):
"""
Get a default RNN state (zeros)
Args:
batch_size (int): batch size dimension
device: device the hidden state should be sent to.
Returns:
hidden_state (torch.Tensor or tuple): returns hidden state tensor or tuple of hidden state tensors
depending on the RNN type
"""
return self.nets["rnn"].get_rnn_init_state(batch_size, device=device)
def output_shape(self, input_shape):
"""
Returns output shape for this module, which is a dictionary instead
of a list since outputs are dictionaries.
Args:
input_shape (dict): dictionary of dictionaries, where each top-level key
corresponds to an observation group, and the low-level dictionaries
specify the shape for each modality in an observation dictionary
"""
# infers temporal dimension from input shape
obs_group = list(self.input_obs_group_shapes.keys())[0]
mod = list(self.input_obs_group_shapes[obs_group].keys())[0]
T = input_shape[obs_group][mod][0]
TensorUtils.assert_size_at_dim(input_shape, size=T, dim=0,
msg="RNN_MIMO_MLP: input_shape inconsistent in temporal dimension")
# returns a dictionary instead of list since outputs are dictionaries
return { k : [T] + list(self.output_shapes[k]) for k in self.output_shapes }
def forward(self, rnn_init_state=None, return_state=False, **inputs):
"""
Args:
inputs (dict): a dictionary of dictionaries with one dictionary per
observation group. Each observation group's dictionary should map
modality to torch.Tensor batches. Should be consistent with
@self.input_obs_group_shapes. First two leading dimensions should
be batch and time [B, T, ...] for each tensor.
rnn_init_state: rnn hidden state, initialize to zero state if set to None
return_state (bool): whether to return hidden state
Returns:
outputs (dict): dictionary of output torch.Tensors, that corresponds
to @self.output_shapes. Leading dimensions will be batch and time [B, T, ...]
for each tensor.
rnn_state (torch.Tensor or tuple): return the new rnn state (if @return_state)
"""
for obs_group in self.input_obs_group_shapes:
for k in self.input_obs_group_shapes[obs_group]:
# first two dimensions should be [B, T] for inputs
assert inputs[obs_group][k].ndim - 2 == len(self.input_obs_group_shapes[obs_group][k])
# use encoder to extract flat rnn inputs
rnn_inputs = TensorUtils.time_distributed(inputs, self.nets["encoder"], inputs_as_kwargs=True)
assert rnn_inputs.ndim == 3 # [B, T, D]
if self.per_step:
return self.nets["rnn"].forward(inputs=rnn_inputs, rnn_init_state=rnn_init_state, return_state=return_state)
# apply MLP + decoder to last RNN output
outputs = self.nets["rnn"].forward(inputs=rnn_inputs, rnn_init_state=rnn_init_state, return_state=return_state)
if return_state:
outputs, rnn_state = outputs
assert outputs.ndim == 3 # [B, T, D]
if self._has_mlp:
outputs = self.nets["decoder"](self.nets["mlp"](outputs[:, -1]))
else:
outputs = self.nets["decoder"](outputs[:, -1])
if return_state:
return outputs, rnn_state
return outputs
def forward_step(self, rnn_state, **inputs):
"""
Unroll network over a single timestep.
Args:
inputs (dict): expects same modalities as @self.input_shapes, with
additional batch dimension (but NOT time), since this is a
single time step.
rnn_state (torch.Tensor): rnn hidden state
Returns:
outputs (dict): dictionary of output torch.Tensors, that corresponds
to @self.output_shapes. Does not contain time dimension.
rnn_state: return the new rnn state
"""
# ensure that the only extra dimension is batch dim, not temporal dim
assert np.all([inputs[k].ndim - 1 == len(self.input_shapes[k]) for k in self.input_shapes])
inputs = TensorUtils.to_sequence(inputs)
outputs, rnn_state = self.forward(
inputs,
rnn_init_state=rnn_state,
return_state=True,
)
if self.per_step:
# if outputs are not per-step, the time dimension is already reduced
outputs = outputs[:, 0]
return outputs, rnn_state
def _to_string(self):
"""
Subclasses should override this method to print out info about network / policy.
"""
return ''
def __repr__(self):
"""Pretty print network."""
header = '{}'.format(str(self.__class__.__name__))
msg = ''
indent = ' ' * 4
msg += textwrap.indent("\n" + self._to_string(), indent)
msg += textwrap.indent("\n\nencoder={}".format(self.nets["encoder"]), indent)
msg += textwrap.indent("\n\nrnn={}".format(self.nets["rnn"]), indent)
msg = header + '(' + msg + '\n)'
return msg
class MIMO_Transformer(Module):
"""
Extension to Transformer (based on GPT architecture) to accept multiple observation
dictionaries as input and to output dictionaries of tensors. Inputs are specified as
a dictionary of observation dictionaries, with each key corresponding to an observation group.
This module utilizes @ObservationGroupEncoder to process the multiple input dictionaries and
@ObservationDecoder to generate tensor dictionaries. The default behavior
for encoding the inputs is to process visual inputs with a learned CNN and concatenating
the flat encodings with the other flat inputs. The default behavior for generating
outputs is to use a linear layer branch to produce each modality separately
(including visual outputs).
"""
def __init__(
self,
input_obs_group_shapes,
output_shapes,
transformer_embed_dim,
transformer_num_layers,
transformer_num_heads,
transformer_context_length,
transformer_emb_dropout=0.1,
transformer_attn_dropout=0.1,
transformer_block_output_dropout=0.1,
transformer_sinusoidal_embedding=False,
transformer_activation="gelu",
transformer_nn_parameter_for_timesteps=False,
encoder_kwargs=None,
):
"""
Args:
input_obs_group_shapes (OrderedDict): a dictionary of dictionaries.
Each key in this dictionary should specify an observation group, and
the value should be an OrderedDict that maps modalities to
expected shapes.
output_shapes (OrderedDict): a dictionary that maps modality to
expected shapes for outputs.
transformer_embed_dim (int): dimension for embeddings used by transformer
transformer_num_layers (int): number of transformer blocks to stack
transformer_num_heads (int): number of attention heads for each
transformer block - must divide @transformer_embed_dim evenly. Self-attention is
computed over this many partitions of the embedding dimension separately.
transformer_context_length (int): expected length of input sequences
transformer_activation: non-linearity for input and output layers used in transformer
transformer_emb_dropout (float): dropout probability for embedding inputs in transformer
transformer_attn_dropout (float): dropout probability for attention outputs for each transformer block
transformer_block_output_dropout (float): dropout probability for final outputs for each transformer block
encoder_kwargs (dict): observation encoder config
"""
super(MIMO_Transformer, self).__init__()
assert isinstance(input_obs_group_shapes, OrderedDict)
assert np.all([isinstance(input_obs_group_shapes[k], OrderedDict) for k in input_obs_group_shapes])
assert isinstance(output_shapes, OrderedDict)
self.input_obs_group_shapes = input_obs_group_shapes
self.output_shapes = output_shapes
self.nets = nn.ModuleDict()
self.params = nn.ParameterDict()
# Encoder for all observation groups.
self.nets["encoder"] = ObservationGroupEncoder(
observation_group_shapes=input_obs_group_shapes,
encoder_kwargs=encoder_kwargs,
feature_activation=None,
)
# flat encoder output dimension
transformer_input_dim = self.nets["encoder"].output_shape()[0]
self.nets["embed_encoder"] = nn.Linear(
transformer_input_dim, transformer_embed_dim
)
max_timestep = transformer_context_length
if transformer_sinusoidal_embedding:
self.nets["embed_timestep"] = PositionalEncoding(transformer_embed_dim)
elif transformer_nn_parameter_for_timesteps:
assert (
not transformer_sinusoidal_embedding
), "nn.Parameter only works with learned embeddings"
self.params["embed_timestep"] = nn.Parameter(
torch.zeros(1, max_timestep, transformer_embed_dim)
)
else:
self.nets["embed_timestep"] = nn.Embedding(max_timestep, transformer_embed_dim)
# layer norm for embeddings
self.nets["embed_ln"] = nn.LayerNorm(transformer_embed_dim)
# dropout for input embeddings
self.nets["embed_drop"] = nn.Dropout(transformer_emb_dropout)
# GPT transformer
self.nets["transformer"] = GPT_Backbone(
embed_dim=transformer_embed_dim,
num_layers=transformer_num_layers,
num_heads=transformer_num_heads,
context_length=transformer_context_length,
attn_dropout=transformer_attn_dropout,
block_output_dropout=transformer_block_output_dropout,
activation=transformer_activation,
)
# decoder for output modalities
self.nets["decoder"] = ObservationDecoder(
decode_shapes=self.output_shapes,
input_feat_dim=transformer_embed_dim,
)
self.transformer_context_length = transformer_context_length
self.transformer_embed_dim = transformer_embed_dim
self.transformer_sinusoidal_embedding = transformer_sinusoidal_embedding
self.transformer_nn_parameter_for_timesteps = transformer_nn_parameter_for_timesteps
def output_shape(self, input_shape=None):
"""
Returns output shape for this module, which is a dictionary instead
of a list since outputs are dictionaries.
"""
return { k : list(self.output_shapes[k]) for k in self.output_shapes }
def embed_timesteps(self, embeddings):
"""
Computes timestep-based embeddings (aka positional embeddings) to add to embeddings.
Args:
embeddings (torch.Tensor): embeddings prior to positional embeddings are computed
Returns:
time_embeddings (torch.Tensor): positional embeddings to add to embeddings
"""
timesteps = (
torch.arange(
0,
embeddings.shape[1],
dtype=embeddings.dtype,
device=embeddings.device,
)
.unsqueeze(0)
.repeat(embeddings.shape[0], 1)
)
assert (timesteps >= 0.0).all(), "timesteps must be positive!"
if self.transformer_sinusoidal_embedding:
assert torch.is_floating_point(timesteps), timesteps.dtype
else:
timesteps = timesteps.long()
if self.transformer_nn_parameter_for_timesteps:
time_embeddings = self.params["embed_timestep"]
else:
time_embeddings = self.nets["embed_timestep"](
timesteps
) # these are NOT fed into transformer, only added to the inputs.
# compute how many modalities were combined into embeddings, replicate time embeddings that many times
num_replicates = embeddings.shape[-1] // self.transformer_embed_dim
time_embeddings = torch.cat([time_embeddings for _ in range(num_replicates)], -1)
assert (
embeddings.shape == time_embeddings.shape
), f"{embeddings.shape}, {time_embeddings.shape}"
return time_embeddings
def input_embedding(
self,
inputs,
):
"""
Process encoded observations into embeddings to pass to transformer,
Adds timestep-based embeddings (aka positional embeddings) to inputs.
Args:
inputs (torch.Tensor): outputs from observation encoder
Returns:
embeddings (torch.Tensor): input embeddings to pass to transformer backbone.
"""
embeddings = self.nets["embed_encoder"](inputs)
time_embeddings = self.embed_timesteps(embeddings)
embeddings = embeddings + time_embeddings
embeddings = self.nets["embed_ln"](embeddings)
embeddings = self.nets["embed_drop"](embeddings)
return embeddings
def forward(self, **inputs):
"""
Process each set of inputs in its own observation group.
Args:
inputs (dict): a dictionary of dictionaries with one dictionary per
observation group. Each observation group's dictionary should map
modality to torch.Tensor batches. Should be consistent with
@self.input_obs_group_shapes. First two leading dimensions should
be batch and time [B, T, ...] for each tensor.
Returns:
outputs (dict): dictionary of output torch.Tensors, that corresponds
to @self.output_shapes. Leading dimensions will be batch and time [B, T, ...]
for each tensor.
"""
for obs_group in self.input_obs_group_shapes:
for k in self.input_obs_group_shapes[obs_group]:
# first two dimensions should be [B, T] for inputs
if inputs[obs_group][k] is None:
continue
assert inputs[obs_group][k].ndim - 2 == len(self.input_obs_group_shapes[obs_group][k])
inputs = inputs.copy()
transformer_encoder_outputs = None
transformer_inputs = TensorUtils.time_distributed(
inputs, self.nets["encoder"], inputs_as_kwargs=True
)
assert transformer_inputs.ndim == 3 # [B, T, D]
if transformer_encoder_outputs is None:
transformer_embeddings = self.input_embedding(transformer_inputs)
# pass encoded sequences through transformer
transformer_encoder_outputs = self.nets["transformer"].forward(transformer_embeddings)
transformer_outputs = transformer_encoder_outputs
# apply decoder to each timestep of sequence to get a dictionary of outputs
transformer_outputs = TensorUtils.time_distributed(
transformer_outputs, self.nets["decoder"]
)
transformer_outputs["transformer_encoder_outputs"] = transformer_encoder_outputs
return transformer_outputs
def _to_string(self):
"""
Subclasses should override this method to print out info about network / policy.
"""
return ''
def __repr__(self):
"""Pretty print network."""
header = '{}'.format(str(self.__class__.__name__))
msg = ''
indent = ' ' * 4
if self._to_string() != '':
msg += textwrap.indent("\n" + self._to_string() + "\n", indent)
msg += textwrap.indent("\nencoder={}".format(self.nets["encoder"]), indent)
msg += textwrap.indent("\n\ntransformer={}".format(self.nets["transformer"]), indent)
msg += textwrap.indent("\n\ndecoder={}".format(self.nets["decoder"]), indent)
msg = header + '(' + msg + '\n)'
return msg |