Changed the attention mechanism to GQA in knn_attention and xl_attention
Browse files- configs/config.json +1 -1
- model_core/attention.py +59 -36
- model_core/model.py +1 -1
configs/config.json
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
"n_layer": 12,
|
| 6 |
"n_head": 12,
|
| 7 |
"n_embd": 768,
|
| 8 |
-
"
|
| 9 |
},
|
| 10 |
"training": {
|
| 11 |
"max_steps": 19073,
|
|
|
|
| 5 |
"n_layer": 12,
|
| 6 |
"n_head": 12,
|
| 7 |
"n_embd": 768,
|
| 8 |
+
"n_kv_head": 4
|
| 9 |
},
|
| 10 |
"training": {
|
| 11 |
"max_steps": 19073,
|
model_core/attention.py
CHANGED
|
@@ -147,12 +147,17 @@ class XLAttention(nn.Module):
|
|
| 147 |
super().__init__()
|
| 148 |
assert config.n_embd % config.n_head == 0
|
| 149 |
self.n_head = config.n_head
|
|
|
|
| 150 |
self.n_embd = config.n_embd
|
| 151 |
self.head_dim = config.n_embd // config.n_head
|
|
|
|
|
|
|
| 152 |
self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
|
| 153 |
self.scale = self.head_dim ** -0.5
|
| 154 |
|
| 155 |
-
self.
|
|
|
|
|
|
|
| 156 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
| 157 |
self.c_proj.MEMGPT_SCALE_INIT = 1
|
| 158 |
|
|
@@ -161,8 +166,9 @@ class XLAttention(nn.Module):
|
|
| 161 |
def forward(self, x, xl_memory=None):
|
| 162 |
B, T, C = x.size()
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
| 166 |
|
| 167 |
# Handle XL memory
|
| 168 |
if xl_memory is not None:
|
|
@@ -172,14 +178,17 @@ class XLAttention(nn.Module):
|
|
| 172 |
xl_seq_len = k_xl.shape[1]
|
| 173 |
|
| 174 |
# Reshape for multi-head attention
|
| 175 |
-
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B,
|
| 176 |
-
k = k.view(B, -1, self.
|
| 177 |
-
v = v.view(B, -1, self.
|
| 178 |
|
| 179 |
# Apply rotary positional encoding
|
| 180 |
seq_len = k.shape[2]
|
| 181 |
q, k = self.rope.apply_rotary_pos_emb(q, k)
|
| 182 |
|
|
|
|
|
|
|
|
|
|
| 183 |
# Attention computation
|
| 184 |
att = (q @ k.transpose(-2, -1)) * self.scale
|
| 185 |
|
|
@@ -190,34 +199,41 @@ class XLAttention(nn.Module):
|
|
| 190 |
att = F.softmax(att, dim=-1)
|
| 191 |
att = self.dropout(att)
|
| 192 |
|
| 193 |
-
y = att @ v # (B,
|
| 194 |
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
|
| 195 |
|
| 196 |
y = self.c_proj(y)
|
| 197 |
|
| 198 |
-
# Prepare new XL memories
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
|
| 203 |
if xl_memory is not None:
|
| 204 |
-
current_kv = kv_memories[:, -xl_seq_len:] #(B,T,C)
|
| 205 |
else:
|
| 206 |
-
current_kv = kv_memories
|
| 207 |
|
| 208 |
-
return y, current_kv
|
| 209 |
|
| 210 |
class KNNAttention(nn.Module):
|
| 211 |
def __init__(self, config, knn, topk_retrieved_memories=3):
|
| 212 |
super().__init__()
|
| 213 |
assert config.n_embd % config.n_head == 0
|
| 214 |
self.n_head = config.n_head
|
|
|
|
| 215 |
self.n_embd = config.n_embd
|
| 216 |
self.head_dim = config.n_embd // config.n_head
|
|
|
|
|
|
|
| 217 |
self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
|
| 218 |
self.scale = self.head_dim ** -0.5
|
| 219 |
|
| 220 |
-
self.
|
|
|
|
|
|
|
| 221 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
| 222 |
self.c_proj.MEMGPT_SCALE_INIT = 1
|
| 223 |
|
|
@@ -230,8 +246,9 @@ class KNNAttention(nn.Module):
|
|
| 230 |
def forward(self, x, xl_memory=None):
|
| 231 |
B, T, C = x.size()
|
| 232 |
|
| 233 |
-
|
| 234 |
-
|
|
|
|
| 235 |
|
| 236 |
q = F.normalize(q, dim=-1)
|
| 237 |
k = F.normalize(k, dim=-1)
|
|
@@ -243,40 +260,46 @@ class KNNAttention(nn.Module):
|
|
| 243 |
v = torch.cat((v_xl, v), dim=1)
|
| 244 |
xl_seq_len = k_xl.shape[1]
|
| 245 |
|
| 246 |
-
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 247 |
-
k = k.view(B, -1, self.
|
| 248 |
-
v = v.view(B, -1, self.
|
| 249 |
|
| 250 |
seq_len = k.shape[2]
|
| 251 |
q, k = self.rope.apply_rotary_pos_emb(q, k)
|
| 252 |
|
|
|
|
|
|
|
|
|
|
| 253 |
# LOCAL ATTENTION
|
| 254 |
-
att = (q @
|
| 255 |
mask = torch.tril(torch.ones(T, seq_len, device=x.device, dtype=torch.bool))
|
| 256 |
att = att.masked_fill(~mask, float('-inf'))
|
| 257 |
att = F.softmax(att, dim=-1)
|
| 258 |
att = self.dropout(att)
|
| 259 |
-
local_out = att @
|
| 260 |
|
| 261 |
-
# KNN ATTENTION
|
| 262 |
if self.knn.index.ntotal > 0:
|
| 263 |
q_search = q.transpose(1, 2).contiguous().view(B, T, C)
|
| 264 |
mem_kv = self.knn.search(q_search, topk=self.topk_retrieved_memories)
|
| 265 |
mem_k, mem_v = mem_kv.unbind(dim=-2)
|
| 266 |
|
| 267 |
-
|
| 268 |
-
mem_k = mem_k.
|
| 269 |
-
|
| 270 |
-
mem_v = mem_v.
|
|
|
|
| 271 |
mem_k = mem_k.to(q.device)
|
| 272 |
mem_v = mem_v.to(q.device)
|
| 273 |
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
|
| 276 |
-
mem_att = (q.unsqueeze(-2) @ mem_k.transpose(-2, -1)).squeeze(-2) * self.scale
|
| 277 |
mem_att = F.softmax(mem_att, dim=-1)
|
| 278 |
mem_att = self.dropout(mem_att)
|
| 279 |
-
mem_out = (mem_att.unsqueeze(-2) @
|
| 280 |
|
| 281 |
# Combine local and memory attention
|
| 282 |
y = mem_out * self.gate_bias + local_out * (1 - self.gate_bias)
|
|
@@ -286,15 +309,15 @@ class KNNAttention(nn.Module):
|
|
| 286 |
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 287 |
y = self.c_proj(y) #(B,T,C)
|
| 288 |
|
| 289 |
-
# Prepare new memories
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
kv_memories = torch.stack((
|
| 293 |
|
| 294 |
if xl_memory is not None:
|
| 295 |
-
current_kv = kv_memories[:, -xl_seq_len:] #(B,T,2,
|
| 296 |
else:
|
| 297 |
-
current_kv = kv_memories #(B,T,2,C)
|
| 298 |
|
| 299 |
self.knn.add(current_kv)
|
| 300 |
|
|
|
|
| 147 |
super().__init__()
|
| 148 |
assert config.n_embd % config.n_head == 0
|
| 149 |
self.n_head = config.n_head
|
| 150 |
+
self.n_kv_head = getattr(config, 'n_kv_head', config.n_head)
|
| 151 |
self.n_embd = config.n_embd
|
| 152 |
self.head_dim = config.n_embd // config.n_head
|
| 153 |
+
self.kv_head_dim = config.n_embd // self.n_kv_head
|
| 154 |
+
self.group_size = self.n_head // self.n_kv_head
|
| 155 |
self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
|
| 156 |
self.scale = self.head_dim ** -0.5
|
| 157 |
|
| 158 |
+
self.q_proj = nn.Linear(config.n_embd, config.n_embd)
|
| 159 |
+
self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
|
| 160 |
+
self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
|
| 161 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
| 162 |
self.c_proj.MEMGPT_SCALE_INIT = 1
|
| 163 |
|
|
|
|
| 166 |
def forward(self, x, xl_memory=None):
|
| 167 |
B, T, C = x.size()
|
| 168 |
|
| 169 |
+
q = self.q_proj(x) # (B, T, C)
|
| 170 |
+
k = self.k_proj(x) # (B, T, n_kv_head * kv_head_dim)
|
| 171 |
+
v = self.v_proj(x) # (B, T, n_kv_head * kv_head_dim)
|
| 172 |
|
| 173 |
# Handle XL memory
|
| 174 |
if xl_memory is not None:
|
|
|
|
| 178 |
xl_seq_len = k_xl.shape[1]
|
| 179 |
|
| 180 |
# Reshape for multi-head attention
|
| 181 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim)
|
| 182 |
+
k = k.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, T+xl, kv_head_dim) # GQAchange
|
| 183 |
+
v = v.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, T+xl, kv_head_dim) # GQAchange
|
| 184 |
|
| 185 |
# Apply rotary positional encoding
|
| 186 |
seq_len = k.shape[2]
|
| 187 |
q, k = self.rope.apply_rotary_pos_emb(q, k)
|
| 188 |
|
| 189 |
+
k = k.repeat_interleave(self.group_size, dim=1) # (B, n_head, T+xl, kv_head_dim)
|
| 190 |
+
v = v.repeat_interleave(self.group_size, dim=1) # (B, n_head, T+xl, kv_head_dim)
|
| 191 |
+
|
| 192 |
# Attention computation
|
| 193 |
att = (q @ k.transpose(-2, -1)) * self.scale
|
| 194 |
|
|
|
|
| 199 |
att = F.softmax(att, dim=-1)
|
| 200 |
att = self.dropout(att)
|
| 201 |
|
| 202 |
+
y = att @ v # (B, n_head, T, kv_head_dim)
|
| 203 |
y = y.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C)
|
| 204 |
|
| 205 |
y = self.c_proj(y)
|
| 206 |
|
| 207 |
+
# Prepare new XL memories - store original KV dimensions
|
| 208 |
+
k_orig = k[:, ::self.group_size]
|
| 209 |
+
v_orig = v[:, ::self.group_size]
|
| 210 |
+
k_orig = k_orig.transpose(1, 2).contiguous().view(B, -1, self.n_kv_head * self.kv_head_dim)
|
| 211 |
+
v_orig = v_orig.transpose(1, 2).contiguous().view(B, -1, self.n_kv_head * self.kv_head_dim)
|
| 212 |
+
kv_memories = torch.stack((k_orig, v_orig), dim=-2)
|
| 213 |
|
| 214 |
if xl_memory is not None:
|
| 215 |
+
current_kv = kv_memories[:, -xl_seq_len:] #(B,T,2,C)
|
| 216 |
else:
|
| 217 |
+
current_kv = kv_memories #(B,T,2,C)
|
| 218 |
|
| 219 |
+
return y, current_kv #(B,T,C),(B,T,2,C)
|
| 220 |
|
| 221 |
class KNNAttention(nn.Module):
|
| 222 |
def __init__(self, config, knn, topk_retrieved_memories=3):
|
| 223 |
super().__init__()
|
| 224 |
assert config.n_embd % config.n_head == 0
|
| 225 |
self.n_head = config.n_head
|
| 226 |
+
self.n_kv_head = getattr(config, 'n_kv_head', config.n_head)
|
| 227 |
self.n_embd = config.n_embd
|
| 228 |
self.head_dim = config.n_embd // config.n_head
|
| 229 |
+
self.kv_head_dim = config.n_embd // self.n_kv_head
|
| 230 |
+
self.group_size = self.n_head // self.n_kv_head
|
| 231 |
self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
|
| 232 |
self.scale = self.head_dim ** -0.5
|
| 233 |
|
| 234 |
+
self.q_proj = nn.Linear(config.n_embd, config.n_embd)
|
| 235 |
+
self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
|
| 236 |
+
self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
|
| 237 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
| 238 |
self.c_proj.MEMGPT_SCALE_INIT = 1
|
| 239 |
|
|
|
|
| 246 |
def forward(self, x, xl_memory=None):
|
| 247 |
B, T, C = x.size()
|
| 248 |
|
| 249 |
+
q = self.q_proj(x) # (B, T, C)
|
| 250 |
+
k = self.k_proj(x) # (B, T, n_kv_head * kv_head_dim)
|
| 251 |
+
v = self.v_proj(x) # (B, T, n_kv_head * kv_head_dim)
|
| 252 |
|
| 253 |
q = F.normalize(q, dim=-1)
|
| 254 |
k = F.normalize(k, dim=-1)
|
|
|
|
| 260 |
v = torch.cat((v_xl, v), dim=1)
|
| 261 |
xl_seq_len = k_xl.shape[1]
|
| 262 |
|
| 263 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim)
|
| 264 |
+
k = k.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, seq_len, kv_head_dim) # GQAchange
|
| 265 |
+
v = v.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, seq_len, kv_head_dim) # GQAchange
|
| 266 |
|
| 267 |
seq_len = k.shape[2]
|
| 268 |
q, k = self.rope.apply_rotary_pos_emb(q, k)
|
| 269 |
|
| 270 |
+
k_expanded = k.repeat_interleave(self.group_size, dim=1) # (B, n_head, seq_len, kv_head_dim)
|
| 271 |
+
v_expanded = v.repeat_interleave(self.group_size, dim=1) # (B, n_head, seq_len, kv_head_dim)
|
| 272 |
+
|
| 273 |
# LOCAL ATTENTION
|
| 274 |
+
att = (q @ k_expanded.transpose(-2, -1)) * self.scale
|
| 275 |
mask = torch.tril(torch.ones(T, seq_len, device=x.device, dtype=torch.bool))
|
| 276 |
att = att.masked_fill(~mask, float('-inf'))
|
| 277 |
att = F.softmax(att, dim=-1)
|
| 278 |
att = self.dropout(att)
|
| 279 |
+
local_out = att @ v_expanded
|
| 280 |
|
| 281 |
+
# KNN ATTENTION
|
| 282 |
if self.knn.index.ntotal > 0:
|
| 283 |
q_search = q.transpose(1, 2).contiguous().view(B, T, C)
|
| 284 |
mem_kv = self.knn.search(q_search, topk=self.topk_retrieved_memories)
|
| 285 |
mem_k, mem_v = mem_kv.unbind(dim=-2)
|
| 286 |
|
| 287 |
+
# Reshape memory K,V according to KV head structure
|
| 288 |
+
mem_k = mem_k.view(B, T, self.topk_retrieved_memories, self.n_kv_head, self.kv_head_dim)
|
| 289 |
+
mem_k = mem_k.permute(0, 3, 1, 2, 4) # (B, n_kv_head, T, k, kv_head_dim)
|
| 290 |
+
mem_v = mem_v.view(B, T, self.topk_retrieved_memories, self.n_kv_head, self.kv_head_dim)
|
| 291 |
+
mem_v = mem_v.permute(0, 3, 1, 2, 4) # (B, n_kv_head, T, k, kv_head_dim)
|
| 292 |
mem_k = mem_k.to(q.device)
|
| 293 |
mem_v = mem_v.to(q.device)
|
| 294 |
|
| 295 |
+
# Expand memory K,V to match query heads
|
| 296 |
+
mem_k_expanded = mem_k.repeat_interleave(self.group_size, dim=1) # (B, n_head, T, k, kv_head_dim)
|
| 297 |
+
mem_v_expanded = mem_v.repeat_interleave(self.group_size, dim=1) # (B, n_head, T, k, kv_head_dim)
|
| 298 |
|
| 299 |
+
mem_att = (q.unsqueeze(-2) @ mem_k_expanded.transpose(-2, -1)).squeeze(-2) * self.scale
|
|
|
|
| 300 |
mem_att = F.softmax(mem_att, dim=-1)
|
| 301 |
mem_att = self.dropout(mem_att)
|
| 302 |
+
mem_out = (mem_att.unsqueeze(-2) @ mem_v_expanded).squeeze(-2)
|
| 303 |
|
| 304 |
# Combine local and memory attention
|
| 305 |
y = mem_out * self.gate_bias + local_out * (1 - self.gate_bias)
|
|
|
|
| 309 |
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 310 |
y = self.c_proj(y) #(B,T,C)
|
| 311 |
|
| 312 |
+
# Prepare new memories - store original KV dimensions
|
| 313 |
+
k_orig = k.transpose(1, 2).contiguous().view(B, -1, self.n_kv_head * self.kv_head_dim)
|
| 314 |
+
v_orig = v.transpose(1, 2).contiguous().view(B, -1, self.n_kv_head * self.kv_head_dim)
|
| 315 |
+
kv_memories = torch.stack((k_orig, v_orig), dim=-2)
|
| 316 |
|
| 317 |
if xl_memory is not None:
|
| 318 |
+
current_kv = kv_memories[:, -xl_seq_len:] #(B,T,2,n_kv_head * kv_head_dim) # GQAchange
|
| 319 |
else:
|
| 320 |
+
current_kv = kv_memories #(B,T,2,C)
|
| 321 |
|
| 322 |
self.knn.add(current_kv)
|
| 323 |
|
model_core/model.py
CHANGED
|
@@ -41,7 +41,7 @@ class GPTConfig:
|
|
| 41 |
n_layer: int = 12
|
| 42 |
n_head: int = 12
|
| 43 |
n_embd: int = 768
|
| 44 |
-
|
| 45 |
dropout: float = 0.0
|
| 46 |
max_knn_memories: int = 81920
|
| 47 |
topk_retrieved_memories: int = 3
|
|
|
|
| 41 |
n_layer: int = 12
|
| 42 |
n_head: int = 12
|
| 43 |
n_embd: int = 768
|
| 44 |
+
n_kv_head: int = 4
|
| 45 |
dropout: float = 0.0
|
| 46 |
max_knn_memories: int = 81920
|
| 47 |
topk_retrieved_memories: int = 3
|