|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
|
|
import rotary_emb |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from einops import rearrange |
|
|
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb |
|
|
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_ |
|
|
from torch import Tensor, nn |
|
|
|
|
|
from internlm.core.context import ParallelMode |
|
|
from internlm.core.context import global_context as gpc |
|
|
|
|
|
from .utils import gather_forward_split_backward, split_forward_gather_backward |
|
|
|
|
|
|
|
|
from .muse import VQGANModel |
|
|
|
|
|
class Embedding1DLVM(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
vq_model_path: str, |
|
|
embedding_dim: int = None, |
|
|
freeze_vq_model: bool = True |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.vq_model = VQGANModel.from_pretrained(vq_model_path) |
|
|
if freeze_vq_model: |
|
|
self.vq_model.requires_grad_(False) |
|
|
self.vq_model.eval() |
|
|
|
|
|
self.num_embeddings, vq_embed_dim = self.vq_model.quantize.embedding.weight.shape |
|
|
|
|
|
if embedding_dim is not None: |
|
|
self.embed_proj = nn.Linear(vq_embed_dim, embedding_dim, bias=False) |
|
|
self.embedding_dim = embedding_dim |
|
|
else: |
|
|
self.embed_proj = None |
|
|
self.embedding_dim = vq_embed_dim |
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
|
|
|
|
|
|
output_parallel = self.vq_model.quantize.get_codebook_entry_for_lvm(input_) |
|
|
|
|
|
if self.embed_proj is not None: |
|
|
output_parallel = self.embed_proj(output_parallel) |
|
|
|
|
|
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) |
|
|
|
|
|
if gpc.config.parallel.sequence_parallel: |
|
|
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class Embedding1D(nn.Module): |
|
|
""" |
|
|
1D Embedding. |
|
|
|
|
|
Args: |
|
|
num_embeddings (int): The size of vocab. |
|
|
embedding_dim (int): The dimention of model. |
|
|
padding_idx (int): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; |
|
|
therefore, the embedding vector at :attr:`padding_idx` is not updated during training, |
|
|
i.e. it remains as a fixed "pad". None by default. |
|
|
dtype (Optional[torch.dtype]): Data type None by default. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_embeddings: int, |
|
|
embedding_dim: int, |
|
|
*args, |
|
|
padding_idx: int = None, |
|
|
dtype: torch.dtype = None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.num_embeddings = num_embeddings |
|
|
self.embed_dim = embedding_dim |
|
|
embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size |
|
|
|
|
|
self.padding_idx = padding_idx |
|
|
self.embed_args = args |
|
|
self.embed_kwargs = kwargs |
|
|
|
|
|
self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) |
|
|
|
|
|
def forward(self, input_: Tensor) -> Tensor: |
|
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) |
|
|
|
|
|
output = gather_forward_split_backward(output_parallel, ParallelMode.TENSOR, dim=-1) |
|
|
|
|
|
if gpc.config.parallel.sequence_parallel: |
|
|
output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class ApplyRotaryEmbQKV_(torch.autograd.Function): |
|
|
""" |
|
|
ApplyRotaryEmbQKV_ |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None): |
|
|
""" |
|
|
qkv: (total, 3, nheads, headdim) |
|
|
cos, sin: (seqlen, rotary_dim / 2) |
|
|
cos_k, sin_k: (seqlen, rotary_dim / 2), optional |
|
|
rotary_dim must be <= headdim |
|
|
Apply rotary embedding *inplace* to the first rotary_dim of q and k. |
|
|
""" |
|
|
_, three, _, headdim = qkv.shape |
|
|
assert three == 3 |
|
|
rotary_seqlen, rotary_dim = cos.shape |
|
|
rotary_dim *= 2 |
|
|
assert rotary_dim <= headdim |
|
|
cos_k = cos if cos_k is None else cos_k |
|
|
sin_k = sin if sin_k is None else sin_k |
|
|
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) |
|
|
q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) |
|
|
rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False) |
|
|
k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) |
|
|
rotary_emb.apply_rotary( |
|
|
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False |
|
|
) |
|
|
ctx.save_for_backward(cos, sin, cos_k, sin_k) |
|
|
return qkv |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, dqkv): |
|
|
cos, sin, cos_k, sin_k = ctx.saved_tensors |
|
|
rotary_dim = cos.shape[-1] |
|
|
rotary_dim *= 2 |
|
|
dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1) |
|
|
rotary_emb.apply_rotary( |
|
|
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True |
|
|
) |
|
|
dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1) |
|
|
rotary_emb.apply_rotary( |
|
|
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True |
|
|
) |
|
|
return dqkv, None, None, None, None |
|
|
|
|
|
|
|
|
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply |
|
|
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply |
|
|
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply |
|
|
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module): |
|
|
""" |
|
|
The rotary position embeddings from RoFormer_ (Su et. al). |
|
|
A crucial insight from the method is that the query and keys are |
|
|
transformed by rotation matrices which depend on the relative positions. |
|
|
|
|
|
Other implementations are available in the Rotary Transformer repo_ and in |
|
|
GPT-NeoX_, GPT-NeoX was an inspiration |
|
|
|
|
|
.. _RoFormer: https://arxiv.org/abs/2104.09864 |
|
|
.. _repo: https://github.com/ZhuiyiTechnology/roformer |
|
|
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox |
|
|
|
|
|
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). |
|
|
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 |
|
|
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py |
|
|
""" |
|
|
|
|
|
def __init__(self, dim: int, base=10000, scale_base=0, device=None): |
|
|
""" """ |
|
|
super().__init__() |
|
|
|
|
|
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) |
|
|
self.scale_base = scale_base |
|
|
self.scale = ( |
|
|
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) |
|
|
if scale_base > 0 |
|
|
else None |
|
|
) |
|
|
|
|
|
self._seq_len_cached = 0 |
|
|
self._cos_cached = None |
|
|
self._sin_cached = None |
|
|
self._cos_k_cached = None |
|
|
self._sin_k_cached = None |
|
|
|
|
|
def _update_cos_sin_cache(self, x, indexes): |
|
|
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" |
|
|
if not isinstance(indexes, int): |
|
|
seqlen = indexes.max().item() + 1 |
|
|
else: |
|
|
seqlen = indexes + 1 |
|
|
|
|
|
|
|
|
if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: |
|
|
self._seq_len_cached = seqlen |
|
|
t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype) |
|
|
|
|
|
|
|
|
freqs = torch.outer(t, self.inv_freq.to(device=t.device)) |
|
|
if self.scale is None: |
|
|
self._cos_cached = torch.cos(freqs).to(x.dtype) |
|
|
self._sin_cached = torch.sin(freqs).to(x.dtype) |
|
|
else: |
|
|
power = ( |
|
|
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 |
|
|
) / self.scale_base |
|
|
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") |
|
|
|
|
|
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) |
|
|
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) |
|
|
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) |
|
|
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) |
|
|
|
|
|
def forward(self, qkv: torch.Tensor, **kwargs): |
|
|
if kwargs.get("indexes", None) is not None: |
|
|
return self._forward(qkv, kwargs.pop("indexes")) |
|
|
if kwargs.get("inference_params", None) is not None: |
|
|
return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset) |
|
|
else: |
|
|
return self._eval_forward(qkv) |
|
|
|
|
|
def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
self._update_cos_sin_cache(qkv, indexes) |
|
|
if self.scale is None: |
|
|
return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) |
|
|
else: |
|
|
return apply_rotary_emb_qkv_( |
|
|
qkv, |
|
|
self._cos_cached[indexes], |
|
|
self._sin_cached[indexes], |
|
|
self._cos_k_cached[indexes], |
|
|
self._sin_k_cached[indexes], |
|
|
) |
|
|
|
|
|
def _eval_forward(self, qkv, seqlen_offset=0): |
|
|
""" |
|
|
seqlen_offset: can be used in generation where the qkv being passed in is only the last |
|
|
token in the batch. |
|
|
""" |
|
|
self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1]) |
|
|
if self.scale is None: |
|
|
return legacy_apply_rotary_embed_qkv( |
|
|
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:] |
|
|
) |
|
|
else: |
|
|
return legacy_apply_rotary_embed_qkv( |
|
|
qkv, |
|
|
self._cos_cached[seqlen_offset:], |
|
|
self._sin_cached[seqlen_offset:], |
|
|
self._cos_k_cached[seqlen_offset:], |
|
|
self._sin_k_cached[seqlen_offset:], |
|
|
) |
|
|
|
|
|
def _single_forward(self, x, indexes=0): |
|
|
assert self.scale is None |
|
|
self._update_cos_sin_cache(x, indexes) |
|
|
x = x[None, ...] |
|
|
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0) |
|
|
return ret |
|
|
|
|
|
def _single_eval_forward(self, x, seqlen_offset=0): |
|
|
assert self.scale is None |
|
|
self._update_cos_sin_cache(x, seqlen_offset + x.shape[1]) |
|
|
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) |
|
|
|