File size: 4,281 Bytes
cf88ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Credits to https://github.com/karpathy/minGPT
"""

from dataclasses import dataclass
import math
from typing import Optional

from einops import rearrange
import torch
import torch.nn as nn
from torch.nn import functional as F

from .kv_caching import KeysValues, KVCache

class Transformer(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        self.config = config
        self.config["max_tokens"] = config["tokens_per_block"] * config["max_blocks"]
        self.drop = nn.Dropout(config["embed_pdrop"])
        self.blocks = nn.ModuleList([Block(config) for _ in range(config["num_layers"])])
        self.ln_f = nn.LayerNorm(config["embed_dim"])

    def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
        device = self.ln_f.weight.device  # Assumption that all submodules are on the same device
        return KeysValues(n, self.config["num_heads"], max_tokens, self.config["embed_dim"], self.config["num_layers"], device)

    def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None) -> torch.Tensor:
        assert past_keys_values is None or len(past_keys_values) == len(self.blocks)
        x = self.drop(sequences)
        for i, block in enumerate(self.blocks):
            x = block(x, None if past_keys_values is None else past_keys_values[i])

        x = self.ln_f(x)
        return x


class Block(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        self.ln1 = nn.LayerNorm(config["embed_dim"])
        self.ln2 = nn.LayerNorm(config["embed_dim"])
        self.attn = SelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config["embed_dim"], 4 * config["embed_dim"]),
            nn.GELU(),
            nn.Linear(4 * config["embed_dim"], config["embed_dim"]),
            nn.Dropout(config["resid_pdrop"]),
        )

    def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None) -> torch.Tensor:
        x_attn = self.attn(self.ln1(x), past_keys_values)
        x = x + x_attn
        x = x + self.mlp(self.ln2(x))
        return x


class SelfAttention(nn.Module):
    def __init__(self, config: dict) -> None:
        super().__init__()
        assert config["embed_dim"] % config["num_heads"] == 0
        assert config["attention"] in ('causal', 'block_causal')
        self.num_heads = config["num_heads"]
        self.key = nn.Linear(config["embed_dim"], config["embed_dim"])
        self.query = nn.Linear(config["embed_dim"], config["embed_dim"])
        self.value = nn.Linear(config["embed_dim"], config["embed_dim"])
        self.attn_drop = nn.Dropout(config["attn_pdrop"])
        self.resid_drop = nn.Dropout(config["resid_pdrop"])
        self.proj = nn.Linear(config["embed_dim"], config["embed_dim"])

        causal_mask = torch.tril(torch.ones(config["max_tokens"], config["max_tokens"]))
        block_causal_mask = torch.max(causal_mask, torch.block_diag(*[torch.ones(config["tokens_per_block"], config["tokens_per_block"]) for _ in range(config["max_blocks"])]))
        self.register_buffer('mask', causal_mask if config["attention"] == 'causal' else block_causal_mask)

    def forward(self, x: torch.Tensor, kv_cache: Optional[KVCache] = None) -> torch.Tensor:
        B, T, C = x.size()
        if kv_cache is not None:
            b, nh, L, c = kv_cache.shape
            assert nh == self.num_heads and b == B and c * nh == C
        else:
            L = 0

        q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)   # (B, nh, T, hs)
        k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)     # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)   # (B, nh, T, hs)

        if kv_cache is not None:
            kv_cache.update(k, v)
            k, v = kv_cache.get()

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.mask[L:L + T, :L + T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v
        y = rearrange(y, 'b h t e -> b t (h e)')

        y = self.resid_drop(self.proj(y))

        return y