abhinavv3 commited on
Commit
f58ea49
·
1 Parent(s): c735a0e

Changed the attention mechanism to GQA in knn_attention and xl_attention

Browse files
configs/config.json CHANGED
@@ -5,7 +5,7 @@
5
  "n_layer": 12,
6
  "n_head": 12,
7
  "n_embd": 768,
8
- "n_kv_heads": 4
9
  },
10
  "training": {
11
  "max_steps": 19073,
 
5
  "n_layer": 12,
6
  "n_head": 12,
7
  "n_embd": 768,
8
+ "n_kv_head": 4
9
  },
10
  "training": {
11
  "max_steps": 19073,
model_core/attention.py CHANGED
@@ -147,12 +147,17 @@ class XLAttention(nn.Module):
147
  super().__init__()
148
  assert config.n_embd % config.n_head == 0
149
  self.n_head = config.n_head
 
150
  self.n_embd = config.n_embd
151
  self.head_dim = config.n_embd // config.n_head
 
 
152
  self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
153
  self.scale = self.head_dim ** -0.5
154
 
155
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
 
 
156
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
157
  self.c_proj.MEMGPT_SCALE_INIT = 1
158
 
@@ -161,8 +166,9 @@ class XLAttention(nn.Module):
161
  def forward(self, x, xl_memory=None):
162
  B, T, C = x.size()
163
 
164
- qkv = self.c_attn(x) # (B,T,3C)
165
- q, k, v = qkv.split(self.n_embd, dim=2) # (B,T,C)
 
166
 
167
  # Handle XL memory
168
  if xl_memory is not None:
@@ -172,14 +178,17 @@ class XLAttention(nn.Module):
172
  xl_seq_len = k_xl.shape[1]
173
 
174
  # Reshape for multi-head attention
175
- q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
176
- k = k.view(B, -1, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T+xl, hs)
177
- v = v.view(B, -1, self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T+xl, hs)
178
 
179
  # Apply rotary positional encoding
180
  seq_len = k.shape[2]
181
  q, k = self.rope.apply_rotary_pos_emb(q, k)
182
 
 
 
 
183
  # Attention computation
184
  att = (q @ k.transpose(-2, -1)) * self.scale
185
 
@@ -190,34 +199,41 @@ class XLAttention(nn.Module):
190
  att = F.softmax(att, dim=-1)
191
  att = self.dropout(att)
192
 
193
- y = att @ v # (B, nh, T, hs)
194
  y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
195
 
196
  y = self.c_proj(y)
197
 
198
- # Prepare new XL memories
199
- k = k.transpose(1, 2).contiguous().view(B, -1, C) #(B,T+xl,C)
200
- v = v.transpose(1, 2).contiguous().view(B, -1, C) #(B,T+xl,C)
201
- kv_memories = torch.stack((k, v), dim=-2)
 
 
202
 
203
  if xl_memory is not None:
204
- current_kv = kv_memories[:, -xl_seq_len:] #(B,T,C)
205
  else:
206
- current_kv = kv_memories
207
 
208
- return y, current_kv
209
 
210
  class KNNAttention(nn.Module):
211
  def __init__(self, config, knn, topk_retrieved_memories=3):
212
  super().__init__()
213
  assert config.n_embd % config.n_head == 0
214
  self.n_head = config.n_head
 
215
  self.n_embd = config.n_embd
216
  self.head_dim = config.n_embd // config.n_head
 
 
217
  self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
218
  self.scale = self.head_dim ** -0.5
219
 
220
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
 
 
221
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
222
  self.c_proj.MEMGPT_SCALE_INIT = 1
223
 
@@ -230,8 +246,9 @@ class KNNAttention(nn.Module):
230
  def forward(self, x, xl_memory=None):
231
  B, T, C = x.size()
232
 
233
- qkv = self.c_attn(x)
234
- q, k, v = qkv.split(self.n_embd, dim=2)
 
235
 
236
  q = F.normalize(q, dim=-1)
237
  k = F.normalize(k, dim=-1)
@@ -243,40 +260,46 @@ class KNNAttention(nn.Module):
243
  v = torch.cat((v_xl, v), dim=1)
244
  xl_seq_len = k_xl.shape[1]
245
 
246
- q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
247
- k = k.view(B, -1, self.n_head, self.head_dim).transpose(1, 2)
248
- v = v.view(B, -1, self.n_head, self.head_dim).transpose(1, 2)
249
 
250
  seq_len = k.shape[2]
251
  q, k = self.rope.apply_rotary_pos_emb(q, k)
252
 
 
 
 
253
  # LOCAL ATTENTION
254
- att = (q @ k.transpose(-2, -1)) * self.scale
255
  mask = torch.tril(torch.ones(T, seq_len, device=x.device, dtype=torch.bool))
256
  att = att.masked_fill(~mask, float('-inf'))
257
  att = F.softmax(att, dim=-1)
258
  att = self.dropout(att)
259
- local_out = att @ v
260
 
261
- # KNN ATTENTION ###
262
  if self.knn.index.ntotal > 0:
263
  q_search = q.transpose(1, 2).contiguous().view(B, T, C)
264
  mem_kv = self.knn.search(q_search, topk=self.topk_retrieved_memories)
265
  mem_k, mem_v = mem_kv.unbind(dim=-2)
266
 
267
- mem_k = mem_k.view(B, T, self.topk_retrieved_memories, self.n_head, self.head_dim)
268
- mem_k = mem_k.permute(0, 3, 1, 2, 4) # (B, nh, T, k, hs)
269
- mem_v = mem_v.view(B, T, self.topk_retrieved_memories, self.n_head, self.head_dim)
270
- mem_v = mem_v.permute(0, 3, 1, 2, 4) # (B, nh, T, k, hs)
 
271
  mem_k = mem_k.to(q.device)
272
  mem_v = mem_v.to(q.device)
273
 
 
 
 
274
 
275
-
276
- mem_att = (q.unsqueeze(-2) @ mem_k.transpose(-2, -1)).squeeze(-2) * self.scale
277
  mem_att = F.softmax(mem_att, dim=-1)
278
  mem_att = self.dropout(mem_att)
279
- mem_out = (mem_att.unsqueeze(-2) @ mem_v).squeeze(-2)
280
 
281
  # Combine local and memory attention
282
  y = mem_out * self.gate_bias + local_out * (1 - self.gate_bias)
@@ -286,15 +309,15 @@ class KNNAttention(nn.Module):
286
  y = y.transpose(1, 2).contiguous().view(B, T, C)
287
  y = self.c_proj(y) #(B,T,C)
288
 
289
- # Prepare new memories
290
- k = k.transpose(1, 2).contiguous().view(B, -1, C)
291
- v = v.transpose(1, 2).contiguous().view(B, -1, C)
292
- kv_memories = torch.stack((k, v), dim=-2)
293
 
294
  if xl_memory is not None:
295
- current_kv = kv_memories[:, -xl_seq_len:] #(B,T,2,C)
296
  else:
297
- current_kv = kv_memories #(B,T,2,C)
298
 
299
  self.knn.add(current_kv)
300
 
 
147
  super().__init__()
148
  assert config.n_embd % config.n_head == 0
149
  self.n_head = config.n_head
150
+ self.n_kv_head = getattr(config, 'n_kv_head', config.n_head)
151
  self.n_embd = config.n_embd
152
  self.head_dim = config.n_embd // config.n_head
153
+ self.kv_head_dim = config.n_embd // self.n_kv_head
154
+ self.group_size = self.n_head // self.n_kv_head
155
  self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
156
  self.scale = self.head_dim ** -0.5
157
 
158
+ self.q_proj = nn.Linear(config.n_embd, config.n_embd)
159
+ self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
160
+ self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
161
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
162
  self.c_proj.MEMGPT_SCALE_INIT = 1
163
 
 
166
  def forward(self, x, xl_memory=None):
167
  B, T, C = x.size()
168
 
169
+ q = self.q_proj(x) # (B, T, C)
170
+ k = self.k_proj(x) # (B, T, n_kv_head * kv_head_dim)
171
+ v = self.v_proj(x) # (B, T, n_kv_head * kv_head_dim)
172
 
173
  # Handle XL memory
174
  if xl_memory is not None:
 
178
  xl_seq_len = k_xl.shape[1]
179
 
180
  # Reshape for multi-head attention
181
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim)
182
+ k = k.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, T+xl, kv_head_dim) # GQAchange
183
+ v = v.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, T+xl, kv_head_dim) # GQAchange
184
 
