|
|
from typing import Optional |
|
|
|
|
|
from tensorrt_llm.functional import Tensor, silu |
|
|
from tensorrt_llm.layers import ColumnLinear |
|
|
from tensorrt_llm.mapping import Mapping |
|
|
from tensorrt_llm.module import Module, ModuleList |
|
|
|
|
|
from ..._utils import str_dtype_to_trt |
|
|
|
|
|
|
|
|
class ResBlock(Module): |
|
|
|
|
|
def __init__(self, |
|
|
exit_dim: int, |
|
|
dtype: Optional[str], |
|
|
mapping: Mapping = Mapping()): |
|
|
super().__init__() |
|
|
self.linear = ColumnLinear( |
|
|
exit_dim, |
|
|
exit_dim, |
|
|
bias=True, |
|
|
dtype=dtype, |
|
|
tp_group=mapping.tp_group, |
|
|
tp_size=mapping.tp_size, |
|
|
gather_output=True, |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
return x + silu(self.linear(x)) |
|
|
|
|
|
|
|
|
class Drafter(Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_layers: int, |
|
|
hidden_size: int, |
|
|
exit_dim: int, |
|
|
vocab_size: int, |
|
|
dtype: Optional[str] = None, |
|
|
is_rnn: bool = False, |
|
|
mapping: Mapping = Mapping(), |
|
|
): |
|
|
super().__init__() |
|
|
self.num_layers = num_layers |
|
|
self.is_rnn = is_rnn |
|
|
self.dtype = str_dtype_to_trt(dtype) |
|
|
|
|
|
input_dim = 2 * hidden_size |
|
|
self.input_proj = (None if input_dim == exit_dim else ColumnLinear( |
|
|
input_dim, |
|
|
exit_dim, |
|
|
bias=True, |
|
|
dtype=dtype, |
|
|
tp_group=mapping.tp_group, |
|
|
tp_size=mapping.tp_size, |
|
|
gather_output=True, |
|
|
)) |
|
|
|
|
|
self.layers = ModuleList([ |
|
|
ResBlock(exit_dim, dtype, mapping) for _ in range(self.num_layers) |
|
|
]) |
|
|
self.lm_head = ColumnLinear( |
|
|
exit_dim, |
|
|
vocab_size, |
|
|
bias=False, |
|
|
dtype=dtype, |
|
|
tp_group=mapping.tp_group, |
|
|
tp_size=mapping.tp_size, |
|
|
gather_output=True, |
|
|
) |
|
|
|
|
|
if is_rnn: |
|
|
self.rnn_u = ColumnLinear( |
|
|
hidden_size, |
|
|
hidden_size, |
|
|
bias=True, |
|
|
dtype=dtype, |
|
|
tp_group=mapping.tp_group, |
|
|
tp_size=mapping.tp_size, |
|
|
gather_output=True, |
|
|
) |
|
|
self.rnn_w = ColumnLinear( |
|
|
hidden_size, |
|
|
hidden_size, |
|
|
bias=False, |
|
|
dtype=dtype, |
|
|
tp_group=mapping.tp_group, |
|
|
tp_size=mapping.tp_size, |
|
|
gather_output=True, |
|
|
) |
|
|
return |
|
|
|
|
|
@classmethod |
|
|
def from_config(cls, config, vocab_size_padded): |
|
|
kwargs = { |
|
|
"num_layers": config.redrafter_num_layers, |
|
|
"hidden_size": config.redrafter_hidden_size, |
|
|
"exit_dim": config.redrafter_exit_dim, |
|
|
"vocab_size": vocab_size_padded, |
|
|
"dtype": config.dtype, |
|
|
"is_rnn": config.redrafter_is_rnn, |
|
|
"mapping": config.mapping, |
|
|
} |
|
|
return cls(**kwargs) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
hidden_states = self.input_proj(x) if self.input_proj is not None else x |
|
|
for layer in self.layers: |
|
|
hidden_states = layer(hidden_states) |
|
|
|
|
|
return self.lm_head(hidden_states) |
|
|
|
|
|
def rnn_embed(self, x: Tensor, prev: Tensor = None) -> Tensor: |
|
|
assert self.is_rnn, "This function should not be called when redrafter_is_rnn is false." |
|
|
w_embd = self.rnn_w(x) |
|
|
return w_embd if prev is None else w_embd + self.rnn_u(prev) |
|
|
|