| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import enum |
| | from typing import Dict, Optional |
| |
|
| | import torch |
| | import torch.nn.init as init |
| | from torch import nn |
| |
|
| | from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu |
| | from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, init_method_normal |
| | from nemo.core.classes import Exportable, NeuralModule |
| | from nemo.core.classes.common import typecheck |
| | from nemo.core.neural_types import ChannelType, NeuralType |
| |
|
| | try: |
| | from apex.transformer import parallel_state, tensor_parallel |
| |
|
| | HAVE_APEX = True |
| |
|
| | except (ImportError, ModuleNotFoundError): |
| | HAVE_APEX = False |
| |
|
| | |
| | ModelType = AttnMaskType = AttnType = LayerType = ApexGuardDefaults() |
| |
|
| |
|
| | __all__ = ["PromptEncoder", "PromptEncoderType"] |
| |
|
| |
|
| | class PromptEncoderType(enum.Enum): |
| | TPMLP = "tpmlp" |
| | MLP = "mlp" |
| | LSTM = "lstm" |
| | EMBEDDING = "embedding" |
| |
|
| |
|
| | class PromptEmbedding(NeuralModule, Exportable): |
| | """Prompt embeddings |
| | |
| | Arugments: |
| | init_from_prompt_text: Whether to intialize prompt embeddings |
| | from from certain lm embeddings |
| | corresponding to a prompt string |
| | hidden_size: hidden size should match lm embedding size |
| | total_virtual_tokens: length of prompt initalized from torch init method |
| | """ |
| |
|
| | def __init__( |
| | self, hidden_size, total_virtual_tokens, |
| | ): |
| | super().__init__() |
| |
|
| | self.hidden_size = hidden_size |
| | self.total_virtual_tokens = total_virtual_tokens |
| |
|
| | |
| | self.prompt_embeddings = torch.nn.Embedding(self.total_virtual_tokens, self.hidden_size) |
| | self.prompt_embeddings.weight.data.fill_(0.0) |
| | self.prompt_embeddings.weight.requires_grad = False |
| |
|
| | |
| | self.register_buffer('indices', torch.LongTensor(list(range(self.total_virtual_tokens)))) |
| |
|
| | def clear_prompt_embedding_weights(self,): |
| | """ |
| | Method sets the prompt embedding weights to 0.0 |
| | """ |
| | self.prompt_embeddings.weight.fill_(0.0) |
| |
|
| | def set_prompt_embedding_weights(self, weight: torch.Tensor): |
| | """ |
| | Method sets the prompt embedding weights with a new weight w |
| | """ |
| | self.prompt_embeddings.weight.data = weight.type_as(self.prompt_embeddings.weight.data) |
| |
|
| | def forward(self,): |
| | """ |
| | Does forward pass |
| | """ |
| | return self.prompt_embeddings(self.indices) |
| |
|
| |
|
| | class InferenceTable(NeuralModule, Exportable): |
| | """ |
| | A wrapper class that holds the output representations of the PromptEncoder Model. |
| | At inference time we do not need to forward pass through the full PromptEncoder and can just use this class. |
| | """ |
| |
|
| | def __init__(self, taskname, hidden_size, total_virtual_tokens, is_inference_ready=False): |
| | super().__init__() |
| | self.taskname = taskname |
| | self.hidden_size = hidden_size |
| | self.total_virtual_tokens = total_virtual_tokens |
| | self.prompt_table = torch.nn.ModuleDict() |
| | self.prompt_table[self.taskname] = PromptEmbedding(self.hidden_size, self.total_virtual_tokens) |
| | self.prompt_table[self.taskname].prompt_embeddings.weight.requires_grad = False |
| | self.prompt_table[self.taskname].clear_prompt_embedding_weights() |
| | self.is_inference_ready = is_inference_ready |
| |
|
| | def set_prompt_table(self, prompt_representation: torch.Tensor): |
| | """ |
| | Method sets the prompt embedding inside self.prompt_table[taskname] with new weights |
| | """ |
| | self.prompt_table[self.taskname].set_prompt_embedding_weights(prompt_representation) |
| | self.is_inference_ready = True |
| |
|
| | def get_prompt_table(self,): |
| | """ |
| | Returns the prompt representation cached in the prompt table |
| | """ |
| | return self.prompt_table[self.taskname].forward() |
| |
|
| | def clear_prompt_table(self,): |
| | """ |
| | Method "clears" the prompt embedding inside self.prompt_table[taskname] by setting it to zero. |
| | """ |
| | self.prompt_table[self.taskname].clear_prompt_embedding_weights() |
| | self.is_inference_ready = False |
| |
|
| |
|
| | class TPMLP(NeuralModule, Exportable): |
| | """ |
| | The Tensor Parallel MLP prompt encoder network that is used to generate the virtual |
| | token embeddings for p-tuning. It only have two layers. |
| | """ |
| |
|
| | def __init__( |
| | self, total_virtual_tokens: int, hidden_size: int, output_size: int, init_std: float, |
| | ): |
| | """ |
| | Initializes the Tensor Model parallel MLP PromptEncoderMLP module. |
| | Args: |
| | total_virtual_tokens: the total number of vitural tokens |
| | hidden_size: hidden dimension |
| | output_size: the output dimension |
| | init_std: the MLP init std value |
| | """ |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | self.output_size = output_size |
| | self.total_virtual_tokens = total_virtual_tokens |
| | self.activation = "gelu" |
| |
|
| | sequence_parallel = False |
| | gradient_accumulation_fusion = False |
| | no_async_tensor_model_parallel_allreduce = ( |
| | parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel |
| | ) |
| | self.first = tensor_parallel.ColumnParallelLinear( |
| | self.output_size, |
| | self.hidden_size, |
| | gather_output=False, |
| | init_method=init_method_normal(init_std), |
| | skip_bias_add=True, |
| | use_cpu_initialization=False, |
| | bias=True, |
| | sequence_parallel_enabled=sequence_parallel, |
| | no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, |
| | gradient_accumulation_fusion=gradient_accumulation_fusion, |
| | ) |
| | self.second = tensor_parallel.RowParallelLinear( |
| | self.hidden_size, |
| | self.output_size, |
| | input_is_parallel=True, |
| | init_method=init_method_normal(init_std), |
| | skip_bias_add=True, |
| | use_cpu_initialization=False, |
| | bias=True, |
| | sequence_parallel_enabled=sequence_parallel, |
| | gradient_accumulation_fusion=gradient_accumulation_fusion, |
| | ) |
| |
|
| | def forward(self, input_embeds) -> torch.Tensor: |
| | intermediate_parallel, bias_parallel = self.first(input_embeds) |
| | intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) |
| | output_embeds, bias_parallel = self.second(intermediate_parallel) |
| | output_embeds = output_embeds + bias_parallel |
| | return output_embeds |
| |
|
| |
|
| | class PromptEncoder(NeuralModule, Exportable): |
| | """ |
| | The prompt encoder network that is used to generate the virtual |
| | token embeddings for p-tuning. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | encoder_type: enum, |
| | total_virtual_tokens: int, |
| | token_dim: int, |
| | hidden_size, |
| | lstm_dropout: float, |
| | num_layers: int, |
| | init_std: float, |
| | taskname: str = "taskname", |
| | ): |
| | """ |
| | Initializes the PromptEncoder module. |
| | Args: |
| | total_virtual_tokens: the total number of vitural tokens |
| | hidden_size: hidden dimension |
| | lstm_dropout: the dropout used for the LSTM |
| | num_layers: number of layers used in the LSTM |
| | init_std: used for TPMLP encoder type to initialize the mlp weights |
| | """ |
| | super().__init__() |
| | self.token_dim = token_dim |
| | self.input_size = token_dim |
| | self.output_size = token_dim |
| | self.hidden_size = hidden_size |
| | self.total_virtual_tokens = total_virtual_tokens |
| | self.encoder_type = encoder_type |
| | self.activation = "gelu" |
| | self.init_std = init_std |
| | self.taskname = taskname |
| |
|
| | |
| | self.register_buffer("indices", torch.LongTensor(list(range(self.total_virtual_tokens)))) |
| |
|
| | |
| | self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim) |
| | self.inference_table = InferenceTable(taskname, self.token_dim, self.total_virtual_tokens) |
| |
|
| | if self.encoder_type == PromptEncoderType.EMBEDDING: |
| | init.xavier_normal_(self.embedding.weight) |
| | elif self.encoder_type == PromptEncoderType.LSTM: |
| | |
| | self.lstm_head = torch.nn.LSTM( |
| | input_size=self.input_size, |
| | hidden_size=self.hidden_size, |
| | num_layers=num_layers, |
| | dropout=lstm_dropout, |
| | bidirectional=True, |
| | batch_first=True, |
| | ) |
| |
|
| | self.mlp_head = nn.Sequential( |
| | nn.Linear(self.hidden_size * 2, self.hidden_size * 2), |
| | nn.ReLU(), |
| | nn.Linear(self.hidden_size * 2, self.output_size), |
| | ) |
| |
|
| | elif self.encoder_type == PromptEncoderType.MLP: |
| | if num_layers <= 1: |
| | raise ValueError( |
| | "The MLP prompt encoder must have at least 2 layers, and exactly 2 layers is recommended." |
| | ) |
| |
|
| | layers = [nn.Linear(self.input_size, self.hidden_size), nn.ReLU()] |
| | for _ in range(num_layers - 2): |
| | layers.extend([nn.Linear(self.hidden_size, self.hidden_size), nn.ReLU()]) |
| |
|
| | layers.append(nn.Linear(self.hidden_size, self.output_size)) |
| | self.mlp_head = nn.Sequential(*layers) |
| |
|
| | elif self.encoder_type == PromptEncoderType.TPMLP: |
| | self.tpmlp = TPMLP(self.total_virtual_tokens, self.hidden_size, self.output_size, self.init_std,) |
| | else: |
| | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") |
| |
|
| | def set_inference_table(self, prompt_representation: torch.Tensor): |
| | """ |
| | This method caches the output representation from the Encoder and saves it inside `self.inference_table`. |
| | """ |
| | prompt_representation = prompt_representation.detach().clone() |
| | self.inference_table.set_prompt_table(prompt_representation) |
| |
|
| | def clear_inference_table(self,): |
| | self.inference_table.clear_prompt_table() |
| |
|
| | def get_inference_table(self,): |
| | return self.inference_table.get_prompt_table() |
| |
|
| | def state_dict(self, desination=None, prefix=None, keep_vars=False): |
| | _state_dict = {} |
| | _state_dict[ |
| | 'prompt_table' |
| | ] = ( |
| | self.inference_table.state_dict() |
| | ) |
| | _state_dict['embeddings'] = self.embedding.state_dict() |
| | if self.encoder_type == PromptEncoderType.EMBEDDING: |
| | pass |
| | elif self.encoder_type == PromptEncoderType.LSTM: |
| | _state_dict['mlp_head'] = self.mlp_head.state_dict() |
| | _state_dict['lstm_head'] = self.lstm_head.state_dict() |
| | elif self.encoder_type == PromptEncoderType.MLP: |
| | _state_dict['mlp_head'] = self.mlp_head.state_dict() |
| | elif self.encoder_type == PromptEncoderType.TPMLP: |
| | _state_dict['tpmlp'] = self.tpmlp.state_dict() |
| | else: |
| | raise ValueError("Prompt encoder type not recognized. Pl.") |
| | return _state_dict |
| |
|
| | def load_state_dict(self, state_dict, strict=True): |
| | self.inference_table.load_state_dict(state_dict['prompt_table']) |
| | self.embedding.load_state_dict(state_dict['embeddings']) |
| | if self.encoder_type == PromptEncoderType.EMBEDDING: |
| | pass |
| | elif self.encoder_type == PromptEncoderType.LSTM: |
| | self.mlp_head.load_state_dict(state_dict['mlp_head']) |
| | self.lstm_head.state_dict(state_dict['lstm_head']) |
| | elif self.encoder_type == PromptEncoderType.MLP: |
| | self.mlp_head.load_state_dict(state_dict['mlp_head']) |
| | elif self.encoder_type == PromptEncoderType.TPMLP: |
| | self.tpmlp.load_state_dict(state_dict['tpmlp']) |
| | else: |
| | raise ValueError("Prompt encoder type not recognized. Pl.") |
| | return |
| |
|
| | def _forward(self,): |
| | input_embeds = self.embedding(self.indices).unsqueeze(0) |
| | if self.encoder_type == PromptEncoderType.EMBEDDING: |
| | output_embeds = input_embeds |
| | elif self.encoder_type == PromptEncoderType.LSTM: |
| | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]) |
| | elif self.encoder_type == PromptEncoderType.MLP: |
| | output_embeds = self.mlp_head(input_embeds) |
| | elif self.encoder_type == PromptEncoderType.TPMLP: |
| | output_embeds = self.tpmlp(input_embeds) |
| | else: |
| | raise ValueError("Prompt encoder type not recognized. Pl.") |
| | return output_embeds |
| |
|
| | @typecheck() |
| | def forward(self, batch_size: int, use_cached_reps: bool) -> torch.Tensor: |
| | """ |
| | Forward pass through the encoder with caching of prompt representations |
| | """ |
| | if use_cached_reps: |
| | output_embeds = self.get_inference_table().unsqueeze(0) |
| | else: |
| | if self.training: |
| | if self.inference_table.is_inference_ready: |
| | self.clear_inference_table() |
| | output_embeds = self._forward() |
| | else: |
| | if not self.inference_table.is_inference_ready: |
| | output_embeds = self._forward() |
| | self.set_inference_table(output_embeds.squeeze(0)) |
| | output_embeds = self.get_inference_table().unsqueeze(0) |
| |
|
| | output_embeds = output_embeds.expand(batch_size, self.total_virtual_tokens, self.token_dim) |
| | return output_embeds |
| |
|