| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from functools import lru_cache |
| | from typing import Optional, Tuple |
| | import torch |
| | from einops import rearrange |
| | from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb |
| | from torch import nn |
| |
|
| | from common.cache import Cache |
| |
|
| |
|
| | class RotaryEmbeddingBase(nn.Module): |
| | def __init__(self, dim: int, rope_dim: int): |
| | super().__init__() |
| | self.rope = RotaryEmbedding( |
| | dim=dim // rope_dim, |
| | freqs_for="pixel", |
| | max_freq=256, |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | freqs = self.rope.freqs |
| | del self.rope.freqs |
| | self.rope.register_buffer("freqs", freqs.data) |
| |
|
| | @lru_cache(maxsize=128) |
| | def get_axial_freqs(self, *dims): |
| | return self.rope.get_axial_freqs(*dims) |
| |
|
| |
|
| | class RotaryEmbedding3d(RotaryEmbeddingBase): |
| | def __init__(self, dim: int): |
| | super().__init__(dim, rope_dim=3) |
| | self.mm = False |
| |
|
| | def forward( |
| | self, |
| | q: torch.FloatTensor, |
| | k: torch.FloatTensor, |
| | size: Tuple[int, int, int], |
| | ) -> Tuple[ |
| | torch.FloatTensor, |
| | torch.FloatTensor, |
| | ]: |
| | T, H, W = size |
| | freqs = self.get_axial_freqs(T, H, W) |
| | q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) |
| | k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) |
| | q = apply_rotary_emb(freqs, q.float()).to(q.dtype) |
| | k = apply_rotary_emb(freqs, k.float()).to(k.dtype) |
| | q = rearrange(q, "b h T H W d -> b h (T H W) d") |
| | k = rearrange(k, "b h T H W d -> b h (T H W) d") |
| | return q, k |
| |
|
| |
|
| | class MMRotaryEmbeddingBase(RotaryEmbeddingBase): |
| | def __init__(self, dim: int, rope_dim: int): |
| | super().__init__(dim, rope_dim) |
| | self.rope = RotaryEmbedding( |
| | dim=dim // rope_dim, |
| | freqs_for="lang", |
| | theta=10000, |
| | ) |
| | freqs = self.rope.freqs |
| | del self.rope.freqs |
| | self.rope.register_buffer("freqs", freqs.data) |
| | self.mm = True |
| |
|
| |
|
| | class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): |
| | def __init__(self, dim: int): |
| | super().__init__(dim, rope_dim=3) |
| |
|
| | def forward( |
| | self, |
| | vid_q: torch.FloatTensor, |
| | vid_k: torch.FloatTensor, |
| | vid_shape: torch.LongTensor, |
| | txt_q: torch.FloatTensor, |
| | txt_k: torch.FloatTensor, |
| | txt_shape: torch.LongTensor, |
| | cache: Cache, |
| | ) -> Tuple[ |
| | torch.FloatTensor, |
| | torch.FloatTensor, |
| | torch.FloatTensor, |
| | torch.FloatTensor, |
| | ]: |
| | vid_freqs, txt_freqs = cache( |
| | "mmrope_freqs_3d", |
| | lambda: self.get_freqs(vid_shape, txt_shape), |
| | ) |
| | vid_q = rearrange(vid_q, "L h d -> h L d") |
| | vid_k = rearrange(vid_k, "L h d -> h L d") |
| | vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) |
| | vid_k = apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) |
| | vid_q = rearrange(vid_q, "h L d -> L h d") |
| | vid_k = rearrange(vid_k, "h L d -> L h d") |
| |
|
| | txt_q = rearrange(txt_q, "L h d -> h L d") |
| | txt_k = rearrange(txt_k, "L h d -> h L d") |
| | txt_q = apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) |
| | txt_k = apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) |
| | txt_q = rearrange(txt_q, "h L d -> L h d") |
| | txt_k = rearrange(txt_k, "h L d -> L h d") |
| | return vid_q, vid_k, txt_q, txt_k |
| |
|
| | def get_freqs( |
| | self, |
| | vid_shape: torch.LongTensor, |
| | txt_shape: torch.LongTensor, |
| | ) -> Tuple[ |
| | torch.Tensor, |
| | torch.Tensor, |
| | ]: |
| | vid_freqs = self.get_axial_freqs(1024, 128, 128) |
| | txt_freqs = self.get_axial_freqs(1024) |
| | vid_freq_list, txt_freq_list = [], [] |
| | for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): |
| | vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) |
| | txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) |
| | vid_freq_list.append(vid_freq) |
| | txt_freq_list.append(txt_freq) |
| | return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) |
| |
|
| |
|
| | def get_na_rope(rope_type: Optional[str], dim: int): |
| | if rope_type is None: |
| | return None |
| | if rope_type == "mmrope3d": |
| | return NaMMRotaryEmbedding3d(dim=dim) |
| | raise NotImplementedError(f"{rope_type} is not supported.") |
| |
|