VLAdaptorBench / external /peract_bimanual /agents /peract_bc /qattention_stack_agent.py
lsnu's picture
Add files using upload-large-folder tool
0d89eb9 verified
from typing import List
import torch
from yarr.agents.agent import Agent, ActResult, Summary
import numpy as np
from helpers import utils
from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
NAME = "QAttentionStackAgent"
class QAttentionStackAgent(Agent):
def __init__(
self,
qattention_agents: List[QAttentionPerActBCAgent],
rotation_resolution: float,
camera_names: List[str],
rotation_prediction_depth: int = 0,
):
super(QAttentionStackAgent, self).__init__()
self._qattention_agents = qattention_agents
self._rotation_resolution = rotation_resolution
self._camera_names = camera_names
self._rotation_prediction_depth = rotation_prediction_depth
def build(self, training: bool, device=None) -> None:
self._device = device
if self._device is None:
self._device = torch.device("cpu")
for qa in self._qattention_agents:
qa.build(training, device)
def update(self, step: int, replay_sample: dict) -> dict:
priorities = 0
total_losses = 0.0
for qa in self._qattention_agents:
update_dict = qa.update(step, replay_sample)
replay_sample.update(update_dict)
total_losses += update_dict["total_loss"]
return {
"total_losses": total_losses,
}
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
observation_elements = {}
translation_results, rot_grip_results, ignore_collisions_results = [], [], []
infos = {}
for depth, qagent in enumerate(self._qattention_agents):
act_results = qagent.act(step, observation, deterministic)
attention_coordinate = (
act_results.observation_elements["attention_coordinate"].cpu().numpy()
)
observation_elements[
"attention_coordinate_layer_%d" % depth
] = attention_coordinate[0]
translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action
translation_results.append(translation_idxs)
if rot_grip_idxs is not None:
rot_grip_results.append(rot_grip_idxs)
if ignore_collisions_idxs is not None:
ignore_collisions_results.append(ignore_collisions_idxs)
observation["attention_coordinate"] = act_results.observation_elements[
"attention_coordinate"
]
observation["prev_layer_voxel_grid"] = act_results.observation_elements[
"prev_layer_voxel_grid"
]
observation["prev_layer_bounds"] = act_results.observation_elements[
"prev_layer_bounds"
]
for n in self._camera_names:
px, py = utils.point_to_pixel_index(
attention_coordinate[0],
observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(),
observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(),
)
pc_t = torch.tensor(
[[[py, px]]], dtype=torch.float32, device=self._device
)
observation["%s_pixel_coord" % n] = pc_t
observation_elements["%s_pixel_coord" % n] = [py, px]
infos.update(act_results.info)
rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy()
ignore_collisions = float(
torch.cat(ignore_collisions_results, 1)[0].cpu().numpy()
)
observation_elements["trans_action_indicies"] = (
torch.cat(translation_results, 1)[0].cpu().numpy()
)
observation_elements["rot_grip_action_indicies"] = rgai
continuous_action = np.concatenate(
[
act_results.observation_elements["attention_coordinate"]
.cpu()
.numpy()[0],
utils.discrete_euler_to_quaternion(
rgai[-4:-1], self._rotation_resolution
),
rgai[-1:],
[ignore_collisions],
]
)
return ActResult(
continuous_action, observation_elements=observation_elements, info=infos
)
def update_summaries(self) -> List[Summary]:
summaries = []
for qa in self._qattention_agents:
summaries.extend(qa.update_summaries())
return summaries
def act_summaries(self) -> List[Summary]:
s = []
for qa in self._qattention_agents:
s.extend(qa.act_summaries())
return s
def load_weights(self, savedir: str):
for qa in self._qattention_agents:
qa.load_weights(savedir)
def save_weights(self, savedir: str):
for qa in self._qattention_agents:
qa.save_weights(savedir)