File size: 5,208 Bytes
63b7820
 
 
 
 
 
a3b70ca
63b7820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3b70ca
63b7820
8b251ce
63b7820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8b251ce
63b7820
 
 
8b251ce
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_gpt2workshop import GPT2WorkshopConfig
from transformers.generation import GenerationMixin


def build_rope_cache(seq_len, head_dim, device, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
    positions = torch.arange(seq_len, device=device).float()
    angles = torch.outer(positions, freqs)
    return torch.cos(angles), torch.sin(angles)


def apply_rotary_embeddings(x, rope_cos, rope_sin):
    cos = rope_cos[:x.shape[2], :].unsqueeze(0).unsqueeze(0)
    sin = rope_sin[:x.shape[2], :].unsqueeze(0).unsqueeze(0)
    even, odd = x[..., 0::2], x[..., 1::2]
    return torch.stack((even * cos - odd * sin, even * sin + odd * cos), dim=-1).flatten(-2)


def relu_squared(x):
    return F.relu(x).square()


def soft_cap_logits(logits, cap=30.0):
    return cap * torch.tanh(logits / cap)


class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.num_heads
        self.head_dim = config.head_dim
        self.query_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
        self.key_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
        self.value_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
        self.output_projection = nn.Linear(config.hidden_dim, config.hidden_dim, bias=False)
        self.attn_dropout_rate = config.dropout

    def forward(self, x, rope_cos, rope_sin):
        batch_size, seq_len, _ = x.shape
        q = self.query_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value_projection(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        q = apply_rotary_embeddings(q, rope_cos, rope_sin)
        k = apply_rotary_embeddings(k, rope_cos, rope_sin)
        dropout_p = self.attn_dropout_rate if self.training else 0.0
        attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=dropout_p)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.output_projection(attn_output)


class FeedForwardNetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        ffn_dim = config.hidden_dim * config.ffn_expansion
        self.up_projection = nn.Linear(config.hidden_dim, ffn_dim, bias=False)
        self.down_projection = nn.Linear(ffn_dim, config.hidden_dim, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        return self.dropout(self.down_projection(relu_squared(self.up_projection(x))))


class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6)
        self.attention = MultiHeadAttention(config)
        self.ffn_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6)
        self.feed_forward = FeedForwardNetwork(config)
        self.attention_residual_dropout = nn.Dropout(config.dropout)
        self.ffn_residual_dropout = nn.Dropout(config.dropout)

    def forward(self, x, rope_cos, rope_sin):
        x = x + self.attention_residual_dropout(self.attention(self.attention_norm(x), rope_cos, rope_sin))
        x = x + self.ffn_residual_dropout(self.feed_forward(self.ffn_norm(x)))
        return x


class GPT2WorkshopForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = GPT2WorkshopConfig
    _tied_weights_keys = {}

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
        self.embedding_dropout = nn.Dropout(config.dropout)
        self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
        self.final_norm = nn.RMSNorm(config.hidden_dim, eps=1e-6)
        rope_cos, rope_sin = build_rope_cache(config.context_length, config.head_dim, device="cpu", theta=config.rope_theta)
        self.register_buffer("rope_cos", rope_cos)
        self.register_buffer("rope_sin", rope_sin)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        x = self.embedding_dropout(self.token_embedding(input_ids))
        for layer in self.layers:
            x = layer(x, self.rope_cos, self.rope_sin)
        x = self.final_norm(x)
        logits = soft_cap_logits(F.linear(x, self.token_embedding.weight), self.config.logit_soft_cap)
        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), labels.reshape(-1), ignore_index=-100)
        return CausalLMOutput(loss=loss, logits=logits)

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {"input_ids": input_ids}

    @property
    def all_tied_weights_keys(self):
        return {}