falcon_moe_1.5Base_test / modeling_falcon_h1_moe.py
wmere's picture
Upload FalconH1MoEForCausalLM
4d089ce verified
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import torch
from transformers import FalconH1Config, FalconH1ForCausalLM, FalconH1Model
from openrlhf.moe_utils import FalconH1MoEConfig
from transformers.models.falcon_h1.modeling_falcon_h1 import FalconH1DecoderLayer, FalconH1MLP, compute_mup_vector
from torch import nn
import random
import numpy as np
import torch.nn.functional as F
class FalconH1MoEModel(FalconH1Model):
def __init__(self, config: FalconH1MoEConfig):
super().__init__(config)
decoder_layers = []
for i in range(config.num_hidden_layers):
decoder_layers.append(FalconH1MoEDecoderLayer(config, layer_idx=i))
self.layers = nn.ModuleList(decoder_layers)
mup_vector = compute_mup_vector(config)
for layer in self.layers:
layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False)
class FalconH1MoEMLP(nn.Module):
def __init__(self, config: FalconH1MoEConfig):
super().__init__()
self.config = config
self.num_local_experts = config.expert_num
self.topk=config.topk
'''build experts'''
self.experts = torch.nn.ModuleList()
for _ in range(self.num_local_experts):
expert = FalconH1MLP(config)
self.experts.append(expert)
'''build router'''
self.weight = torch.nn.Parameter(
torch.empty((self.num_local_experts, self.config.hidden_size), dtype=torch.float32)
)
torch.nn.init.xavier_uniform_(self.weight) # 或者 xavier_normal_
def forward(self, x):
x = x.transpose(0, 1).contiguous() #x: [seq_len, bs, hidden_size]
'''fixed parameters'''
inp_shape = x.shape
num_tokens = inp_shape[0] * inp_shape[1]
hidden = inp_shape[-1]
num_experts = self.num_local_experts
x = x.view(-1, inp_shape[-1]) #x: [token_num, hidden_size]
restore_shape = x.shape
"""Routing , compute the experts' weight for each token, all following step is on token level.
Args:
input (torch.Tensor): Input tensor of shape [bs, seq, hidden].
weights (torch.Tensor): router's weights, [hidden, expert_num].
Returns:
routing_probs, token -> expert_prob
[[0.0000, 0.0000, 0.4006, 0.5994],
...,
[0.0373, 0.0000, 0.9627, 0.0000]]
------------
routing_map, token -> expert_idx
[[False, False, True, True],
...,
[ True, False, True, False]])
"""
y = torch.mm(x, self.weight.to(x.dtype).t()) #y: [token_num, expert_num]
scores, top_indices = torch.topk(y, k=self.topk, dim=1)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(y)
# routing_weights = F.softmax(y, dim=1, dtype=torch.float)
# probs, top_indices = torch.topk(routing_weights, self.topk, dim=-1)
# probs = probs.type_as(y)
routing_probs = torch.zeros_like(y).scatter(1, top_indices, probs)
routing_map = torch.zeros_like(y).int().scatter(1, top_indices, 1).bool()
"""Dispatch: experts-to-tokens
Args:
Returns:
probs: [expert0{token4_prob, token2_prob,token8_prob}.....expertn]
x: [expert0{token4_idx, token2_idx, token8_idx}.....]
"""
permuted_probs = None
num_local_tokens_per_expert = routing_map.sum(dim=0).long() # [token_num_e_1, ...., token_num_e_n]
num_out_tokens = routing_map.size(0) * self.topk
routing_map = routing_map.bool().T.contiguous() # expert-to-token, [expert_num, token_num]
'''
[False, False, False, ..., False, True, True],
[False, False, False, ..., True, False, False],
[ True, True, True, ..., True, True, True],
[ True, True, True, ..., False, False, False]]
'''
token_indices = (
torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
) # [expert_num, token_num]
'''
[[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023],
[ 0, 1, 2, ..., 1021, 1022, 1023]]
'''
sorted_indices = token_indices.masked_select(routing_map) # [topk * token_num]
'''
[ 8, 9, 12, ..., 1015, 1016, 1017],
sorted_indices[:idx_1]->expert0
sorted_indices[idx_1:idx_2]->expert1
sorted_indices[idx_2:idx_3]->expert2
sorted_indices[idx_3:idx_4]->expert3
'''
probs = routing_probs.T.contiguous().masked_select(routing_map) # [topk * token_num]
'''
[0.6458, 0.6458, 0.5577, ..., 0.4983, 0.0520, 0.0520]
'''
x = x.index_select(0, sorted_indices) # [token_num * topk, hidden]
"""compute:
Args:
Returns:
"""
tokens_list = torch.split(x, num_local_tokens_per_expert.tolist())
probs_list = torch.split(probs, num_local_tokens_per_expert.tolist())
output_local_list = []
for expert, tokens, prob in zip(self.experts, tokens_list, probs_list):
output = expert(tokens) * prob.unsqueeze(-1)
output_local_list.append(output)
permuted_tokens = torch.cat(output_local_list, dim=0)
output_tokens = torch.zeros(
restore_shape, dtype=permuted_tokens.dtype, device=permuted_tokens.device
)
# Scatter add the permuted_input back to the original positions
output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens)
output = output_tokens.view(inp_shape).transpose(0, 1)
return output
class FalconH1MoEDecoderLayer(FalconH1DecoderLayer):
def __init__(self, config: FalconH1MoEConfig, layer_idx: int):
super().__init__(config, layer_idx)
self.feed_forward = FalconH1MoEMLP(config)
class FalconH1MoEForCausalLM(FalconH1ForCausalLM):
def __init__(self, config: FalconH1MoEConfig):
super().__init__(config)
self.model = FalconH1MoEModel(config)
__all__ = ["FalconH1MoEForCausalLM"]