pong / src /models /dit_dforce.py
chrisxx's picture
Update src/models/dit_dforce.py
c67e0c7 verified
import torch as t
from torch import nn
import torch.nn.functional as F
from ..nn.attn import AttentionEinOps, KVCache, KVCacheNaive
from ..nn.patch import Patch, UnPatch
from ..nn.geglu import GEGLU
from ..nn.pe import NumericEncoding, RoPE, LearnRoPE, VidRoPE
from jaxtyping import Float, Bool, Int
from torch import Tensor
from typing import Optional, Literal
import math
def modulate(x, shift, scale):
b, s, d = x.shape
toks_per_frame = s // shift.shape[1]
x = x.reshape(b, -1, toks_per_frame, d)
x = x * (1 + scale[:, :, None, :]) + shift[:, :, None, :]
x = x.reshape(b, s, d)
return x
def gate(x, gate):
b, s, d = x.shape
toks_per_frame = s // gate.shape[1]
x = x.reshape(b, -1, toks_per_frame, d)
x = x * gate[:, :, None, :]
x = x.reshape(b, s, d)
return x
def modulate_deprecated(x, shift, scale):
return x * (1 + scale) + shift
class CausalBlock(nn.Module):
def __init__(self, layer_idx, d_model, expansion, n_heads, rope=None, ln_first = False):
super().__init__()
self.layer_idx = layer_idx
self.d_model = d_model
self.expansion = expansion
self.n_heads = n_heads
self.norm1 = nn.LayerNorm(d_model)
self.selfattn = AttentionEinOps(d_model, n_heads, rope=rope, ln_first=ln_first)
self.norm2 = nn.LayerNorm(d_model)
self.geglu = GEGLU(d_model, expansion*d_model, d_model)
self.modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(d_model, 6 * d_model, bias=True),
)
def forward(self, z, cond, mask_self, cached_k=None, cached_v=None):
# batch durseq1 d
# batch durseq2 d
mu1, sigma1, c1, mu2, sigma2, c2 = self.modulation(cond).chunk(6, dim=-1)
residual = z
z = modulate(self.norm1(z), mu1, sigma1)
z, k_new, v_new = self.selfattn(z, z, mask=mask_self, k_cache=cached_k, v_cache=cached_v)
z = residual + gate(z, c1)
residual = z
z = modulate(self.norm2(z), mu2, sigma2)
z = self.geglu(z)
z = residual + gate(z, c2)
return z, k_new, v_new
class CausalDit(nn.Module):
def __init__(self, height, width, n_window, d_model, T=1000, in_channels=3,
patch_size=2, n_heads=8, expansion=4, n_blocks=6,
n_registers=1, n_actions=4, bidirectional=False,
debug=False,
rope_C=10000,
rope_tmax=None,
rope_type: Literal["rope", "learn", "vid"] = "rope",
ln_first: bool = False):
super().__init__()
self.height = height
self.width = width
self.n_window = n_window
self.d_model = d_model
self.n_heads = n_heads
self.d_head = self.d_model // self.n_heads
self.n_blocks = n_blocks
self.expansion = expansion
self.n_registers = n_registers
self.T = T
self.patch_size = patch_size
self.debug = debug
self.bidirectional = bidirectional
self.toks_per_frame = (height//patch_size)*(width//patch_size) + n_registers
self.rope_C = rope_C
if rope_tmax is None:
rope_tmax = self.n_window*self.toks_per_frame
if rope_type == "rope":
self.rope_seq = RoPE(d_model//n_heads, rope_tmax, C=rope_C)
elif rope_type == "learn":
self.rope_seq = LearnRoPE(d_model//n_heads, rope_tmax, C=rope_C)
elif rope_type == "vid":
d_head = d_model//n_heads
d_x = d_y = d_t = d_head // 3
C_x = C_y = C_t = rope_C // 3
ctx_x = width // patch_size
ctx_y = height // patch_size
ctx_t = self.n_window
self.rope_seq = VidRoPE(d_head,
d_x,
d_y,
d_t,
ctx_x,
ctx_y,
ctx_t,
C_x,
C_y,
C_t,
self.toks_per_frame,
n_registers)
self.grid_pe = None
self.rope_tmax = rope_tmax
self.blocks = nn.ModuleList([CausalBlock(lidx, d_model, expansion, n_heads, rope=self.rope_seq, ln_first=ln_first) for lidx in range(n_blocks)])
self.patch = Patch(in_channels=in_channels, out_channels=d_model, patch_size=patch_size)
self.norm = nn.LayerNorm(d_model)
self.unpatch = UnPatch(height, width, in_channels=d_model, out_channels=in_channels, patch_size=patch_size)
self.action_emb = nn.Embedding(n_actions, d_model)
self.registers = nn.Parameter(t.randn(n_registers, d_model) * 1/d_model**0.5)
self.time_emb = NumericEncoding(dim=d_model, n_max=T)
self.time_emb_mixer = nn.Linear(d_model, d_model)
self.modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(d_model, 2 * d_model, bias=True),
)
self.cache = None
def create_cache(self, batch_size):
return KVCache(batch_size, self.n_blocks, self.n_heads, self.d_head, self.toks_per_frame, self.n_window, dtype=self.dtype, device=self.device)
def create_cache2(self, batch_size):
return KVCacheNaive(batch_size, self.n_blocks, self.n_heads, self.d_head, self.toks_per_frame, self.n_window, dtype=self.dtype, device=self.device)
def forward(self,
z: Float[Tensor, "batch dur channels height width"],
actions: Float[Tensor, "batch dur"],
ts: Int[Tensor, "batch dur"],
cached_k: Optional[Float[Tensor, "layer batch dur seq d"]] = None,
cached_v: Optional[Float[Tensor, "layer batch dur seq d"]] = None):
if ts.shape[1] == 1:
ts = ts.repeat(1, z.shape[1])
a = self.action_emb(actions) # batch dur d
ts_scaled = (ts * self.T).clamp(0, self.T - 1).long()
cond = self.time_emb_mixer(self.time_emb(ts_scaled)) + a
z = self.patch(z) # batch dur seq d
if self.grid_pe is not None:
z = z + self.grid_pe[None, None]
# self.registers is in 1x
zr = t.cat((z, self.registers[None, None].repeat([z.shape[0], z.shape[1], 1, 1])), dim=2)# z plus registers
if self.bidirectional:
mask_self = None
else:
mask_self = self.causal_mask
batch, durzr, seqzr, d = zr.shape
zr = zr.reshape(batch, -1, d) # batch durseq d
k_update = []
v_update = []
for bidx, block in enumerate(self.blocks):
ks = cached_k[bidx] if cached_k is not None else None
vs = cached_v[bidx] if cached_v is not None else None
zr, k_new, v_new = block(zr, cond, mask_self, cached_k=ks, cached_v=vs)
if k_new is not None:
k_update.append(k_new.unsqueeze(0))
v_update.append(v_new.unsqueeze(0))
if len(k_update) > 0:
k_update = t.cat(k_update, dim=0)
v_update = t.cat(v_update, dim=0)
mu, sigma = self.modulation(cond).chunk(2, dim=-1)
zr = modulate(self.norm(zr), mu, sigma)
zr = zr.reshape(batch, durzr, seqzr, d)
out = self.unpatch(zr[:, :, :-self.n_registers])
return out, k_update, v_update
@property
def causal_mask(self):
size = self.n_window
m_self = t.tril(t.ones((size, size), dtype=t.int8, device=self.device)) # - t.tril(t.ones((size, size), dtype=t.int8, device=self.device), diagonal=-self.n_window) # this would be useful if we go bigger than windowxwindow
m_self = t.kron(m_self, t.ones((self.toks_per_frame, self.toks_per_frame), dtype=t.int8, device=self.device))
m_self = m_self.to(bool)
return ~ m_self # we want to mask out the ones
@property
def device(self):
return self.parameters().__next__().device
@property
def dtype(self):
return self.parameters().__next__().dtype
def get_model(height, width,
n_window=5,
d_model=64,
T=100,
n_blocks=2,
patch_size=2,
n_heads=8,
bidirectional=False,
in_channels=3,
C=10000,
rope_type: Literal["rope", "learn", "vid"] = "rope",
ln_first=False):
return CausalDit(height, width,
n_window,
d_model,
T,
in_channels=in_channels,
n_blocks=n_blocks,
patch_size=patch_size,
n_heads=n_heads,
bidirectional=bidirectional,
rope_C=C,
rope_type=rope_type,
ln_first=ln_first)
if __name__ == "__main__":
print("running w/o cache")
dit = CausalDit(20, 20, 100, 64, 5, n_blocks=2)
z = t.rand((2, 6, 3, 20, 20))
actions = t.randint(4, (2, 6))
ts = t.rand((2, 6))
out, _, _ = dit(z, actions, ts)
print(z.shape)
print(out.shape)