File size: 5,384 Bytes
8652b14 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
Action format derived from VPT https://github.com/openai/Video-Pre-Training
Adapted from https://github.com/etched-ai/open-oasis/blob/master/utils.py
"""
import math
import torch
from torch import nn
from torchvision.io import read_image, read_video
from torchvision.transforms.functional import resize
from einops import rearrange
from typing import Mapping, Sequence
from einops import rearrange, parse_shape
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def extract(a, t, x_shape):
f, b = t.shape
out = a[t]
return out.reshape(f, b, *((1,) * (len(x_shape) - 2)))
def linear_beta_schedule(timesteps):
"""
linear schedule, proposed in original ddpm paper
"""
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
"""
sigmoid schedule
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
better for images > 64x64, when used during training
"""
steps = timesteps + 1
t = torch.linspace(0, timesteps, steps, dtype=torch.float64) / timesteps
v_start = torch.tensor(start / tau).sigmoid()
v_end = torch.tensor(end / tau).sigmoid()
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
ACTION_KEYS = [
"inventory",
"ESC",
"hotbar.1",
"hotbar.2",
"hotbar.3",
"hotbar.4",
"hotbar.5",
"hotbar.6",
"hotbar.7",
"hotbar.8",
"hotbar.9",
"forward",
"back",
"left",
"right",
"cameraX",
"cameraY",
"jump",
"sneak",
"sprint",
"swapHands",
"attack",
"use",
"pickItem",
"drop",
]
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
for i, current_actions in enumerate(actions):
for j, action_key in enumerate(ACTION_KEYS):
if action_key.startswith("camera"):
if action_key == "cameraX":
value = current_actions["camera"][0]
elif action_key == "cameraY":
value = current_actions["camera"][1]
else:
raise ValueError(f"Unknown camera action key: {action_key}")
max_val = 20
bin_size = 0.5
num_buckets = int(max_val / bin_size)
value = (value - num_buckets) / num_buckets
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
else:
value = current_actions[action_key]
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
actions_one_hot[i, j] = value
return actions_one_hot
IMAGE_EXTENSIONS = {"png", "jpg", "jpeg"}
VIDEO_EXTENSIONS = {"mp4"}
def load_prompt(path, video_offset=None, n_prompt_frames=1):
if path.lower().split(".")[-1] in IMAGE_EXTENSIONS:
print("prompt is image; ignoring video_offset and n_prompt_frames")
prompt = read_image(path)
# add frame dimension
prompt = rearrange(prompt, "c h w -> 1 c h w")
elif path.lower().split(".")[-1] in VIDEO_EXTENSIONS:
prompt = read_video(path, pts_unit="sec")[0]
if video_offset is not None:
prompt = prompt[video_offset:]
prompt = prompt[:n_prompt_frames]
else:
raise ValueError(f"unrecognized prompt file extension; expected one in {IMAGE_EXTENSIONS} or {VIDEO_EXTENSIONS}")
assert prompt.shape[0] == n_prompt_frames, f"input prompt {path} had less than n_prompt_frames={n_prompt_frames} frames"
prompt = resize(prompt, (360, 640))
# add batch dimension
prompt = rearrange(prompt, "t c h w -> 1 t c h w")
prompt = prompt.float() / 255.0
return prompt
def load_actions(path, action_offset=None):
if path.endswith(".actions.pt"):
actions = one_hot_actions(torch.load(path))
elif path.endswith(".one_hot_actions.pt"):
actions = torch.load(path, weights_only=True)
else:
raise ValueError("unrecognized action file extension; expected '*.actions.pt' or '*.one_hot_actions.pt'")
if action_offset is not None:
actions = actions[action_offset:]
actions = torch.cat([torch.zeros_like(actions[:1]), actions], dim=0)
# add batch dimension
actions = rearrange(actions, "t d -> 1 t d")
return actions
|