File size: 27,707 Bytes
e576ca4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 | import copy
import os
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
import deepspeed
from deepspeed import comm as dist
from deepspeed.utils import groups, log_dist
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.moe.sharded_moe import FIRST_ALLTOALL_TIMER, MOE_TIMER, SECOND_ALLTOALL_TIMER, _AllToAll, einsum, gumbel_rsample
from transformers.activations import ACT2FN
def compress_matrix(A: torch.Tensor, mask: torch.Tensor, force_dim: int = None, allow_larger_dim=None) -> torch.Tensor:
if A.shape[:2] != mask.shape:
raise ValueError("First two dimensions of A and mask must match.")
if mask.ndim != 2:
raise ValueError("mask must be a 2D tensor.")
if not ((mask == 0) | (mask == 1)).all():
raise ValueError(
f"mask must only contain 0s and 1s. dtype: {mask.dtype}. "
f"Invalid elements found at indices: {((mask != 0) & (mask != 1)).nonzero().tolist()} " # Get indices of elements not 0 AND not 1
f"with corresponding values: {mask[((mask != 0) & (mask != 1))].tolist()}. " # Get the values at those indices
f"\nOriginal mask (showing up to first 20 elements if large):\n{mask.flatten()[:20]}{'...' if mask.numel() > 20 else ''}"
)
S, E = mask.shape
trailing_dims_shape = A.shape[2:]
num_trailing_dims = len(trailing_dims_shape)
device = A.device
ones_per_column = mask.sum(dim=0)
X = ones_per_column.max().item() if force_dim is None else force_dim
if X == 0:
return torch.empty((0, E, *trailing_dims_shape), dtype=A.dtype, device=device)
sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
view_shape_for_indices = (S, E, *((1,) * num_trailing_dims))
expanded_indices = sorted_row_indices_2d.view(view_shape_for_indices).expand_as(A)
A_gathered = torch.gather(A, 0, expanded_indices)
if X <= A_gathered.shape[0]:
B_candidate = A_gathered[:X, ...]
elif allow_larger_dim or allow_larger_dim is None:
if allow_larger_dim is None:
print(f"[Warning compress_matrix] Target dimension X ({X}) is larger than "
f"A's original row count S ({S}). Padding B_candidate with zeros.")
B_candidate = A_gathered
zeros_shape = [X - A_gathered.shape[0]] + list(B_candidate.shape[1:])
B_candidate = torch.cat((B_candidate, torch.zeros(zeros_shape, dtype=B_candidate.dtype, device=B_candidate.device)), dim=0) # Shape (X_target_dim, E, ...)
else:
raise AssertionError(
f"Target dimension X ({X}) is larger than A's original row count S ({S}) "
f"and allow_larger_dim is False. Padding is disallowed."
)
row_indices_for_B = torch.arange(X, device=device).unsqueeze(1)
b_mask_2d = row_indices_for_B < ones_per_column.unsqueeze(0)
view_shape_for_b_mask = (X, E, *((1,) * num_trailing_dims))
B = B_candidate * b_mask_2d.view(view_shape_for_b_mask).to(A.dtype)
return B
def decompress_matrix(B: torch.Tensor, mask: torch.Tensor, allow_larger_dim=None) -> torch.Tensor:
if B.shape[1] != mask.shape[1]:
raise ValueError("B's second dimension and mask's second dimension (E) must match.")
if mask.ndim != 2:
raise ValueError("mask must be a 2D tensor.")
if not ((mask == 0) | (mask == 1)).all():
raise ValueError("mask must only contain 0s and 1s.")
S, E = mask.shape
X = B.shape[0]
trailing_dims_shape = B.shape[2:]
num_trailing_dims = len(trailing_dims_shape)
device = B.device
if X == 0: return torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
if X <= S: pass
elif allow_larger_dim or allow_larger_dim is None:
if allow_larger_dim is None:
print(f"[Warning decompress_matrix] Input B.shape[0] ({X}) is larger than "
f"target A's row count S ({S}). Truncating B to its first {S} rows.")
B = B[:S, ...]
X = S
else:
raise AssertionError(
f"Input B.shape[0] ({X}) is larger than target A's row count S ({S}) "
f"and allow_larger_dim is False. Truncation is disallowed."
)
sorted_row_indices_2d = torch.argsort(mask.float(), dim=0, descending=True)
target_A_row_indices_2d = sorted_row_indices_2d[:X, :]
A_reconstructed = torch.zeros((S, E, *trailing_dims_shape), dtype=B.dtype, device=device)
view_shape_for_target_indices = (X, E, *((1,) * num_trailing_dims))
expanded_target_indices = target_A_row_indices_2d.view(view_shape_for_target_indices).expand_as(B)
A_reconstructed.scatter_(dim=0, index=expanded_target_indices, src=B)
return A_reconstructed
class AudioSharedExpertMLP(nn.Module):
"""
Shared expert MLP for UniMoE-Audio model.
Handles common audio feature transformations across all tokens.
"""
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.shared_intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class AudioDynamicExpertMLP(nn.Module):
"""
Dynamic expert MLP for UniMoE-Audio model.
Specialized for adaptive audio feature processing based on content.
"""
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.dynamic_intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class AudioNullExpertMLP(nn.Module):
"""
Null expert MLP for UniMoE-Audio model.
Returns zero output for tokens that don't require expert processing.
"""
def __init__(self, config):
super().__init__()
def forward(self, hidden_state):
return torch.zeros_like(hidden_state, dtype=hidden_state.dtype, device=hidden_state.device)
def audio_sparse_expert_mixer(scores, top_k, jitter_eps, training):
"""
Sparse expert mixing function for UniMoE-Audio.
Implements adaptive expert selection with noise injection for training.
"""
masked_scores = scores
multiplier_list = []
selected_experts_list = []
for _ in range(top_k):
with torch.no_grad():
mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
factor = scores.abs().clamp(min=mask_logits_threshold.abs())
mask_logits_threshold = ((mask_logits_threshold - scores) / factor) > (2 * jitter_eps)
masked_gates = masked_scores.masked_fill(mask_logits_threshold, float("-inf"))
selected_experts = max_ind
masked_gates = torch.softmax(masked_gates, dim=-1)
multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
multiplier = multiplier_o
masked_scores = torch.scatter(
masked_scores,
-1,
selected_experts,
float("-inf"),
)
multiplier_list.append(multiplier)
selected_experts_list.append(selected_experts)
multiplier = torch.concat(multiplier_list, dim=-1)
selected_experts = torch.concat(selected_experts_list, dim=-1)
return (
multiplier,
selected_experts,
)
def audio_dynamic_expert_selection(logits, top_p):
"""
Dynamic expert selection for UniMoE-Audio based on cumulative probability threshold.
Adapts the number of experts based on audio content complexity.
"""
dynamic_scores = torch.softmax(logits, dim=-1)
dynamic_scores_sorted, _ = torch.sort(dynamic_scores, dim=-1, descending=True)
dynamic_scores_cumsum = dynamic_scores_sorted.cumsum(dim=-1)
dynamic_top_k = (~(dynamic_scores_cumsum >= top_p)).sum(dim=-1)
dynamic_top_k = dynamic_top_k + 1
return dynamic_top_k
def _audio_expert_capacity(num_tokens, num_experts, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor:
"""Calculate expert capacity for UniMoE-Audio based on token distribution and capacity factor."""
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64)
if capacity < min_capacity:
capacity = min_capacity.to(torch.int64)
return capacity
def calculate_audio_global_routing_weight(
expert_mask: torch.Tensor,
full_router_logits: torch.Tensor,
mlp_dynamic_expert_num: int,
routing_weights: torch.Tensor,
):
"""
Calculate global routing weights for UniMoE-Audio combining dynamic and fixed expert weights.
Optimized for audio generation tasks.
"""
global_weight = torch.softmax(full_router_logits.masked_fill(expert_mask == 0, float("-inf")), dim=-1)
global_dynamic_weight = global_weight[:, :mlp_dynamic_expert_num]
global_fixed_weight = global_weight[:, mlp_dynamic_expert_num:]
global_dynamic_weight = routing_weights * global_dynamic_weight.sum(-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1])
global_weight = torch.cat((global_dynamic_weight, global_fixed_weight), dim=-1)
return global_weight
class UniMoEAudioSparseMoeBlock(nn.Module):
"""
UniMoE-Audio Sparse Mixture of Experts block with dynamic routing and expert selection.
Optimized for audio generation tasks with efficient sparse operations and capacity management.
"""
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.mlp_dynamic_expert_num = config.mlp_dynamic_expert_num + config.mlp_dynamic_null_expert_num
self.mlp_dynamic_real_expert_num = config.mlp_dynamic_expert_num
self.mlp_dynamic_null_expert_num = config.mlp_dynamic_null_expert_num
self.mlp_dynamic_top_p = config.mlp_dynamic_top_p
self.mlp_dynamic_top_k = config.mlp_dynamic_top_k
self.mlp_fixed_expert_num = config.mlp_fixed_expert_num
self.num_experts = self.mlp_dynamic_expert_num + self.mlp_fixed_expert_num
if self.mlp_dynamic_top_p == 0:
print(f"mlp_dynamic_top_p is 0, will use mlp_dynamic_top_k={self.mlp_dynamic_top_k} instead !!!")
self.ignore_differentiable_router = config.ignore_differentiable_router
if self.ignore_differentiable_router:
print("ignore_differentiable_router is True, will not use router_logits !!!")
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.fixed_real_moe = nn.ModuleList([AudioSharedExpertMLP(config) for _ in range(self.mlp_fixed_expert_num)])
self.dynamic_real_moe = UniMoEAudioMoE(config, AudioDynamicExpertMLP(config), self.mlp_dynamic_real_expert_num, config.ep_size)
self.router_jitter_noise = config.router_jitter_noise
self.input_jitter_noise = config.input_jitter_noise
self.min_capacity = config.min_capacity
self.capacity_factor = config.capacity_factor
self.token_drop = config.token_drop
self.drop_policy = config.drop_policy
self.avg_hidden_states_last = config.avg_hidden_states_last
self.drop_token_num_print = config.drop_token_num_print
self.fp32_gate = config.fp32_gate
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, aux_balance_weight: torch.Tensor=None):
batch_size, sequence_length, hidden_dim = hidden_states.shape
original_hidden_states = hidden_states
if self.training and self.fp32_gate:
hidden_states = hidden_states.float()
if self.training and self.input_jitter_noise > 0:
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise)
hidden_states = hidden_states.view(-1, hidden_dim)
if self.training and self.fp32_gate:
full_router_logits = torch.nn.functional.linear(hidden_states, weight=self.gate.weight.float(), bias=None)
else:
full_router_logits = self.gate(hidden_states)
dynamic_router_logits = full_router_logits[:, : self.mlp_dynamic_expert_num]
if self.mlp_dynamic_top_p != 0:
dynamic_top_k = audio_dynamic_expert_selection(dynamic_router_logits, self.mlp_dynamic_top_p)
else:
dynamic_top_k = torch.full((dynamic_router_logits.shape[0],), self.mlp_dynamic_top_k, dtype=torch.int, device=dynamic_router_logits.device)
expert_mask = torch.zeros((batch_size * sequence_length, self.num_experts), dtype=torch.int, device=hidden_states.device)
routing_weights = torch.zeros((batch_size * sequence_length, self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
for top_k in range(1, self.mlp_dynamic_expert_num + 1):
group_idx = torch.nonzero(dynamic_top_k == top_k, as_tuple=True)[0]
if len(group_idx) == 0:
continue
dynamic_group_logits = dynamic_router_logits[group_idx]
group_routing_weights, group_selected_experts = audio_sparse_expert_mixer(
dynamic_group_logits,
top_k=top_k,
jitter_eps=self.router_jitter_noise,
training=self.training and not self.ignore_differentiable_router,
)
group_expert_mask = torch.nn.functional.one_hot(group_selected_experts, num_classes=self.num_experts)
group_expert_mask = group_expert_mask.sum(dim=1)
group_weight = torch.zeros((len(group_idx), self.mlp_dynamic_expert_num), dtype=hidden_states.dtype, device=hidden_states.device)
group_weight.scatter_(dim=-1, index=group_selected_experts, src=group_routing_weights)
routing_weights.index_add_(0, group_idx, group_weight)
expert_mask.index_add_(0, group_idx, group_expert_mask.to(expert_mask.dtype))
routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
if attention_mask is not None:
attention_mask = attention_mask.to(expert_mask.dtype).view(-1).unsqueeze(-1).expand(-1, self.num_experts)
expert_mask = expert_mask * attention_mask
if self.mlp_dynamic_expert_num < self.num_experts:
expert_mask[:, self.mlp_dynamic_expert_num :] = 1
aux_loss = audio_load_balancing_loss_func(
expert_mask=expert_mask,
mlp_dynamic_expert_num=self.mlp_dynamic_expert_num,
global_weight=None,
full_router_logits=full_router_logits,
routing_weights=routing_weights,
aux_balance_weight=aux_balance_weight,
)
if self.token_drop:
expert_mask_dtype = expert_mask.dtype
capacity = _audio_expert_capacity(batch_size * sequence_length, self.mlp_dynamic_expert_num, torch.tensor(self.capacity_factor), torch.tensor(self.min_capacity))
if self.drop_policy == "probs":
if capacity > dynamic_router_logits.shape[0]:
print(f"[warning] token capacity({capacity}) > token num({dynamic_router_logits.shape[0]}), setting capacity=token num")
capacity = dynamic_router_logits.shape[0]
dynamic_expert_mask = expert_mask[:, : self.mlp_dynamic_expert_num].bool()
token_drop_router_logits = torch.masked_fill(dynamic_router_logits, ~dynamic_expert_mask, torch.finfo(dynamic_router_logits.dtype).min)
capacity_probs, capacity_indices = torch.topk(token_drop_router_logits, k=capacity, dim=0, sorted=False)
capacity_mask = torch.zeros_like(expert_mask).scatter(0, capacity_indices, 1)
capacity_mask[:, self.mlp_dynamic_expert_num :] = 1
expert_mask = torch.logical_and(expert_mask, capacity_mask)
ori_token_num = dynamic_expert_mask.sum().item()
cur_token_num = expert_mask[:, : self.mlp_dynamic_expert_num].sum().item()
if self.drop_token_num_print and ("RANK" not in os.environ or int(os.environ["RANK"]) == 0):
print(f"drop {ori_token_num - cur_token_num} tokens from total {ori_token_num} tokens")
elif self.drop_policy == "position":
locations = torch.cumsum(expert_mask, dim=0) - 1
expert_mask *= torch.lt(locations, capacity)
else:
raise ValueError(f"Invalid drop_policy: {self.drop_policy}")
expert_mask = expert_mask.to(expert_mask_dtype)
routing_weights = routing_weights.masked_fill(~(expert_mask[:, : self.mlp_dynamic_expert_num].bool()), 0.0)
routing_weights = routing_weights / (routing_weights.sum(dim=-1).unsqueeze(-1).expand(-1, routing_weights.shape[-1]) + 1e-6)
if self.mlp_dynamic_expert_num < self.num_experts:
global_weight = calculate_audio_global_routing_weight(expert_mask, full_router_logits, self.mlp_dynamic_expert_num, routing_weights)
else:
global_weight = routing_weights
hidden_states = original_hidden_states.view(-1, hidden_dim)
final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device)
global_weight = global_weight.to(hidden_states.dtype)
current_hidden_states = self.dynamic_real_moe(hidden_states, expert_mask=expert_mask[:, : self.mlp_dynamic_real_expert_num], router_weight=global_weight[:, : self.mlp_dynamic_real_expert_num])
final_hidden_states = final_hidden_states + current_hidden_states
for expert_idx in range(self.mlp_fixed_expert_num):
expert_layer = self.fixed_real_moe[expert_idx]
current_state = hidden_states
current_global_weight = global_weight[:, self.mlp_dynamic_expert_num + expert_idx].unsqueeze(-1)
current_hidden_states = expert_layer(current_state) * current_global_weight
final_hidden_states = final_hidden_states + current_hidden_states
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
if not self.training and self.avg_hidden_states_last:
dist.all_reduce(final_hidden_states, op=dist.ReduceOp.AVG, group=self.dynamic_real_moe.deepspeed_moe.ep_group)
return final_hidden_states, full_router_logits, dynamic_top_k, expert_mask, global_weight, aux_loss
def audio_load_balancing_loss_func(
expert_mask: torch.Tensor,
mlp_dynamic_expert_num: int,
global_weight: Optional[torch.Tensor] = None,
full_router_logits: Optional[torch.Tensor] = None,
routing_weights: Optional[torch.Tensor] = None,
aux_balance_weight: Optional[torch.Tensor] = None,
) -> float:
"""Calculate load balancing loss for UniMoE-Audio expert routing to encourage balanced usage."""
min_dtype = torch.finfo(full_router_logits.dtype).min
global_weight = full_router_logits.masked_fill(expert_mask == 0, min_dtype)
global_weight = global_weight[:, :mlp_dynamic_expert_num]
global_weight = torch.softmax(global_weight, dim=-1)
expert_mask = expert_mask[:, :mlp_dynamic_expert_num]
num_experts = expert_mask.shape[-1]
if aux_balance_weight is None:
tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
router_prob_per_expert = torch.mean(global_weight, dim=0)
else:
batch_size, sequence_length = aux_balance_weight.shape
num_hidden_layers = global_weight.shape[0] // (batch_size * sequence_length)
expert_attention_mask = aux_balance_weight[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(global_weight.device)
tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
router_prob_per_expert = torch.sum(global_weight * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert)
return overall_loss * num_experts
class AudioExperts(deepspeed.moe.experts.Experts):
"""Custom Audio experts class extending DeepSpeed MoE experts with additional functionality."""
def __init__(self, expert, num_local_experts=1, expert_group_name=None):
super(deepspeed.moe.experts.Experts, self).__init__()
self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
self.num_local_experts = num_local_experts
for expert in self.deepspeed_experts:
for name, param in expert.named_parameters():
param.allreduce = False
param.group_name = expert_group_name
def forward(self, inputs):
chunks = inputs.chunk(self.num_local_experts, dim=1)
expert_outputs = []
for chunk, expert in zip(chunks, self.deepspeed_experts):
out = expert(chunk)
if type(out) is tuple:
out = out[0]
expert_outputs += [out]
expert_output = torch.cat(expert_outputs, dim=1)
return expert_output
class AudioMOELayer(deepspeed.moe.sharded_moe.MOELayer):
"""Custom Audio MoE layer extending DeepSpeed MOELayer with matrix compression optimization."""
def __init__(
self,
experts: nn.Module,
ep_group_name,
ep_size,
num_local_experts: int,
use_tutel: bool = False,
) -> None:
super(deepspeed.moe.sharded_moe.MOELayer, self).__init__()
self.experts = experts
self.ep_group = None
self.ep_size = ep_size
self.ep_group_name = ep_group_name
self.num_local_experts = num_local_experts
self.time_falltoall = 0.0
self.time_salltoall = 0.0
self.time_moe = 0.0
self.timers = SynchronizedWallClockTimer()
self.wall_clock_breakdown = False
def _set_ep_group(self, ep_group):
self.ep_group = ep_group
def forward(self, hidden_states: Tensor, expert_mask: Tensor, router_weight: Tensor) -> Tensor:
router_weight = router_weight * expert_mask
if self.wall_clock_breakdown:
self.timers(MOE_TIMER).start()
d_model = hidden_states.shape[-1]
seq_len = hidden_states.shape[0]
expert_num = expert_mask.shape[-1]
capacity = expert_mask.sum(dim=0).max()
if self.ep_group is not None:
dist.all_reduce(capacity, op=dist.ReduceOp.MAX, group=self.ep_group)
compres_hidden_states = hidden_states.unsqueeze(1).expand(seq_len, expert_num, d_model)
compres_hidden_states = compress_matrix(compres_hidden_states, expert_mask, force_dim=capacity, allow_larger_dim=True) # [C, expert_num, d_model]
compres_expert_mask = compress_matrix(expert_mask, expert_mask, force_dim=capacity, allow_larger_dim=True)
dispatched_input = einsum("ce,cem->ecm", compres_expert_mask, compres_hidden_states)
if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).start()
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input)
if self.wall_clock_breakdown:
self.timers(FIRST_ALLTOALL_TIMER).stop()
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False)
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model)
expert_output = self.experts(dispatched_input)
if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).start()
expert_output = _AllToAll.apply(self.ep_group, expert_output)
if self.wall_clock_breakdown:
self.timers(SECOND_ALLTOALL_TIMER).stop()
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False)
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model)
expert_output = decompress_matrix(expert_output.transpose(0, 1), expert_mask, allow_larger_dim=True)
combined_output = einsum("se,sem->sm", router_weight, expert_output)
if self.wall_clock_breakdown:
self.timers(MOE_TIMER).stop()
self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False)
return combined_output
class UniMoEAudioMoE(deepspeed.moe.layer.MoE):
"""Custom Audio MoE class extending DeepSpeed MoE with configuration and parallelism setup."""
def __init__(self, config, expert, num_experts, ep_size, moe_name_prefix="ep_size"):
super(deepspeed.moe.layer.MoE, self).__init__()
self.enable_expert_tensor_parallelism = config.enable_expert_tensor_parallelism
self.ep_size = ep_size
self.num_experts = num_experts
self.expert_group_name = f"{moe_name_prefix}_{self.ep_size}"
self.num_local_experts = self.num_experts // self.ep_size
log_dist(f"Creating MoE layer with num_experts: {self.num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}", [0])
experts = AudioExperts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = AudioMOELayer(experts, self.expert_group_name, self.ep_size, self.num_local_experts)
def set_deepspeed_parallelism(self, use_data_before_expert_parallel_=False):
self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_)
def _create_process_groups(self, use_data_before_expert_parallel_=False):
if self.expert_group_name not in groups._get_expert_parallel_group_dict():
print(f"No existing process group found, creating a new group named: {self.expert_group_name}")
if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism):
groups._create_expert_and_data_parallel(self.ep_size, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
else:
groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu, use_data_before_expert_parallel_=use_data_before_expert_parallel_)
self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name))
def forward(self, *input_args, **input_kwargs):
return self.deepspeed_moe(*input_args, **input_kwargs)
|