|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from .._utils import set_obj_attrs |
|
|
from ..functional import embedding, unsqueeze, where |
|
|
from ..mapping import Mapping |
|
|
from ..module import Module |
|
|
from ..parameter import Parameter |
|
|
|
|
|
|
|
|
class Embedding(Module): |
|
|
""" |
|
|
The embedding layer takes input indices (x) and the embedding lookup table (weight) as input. |
|
|
And output the corresponding embeddings according to input indices. |
|
|
The size of weight is [num_embeddings, embedding_dim] |
|
|
|
|
|
Four parameters (tp_size, tp_group, sharding_dim, tp_rank) are involved in tensor parallelism. |
|
|
Only when "tp_size > 1 and tp_group is not None", tensor parallelism is enabled. |
|
|
When "sharding_dim == 0", the weight is shared in the vocabulary dimension. |
|
|
tp_rank must be set when sharding_dim == 0. |
|
|
When "sharding_dim == 1", the weight is shard in the hidden dimension. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
num_embeddings: int, |
|
|
embedding_dim: int, |
|
|
dtype: Optional[str] = None, |
|
|
tp_size: int = 1, |
|
|
tp_group: Optional[list] = None, |
|
|
sharding_dim: int = 0, |
|
|
tp_rank: Optional[int] = None): |
|
|
super().__init__() |
|
|
|
|
|
self.num_embeddings = num_embeddings |
|
|
self.embedding_dim = embedding_dim |
|
|
self.tp_size = tp_size |
|
|
self.tp_group = tp_group |
|
|
self.sharding_dim = sharding_dim |
|
|
self.tp_rank = tp_rank |
|
|
self.dtype = dtype |
|
|
|
|
|
if sharding_dim == 1: |
|
|
self.weight = Parameter(shape=(self.num_embeddings, |
|
|
self.embedding_dim // self.tp_size), |
|
|
dtype=dtype) |
|
|
elif sharding_dim == 0: |
|
|
self.weight = Parameter(shape=(math.ceil( |
|
|
self.num_embeddings / self.tp_size), self.embedding_dim), |
|
|
dtype=dtype) |
|
|
|
|
|
set_obj_attrs(self.weight, { |
|
|
"weight_loader": self.weight_loader, |
|
|
}) |
|
|
|
|
|
def forward(self, x): |
|
|
return embedding(x, |
|
|
self.weight.value, |
|
|
tp_size=self.tp_size, |
|
|
tp_group=self.tp_group, |
|
|
sharding_dim=self.sharding_dim, |
|
|
tp_rank=self.tp_rank) |
|
|
|
|
|
def weight_loader(self, mapping: Mapping, param: Parameter, |
|
|
loaded_weight: torch.Tensor): |
|
|
|
|
|
tp_rank = mapping.tp_rank |
|
|
if self.tp_size > 1: |
|
|
sharding_dim = self.sharding_dim |
|
|
shard_size = param._shape[sharding_dim] |
|
|
start_idx = tp_rank * shard_size |
|
|
loaded_weight = loaded_weight.narrow(sharding_dim, start_idx, |
|
|
shard_size) |
|
|
param.value = loaded_weight |
|
|
|
|
|
|
|
|
class PromptTuningEmbedding(Embedding): |
|
|
""" |
|
|
PromptTuningEmbedding handles fine-tuned prompts with virtual tokens. At runtime, |
|
|
a supplementary embedding dictionary is passed. Tokens whose ids are >= vocab_size are embedded |
|
|
with that additional dictionary. |
|
|
The prompt tuning dictionary holds multiple tasks, and each sequence is assigned a given task. |
|
|
Prompt-tuned tokens from a given sequence use the adequate task dictionary, as defined by the `tasks` input. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
num_embeddings, |
|
|
embedding_dim, |
|
|
vocab_size=None, |
|
|
dtype=None, |
|
|
tp_size=1, |
|
|
tp_group=None, |
|
|
sharding_dim=0, |
|
|
tp_rank=0): |
|
|
super().__init__(num_embeddings, embedding_dim, dtype, tp_size, |
|
|
tp_group, sharding_dim, tp_rank) |
|
|
if vocab_size is None: |
|
|
vocab_size = num_embeddings |
|
|
self.vocab_size = vocab_size |
|
|
|
|
|
def forward(self, tokens, prompt_embedding_table, tasks, task_vocab_size): |
|
|
""" |
|
|
Pass all tokens through both normal and prompt embedding tables. |
|
|
Tokens are masked so that "normal" embedding only see "normal" tokens. Same logic for "prompt" embedding. |
|
|
After those two embedding, combine results based on whether the token was "normal" or "prompt-tuned". |
|
|
|
|
|
Parameters: |
|
|
tokens : Tensor |
|
|
the ids to embbed, size [batch_size, seq_len] |
|
|
|
|
|
prompt_embedding_table : Tensor |
|
|
the additional embedding table for prompt-tuned tokens, size [num_tasks * num_tokens_per_task, hidden_size] |
|
|
|
|
|
tasks: Tensor |
|
|
the task required by each token, size [batch_size, seq_len] |
|
|
|
|
|
task_vocab_size: Tensor |
|
|
the number of tokens used for each task, should be equal to prompt_embedding_table's num_tokens_per_task, size [1] |
|
|
|
|
|
Returns: |
|
|
Tokens' embedding |
|
|
""" |
|
|
|
|
|
prompt_tokens_mask = tokens > (self.vocab_size - 1) |
|
|
|
|
|
|
|
|
normal_tokens = where(prompt_tokens_mask, self.vocab_size - 1, tokens) |
|
|
normal_embeddings = embedding(normal_tokens, self.weight.value, |
|
|
self.tp_size, self.tp_group, |
|
|
self.sharding_dim, self.tp_rank) |
|
|
|
|
|
|
|
|
prompt_tokens = where(prompt_tokens_mask, tokens - self.vocab_size, 0) |
|
|
|
|
|
|
|
|
tasks = tasks * task_vocab_size |
|
|
|
|
|
|
|
|
|
|
|
prompt_tokens = prompt_tokens + tasks |
|
|
prompt_embeddings = embedding(prompt_tokens, prompt_embedding_table) |
|
|
|
|
|
|
|
|
|
|
|
return where(unsqueeze(prompt_tokens_mask, -1), prompt_embeddings, |
|
|
normal_embeddings) |
|
|
|