File size: 9,354 Bytes
8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 8746765 8ef5ca9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import torch as t
from torch import nn
import torch.nn.functional as F
from ..nn.attn import AttentionEinOps, KVCache, KVCacheNaive
from ..nn.patch import Patch, UnPatch
from ..nn.geglu import GEGLU
from ..nn.pe import NumericEncoding, RoPE, LearnRoPE, VidRoPE
from jaxtyping import Float, Bool, Int
from torch import Tensor
from typing import Optional, Literal
import math
def modulate(x, shift, scale):
b, s, d = x.shape
toks_per_frame = s // shift.shape[1]
x = x.reshape(b, -1, toks_per_frame, d)
x = x * (1 + scale[:, :, None, :]) + shift[:, :, None, :]
x = x.reshape(b, s, d)
return x
def gate(x, gate):
b, s, d = x.shape
toks_per_frame = s // gate.shape[1]
x = x.reshape(b, -1, toks_per_frame, d)
x = x * gate[:, :, None, :]
x = x.reshape(b, s, d)
return x
def modulate_deprecated(x, shift, scale):
return x * (1 + scale) + shift
class CausalBlock(nn.Module):
def __init__(self, layer_idx, d_model, expansion, n_heads, rope=None, ln_first = False):
super().__init__()
self.layer_idx = layer_idx
self.d_model = d_model
self.expansion = expansion
self.n_heads = n_heads
self.norm1 = nn.LayerNorm(d_model)
self.selfattn = AttentionEinOps(d_model, n_heads, rope=rope, ln_first=ln_first)
self.norm2 = nn.LayerNorm(d_model)
self.geglu = GEGLU(d_model, expansion*d_model, d_model)
self.modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(d_model, 6 * d_model, bias=True),
)
def forward(self, z, cond, mask_self, cached_k=None, cached_v=None):
# batch durseq1 d
# batch durseq2 d
mu1, sigma1, c1, mu2, sigma2, c2 = self.modulation(cond).chunk(6, dim=-1)
residual = z
z = modulate(self.norm1(z), mu1, sigma1)
z, k_new, v_new = self.selfattn(z, z, mask=mask_self, k_cache=cached_k, v_cache=cached_v)
z = residual + gate(z, c1)
residual = z
z = modulate(self.norm2(z), mu2, sigma2)
z = self.geglu(z)
z = residual + gate(z, c2)
return z, k_new, v_new
class CausalDit(nn.Module):
def __init__(self, height, width, n_window, d_model, T=1000, in_channels=3,
patch_size=2, n_heads=8, expansion=4, n_blocks=6,
n_registers=1, n_actions=4, bidirectional=False,
debug=False,
rope_C=10000,
rope_tmax=None,
rope_type: Literal["rope", "learn", "vid"] = "rope",
ln_first: bool = False):
super().__init__()
self.height = height
self.width = width
self.n_window = n_window
self.d_model = d_model
self.n_heads = n_heads
self.d_head = self.d_model // self.n_heads
self.n_blocks = n_blocks
self.expansion = expansion
self.n_registers = n_registers
self.T = T
self.patch_size = patch_size
self.debug = debug
self.bidirectional = bidirectional
self.toks_per_frame = (height//patch_size)*(width//patch_size) + n_registers
self.rope_C = rope_C
if rope_tmax is None:
rope_tmax = self.n_window*self.toks_per_frame
if rope_type == "rope":
self.rope_seq = RoPE(d_model//n_heads, rope_tmax, C=rope_C)
elif rope_type == "learn":
self.rope_seq = LearnRoPE(d_model//n_heads, rope_tmax, C=rope_C)
elif rope_type == "vid":
d_head = d_model//n_heads
d_x = d_y = d_t = d_head // 3
C_x = C_y = C_t = rope_C // 3
ctx_x = width // patch_size
ctx_y = height // patch_size
ctx_t = self.n_window
self.rope_seq = VidRoPE(d_head,
d_x,
d_y,
d_t,
ctx_x,
ctx_y,
ctx_t,
C_x,
C_y,
C_t,
self.toks_per_frame,
n_registers)
self.grid_pe = None
self.rope_tmax = rope_tmax
self.blocks = nn.ModuleList([CausalBlock(lidx, d_model, expansion, n_heads, rope=self.rope_seq, ln_first=ln_first) for lidx in range(n_blocks)])
self.patch = Patch(in_channels=in_channels, out_channels=d_model, patch_size=patch_size)
self.norm = nn.LayerNorm(d_model)
self.unpatch = UnPatch(height, width, in_channels=d_model, out_channels=in_channels, patch_size=patch_size)
self.action_emb = nn.Embedding(n_actions, d_model)
self.registers = nn.Parameter(t.randn(n_registers, d_model) * 1/d_model**0.5)
self.time_emb = NumericEncoding(dim=d_model, n_max=T)
self.time_emb_mixer = nn.Linear(d_model, d_model)
self.modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(d_model, 2 * d_model, bias=True),
)
self.cache = None
def create_cache(self, batch_size):
return KVCache(batch_size, self.n_blocks, self.n_heads, self.d_head, self.toks_per_frame, self.n_window, dtype=self.dtype, device=self.device)
def create_cache2(self, batch_size):
return KVCacheNaive(batch_size, self.n_blocks, self.n_heads, self.d_head, self.toks_per_frame, self.n_window, dtype=self.dtype, device=self.device)
def forward(self,
z: Float[Tensor, "batch dur channels height width"],
actions: Float[Tensor, "batch dur"],
ts: Int[Tensor, "batch dur"],
cached_k: Optional[Float[Tensor, "layer batch dur seq d"]] = None,
cached_v: Optional[Float[Tensor, "layer batch dur seq d"]] = None):
if ts.shape[1] == 1:
ts = ts.repeat(1, z.shape[1])
a = self.action_emb(actions) # batch dur d
ts_scaled = (ts * self.T).clamp(0, self.T - 1).long()
cond = self.time_emb_mixer(self.time_emb(ts_scaled)) + a
z = self.patch(z) # batch dur seq d
if self.grid_pe is not None:
z = z + self.grid_pe[None, None]
# self.registers is in 1x
zr = t.cat((z, self.registers[None, None].repeat([z.shape[0], z.shape[1], 1, 1])), dim=2)# z plus registers
if self.bidirectional:
mask_self = None
else:
mask_self = self.causal_mask
batch, durzr, seqzr, d = zr.shape
zr = zr.reshape(batch, -1, d) # batch durseq d
k_update = []
v_update = []
for bidx, block in enumerate(self.blocks):
ks = cached_k[bidx] if cached_k is not None else None
vs = cached_v[bidx] if cached_v is not None else None
zr, k_new, v_new = block(zr, cond, mask_self, cached_k=ks, cached_v=vs)
if k_new is not None:
k_update.append(k_new.unsqueeze(0))
v_update.append(v_new.unsqueeze(0))
if len(k_update) > 0:
k_update = t.cat(k_update, dim=0)
v_update = t.cat(v_update, dim=0)
mu, sigma = self.modulation(cond).chunk(2, dim=-1)
zr = modulate(self.norm(zr), mu, sigma)
zr = zr.reshape(batch, durzr, seqzr, d)
out = self.unpatch(zr[:, :, :-self.n_registers])
return out, k_update, v_update
@property
def causal_mask(self):
size = self.n_window
m_self = t.tril(t.ones((size, size), dtype=t.int8, device=self.device)) # - t.tril(t.ones((size, size), dtype=t.int8, device=self.device), diagonal=-self.n_window) # this would be useful if we go bigger than windowxwindow
m_self = t.kron(m_self, t.ones((self.toks_per_frame, self.toks_per_frame), dtype=t.int8, device=self.device))
m_self = m_self.to(bool)
return ~ m_self # we want to mask out the ones
@property
def device(self):
return self.parameters().__next__().device
@property
def dtype(self):
return self.parameters().__next__().dtype
def get_model(height, width,
n_window=5,
d_model=64,
T=100,
n_blocks=2,
patch_size=2,
n_heads=8,
bidirectional=False,
in_channels=3,
C=10000,
rope_type: Literal["rope", "learn", "vid"] = "rope",
ln_first=False):
return CausalDit(height, width,
n_window,
d_model,
T,
in_channels=in_channels,
n_blocks=n_blocks,
patch_size=patch_size,
n_heads=n_heads,
bidirectional=bidirectional,
rope_C=C,
rope_type=rope_type,
ln_first=ln_first)
if __name__ == "__main__":
print("running w/o cache")
dit = CausalDit(20, 20, 100, 64, 5, n_blocks=2)
z = t.rand((2, 6, 3, 20, 20))
actions = t.randint(4, (2, 6))
ts = t.rand((2, 6))
out, _, _ = dit(z, actions, ts)
print(z.shape)
print(out.shape) |