|
|
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
| import torch |
| from transformers import FalconH1Config, FalconH1ForCausalLM, FalconH1Model |
| from openrlhf.moe_utils import FalconH1MoEConfig |
| from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1DecoderLayer, FalconH1MLP, compute_mup_vector |
| from torch import nn |
| import random |
| import numpy as np |
| import torch.nn.functional as F |
|
|
| class FalconH1MoEModel(FalconH1Model): |
| def __init__(self, config: FalconH1MoEConfig): |
| super().__init__(config) |
| decoder_layers = [] |
| for i in range(config.num_hidden_layers): |
| decoder_layers.append(FalconH1MoEDecoderLayer(config, layer_idx=i)) |
| self.layers = nn.ModuleList(decoder_layers) |
| mup_vector = compute_mup_vector(config) |
| for layer in self.layers: |
| layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False) |
|
|
| |
|
|
| class FalconH1MoEMLP(nn.Module): |
| def __init__(self, config: FalconH1MoEConfig): |
| super().__init__() |
| self.config = config |
| self.num_local_experts = config.expert_num |
| self.topk=config.topk |
| '''build experts''' |
| self.experts = torch.nn.ModuleList() |
| for _ in range(self.num_local_experts): |
| expert = FalconH1MLP(config) |
| self.experts.append(expert) |
| |
|
|
| '''build router''' |
| self.weight = torch.nn.Parameter( |
| torch.empty((self.num_local_experts, self.config.hidden_size), dtype=torch.float32) |
| ) |
| torch.nn.init.xavier_uniform_(self.weight) |
|
|
| |
| |
| def forward(self, x): |
| x = x.transpose(0, 1).contiguous() |
| '''fixed parameters''' |
| inp_shape = x.shape |
| num_tokens = inp_shape[0] * inp_shape[1] |
| hidden = inp_shape[-1] |
| num_experts = self.num_local_experts |
| x = x.view(-1, inp_shape[-1]) |
| restore_shape = x.shape |
|
|
|
|
|
|
| """Routing , compute the experts' weight for each token, all following step is on token level. |
| Args: |
| input (torch.Tensor): Input tensor of shape [bs, seq, hidden]. |
| weights (torch.Tensor): router's weights, [hidden, expert_num]. |
| Returns: |
| routing_probs, token -> expert_prob |
| [[0.0000, 0.0000, 0.4006, 0.5994], |
| ..., |
| [0.0373, 0.0000, 0.9627, 0.0000]] |
| ------------ |
| routing_map, token -> expert_idx |
| [[False, False, True, True], |
| ..., |
| [ True, False, True, False]]) |
| """ |
| y = torch.mm(x, self.weight.to(x.dtype).t()) |
| scores, top_indices = torch.topk(y, k=self.topk, dim=1) |
| probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(y) |
|
|
|
|
| |
| |
| |
|
|
|
|
|
|
| routing_probs = torch.zeros_like(y).scatter(1, top_indices, probs) |
| routing_map = torch.zeros_like(y).int().scatter(1, top_indices, 1).bool() |
| |
|
|
|
|
| """Dispatch: experts-to-tokens |
| |
| Args: |
| |
| |
| Returns: |
| probs: [expert0{token4_prob, token2_prob,token8_prob}.....expertn] |
| x: [expert0{token4_idx, token2_idx, token8_idx}.....] |
| |
| """ |
| permuted_probs = None |
| num_local_tokens_per_expert = routing_map.sum(dim=0).long() |
| num_out_tokens = routing_map.size(0) * self.topk |
| routing_map = routing_map.bool().T.contiguous() |
| ''' |
| [False, False, False, ..., False, True, True], |
| [False, False, False, ..., True, False, False], |
| [ True, True, True, ..., True, True, True], |
| [ True, True, True, ..., False, False, False]] |
| ''' |
| token_indices = ( |
| torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1) |
| ) |
| ''' |
| [[ 0, 1, 2, ..., 1021, 1022, 1023], |
| [ 0, 1, 2, ..., 1021, 1022, 1023], |
| [ 0, 1, 2, ..., 1021, 1022, 1023], |
| [ 0, 1, 2, ..., 1021, 1022, 1023]] |
| ''' |
|
|
| sorted_indices = token_indices.masked_select(routing_map) |
| ''' |
| [ 8, 9, 12, ..., 1015, 1016, 1017], |
| sorted_indices[:idx_1]->expert0 |
| sorted_indices[idx_1:idx_2]->expert1 |
| sorted_indices[idx_2:idx_3]->expert2 |
| sorted_indices[idx_3:idx_4]->expert3 |
| ''' |
| probs = routing_probs.T.contiguous().masked_select(routing_map) |
| ''' |
| [0.6458, 0.6458, 0.5577, ..., 0.4983, 0.0520, 0.0520] |
| ''' |
| x = x.index_select(0, sorted_indices) |
|
|
|
|
|
|
|
|
|
|
| """compute: |
| |
| Args: |
| |
| |
| Returns: |
| |
| """ |
| tokens_list = torch.split(x, num_local_tokens_per_expert.tolist()) |
| probs_list = torch.split(probs, num_local_tokens_per_expert.tolist()) |
|
|
| output_local_list = [] |
|
|
| for expert, tokens, prob in zip(self.experts, tokens_list, probs_list): |
| output = expert(tokens) * prob.unsqueeze(-1) |
|
|
| output_local_list.append(output) |
| permuted_tokens = torch.cat(output_local_list, dim=0) |
|
|
| output_tokens = torch.zeros( |
| restore_shape, dtype=permuted_tokens.dtype, device=permuted_tokens.device |
| ) |
| |
| output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens) |
| output = output_tokens.view(inp_shape).transpose(0, 1) |
| |
| return output |
| |
|
|
|
|
| class FalconH1MoEDecoderLayer(FalconH1DecoderLayer): |
| def __init__(self, config: FalconH1MoEConfig, layer_idx: int): |
| super().__init__(config, layer_idx) |
| self.feed_forward = FalconH1MoEMLP(config) |
|
|
|
|
|
|
|
|
|
|
| class FalconH1MoEForCausalLM(FalconH1ForCausalLM): |
| def __init__(self, config: FalconH1MoEConfig): |
| super().__init__(config) |
| self.model = FalconH1MoEModel(config) |
| |
| __all__ = ["FalconH1MoEForCausalLM"] |
|
|