|
|
import copy
|
|
|
import os
|
|
|
from typing import Optional
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch import Tensor
|
|
|
import deepspeed
|
|
|
from deepspeed import comm as dist
|
|
|
from deepspeed.utils import groups, log_dist
|
|
|
from deepspeed.utils.timer import SynchronizedWallClockTimer
|
|
|
from deepspeed.moe.sharded_moe import FIRST_ALLTOALL_TIMER, MOE_TIMER, SECOND_ALLTOALL_TIMER, _AllToAll, einsum, gumbel_rsample
|
|
|
from transformers.activations import ACT2FN
|
|
|
|
|
|
def compress_matrix(A: torch.Tensor, mask: torch.Tensor, force_dim: int = None, allow_larger_dim=None) -> torch.Tensor:
|
|
|
if A.shape[:2] != mask.shape:
|
|
|
raise ValueError("First two dimensions of A and mask must match.")
|
|
|
if mask.ndim != 2:
|
|
|
raise ValueError("mask must be a 2D tensor.")
|
|
|
if not ((mask == 0) | (mask == 1)).all():
|
|
|
raise ValueError(
|
|
|
f"mask must only contain 0s and 1s. dtype: {mask.dtype}. "
|
|
|
f"Invalid elements found at indices: {((mask != 0) & (mask != 1)).nonzero().tolist()} "
|
|
|
f"with corresponding values: {mask[((mask != 0) & (mask != 1))].tolist()}. "
|
|
|
f"\nOriginal mask (showing up to first 20 elements if large):\n{mask.flatten()[:20]}{'...' if mask.numel() > 20 else ''}"
|
|
|
)
|
|
|
|
|
|
S, E = mask.shape
|
|
|
trailing_dims_shape = A.shape[2:]
|
|
|
num_trailing_dims = len(trailing_dims_shape)
|
|
|
device = A.device
|
|
|
|
|
|
ones_per_column = mask.sum(dim=0)
|
|
|
X = ones_per_column.max().item() if force_dim is None else force_dim
|
|
|
|
|
|
if X == 0:
|
|
|
return torch.empty((0, E, *trailing_dims_shape), dtype=A.dtype, device=device)
|
|
|
|
|
|
sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
|
|
|
view_shape_for_indices = (S, E, *((1,) * num_trailing_dims))
|
|
|
expanded_indices = sorted_row_indices_2d.view(view_shape_for_indices).expand_as(A)
|
|
|
|
|
|
A_gathered = torch.gather(A, 0, expanded_indices)
|
|
|
|
|
|
if X <= A_gathered.shape[0]:
|
|
|
B_candidate = A_gathered[:X, ...]
|
|
|
elif allow_larger_dim or allow_larger_dim is None:
|
|
|
if allow_larger_dim is None:
|
|
|
print(f"[Warning compress_matrix] Target dimension X ({X}) is larger than "
|
|
|
f"A's original row count S ({S}). Padding B_candidate with zeros.")
|
|
|
B_candidate = A_gathered
|
|
|
zeros_shape = [X - A_gathered.shape[0]] + list(B_candidate.shape[1:])
|
|
|
B_candidate = torch.cat((B_candidate, torch.zeros(zeros_shape, dtype=B_candidate.dtype, device=B_candidate.device)), dim=0)
|
|
|
else:
|
|
|
raise AssertionError(
|
|
|
f"Target dimension X ({X}) is larger than A's original row count S ({S}) "
|
|
|
f"and allow_larger_dim is False. Padding is disallowed."
|
|
|
)
|
|
|
row_indices_for_B = torch.arange(X, device=device).unsqueeze(1)
|
|
|
b_mask_2d = row_indices_for_B < ones_per_column.unsqueeze(0)
|
|
|
view_shape_for_b_mask = (X, E, *((1,) * num_trailing_dims))
|
|
|
B = B_candidate * b_mask_2d.view(view_shape_for_b_mask).to(A.dtype)
|
|
|
|
|
|
return B
|
|
|
|
|
|
|
|
|
def decompress_matrix(B: torch.Tensor, mask: torch.Tensor, allow_larger_dim=None) -> torch.Tensor:
|
|
|
if B.shape[1] != mask.shape[1]:
|
|
|
raise ValueError("B's second dimension and mask's second dimension (E) must match.")
|
|
|
if mask.ndim != 2:
|
|
|
raise ValueError("mask must be a 2D tensor.")
|
|
|
if not ((mask == 0) | (mask == 1)).all():
|
|
|
raise ValueError("mask must only contain 0s and 1s.")
|
|
|
|
|
|
S, E = mask.shape
|
|
|
X = B.shape[0]
|
|
|
trailing_dims_shape = B.shape[2:]
|
|
|
num_trailing_dims = len(trailing_dims_shape)
|
|
|
device = B.device
|
|
|
|
|
|
if X == 0: return torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
|
|
|
if X <= S: pass
|
|
|
elif allow_larger_dim or allow_larger_dim is None:
|
|
|
if allow_larger_dim is None:
|
|
|
print(f"[Warning decompress_matrix] Input B.shape[0] ({X}) is larger than "
|
|
|
f"target A's row count S ({S}). Truncating B to its first {S} rows.")
|
|
|
B = B[:S, ...]
|
|
|
X = S
|
|
|
else:
|
|
|
raise AssertionError(
|
|
|
f"Input B.shape[0] ({X}) is larger than target A's row count S ({S}) "
|
|
|
f"and allow_larger_dim is False. Truncation is disallowed."
|
|
|
)
|
|
|
|
|
|
sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
|
|
|
target_A_row_indices_2d = sorted_row_indices_2d[:X, :]
|
|
|
A_reconstructed = torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
|
|
|
view_shape_for_target_indices = (X, E, *((1,) * num_trailing_dims))
|
|
|
expanded_target_indices = target_A_row_indices_2d.view(view_shape_for_target_indices).expand_as(B)
|
|
|
A_reconstructed.scatter_(dim=0, index=expanded_target_indices, src=B)
|
|
|
|
|
|
return A_reconstructed
|
|
|
|
|
|
|
|
|
|
|
|
class AudioSharedExpertMLP(nn.Module):
|
|
|
"""
|
|
|
Shared expert MLP for UniMoE-Audio model.
|
|
|
Handles common audio feature transformations across all tokens.
|
|
|
"""
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.hidden_size = config.hidden_size
|
|
|
self.intermediate_size = config.shared_intermediate_size
|
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
|
|
def forward(self, hidden_state):
|
|
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
|
|
|
|
class AudioDynamicExpertMLP(nn.Module):
|
|
|
"""
|
|
|
Dynamic expert MLP for UniMoE-Audio model.
|
|
|
Specialized for adaptive audio feature processing based on content.
|
|
|
"""
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.hidden_size = config.hidden_size
|
|
|
self.intermediate_size = config.dynamic_intermediate_size
|
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
|
|
self.act_fn = ACT2FN[config.hidden_act]
|
|
|
|
|
|
def forward(self, hidden_state):
|
|
|
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
|
|
|
|
|
|
|
|
class AudioNullExpertMLP(nn.Module):
|
|
|
"""
|
|
|
Null expert MLP for UniMoE-Audio model.
|
|
|
Returns zero output for tokens that don't require expert processing.
|
|
|
"""
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
|
|
|
def forward(self, hidden_state):
|
|
|
return torch.zeros_like(hidden_state, dtype=hidden_state.dtype, device=hidden_state.device)
|
|
|
|
|
|
|
|
|
def audio_sparse_expert_mixer(scores, top_k, jitter_eps, training):
|
|
|
"""
|
|
|
Sparse expert mixing function for UniMoE-Audio.
|
|
|
Implements adaptive expert selection with noise injection for training.
|
|
|
"""
|
|
|
masked_scores = scores
|
|
|
multiplier_list = []
|
|
|
selected_experts_list = []
|
|
|
|
|
|
for _ in range(top_k):
|
|
|
with torch.no_grad():
|
|
|
mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
|
|
|
factor = scores.abs().clamp(min=mask_logits_threshold.abs())
|
|
|
mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
|
|
|
|
|
|
masked_gates = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
|
|
|
|
|
|
selected_experts = max_ind
|
|
|
|
|
|
masked_gates = torch.softmax(masked_gates, dim=-1)
|
|
|
multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
|
|
|
|
|
|
multiplier = multiplier_o
|
|
|
|
|
|
masked_scores = torch.scatter(
|
|
|
masked_scores,
|
|
|
-1,
|
|
|
selected_experts,
|
|
|
float("-inf"),
|
|
|
)
|
|
|
|
|
|
multiplier_list.append(multiplier)
|
|
|
selected_experts_list.append(selected_experts)
|
|
|
|
|
|
multiplier = torch.concat(multiplier_list, dim=-1)
|
|
|
selected_experts = torch.concat(selected_experts_list, dim=-1)
|
|
|
return (
|
|
|
multiplier,
|
|
|
selected_experts,
|
|
|
)
|
|
|
|
|
|
|
|
|
def audio_dynamic_expert_selection(logits, top_p):
|
|
|
"""
|
|
|
Dynamic expert selection for UniMoE-Audio based on cumulative probability threshold.
|
|
|
Adapts the number of experts based on audio content complexity.
|
|
|
"""
|
|
|
dynamic_scores = torch.softmax(logits, dim=-1)
|
|
|
dynamic_scores_sorted, _ = torch.sort(dynamic_scores, dim=-1, descending=True)
|
|
|
dynamic_scores_cumsum = dynamic_scores_sorted.cumsum(dim=-1)
|
|
|
dynamic_top_k = (~(dynamic_scores_cumsum >= top_p)).sum(dim=-1)
|
|
|
dynamic_top_k = dynamic_top_k + 1
|
|
|
return dynamic_top_k
|
|
|
|
|
|
|
|
|
def _audio_expert_capacity(num_tokens, num_experts, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
|
|
|
"""Calculate expert capacity for UniMoE-Audio based on token distribution and capacity factor."""
|
|
|
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
|
|
|
if capacity < min_capacity:
|
|
|
capacity = min_capacity.to(torch.int64)
|
|
|
return capacity
|
|
|
|
|
|
|
|
|
def calculate_audio_global_routing_weight(
|
|
|
expert_mask: torch.Tensor,
|
|
|
full_router_logits: torch.Tensor,
|
|
|
mlp_dynamic_expert_num: int,
|
|
|
routing_weights: torch.Tensor,
|
|
|
):
|
|
|
"""
|
|
|
Calculate global routing weights for UniMoE-Audio combining dynamic and fixed expert weights.
|
|
|
Optimized for audio generation tasks.
|
|
|
"""
|
|
|
global_weight = torch.softmax(full_router_logits.masked_fill(expert_mask == 0, float("-inf")), dim=-1)
|
|
|
global_dynamic_weight = global_weight[:, :mlp_dynamic_expert_num]
|
|
|
global_fixed_weight = global_weight[:, mlp_dynamic_expert_num:]
|
|
|
global_dynamic_weight = routing_weights * global_dynamic_weight.sum(-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1])
|
|
|
global_weight = torch.cat((global_dynamic_weight, global_fixed_weight), dim=-1)
|
|
|
return global_weight
|
|
|
|
|
|
|
|
|
class UniMoEAudioSparseMoeBlock(nn.Module):
|
|
|
"""
|
|
|
UniMoE-Audio Sparse Mixture of Experts block with dynamic routing and expert selection.
|
|
|
Optimized for audio generation tasks with efficient sparse operations and capacity management.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = config.hidden_size
|
|
|
self.mlp_dynamic_expert_num = config.mlp_dynamic_expert_num + config.mlp_dynamic_null_expert_num
|
|
|
self.mlp_dynamic_real_expert_num = config.mlp_dynamic_expert_num
|
|
|
self.mlp_dynamic_null_expert_num = config.mlp_dynamic_null_expert_num
|
|
|
self.mlp_dynamic_top_p = config.mlp_dynamic_top_p
|
|
|
self.mlp_dynamic_top_k = config.mlp_dynamic_top_k
|
|
|
self.mlp_fixed_expert_num = config.mlp_fixed_expert_num
|
|
|
self.num_experts = self.mlp_dynamic_expert_num + self.mlp_fixed_expert_num
|
|
|
|
|
|
if self.mlp_dynamic_top_p == 0:
|
|
|
print(f"mlp_dynamic_top_p is 0, will use mlp_dynamic_top_k={self.mlp_dynamic_top_k} instead !!!")
|
|
|
|
|
|
self.ignore_differentiable_router = config.ignore_differentiable_router
|
|
|
if self.ignore_differentiable_router:
|
|
|
print("ignore_differentiable_router is True, will not use router_logits !!!")
|
|
|
|
|
|
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
|
|
|
self.fixed_real_moe = nn.ModuleList([AudioSharedExpertMLP(config) for _ in range(self.mlp_fixed_expert_num)])
|
|
|
self.dynamic_real_moe = UniMoEAudioMoE(config, AudioDynamicExpertMLP(config), self.mlp_dynamic_real_expert_num, config.ep_size)
|
|
|
|
|
|
self.router_jitter_noise = config.router_jitter_noise
|
|
|
self.input_jitter_noise = config.input_jitter_noise
|
|
|
|
|
|
self.min_capacity = config.min_capacity
|
|
|
self.capacity_factor = config.capacity_factor
|
|
|
self.token_drop = config.token_drop
|
|
|
self.drop_policy = config.drop_policy
|
|
|
|
|
|
self.avg_hidden_states_last = config.avg_hidden_states_last
|
|
|
self.drop_token_num_print = config.drop_token_num_print
|
|
|
self.fp32_gate = config.fp32_gate
|
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, aux_balance_weight: torch.Tensor=None):
|
|
|
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
|
|
original_hidden_states = hidden_states
|
|
|
|
|
|
if self.training and self.fp32_gate:
|
|
|
hidden_states = hidden_states.float()
|
|
|
|
|
|
if self.training and self.input_jitter_noise > 0:
|
|
|
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise)
|
|
|
|
|
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
|
|
|
|
|
if self.training and self.fp32_gate:
|
|
|
full_router_logits = torch.nn.functional.linear(hidden_states, weight=self.gate.weight.float(), bias=None)
|
|
|
else:
|
|
|
full_router_logits = self.gate(hidden_states)
|
|
|
dynamic_router_logits = full_router_logits[:, : self.mlp_dynamic_expert_num]
|
|
|
|
|
|
if self.mlp_dynamic_top_p != 0:
|
|
|
dynamic_top_k = audio_dynamic_expert_selection(dynamic_router_logits, self.mlp_dynamic_top_p)
|
|
|
else:
|
|
|
dynamic_top_k = torch.full((dynamic_router_logits.shape[0],), self.mlp_dynamic_top_k, dtype=torch.int, device=dynamic_router_logits.device)
|
|
|
|
|
|
expert_mask = torch.zeros((batch_size * sequence_length, self.num_experts), dtype=torch.int, device=hidden_states.device)
|
|
|
|
|
|
routing_weights = torch.zeros((batch_size * sequence_length, self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
|
|
|
for top_k in range(1, self.mlp_dynamic_expert_num + 1):
|
|
|
group_idx = torch.nonzero(dynamic_top_k == top_k, as_tuple=True)[0]
|
|
|
if len(group_idx) == 0:
|
|
|
continue
|
|
|
|
|
|
dynamic_group_logits = dynamic_router_logits[group_idx]
|
|
|
group_routing_weights, group_selected_experts = audio_sparse_expert_mixer(
|
|
|
dynamic_group_logits,
|
|
|
top_k=top_k,
|
|
|
jitter_eps=self.router_jitter_noise,
|
|
|
training=self.training and not self.ignore_differentiable_router,
|
|
|
)
|
|
|
|
|
|
group_expert_mask = torch.nn.functional.one_hot(group_selected_experts, num_classes=self.num_experts)
|
|
|
group_expert_mask = group_expert_mask.sum(dim=1)
|
|
|
|
|
|
group_weight = torch.zeros((len(group_idx), self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
|
|
|
group_weight.scatter_(dim=-1, index=group_selected_experts, src=group_routing_weights)
|
|
|
routing_weights.index_add_(0, group_idx, group_weight)
|
|
|
|
|
|
expert_mask.index_add_(0, group_idx, group_expert_mask.to(expert_mask.dtype))
|
|
|
|
|
|
routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
attention_mask = attention_mask.to(expert_mask.dtype).view(-1).unsqueeze(-1).expand(-1, self.num_experts)
|
|
|
expert_mask = expert_mask * attention_mask
|
|
|
|
|
|
if self.mlp_dynamic_expert_num < self.num_experts:
|
|
|
expert_mask[:, self.mlp_dynamic_expert_num :] = 1
|
|
|
|
|
|
aux_loss = audio_load_balancing_loss_func(
|
|
|
expert_mask=expert_mask,
|
|
|
mlp_dynamic_expert_num=self.mlp_dynamic_expert_num,
|
|
|
global_weight=None,
|
|
|
full_router_logits=full_router_logits,
|
|
|
routing_weights=routing_weights,
|
|
|
aux_balance_weight=aux_balance_weight,
|
|
|
)
|
|
|
|
|
|
if self.token_drop:
|
|
|
expert_mask_dtype = expert_mask.dtype
|
|
|
capacity = _audio_expert_capacity(batch_size * sequence_length, self.mlp_dynamic_expert_num, torch.tensor(self.capacity_factor), torch.tensor(self.min_capacity))
|
|
|
if self.drop_policy == "probs":
|
|
|
if capacity > dynamic_router_logits.shape[0]:
|
|
|
print(f"[warning] token capacity({capacity}) > token num({dynamic_router_logits.shape[0]}), setting capacity=token num")
|
|
|
capacity = dynamic_router_logits.shape[0]
|
|
|
dynamic_expert_mask = expert_mask[:, : self.mlp_dynamic_expert_num].bool()
|
|
|
token_drop_router_logits = torch.masked_fill(dynamic_router_logits, ~dynamic_expert_mask, torch.finfo(dynamic_router_logits.dtype).min)
|
|
|
capacity_probs, capacity_indices = torch.topk(token_drop_router_logits, k=capacity, dim=0, sorted=False)
|
|
|
capacity_mask = torch.zeros_like(expert_mask).scatter(0, capacity_indices, 1)
|
|
|
capacity_mask[:, self.mlp_dynamic_expert_num :] = 1
|
|
|
expert_mask = torch.logical_and(expert_mask, capacity_mask)
|
|
|
|
|
|
ori_token_num = dynamic_expert_mask.sum().item()
|
|
|
cur_token_num = expert_mask[:, : self.mlp_dynamic_expert_num].sum().item()
|
|
|
if self.drop_token_num_print and ("RANK" not in os.environ or int(os.environ["RANK"]) == 0):
|
|
|
print(f"drop {ori_token_num - cur_token_num} tokens from total {ori_token_num} tokens")
|
|
|
|
|
|
elif self.drop_policy == "position":
|
|
|
locations = torch.cumsum(expert_mask, dim=0) - 1
|
|
|
expert_mask *= torch.lt(locations, capacity)
|
|
|
else:
|
|
|
raise ValueError(f"Invalid drop_policy: {self.drop_policy}")
|
|
|
expert_mask = expert_mask.to(expert_mask_dtype)
|
|
|
|
|
|
routing_weights = routing_weights.masked_fill(~(expert_mask[:, : self.mlp_dynamic_expert_num].bool()), 0.0)
|
|
|
routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
|
|
|
|
|
|
if self.mlp_dynamic_expert_num < self.num_experts:
|
|
|
global_weight = calculate_audio_global_routing_weight(expert_mask, full_router_logits, self.mlp_dynamic_expert_num, routing_weights)
|
|
|
else:
|
|
|
global_weight = routing_weights
|
|
|
|
|
|
hidden_states = original_hidden_states.view(-1, hidden_dim)
|
|
|
|
|
|
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
|
|
|
global_weight = global_weight.to(hidden_states.dtype)
|
|
|
|
|
|
current_hidden_states = self.dynamic_real_moe(hidden_states, expert_mask=expert_mask[:, : self.mlp_dynamic_real_expert_num], router_weight=global_weight[:, : self.mlp_dynamic_real_expert_num])
|
|
|
final_hidden_states = final_hidden_states + current_hidden_states
|
|
|
|
|
|
for expert_idx in range(self.mlp_fixed_expert_num):
|
|
|
expert_layer = self.fixed_real_moe[expert_idx]
|
|
|
|
|
|
current_state = hidden_states
|
|
|
current_global_weight = global_weight[:, self.mlp_dynamic_expert_num + expert_idx].unsqueeze(-1)
|
|
|
current_hidden_states = expert_layer(current_state) * current_global_weight
|
|
|
|
|
|
final_hidden_states = final_hidden_states + current_hidden_states
|
|
|
|
|
|
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
|
|
|
|
|
if not self.training and self.avg_hidden_states_last:
|
|
|
dist.all_reduce(final_hidden_states, op=dist.ReduceOp.AVG, group=self.dynamic_real_moe.deepspeed_moe.ep_group)
|
|
|
|
|
|
return final_hidden_states, full_router_logits, dynamic_top_k, expert_mask, global_weight, aux_loss
|
|
|
|
|
|
|
|
|
def audio_load_balancing_loss_func(
|
|
|
expert_mask: torch.Tensor,
|
|
|
mlp_dynamic_expert_num: int,
|
|
|
global_weight: Optional[torch.Tensor] = None,
|
|
|
full_router_logits: Optional[torch.Tensor] = None,
|
|
|
routing_weights: Optional[torch.Tensor] = None,
|
|
|
aux_balance_weight: Optional[torch.Tensor] = None,
|
|
|
) -> float:
|
|
|
"""Calculate load balancing loss for UniMoE-Audio expert routing to encourage balanced usage."""
|
|
|
min_dtype = torch.finfo(full_router_logits.dtype).min
|
|
|
global_weight = full_router_logits.masked_fill(expert_mask == 0, min_dtype)
|
|
|
global_weight = global_weight[:, :mlp_dynamic_expert_num]
|
|
|
global_weight = torch.softmax(global_weight, dim=-1)
|
|
|
expert_mask = expert_mask[:, :mlp_dynamic_expert_num]
|
|
|
|
|
|
num_experts = expert_mask.shape[-1]
|
|
|
if aux_balance_weight is None:
|
|
|
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
|
|
|
router_prob_per_expert = torch.mean(global_weight, dim=0)
|
|
|
else:
|
|
|
batch_size, sequence_length = aux_balance_weight.shape
|
|
|
num_hidden_layers = global_weight.shape[0] // (batch_size * sequence_length)
|
|
|
expert_attention_mask = aux_balance_weight[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(global_weight.device)
|
|
|
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
|
|
|
router_prob_per_expert = torch.sum(global_weight * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
|
|
|
|
|
|
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
|
|
|
|
|
|
return overall_loss * num_experts
|
|
|
|
|
|
|
|
|
class AudioExperts(deepspeed.moe.experts.Experts):
|
|
|
"""Custom Audio experts class extending DeepSpeed MoE experts with additional functionality."""
|
|
|
|
|
|
def __init__(self, expert, num_local_experts=1, expert_group_name=None):
|
|
|
super(deepspeed.moe.experts.Experts, self).__init__()
|
|
|
|
|
|
self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
|
|
self.num_local_experts = num_local_experts
|
|
|
|
|
|
for expert in self.deepspeed_experts:
|
|
|
for name, param in expert.named_parameters():
|
|
|
param.allreduce = False
|
|
|
param.group_name = expert_group_name
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
|
|
expert_outputs = []
|
|
|
for chunk, expert in zip(chunks, self.deepspeed_experts):
|
|
|
out = expert(chunk)
|
|
|
if type(out) is tuple:
|
|
|
out = out[0]
|
|
|
expert_outputs += [out]
|
|
|
|
|
|
expert_output = torch.cat(expert_outputs, dim=1)
|
|
|
return expert_output
|
|
|
|
|
|
|
|
|
class AudioMOELayer(deepspeed.moe.sharded_moe.MOELayer):
|
|
|
"""Custom Audio MoE layer extending DeepSpeed MOELayer with matrix compression optimization."""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
experts: nn.Module,
|
|
|
ep_group_name,
|
|
|
ep_size,
|
|
|
num_local_experts: int,
|
|
|
use_tutel: bool = False,
|
|
|
) -> None:
|
|
|
super(deepspeed.moe.sharded_moe.MOELayer, self).__init__()
|
|
|
|
|
|
self.experts = experts
|
|
|
self.ep_group = None
|
|
|
self.ep_size = ep_size
|
|
|
self.ep_group_name = ep_group_name
|
|
|
self.num_local_experts = num_local_experts
|
|
|
self.time_falltoall = 0.0
|
|
|
self.time_salltoall = 0.0
|
|
|
self.time_moe = 0.0
|
|
|
self.timers = SynchronizedWallClockTimer()
|
|
|
self.wall_clock_breakdown = False
|
|
|
|
|
|
def _set_ep_group(self, ep_group):
|
|
|
self.ep_group = ep_group
|
|
|
|
|
|
def forward(self, hidden_states: Tensor, expert_mask: Tensor, router_weight: Tensor) -> Tensor:
|
|
|
router_weight = router_weight * expert_mask
|
|
|
|
|
|
if self.wall_clock_breakdown:
|
|
|
self.timers(MOE_TIMER).start()
|
|
|
|
|
|
d_model = hidden_states.shape[-1]
|
|
|
seq_len = hidden_states.shape[0]
|
|
|
expert_num = expert_mask.shape[-1]
|
|
|
capacity = expert_mask.sum(dim=0).max()
|
|
|
if self.ep_group is not None:
|
|
|
dist.all_reduce(capacity, op=dist.ReduceOp.MAX, group=self.ep_group)
|
|
|
|
|
|
compres_hidden_states = hidden_states.unsqueeze(1).expand(seq_len, expert_num, d_model)
|
|
|
compres_hidden_states = compress_matrix(compres_hidden_states, expert_mask, force_dim=capacity, allow_larger_dim=True)
|
|
|
compres_expert_mask = compress_matrix(expert_mask, expert_mask, force_dim=capacity, allow_larger_dim=True)
|
|
|
dispatched_input = einsum("ce,cem->ecm", compres_expert_mask, compres_hidden_states)
|
|
|
|
|
|
if self.wall_clock_breakdown:
|
|
|
self.timers(FIRST_ALLTOALL_TIMER).start()
|
|
|
|
|
|
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
|
|
|
|
|
|
if self.wall_clock_breakdown:
|
|
|
self.timers(FIRST_ALLTOALL_TIMER).stop()
|
|
|
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
|
|
|
|
|
|
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
|
|
|
|
|
|
expert_output = self.experts(dispatched_input)
|
|
|
|
|
|
if self.wall_clock_breakdown:
|
|
|
self.timers(SECOND_ALLTOALL_TIMER).start()
|
|
|
|
|
|
expert_output = _AllToAll.apply(self.ep_group, expert_output)
|
|
|
|
|
|
if self.wall_clock_breakdown:
|
|
|
self.timers(SECOND_ALLTOALL_TIMER).stop()
|
|
|
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
|
|
|
|
|
|
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
|
|
|
expert_output = decompress_matrix(expert_output.transpose(0, 1), expert_mask, allow_larger_dim=True)
|
|
|
combined_output = einsum("se,sem->sm", router_weight, expert_output)
|
|
|
if self.wall_clock_breakdown:
|
|
|
self.timers(MOE_TIMER).stop()
|
|
|
self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False)
|
|
|
|
|
|
return combined_output
|
|
|
|
|
|
|
|
|
class UniMoEAudioMoE(deepspeed.moe.layer.MoE):
|
|
|
"""Custom Audio MoE class extending DeepSpeed MoE with configuration and parallelism setup."""
|
|
|
|
|
|
def __init__(self, config, expert, num_experts, ep_size, moe_name_prefix="ep_size"):
|
|
|
super(deepspeed.moe.layer.MoE, self).__init__()
|
|
|
self.enable_expert_tensor_parallelism = config.enable_expert_tensor_parallelism
|
|
|
self.ep_size = ep_size
|
|
|
self.num_experts = num_experts
|
|
|
self.expert_group_name = f"{moe_name_prefix}_{self.ep_size}"
|
|
|
self.num_local_experts = self.num_experts // self.ep_size
|
|
|
log_dist(f"Creating MoE layer with num_experts: {self.num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}", [0])
|
|
|
experts = AudioExperts(expert, self.num_local_experts, self.expert_group_name)
|
|
|
self.deepspeed_moe = AudioMOELayer(experts, self.expert_group_name, self.ep_size, self.num_local_experts)
|
|
|
|
|
|
def set_deepspeed_parallelism(self, use_data_before_expert_parallel_=False):
|
|
|
self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_)
|
|
|
|
|
|
def _create_process_groups(self, use_data_before_expert_parallel_=False):
|
|
|
if self.expert_group_name not in groups._get_expert_parallel_group_dict():
|
|
|
print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
|
|
|
if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
|
|
|
groups._create_expert_and_data_parallel(self.ep_size, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
|
|
|
else:
|
|
|
groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
|
|
|
self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
|
|
|
|
|
|
def forward(self, *input_args, **input_kwargs):
|
|
|
return self.deepspeed_moe(*input_args, **input_kwargs)
|
|
|
|