Spaces:
Runtime error
Runtime error
| """ | |
| Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py | |
| Action format derived from VPT https://github.com/openai/Video-Pre-Training | |
| """ | |
| import math | |
| import torch | |
| from torch import nn | |
| from einops import rearrange, parse_shape | |
| from typing import Mapping, Sequence | |
| import torch | |
| from einops import rearrange | |
| def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5): | |
| """ | |
| sigmoid schedule | |
| proposed in https://arxiv.org/abs/2212.11972 - Figure 8 | |
| better for images > 64x64, when used during training | |
| """ | |
| steps = timesteps + 1 | |
| t = torch.linspace(0, timesteps, steps, dtype=torch.float32) / timesteps | |
| v_start = torch.tensor(start / tau).sigmoid() | |
| v_end = torch.tensor(end / tau).sigmoid() | |
| alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start) | |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
| betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
| return torch.clip(betas, 0, 0.999) | |
| ACTION_KEYS = [ | |
| "inventory", | |
| "ESC", | |
| "hotbar.1", | |
| "hotbar.2", | |
| "hotbar.3", | |
| "hotbar.4", | |
| "hotbar.5", | |
| "hotbar.6", | |
| "hotbar.7", | |
| "hotbar.8", | |
| "hotbar.9", | |
| "forward", | |
| "back", | |
| "left", | |
| "right", | |
| "cameraX", | |
| "cameraY", | |
| "jump", | |
| "sneak", | |
| "sprint", | |
| "swapHands", | |
| "attack", | |
| "use", | |
| "pickItem", | |
| "drop", | |
| ] | |
| def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor: | |
| actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS)) | |
| for i, current_actions in enumerate(actions): | |
| for j, action_key in enumerate(ACTION_KEYS): | |
| if action_key.startswith("camera"): | |
| if action_key == "cameraX": | |
| value = current_actions["camera"][0] | |
| elif action_key == "cameraY": | |
| value = current_actions["camera"][1] | |
| else: | |
| raise ValueError(f"Unknown camera action key: {action_key}") | |
| # NOTE these numbers specific to the camera quantization used in | |
| # https://github.com/etched-ai/dreamcraft/blob/216e952f795bb3da598639a109bcdba4d2067b69/spark/preprocess_vpt_to_videos_actions.py#L312 | |
| # see method `compress_mouse` | |
| max_val = 20 | |
| bin_size = 0.5 | |
| num_buckets = int(max_val / bin_size) | |
| value = (value - num_buckets) / num_buckets | |
| assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}" | |
| else: | |
| value = current_actions[action_key] | |
| assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}" | |
| actions_one_hot[i, j] = value | |
| return actions_one_hot | |