File size: 3,508 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Optional

from tensorrt_llm.functional import Tensor, silu
from tensorrt_llm.layers import ColumnLinear
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.module import Module, ModuleList

from ..._utils import str_dtype_to_trt


class ResBlock(Module):

    def __init__(self,
                 exit_dim: int,
                 dtype: Optional[str],
                 mapping: Mapping = Mapping()):
        super().__init__()
        self.linear = ColumnLinear(
            exit_dim,
            exit_dim,
            bias=True,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

    def forward(self, x: Tensor) -> Tensor:
        return x + silu(self.linear(x))


class Drafter(Module):

    def __init__(
            self,
            num_layers: int,
            hidden_size: int,
            exit_dim: int,
            vocab_size: int,
            dtype: Optional[str] = None,
            is_rnn: bool = False,
            mapping: Mapping = Mapping(),
    ):
        super().__init__()
        self.num_layers = num_layers
        self.is_rnn = is_rnn
        self.dtype = str_dtype_to_trt(dtype)

        input_dim = 2 * hidden_size
        self.input_proj = (None if input_dim == exit_dim else ColumnLinear(
            input_dim,
            exit_dim,
            bias=True,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        ))

        self.layers = ModuleList([
            ResBlock(exit_dim, dtype, mapping) for _ in range(self.num_layers)
        ])
        self.lm_head = ColumnLinear(
            exit_dim,
            vocab_size,
            bias=False,
            dtype=dtype,
            tp_group=mapping.tp_group,
            tp_size=mapping.tp_size,
            gather_output=True,
        )

        if is_rnn:
            self.rnn_u = ColumnLinear(
                hidden_size,
                hidden_size,
                bias=True,
                dtype=dtype,
                tp_group=mapping.tp_group,
                tp_size=mapping.tp_size,
                gather_output=True,
            )
            self.rnn_w = ColumnLinear(
                hidden_size,
                hidden_size,
                bias=False,
                dtype=dtype,
                tp_group=mapping.tp_group,
                tp_size=mapping.tp_size,
                gather_output=True,
            )
        return

    @classmethod
    def from_config(cls, config, vocab_size_padded):
        kwargs = {
            "num_layers": config.redrafter_num_layers,
            "hidden_size": config.redrafter_hidden_size,
            "exit_dim": config.redrafter_exit_dim,
            "vocab_size": vocab_size_padded,
            "dtype": config.dtype,
            "is_rnn": config.redrafter_is_rnn,
            "mapping": config.mapping,
        }
        return cls(**kwargs)

    def forward(self, x: Tensor) -> Tensor:
        hidden_states = self.input_proj(x) if self.input_proj is not None else x
        for layer in self.layers:
            hidden_states = layer(hidden_states)

        return self.lm_head(hidden_states)

    def rnn_embed(self, x: Tensor, prev: Tensor = None) -> Tensor:
        assert self.is_rnn, "This function should not be called when redrafter_is_rnn is false."
        w_embd = self.rnn_w(x)
        return w_embd if prev is None else w_embd + self.rnn_u(prev)