185
  # Apply rotary positional encoding
186
  seq_len = k.shape[2]
187
  q, k = self.rope.apply_rotary_pos_emb(q, k)
188
 
189
+ k = k.repeat_interleave(self.group_size, dim=1) # (B, n_head, T+xl, kv_head_dim)
190
+ v = v.repeat_interleave(self.group_size, dim=1) # (B, n_head, T+xl, kv_head_dim)
191
+
192
  # Attention computation
193
  att = (q @ k.transpose(-2, -1)) * self.scale
194
 
 
199
  att = F.softmax(att, dim=-1)
200
  att = self.dropout(att)
201
 
202
+ y = att @ v # (B, n_head, T, kv_head_dim)
203
  y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
204
 
205
  y = self.c_proj(y)
206
 
207
+ # Prepare new XL memories - store original KV dimensions
208
+ k_orig = k[:, ::self.group_size]
209
+ v_orig = v[:, ::self.group_size]
210
+ k_orig = k_orig.transpose(1, 2).contiguous().view(B, -1, self.n_kv_head * self.kv_head_dim)
211
+ v_orig = v_orig.transpose(1, 2).contiguous().view(B, -1, self.n_kv_head * self.kv_head_dim)
212
+ kv_memories = torch.stack((k_orig, v_orig), dim=-2)
213
 
