File size: 4,924 Bytes
0d89eb9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | from typing import List
import torch
from yarr.agents.agent import Agent, ActResult, Summary
import numpy as np
from helpers import utils
from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
NAME = "QAttentionStackAgent"
class QAttentionStackAgent(Agent):
def __init__(
self,
qattention_agents: List[QAttentionPerActBCAgent],
rotation_resolution: float,
camera_names: List[str],
rotation_prediction_depth: int = 0,
):
super(QAttentionStackAgent, self).__init__()
self._qattention_agents = qattention_agents
self._rotation_resolution = rotation_resolution
self._camera_names = camera_names
self._rotation_prediction_depth = rotation_prediction_depth
def build(self, training: bool, device=None) -> None:
self._device = device
if self._device is None:
self._device = torch.device("cpu")
for qa in self._qattention_agents:
qa.build(training, device)
def update(self, step: int, replay_sample: dict) -> dict:
priorities = 0
total_losses = 0.0
for qa in self._qattention_agents:
update_dict = qa.update(step, replay_sample)
replay_sample.update(update_dict)
total_losses += update_dict["total_loss"]
return {
"total_losses": total_losses,
}
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
observation_elements = {}
translation_results, rot_grip_results, ignore_collisions_results = [], [], []
infos = {}
for depth, qagent in enumerate(self._qattention_agents):
act_results = qagent.act(step, observation, deterministic)
attention_coordinate = (
act_results.observation_elements["attention_coordinate"].cpu().numpy()
)
observation_elements[
"attention_coordinate_layer_%d" % depth
] = attention_coordinate[0]
translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action
translation_results.append(translation_idxs)
if rot_grip_idxs is not None:
rot_grip_results.append(rot_grip_idxs)
if ignore_collisions_idxs is not None:
ignore_collisions_results.append(ignore_collisions_idxs)
observation["attention_coordinate"] = act_results.observation_elements[
"attention_coordinate"
]
observation["prev_layer_voxel_grid"] = act_results.observation_elements[
"prev_layer_voxel_grid"
]
observation["prev_layer_bounds"] = act_results.observation_elements[
"prev_layer_bounds"
]
for n in self._camera_names:
px, py = utils.point_to_pixel_index(
attention_coordinate[0],
observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(),
observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(),
)
pc_t = torch.tensor(
[[[py, px]]], dtype=torch.float32, device=self._device
)
observation["%s_pixel_coord" % n] = pc_t
observation_elements["%s_pixel_coord" % n] = [py, px]
infos.update(act_results.info)
rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy()
ignore_collisions = float(
torch.cat(ignore_collisions_results, 1)[0].cpu().numpy()
)
observation_elements["trans_action_indicies"] = (
torch.cat(translation_results, 1)[0].cpu().numpy()
)
observation_elements["rot_grip_action_indicies"] = rgai
continuous_action = np.concatenate(
[
act_results.observation_elements["attention_coordinate"]
.cpu()
.numpy()[0],
utils.discrete_euler_to_quaternion(
rgai[-4:-1], self._rotation_resolution
),
rgai[-1:],
[ignore_collisions],
]
)
return ActResult(
continuous_action, observation_elements=observation_elements, info=infos
)
def update_summaries(self) -> List[Summary]:
summaries = []
for qa in self._qattention_agents:
summaries.extend(qa.update_summaries())
return summaries
def act_summaries(self) -> List[Summary]:
s = []
for qa in self._qattention_agents:
s.extend(qa.act_summaries())
return s
def load_weights(self, savedir: str):
for qa in self._qattention_agents:
qa.load_weights(savedir)
def save_weights(self, savedir: str):
for qa in self._qattention_agents:
qa.save_weights(savedir)
|