File size: 2,331 Bytes
cd16f07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from __future__ import annotations

import importlib.util
from typing import Optional, Tuple

import torch
import torch.nn as nn


_HAS_DEEPSPEED = importlib.util.find_spec("deepspeed") is not None
_DEEPSPEED_MOE_LAYER = None
_DEEPSPEED_IMPORT_ATTEMPTED = False
_DEEPSPEED_IMPORT_ERROR: Optional[str] = None


def _load_deepspeed_moe_layer():
    global _DEEPSPEED_MOE_LAYER, _DEEPSPEED_IMPORT_ATTEMPTED, _DEEPSPEED_IMPORT_ERROR
    if _DEEPSPEED_IMPORT_ATTEMPTED:
        return _DEEPSPEED_MOE_LAYER
    _DEEPSPEED_IMPORT_ATTEMPTED = True
    if not _HAS_DEEPSPEED:
        return None
    try:
        from deepspeed.moe.layer import MoE as deepspeed_moe_layer
    except Exception as exc:
        _DEEPSPEED_IMPORT_ERROR = str(exc)
        _DEEPSPEED_MOE_LAYER = None
        return None
    _DEEPSPEED_MOE_LAYER = deepspeed_moe_layer
    return _DEEPSPEED_MOE_LAYER


class DeepSpeedMoEWrapper(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        expert: nn.Module,
        num_experts: int,
        top_k: int,
        ep_size: int = 1,
    ):
        super().__init__()
        deepspeed_moe_layer = _load_deepspeed_moe_layer()
        if deepspeed_moe_layer is None:
            details = f": {_DEEPSPEED_IMPORT_ERROR}" if _DEEPSPEED_IMPORT_ERROR else ""
            raise RuntimeError(f"DeepSpeed MoE backend is not available{details}")
        self.layer = deepspeed_moe_layer(
            hidden_size=hidden_size,
            expert=expert,
            num_experts=num_experts,
            ep_size=ep_size,
            k=top_k,
            use_residual=False,
        )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        out, aux_loss, _ = self.layer(x)
        if isinstance(aux_loss, torch.Tensor):
            return out, aux_loss
        return out, x.new_zeros(())


def build_deepspeed_moe(
    hidden_size: int,
    expert: nn.Module,
    num_experts: int,
    top_k: int,
    ep_size: int = 1,
) -> Optional[DeepSpeedMoEWrapper]:
    if _load_deepspeed_moe_layer() is None:
        return None
    return DeepSpeedMoEWrapper(
        hidden_size=hidden_size,
        expert=expert,
        num_experts=num_experts,
        top_k=top_k,
        ep_size=ep_size,
    )


def has_deepspeed_moe() -> bool:
    return _load_deepspeed_moe_layer() is not None