File size: 5,153 Bytes
23bc32f
 
 
 
 
 
f9f6093
23bc32f
 
 
 
 
 
fb56df2
23bc32f
 
fb56df2
 
 
 
23bc32f
fb56df2
 
23bc32f
fb56df2
23bc32f
 
 
fb56df2
23bc32f
 
 
 
 
fb56df2
23bc32f
 
 
fb56df2
23bc32f
 
fb56df2
23bc32f
 
 
 
 
 
 
 
 
 
 
 
 
fb56df2
23bc32f
 
 
 
 
 
 
 
 
fb56df2
23bc32f
fb56df2
23bc32f
fb56df2
23bc32f
fb56df2
 
23bc32f
 
 
 
 
 
 
fb56df2
23bc32f
 
fb56df2
 
 
 
23bc32f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb56df2
23bc32f
fb56df2
 
 
 
 
23bc32f
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""
Inspired from https://github.com/karpathy/minGPT
"""

from typing import Optional
from einops import rearrange

import torch
import torch.nn as nn

from .kv_caching import KeysValues, KVCache

class TransformerEncoder(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        self.config = config
        self.config["max_tokens"] = config["tokens_per_block"] * config["max_blocks"]
        self.pos_emb = nn.Embedding(config["max_tokens"], config["embed_dim"])
        self.emb_drop = nn.Dropout(config["embed_pdrop"])
        self.ln = nn.LayerNorm(config["embed_dim"])

        assert config["attention"] in ('causal', 'block_causal')
        k, m = config["tokens_per_block"], config["max_blocks"]
        mask_sa = torch.tril(torch.ones(k * m, k * m))
        if config["attention"] == 'block_causal':
            mask_sa = torch.max(mask_sa, torch.block_diag(*[torch.ones(k, k) for _ in range(m)]))
        mask_sa = mask_sa.bool()

        self.blocks = nn.ModuleList([EncoderLayer(config, mask_sa) for _ in range(config["num_layers"])])
        self.keys_values = None

    @property
    def num_blocks_left_in_kv_cache(self) -> float:
        assert self.keys_values is not None
        return (self.config["max_tokens"] - self.keys_values.size) / self.config["tokens_per_block"]

    def reset_kv_cache(self, n: int) -> None:
        device = self.ln.weight.device
        self.keys_values = KeysValues(n, self.config["max_tokens"], self.config["embed_dim"], self.config["num_layers"], device)

    def forward(self, x: torch.FloatTensor, use_kv_cache: bool = False) -> torch.FloatTensor:
        assert x.ndim == 3 and x.size(2) == self.config["embed_dim"]   # (B, TK, E)

        prev_steps = self.keys_values.size if use_kv_cache else 0
        inputs = x + self.pos_emb(prev_steps + torch.arange(x.size(1), device=x.device))

        y = self.emb_drop(inputs)
        for i, block in enumerate(self.blocks):
            y = block(y, self.keys_values[i] if use_kv_cache else None)
        y = self.ln(y)

        return y


class EncoderLayer(nn.Module):
    def __init__(self, config: dict, mask_sa: torch.LongTensor) -> None:
        super().__init__()
        self.sa = SelfAttentionLayer(config, mask=mask_sa)
        self.mlp = MLPLayer(config)

    def forward(self, x: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
        return self.mlp(self.sa(x, kv_cache))   


class MLPLayer(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        self.ln = nn.LayerNorm(config["embed_dim"])
        self.mlp = nn.Sequential(
            nn.Linear(config["embed_dim"], 4 * config["embed_dim"]),
            nn.GELU(),
            nn.Linear(4 * config["embed_dim"], config["embed_dim"]),
            nn.Dropout(config["resid_pdrop"]),
        )

    def forward(self, inputs: torch.FloatTensor) -> torch.FloatTensor:
        return inputs + self.mlp(self.ln(inputs)) 


class SelfAttentionLayer(nn.Module):
    def __init__(self, config: dict, mask: torch.BoolTensor) -> None:
        super().__init__()
        self.register_buffer('mask', mask)
        self.ln = nn.LayerNorm(config["embed_dim"])
        self.query = nn.Linear(config["embed_dim"], config["embed_dim"])
        self.key = nn.Linear(config["embed_dim"], config["embed_dim"])
        self.value = nn.Linear(config["embed_dim"], config["embed_dim"])
        self.attention = Attention(config)

    def forward(self, inputs: torch.FloatTensor, kv_cache: Optional[KVCache] = None) -> torch.FloatTensor:
        B, T, C = inputs.size()
        if kv_cache is not None:
            b, L, c = kv_cache.shape
            assert b == B and c == C
        else:
            L = 0

        x = self.ln(inputs)

        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        if kv_cache is not None:
            kv_cache.update(k, v)
            k, v = kv_cache.get()

        y = inputs + self.attention(q, k, v, self.mask[L:L + T, :L + T])

        return y


class Attention(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        assert config["embed_dim"] % config["num_heads"] == 0
        self.num_heads = config["num_heads"]
        self.attn_pdrop = config["attn_pdrop"]
        self.resid_drop = nn.Dropout(config["resid_pdrop"])
        self.proj = nn.Linear(config["embed_dim"], config["embed_dim"])

    def forward(self, q: torch.FloatTensor, k: torch.FloatTensor, v: torch.FloatTensor, mask: torch.BoolTensor) -> torch.FloatTensor:
        assert mask.size(0) == q.size(1) and mask.size(1) == k.size(1)

        q = rearrange(q, 'b q (h e) -> b h q e', h=self.num_heads)
        k = rearrange(k, 'b k (h e) -> b h k e', h=self.num_heads)
        v = rearrange(v, 'b k (h d) -> b h k d', h=self.num_heads)

        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=self.attn_pdrop, is_causal=False) if q.size(2) != 0 else q.new_empty(*q.shape[:-1], v.size(-1))

        y = rearrange(y, 'b h q d -> b q (h d)')
        y = self.resid_drop(self.proj(y))

        return y