UVD / uvd /decomp /decomp.py
ryanhoangt's picture
Upload folder using huggingface_hub
c456c14 verified
from __future__ import annotations
import random
from typing import Literal, NamedTuple, Callable
import numpy as np
import torch
import wandb
from matplotlib import pyplot as plt
from scipy.signal import medfilt
from scipy.signal import savgol_filter, argrelextrema
import uvd.utils as U
from uvd.decomp.kernel_reg import KernelRegression
def linear_random_skip(
cur_step: int, next_goal_step: int, ratio: float = 0.0, progress_lower: float = 0.3
) -> bool:
"""Progress toward next goal exceed `progress_lower`, linearly increasing
temperature for P(skip) <= ratio."""
if ratio == 0.0 or progress_lower == 1:
return False
assert cur_step <= next_goal_step, f"{cur_step} > {next_goal_step}"
# linearly increase until ratio
progress = cur_step / next_goal_step
temp = (
1.0
if progress < progress_lower
else ((progress - progress_lower) / (1 - progress_lower)) * ratio
)
if 0 < temp < random.random():
return True
return False
class DecompMeta(NamedTuple):
milestone_indices: list
milestone_starts: list | None = None
iter_curves: list[np.ndarray] | None = None
def _debug_plt(
xs: np.ndarray,
embed_distances: np.ndarray,
starts: np.ndarray | list,
ends: np.ndarray | list,
return_numpy: bool = False,
):
fig = plt.figure()
for s in starts:
plt.axvline(x=s, linestyle="--")
for e in ends:
plt.axvline(x=e, linestyle="dotted")
plt.plot(xs, np.gradient(embed_distances), linewidth=1.5, label="1st derivative")
plt.plot(
xs,
np.gradient(np.gradient(embed_distances)),
linewidth=1.5,
label="2nd derivative",
)
plt.plot(xs, embed_distances, linewidth=1.5, label="embedding distance")
plt.legend()
if return_numpy:
# fig.canvas.draw()
return U.plt_to_numpy(fig)
else:
plt.show()
def embed_decomp_no_robot(
embeddings: np.ndarray | torch.Tensor,
no_robot_embeddings: np.ndarray | torch.Tensor,
window_length: int = 10,
derivative_order: int = 1,
derivative_threshold: float = 1e-3,
threshold_subgoal_passing: float | None = None,
force_interleave: bool = False,
debug_plt: bool = False,
debug_plt_to_wandb: bool = False,
task_name: str | None = None,
fill_embeddings: bool = True,
):
debug_plt = debug_plt and U.is_rank_zero()
debug_plt_to_wandb = debug_plt and debug_plt_to_wandb
if threshold_subgoal_passing is not None:
assert 0 < threshold_subgoal_passing <= 1.0, threshold_subgoal_passing
if isinstance(embeddings, torch.Tensor):
device = embeddings.device
else:
device = None
# clip based preprocessors would have bf16 by default
embeddings = U.any_to_numpy(embeddings, dtype="float32")
no_robot_embeddings = U.any_to_numpy(no_robot_embeddings, dtype="float32")
# L, N (though can be rgb for debugging as well)
traj_length = embeddings.shape[0]
debug_plt_figs = []
if fill_embeddings:
milestone_embeddings = []
else:
milestone_embeddings = None
# milestone indices, i.e. end of obj changing
subgoal_indices = []
# indices when obj changing starts, i.e. milestone indices for hand reaching
subgoal_starts = []
cur_subgoal_idx = traj_length - 1
# back to start
while cur_subgoal_idx > 15:
unnormalized_embed_distance = np.linalg.norm(
no_robot_embeddings[: cur_subgoal_idx + 1]
- no_robot_embeddings[cur_subgoal_idx],
axis=1,
)
cur_embed_distance = unnormalized_embed_distance / np.linalg.norm(
no_robot_embeddings[0] - no_robot_embeddings[cur_subgoal_idx]
)
cur_embed_distance = medfilt(cur_embed_distance, kernel_size=None) # smooth
if derivative_order == 1:
slope = np.gradient(cur_embed_distance)
elif derivative_order == 2:
slope = np.gradient(np.gradient(cur_embed_distance))
else:
raise NotImplementedError(derivative_order)
valid_slope_indices = np.where(np.abs(slope) >= derivative_threshold)[0]
# Find the differences between consecutive valid_slope_indices
diffs = np.diff(valid_slope_indices)
# Find indices where the difference is greater than window_length + 1
break_indices = np.where(diffs > window_length + 1)[0]
# Extract start and end positions
start_positions = valid_slope_indices[np.concatenate([[0], break_indices + 1])]
end_positions = valid_slope_indices[
np.concatenate((break_indices, [len(valid_slope_indices) - 1]))
]
# small tolerance
start_positions = [max(0, p - 1) for p in start_positions]
end_positions = [min(traj_length - 1, p + 1) for p in end_positions]
if debug_plt:
x = np.arange(0, len(cur_embed_distance))
fig = _debug_plt(
x,
cur_embed_distance,
start_positions,
end_positions,
return_numpy=debug_plt_to_wandb,
)
if debug_plt_to_wandb:
debug_plt_figs.append(fig)
subgoal_indices.append(cur_subgoal_idx)
subgoal_starts.append(start_positions[-1])
if len(end_positions) < 2 or end_positions[-2] < 15:
if threshold_subgoal_passing is not None:
cur_milestone_dist = None
if len(subgoal_indices) > 1:
cur_milestone_dist = np.linalg.norm(
no_robot_embeddings[cur_subgoal_idx] # cur
- no_robot_embeddings[subgoal_indices[-2]] # prev
) # use the order from start to end
if fill_embeddings:
for step in reversed(range(cur_subgoal_idx + 1)):
if (
len(subgoal_indices) == 1
or unnormalized_embed_distance[step] / cur_milestone_dist
> threshold_subgoal_passing
):
milestone_embeddings.append(embeddings[cur_subgoal_idx])
else:
milestone_embeddings.append(embeddings[subgoal_indices[-2]])
break
if threshold_subgoal_passing is not None:
cur_milestone_dist = None
if len(subgoal_indices) > 1:
cur_milestone_dist = np.linalg.norm(
no_robot_embeddings[cur_subgoal_idx] # cur
- no_robot_embeddings[subgoal_indices[-2]] # prev
)
if fill_embeddings:
for step in reversed(range(end_positions[-2] + 1, cur_subgoal_idx + 1)):
# not pass threshold or 1st iter with only last frame contained
if (
len(subgoal_indices) == 1
or unnormalized_embed_distance[step] / cur_milestone_dist
> threshold_subgoal_passing
):
# use the real subgoal
milestone_embeddings.append(embeddings[cur_subgoal_idx])
else:
# skip to the next subgoal if passing threshold
milestone_embeddings.append(embeddings[subgoal_indices[-2]])
cur_subgoal_idx = end_positions[-2]
subgoal_starts = list(reversed(subgoal_starts))
subgoal_indices = list(reversed(subgoal_indices))
if len(debug_plt_figs) > 0 and wandb.run is not None:
debug_plt_figs = np.concatenate(debug_plt_figs, axis=1)
wandb.log(
{
f"decomp_curves/{task_name}": wandb.Image(
debug_plt_figs,
caption=f"starts: {subgoal_starts}, ends: {subgoal_indices}",
)
}
)
assert len(subgoal_starts) == len(
subgoal_indices
), f"{subgoal_starts}, {subgoal_indices}"
if force_interleave:
_starts, _ends = [], []
for i in range(len(subgoal_starts)):
if subgoal_starts[i] < subgoal_indices[i]:
_starts.append(subgoal_starts[i])
_ends.append(subgoal_indices[i])
else:
U.get_logger().warning(
f"{subgoal_starts} & {subgoal_indices} not interleaved"
)
if fill_embeddings:
if threshold_subgoal_passing is not None:
milestone_embeddings = np.stack(list(reversed(milestone_embeddings)))
else:
# slightly faster to do once here without threshold checking
milestone_embeddings = np.concatenate(
[embeddings[subgoal_indices[0], ...][None]]
+ [
np.full((end - start, *embeddings.shape[1:]), embeddings[end, ...])
for start, end in zip([0] + subgoal_indices[:-1], subgoal_indices)
],
)
if device is not None:
milestone_embeddings = U.any_to_torch_tensor(
milestone_embeddings, device=device
)
return milestone_embeddings, DecompMeta(
milestone_indices=subgoal_indices, milestone_starts=subgoal_starts
)
def embed_decomp_no_robot_extended(
embeddings: np.ndarray | torch.Tensor,
no_robot_embeddings: np.ndarray | torch.Tensor,
threshold_subgoal_passing: float | None = None,
**kwargs,
):
kwargs["fill_embeddings"] = False
_, decomp_meta = embed_decomp_no_robot(
embeddings,
no_robot_embeddings,
threshold_subgoal_passing=None,
**kwargs,
)
milestone_indices = decomp_meta.milestone_indices
milestone_starts = decomp_meta.milestone_starts
norm = (
np.linalg.norm if isinstance(embeddings[0], np.ndarray) else torch.linalg.norm
)
assert len(milestone_starts) == len(milestone_indices)
milestone_embeddings = []
hybrid_indices = list(sorted(milestone_starts + milestone_indices))
prev_idx = -1
s = -1
init_dist = None
for i, goal_idx in enumerate(hybrid_indices):
once_passed = False
for _ in range(goal_idx - prev_idx):
s += 1
if threshold_subgoal_passing is None:
milestone_embeddings.append(embeddings[goal_idx])
elif once_passed:
milestone_embeddings.append(embeddings[hybrid_indices[i + 1]])
else:
raw_cur_dist = float(norm(embeddings[s] - embeddings[goal_idx]))
if init_dist is None:
assert s == 0, s
cur_dist = 1.0
init_dist = max(raw_cur_dist, 1e-7)
else:
cur_dist = raw_cur_dist / init_dist
if (
cur_dist <= threshold_subgoal_passing
and i < len(hybrid_indices) - 1
):
init_dist = float(
norm(embeddings[s] - embeddings[hybrid_indices[i + 1]])
)
init_dist = max(init_dist, 1e-7)
once_passed = True
milestone_embeddings.append(embeddings[hybrid_indices[i + 1]])
else:
milestone_embeddings.append(embeddings[goal_idx])
prev_idx = goal_idx
milestone_embeddings = U.any_stack(milestone_embeddings)
U.assert_(milestone_embeddings.shape, embeddings.shape)
return milestone_embeddings, DecompMeta(milestone_indices=hybrid_indices)
def get_hybrid_milestones(
start_embeddings: np.ndarray, # w. robot
end_embeddings: np.ndarray, # w.o robot
milestone_starts: list,
milestone_indices: list,
) -> np.ndarray:
assert len(milestone_starts) == len(milestone_indices)
assert len(start_embeddings) == len(end_embeddings)
milestone_only = len(milestone_starts) == len(start_embeddings)
hybrid_milestones = np.empty(
(start_embeddings.shape[0] * 2, *start_embeddings.shape[1:]),
dtype=start_embeddings.dtype,
)
hybrid_milestones[::2] = (
start_embeddings if milestone_only else start_embeddings[milestone_starts]
)
hybrid_milestones[1::2] = (
end_embeddings if milestone_only else end_embeddings[milestone_indices]
)
return hybrid_milestones
def embedding_decomp(
embeddings: np.ndarray | torch.Tensor,
normalize_curve: bool = True,
min_interval: int = 18,
window_length: int | None = None,
smooth_method: Literal["kernel", "savgol"] = "kernel",
extrema_comparator: Callable = np.greater,
fill_embeddings: bool = True,
return_intermediate_curves: bool = False,
**kwargs,
) -> tuple[torch.Tensor | np.ndarray, DecompMeta]:
if torch.is_tensor(embeddings):
device = embeddings.device
embeddings = U.any_to_numpy(embeddings)
else:
device = None
# L, N
assert embeddings.ndim == 2, embeddings.shape
traj_length = embeddings.shape[0]
cur_goal_idx = traj_length - 1
goal_indices = [cur_goal_idx]
cur_embeddings = embeddings[
max(0, cur_goal_idx - (window_length or cur_goal_idx)) : cur_goal_idx + 1
]
iterate_num = 0
iter_curves = [] if return_intermediate_curves else None
while cur_goal_idx > (window_length or min_interval):
iterate_num += 1
# get goal embedding
goal_embedding = cur_embeddings[-1]
distances = np.linalg.norm(cur_embeddings - goal_embedding, axis=1)
if normalize_curve:
distances = distances / np.linalg.norm(cur_embeddings[0] - goal_embedding)
x = np.arange(
max(0, cur_goal_idx - (window_length or cur_goal_idx)), cur_goal_idx + 1
)
if smooth_method == "kernel":
smooth_kwargs = dict(kernel="rbf", gamma=0.08)
smooth_kwargs.update(kwargs or {})
kr = KernelRegression(**smooth_kwargs)
kr.fit(x.reshape(-1, 1), distances)
distance_smoothed = kr.predict(x.reshape(-1, 1))
elif smooth_method == "savgol":
smooth_kwargs = dict(window_length=85, polyorder=2, mode="nearest")
smooth_kwargs.update(kwargs or {})
distance_smoothed = savgol_filter(distances, **smooth_kwargs)
elif smooth_method is None:
distance_smoothed = distances
else:
raise NotImplementedError(smooth_method)
if iter_curves is not None:
iter_curves.append(distance_smoothed)
extrema_indices = argrelextrema(distance_smoothed, extrema_comparator)[0]
x_extrema = x[extrema_indices]
update_goal = False
for i in range(len(x_extrema) - 1, -1, -1):
if cur_goal_idx < min_interval:
break
if (
cur_goal_idx - x_extrema[i] > min_interval
and x_extrema[i] > min_interval
):
cur_goal_idx = x_extrema[i]
update_goal = True
goal_indices.append(cur_goal_idx)
break
if not update_goal or cur_goal_idx < min_interval:
break
cur_embeddings = embeddings[
max(0, cur_goal_idx - (window_length or cur_goal_idx)) : cur_goal_idx + 1
]
goal_indices = goal_indices[::-1]
if fill_embeddings:
milestone_embeddings = np.concatenate(
[embeddings[goal_indices[0], ...][None]]
+ [
np.full((end - start, *embeddings.shape[1:]), embeddings[end, ...])
for start, end in zip([0] + goal_indices[:-1], goal_indices)
],
)
if device is not None:
milestone_embeddings = U.any_to_torch_tensor(
milestone_embeddings, device=device
)
else:
milestone_embeddings = None
return milestone_embeddings, DecompMeta(
milestone_indices=goal_indices, iter_curves=iter_curves
)
def goal_idx_from_mask(goal_achieved_mask):
diff = np.diff(goal_achieved_mask)
goal_indices = np.where(diff != 0)[0] + 1
traj_length = goal_achieved_mask.shape[0]
goal_indices[-1] = traj_length - 1 # last
goal_indices = goal_indices.tolist()
return goal_indices
def oracle_decomp(
embeddings: np.ndarray | torch.Tensor | None,
goal_achieved_mask: np.ndarray,
random_skip_ratio: float | None = None,
linearly_random_skip_lower: float | None = None,
fill_embeddings: bool = True,
) -> tuple[torch.Tensor | np.ndarray, DecompMeta]:
"""Note: embeddings here only has the oracle subgoals, not full trajectory"""
goal_indices = goal_idx_from_mask(goal_achieved_mask)
if not fill_embeddings:
return None, DecompMeta(milestone_indices=goal_indices)
traj_length = goal_achieved_mask.shape[0]
assert embeddings.shape[0] < traj_length, embeddings.shape
milestone_embeddings = (
torch.empty(
(goal_achieved_mask.shape[0], *embeddings.shape[1:]),
dtype=embeddings.dtype,
device=embeddings.device,
)
if not isinstance(embeddings, np.ndarray)
else np.empty(
(goal_achieved_mask.shape[0], *embeddings.shape[1:]),
dtype=embeddings.dtype,
)
)
assert len(goal_indices) == len(embeddings), goal_indices
last_embedding = embeddings[-1]
for i, idx in enumerate(goal_achieved_mask):
if idx >= embeddings.shape[0]:
# If the index in the mask is greater than the highest index in the embedding,
# just use the last row of the embedding
milestone_embeddings[i] = last_embedding
else:
skip = False
if linearly_random_skip_lower is not None:
skip = linear_random_skip(
cur_step=i,
next_goal_step=goal_indices[idx],
ratio=random_skip_ratio,
progress_lower=linearly_random_skip_lower,
)
elif (
random_skip_ratio is not None
and 0 < random_skip_ratio < random.random()
):
skip = True
if skip:
milestone_embeddings[i] = embeddings[min(idx + 1, len(embeddings) - 1)]
else:
milestone_embeddings[i] = embeddings[idx]
return milestone_embeddings, DecompMeta(milestone_indices=goal_indices)
def random_decomp(
embeddings: np.ndarray | torch.Tensor,
num_milestones: int | tuple[int, int],
fill_embeddings: bool = True,
) -> tuple[torch.Tensor | np.ndarray, DecompMeta]:
if not isinstance(num_milestones, int):
assert len(num_milestones) == 2, num_milestones
# by randomly sample from lower and higher bound
num_milestones = random.randint(*num_milestones)
traj_length = embeddings.shape[0]
goal_indices = random.sample(range(traj_length), k=num_milestones)
goal_indices = list(sorted(goal_indices))
if fill_embeddings:
milestone_embeddings = (
torch.empty_like(
embeddings, dtype=embeddings.dtype, device=embeddings.device
)
if not isinstance(embeddings, np.ndarray)
else np.empty_like(embeddings, dtype=embeddings.dtype)
)
for i, goal_idx in enumerate(goal_indices):
milestone_embeddings[
(goal_indices[i - 1] + 1) if i != 0 else 0 : goal_idx + 1
] = embeddings[goal_idx]
else:
milestone_embeddings = None
return milestone_embeddings, DecompMeta(milestone_indices=goal_indices)
def equally_decomp(
embeddings: np.ndarray | torch.Tensor,
num_milestones: int | tuple[int, int],
fill_embeddings: bool = True,
) -> tuple[torch.Tensor | np.ndarray, DecompMeta]:
if not isinstance(num_milestones, int):
assert len(num_milestones) == 2, num_milestones
# by randomly sample from lower and higher bound
num_milestones = random.randint(*num_milestones)
traj_length = embeddings.shape[0]
indices = np.linspace(0, traj_length - 1, num_milestones + 1, dtype=int)
if fill_embeddings:
milestone_embeddings = (
torch.empty_like(
embeddings, dtype=embeddings.dtype, device=embeddings.device
)
if not isinstance(embeddings, np.ndarray)
else np.empty_like(
embeddings,
dtype=embeddings.dtype,
)
)
for i, goal_idx in enumerate(indices[1:], start=1):
milestone_embeddings[
(indices[i - 1] + 1) if i != 1 else 0 : goal_idx + 1
] = embeddings[goal_idx]
else:
milestone_embeddings = None
return milestone_embeddings, DecompMeta(milestone_indices=indices[1:].tolist())
def near_future_decomp(
embeddings: np.ndarray | torch.Tensor, advance_steps: int, **kwargs
) -> tuple[torch.Tensor | np.ndarray, DecompMeta]:
return equally_decomp(
embeddings, num_milestones=embeddings.shape[0] // advance_steps, **kwargs
)
def no_decomp(
embeddings: np.ndarray | torch.Tensor, fill_embeddings: bool = True
) -> tuple[torch.Tensor | np.ndarray, DecompMeta]:
"""Only conditioned on final goal."""
if not fill_embeddings:
return None, DecompMeta(milestone_indices=[-1])
return embeddings[-1, ...].expand_as(embeddings).clone() if not isinstance(
embeddings, np.ndarray
) else np.full(
embeddings.shape,
embeddings[-1, ...],
dtype=embeddings.dtype,
), DecompMeta(
milestone_indices=[-1]
)
def decomp_trajectories(
method_name: Literal[
"embed", "embed_no_robot", "oracle", "random", "equally", "near_future"
]
| None,
embeddings: np.ndarray | torch.Tensor,
**kwargs,
) -> tuple[torch.Tensor | np.ndarray, DecompMeta]:
assert embeddings.ndim == 2 or embeddings.ndim == 4, (
f"input embedding should be either 2 dimensional, "
f"with (L, feature_dim), or raw rgb with shape (L, H, W, 3), "
f"but get {embeddings.shape}"
)
if method_name is None:
return no_decomp(embeddings)
assert method_name in DEFAULT_DECOMP_KWARGS, method_name
method_kwargs = DEFAULT_DECOMP_KWARGS[method_name]
method_kwargs.update(kwargs)
if method_name == "embed":
return embedding_decomp(embeddings=embeddings, **method_kwargs)
elif method_name == "embed_no_robot":
return embed_decomp_no_robot(embeddings=embeddings, **method_kwargs)
elif method_name == "embed_no_robot_extended":
return embed_decomp_no_robot_extended(embeddings=embeddings, **method_kwargs)
elif method_name == "oracle":
return oracle_decomp(embeddings, **method_kwargs)
elif method_name == "random":
return random_decomp(embeddings, **method_kwargs)
elif method_name == "equally":
return equally_decomp(embeddings, **method_kwargs)
elif method_name == "near_future":
return near_future_decomp(embeddings, **method_kwargs)
raise NotImplementedError(method_name)
DEFAULT_DECOMP_KWARGS = dict(
embed=dict(
normalize_curve=False,
min_interval=18,
smooth_method="kernel",
gamma=0.08,
),
embed_no_robot=dict(
window_length=8,
derivative_order=1,
derivative_threshold=1e-3,
threshold_subgoal_passing=None,
),
embed_no_robot_extended=dict(
window_length=3,
derivative_order=1,
derivative_threshold=1e-3,
threshold_subgoal_passing=None,
),
oracle=dict(),
random=dict(num_milestones=(3, 6)),
equally=dict(num_milestones=(3, 6)),
near_future=dict(advance_steps=5),
)