from typing import List import torch from yarr.agents.agent import Agent, ActResult, Summary import numpy as np from helpers import utils from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent NAME = "QAttentionStackAgent" class QAttentionStackAgent(Agent): def __init__( self, qattention_agents: List[QAttentionPerActBCAgent], rotation_resolution: float, camera_names: List[str], rotation_prediction_depth: int = 0, ): super(QAttentionStackAgent, self).__init__() self._qattention_agents = qattention_agents self._rotation_resolution = rotation_resolution self._camera_names = camera_names self._rotation_prediction_depth = rotation_prediction_depth def build(self, training: bool, device=None) -> None: self._device = device if self._device is None: self._device = torch.device("cpu") for qa in self._qattention_agents: qa.build(training, device) def update(self, step: int, replay_sample: dict) -> dict: priorities = 0 total_losses = 0.0 for qa in self._qattention_agents: update_dict = qa.update(step, replay_sample) replay_sample.update(update_dict) total_losses += update_dict["total_loss"] return { "total_losses": total_losses, } def act(self, step: int, observation: dict, deterministic=False) -> ActResult: observation_elements = {} translation_results, rot_grip_results, ignore_collisions_results = [], [], [] infos = {} for depth, qagent in enumerate(self._qattention_agents): act_results = qagent.act(step, observation, deterministic) attention_coordinate = ( act_results.observation_elements["attention_coordinate"].cpu().numpy() ) observation_elements[ "attention_coordinate_layer_%d" % depth ] = attention_coordinate[0] translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action translation_results.append(translation_idxs) if rot_grip_idxs is not None: rot_grip_results.append(rot_grip_idxs) if ignore_collisions_idxs is not None: ignore_collisions_results.append(ignore_collisions_idxs) observation["attention_coordinate"] = act_results.observation_elements[ "attention_coordinate" ] observation["prev_layer_voxel_grid"] = act_results.observation_elements[ "prev_layer_voxel_grid" ] observation["prev_layer_bounds"] = act_results.observation_elements[ "prev_layer_bounds" ] for n in self._camera_names: px, py = utils.point_to_pixel_index( attention_coordinate[0], observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(), observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(), ) pc_t = torch.tensor( [[[py, px]]], dtype=torch.float32, device=self._device ) observation["%s_pixel_coord" % n] = pc_t observation_elements["%s_pixel_coord" % n] = [py, px] infos.update(act_results.info) rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy() ignore_collisions = float( torch.cat(ignore_collisions_results, 1)[0].cpu().numpy() ) observation_elements["trans_action_indicies"] = ( torch.cat(translation_results, 1)[0].cpu().numpy() ) observation_elements["rot_grip_action_indicies"] = rgai continuous_action = np.concatenate( [ act_results.observation_elements["attention_coordinate"] .cpu() .numpy()[0], utils.discrete_euler_to_quaternion( rgai[-4:-1], self._rotation_resolution ), rgai[-1:], [ignore_collisions], ] ) return ActResult( continuous_action, observation_elements=observation_elements, info=infos ) def update_summaries(self) -> List[Summary]: summaries = [] for qa in self._qattention_agents: summaries.extend(qa.update_summaries()) return summaries def act_summaries(self) -> List[Summary]: s = [] for qa in self._qattention_agents: s.extend(qa.act_summaries()) return s def load_weights(self, savedir: str): for qa in self._qattention_agents: qa.load_weights(savedir) def save_weights(self, savedir: str): for qa in self._qattention_agents: qa.save_weights(savedir)