import pickle from typing import Dict, Optional, Sequence from pathlib import Path import json import torch import numpy as np Instructions = Dict[str, Dict[int, torch.Tensor]] def round_floats(o): if isinstance(o, float): return round(o, 2) if isinstance(o, dict): return {k: round_floats(v) for k, v in o.items()} if isinstance(o, (list, tuple)): return [round_floats(x) for x in o] return o def normalise_quat(x: torch.Tensor): return x / x.square().sum(dim=-1).sqrt().unsqueeze(-1) def get_gripper_loc_bounds(path: str, buffer: float = 0.0, task: Optional[str] = None): gripper_loc_bounds = json.load(open(path, "r")) if task is not None and task in gripper_loc_bounds: gripper_loc_bounds = gripper_loc_bounds[task] gripper_loc_bounds_min = np.array(gripper_loc_bounds[0]) - buffer gripper_loc_bounds_max = np.array(gripper_loc_bounds[1]) + buffer gripper_loc_bounds = np.stack([gripper_loc_bounds_min, gripper_loc_bounds_max]) else: # Gripper workspace is the union of workspaces for all tasks gripper_loc_bounds = json.load(open(path, "r")) gripper_loc_bounds_min = np.min(np.stack([bounds[0] for bounds in gripper_loc_bounds.values()]), axis=0) - buffer gripper_loc_bounds_max = np.max(np.stack([bounds[1] for bounds in gripper_loc_bounds.values()]), axis=0) + buffer gripper_loc_bounds = np.stack([gripper_loc_bounds_min, gripper_loc_bounds_max]) print("Gripper workspace size:", gripper_loc_bounds_max - gripper_loc_bounds_min) return gripper_loc_bounds def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def norm_tensor(tensor: torch.Tensor) -> torch.Tensor: return tensor / torch.linalg.norm(tensor, ord=2, dim=-1, keepdim=True) def load_instructions( instructions: Optional[Path], tasks: Optional[Sequence[str]] = None, variations: Optional[Sequence[int]] = None, ) -> Optional[Instructions]: if instructions is not None: with open(instructions, "rb") as fid: data: Instructions = pickle.load(fid) if tasks is not None: data = {task: var_instr for task, var_instr in data.items() if task in tasks} if variations is not None: data = { task: { var: instr for var, instr in var_instr.items() if var in variations } for task, var_instr in data.items() } return data return None