214
  if xl_memory is not None:
215
+ current_kv = kv_memories[:, -xl_seq_len:] #(B,T,2,C)
216
  else:
217
+ current_kv = kv_memories #(B,T,2,C)
218
 
219
+ return y, current_kv #(B,T,C),(B,T,2,C)
220
 
221
  class KNNAttention(nn.Module):
222
  def __init__(self, config, knn, topk_retrieved_memories=3):
223
  super().__init__()
224
  assert config.n_embd % config.n_head == 0
225
  self.n_head = config.n_head
226
+ self.n_kv_head = getattr(config, 'n_kv_head', config.n_head)
227
  self.n_embd = config.n_embd
228
  self.head_dim = config.n_embd // config.n_head
229
+ self.kv_head_dim = config.n_embd // self.n_kv_head
230
+ self.group_size = self.n_head // self.n_kv_head
231
  self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
232
  self.scale = self.head_dim ** -0.5
233
 
234
+ self.q_proj = nn.Linear(config.n_embd, config.n_embd)
235
+ self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
236
+ self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
237
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
238
  self.c_proj.MEMGPT_SCALE_INIT = 1
239
 
 
246
  def forward(self, x, xl_memory=None):
247
  B, T, C = x.size()
248
 
249
+ q = self.q_proj(x) # (B, T, C)
250
+ k = self.k_proj(x) # (B, T, n_kv_head * kv_head_dim)
251
+ v = self.v_proj(x) # (B, T, n_kv_head * kv_head_dim)
252
 
253
  q = F.normalize(q, dim=-1)
254
  k = F.normalize(k, dim=-1)
 
260
  v = torch.cat((v_xl, v), dim=1)
261
  xl_seq_len = k_xl.shape[1]
262
 
263
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim)
264
+ k = k.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, seq_len, kv_head_dim) # GQAchange
265
+ v = v.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, seq_len, kv_head_dim) # GQAchange
266
 
267
  seq_len = k.shape[2]
268
  q, k = self.rope.apply_rotary_pos_emb(q, k)
269
 
