NMR / src /models /transformers /llama_ar.py
RayZhao's picture
initial push
4cc0d6c
"""Full definition of a LLaMA Language Model, all of it in this single file.
Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT.
"""
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor
from typing import Optional
from torch.distributions import Categorical
import torch.nn.functional as F
from mmengine.registry import MODELS
from mmengine.model import BaseModel
from typing import Optional
@dataclass
class LLaMAHFConfig:
block_size: int = 4096
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
n_embd: int = 4096
condition_dim: int = 512
@MODELS.register_module()
class LLaMAHF_AR(BaseModel):
def __init__(self,
block_size: int = 4096,
vocab_size: int = 32000,
n_layer: int = 32,
n_head: int = 32,
n_embd: int = 4096,
condition_dim = 512,
**kwargs) -> None:
'''
end_token_idx: vocab size - 2
pad_token_idx: vocab size - 1
'''
super().__init__(**kwargs)
config = LLaMAHFConfig(
block_size=block_size, vocab_size=vocab_size, n_layer=n_layer,
n_head=n_head, n_embd=n_embd, condition_dim=condition_dim)
assert config.vocab_size is not None
assert config.block_size is not None
self.config = config
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.vocab_size, config.n_embd),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=RMSNorm(config.n_embd),
)
)
@torch.no_grad()
def sample(self, condition_feature, condition_mask, if_categorial=False, sample_cnt=30):
'''
Retuen:
- generated_idx: Tensor, [B, T]
- generated_length: Tensor, int, [B, ], the length before motion_end_idx (inclusive)
'''
# support batch inference
B, L, _ = condition_feature.shape
device = condition_feature.device
generated_idx = []
# right padding -> left padding
flip_condition_mask = condition_mask.flip(1).bool()
position_ids = torch.arange(L, device=device).unsqueeze(0).expand(B, -1) # B, L
condition_feature[flip_condition_mask] = condition_feature[condition_mask.bool()]
position_ids[flip_condition_mask] = position_ids[condition_mask.bool()]
pre_embeddings = condition_feature
past_key_values = [None] * len(self.transformer.h) # type: ignore
for i in range(sample_cnt):
idx, past_key_values = self.forward_sample(
pre_embeddings, flip_condition_mask, past_key_values, position_ids, if_categorial) # B
position_ids = position_ids[:, -1:] + 1
idx = idx.squeeze(-1)
generated_idx.append(idx)
pre_embeddings = self.transformer.wte(idx).unsqueeze(1) # B, 1, C # type: ignore
generated_idx = torch.stack(generated_idx, dim=1) # B, T
return generated_idx
def forward_sample(self, pre_embeddings: Tensor, condition_mask: Tensor,
past_key_values, position_ids, if_categorial=False):
'''
end_flag: Tensor, bool, [B, ]
pre_embeddings: Tensor, [B, L, C]
condition_mask: Tensor, [B, N, C]
'''
x = pre_embeddings.clone()
new_past_key_values = []
for i, block in enumerate(self.transformer.h): # type: ignore
x, present_kv = block(x, condition_mask, past_key_values[i], position_ids)
new_past_key_values.append(present_kv)
x = self.transformer.ln_f(x) # type: ignore
logits = self.lm_head(x)[:, -1, :]
probs = F.softmax(logits, dim=-1)
if if_categorial:
idx = Categorical(probs).sample().unsqueeze(-1)
else:
_, idx = torch.topk(probs, k=1, dim=-1)
return idx, new_past_key_values
def forward(self, idx: Tensor, condition_features, condition_masks): # type: ignore
B, T = idx.size()
assert (T <= self.config.block_size), \
f"Cannot forward sequence of length {T}, block size is only {self.config.block_size}"
# import ipdb; ipdb.set_trace()
# B, T -> B, T, C
x = self.transformer.wte(idx) # type: ignore
condition_length = condition_features.shape[1]
expanded_mask = condition_masks.unsqueeze(-1).expand(-1, -1, x.shape[-1]) # B, L -> B, L, C
result = torch.where(expanded_mask == 1, condition_features, x[:, :condition_length, :])
x = torch.cat((result, x[:, condition_length:]), dim=1)
for block in self.transformer.h: # type: ignore
x, _ = block(x, condition_masks)
x = self.transformer.ln_f(x) # type: ignore
logits = self.lm_head(x) # (b, t, vocab_size)
return logits
class Block(nn.Module):
def __init__(self, config: LLaMAHFConfig) -> None: # , use_qkNorm=False, use_moe=False) -> None:
super().__init__()
self.rms_1 = RMSNorm(config.n_embd)
self.attn = LengthCausalSelfAttention(config) # , use_qkNorm)
self.rms_2 = RMSNorm(config.n_embd)
self.mlp = MLP(config)
def forward(self, x: Tensor, y_mask: Tensor, past_key_values: Optional[Tensor]=None, position_ids: Optional[Tensor]=None):
attn_output, current_key_values = self.attn(self.rms_1(x), y_mask, past_key_values, position_ids)
x = x + attn_output
x = x + self.mlp(self.rms_2(x))
return x, current_key_values
class LengthCausalSelfAttention(nn.Module):
def __init__(self, config: LLaMAHFConfig) -> None: # , use_qkNorm=False) -> None:
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.block_size = config.block_size
self.rope_cache = None
def forward(self, x: Tensor, y_mask: Tensor, past_key_values: Optional[Tensor], position_ids: Optional[Tensor] = None):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
head_size = C // self.n_head
k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, head_size).transpose(1, 2).contiguous() # (B, nh, T, hs)
if self.rope_cache is None:
# cache for future forward calls
self.rope_cache = build_rope_cache(
seq_len=self.block_size,
n_elem=self.n_embd // self.n_head,
dtype=x.dtype,
device=x.device,
)
q = apply_rope(q, self.rope_cache, position_ids)
k = apply_rope(k, self.rope_cache, position_ids)
if past_key_values is not None:
past_k, past_v = past_key_values
k = torch.cat((past_k, k), dim=2)
v = torch.cat((past_v, v), dim=2)
current_key_values = (k, v)
# create attention mask
_T = k.shape[2]
attn_mask = torch.ones(_T, _T, dtype=torch.bool, device=x.device)
attn_mask = torch.tril(attn_mask)
attn_mask = attn_mask.unsqueeze(0).expand(B, -1, -1)
text_mask = y_mask.unsqueeze(2)*y_mask.unsqueeze(1)
text_mask = F.pad(text_mask, (0, _T-y_mask.shape[1], 0, _T-y_mask.shape[1]), mode='constant', value=0)
attn_mask = torch.logical_or(attn_mask, text_mask)
attn_mask = attn_mask[:, -q.size(2):, :]
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask.unsqueeze(1), dropout_p=0.0, is_causal=False)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.c_proj(y)
return y, current_key_values
class MLP(nn.Module):
def __init__(self, config: LLaMAHFConfig) -> None:
super().__init__()
hidden_dim = 4 * config.n_embd
n_hidden = int(2 * hidden_dim / 3)
N = 256
# ensure n_hidden is multiple of N
n_hidden = ((n_hidden - 1) // N) * N + N
self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
def forward(self, x: Tensor) -> Tensor:
x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
x = self.c_proj(x)
return x
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization.
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
"""
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
super().__init__()
self.scale = nn.Parameter(torch.ones(size))
self.eps = eps
self.dim = dim
def forward(self, x: Tensor) -> Tensor:
# NOTE: the original RMSNorm paper implementation is not equivalent
# norm_x = x.norm(2, dim=self.dim, keepdim=True)
# rms_x = norm_x * d_x ** (-1. / 2)
# x_normed = x / (rms_x + self.eps)
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return self.scale * x_normed
def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> Tensor:
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta)
# Compute cache. Because polar only takes float32 or float64, we need to cast
# when working with 16 bit floats (float16 or bfloat16)
dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8]
working_dtype = (
torch.float32 if dtype in dtypes_requiring_casting else dtype
)
complex_dtype = (
torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64
)
cache = torch.polar(
torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype)
).to(complex_dtype)
return cache
def apply_rope(x: Tensor, rope_cache: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
x = x.transpose(1, 2)
B, T, nh, hs = x.size()
# truncate to support variable sizes
if position_ids is None:
rope_cache = rope_cache[:T] # T, c
rope_cache = rope_cache.unsqueeze(0).unsqueeze(2) # 1, T, 1, c
rope_cache = rope_cache.expand(B, T, nh, -1) # B, T, nh, c
else:
rope_cache = rope_cache[position_ids] # B, T, c
rope_cache = rope_cache.unsqueeze(2).expand(B, T, nh, -1) # B, T, nh, c
# cast because `view_as_complex` does not support 16 bit tensors
xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # B, T, nh, c/2
x_out = torch.view_as_real(xc * rope_cache).flatten(3)
return x_out.transpose(1, 2).type_as(x)