Spaces:
Runtime error
Runtime error
Upload 21 files
Browse files- .gitattributes +5 -0
- open_oasis_master/.gitattributes +1 -0
- open_oasis_master/LICENSE +21 -0
- open_oasis_master/README.md +37 -0
- open_oasis_master/attention.py +137 -0
- open_oasis_master/dit.py +310 -0
- open_oasis_master/embeddings.py +103 -0
- open_oasis_master/generate.py +119 -0
- open_oasis_master/media/arch.png +0 -0
- open_oasis_master/media/sample_0.gif +3 -0
- open_oasis_master/media/sample_1.gif +3 -0
- open_oasis_master/media/thumb.png +0 -0
- open_oasis_master/requirements.txt +31 -0
- open_oasis_master/rotary_embedding_torch.py +316 -0
- open_oasis_master/sample_data/Player729-f153ac423f61-20210806-224813.chunk_000.actions.pt +3 -0
- open_oasis_master/sample_data/Player729-f153ac423f61-20210806-224813.chunk_000.mp4 +3 -0
- open_oasis_master/sample_data/snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001.actions.pt +3 -0
- open_oasis_master/sample_data/snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001.mp4 +3 -0
- open_oasis_master/sample_data/treechop-f153ac423f61-20210916-183423.chunk_000.actions.pt +3 -0
- open_oasis_master/sample_data/treechop-f153ac423f61-20210916-183423.chunk_000.mp4 +3 -0
- open_oasis_master/utils.py +82 -0
- open_oasis_master/vae.py +381 -0
.gitattributes
CHANGED
|
@@ -38,3 +38,8 @@ open-oasis-master/media/sample_1.gif filter=lfs diff=lfs merge=lfs -text
|
|
| 38 |
open-oasis-master/sample_data/Player729-f153ac423f61-20210806-224813.chunk_000.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
open-oasis-master/sample_data/snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
open-oasis-master/sample_data/treechop-f153ac423f61-20210916-183423.chunk_000.mp4 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
open-oasis-master/sample_data/Player729-f153ac423f61-20210806-224813.chunk_000.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
open-oasis-master/sample_data/snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
open-oasis-master/sample_data/treechop-f153ac423f61-20210916-183423.chunk_000.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
open_oasis_master/media/sample_0.gif filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
open_oasis_master/media/sample_1.gif filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
open_oasis_master/sample_data/Player729-f153ac423f61-20210806-224813.chunk_000.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
open_oasis_master/sample_data/snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
open_oasis_master/sample_data/treechop-f153ac423f61-20210916-183423.chunk_000.mp4 filter=lfs diff=lfs merge=lfs -text
|
open_oasis_master/.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
video.mp4 filter=lfs diff=lfs merge=lfs -text
|
open_oasis_master/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Etched & Decart
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
open_oasis_master/README.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Oasis 500M
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
Oasis is an interactive world model developed by [Decart](https://www.decart.ai/) and [Etched](https://www.etched.com/). Based on diffusion transformers, Oasis takes in user keyboard input and generates gameplay in an autoregressive manner. We release the weights for Oasis 500M, a downscaled version of the model, along with inference code for action-conditional frame generation.
|
| 8 |
+
|
| 9 |
+
For more details, see our [joint blog post](https://oasis-model.github.io/) to learn more.
|
| 10 |
+
|
| 11 |
+
And to use the most powerful version of the model, be sure to check out the [live demo](https://oasis.us.decart.ai/) as well!
|
| 12 |
+
|
| 13 |
+
## Setup
|
| 14 |
+
```
|
| 15 |
+
git clone https://github.com/etched-ai/open-oasis.git
|
| 16 |
+
cd open-oasis
|
| 17 |
+
pip install -r requirements.txt
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## Download the model weights
|
| 21 |
+
```
|
| 22 |
+
huggingface-cli login
|
| 23 |
+
huggingface-cli download Etched/oasis-500m oasis500m.pt # DiT checkpoint
|
| 24 |
+
huggingface-cli download Etched/oasis-500m vit-l-20.pt # ViT VAE checkpoint
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
## Basic Usage
|
| 28 |
+
We include a basic inference script that loads a prompt frame from a video and generates additional frames conditioned on actions.
|
| 29 |
+
```
|
| 30 |
+
python generate.py
|
| 31 |
+
```
|
| 32 |
+
The resulting video will be saved to `video.mp4`. Here's are some examples of a generation from this 500M model!
|
| 33 |
+
|
| 34 |
+

|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
> Hint: try swapping out the `.mp4` input file in the script to try different environments!
|
open_oasis_master/attention.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
|
| 3 |
+
"""
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from collections import namedtuple
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
|
| 11 |
+
from embeddings import TimestepEmbedding, Timesteps, Positions2d
|
| 12 |
+
|
| 13 |
+
class TemporalAxialAttention(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
dim: int,
|
| 17 |
+
heads: int = 4,
|
| 18 |
+
dim_head: int = 32,
|
| 19 |
+
is_causal: bool = True,
|
| 20 |
+
rotary_emb: Optional[RotaryEmbedding] = None,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.inner_dim = dim_head * heads
|
| 24 |
+
self.heads = heads
|
| 25 |
+
self.head_dim = dim_head
|
| 26 |
+
self.inner_dim = dim_head * heads
|
| 27 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 28 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 29 |
+
|
| 30 |
+
self.rotary_emb = rotary_emb
|
| 31 |
+
self.time_pos_embedding = (
|
| 32 |
+
nn.Sequential(
|
| 33 |
+
Timesteps(dim),
|
| 34 |
+
TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim),
|
| 35 |
+
)
|
| 36 |
+
if rotary_emb is None
|
| 37 |
+
else None
|
| 38 |
+
)
|
| 39 |
+
self.is_causal = is_causal
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor):
|
| 42 |
+
B, T, H, W, D = x.shape
|
| 43 |
+
|
| 44 |
+
if self.time_pos_embedding is not None:
|
| 45 |
+
time_emb = self.time_pos_embedding(
|
| 46 |
+
torch.arange(T, device=x.device)
|
| 47 |
+
)
|
| 48 |
+
x = x + rearrange(time_emb, "t d -> 1 t 1 1 d")
|
| 49 |
+
|
| 50 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 51 |
+
|
| 52 |
+
q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 53 |
+
k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 54 |
+
v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
|
| 55 |
+
|
| 56 |
+
if self.rotary_emb is not None:
|
| 57 |
+
q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
|
| 58 |
+
k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
|
| 59 |
+
|
| 60 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 61 |
+
|
| 62 |
+
x = F.scaled_dot_product_attention(
|
| 63 |
+
query=q, key=k, value=v, is_causal=self.is_causal
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
|
| 67 |
+
x = x.to(q.dtype)
|
| 68 |
+
|
| 69 |
+
# linear proj
|
| 70 |
+
x = self.to_out(x)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
class SpatialAxialAttention(nn.Module):
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
dim: int,
|
| 77 |
+
heads: int = 4,
|
| 78 |
+
dim_head: int = 32,
|
| 79 |
+
rotary_emb: Optional[RotaryEmbedding] = None,
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.inner_dim = dim_head * heads
|
| 83 |
+
self.heads = heads
|
| 84 |
+
self.head_dim = dim_head
|
| 85 |
+
self.inner_dim = dim_head * heads
|
| 86 |
+
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
|
| 87 |
+
self.to_out = nn.Linear(self.inner_dim, dim)
|
| 88 |
+
|
| 89 |
+
self.rotary_emb = rotary_emb
|
| 90 |
+
self.space_pos_embedding = (
|
| 91 |
+
nn.Sequential(
|
| 92 |
+
Positions2d(dim),
|
| 93 |
+
TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim),
|
| 94 |
+
)
|
| 95 |
+
if rotary_emb is None
|
| 96 |
+
else None
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def forward(self, x: torch.Tensor):
|
| 100 |
+
B, T, H, W, D = x.shape
|
| 101 |
+
|
| 102 |
+
if self.space_pos_embedding is not None:
|
| 103 |
+
h_steps = torch.arange(H, device=x.device)
|
| 104 |
+
w_steps = torch.arange(W, device=x.device)
|
| 105 |
+
grid = torch.meshgrid(h_steps, w_steps, indexing="ij")
|
| 106 |
+
space_emb = self.space_pos_embedding(grid)
|
| 107 |
+
x = x + rearrange(space_emb, "h w d -> 1 1 h w d")
|
| 108 |
+
|
| 109 |
+
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
| 110 |
+
|
| 111 |
+
q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
|
| 112 |
+
k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
|
| 113 |
+
v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
|
| 114 |
+
|
| 115 |
+
if self.rotary_emb is not None:
|
| 116 |
+
freqs = self.rotary_emb.get_axial_freqs(H, W)
|
| 117 |
+
q = apply_rotary_emb(freqs, q)
|
| 118 |
+
k = apply_rotary_emb(freqs, k)
|
| 119 |
+
|
| 120 |
+
# prepare for attn
|
| 121 |
+
q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
|
| 122 |
+
k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
|
| 123 |
+
v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
|
| 124 |
+
|
| 125 |
+
q, k, v = map(lambda t: t.contiguous(), (q, k, v))
|
| 126 |
+
|
| 127 |
+
x = F.scaled_dot_product_attention(
|
| 128 |
+
query=q, key=k, value=v, is_causal=False
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
|
| 132 |
+
x = x.to(q.dtype)
|
| 133 |
+
|
| 134 |
+
# linear proj
|
| 135 |
+
x = self.to_out(x)
|
| 136 |
+
return x
|
| 137 |
+
|
open_oasis_master/dit.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
References:
|
| 3 |
+
- DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
|
| 4 |
+
- Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py
|
| 5 |
+
- Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py
|
| 6 |
+
"""
|
| 7 |
+
from typing import Optional, Literal
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from rotary_embedding_torch import RotaryEmbedding
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from embeddings import Timesteps, TimestepEmbedding
|
| 13 |
+
from attention import SpatialAxialAttention, TemporalAxialAttention
|
| 14 |
+
from timm.models.vision_transformer import Mlp
|
| 15 |
+
from timm.layers.helpers import to_2tuple
|
| 16 |
+
import math
|
| 17 |
+
|
| 18 |
+
def modulate(x, shift, scale):
|
| 19 |
+
fixed_dims = [1] * len(shift.shape[1:])
|
| 20 |
+
shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims)
|
| 21 |
+
scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims)
|
| 22 |
+
while shift.dim() < x.dim():
|
| 23 |
+
shift = shift.unsqueeze(-2)
|
| 24 |
+
scale = scale.unsqueeze(-2)
|
| 25 |
+
return x * (1 + scale) + shift
|
| 26 |
+
|
| 27 |
+
def gate(x, g):
|
| 28 |
+
fixed_dims = [1] * len(g.shape[1:])
|
| 29 |
+
g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims)
|
| 30 |
+
while g.dim() < x.dim():
|
| 31 |
+
g = g.unsqueeze(-2)
|
| 32 |
+
return g * x
|
| 33 |
+
|
| 34 |
+
class PatchEmbed(nn.Module):
|
| 35 |
+
"""2D Image to Patch Embedding"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_height=256,
|
| 40 |
+
img_width=256,
|
| 41 |
+
patch_size=16,
|
| 42 |
+
in_chans=3,
|
| 43 |
+
embed_dim=768,
|
| 44 |
+
norm_layer=None,
|
| 45 |
+
flatten=True,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
img_size = (img_height, img_width)
|
| 49 |
+
patch_size = to_2tuple(patch_size)
|
| 50 |
+
self.img_size = img_size
|
| 51 |
+
self.patch_size = patch_size
|
| 52 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
| 53 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
| 54 |
+
self.flatten = flatten
|
| 55 |
+
|
| 56 |
+
self.proj = nn.Conv2d(
|
| 57 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
| 58 |
+
)
|
| 59 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 60 |
+
|
| 61 |
+
def forward(self, x, random_sample=False):
|
| 62 |
+
B, C, H, W = x.shape
|
| 63 |
+
assert random_sample or (
|
| 64 |
+
H == self.img_size[0] and W == self.img_size[1]
|
| 65 |
+
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
| 66 |
+
x = self.proj(x)
|
| 67 |
+
if self.flatten:
|
| 68 |
+
x = rearrange(x, "B C H W -> B (H W) C")
|
| 69 |
+
else:
|
| 70 |
+
x = rearrange(x, "B C H W -> B H W C")
|
| 71 |
+
x = self.norm(x)
|
| 72 |
+
return x
|
| 73 |
+
|
| 74 |
+
class TimestepEmbedder(nn.Module):
|
| 75 |
+
"""
|
| 76 |
+
Embeds scalar timesteps into vector representations.
|
| 77 |
+
"""
|
| 78 |
+
def __init__(self, hidden_size, frequency_embedding_size=256):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.mlp = nn.Sequential(
|
| 81 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size
|
| 82 |
+
nn.SiLU(),
|
| 83 |
+
nn.Linear(hidden_size, hidden_size, bias=True),
|
| 84 |
+
)
|
| 85 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 89 |
+
"""
|
| 90 |
+
Create sinusoidal timestep embeddings.
|
| 91 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 92 |
+
These may be fractional.
|
| 93 |
+
:param dim: the dimension of the output.
|
| 94 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 95 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 96 |
+
"""
|
| 97 |
+
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 98 |
+
half = dim // 2
|
| 99 |
+
freqs = torch.exp(
|
| 100 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
| 101 |
+
).to(device=t.device)
|
| 102 |
+
args = t[:, None].float() * freqs[None]
|
| 103 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 104 |
+
if dim % 2:
|
| 105 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 106 |
+
return embedding
|
| 107 |
+
|
| 108 |
+
def forward(self, t):
|
| 109 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 110 |
+
t_emb = self.mlp(t_freq)
|
| 111 |
+
return t_emb
|
| 112 |
+
|
| 113 |
+
class FinalLayer(nn.Module):
|
| 114 |
+
"""
|
| 115 |
+
The final layer of DiT.
|
| 116 |
+
"""
|
| 117 |
+
def __init__(self, hidden_size, patch_size, out_channels):
|
| 118 |
+
super().__init__()
|
| 119 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 120 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 121 |
+
self.adaLN_modulation = nn.Sequential(
|
| 122 |
+
nn.SiLU(),
|
| 123 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def forward(self, x, c):
|
| 127 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
|
| 128 |
+
x = modulate(self.norm_final(x), shift, scale)
|
| 129 |
+
x = self.linear(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
class SpatioTemporalDiTBlock(nn.Module):
|
| 133 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, is_causal=True, spatial_rotary_emb: Optional[RotaryEmbedding] = None, temporal_rotary_emb: Optional[RotaryEmbedding] = None):
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.is_causal = is_causal
|
| 136 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 137 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 138 |
+
|
| 139 |
+
self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 140 |
+
self.s_attn = SpatialAxialAttention(hidden_size, heads=num_heads, dim_head=hidden_size // num_heads, rotary_emb=spatial_rotary_emb)
|
| 141 |
+
self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 142 |
+
self.s_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 143 |
+
self.s_adaLN_modulation = nn.Sequential(
|
| 144 |
+
nn.SiLU(),
|
| 145 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 149 |
+
self.t_attn = TemporalAxialAttention(hidden_size, heads=num_heads, dim_head=hidden_size // num_heads, is_causal=is_causal, rotary_emb=temporal_rotary_emb)
|
| 150 |
+
self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 151 |
+
self.t_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
|
| 152 |
+
self.t_adaLN_modulation = nn.Sequential(
|
| 153 |
+
nn.SiLU(),
|
| 154 |
+
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
def forward(self, x, c):
|
| 158 |
+
B, T, H, W, D = x.shape
|
| 159 |
+
|
| 160 |
+
# spatial block
|
| 161 |
+
s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1)
|
| 162 |
+
x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa)
|
| 163 |
+
x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp)
|
| 164 |
+
|
| 165 |
+
# temporal block
|
| 166 |
+
t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1)
|
| 167 |
+
x = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa)
|
| 168 |
+
x = x + gate(self.t_mlp(modulate(self.t_norm2(x), t_shift_mlp, t_scale_mlp)), t_gate_mlp)
|
| 169 |
+
|
| 170 |
+
return x
|
| 171 |
+
|
| 172 |
+
class DiT(nn.Module):
|
| 173 |
+
"""
|
| 174 |
+
Diffusion model with a Transformer backbone.
|
| 175 |
+
"""
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
input_h=18,
|
| 179 |
+
input_w=32,
|
| 180 |
+
patch_size=2,
|
| 181 |
+
in_channels=16,
|
| 182 |
+
hidden_size=1024,
|
| 183 |
+
depth=12,
|
| 184 |
+
num_heads=16,
|
| 185 |
+
mlp_ratio=4.0,
|
| 186 |
+
external_cond_dim=25,
|
| 187 |
+
max_frames=32,
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.in_channels = in_channels
|
| 191 |
+
self.out_channels = in_channels
|
| 192 |
+
self.patch_size = patch_size
|
| 193 |
+
self.num_heads = num_heads
|
| 194 |
+
self.max_frames = max_frames
|
| 195 |
+
|
| 196 |
+
self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False)
|
| 197 |
+
self.t_embedder = TimestepEmbedder(hidden_size)
|
| 198 |
+
frame_h, frame_w = self.x_embedder.grid_size
|
| 199 |
+
|
| 200 |
+
self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
|
| 201 |
+
self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads)
|
| 202 |
+
self.external_cond = nn.Linear(external_cond_dim, hidden_size) if external_cond_dim > 0 else nn.Identity()
|
| 203 |
+
|
| 204 |
+
self.blocks = nn.ModuleList(
|
| 205 |
+
[
|
| 206 |
+
SpatioTemporalDiTBlock(
|
| 207 |
+
hidden_size,
|
| 208 |
+
num_heads,
|
| 209 |
+
mlp_ratio=mlp_ratio,
|
| 210 |
+
is_causal=True,
|
| 211 |
+
spatial_rotary_emb=self.spatial_rotary_emb,
|
| 212 |
+
temporal_rotary_emb=self.temporal_rotary_emb,
|
| 213 |
+
)
|
| 214 |
+
for _ in range(depth)
|
| 215 |
+
]
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
| 219 |
+
self.initialize_weights()
|
| 220 |
+
|
| 221 |
+
def initialize_weights(self):
|
| 222 |
+
# Initialize transformer layers:
|
| 223 |
+
def _basic_init(module):
|
| 224 |
+
if isinstance(module, nn.Linear):
|
| 225 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 226 |
+
if module.bias is not None:
|
| 227 |
+
nn.init.constant_(module.bias, 0)
|
| 228 |
+
self.apply(_basic_init)
|
| 229 |
+
|
| 230 |
+
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
| 231 |
+
w = self.x_embedder.proj.weight.data
|
| 232 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 233 |
+
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
| 234 |
+
|
| 235 |
+
# Initialize timestep embedding MLP:
|
| 236 |
+
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
| 237 |
+
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
| 238 |
+
|
| 239 |
+
# Zero-out adaLN modulation layers in DiT blocks:
|
| 240 |
+
for block in self.blocks:
|
| 241 |
+
nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0)
|
| 242 |
+
nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0)
|
| 243 |
+
nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0)
|
| 244 |
+
nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0)
|
| 245 |
+
|
| 246 |
+
# Zero-out output layers:
|
| 247 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
| 248 |
+
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
| 249 |
+
nn.init.constant_(self.final_layer.linear.weight, 0)
|
| 250 |
+
nn.init.constant_(self.final_layer.linear.bias, 0)
|
| 251 |
+
|
| 252 |
+
def unpatchify(self, x):
|
| 253 |
+
"""
|
| 254 |
+
x: (N, H, W, patch_size**2 * C)
|
| 255 |
+
imgs: (N, H, W, C)
|
| 256 |
+
"""
|
| 257 |
+
c = self.out_channels
|
| 258 |
+
p = self.x_embedder.patch_size[0]
|
| 259 |
+
h = x.shape[1]
|
| 260 |
+
w = x.shape[2]
|
| 261 |
+
|
| 262 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
| 263 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
| 264 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
| 265 |
+
return imgs
|
| 266 |
+
|
| 267 |
+
def forward(self, x, t, external_cond=None):
|
| 268 |
+
"""
|
| 269 |
+
Forward pass of DiT.
|
| 270 |
+
x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
| 271 |
+
t: (B, T,) tensor of diffusion timesteps
|
| 272 |
+
"""
|
| 273 |
+
|
| 274 |
+
B, T, C, H, W = x.shape
|
| 275 |
+
|
| 276 |
+
# add spatial embeddings
|
| 277 |
+
x = rearrange(x, "b t c h w -> (b t) c h w")
|
| 278 |
+
x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
|
| 279 |
+
# restore shape
|
| 280 |
+
x = rearrange(x, "(b t) h w d -> b t h w d", t = T)
|
| 281 |
+
# embed noise steps
|
| 282 |
+
t = rearrange(t, "b t -> (b t)")
|
| 283 |
+
c = self.t_embedder(t) # (N, D)
|
| 284 |
+
c = rearrange(c, "(b t) d -> b t d", t = T)
|
| 285 |
+
if torch.is_tensor(external_cond):
|
| 286 |
+
c += self.external_cond(external_cond)
|
| 287 |
+
for block in self.blocks:
|
| 288 |
+
x = block(x, c) # (N, T, H, W, D)
|
| 289 |
+
x = self.final_layer(x, c) # (N, T, H, W, patch_size ** 2 * out_channels)
|
| 290 |
+
# unpatchify
|
| 291 |
+
x = rearrange(x, "b t h w d -> (b t) h w d")
|
| 292 |
+
x = self.unpatchify(x) # (N, out_channels, H, W)
|
| 293 |
+
x = rearrange(x, "(b t) c h w -> b t c h w", t = T)
|
| 294 |
+
|
| 295 |
+
return x
|
| 296 |
+
|
| 297 |
+
def DiT_S_2():
|
| 298 |
+
return DiT(
|
| 299 |
+
patch_size=2,
|
| 300 |
+
hidden_size=1024,
|
| 301 |
+
depth=16,
|
| 302 |
+
num_heads=16,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
DiT_models = {
|
| 306 |
+
"DiT-S/2": DiT_S_2
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
|
open_oasis_master/embeddings.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
|
| 10 |
+
# pylint: disable=unused-import
|
| 11 |
+
from diffusers.models.embeddings import TimestepEmbedding
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Timesteps(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
num_channels: int,
|
| 18 |
+
flip_sin_to_cos: bool = True,
|
| 19 |
+
downscale_freq_shift: float = 0,
|
| 20 |
+
):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.num_channels = num_channels
|
| 23 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 24 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 25 |
+
|
| 26 |
+
def forward(self, timesteps):
|
| 27 |
+
t_emb = get_timestep_embedding(
|
| 28 |
+
timesteps,
|
| 29 |
+
self.num_channels,
|
| 30 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 31 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 32 |
+
)
|
| 33 |
+
return t_emb
|
| 34 |
+
|
| 35 |
+
class Positions2d(nn.Module):
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
num_channels: int,
|
| 39 |
+
flip_sin_to_cos: bool = True,
|
| 40 |
+
downscale_freq_shift: float = 0,
|
| 41 |
+
):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.num_channels = num_channels
|
| 44 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 45 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 46 |
+
|
| 47 |
+
def forward(self, grid):
|
| 48 |
+
h_emb = get_timestep_embedding(
|
| 49 |
+
grid[0],
|
| 50 |
+
self.num_channels // 2,
|
| 51 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 52 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 53 |
+
)
|
| 54 |
+
w_emb = get_timestep_embedding(
|
| 55 |
+
grid[1],
|
| 56 |
+
self.num_channels // 2,
|
| 57 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 58 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 59 |
+
)
|
| 60 |
+
emb = torch.cat((h_emb, w_emb), dim=-1)
|
| 61 |
+
return emb
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_timestep_embedding(
|
| 65 |
+
timesteps: torch.Tensor,
|
| 66 |
+
embedding_dim: int,
|
| 67 |
+
flip_sin_to_cos: bool = False,
|
| 68 |
+
downscale_freq_shift: float = 1,
|
| 69 |
+
scale: float = 1,
|
| 70 |
+
max_period: int = 10000,
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 74 |
+
|
| 75 |
+
:param timesteps: a 1-D or 2-D Tensor of N indices, one per batch element.
|
| 76 |
+
These may be fractional.
|
| 77 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
| 78 |
+
embeddings. :return: an [N x dim] or [N x M x dim] Tensor of positional embeddings.
|
| 79 |
+
"""
|
| 80 |
+
if len(timesteps.shape) not in [1, 2]:
|
| 81 |
+
raise ValueError("Timesteps should be a 1D or 2D tensor")
|
| 82 |
+
|
| 83 |
+
half_dim = embedding_dim // 2
|
| 84 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
| 85 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 86 |
+
|
| 87 |
+
emb = torch.exp(exponent)
|
| 88 |
+
emb = timesteps[..., None].float() * emb
|
| 89 |
+
|
| 90 |
+
# scale embeddings
|
| 91 |
+
emb = scale * emb
|
| 92 |
+
|
| 93 |
+
# concat sine and cosine embeddings
|
| 94 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 95 |
+
|
| 96 |
+
# flip sine and cosine embeddings
|
| 97 |
+
if flip_sin_to_cos:
|
| 98 |
+
emb = torch.cat([emb[..., half_dim:], emb[..., :half_dim]], dim=-1)
|
| 99 |
+
|
| 100 |
+
# zero pad
|
| 101 |
+
if embedding_dim % 2 == 1:
|
| 102 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 103 |
+
return emb
|
open_oasis_master/generate.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
References:
|
| 3 |
+
- Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing
|
| 4 |
+
"""
|
| 5 |
+
import torch
|
| 6 |
+
from dit import DiT_models
|
| 7 |
+
from vae import VAE_models
|
| 8 |
+
from torchvision.io import read_video, write_video
|
| 9 |
+
from utils import one_hot_actions, sigmoid_beta_schedule
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
from torch import autocast
|
| 13 |
+
assert torch.cuda.is_available()
|
| 14 |
+
device = "cuda:0"
|
| 15 |
+
|
| 16 |
+
# load DiT checkpoint
|
| 17 |
+
ckpt = torch.load("oasis500m.pt")
|
| 18 |
+
model = DiT_models["DiT-S/2"]()
|
| 19 |
+
model.load_state_dict(ckpt, strict=False)
|
| 20 |
+
model = model.to(device).eval()
|
| 21 |
+
|
| 22 |
+
# load VAE checkpoint
|
| 23 |
+
vae_ckpt = torch.load("vit-l-20.pt")
|
| 24 |
+
vae = VAE_models["vit-l-20-shallow-encoder"]()
|
| 25 |
+
vae.load_state_dict(vae_ckpt)
|
| 26 |
+
vae = vae.to(device).eval()
|
| 27 |
+
|
| 28 |
+
# sampling params
|
| 29 |
+
B = 1
|
| 30 |
+
total_frames = 32
|
| 31 |
+
max_noise_level = 1000
|
| 32 |
+
ddim_noise_steps = 100
|
| 33 |
+
noise_range = torch.linspace(-1, max_noise_level - 1, ddim_noise_steps + 1)
|
| 34 |
+
noise_abs_max = 20
|
| 35 |
+
ctx_max_noise_idx = ddim_noise_steps // 10 * 3
|
| 36 |
+
|
| 37 |
+
# get input video
|
| 38 |
+
video_id = "snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001"
|
| 39 |
+
mp4_path = f"sample_data/{video_id}.mp4"
|
| 40 |
+
actions_path = f"sample_data/{video_id}.actions.pt"
|
| 41 |
+
video = read_video(mp4_path, pts_unit="sec")[0].float() / 255
|
| 42 |
+
actions = one_hot_actions(torch.load(actions_path))
|
| 43 |
+
offset = 100
|
| 44 |
+
video = video[offset:offset+total_frames].unsqueeze(0)
|
| 45 |
+
actions = actions[offset:offset+total_frames].unsqueeze(0)
|
| 46 |
+
|
| 47 |
+
# sampling inputs
|
| 48 |
+
n_prompt_frames = 1
|
| 49 |
+
x = video[:, :n_prompt_frames]
|
| 50 |
+
x = x.to(device)
|
| 51 |
+
actions = actions.to(device)
|
| 52 |
+
|
| 53 |
+
# vae encoding
|
| 54 |
+
scaling_factor = 0.07843137255
|
| 55 |
+
x = rearrange(x, "b t h w c -> (b t) c h w")
|
| 56 |
+
H, W = x.shape[-2:]
|
| 57 |
+
with torch.no_grad():
|
| 58 |
+
x = vae.encode(x * 2 - 1).mean * scaling_factor
|
| 59 |
+
x = rearrange(x, "(b t) (h w) c -> b t c h w", t=n_prompt_frames, h=H//vae.patch_size, w=W//vae.patch_size)
|
| 60 |
+
|
| 61 |
+
# get alphas
|
| 62 |
+
betas = sigmoid_beta_schedule(max_noise_level).to(device)
|
| 63 |
+
alphas = 1.0 - betas
|
| 64 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 65 |
+
alphas_cumprod = rearrange(alphas_cumprod, "T -> T 1 1 1")
|
| 66 |
+
|
| 67 |
+
# sampling loop
|
| 68 |
+
for i in tqdm(range(n_prompt_frames, total_frames)):
|
| 69 |
+
chunk = torch.randn((B, 1, *x.shape[-3:]), device=device)
|
| 70 |
+
chunk = torch.clamp(chunk, -noise_abs_max, +noise_abs_max)
|
| 71 |
+
x = torch.cat([x, chunk], dim=1)
|
| 72 |
+
start_frame = max(0, i + 1 - model.max_frames)
|
| 73 |
+
|
| 74 |
+
for noise_idx in reversed(range(1, ddim_noise_steps + 1)):
|
| 75 |
+
# set up noise values
|
| 76 |
+
ctx_noise_idx = min(noise_idx, ctx_max_noise_idx)
|
| 77 |
+
t_ctx = torch.full((B, i), noise_range[ctx_noise_idx], dtype=torch.long, device=device)
|
| 78 |
+
t = torch.full((B, 1), noise_range[noise_idx], dtype=torch.long, device=device)
|
| 79 |
+
t_next = torch.full((B, 1), noise_range[noise_idx - 1], dtype=torch.long, device=device)
|
| 80 |
+
t_next = torch.where(t_next < 0, t, t_next)
|
| 81 |
+
t = torch.cat([t_ctx, t], dim=1)
|
| 82 |
+
t_next = torch.cat([t_ctx, t_next], dim=1)
|
| 83 |
+
|
| 84 |
+
# sliding window
|
| 85 |
+
x_curr = x.clone()
|
| 86 |
+
x_curr = x_curr[:, start_frame:]
|
| 87 |
+
t = t[:, start_frame:]
|
| 88 |
+
t_next = t_next[:, start_frame:]
|
| 89 |
+
|
| 90 |
+
# add some noise to the context
|
| 91 |
+
ctx_noise = torch.randn_like(x_curr[:, :-1])
|
| 92 |
+
ctx_noise = torch.clamp(ctx_noise, -noise_abs_max, +noise_abs_max)
|
| 93 |
+
x_curr[:, :-1] = alphas_cumprod[t[:, :-1]].sqrt() * x_curr[:, :-1] + (1 - alphas_cumprod[t[:, :-1]]).sqrt() * ctx_noise
|
| 94 |
+
|
| 95 |
+
# get model predictions
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
with autocast("cuda", dtype=torch.half):
|
| 98 |
+
v = model(x_curr, t, actions[:, start_frame : i + 1])
|
| 99 |
+
|
| 100 |
+
x_start = alphas_cumprod[t].sqrt() * x_curr - (1 - alphas_cumprod[t]).sqrt() * v
|
| 101 |
+
x_noise = ((1 / alphas_cumprod[t]).sqrt() * x_curr - x_start) \
|
| 102 |
+
/ (1 / alphas_cumprod[t] - 1).sqrt()
|
| 103 |
+
|
| 104 |
+
# get frame prediction
|
| 105 |
+
x_pred = alphas_cumprod[t_next].sqrt() * x_start + x_noise * (1 - alphas_cumprod[t_next]).sqrt()
|
| 106 |
+
x[:, -1:] = x_pred[:, -1:]
|
| 107 |
+
|
| 108 |
+
# vae decoding
|
| 109 |
+
x = rearrange(x, "b t c h w -> (b t) (h w) c")
|
| 110 |
+
with torch.no_grad():
|
| 111 |
+
x = (vae.decode(x / scaling_factor) + 1) / 2
|
| 112 |
+
x = rearrange(x, "(b t) c h w -> b t h w c", t=total_frames)
|
| 113 |
+
|
| 114 |
+
# save video
|
| 115 |
+
x = torch.clamp(x, 0, 1)
|
| 116 |
+
x = (x * 255).byte()
|
| 117 |
+
write_video("video.mp4", x[0], fps=20)
|
| 118 |
+
print("generation saved to video.mp4.")
|
| 119 |
+
|
open_oasis_master/media/arch.png
ADDED
|
open_oasis_master/media/sample_0.gif
ADDED
|
Git LFS Details
|
open_oasis_master/media/sample_1.gif
ADDED
|
Git LFS Details
|
open_oasis_master/media/thumb.png
ADDED
|
|
open_oasis_master/requirements.txt
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
av==13.1.0
|
| 2 |
+
certifi==2024.8.30
|
| 3 |
+
charset-normalizer==3.4.0
|
| 4 |
+
diffusers==0.31.0
|
| 5 |
+
einops==0.8.0
|
| 6 |
+
filelock==3.13.1
|
| 7 |
+
fsspec==2024.2.0
|
| 8 |
+
huggingface-hub==0.26.2
|
| 9 |
+
idna==3.10
|
| 10 |
+
importlib_metadata==8.5.0
|
| 11 |
+
Jinja2==3.1.3
|
| 12 |
+
MarkupSafe==2.1.5
|
| 13 |
+
mpmath==1.3.0
|
| 14 |
+
networkx==3.2.1
|
| 15 |
+
numpy==1.26.3
|
| 16 |
+
packaging==24.1
|
| 17 |
+
pillow==10.2.0
|
| 18 |
+
PyYAML==6.0.2
|
| 19 |
+
regex==2024.9.11
|
| 20 |
+
requests==2.32.3
|
| 21 |
+
safetensors==0.4.5
|
| 22 |
+
sympy==1.13.1
|
| 23 |
+
timm==1.0.11
|
| 24 |
+
torch==2.5.1
|
| 25 |
+
torchaudio==2.5.1
|
| 26 |
+
torchvision==0.20.1
|
| 27 |
+
tqdm==4.66.6
|
| 28 |
+
triton==3.1.0
|
| 29 |
+
typing_extensions==4.9.0
|
| 30 |
+
urllib3==2.2.3
|
| 31 |
+
zipp==3.20.2
|
open_oasis_master/rotary_embedding_torch.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
from math import pi, log
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn import Module, ModuleList
|
| 10 |
+
from torch.amp import autocast
|
| 11 |
+
from torch import nn, einsum, broadcast_tensors, Tensor
|
| 12 |
+
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
|
| 15 |
+
from typing import Literal
|
| 16 |
+
|
| 17 |
+
# helper functions
|
| 18 |
+
|
| 19 |
+
def exists(val):
|
| 20 |
+
return val is not None
|
| 21 |
+
|
| 22 |
+
def default(val, d):
|
| 23 |
+
return val if exists(val) else d
|
| 24 |
+
|
| 25 |
+
# broadcat, as tortoise-tts was using it
|
| 26 |
+
|
| 27 |
+
def broadcat(tensors, dim = -1):
|
| 28 |
+
broadcasted_tensors = broadcast_tensors(*tensors)
|
| 29 |
+
return torch.cat(broadcasted_tensors, dim = dim)
|
| 30 |
+
|
| 31 |
+
# rotary embedding helper functions
|
| 32 |
+
|
| 33 |
+
def rotate_half(x):
|
| 34 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
| 35 |
+
x1, x2 = x.unbind(dim = -1)
|
| 36 |
+
x = torch.stack((-x2, x1), dim = -1)
|
| 37 |
+
return rearrange(x, '... d r -> ... (d r)')
|
| 38 |
+
|
| 39 |
+
@autocast('cuda', enabled = False)
|
| 40 |
+
def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
|
| 41 |
+
dtype = t.dtype
|
| 42 |
+
|
| 43 |
+
if t.ndim == 3:
|
| 44 |
+
seq_len = t.shape[seq_dim]
|
| 45 |
+
freqs = freqs[-seq_len:]
|
| 46 |
+
|
| 47 |
+
rot_dim = freqs.shape[-1]
|
| 48 |
+
end_index = start_index + rot_dim
|
| 49 |
+
|
| 50 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
| 51 |
+
|
| 52 |
+
# Split t into three parts: left, middle (to be transformed), and right
|
| 53 |
+
t_left = t[..., :start_index]
|
| 54 |
+
t_middle = t[..., start_index:end_index]
|
| 55 |
+
t_right = t[..., end_index:]
|
| 56 |
+
|
| 57 |
+
# Apply rotary embeddings without modifying t in place
|
| 58 |
+
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
| 59 |
+
|
| 60 |
+
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
| 61 |
+
|
| 62 |
+
return out.type(dtype)
|
| 63 |
+
|
| 64 |
+
# learned rotation helpers
|
| 65 |
+
|
| 66 |
+
def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
|
| 67 |
+
if exists(freq_ranges):
|
| 68 |
+
rotations = einsum('..., f -> ... f', rotations, freq_ranges)
|
| 69 |
+
rotations = rearrange(rotations, '... r f -> ... (r f)')
|
| 70 |
+
|
| 71 |
+
rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
|
| 72 |
+
return apply_rotary_emb(rotations, t, start_index = start_index)
|
| 73 |
+
|
| 74 |
+
# classes
|
| 75 |
+
|
| 76 |
+
class RotaryEmbedding(Module):
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
dim,
|
| 80 |
+
custom_freqs: Tensor | None = None,
|
| 81 |
+
freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang',
|
| 82 |
+
theta = 10000,
|
| 83 |
+
max_freq = 10,
|
| 84 |
+
num_freqs = 1,
|
| 85 |
+
learned_freq = False,
|
| 86 |
+
use_xpos = False,
|
| 87 |
+
xpos_scale_base = 512,
|
| 88 |
+
interpolate_factor = 1.,
|
| 89 |
+
theta_rescale_factor = 1.,
|
| 90 |
+
seq_before_head_dim = False,
|
| 91 |
+
cache_if_possible = True,
|
| 92 |
+
cache_max_seq_len = 8192
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 96 |
+
# has some connection to NTK literature
|
| 97 |
+
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
| 98 |
+
|
| 99 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 100 |
+
|
| 101 |
+
self.freqs_for = freqs_for
|
| 102 |
+
|
| 103 |
+
if exists(custom_freqs):
|
| 104 |
+
freqs = custom_freqs
|
| 105 |
+
elif freqs_for == 'lang':
|
| 106 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 107 |
+
elif freqs_for == 'pixel':
|
| 108 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 109 |
+
elif freqs_for == 'spacetime':
|
| 110 |
+
time_freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
| 111 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
| 112 |
+
elif freqs_for == 'constant':
|
| 113 |
+
freqs = torch.ones(num_freqs).float()
|
| 114 |
+
|
| 115 |
+
if freqs_for == 'spacetime':
|
| 116 |
+
self.time_freqs = nn.Parameter(time_freqs, requires_grad = learned_freq)
|
| 117 |
+
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
|
| 118 |
+
|
| 119 |
+
self.cache_if_possible = cache_if_possible
|
| 120 |
+
self.cache_max_seq_len = cache_max_seq_len
|
| 121 |
+
|
| 122 |
+
self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False)
|
| 123 |
+
self.register_buffer('cached_freqs_seq_len', torch.tensor(0), persistent = False)
|
| 124 |
+
|
| 125 |
+
self.learned_freq = learned_freq
|
| 126 |
+
|
| 127 |
+
# dummy for device
|
| 128 |
+
|
| 129 |
+
self.register_buffer('dummy', torch.tensor(0), persistent = False)
|
| 130 |
+
|
| 131 |
+
# default sequence dimension
|
| 132 |
+
|
| 133 |
+
self.seq_before_head_dim = seq_before_head_dim
|
| 134 |
+
self.default_seq_dim = -3 if seq_before_head_dim else -2
|
| 135 |
+
|
| 136 |
+
# interpolation factors
|
| 137 |
+
|
| 138 |
+
assert interpolate_factor >= 1.
|
| 139 |
+
self.interpolate_factor = interpolate_factor
|
| 140 |
+
|
| 141 |
+
# xpos
|
| 142 |
+
|
| 143 |
+
self.use_xpos = use_xpos
|
| 144 |
+
|
| 145 |
+
if not use_xpos:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
| 149 |
+
self.scale_base = xpos_scale_base
|
| 150 |
+
|
| 151 |
+
self.register_buffer('scale', scale, persistent = False)
|
| 152 |
+
self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False)
|
| 153 |
+
self.register_buffer('cached_scales_seq_len', torch.tensor(0), persistent = False)
|
| 154 |
+
|
| 155 |
+
# add apply_rotary_emb as static method
|
| 156 |
+
|
| 157 |
+
self.apply_rotary_emb = staticmethod(apply_rotary_emb)
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def device(self):
|
| 161 |
+
return self.dummy.device
|
| 162 |
+
|
| 163 |
+
def get_seq_pos(self, seq_len, device, dtype, offset = 0):
|
| 164 |
+
return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
|
| 165 |
+
|
| 166 |
+
def rotate_queries_or_keys(self, t, freqs, seq_dim = None, offset = 0, scale = None):
|
| 167 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 168 |
+
|
| 169 |
+
assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
|
| 170 |
+
|
| 171 |
+
device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
|
| 172 |
+
|
| 173 |
+
seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset)
|
| 174 |
+
|
| 175 |
+
seq_freqs = self.forward(seq, freqs, seq_len = seq_len, offset = offset)
|
| 176 |
+
|
| 177 |
+
if seq_dim == -3:
|
| 178 |
+
seq_freqs = rearrange(seq_freqs, 'n d -> n 1 d')
|
| 179 |
+
|
| 180 |
+
return apply_rotary_emb(seq_freqs, t, scale = default(scale, 1.), seq_dim = seq_dim)
|
| 181 |
+
|
| 182 |
+
def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
|
| 183 |
+
dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
|
| 184 |
+
|
| 185 |
+
q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
|
| 186 |
+
assert q_len <= k_len
|
| 187 |
+
|
| 188 |
+
q_scale = k_scale = 1.
|
| 189 |
+
|
| 190 |
+
if self.use_xpos:
|
| 191 |
+
seq = self.get_seq_pos(k_len, dtype = dtype, device = device)
|
| 192 |
+
|
| 193 |
+
q_scale = self.get_scale(seq[-q_len:]).type(dtype)
|
| 194 |
+
k_scale = self.get_scale(seq).type(dtype)
|
| 195 |
+
|
| 196 |
+
rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset)
|
| 197 |
+
rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1)
|
| 198 |
+
|
| 199 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 200 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 201 |
+
|
| 202 |
+
return rotated_q, rotated_k
|
| 203 |
+
|
| 204 |
+
def rotate_queries_and_keys(self, q, k, freqs, seq_dim = None):
|
| 205 |
+
seq_dim = default(seq_dim, self.default_seq_dim)
|
| 206 |
+
|
| 207 |
+
assert self.use_xpos
|
| 208 |
+
device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
|
| 209 |
+
|
| 210 |
+
seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
|
| 211 |
+
|
| 212 |
+
seq_freqs = self.forward(seq, freqs, seq_len = seq_len)
|
| 213 |
+
scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
|
| 214 |
+
|
| 215 |
+
if seq_dim == -3:
|
| 216 |
+
seq_freqs = rearrange(seq_freqs, 'n d -> n 1 d')
|
| 217 |
+
scale = rearrange(scale, 'n d -> n 1 d')
|
| 218 |
+
|
| 219 |
+
rotated_q = apply_rotary_emb(seq_freqs, q, scale = scale, seq_dim = seq_dim)
|
| 220 |
+
rotated_k = apply_rotary_emb(seq_freqs, k, scale = scale ** -1, seq_dim = seq_dim)
|
| 221 |
+
|
| 222 |
+
rotated_q = rotated_q.type(q.dtype)
|
| 223 |
+
rotated_k = rotated_k.type(k.dtype)
|
| 224 |
+
|
| 225 |
+
return rotated_q, rotated_k
|
| 226 |
+
|
| 227 |
+
def get_scale(
|
| 228 |
+
self,
|
| 229 |
+
t: Tensor,
|
| 230 |
+
seq_len: int | None = None,
|
| 231 |
+
offset = 0
|
| 232 |
+
):
|
| 233 |
+
assert self.use_xpos
|
| 234 |
+
|
| 235 |
+
should_cache = (
|
| 236 |
+
self.cache_if_possible and
|
| 237 |
+
exists(seq_len) and
|
| 238 |
+
(offset + seq_len) <= self.cache_max_seq_len
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if (
|
| 242 |
+
should_cache and \
|
| 243 |
+
exists(self.cached_scales) and \
|
| 244 |
+
(seq_len + offset) <= self.cached_scales_seq_len.item()
|
| 245 |
+
):
|
| 246 |
+
return self.cached_scales[offset:(offset + seq_len)]
|
| 247 |
+
|
| 248 |
+
scale = 1.
|
| 249 |
+
if self.use_xpos:
|
| 250 |
+
power = (t - len(t) // 2) / self.scale_base
|
| 251 |
+
scale = self.scale ** rearrange(power, 'n -> n 1')
|
| 252 |
+
scale = repeat(scale, 'n d -> n (d r)', r = 2)
|
| 253 |
+
|
| 254 |
+
if should_cache and offset == 0:
|
| 255 |
+
self.cached_scales[:seq_len] = scale.detach()
|
| 256 |
+
self.cached_scales_seq_len.copy_(seq_len)
|
| 257 |
+
|
| 258 |
+
return scale
|
| 259 |
+
|
| 260 |
+
def get_axial_freqs(self, *dims):
|
| 261 |
+
Colon = slice(None)
|
| 262 |
+
all_freqs = []
|
| 263 |
+
|
| 264 |
+
for ind, dim in enumerate(dims):
|
| 265 |
+
# only allow pixel freqs for last two dimensions
|
| 266 |
+
use_pixel = (self.freqs_for == 'pixel' or self.freqs_for == 'spacetime') and ind >= len(dims) - 2
|
| 267 |
+
if use_pixel:
|
| 268 |
+
pos = torch.linspace(-1, 1, steps = dim, device = self.device)
|
| 269 |
+
else:
|
| 270 |
+
pos = torch.arange(dim, device = self.device)
|
| 271 |
+
|
| 272 |
+
if self.freqs_for == 'spacetime' and not use_pixel:
|
| 273 |
+
seq_freqs = self.forward(pos, self.time_freqs, seq_len = dim)
|
| 274 |
+
else:
|
| 275 |
+
seq_freqs = self.forward(pos, self.freqs, seq_len = dim)
|
| 276 |
+
|
| 277 |
+
all_axis = [None] * len(dims)
|
| 278 |
+
all_axis[ind] = Colon
|
| 279 |
+
|
| 280 |
+
new_axis_slice = (Ellipsis, *all_axis, Colon)
|
| 281 |
+
all_freqs.append(seq_freqs[new_axis_slice])
|
| 282 |
+
|
| 283 |
+
all_freqs = broadcast_tensors(*all_freqs)
|
| 284 |
+
return torch.cat(all_freqs, dim = -1)
|
| 285 |
+
|
| 286 |
+
@autocast('cuda', enabled = False)
|
| 287 |
+
def forward(
|
| 288 |
+
self,
|
| 289 |
+
t: Tensor,
|
| 290 |
+
freqs: Tensor,
|
| 291 |
+
seq_len = None,
|
| 292 |
+
offset = 0
|
| 293 |
+
):
|
| 294 |
+
should_cache = (
|
| 295 |
+
self.cache_if_possible and
|
| 296 |
+
not self.learned_freq and
|
| 297 |
+
exists(seq_len) and
|
| 298 |
+
self.freqs_for != 'pixel' and
|
| 299 |
+
(offset + seq_len) <= self.cache_max_seq_len
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
if (
|
| 303 |
+
should_cache and \
|
| 304 |
+
exists(self.cached_freqs) and \
|
| 305 |
+
(offset + seq_len) <= self.cached_freqs_seq_len.item()
|
| 306 |
+
):
|
| 307 |
+
return self.cached_freqs[offset:(offset + seq_len)].detach()
|
| 308 |
+
|
| 309 |
+
freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
|
| 310 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
| 311 |
+
|
| 312 |
+
if should_cache and offset == 0:
|
| 313 |
+
self.cached_freqs[:seq_len] = freqs.detach()
|
| 314 |
+
self.cached_freqs_seq_len.copy_(seq_len)
|
| 315 |
+
|
| 316 |
+
return freqs
|
open_oasis_master/sample_data/Player729-f153ac423f61-20210806-224813.chunk_000.actions.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cc3ea8894f87e2c2c2387dd32b193f27a8a95009397c32b5fbaf8a6f23608b0c
|
| 3 |
+
size 230180
|
open_oasis_master/sample_data/Player729-f153ac423f61-20210806-224813.chunk_000.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9fb1cf3a87be9deca2fec2e946427521a85026ee607cf9281aa87f6df447e4ea
|
| 3 |
+
size 6818283
|
open_oasis_master/sample_data/snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001.actions.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:955929d771293156d3f27d295091a978dcd97fdaa78e3a17395ac90c0403004d
|
| 3 |
+
size 230308
|
open_oasis_master/sample_data/snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:745b0348a014d943f70ccf6ccba17ad260540caba502b312d972235326003ab0
|
| 3 |
+
size 7109171
|
open_oasis_master/sample_data/treechop-f153ac423f61-20210916-183423.chunk_000.actions.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:46ae60cc9d3a02df949923c707df4c5cd3f49d279aa6500c81f0ef00c14f7747
|
| 3 |
+
size 230176
|
open_oasis_master/sample_data/treechop-f153ac423f61-20210916-183423.chunk_000.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a0ad584df52d7b2636fae5d7a3116f596f25a09ba7d28ff5fc42193105605d92
|
| 3 |
+
size 8716515
|
open_oasis_master/utils.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
|
| 3 |
+
Action format derived from VPT https://github.com/openai/Video-Pre-Training
|
| 4 |
+
"""
|
| 5 |
+
import math
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
from einops import rearrange, parse_shape
|
| 9 |
+
from typing import Mapping, Sequence
|
| 10 |
+
import torch
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
|
| 15 |
+
"""
|
| 16 |
+
sigmoid schedule
|
| 17 |
+
proposed in https://arxiv.org/abs/2212.11972 - Figure 8
|
| 18 |
+
better for images > 64x64, when used during training
|
| 19 |
+
"""
|
| 20 |
+
steps = timesteps + 1
|
| 21 |
+
t = torch.linspace(0, timesteps, steps, dtype=torch.float32) / timesteps
|
| 22 |
+
v_start = torch.tensor(start / tau).sigmoid()
|
| 23 |
+
v_end = torch.tensor(end / tau).sigmoid()
|
| 24 |
+
alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
|
| 25 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
| 26 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
| 27 |
+
return torch.clip(betas, 0, 0.999)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
ACTION_KEYS = [
|
| 31 |
+
"inventory",
|
| 32 |
+
"ESC",
|
| 33 |
+
"hotbar.1",
|
| 34 |
+
"hotbar.2",
|
| 35 |
+
"hotbar.3",
|
| 36 |
+
"hotbar.4",
|
| 37 |
+
"hotbar.5",
|
| 38 |
+
"hotbar.6",
|
| 39 |
+
"hotbar.7",
|
| 40 |
+
"hotbar.8",
|
| 41 |
+
"hotbar.9",
|
| 42 |
+
"forward",
|
| 43 |
+
"back",
|
| 44 |
+
"left",
|
| 45 |
+
"right",
|
| 46 |
+
"cameraX",
|
| 47 |
+
"cameraY",
|
| 48 |
+
"jump",
|
| 49 |
+
"sneak",
|
| 50 |
+
"sprint",
|
| 51 |
+
"swapHands",
|
| 52 |
+
"attack",
|
| 53 |
+
"use",
|
| 54 |
+
"pickItem",
|
| 55 |
+
"drop",
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
|
| 59 |
+
actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
|
| 60 |
+
for i, current_actions in enumerate(actions):
|
| 61 |
+
for j, action_key in enumerate(ACTION_KEYS):
|
| 62 |
+
if action_key.startswith("camera"):
|
| 63 |
+
if action_key == "cameraX":
|
| 64 |
+
value = current_actions["camera"][0]
|
| 65 |
+
elif action_key == "cameraY":
|
| 66 |
+
value = current_actions["camera"][1]
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Unknown camera action key: {action_key}")
|
| 69 |
+
# NOTE these numbers specific to the camera quantization used in
|
| 70 |
+
# https://github.com/etched-ai/dreamcraft/blob/216e952f795bb3da598639a109bcdba4d2067b69/spark/preprocess_vpt_to_videos_actions.py#L312
|
| 71 |
+
# see method `compress_mouse`
|
| 72 |
+
max_val = 20
|
| 73 |
+
bin_size = 0.5
|
| 74 |
+
num_buckets = int(max_val / bin_size)
|
| 75 |
+
value = (value - num_buckets) / num_buckets
|
| 76 |
+
assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
|
| 77 |
+
else:
|
| 78 |
+
value = current_actions[action_key]
|
| 79 |
+
assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
|
| 80 |
+
actions_one_hot[i, j] = value
|
| 81 |
+
|
| 82 |
+
return actions_one_hot
|
open_oasis_master/vae.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
References:
|
| 3 |
+
- VQGAN: https://github.com/CompVis/taming-transformers
|
| 4 |
+
- MAE: https://github.com/facebookresearch/mae
|
| 5 |
+
"""
|
| 6 |
+
import numpy as np
|
| 7 |
+
import math
|
| 8 |
+
import functools
|
| 9 |
+
from collections import namedtuple
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from einops import rearrange
|
| 14 |
+
from timm.models.vision_transformer import Mlp
|
| 15 |
+
from timm.layers.helpers import to_2tuple
|
| 16 |
+
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
|
| 17 |
+
from dit import PatchEmbed
|
| 18 |
+
|
| 19 |
+
class DiagonalGaussianDistribution(object):
|
| 20 |
+
def __init__(self, parameters, deterministic=False, dim=1):
|
| 21 |
+
self.parameters = parameters
|
| 22 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
| 23 |
+
if dim == 1:
|
| 24 |
+
self.dims = [1, 2, 3]
|
| 25 |
+
elif dim == 2:
|
| 26 |
+
self.dims = [1, 2]
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 30 |
+
self.deterministic = deterministic
|
| 31 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 32 |
+
self.var = torch.exp(self.logvar)
|
| 33 |
+
if self.deterministic:
|
| 34 |
+
self.var = self.std = torch.zeros_like(self.mean).to(
|
| 35 |
+
device=self.parameters.device
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def sample(self):
|
| 39 |
+
x = self.mean + self.std * torch.randn(self.mean.shape).to(
|
| 40 |
+
device=self.parameters.device
|
| 41 |
+
)
|
| 42 |
+
return x
|
| 43 |
+
|
| 44 |
+
def mode(self):
|
| 45 |
+
return self.mean
|
| 46 |
+
|
| 47 |
+
class Attention(nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
dim,
|
| 51 |
+
num_heads,
|
| 52 |
+
frame_height,
|
| 53 |
+
frame_width,
|
| 54 |
+
qkv_bias=False,
|
| 55 |
+
attn_drop=0.0,
|
| 56 |
+
proj_drop=0.0,
|
| 57 |
+
is_causal=False,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.num_heads = num_heads
|
| 61 |
+
head_dim = dim // num_heads
|
| 62 |
+
self.frame_height = frame_height
|
| 63 |
+
self.frame_width = frame_width
|
| 64 |
+
|
| 65 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 66 |
+
self.attn_drop = attn_drop
|
| 67 |
+
self.proj = nn.Linear(dim, dim)
|
| 68 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 69 |
+
self.is_causal = is_causal
|
| 70 |
+
|
| 71 |
+
rotary_freqs = RotaryEmbedding(
|
| 72 |
+
dim=head_dim // 4,
|
| 73 |
+
freqs_for="pixel",
|
| 74 |
+
max_freq=frame_height*frame_width,
|
| 75 |
+
).get_axial_freqs(frame_height, frame_width)
|
| 76 |
+
self.register_buffer("rotary_freqs", rotary_freqs, persistent=False)
|
| 77 |
+
|
| 78 |
+
def forward(self, x):
|
| 79 |
+
B, N, C = x.shape
|
| 80 |
+
assert N == self.frame_height * self.frame_width
|
| 81 |
+
|
| 82 |
+
qkv = (
|
| 83 |
+
self.qkv(x)
|
| 84 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 85 |
+
.permute(2, 0, 3, 1, 4)
|
| 86 |
+
)
|
| 87 |
+
q, k, v = (
|
| 88 |
+
qkv[0],
|
| 89 |
+
qkv[1],
|
| 90 |
+
qkv[2],
|
| 91 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
| 92 |
+
|
| 93 |
+
if self.rotary_freqs is not None:
|
| 94 |
+
q = rearrange(q, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width)
|
| 95 |
+
k = rearrange(k, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width)
|
| 96 |
+
q = apply_rotary_emb(self.rotary_freqs, q)
|
| 97 |
+
k = apply_rotary_emb(self.rotary_freqs, k)
|
| 98 |
+
q = rearrange(q, "b h H W d -> b h (H W) d")
|
| 99 |
+
k = rearrange(k, "b h H W d -> b h (H W) d")
|
| 100 |
+
|
| 101 |
+
attn = F.scaled_dot_product_attention(
|
| 102 |
+
q,
|
| 103 |
+
k,
|
| 104 |
+
v,
|
| 105 |
+
dropout_p=self.attn_drop,
|
| 106 |
+
is_causal=self.is_causal,
|
| 107 |
+
)
|
| 108 |
+
x = attn.transpose(1, 2).reshape(B, N, C)
|
| 109 |
+
|
| 110 |
+
x = self.proj(x)
|
| 111 |
+
x = self.proj_drop(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class AttentionBlock(nn.Module):
|
| 116 |
+
def __init__(
|
| 117 |
+
self,
|
| 118 |
+
dim,
|
| 119 |
+
num_heads,
|
| 120 |
+
frame_height,
|
| 121 |
+
frame_width,
|
| 122 |
+
mlp_ratio=4.0,
|
| 123 |
+
qkv_bias=False,
|
| 124 |
+
drop=0.0,
|
| 125 |
+
attn_drop=0.0,
|
| 126 |
+
attn_causal=False,
|
| 127 |
+
drop_path=0.0,
|
| 128 |
+
act_layer=nn.GELU,
|
| 129 |
+
norm_layer=nn.LayerNorm,
|
| 130 |
+
):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.norm1 = norm_layer(dim)
|
| 133 |
+
self.attn = Attention(
|
| 134 |
+
dim,
|
| 135 |
+
num_heads,
|
| 136 |
+
frame_height,
|
| 137 |
+
frame_width,
|
| 138 |
+
qkv_bias=qkv_bias,
|
| 139 |
+
attn_drop=attn_drop,
|
| 140 |
+
proj_drop=drop,
|
| 141 |
+
is_causal=attn_causal,
|
| 142 |
+
)
|
| 143 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 144 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 145 |
+
self.norm2 = norm_layer(dim)
|
| 146 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 147 |
+
self.mlp = Mlp(
|
| 148 |
+
in_features=dim,
|
| 149 |
+
hidden_features=mlp_hidden_dim,
|
| 150 |
+
act_layer=act_layer,
|
| 151 |
+
drop=drop,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def forward(self, x):
|
| 155 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 156 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class AutoencoderKL(nn.Module):
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
latent_dim,
|
| 164 |
+
input_height=256,
|
| 165 |
+
input_width=256,
|
| 166 |
+
patch_size=16,
|
| 167 |
+
enc_dim=768,
|
| 168 |
+
enc_depth=6,
|
| 169 |
+
enc_heads=12,
|
| 170 |
+
dec_dim=768,
|
| 171 |
+
dec_depth=6,
|
| 172 |
+
dec_heads=12,
|
| 173 |
+
mlp_ratio=4.0,
|
| 174 |
+
norm_layer=functools.partial(nn.LayerNorm, eps=1e-6),
|
| 175 |
+
use_variational=True,
|
| 176 |
+
**kwargs,
|
| 177 |
+
):
|
| 178 |
+
super().__init__()
|
| 179 |
+
self.input_height = input_height
|
| 180 |
+
self.input_width = input_width
|
| 181 |
+
self.patch_size = patch_size
|
| 182 |
+
self.seq_h = input_height // patch_size
|
| 183 |
+
self.seq_w = input_width // patch_size
|
| 184 |
+
self.seq_len = self.seq_h * self.seq_w
|
| 185 |
+
self.patch_dim = 3 * patch_size**2
|
| 186 |
+
|
| 187 |
+
self.latent_dim = latent_dim
|
| 188 |
+
self.enc_dim = enc_dim
|
| 189 |
+
self.dec_dim = dec_dim
|
| 190 |
+
|
| 191 |
+
# patch
|
| 192 |
+
self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim)
|
| 193 |
+
|
| 194 |
+
# encoder
|
| 195 |
+
self.encoder = nn.ModuleList(
|
| 196 |
+
[
|
| 197 |
+
AttentionBlock(
|
| 198 |
+
enc_dim,
|
| 199 |
+
enc_heads,
|
| 200 |
+
self.seq_h,
|
| 201 |
+
self.seq_w,
|
| 202 |
+
mlp_ratio,
|
| 203 |
+
qkv_bias=True,
|
| 204 |
+
norm_layer=norm_layer,
|
| 205 |
+
)
|
| 206 |
+
for i in range(enc_depth)
|
| 207 |
+
]
|
| 208 |
+
)
|
| 209 |
+
self.enc_norm = norm_layer(enc_dim)
|
| 210 |
+
|
| 211 |
+
# bottleneck
|
| 212 |
+
self.use_variational = use_variational
|
| 213 |
+
mult = 2 if self.use_variational else 1
|
| 214 |
+
self.quant_conv = nn.Linear(enc_dim, mult * latent_dim)
|
| 215 |
+
self.post_quant_conv = nn.Linear(latent_dim, dec_dim)
|
| 216 |
+
|
| 217 |
+
# decoder
|
| 218 |
+
self.decoder = nn.ModuleList(
|
| 219 |
+
[
|
| 220 |
+
AttentionBlock(
|
| 221 |
+
dec_dim,
|
| 222 |
+
dec_heads,
|
| 223 |
+
self.seq_h,
|
| 224 |
+
self.seq_w,
|
| 225 |
+
mlp_ratio,
|
| 226 |
+
qkv_bias=True,
|
| 227 |
+
norm_layer=norm_layer,
|
| 228 |
+
)
|
| 229 |
+
for i in range(dec_depth)
|
| 230 |
+
]
|
| 231 |
+
)
|
| 232 |
+
self.dec_norm = norm_layer(dec_dim)
|
| 233 |
+
self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch
|
| 234 |
+
|
| 235 |
+
# initialize this weight first
|
| 236 |
+
self.initialize_weights()
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def initialize_weights(self):
|
| 240 |
+
# initialization
|
| 241 |
+
# initialize nn.Linear and nn.LayerNorm
|
| 242 |
+
self.apply(self._init_weights)
|
| 243 |
+
|
| 244 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
| 245 |
+
w = self.patch_embed.proj.weight.data
|
| 246 |
+
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 247 |
+
|
| 248 |
+
def _init_weights(self, m):
|
| 249 |
+
if isinstance(m, nn.Linear):
|
| 250 |
+
# we use xavier_uniform following official JAX ViT:
|
| 251 |
+
nn.init.xavier_uniform_(m.weight)
|
| 252 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 253 |
+
nn.init.constant_(m.bias, 0.0)
|
| 254 |
+
elif isinstance(m, nn.LayerNorm):
|
| 255 |
+
nn.init.constant_(m.bias, 0.0)
|
| 256 |
+
nn.init.constant_(m.weight, 1.0)
|
| 257 |
+
|
| 258 |
+
def patchify(self, x):
|
| 259 |
+
# patchify
|
| 260 |
+
bsz, _, h, w = x.shape
|
| 261 |
+
x = x.reshape(
|
| 262 |
+
bsz,
|
| 263 |
+
3,
|
| 264 |
+
self.seq_h,
|
| 265 |
+
self.patch_size,
|
| 266 |
+
self.seq_w,
|
| 267 |
+
self.patch_size,
|
| 268 |
+
).permute(
|
| 269 |
+
[0, 1, 3, 5, 2, 4]
|
| 270 |
+
) # [b, c, h, p, w, p] --> [b, c, p, p, h, w]
|
| 271 |
+
x = x.reshape(
|
| 272 |
+
bsz, self.patch_dim, self.seq_h, self.seq_w
|
| 273 |
+
) # --> [b, cxpxp, h, w]
|
| 274 |
+
x = x.permute([0, 2, 3, 1]).reshape(
|
| 275 |
+
bsz, self.seq_len, self.patch_dim
|
| 276 |
+
) # --> [b, hxw, cxpxp]
|
| 277 |
+
return x
|
| 278 |
+
|
| 279 |
+
def unpatchify(self, x):
|
| 280 |
+
bsz = x.shape[0]
|
| 281 |
+
# unpatchify
|
| 282 |
+
x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute(
|
| 283 |
+
[0, 3, 1, 2]
|
| 284 |
+
) # [b, h, w, cxpxp] --> [b, cxpxp, h, w]
|
| 285 |
+
x = x.reshape(
|
| 286 |
+
bsz,
|
| 287 |
+
3,
|
| 288 |
+
self.patch_size,
|
| 289 |
+
self.patch_size,
|
| 290 |
+
self.seq_h,
|
| 291 |
+
self.seq_w,
|
| 292 |
+
).permute(
|
| 293 |
+
[0, 1, 4, 2, 5, 3]
|
| 294 |
+
) # [b, c, p, p, h, w] --> [b, c, h, p, w, p]
|
| 295 |
+
x = x.reshape(
|
| 296 |
+
bsz,
|
| 297 |
+
3,
|
| 298 |
+
self.input_height,
|
| 299 |
+
self.input_width,
|
| 300 |
+
) # [b, c, hxp, wxp]
|
| 301 |
+
return x
|
| 302 |
+
|
| 303 |
+
def encode(self, x):
|
| 304 |
+
# patchify
|
| 305 |
+
x = self.patch_embed(x)
|
| 306 |
+
|
| 307 |
+
# encoder
|
| 308 |
+
for blk in self.encoder:
|
| 309 |
+
x = blk(x)
|
| 310 |
+
x = self.enc_norm(x)
|
| 311 |
+
|
| 312 |
+
# bottleneck
|
| 313 |
+
moments = self.quant_conv(x)
|
| 314 |
+
if not self.use_variational:
|
| 315 |
+
moments = torch.cat((moments, torch.zeros_like(moments)), 2)
|
| 316 |
+
posterior = DiagonalGaussianDistribution(
|
| 317 |
+
moments, deterministic=(not self.use_variational), dim=2
|
| 318 |
+
)
|
| 319 |
+
return posterior
|
| 320 |
+
|
| 321 |
+
def decode(self, z):
|
| 322 |
+
# bottleneck
|
| 323 |
+
z = self.post_quant_conv(z)
|
| 324 |
+
|
| 325 |
+
# decoder
|
| 326 |
+
for blk in self.decoder:
|
| 327 |
+
z = blk(z)
|
| 328 |
+
z = self.dec_norm(z)
|
| 329 |
+
|
| 330 |
+
# predictor
|
| 331 |
+
z = self.predictor(z)
|
| 332 |
+
|
| 333 |
+
# unpatchify
|
| 334 |
+
dec = self.unpatchify(z)
|
| 335 |
+
return dec
|
| 336 |
+
|
| 337 |
+
def autoencode(self, input, sample_posterior=True):
|
| 338 |
+
posterior = self.encode(input)
|
| 339 |
+
if self.use_variational and sample_posterior:
|
| 340 |
+
z = posterior.sample()
|
| 341 |
+
else:
|
| 342 |
+
z = posterior.mode()
|
| 343 |
+
dec = self.decode(z)
|
| 344 |
+
return dec, posterior, z
|
| 345 |
+
|
| 346 |
+
def get_input(self, batch, k):
|
| 347 |
+
x = batch[k]
|
| 348 |
+
if len(x.shape) == 3:
|
| 349 |
+
x = x[..., None]
|
| 350 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
| 351 |
+
return x
|
| 352 |
+
|
| 353 |
+
def forward(self, inputs, labels, split="train"):
|
| 354 |
+
rec, post, latent = self.autoencode(inputs)
|
| 355 |
+
return rec, post, latent
|
| 356 |
+
|
| 357 |
+
def get_last_layer(self):
|
| 358 |
+
return self.predictor.weight
|
| 359 |
+
|
| 360 |
+
def ViT_L_20_Shallow_Encoder(**kwargs):
|
| 361 |
+
if "latent_dim" in kwargs:
|
| 362 |
+
latent_dim = kwargs.pop("latent_dim")
|
| 363 |
+
else:
|
| 364 |
+
latent_dim = 16
|
| 365 |
+
return AutoencoderKL(
|
| 366 |
+
latent_dim=latent_dim,
|
| 367 |
+
patch_size=20,
|
| 368 |
+
enc_dim=1024,
|
| 369 |
+
enc_depth=6,
|
| 370 |
+
enc_heads=16,
|
| 371 |
+
dec_dim=1024,
|
| 372 |
+
dec_depth=12,
|
| 373 |
+
dec_heads=16,
|
| 374 |
+
input_height=360,
|
| 375 |
+
input_width=640,
|
| 376 |
+
**kwargs,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
VAE_models = {
|
| 380 |
+
"vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder,
|
| 381 |
+
}
|