270
+ k_expanded = k.repeat_interleave(self.group_size, dim=1) # (B, n_head, seq_len, kv_head_dim)
271
+ v_expanded = v.repeat_interleave(self.group_size, dim=1) # (B, n_head, seq_len, kv_head_dim)
272
+
273
  # LOCAL ATTENTION
274
+ att = (q @ k_expanded.transpose(-2, -1)) * self.scale
275
  mask = torch.tril(torch.ones(T, seq_len, device=x.device, dtype=torch.bool))
276
  att = att.masked_fill(~mask, float('-inf'))
277
  att = F.softmax(att, dim=-1)
278
  att = self.dropout(att)
279
+ local_out = att @ v_expanded
280
 
281
+ # KNN ATTENTION
282
  if self.knn.index.ntotal > 0:
283
  q_search = q.transpose(1, 2).contiguous().view(B, T, C)
284
  mem_kv = self.knn.search(q_search, topk=self.topk_retrieved_memories)
285
  mem_k, mem_v = mem_kv.unbind(dim=-2)
286
 
287
+ # Reshape memory K,V according to KV head structure
288
+ mem_k = mem_k.view(B, T, self.topk_retrieved_memories, self.n_kv_head, self.kv_head_dim)
289
+ mem_k = mem_k.permute(0, 3, 1, 2, 4) # (B, n_kv_head, T, k, kv_head_dim)
290
+ mem_v = mem_v.view(B, T, self.topk_retrieved_memories, self.n_kv_head, self.kv_head_dim)
291
+ mem_v = mem_v.permute(0, 3, 1, 2, 4) # (B, n_kv_head, T, k, kv_head_dim)
292
  mem_k = mem_k.to(q.device)
293
  mem_v = mem_v.to(q.device)
294
 
295
+ # Expand memory K,V to match query heads
296
+ mem_k_expanded = mem_k.repeat_interleave(self.group_size, dim=1) # (B, n_head, T, k, kv_head_dim)
297
+ mem_v_expanded = mem_v.repeat_interleave(self.group_size, dim=1) # (B, n_head, T, k, kv_head_dim)
298
 
299
+ mem_att = (q.unsqueeze(-2) @ mem_k_expanded.transpose(-2, -1)).squeeze(-2) * self.scale
 
300
  mem_att = F.softmax(mem_att, dim=-1)
301
  mem_att = self.dropout(mem_att)
302
+ mem_out = (mem_att.unsqueeze(-2) @ mem_v_expanded).squeeze(-2)
303
 
304
  # Combine local and memory attention
305
  y = mem_out * self.gate_bias + local_out * (1 - self.gate_bias)
 
309
  y = y.transpose(1, 2).contiguous().view(B, T, C)
310
  y = self.c_proj(y) #(B,T,C)
311
 
312
+ # Prepare new memories - store original KV dimensions
313
+ k_orig = k.transpose(1, 2).contiguous().view(B, -1, self.n_kv_head * self.kv_head_dim)
314
+ v_orig = v.transpose(1, 2).contiguous().view(B, -1, self.n_kv_head * self.kv_head_dim)
315
+ kv_memories = torch.stack((k_orig, v_orig), dim=-2)
316
 
317
  if xl_memory is not None:
318
+ current_kv = kv_memories[:, -xl_seq_len:] #(B,T,2,n_kv_head * kv_head_dim) # GQAchange
319
  else:
320
+ current_kv = kv_memories #(B,T,2,C)
321
 
322
  self.knn.add(current_kv)
323
 
model_core/model.py CHANGED
@@ -41,7 +41,7 @@ class GPTConfig:
41
  n_layer: int = 12
42
  n_head: int = 12
43
  n_embd: int = 768
44
- n_kv_heads: int = 4
45
  dropout: float = 0.0
46
  max_knn_memories: int = 81920
47
  topk_retrieved_memories: int = 3
 
41
  n_layer: int = 12
42
  n_head: int = 12
43
  n_embd: int = 768
44
+ n_kv_head: int = 4
45
  dropout: float = 0.0
46
  max_knn_memories: int = 81920
47
  topk_retrieved_memories: int = 3