File size: 5,035 Bytes
0d89eb9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | from typing import List
import torch
import torchvision.transforms as transforms
from yarr.agents.agent import (
Agent,
Summary,
ActResult,
ScalarSummary,
HistogramSummary,
ImageSummary,
)
class PreprocessAgent(Agent):
def __init__(
self, pose_agent: Agent, norm_rgb: bool = True, norm_type: str = "zero_mean"
):
self._pose_agent = pose_agent
self._norm_rgb = norm_rgb
self._norm_type = norm_type
def build(self, training: bool, device: torch.device = None):
self._pose_agent.build(training, device)
def _norm_rgb_(self, x):
if self._norm_type == "zero_mean":
return (x.float() / 255.0) * 2.0 - 1.0
elif self._norm_type == "imagenet":
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# return normalize(x)
return x.float() / 255.0
else:
raise NotImplementedError
def update(self, step: int, replay_sample: dict) -> dict:
# Samples are (B, N, ...) where N is number of buffers/tasks. This is a single task setup, so 0 index.
replay_sample = {
k: v[:, 0] if len(v.shape) > 2 and v.shape[1] == 1 else v
for k, v in replay_sample.items()
}
for k, v in replay_sample.items():
if self._norm_rgb and "rgb" in k:
replay_sample[k] = self._norm_rgb_(v)
else:
replay_sample[k] = v.float()
self._replay_sample = replay_sample
return self._pose_agent.update(step, replay_sample)
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
# observation = {k: torch.tensor(v) for k, v in observation.items()}
for k, v in observation.items():
if self._norm_rgb and "rgb" in k:
observation[k] = self._norm_rgb_(v)
else:
observation[k] = v.float()
act_res = self._pose_agent.act(step, observation, deterministic)
act_res.replay_elements.update({"demo": False})
return act_res
def update_summaries(self) -> List[Summary]:
prefix = "inputs"
demo_f = self._replay_sample["demo"].float()
demo_proportion = demo_f.mean()
tile = lambda x: torch.squeeze(torch.cat(x.split(1, dim=1), dim=-1), dim=1)
sums = [
ScalarSummary("%s/demo_proportion" % prefix, demo_proportion),
ScalarSummary(
"%s/timeouts" % prefix, self._replay_sample["timeout"].float().mean()
),
]
for robot_prefix in ["", "right_", "left_"]:
if not f"{robot_prefix}low_dim_state" in self._replay_sample.keys():
continue
sums.extend(
[
HistogramSummary(
f"{prefix}/{robot_prefix}low_dim_state",
self._replay_sample[f"{robot_prefix}low_dim_state"],
),
HistogramSummary(
f"{prefix}/{robot_prefix}low_dim_state_tp1",
self._replay_sample[f"{robot_prefix}low_dim_state_tp1"],
),
ScalarSummary(
f"{prefix}/{robot_prefix}low_dim_state_mean",
self._replay_sample[f"{robot_prefix}low_dim_state"].mean(),
),
ScalarSummary(
f"{prefix}/{robot_prefix}low_dim_state_min",
self._replay_sample[f"{robot_prefix}low_dim_state"].min(),
),
ScalarSummary(
f"{prefix}/{robot_prefix}low_dim_state_max",
self._replay_sample[f"{robot_prefix}low_dim_state"].max(),
),
]
)
for k, v in self._replay_sample.items():
if "rgb" in k or "point_cloud" in k:
if "rgb" in k:
# Convert back to 0 - 1
v = (v + 1.0) / 2.0
sums.append(
ImageSummary(
"%s/%s" % (prefix, k), tile(v) if len(v.shape) > 4 else v
)
)
if "sampling_probabilities" in self._replay_sample:
sums.extend(
[
HistogramSummary(
"replay/priority", self._replay_sample["sampling_probabilities"]
),
]
)
sums.extend(self._pose_agent.update_summaries())
return sums
def act_summaries(self) -> List[Summary]:
return self._pose_agent.act_summaries()
def load_weights(self, savedir: str):
self._pose_agent.load_weights(savedir)
def save_weights(self, savedir: str):
self._pose_agent.save_weights(savedir)
def reset(self) -> None:
self._pose_agent.reset()
|