| from typing import List |
|
|
| import torch |
| import torchvision.transforms as transforms |
|
|
| from yarr.agents.agent import ( |
| Agent, |
| Summary, |
| ActResult, |
| ScalarSummary, |
| HistogramSummary, |
| ImageSummary, |
| ) |
|
|
|
|
| class PreprocessAgent(Agent): |
| def __init__( |
| self, pose_agent: Agent, norm_rgb: bool = True, norm_type: str = "zero_mean" |
| ): |
| self._pose_agent = pose_agent |
| self._norm_rgb = norm_rgb |
| self._norm_type = norm_type |
|
|
| def build(self, training: bool, device: torch.device = None): |
| self._pose_agent.build(training, device) |
|
|
| def _norm_rgb_(self, x): |
| if self._norm_type == "zero_mean": |
| return (x.float() / 255.0) * 2.0 - 1.0 |
| elif self._norm_type == "imagenet": |
| |
| |
| |
| return x.float() / 255.0 |
| else: |
| raise NotImplementedError |
|
|
| def update(self, step: int, replay_sample: dict) -> dict: |
| |
| replay_sample = { |
| k: v[:, 0] if len(v.shape) > 2 and v.shape[1] == 1 else v |
| for k, v in replay_sample.items() |
| } |
| for k, v in replay_sample.items(): |
| if self._norm_rgb and "rgb" in k: |
| replay_sample[k] = self._norm_rgb_(v) |
| else: |
| replay_sample[k] = v.float() |
| self._replay_sample = replay_sample |
| return self._pose_agent.update(step, replay_sample) |
|
|
| def act(self, step: int, observation: dict, deterministic=False) -> ActResult: |
| |
| for k, v in observation.items(): |
| if self._norm_rgb and "rgb" in k: |
| observation[k] = self._norm_rgb_(v) |
| else: |
| observation[k] = v.float() |
| act_res = self._pose_agent.act(step, observation, deterministic) |
| act_res.replay_elements.update({"demo": False}) |
| return act_res |
|
|
| def update_summaries(self) -> List[Summary]: |
| prefix = "inputs" |
| demo_f = self._replay_sample["demo"].float() |
| demo_proportion = demo_f.mean() |
| tile = lambda x: torch.squeeze(torch.cat(x.split(1, dim=1), dim=-1), dim=1) |
| sums = [ |
| ScalarSummary("%s/demo_proportion" % prefix, demo_proportion), |
| ScalarSummary( |
| "%s/timeouts" % prefix, self._replay_sample["timeout"].float().mean() |
| ), |
| ] |
|
|
| for robot_prefix in ["", "right_", "left_"]: |
| if not f"{robot_prefix}low_dim_state" in self._replay_sample.keys(): |
| continue |
|
|
| sums.extend( |
| [ |
| HistogramSummary( |
| f"{prefix}/{robot_prefix}low_dim_state", |
| self._replay_sample[f"{robot_prefix}low_dim_state"], |
| ), |
| HistogramSummary( |
| f"{prefix}/{robot_prefix}low_dim_state_tp1", |
| self._replay_sample[f"{robot_prefix}low_dim_state_tp1"], |
| ), |
| ScalarSummary( |
| f"{prefix}/{robot_prefix}low_dim_state_mean", |
| self._replay_sample[f"{robot_prefix}low_dim_state"].mean(), |
| ), |
| ScalarSummary( |
| f"{prefix}/{robot_prefix}low_dim_state_min", |
| self._replay_sample[f"{robot_prefix}low_dim_state"].min(), |
| ), |
| ScalarSummary( |
| f"{prefix}/{robot_prefix}low_dim_state_max", |
| self._replay_sample[f"{robot_prefix}low_dim_state"].max(), |
| ), |
| ] |
| ) |
|
|
| for k, v in self._replay_sample.items(): |
| if "rgb" in k or "point_cloud" in k: |
| if "rgb" in k: |
| |
| v = (v + 1.0) / 2.0 |
| sums.append( |
| ImageSummary( |
| "%s/%s" % (prefix, k), tile(v) if len(v.shape) > 4 else v |
| ) |
| ) |
|
|
| if "sampling_probabilities" in self._replay_sample: |
| sums.extend( |
| [ |
| HistogramSummary( |
| "replay/priority", self._replay_sample["sampling_probabilities"] |
| ), |
| ] |
| ) |
| sums.extend(self._pose_agent.update_summaries()) |
| return sums |
|
|
| def act_summaries(self) -> List[Summary]: |
| return self._pose_agent.act_summaries() |
|
|
| def load_weights(self, savedir: str): |
| self._pose_agent.load_weights(savedir) |
|
|
| def save_weights(self, savedir: str): |
| self._pose_agent.save_weights(savedir) |
|
|
| def reset(self) -> None: |
| self._pose_agent.reset() |
|
|