File size: 3,520 Bytes
72c0672
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright (c) Meta Platforms, Inc. and affiliates.

from typing import Tuple
import torch

from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states
from apps.mamba.component.causal_conv1d_compilable import (
    causal_conv1d_fn,
    causal_conv1d_update,
)

from apps.fastRNN.component.compilable_scan import scan as accelerated_scan

# from accelerated_scan.triton import scan as triton_scan
from accelerated_scan.ref import scan as ref_scan


def conv1d(
    x: torch.Tensor,
    conv_weight: torch.Tensor,
    tok_idx: torch.Tensor,
    cu_seqlens: torch.Tensor,
    impl: str = "parallel",
    cache=None,
) -> torch.Tensor:
    if impl == "parallel":
        if cache is not None:
            conv_varlen_states = causal_conv1d_varlen_states(
                x.squeeze(0).transpose(0, 1), cu_seqlens, state_len=cache.shape[-1]
            )
            cache.copy_(conv_varlen_states)

        x = causal_conv1d_fn(
            x=x,
            weight=conv_weight,
            bias=None,
            seq_idx=tok_idx,
            activation="silu",
        )

    elif impl == "sequential":
        x = (
            causal_conv1d_update(
                x=x.squeeze(0).transpose(0, 1),
                conv_state=cache,
                weight=conv_weight,
                bias=None,
                activation="silu",
            )
            .transpose(0, 1)
            .unsqueeze(0)
        )

    else:
        raise NotImplementedError(
                f"causal_conv1d implementation {impl} not supported"
            )

    return x


def _prepare_for_cache(
    a: torch.Tensor, b: torch.Tensor, cu_seqlen: torch.Tensor, seq_len: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """This function reset the hidden state at the beginning of each sequence in the batch so that the hidden state is not carried over between sequences."""
    num_seq = cu_seqlen.size(0) - 1
    pow_2_seqlen = max(2 ** (seq_len + num_seq - 2).bit_length(), 32)
    _a = torch.zeros(*a.shape[:2], pow_2_seqlen, device=a.device, dtype=a.dtype)
    _b = torch.zeros(*b.shape[:2], pow_2_seqlen, device=b.device, dtype=b.dtype)

    mask = torch.zeros(pow_2_seqlen, dtype=torch.bool, device=a.device)
    offsets = torch.arange(0, num_seq, device=a.device)
    mask[cu_seqlen[1:-1] + offsets[:-1]] = True
    mask[(cu_seqlen[-1] + offsets[-1]) :] = True
    mask = (~mask).nonzero().flatten()

    for tensor_with_reset, tensor in zip((_a, _b), (a, b)):
        tensor_with_reset[..., mask] = tensor

    return _a, _b, cu_seqlen[1:] + offsets - 1, mask


def sequential_step(
    states: torch.Tensor, a: torch.Tensor, b: torch.Tensor
) -> torch.Tensor:
    return a * states + b


def scan(
    a: torch.Tensor,
    b: torch.Tensor,
    cu_seqlens: torch.Tensor,
    impl: str = "parallel",
    cache=None,
) -> torch.Tensor:
    if impl == "parallel":
        if cache is not None:
            # For accelerated_scan give me illegal memory access error when seqlen > ~2048
            a, b, last_state_idx, mask = _prepare_for_cache(a, b, cu_seqlens, a.size(2))

            h = ref_scan(
                a.contiguous(),
                b.contiguous(),
            )

            cache.copy_(h[:, :, last_state_idx])
            h = h[:, :, mask]
        else:
            h = accelerated_scan(
                a.contiguous(),
                b.contiguous(),
            )

    elif impl == "sequential":
        h = sequential_step(cache, a, b)
        cache.copy_(h)

    return h