object-assembler / code /cube3d /model /gpt /dual_stream_roformer.py
0xZohar's picture
Add code/cube3d/model/gpt/dual_stream_roformer.py
6d7f949 verified
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from cube3d.model.transformers.cache import Cache
from cube3d.model.transformers.dual_stream_attention import (
DualStreamDecoderLayerWithRotaryEmbedding,
)
from cube3d.model.transformers.norm import LayerNorm
from cube3d.model.transformers.roformer import DecoderLayerWithRotaryEmbedding
from cube3d.model.transformers.rope import precompute_freqs_cis
class DualStreamRoformer(nn.Module):
@dataclass
class Config:
checkpoint_path: str = ""
n_layer: int = 12
n_single_layer: int = 0
rope_theta: float = 1000
n_head: int = 16
n_embd: int = 2048
bias: bool = False # bias in Linears and LayerNorms
eps: float = 1e-6 # Norm eps
shape_model_vocab_size: int = 4096
shape_model_embed_dim: int = 16
text_model_embed_dim: int = 512
use_pooled_text_embed: bool = False
encoder_with_cls_token: bool = True
use_bbox: bool = False
ldr_in_embed_dim: int = 2048
ldr_out_embed_dim: int = 2048
def __init__(self, cfg: Config) -> None:
"""
Initializes the DualStreamRoFormer model.
Args:
cfg (Config): Configuration object containing model parameters.
Attributes:
cfg (Config): Stores the configuration object.
text_proj (nn.Linear): Linear layer to project text model embeddings to the desired embedding dimension.
shape_proj (nn.Linear, optional): Linear layer to project shape model embeddings to the desired embedding
dimension
vocab_size (int): Vocabulary size for the shape model, including special tokens.
shape_bos_id (int): Token ID for the beginning-of-sequence (BOS) token for the shape model.
shape_eos_id (int): Token ID for the end-of-sequence (EOS) token for the shape model.
padding_id (int): Token ID for the padding token.
transformer (nn.ModuleDict): Dictionary containing the following components:
- wte (nn.Embedding): Embedding layer for the vocabulary.
- dual_blocks (nn.ModuleList): List of dual-stream decoder layers with rotary embeddings.
- single_blocks (nn.ModuleList): List of single-stream decoder layers with rotary embeddings.
- ln_f (LayerNorm): Layer normalization applied to the final output.
lm_head (nn.Linear): Linear layer mapping the final embeddings to the vocabulary size for language modeling.
"""
super().__init__()
self.cfg = cfg
self.text_proj = nn.Linear(
in_features=self.cfg.text_model_embed_dim,
out_features=self.cfg.n_embd,
bias=self.cfg.bias,
)
self.shape_proj = nn.Linear(self.cfg.shape_model_embed_dim, self.cfg.n_embd)
self.ldr_proj = nn.Linear(self.cfg.ldr_in_embed_dim, self.cfg.n_embd)
#self.postion_proj = nn.Linear(3, 3)
self.vocab_size = self.cfg.shape_model_vocab_size
x_num = 251
y_num = 215
z_num = 525
rot_num = 24
self.x_num = x_num
self.y_num = y_num
self.z_num = z_num
self.rot_num = rot_num
self.x = x_num
self.xy = x_num + y_num + rot_num
self.xyz = x_num + y_num + z_num + rot_num
self.dat_num = 1217 #286 #604
self.dte = nn.Embedding(
self.dat_num+1,
#(self.cfg.n_embd-768),
self.cfg.n_embd,
padding_idx=self.dat_num,
)
self.rte = nn.Embedding(
self.rot_num+2,
#(self.cfg.n_embd-768),
self.cfg.n_embd,
padding_idx=self.rot_num,
)
self.xte = nn.Embedding(
self.x_num+2,
#(self.cfg.n_embd-768),
self.cfg.n_embd,
padding_idx=self.x_num,
)
self.yte = nn.Embedding(
self.y_num+2,
#(self.cfg.n_embd-768),
self.cfg.n_embd,
padding_idx=self.y_num,
)
self.zte = nn.Embedding(
self.z_num+2,
#(self.cfg.n_embd-768),
self.cfg.n_embd,
padding_idx=self.z_num,
)
self.is_compute = False
def add_special_token():
token_id = self.vocab_size
self.vocab_size += 1
return token_id
self.shape_bos_id = add_special_token() #16384
self.shape_eos_id = add_special_token() #16385
self.padding_id = add_special_token() #16386
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(
self.vocab_size,
self.cfg.n_embd,
padding_idx=self.padding_id,
),
dual_blocks=nn.ModuleList(
[
DualStreamDecoderLayerWithRotaryEmbedding.from_config(
self.cfg, cond_pre_only=(i == self.cfg.n_layer - 1)
)
for i in range(self.cfg.n_layer)
]
),
single_blocks=nn.ModuleList(
[
DecoderLayerWithRotaryEmbedding.from_config(self.cfg)
for _ in range(self.cfg.n_single_layer)
]
),
ln_f=LayerNorm(
self.cfg.n_embd, elementwise_affine=False, eps=self.cfg.eps
),
)
)
self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False)
self.ldr_head = nn.Linear(self.cfg.n_embd, self.cfg.ldr_out_embed_dim, bias=False)
if self.cfg.use_bbox:
self.bbox_proj = nn.Linear(3, self.cfg.n_embd)
def encode_embed(self, ldr_embed):
"""
Encodes the given ldr embeddings by projecting them through a linear transformation.
Args:
ldr_embed (torch.Tensor): A tensor representing the text embeddings to be encoded.
Returns:
torch.Tensor: The projected ldr embeddings after applying the linear transformation.
"""
return self.ldr_proj(ldr_embed)
def encode_text(self, text_embed):
"""
Encodes the given text embeddings by projecting them through a linear transformation.
Args:
text_embed (torch.Tensor): A tensor representing the text embeddings to be encoded.
Returns:
torch.Tensor: The projected text embeddings after applying the linear transformation.
"""
return self.text_proj(text_embed)
def encode_token(self, tokens):
"""
Encodes the input tokens using the word token embedding layer of the transformer model.
Args:
tokens (torch.Tensor): A tensor containing the input tokens to be encoded.
Returns:
torch.Tensor: A tensor containing the encoded token embeddings.
"""
return self.transformer.wte(tokens)
def init_kv_cache(
self,
batch_size: int,
cond_len: int,
max_shape_tokens: int,
dtype: torch.dtype,
device: torch.device,
) -> list[Cache]:
"""
Initializes the key-value cache for the transformer model.
This method creates a list of `Cache` objects to store the key and value
states for both dual-stream and single-stream transformer blocks. The
cache is pre-allocated with zeros and is used to optimize the computation
of attention mechanisms during model inference.
Args:
batch_size (int): The batch size for the input data.
cond_len (int): The length of the conditioning sequence.
max_shape_tokens (int): The maximum number of tokens in the shape sequence.
dtype (torch.dtype): The data type for the tensors (e.g., torch.float32).
device (torch.device): The device on which the tensors will be allocated
(e.g., torch.device('cuda') or torch.device('cpu')).
Returns:
list[Cache]: A list of `Cache` objects containing pre-allocated key and
value states for each transformer block.
"""
num_heads = self.cfg.n_head
max_all_tokens = cond_len + max_shape_tokens
per_head_dim = self.cfg.n_embd // num_heads
kv_cache = [
Cache(
key_states=torch.zeros(
(batch_size, num_heads, max_all_tokens, per_head_dim),
dtype=dtype,
device=device,
),
value_states=torch.zeros(
(batch_size, num_heads, max_all_tokens, per_head_dim),
dtype=dtype,
device=device,
),
)
for _ in range(len(self.transformer.dual_blocks))
]
kv_cache += [
Cache(
key_states=torch.zeros(
(batch_size, num_heads, max_shape_tokens, per_head_dim),
dtype=dtype,
device=device,
),
value_states=torch.zeros(
(batch_size, num_heads, max_shape_tokens, per_head_dim),
dtype=dtype,
device=device,
),
)
for _ in range(len(self.transformer.single_blocks))
]
return kv_cache
def forward(
self,
embed: torch.Tensor,
cond: torch.Tensor,
kv_cache: Optional[list[Cache]] = None,
curr_pos_id: Optional[torch.Tensor] = None,
decode: bool = False,
**kwargs,
):
"""
Forward pass for the dual-stream RoFormer model.
Args:
embed (torch.Tensor): The input embedding tensor.
cond (torch.Tensor): The conditioning tensor.
kv_cache (Optional[list[Cache]]): A list of key-value caches for each layer, used for decoding. Default is None.
curr_pos_id (Optional[torch.Tensor]): The current position ID tensor of shape (batch_size,). Required if `decode` is True. Default is None.
decode (bool): Whether the model is in decoding mode. Default is False.
Returns:
torch.Tensor: The output logits tensor.
"""
b, l = embed.shape[:2]
s = cond.shape[1]
device = embed.device
# attn_mask = torch.tril(
# torch.ones(s + l, s + l, dtype=torch.bool, device=device)
# ) #Causal Attention Mask
attn_mask = torch.ones(s + l, s + l, dtype=torch.bool, device=device) #Without Attention Mask
# positions = torch.arange(s + l, device=device)
# mask_1d = (positions > 1) & ((positions % 5 == 0) | (positions % 5 == 1) | (positions % 5 == 4))
# attn_mask[mask_1d, :] = False
# attn_mask[:, mask_1d] = False
position_ids = torch.arange(l, dtype=torch.long, device=device) # shape (t)
position_ids = position_ids.unsqueeze_(0).expand(b, -1)
#position_ids = position_ids.unsqueeze(0).expand(b, -1)
s_freqs_cis = precompute_freqs_cis(
dim=self.cfg.n_embd // self.cfg.n_head, # 128
t=position_ids,
theta=self.cfg.rope_theta, #10000.0
)
position_ids = torch.cat(
[
torch.zeros([b, s], dtype=torch.long, device=position_ids.device),
position_ids,
],
dim=1,
) #full position_ids
d_freqs_cis = precompute_freqs_cis(
dim=self.cfg.n_embd // self.cfg.n_head,
t=position_ids,
theta=self.cfg.rope_theta,
) #full position embedding
#import ipdb; ipdb.set_trace()
if kv_cache is not None and decode:
assert curr_pos_id is not None
embed = embed[:, curr_pos_id, :]
#print(decode)
h = embed
c = cond
layer_idx = 0
for block in self.transformer.dual_blocks:
h, c = block(
h,
c=c,
freqs_cis=d_freqs_cis,
attn_mask=attn_mask,
is_causal=True,
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
#kv_cache=None,
curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
decode=decode,
)
layer_idx += 1
for block in self.transformer.single_blocks:
h = block(
h,
freqs_cis=s_freqs_cis,
attn_mask=None,
is_causal=True,
kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
#kv_cache=None,
curr_pos_id=curr_pos_id,
decode=decode,
)
layer_idx += 1
#import ipdb; ipdb.set_trace()
# Normalization
h = self.transformer.ln_f(h)
logits = self.ldr_head(h)
return logits