| import pickle |
| from typing import Dict, Optional, Sequence |
| from pathlib import Path |
| import json |
| import torch |
| import numpy as np |
|
|
|
|
| Instructions = Dict[str, Dict[int, torch.Tensor]] |
|
|
|
|
| def round_floats(o): |
| if isinstance(o, float): return round(o, 2) |
| if isinstance(o, dict): return {k: round_floats(v) for k, v in o.items()} |
| if isinstance(o, (list, tuple)): return [round_floats(x) for x in o] |
| return o |
|
|
|
|
| def normalise_quat(x: torch.Tensor): |
| return x / x.square().sum(dim=-1).sqrt().unsqueeze(-1) |
|
|
|
|
| def get_gripper_loc_bounds(path: str, buffer: float = 0.0, task: Optional[str] = None): |
| gripper_loc_bounds = json.load(open(path, "r")) |
| if task is not None and task in gripper_loc_bounds: |
| gripper_loc_bounds = gripper_loc_bounds[task] |
| gripper_loc_bounds_min = np.array(gripper_loc_bounds[0]) - buffer |
| gripper_loc_bounds_max = np.array(gripper_loc_bounds[1]) + buffer |
| gripper_loc_bounds = np.stack([gripper_loc_bounds_min, gripper_loc_bounds_max]) |
| else: |
| |
| gripper_loc_bounds = json.load(open(path, "r")) |
| gripper_loc_bounds_min = np.min(np.stack([bounds[0] for bounds in gripper_loc_bounds.values()]), axis=0) - buffer |
| gripper_loc_bounds_max = np.max(np.stack([bounds[1] for bounds in gripper_loc_bounds.values()]), axis=0) + buffer |
| gripper_loc_bounds = np.stack([gripper_loc_bounds_min, gripper_loc_bounds_max]) |
| print("Gripper workspace size:", gripper_loc_bounds_max - gripper_loc_bounds_min) |
| return gripper_loc_bounds |
|
|
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| def norm_tensor(tensor: torch.Tensor) -> torch.Tensor: |
| return tensor / torch.linalg.norm(tensor, ord=2, dim=-1, keepdim=True) |
|
|
|
|
| def load_instructions( |
| instructions: Optional[Path], |
| tasks: Optional[Sequence[str]] = None, |
| variations: Optional[Sequence[int]] = None, |
| ) -> Optional[Instructions]: |
| if instructions is not None: |
| with open(instructions, "rb") as fid: |
| data: Instructions = pickle.load(fid) |
| if tasks is not None: |
| data = {task: var_instr for task, var_instr in data.items() if task in tasks} |
| if variations is not None: |
| data = { |
| task: { |
| var: instr for var, instr in var_instr.items() if var in variations |
| } |
| for task, var_instr in data.items() |
| } |
| return data |
| return None |
|
|