from __future__ import annotations import random from typing import Literal, NamedTuple, Callable import numpy as np import torch import wandb from matplotlib import pyplot as plt from scipy.signal import medfilt from scipy.signal import savgol_filter, argrelextrema import uvd.utils as U from uvd.decomp.kernel_reg import KernelRegression def linear_random_skip( cur_step: int, next_goal_step: int, ratio: float = 0.0, progress_lower: float = 0.3 ) -> bool: """Progress toward next goal exceed `progress_lower`, linearly increasing temperature for P(skip) <= ratio.""" if ratio == 0.0 or progress_lower == 1: return False assert cur_step <= next_goal_step, f"{cur_step} > {next_goal_step}" # linearly increase until ratio progress = cur_step / next_goal_step temp = ( 1.0 if progress < progress_lower else ((progress - progress_lower) / (1 - progress_lower)) * ratio ) if 0 < temp < random.random(): return True return False class DecompMeta(NamedTuple): milestone_indices: list milestone_starts: list | None = None iter_curves: list[np.ndarray] | None = None def _debug_plt( xs: np.ndarray, embed_distances: np.ndarray, starts: np.ndarray | list, ends: np.ndarray | list, return_numpy: bool = False, ): fig = plt.figure() for s in starts: plt.axvline(x=s, linestyle="--") for e in ends: plt.axvline(x=e, linestyle="dotted") plt.plot(xs, np.gradient(embed_distances), linewidth=1.5, label="1st derivative") plt.plot( xs, np.gradient(np.gradient(embed_distances)), linewidth=1.5, label="2nd derivative", ) plt.plot(xs, embed_distances, linewidth=1.5, label="embedding distance") plt.legend() if return_numpy: # fig.canvas.draw() return U.plt_to_numpy(fig) else: plt.show() def embed_decomp_no_robot( embeddings: np.ndarray | torch.Tensor, no_robot_embeddings: np.ndarray | torch.Tensor, window_length: int = 10, derivative_order: int = 1, derivative_threshold: float = 1e-3, threshold_subgoal_passing: float | None = None, force_interleave: bool = False, debug_plt: bool = False, debug_plt_to_wandb: bool = False, task_name: str | None = None, fill_embeddings: bool = True, ): debug_plt = debug_plt and U.is_rank_zero() debug_plt_to_wandb = debug_plt and debug_plt_to_wandb if threshold_subgoal_passing is not None: assert 0 < threshold_subgoal_passing <= 1.0, threshold_subgoal_passing if isinstance(embeddings, torch.Tensor): device = embeddings.device else: device = None # clip based preprocessors would have bf16 by default embeddings = U.any_to_numpy(embeddings, dtype="float32") no_robot_embeddings = U.any_to_numpy(no_robot_embeddings, dtype="float32") # L, N (though can be rgb for debugging as well) traj_length = embeddings.shape[0] debug_plt_figs = [] if fill_embeddings: milestone_embeddings = [] else: milestone_embeddings = None # milestone indices, i.e. end of obj changing subgoal_indices = [] # indices when obj changing starts, i.e. milestone indices for hand reaching subgoal_starts = [] cur_subgoal_idx = traj_length - 1 # back to start while cur_subgoal_idx > 15: unnormalized_embed_distance = np.linalg.norm( no_robot_embeddings[: cur_subgoal_idx + 1] - no_robot_embeddings[cur_subgoal_idx], axis=1, ) cur_embed_distance = unnormalized_embed_distance / np.linalg.norm( no_robot_embeddings[0] - no_robot_embeddings[cur_subgoal_idx] ) cur_embed_distance = medfilt(cur_embed_distance, kernel_size=None) # smooth if derivative_order == 1: slope = np.gradient(cur_embed_distance) elif derivative_order == 2: slope = np.gradient(np.gradient(cur_embed_distance)) else: raise NotImplementedError(derivative_order) valid_slope_indices = np.where(np.abs(slope) >= derivative_threshold)[0] # Find the differences between consecutive valid_slope_indices diffs = np.diff(valid_slope_indices) # Find indices where the difference is greater than window_length + 1 break_indices = np.where(diffs > window_length + 1)[0] # Extract start and end positions start_positions = valid_slope_indices[np.concatenate([[0], break_indices + 1])] end_positions = valid_slope_indices[ np.concatenate((break_indices, [len(valid_slope_indices) - 1])) ] # small tolerance start_positions = [max(0, p - 1) for p in start_positions] end_positions = [min(traj_length - 1, p + 1) for p in end_positions] if debug_plt: x = np.arange(0, len(cur_embed_distance)) fig = _debug_plt( x, cur_embed_distance, start_positions, end_positions, return_numpy=debug_plt_to_wandb, ) if debug_plt_to_wandb: debug_plt_figs.append(fig) subgoal_indices.append(cur_subgoal_idx) subgoal_starts.append(start_positions[-1]) if len(end_positions) < 2 or end_positions[-2] < 15: if threshold_subgoal_passing is not None: cur_milestone_dist = None if len(subgoal_indices) > 1: cur_milestone_dist = np.linalg.norm( no_robot_embeddings[cur_subgoal_idx] # cur - no_robot_embeddings[subgoal_indices[-2]] # prev ) # use the order from start to end if fill_embeddings: for step in reversed(range(cur_subgoal_idx + 1)): if ( len(subgoal_indices) == 1 or unnormalized_embed_distance[step] / cur_milestone_dist > threshold_subgoal_passing ): milestone_embeddings.append(embeddings[cur_subgoal_idx]) else: milestone_embeddings.append(embeddings[subgoal_indices[-2]]) break if threshold_subgoal_passing is not None: cur_milestone_dist = None if len(subgoal_indices) > 1: cur_milestone_dist = np.linalg.norm( no_robot_embeddings[cur_subgoal_idx] # cur - no_robot_embeddings[subgoal_indices[-2]] # prev ) if fill_embeddings: for step in reversed(range(end_positions[-2] + 1, cur_subgoal_idx + 1)): # not pass threshold or 1st iter with only last frame contained if ( len(subgoal_indices) == 1 or unnormalized_embed_distance[step] / cur_milestone_dist > threshold_subgoal_passing ): # use the real subgoal milestone_embeddings.append(embeddings[cur_subgoal_idx]) else: # skip to the next subgoal if passing threshold milestone_embeddings.append(embeddings[subgoal_indices[-2]]) cur_subgoal_idx = end_positions[-2] subgoal_starts = list(reversed(subgoal_starts)) subgoal_indices = list(reversed(subgoal_indices)) if len(debug_plt_figs) > 0 and wandb.run is not None: debug_plt_figs = np.concatenate(debug_plt_figs, axis=1) wandb.log( { f"decomp_curves/{task_name}": wandb.Image( debug_plt_figs, caption=f"starts: {subgoal_starts}, ends: {subgoal_indices}", ) } ) assert len(subgoal_starts) == len( subgoal_indices ), f"{subgoal_starts}, {subgoal_indices}" if force_interleave: _starts, _ends = [], [] for i in range(len(subgoal_starts)): if subgoal_starts[i] < subgoal_indices[i]: _starts.append(subgoal_starts[i]) _ends.append(subgoal_indices[i]) else: U.get_logger().warning( f"{subgoal_starts} & {subgoal_indices} not interleaved" ) if fill_embeddings: if threshold_subgoal_passing is not None: milestone_embeddings = np.stack(list(reversed(milestone_embeddings))) else: # slightly faster to do once here without threshold checking milestone_embeddings = np.concatenate( [embeddings[subgoal_indices[0], ...][None]] + [ np.full((end - start, *embeddings.shape[1:]), embeddings[end, ...]) for start, end in zip([0] + subgoal_indices[:-1], subgoal_indices) ], ) if device is not None: milestone_embeddings = U.any_to_torch_tensor( milestone_embeddings, device=device ) return milestone_embeddings, DecompMeta( milestone_indices=subgoal_indices, milestone_starts=subgoal_starts ) def embed_decomp_no_robot_extended( embeddings: np.ndarray | torch.Tensor, no_robot_embeddings: np.ndarray | torch.Tensor, threshold_subgoal_passing: float | None = None, **kwargs, ): kwargs["fill_embeddings"] = False _, decomp_meta = embed_decomp_no_robot( embeddings, no_robot_embeddings, threshold_subgoal_passing=None, **kwargs, ) milestone_indices = decomp_meta.milestone_indices milestone_starts = decomp_meta.milestone_starts norm = ( np.linalg.norm if isinstance(embeddings[0], np.ndarray) else torch.linalg.norm ) assert len(milestone_starts) == len(milestone_indices) milestone_embeddings = [] hybrid_indices = list(sorted(milestone_starts + milestone_indices)) prev_idx = -1 s = -1 init_dist = None for i, goal_idx in enumerate(hybrid_indices): once_passed = False for _ in range(goal_idx - prev_idx): s += 1 if threshold_subgoal_passing is None: milestone_embeddings.append(embeddings[goal_idx]) elif once_passed: milestone_embeddings.append(embeddings[hybrid_indices[i + 1]]) else: raw_cur_dist = float(norm(embeddings[s] - embeddings[goal_idx])) if init_dist is None: assert s == 0, s cur_dist = 1.0 init_dist = max(raw_cur_dist, 1e-7) else: cur_dist = raw_cur_dist / init_dist if ( cur_dist <= threshold_subgoal_passing and i < len(hybrid_indices) - 1 ): init_dist = float( norm(embeddings[s] - embeddings[hybrid_indices[i + 1]]) ) init_dist = max(init_dist, 1e-7) once_passed = True milestone_embeddings.append(embeddings[hybrid_indices[i + 1]]) else: milestone_embeddings.append(embeddings[goal_idx]) prev_idx = goal_idx milestone_embeddings = U.any_stack(milestone_embeddings) U.assert_(milestone_embeddings.shape, embeddings.shape) return milestone_embeddings, DecompMeta(milestone_indices=hybrid_indices) def get_hybrid_milestones( start_embeddings: np.ndarray, # w. robot end_embeddings: np.ndarray, # w.o robot milestone_starts: list, milestone_indices: list, ) -> np.ndarray: assert len(milestone_starts) == len(milestone_indices) assert len(start_embeddings) == len(end_embeddings) milestone_only = len(milestone_starts) == len(start_embeddings) hybrid_milestones = np.empty( (start_embeddings.shape[0] * 2, *start_embeddings.shape[1:]), dtype=start_embeddings.dtype, ) hybrid_milestones[::2] = ( start_embeddings if milestone_only else start_embeddings[milestone_starts] ) hybrid_milestones[1::2] = ( end_embeddings if milestone_only else end_embeddings[milestone_indices] ) return hybrid_milestones def embedding_decomp( embeddings: np.ndarray | torch.Tensor, normalize_curve: bool = True, min_interval: int = 18, window_length: int | None = None, smooth_method: Literal["kernel", "savgol"] = "kernel", extrema_comparator: Callable = np.greater, fill_embeddings: bool = True, return_intermediate_curves: bool = False, **kwargs, ) -> tuple[torch.Tensor | np.ndarray, DecompMeta]: if torch.is_tensor(embeddings): device = embeddings.device embeddings = U.any_to_numpy(embeddings) else: device = None # L, N assert embeddings.ndim == 2, embeddings.shape traj_length = embeddings.shape[0] cur_goal_idx = traj_length - 1 goal_indices = [cur_goal_idx] cur_embeddings = embeddings[ max(0, cur_goal_idx - (window_length or cur_goal_idx)) : cur_goal_idx + 1 ] iterate_num = 0 iter_curves = [] if return_intermediate_curves else None while cur_goal_idx > (window_length or min_interval): iterate_num += 1 # get goal embedding goal_embedding = cur_embeddings[-1] distances = np.linalg.norm(cur_embeddings - goal_embedding, axis=1) if normalize_curve: distances = distances / np.linalg.norm(cur_embeddings[0] - goal_embedding) x = np.arange( max(0, cur_goal_idx - (window_length or cur_goal_idx)), cur_goal_idx + 1 ) if smooth_method == "kernel": smooth_kwargs = dict(kernel="rbf", gamma=0.08) smooth_kwargs.update(kwargs or {}) kr = KernelRegression(**smooth_kwargs) kr.fit(x.reshape(-1, 1), distances) distance_smoothed = kr.predict(x.reshape(-1, 1)) elif smooth_method == "savgol": smooth_kwargs = dict(window_length=85, polyorder=2, mode="nearest") smooth_kwargs.update(kwargs or {}) distance_smoothed = savgol_filter(distances, **smooth_kwargs) elif smooth_method is None: distance_smoothed = distances else: raise NotImplementedError(smooth_method) if iter_curves is not None: iter_curves.append(distance_smoothed) extrema_indices = argrelextrema(distance_smoothed, extrema_comparator)[0] x_extrema = x[extrema_indices] update_goal = False for i in range(len(x_extrema) - 1, -1, -1): if cur_goal_idx < min_interval: break if ( cur_goal_idx - x_extrema[i] > min_interval and x_extrema[i] > min_interval ): cur_goal_idx = x_extrema[i] update_goal = True goal_indices.append(cur_goal_idx) break if not update_goal or cur_goal_idx < min_interval: break cur_embeddings = embeddings[ max(0, cur_goal_idx - (window_length or cur_goal_idx)) : cur_goal_idx + 1 ] goal_indices = goal_indices[::-1] if fill_embeddings: milestone_embeddings = np.concatenate( [embeddings[goal_indices[0], ...][None]] + [ np.full((end - start, *embeddings.shape[1:]), embeddings[end, ...]) for start, end in zip([0] + goal_indices[:-1], goal_indices) ], ) if device is not None: milestone_embeddings = U.any_to_torch_tensor( milestone_embeddings, device=device ) else: milestone_embeddings = None return milestone_embeddings, DecompMeta( milestone_indices=goal_indices, iter_curves=iter_curves ) def goal_idx_from_mask(goal_achieved_mask): diff = np.diff(goal_achieved_mask) goal_indices = np.where(diff != 0)[0] + 1 traj_length = goal_achieved_mask.shape[0] goal_indices[-1] = traj_length - 1 # last goal_indices = goal_indices.tolist() return goal_indices def oracle_decomp( embeddings: np.ndarray | torch.Tensor | None, goal_achieved_mask: np.ndarray, random_skip_ratio: float | None = None, linearly_random_skip_lower: float | None = None, fill_embeddings: bool = True, ) -> tuple[torch.Tensor | np.ndarray, DecompMeta]: """Note: embeddings here only has the oracle subgoals, not full trajectory""" goal_indices = goal_idx_from_mask(goal_achieved_mask) if not fill_embeddings: return None, DecompMeta(milestone_indices=goal_indices) traj_length = goal_achieved_mask.shape[0] assert embeddings.shape[0] < traj_length, embeddings.shape milestone_embeddings = ( torch.empty( (goal_achieved_mask.shape[0], *embeddings.shape[1:]), dtype=embeddings.dtype, device=embeddings.device, ) if not isinstance(embeddings, np.ndarray) else np.empty( (goal_achieved_mask.shape[0], *embeddings.shape[1:]), dtype=embeddings.dtype, ) ) assert len(goal_indices) == len(embeddings), goal_indices last_embedding = embeddings[-1] for i, idx in enumerate(goal_achieved_mask): if idx >= embeddings.shape[0]: # If the index in the mask is greater than the highest index in the embedding, # just use the last row of the embedding milestone_embeddings[i] = last_embedding else: skip = False if linearly_random_skip_lower is not None: skip = linear_random_skip( cur_step=i, next_goal_step=goal_indices[idx], ratio=random_skip_ratio, progress_lower=linearly_random_skip_lower, ) elif ( random_skip_ratio is not None and 0 < random_skip_ratio < random.random() ): skip = True if skip: milestone_embeddings[i] = embeddings[min(idx + 1, len(embeddings) - 1)] else: milestone_embeddings[i] = embeddings[idx] return milestone_embeddings, DecompMeta(milestone_indices=goal_indices) def random_decomp( embeddings: np.ndarray | torch.Tensor, num_milestones: int | tuple[int, int], fill_embeddings: bool = True, ) -> tuple[torch.Tensor | np.ndarray, DecompMeta]: if not isinstance(num_milestones, int): assert len(num_milestones) == 2, num_milestones # by randomly sample from lower and higher bound num_milestones = random.randint(*num_milestones) traj_length = embeddings.shape[0] goal_indices = random.sample(range(traj_length), k=num_milestones) goal_indices = list(sorted(goal_indices)) if fill_embeddings: milestone_embeddings = ( torch.empty_like( embeddings, dtype=embeddings.dtype, device=embeddings.device ) if not isinstance(embeddings, np.ndarray) else np.empty_like(embeddings, dtype=embeddings.dtype) ) for i, goal_idx in enumerate(goal_indices): milestone_embeddings[ (goal_indices[i - 1] + 1) if i != 0 else 0 : goal_idx + 1 ] = embeddings[goal_idx] else: milestone_embeddings = None return milestone_embeddings, DecompMeta(milestone_indices=goal_indices) def equally_decomp( embeddings: np.ndarray | torch.Tensor, num_milestones: int | tuple[int, int], fill_embeddings: bool = True, ) -> tuple[torch.Tensor | np.ndarray, DecompMeta]: if not isinstance(num_milestones, int): assert len(num_milestones) == 2, num_milestones # by randomly sample from lower and higher bound num_milestones = random.randint(*num_milestones) traj_length = embeddings.shape[0] indices = np.linspace(0, traj_length - 1, num_milestones + 1, dtype=int) if fill_embeddings: milestone_embeddings = ( torch.empty_like( embeddings, dtype=embeddings.dtype, device=embeddings.device ) if not isinstance(embeddings, np.ndarray) else np.empty_like( embeddings, dtype=embeddings.dtype, ) ) for i, goal_idx in enumerate(indices[1:], start=1): milestone_embeddings[ (indices[i - 1] + 1) if i != 1 else 0 : goal_idx + 1 ] = embeddings[goal_idx] else: milestone_embeddings = None return milestone_embeddings, DecompMeta(milestone_indices=indices[1:].tolist()) def near_future_decomp( embeddings: np.ndarray | torch.Tensor, advance_steps: int, **kwargs ) -> tuple[torch.Tensor | np.ndarray, DecompMeta]: return equally_decomp( embeddings, num_milestones=embeddings.shape[0] // advance_steps, **kwargs ) def no_decomp( embeddings: np.ndarray | torch.Tensor, fill_embeddings: bool = True ) -> tuple[torch.Tensor | np.ndarray, DecompMeta]: """Only conditioned on final goal.""" if not fill_embeddings: return None, DecompMeta(milestone_indices=[-1]) return embeddings[-1, ...].expand_as(embeddings).clone() if not isinstance( embeddings, np.ndarray ) else np.full( embeddings.shape, embeddings[-1, ...], dtype=embeddings.dtype, ), DecompMeta( milestone_indices=[-1] ) def decomp_trajectories( method_name: Literal[ "embed", "embed_no_robot", "oracle", "random", "equally", "near_future" ] | None, embeddings: np.ndarray | torch.Tensor, **kwargs, ) -> tuple[torch.Tensor | np.ndarray, DecompMeta]: assert embeddings.ndim == 2 or embeddings.ndim == 4, ( f"input embedding should be either 2 dimensional, " f"with (L, feature_dim), or raw rgb with shape (L, H, W, 3), " f"but get {embeddings.shape}" ) if method_name is None: return no_decomp(embeddings) assert method_name in DEFAULT_DECOMP_KWARGS, method_name method_kwargs = DEFAULT_DECOMP_KWARGS[method_name] method_kwargs.update(kwargs) if method_name == "embed": return embedding_decomp(embeddings=embeddings, **method_kwargs) elif method_name == "embed_no_robot": return embed_decomp_no_robot(embeddings=embeddings, **method_kwargs) elif method_name == "embed_no_robot_extended": return embed_decomp_no_robot_extended(embeddings=embeddings, **method_kwargs) elif method_name == "oracle": return oracle_decomp(embeddings, **method_kwargs) elif method_name == "random": return random_decomp(embeddings, **method_kwargs) elif method_name == "equally": return equally_decomp(embeddings, **method_kwargs) elif method_name == "near_future": return near_future_decomp(embeddings, **method_kwargs) raise NotImplementedError(method_name) DEFAULT_DECOMP_KWARGS = dict( embed=dict( normalize_curve=False, min_interval=18, smooth_method="kernel", gamma=0.08, ), embed_no_robot=dict( window_length=8, derivative_order=1, derivative_threshold=1e-3, threshold_subgoal_passing=None, ), embed_no_robot_extended=dict( window_length=3, derivative_order=1, derivative_threshold=1e-3, threshold_subgoal_passing=None, ), oracle=dict(), random=dict(num_milestones=(3, 6)), equally=dict(num_milestones=(3, 6)), near_future=dict(advance_steps=5), )