klemenk commited on
Commit
7eef20c
·
verified ·
1 Parent(s): f23dd4b

Update modeling_gpt.py

Browse files
Files changed (1) hide show
  1. modeling_gpt.py +126 -8
modeling_gpt.py CHANGED
@@ -5,7 +5,105 @@ from torch.nn import functional as F
5
  from transformers import PreTrainedModel
6
  from .configuration_gpt import GPTConfig
7
 
8
- # Include your Block and RMSNorm implementations here:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class RMSNorm(nn.Module):
10
  def __init__(self, dim, eps=1e-5):
11
  super().__init__()
@@ -17,18 +115,38 @@ class RMSNorm(nn.Module):
17
  return self.weight * x / (norm + self.eps)
18
 
19
  class Block(nn.Module):
 
20
  def __init__(self, config):
21
  super().__init__()
22
- # Define the structure (attention, mlp, etc.) exactly as your existing implementation.
23
- pass # Replace with your existing implementation.
 
24
 
25
  def forward(self, x):
26
- # Implement exactly as your existing Block forward pass.
27
- pass # Replace with your existing implementation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- def rmsnorm(x, eps=1e-5):
30
- norm = torch.norm(x, dim=-1, keepdim=True)
31
- return x / (norm + eps)
32
 
33
  class GPT(PreTrainedModel):
34
  config_class = GPTConfig
 
5
  from transformers import PreTrainedModel
6
  from .configuration_gpt import GPTConfig
7
 
8
+
9
+ ################################
10
+ ### Layers ###
11
+ ################################
12
+
13
+ class Rotary(torch.nn.Module):
14
+
15
+ def __init__(self, dim, base=10000):
16
+ super().__init__()
17
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
18
+ self.register_buffer("inv_freq", inv_freq)
19
+ self.seq_len_cached = None
20
+ self.cos_cached = None
21
+ self.sin_cached = None
22
+
23
+ def forward(self, x):
24
+ seq_len = x.shape[1]
25
+ if seq_len != self.seq_len_cached:
26
+ self.seq_len_cached = seq_len
27
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
28
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
29
+ self.cos_cached = freqs.cos()
30
+ self.sin_cached = freqs.sin()
31
+ return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
32
+
33
+ def apply_rotary_emb(x, cos, sin):
34
+ assert x.ndim == 4 # multihead attention
35
+ d = x.shape[3]//2
36
+ x1 = x[..., :d]
37
+ x2 = x[..., d:]
38
+ y1 = x1 * cos + x2 * sin
39
+ y2 = x1 * (-sin) + x2 * cos
40
+ return torch.cat([y1, y2], 3)
41
+
42
+ def rmsnorm(x0, eps=1e-6):
43
+ x = x0.float()
44
+ x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
45
+ return x.type_as(x0)
46
+
47
+
48
+ class RMSNorm(nn.Module):
49
+ """ Root Mean Square Normalization """
50
+ def __init__(self, dim: int, weight: bool = False, bias: bool = False, eps: float = 1e-6):
51
+ super().__init__()
52
+ self.eps = eps
53
+
54
+ if weight:
55
+ self.weight = nn.Parameter(torch.ones(dim))
56
+ else:
57
+ self.register_parameter("weight", None)
58
+
59
+ if bias:
60
+ self.bias = nn.Parameter(torch.zeros(dim))
61
+ else:
62
+ self.register_parameter("bias", None)
63
+
64
+ def _norm(self, x):
65
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
66
+
67
+ def forward(self, x):
68
+ output = self._norm(x.float()).type_as(x)
69
+ if self.weight is not None:
70
+ output = output * self.weight
71
+ if self.bias is not None:
72
+ output = output + self.bias
73
+ return output
74
+
75
+
76
+ class CausalSelfAttention(nn.Module):
77
+
78
+ def __init__(self, config):
79
+ super().__init__()
80
+ self.n_head = config.n_head
81
+ self.n_embd = config.n_embd
82
+ self.head_dim = self.n_embd // self.n_head
83
+ assert self.n_embd % self.n_head == 0
84
+ # key, query, value projections for all heads, but in a batch
85
+ self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)
86
+ # output projection
87
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
88
+ self.rotary = Rotary(self.head_dim)
89
+
90
+ def forward(self, x):
91
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
92
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
93
+ qkv = self.c_attn(x)
94
+ q, k, v = qkv.split(self.n_embd, dim=2)
95
+ k = k.view(B, T, self.n_head, self.head_dim)
96
+ q = q.view(B, T, self.n_head, self.head_dim)
97
+ v = v.view(B, T, self.n_head, self.head_dim)
98
+ cos, sin = self.rotary(q)
99
+ q = apply_rotary_emb(q, cos, sin)
100
+ k = apply_rotary_emb(k, cos, sin)
101
+ y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
102
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
103
+ # output projection
104
+ y = self.c_proj(y)
105
+ return y
106
+
107
  class RMSNorm(nn.Module):
108
  def __init__(self, dim, eps=1e-5):
109
  super().__init__()
 
115
  return self.weight * x / (norm + self.eps)
116
 
117
  class Block(nn.Module):
118
+
119
  def __init__(self, config):
120
  super().__init__()
121
+ self.attn = CausalSelfAttention(config)
122
+ self.mlp = MLP(config)
123
+ self.attn_scale = (1 / (2 * config.n_layer)**0.5)
124
 
125
  def forward(self, x):
126
+ x = x + self.attn_scale * self.attn(rmsnorm(x))
127
+ x = x + self.mlp(rmsnorm(x))
128
+ return x
129
+
130
+ class MLP(nn.Module):
131
+
132
+ def __init__(self, config):
133
+ super().__init__()
134
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
135
+ self.gelu = nn.GELU()
136
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
137
+ self.dropout = nn.Dropout(config.dropout)
138
+
139
+ def forward(self, x):
140
+ x = self.c_fc(x)
141
+ x = self.gelu(x)
142
+ x = self.c_proj(x)
143
+ x = self.dropout(x)
144
+ return x
145
+
146
 
147
+ ################################
148
+ ### Model ###
149
+ ################################
150
 
151
  class GPT(PreTrainedModel):
152
  config_class = GPTConfig