| |
| |
| |
|
|
|
|
| from helpers.preprocess_agent import PreprocessAgent |
| from agents.peract_bc.perceiver_lang_io import PerceiverVoxelLangEncoder |
| from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent |
| from agents.peract_bc.qattention_stack_agent import QAttentionStackAgent |
|
|
| from omegaconf import DictConfig |
|
|
|
|
| def create_agent(cfg: DictConfig): |
| LATENT_SIZE = 64 |
| depth_0bounds = cfg.rlbench.scene_bounds |
| cam_resolution = cfg.rlbench.camera_resolution |
|
|
| num_rotation_classes = int(360.0 // cfg.method.rotation_resolution) |
| qattention_agents = [] |
| for depth, vox_size in enumerate(cfg.method.voxel_sizes): |
| last = depth == len(cfg.method.voxel_sizes) - 1 |
| perceiver_encoder = PerceiverVoxelLangEncoder( |
| depth=cfg.method.transformer_depth, |
| iterations=cfg.method.transformer_iterations, |
| voxel_size=vox_size, |
| initial_dim=3 + 3 + 1 + 3, |
| low_dim_size=cfg.method.low_dim_size, |
| layer=depth, |
| num_rotation_classes=num_rotation_classes if last else 0, |
| num_grip_classes=2 if last else 0, |
| num_collision_classes=2 if last else 0, |
| input_axis=3, |
| num_latents=cfg.method.num_latents, |
| latent_dim=cfg.method.latent_dim, |
| cross_heads=cfg.method.cross_heads, |
| latent_heads=cfg.method.latent_heads, |
| cross_dim_head=cfg.method.cross_dim_head, |
| latent_dim_head=cfg.method.latent_dim_head, |
| weight_tie_layers=False, |
| activation=cfg.method.activation, |
| pos_encoding_with_lang=cfg.method.pos_encoding_with_lang, |
| input_dropout=cfg.method.input_dropout, |
| attn_dropout=cfg.method.attn_dropout, |
| decoder_dropout=cfg.method.decoder_dropout, |
| lang_fusion_type=cfg.method.lang_fusion_type, |
| voxel_patch_size=cfg.method.voxel_patch_size, |
| voxel_patch_stride=cfg.method.voxel_patch_stride, |
| no_skip_connection=cfg.method.no_skip_connection, |
| no_perceiver=cfg.method.no_perceiver, |
| no_language=cfg.method.no_language, |
| final_dim=cfg.method.final_dim, |
| ) |
|
|
| qattention_agent = QAttentionPerActBCAgent( |
| layer=depth, |
| coordinate_bounds=depth_0bounds, |
| perceiver_encoder=perceiver_encoder, |
| camera_names=cfg.rlbench.cameras, |
| voxel_size=vox_size, |
| bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None, |
| image_crop_size=cfg.method.image_crop_size, |
| lr=cfg.method.lr, |
| training_iterations=cfg.framework.training_iterations, |
| lr_scheduler=cfg.method.lr_scheduler, |
| num_warmup_steps=cfg.method.num_warmup_steps, |
| trans_loss_weight=cfg.method.trans_loss_weight, |
| rot_loss_weight=cfg.method.rot_loss_weight, |
| grip_loss_weight=cfg.method.grip_loss_weight, |
| collision_loss_weight=cfg.method.collision_loss_weight, |
| include_low_dim_state=True, |
| image_resolution=cam_resolution, |
| batch_size=cfg.replay.batch_size, |
| voxel_feature_size=3, |
| lambda_weight_l2=cfg.method.lambda_weight_l2, |
| num_rotation_classes=num_rotation_classes, |
| rotation_resolution=cfg.method.rotation_resolution, |
| transform_augmentation=cfg.method.transform_augmentation.apply_se3, |
| transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz, |
| transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy, |
| transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution, |
| optimizer_type=cfg.method.optimizer, |
| num_devices=cfg.ddp.num_devices, |
| checkpoint_name_prefix=cfg.framework.checkpoint_name_prefix, |
| ) |
| qattention_agents.append(qattention_agent) |
|
|
| rotation_agent = QAttentionStackAgent( |
| qattention_agents=qattention_agents, |
| rotation_resolution=cfg.method.rotation_resolution, |
| camera_names=cfg.rlbench.cameras, |
| ) |
| preprocess_agent = PreprocessAgent(pose_agent=rotation_agent) |
| return preprocess_agent |
|
|