lsnu's picture
Add files using upload-large-folder tool
0d89eb9 verified
# Adapted from ARM
# Source: https://github.com/stepjam/ARM
# License: https://github.com/stepjam/ARM/LICENSE
from helpers.preprocess_agent import PreprocessAgent
from agents.peract_bc.perceiver_lang_io import PerceiverVoxelLangEncoder
from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
from agents.peract_bc.qattention_stack_agent import QAttentionStackAgent
from omegaconf import DictConfig
def create_agent(cfg: DictConfig):
LATENT_SIZE = 64
depth_0bounds = cfg.rlbench.scene_bounds
cam_resolution = cfg.rlbench.camera_resolution
num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
qattention_agents = []
for depth, vox_size in enumerate(cfg.method.voxel_sizes):
last = depth == len(cfg.method.voxel_sizes) - 1
perceiver_encoder = PerceiverVoxelLangEncoder(
depth=cfg.method.transformer_depth,
iterations=cfg.method.transformer_iterations,
voxel_size=vox_size,
initial_dim=3 + 3 + 1 + 3,
low_dim_size=cfg.method.low_dim_size,
layer=depth,
num_rotation_classes=num_rotation_classes if last else 0,
num_grip_classes=2 if last else 0,
num_collision_classes=2 if last else 0,
input_axis=3,
num_latents=cfg.method.num_latents,
latent_dim=cfg.method.latent_dim,
cross_heads=cfg.method.cross_heads,
latent_heads=cfg.method.latent_heads,
cross_dim_head=cfg.method.cross_dim_head,
latent_dim_head=cfg.method.latent_dim_head,
weight_tie_layers=False,
activation=cfg.method.activation,
pos_encoding_with_lang=cfg.method.pos_encoding_with_lang,
input_dropout=cfg.method.input_dropout,
attn_dropout=cfg.method.attn_dropout,
decoder_dropout=cfg.method.decoder_dropout,
lang_fusion_type=cfg.method.lang_fusion_type,
voxel_patch_size=cfg.method.voxel_patch_size,
voxel_patch_stride=cfg.method.voxel_patch_stride,
no_skip_connection=cfg.method.no_skip_connection,
no_perceiver=cfg.method.no_perceiver,
no_language=cfg.method.no_language,
final_dim=cfg.method.final_dim,
)
qattention_agent = QAttentionPerActBCAgent(
layer=depth,
coordinate_bounds=depth_0bounds,
perceiver_encoder=perceiver_encoder,
camera_names=cfg.rlbench.cameras,
voxel_size=vox_size,
bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
image_crop_size=cfg.method.image_crop_size,
lr=cfg.method.lr,
training_iterations=cfg.framework.training_iterations,
lr_scheduler=cfg.method.lr_scheduler,
num_warmup_steps=cfg.method.num_warmup_steps,
trans_loss_weight=cfg.method.trans_loss_weight,
rot_loss_weight=cfg.method.rot_loss_weight,
grip_loss_weight=cfg.method.grip_loss_weight,
collision_loss_weight=cfg.method.collision_loss_weight,
include_low_dim_state=True,
image_resolution=cam_resolution,
batch_size=cfg.replay.batch_size,
voxel_feature_size=3,
lambda_weight_l2=cfg.method.lambda_weight_l2,
num_rotation_classes=num_rotation_classes,
rotation_resolution=cfg.method.rotation_resolution,
transform_augmentation=cfg.method.transform_augmentation.apply_se3,
transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
optimizer_type=cfg.method.optimizer,
num_devices=cfg.ddp.num_devices,
checkpoint_name_prefix=cfg.framework.checkpoint_name_prefix,
)
qattention_agents.append(qattention_agent)
rotation_agent = QAttentionStackAgent(
qattention_agents=qattention_agents,
rotation_resolution=cfg.method.rotation_resolution,
camera_names=cfg.rlbench.cameras,
)
preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
return preprocess_agent