| from dataclasses import dataclass |
| from typing import Optional,Tuple,List |
| from collections import namedtuple |
|
|
| import math |
| import time |
| import json |
| import torch |
| import torch.nn as nn |
| from torch import Tensor |
| from torch.nn import functional as F |
| from torch.utils.checkpoint import checkpoint |
|
|
| try: |
| from .configuration_dcpythia import DCPythiaConfig |
| except: |
| from configuration_dcpythia import DCPythiaConfig |
| from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
| class KVKWCache(nn.Module): |
| def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, window_size=2048, dtype=torch.float16, use_kw_cache=True): |
| super().__init__() |
| self.head_dim = head_dim |
| self.kw_dim = 2 * n_heads |
| self.n_heads = n_heads |
| self.window_size = window_size |
| self.use_kw_cache = use_kw_cache |
| if window_size is None: |
| self.seq_length = max_seq_length |
| else: |
| self.seq_length = min(window_size, max_seq_length) |
| cache_shape = (max_batch_size, n_heads, self.seq_length, head_dim) |
| kw_cache_shape = (max_batch_size, self.seq_length, 2, n_heads, n_heads) |
| self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) |
| self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) |
| if self.use_kw_cache: |
| self.register_buffer('kw_cache', torch.zeros(kw_cache_shape, dtype=dtype)) |
|
|
| def update(self, input_pos, k_val, v_val, kw_val=None): |
| |
| assert input_pos.shape[-1] == k_val.shape[2] |
| B,N,S,D = v_val.shape |
| k_out = self.k_cache |
| v_out = self.v_cache |
| if self.use_kw_cache: |
| kw_out = self.kw_cache |
| else: |
| kw_out = None |
|
|
| if self.window_size is None: |
| k_out[:, :, input_pos] = k_val |
| v_out[:, :, input_pos] = v_val |
| if self.use_kw_cache and kw_val is not None: |
| kw_out[:,input_pos] = kw_val |
| elif S == 1: |
| input_pos = input_pos % self.seq_length |
| v_out[:, :, input_pos] = v_val |
| k_out[:, :, input_pos] = k_val |
| if self.use_kw_cache and kw_val is not None: |
| kw_out[:,input_pos] = kw_val |
| else: |
| start = max(0, input_pos[-1]-self.seq_length+1) |
| input_pos = input_pos[start:] % self.seq_length |
| v_out[:, :, input_pos] = v_val[:,:,start:] |
| k_out[:, :, input_pos] = k_val[:,:,start:] |
| if self.use_kw_cache and kw_val is not None: |
| kw_out[:, input_pos] = kw_val[:,start:] |
| return k_out, v_out, kw_out |
|
|
| class DCPythia(PreTrainedModel): |
| config_class=DCPythiaConfig |
|
|
| def __init__(self, config: DCPythiaConfig) -> None: |
| super().__init__(config) |
| self.config = config |
|
|
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) |
| self.layers = nn.ModuleList(DCPythiaBlock(config, lidx) for lidx in range(config.n_layer)) |
| self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps) |
| self.output = nn.Linear(config.dim, config.vocab_size, bias=False) |
| self.use_gradient_checkpointing = config.use_gradient_checkpointing |
| self.is_training = config.is_training |
|
|
| self.freqs_cis: Optional[Tensor] = None |
| self.rotary_ndims = int(config.head_dim * config.rotary_pct) |
| self.mask_cache: Optional[Tensor] = None |
| self.window_size = config.window_size |
| self.max_batch_size = -1 |
| self.max_seq_length = -1 |
|
|
| def setup_caches(self, max_batch_size, max_seq_length, set_kv_cache=True): |
| if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: |
| return |
| head_dim = self.config.dim // self.config.n_head |
| max_seq_length = find_multiple(max_seq_length, 8) |
| self.max_seq_length = max_seq_length |
| self.max_batch_size = max_batch_size |
| if not self.is_training: |
| for b in self.layers: |
| if set_kv_cache: |
| use_kw_cache = False if b.attention.query_wise else True |
| b.attention.kv_cache = KVKWCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, window_size=b.attention.window_size, use_kw_cache=use_kw_cache) |
| b.attention.dyn_w_proj.merge_weights() |
| if not b.attention.use_sw: |
| dtype = b.attention.wo.weight.dtype |
| device = b.attention.wo.weight.device |
| b.attention.dyn_w_proj.sw = b.attention.dyn_w_proj.sw.to(device=device, dtype=dtype) |
| b.attention.dyn_w_proj.pre_proj.w = b.attention.dyn_w_proj.pre_proj.w.to(device=device, dtype=dtype) |
| b.attention.dyn_w_proj.post_proj.w = b.attention.dyn_w_proj.post_proj.w.to(device=device, dtype=dtype) |
| |
| self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.rotary_ndims, self.config.rope_base).to(self.tok_embeddings.weight.device) |
| if self.is_training: |
| self.causal_mask = torch.tril(torch.ones(self.config.block_size, self.config.block_size, dtype=torch.bool, device=self.tok_embeddings.weight.device)) |
| elif self.window_size is None: |
| self.causal_mask = torch.tril(torch.ones(max_seq_length, max_seq_length, dtype=torch.bool, device=self.tok_embeddings.weight.device)) |
| else: |
| self.causal_mask = torch.stack([make_window_mask(max_seq_length, self.config.window_size), torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))]) |
| |
| def generate(self, input_ids, num_tokens_to_generate=10, compiled_decode_one_token=None): |
| batch_size, seq_length = input_ids.shape |
| input_pos = torch.arange(seq_length, device=self.device) |
| generated_ids = torch.zeros(batch_size, seq_length + num_tokens_to_generate, dtype=torch.int, device=self.device) |
| generated_ids[:, :seq_length] = input_ids.to(self.device).to(torch.int) |
| logits = self.forward(input_ids, input_pos=input_pos,return_tensor=True) |
| _next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] |
| next_token = torch.zeros(self.max_batch_size, 1, device=self.device, dtype=torch.int) |
| next_token[:batch_size] = _next_token |
| generated_ids[:, seq_length] = next_token[:batch_size, 0] |
| input_pos = torch.tensor([seq_length], device=self.device) |
| for _ in range(1, num_tokens_to_generate): |
| if compiled_decode_one_token is not None: |
| next_token = compiled_decode_one_token(self, next_token.clone(), input_pos) |
| else: |
| next_token = self.decode_one_token(next_token.clone(), input_pos) |
| generated_ids[:, input_pos+1] = next_token.int()[:batch_size] |
| input_pos += 1 |
| return generated_ids |
| |
| def decode_one_token(self, cur_token, input_pos): |
| logits = self.forward( |
| cur_token, |
| input_pos=input_pos, |
| return_tensor=True, |
| ) |
| new_token = torch.argmax(logits[:, -1], dim=-1)[:,None] |
| return new_token |
|
|
| def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None, return_tensor=False) -> Tensor: |
| assert self.freqs_cis is not None, "Caches must be initialized first" |
| if input_pos is None: |
| input_pos = torch.arange(idx.shape[-1], device=idx.device, dtype=torch.int) |
| if self.window_size is None or self.is_training: |
| mask = self.causal_mask[None, None, input_pos] |
| else: |
| mask = self.causal_mask[None, None,:,input_pos] |
| freqs_cis = self.freqs_cis[input_pos][:idx.shape[-1]] |
| x = self.tok_embeddings(idx) |
| for i, layer in enumerate(self.layers): |
| if self.is_training or self.window_size is None : |
| layer_mask = mask |
| gen_mask = None |
| elif self.window_size is not None: |
| layer_mask = mask[:,:,1] if layer.attention.window_size is None else mask[:,:,0] |
| gen_mask = mask[:,:,1] if layer.attention.window_size is not None else None |
| if self.use_gradient_checkpointing: |
| x = checkpoint(layer, x, input_pos, freqs_cis, layer_mask) |
| else: |
| x = layer(x, input_pos, freqs_cis, layer_mask, gen_mask=gen_mask) |
| x = self.norm(x) |
| logits = self.output(x) |
| if return_tensor: |
| return logits |
| else: |
| CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) |
| return CausalLMOutput(logits=logits) |
|
|
| class DCPythiaBlock(nn.Module): |
| def __init__(self, config: DCPythiaConfig, lidx) -> None: |
| super().__init__() |
| self.lidx = lidx |
| self.attention = DCMHAttention(config, lidx) |
| self.feed_forward = FeedForward(config) |
| self.ffn_norm = nn.LayerNorm(config.dim, eps=config.norm_eps) |
| self.attention_norm = nn.LayerNorm(config.dim, eps=config.norm_eps) |
| self.use_parallel_residual = config.use_parallel_residual |
|
|
| def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor, gen_mask=None) -> Tensor: |
| h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos, fast_infer=True, gen_mask=gen_mask) |
| if self.use_parallel_residual: |
| out = h + self.feed_forward(self.ffn_norm(x)) |
| else: |
| out = h + self.feed_forward(self.ffn_norm(h)) |
| return out |
|
|
| class DynamicWeightProjection(nn.Module): |
|
|
| def __init__(self, num_heads=32, num_groups=1, residual=True, query_input_dim=4096, dynamic_squeeze_ratio=16, dynamic_w_hidden_dim=128,dtype=torch.float16,use_sw=False): |
| super().__init__() |
| self.num_heads = num_heads |
| self.num_groups = num_groups |
| self.query_input_dim = query_input_dim |
| self.dynamic_squeeze_ratio = dynamic_squeeze_ratio |
| self.dynamic_w_hidden_dim = dynamic_w_hidden_dim |
| self.dw_hidden_activation = nn.GELU() |
| self.num_heads_per_group = self.num_heads // self.num_groups |
| self.dw_activation = nn.Tanh() |
| self.dw1_norm = RMSnormNoscale(dim=-1) |
| self.use_sw = use_sw |
| self.pre_proj = CrossHeadProjection('pre', num_heads=self.num_heads, use_sw=use_sw) |
| self.post_proj = CrossHeadProjection('post', num_heads=self.num_heads, use_sw=use_sw) |
|
|
| dynamic_hidden_dim = self.num_heads_per_group // self.dynamic_squeeze_ratio |
| self.dynamic_hidden_dim = dynamic_hidden_dim |
| self.dw1 = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, 4, self.dynamic_w_hidden_dim, dtype=dtype)) |
| G, K, M = self.num_groups, self.dynamic_w_hidden_dim, self.num_heads_per_group |
| I = dynamic_hidden_dim * 2 |
| self.qkw = nn.parameter.Parameter(torch.zeros([G, 4, K, I, M], dtype=dtype)) |
| self.dd = nn.parameter.Parameter(torch.zeros(self.query_input_dim, self.num_groups, self.num_heads_per_group * 4, dtype=dtype)) |
|
|
| self.merge_weights() |
|
|
| def merge_weights(self): |
| self.dw_m = nn.parameter.Parameter(torch.cat([self.dw1.reshape(self.query_input_dim, -1), self.dd.squeeze(1)], dim=-1)).to(self.dw1.device) |
| self.qkw_m = nn.parameter.Parameter(self.qkw.permute(0,1,2,3,4).reshape(4,self.dynamic_w_hidden_dim,-1)).to(self.dw1.device) |
| if self.use_sw: |
| self.sw = nn.parameter.Parameter(torch.stack([self.pre_proj.w, self.post_proj.w]).squeeze(1) + torch.eye(self.num_heads) ).to(self.dw1.device) |
| else: |
| self.sw = (torch.eye(self.num_heads).expand(2,self.num_heads,self.num_heads)).to(self.dw1.device) |
| |
| def forward(self,query_vec,KW:Optional[torch.Tensor]=None, gen_cache:Optional[bool]=True): |
| dw_hidden = torch.einsum('BTD,DGCK->BTGCK', query_vec, self.dw1) |
| dw_hidden = self.dw_hidden_activation(dw_hidden) |
| w1, w2 = torch.split(torch.einsum('BTGCK,GCKIM->BTGCIM', dw_hidden, self.qkw), self.qkw.shape[-2]//2, dim=-2) |
| w1 = self.dw1_norm(w1) |
| pre_qw1, pre_kw1, post_qw1, post_kw1 = unbind(w1, 4, dim=3) |
| pre_qw2, pre_kw2, post_qw2, post_kw2 = unbind(w2, 4, dim=3) |
| dd = torch.einsum('BTD,DGM->BTGM', query_vec, self.dd) |
| dd = self.dw_activation(dd) |
| pre_qdd, pre_kdd, post_qdd, post_kdd = torch.split(dd, dd.shape[-1] // 4, dim=-1) |
| pre_dw_args = (pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd) |
| post_dw_args = (post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd) |
| if gen_cache: |
| pre_kw = torch.einsum('BSGIM, BSGIN->BSMN', pre_kw1, pre_kw2) + torch.diag_embed(pre_kdd.squeeze(2)) |
| post_kw = torch.einsum('BSGIM, BSGIN->BSMN', post_kw1, post_kw2) + torch.diag_embed(post_kdd.squeeze(2)) |
| KW = torch.stack((pre_kw, post_kw), dim=-3) |
| return pre_dw_args, post_dw_args, KW |
|
|
|
|
| class RMSnormNoscale(nn.Module): |
| |
| def __init__(self, epsilon=1e-6, dim=-1): |
| super().__init__() |
| self.dim = dim |
| self.epsilon = epsilon |
|
|
| def forward(self, inputs): |
| var = inputs.pow(2).mean(dim=self.dim, keepdim=True) |
| normed_inputs = inputs * torch.rsqrt(var + self.epsilon) |
| return normed_inputs |
|
|
|
|
| class RMSnorm(nn.Module): |
|
|
| def __init__(self, hid_dim=128, epsilon=1e-6, dim=-1): |
| super().__init__() |
| self.dim = dim |
| self.hid_dim = hid_dim |
| self.epsilon = epsilon |
| self.scale = nn.parameter.Parameter(data=torch.ones(self.hid_dim)) |
|
|
| def forward(self, inputs): |
| var = inputs.pow(2).mean(dim=self.dim, keepdim=True) |
| normed_inputs = inputs * torch.rsqrt(var + self.epsilon) |
| normed_inputs = normed_inputs * self.scale |
| return normed_inputs |
|
|
|
|
| class CrossHeadProjection(nn.Module): |
|
|
| def __init__(self, mode, num_heads=16, num_groups=1, dtype=torch.float16, use_sw=False): |
| super().__init__() |
| self.mode = mode |
| self.use_sw = use_sw |
| self.num_heads = num_heads |
| self.num_groups = num_groups |
| self.num_heads_per_group = self.num_heads // self.num_groups |
| if self.use_sw: |
| self.w = nn.parameter.Parameter(data=torch.zeros(self.num_groups, self.num_heads_per_group, self.num_heads_per_group, dtype=dtype)) |
| else: |
| self.register_buffer('w', torch.eye(self.num_heads_per_group, dtype=dtype).expand(self.num_groups, self.num_heads_per_group, self.num_heads_per_group)) |
|
|
| def forward(self, inputs, |
| dws:Optional[Tuple[Tensor,Tensor, Tensor,Tensor, Tensor,Tensor]]=None, |
| query_vec=None, key_vec=None, |
| proj_w:Optional[Tensor]=None, |
| fast_infer=True): |
| if proj_w is not None: |
| ret = torch.einsum('BNTS,BSNM->BMTS', inputs, proj_w) |
| else: |
| assert dws is not None |
| qw1, qw2, kw1, kw2, qdd, kdd = dws |
| inputs = inputs.unsqueeze(1) |
| |
| ret = torch.einsum('BGMTS,GMN->BGNTS', inputs, self.w) if self.use_sw else inputs |
| if fast_infer: |
| inputs_label = 'BGMTS' |
| hidden_sym = 'I'; hidden_label = inputs_label.replace('M', 'I') |
| |
| for sym, (w1, w2) in zip(['T', 'S'], [(qw1, qw2), (kw1, kw2)]): |
| dw_label = f'B{sym}G{hidden_sym}M' |
| dynamic_hidden_dim = w1.shape[dw_label.index(hidden_sym)] |
| eqn1 = f'{inputs_label},{dw_label}->{hidden_label}' |
| eqn2 = f'{hidden_label},{dw_label}->{inputs_label}' |
| for i in range(dynamic_hidden_dim): |
| hidden = torch.einsum(eqn1.replace(hidden_sym, ''), inputs, w1[..., i, :]) |
| out = torch.einsum(eqn2.replace(hidden_sym, ''), hidden, w2[..., i, :]) |
| ret = ret + out |
| |
| for sym, dd in zip(['T', 'S'], [qdd, kdd]): |
| dd_label = f'B{sym}GM' |
| dout = torch.einsum(f'{inputs_label},{dd_label}->{inputs_label}', inputs, dd) |
| ret = ret + dout |
| else: |
| |
| x_inter = torch.einsum('BGNTS, BTGIN->BGTSI', inputs, qw1) |
| qw_out = torch.einsum('BGTSI, BTGIN->BGNTS', x_inter, qw2) |
| ret = ret + qw_out |
| x_inter = torch.einsum('BGNTS, BSGIN->BGTSI', inputs, kw1) |
| kw_out = torch.einsum('BGTSI, BSGIN->BGNTS', x_inter, kw2) |
| ret = ret + kw_out |
|
|
| |
| ret = ret + torch.einsum('BGNTS, BTGN->BGNTS', inputs, qdd) |
| ret = ret + torch.einsum('BGNTS, BSGN->BGNTS', inputs, kdd) |
| ret = ret.squeeze(1) |
| return ret |
|
|
|
|
| class DCMHAttention(nn.Module): |
| def __init__(self, config: DCPythiaConfig, lidx, use_sw=False): |
| super().__init__() |
| assert config.dim % config.n_head == 0 |
| total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim |
| |
| self.lidx = lidx |
| self.wqkv = nn.Linear(config.dim, total_head_dim, bias=config.use_linear_bias) |
| self.wo = nn.Linear(config.dim, config.dim, bias=config.use_linear_bias) |
| self.kv_cache = None |
|
|
| self.n_head = config.n_head |
| self.head_dim = config.head_dim |
| self.n_local_heads = config.n_local_heads |
| self.is_training = config.is_training |
| self.dim = config.dim |
| self.use_dcmha = config.use_dcmha |
| self.scale_factor = 1 / math.sqrt(self.head_dim) |
| self.q_chunk_size = config.q_chunk_size |
| self.use_sw = use_sw |
| self.dyn_w_proj = DynamicWeightProjection(num_heads=self.n_head, query_input_dim=config.dim, dynamic_squeeze_ratio=self.n_head//2, dynamic_w_hidden_dim=self.n_head*4, use_sw=use_sw) |
| self.use_qk_norm = config.use_qk_norm |
| if self.use_qk_norm: |
| self.q_norm = RMSnorm(hid_dim=self.head_dim) |
| self.k_norm = RMSnorm(hid_dim=self.head_dim) |
|
|
| self.window_types = { |
| "LG":[256, None], |
| "LGLL":[256, None, 256, 256], |
| "LGL6":[256, None, 256, 256, 256, 256, 256, 256], |
| } |
|
|
| self.query_wise = config.query_wise |
| if config.window_type is None: |
| self.window_size = None if self.lidx % 2 == 1 else config.window_size |
| else: |
| window_l = self.window_types[config.window_type] |
| self.window_size = window_l[self.lidx % len(window_l)] |
|
|
| self.rotary_ndims = int(self.head_dim * config.rotary_pct) |
|
|
| if not self.is_training: |
| self._register_load_state_dict_pre_hook(self.load_hook) |
|
|
| def load_hook(self, state_dict, prefix, *args): |
| if prefix + "wq.weight" in state_dict: |
| wq = state_dict.pop(prefix + "wq.weight") |
| wk = state_dict.pop(prefix + "wk.weight") |
| wv = state_dict.pop(prefix + "wv.weight") |
| state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) |
| if prefix + "wq.bias" in state_dict: |
| wq_b = state_dict.pop(prefix + "wq.bias") |
| wk_b = state_dict.pop(prefix + "wk.bias") |
| wv_b = state_dict.pop(prefix + "wv.bias") |
| state_dict[prefix + "wqkv.bias"] = torch.cat([wq_b, wk_b, wv_b]) |
|
|
| def _generate_fast(self, x, input_pos, q, k, v, k_mask): |
| B,T,D = x.shape |
| N,I = self.n_head, self.dyn_w_proj.dynamic_hidden_dim |
| dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1) |
| dw_hidden = dw_hidden.view((B,T,4,-1,1)) |
| dw = (self.dyn_w_proj.dw_hidden_activation(dw_hidden) * self.dyn_w_proj.qkw_m).sum(-2) |
| w1, w2 = dw.view((B,T,2,2,-1,N)).split(I,-2) |
| w1 = self.dyn_w_proj.dw1_norm(w1) |
| qkdd = self.dyn_w_proj.dw_activation(dd.view((B,T,2,2,N))) |
| qkw = torch.einsum('BTKJIN,BTKJIM->BTKJNM', w1, w2) + torch.diag_embed(qkdd) |
| if self.query_wise: |
| qw, _ = qkw.unbind(3) |
| kw_new = None |
| qw = qw + self.dyn_w_proj.sw |
| else: |
| qw, kw_new = qkw.unbind(3) |
| kw_new = kw_new + self.dyn_w_proj.sw |
| if self.kv_cache is not None: |
| k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) |
| logits = q @ k.transpose(-2, -1) * self.scale_factor |
| if self.query_wise: |
| w = qw |
| else: |
| w = qw + kw_out |
| wl, w = w.permute(0,2,3,4,1).unbind(1) |
| logits = (logits * wl).sum(1).unsqueeze(2) |
| min_value = torch.finfo(torch.float16).min |
| logits = torch.where(k_mask, logits, min_value) |
| probs = logits.softmax(-1) |
| probs = (probs * w).sum(1).unsqueeze(2) |
| y = probs @ v |
| return y |
|
|
| def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None, fast_infer=True, gen_mask=None) -> Tensor: |
| bsz, seqlen, _ = x.shape |
|
|
| kv_size = self.n_local_heads * self.head_dim |
| q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) |
|
|
| q = q.view(bsz, seqlen, self.n_head, self.head_dim) |
| k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
|
|
| if self.use_qk_norm: |
| q, k = self.q_norm(q), self.k_norm(k) |
|
|
| if self.rotary_ndims == self.head_dim: |
| q = apply_rotary_emb(q, freqs_cis) |
| k = apply_rotary_emb(k, freqs_cis) |
| else: |
| q_rot = q[..., : self.rotary_ndims] |
| q_pass = q[..., self.rotary_ndims :] |
| k_rot = k[..., : self.rotary_ndims] |
| k_pass = k[..., self.rotary_ndims :] |
| q_rot = apply_rotary_emb(q_rot, freqs_cis, mode='half') |
| k_rot = apply_rotary_emb(k_rot, freqs_cis, mode='half') |
| q = torch.cat((q_rot, q_pass), dim=-1) |
| k = torch.cat((k_rot, k_pass), dim=-1) |
|
|
| q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) |
|
|
| if self.is_training: |
| N, D, I = self.n_head, self.head_dim, self.dyn_w_proj.dynamic_hidden_dim; |
| B,T,E = x.shape |
| if self.use_dcmha: |
| project_logits = True |
| project_probs = True |
| if project_probs: |
| dw_hidden, dd = (x @ self.dyn_w_proj.dw_m).split([2*2*N*(2*I), 2*2*N*1], -1) |
| dw_hidden = self.dyn_w_proj.dw_hidden_activation(dw_hidden) |
| dw_hidden = dw_hidden.view(dw_hidden.shape[:2]+(4,-1)) |
| dw = torch.einsum('B T C K, C K D -> B T C D', dw_hidden, self.dyn_w_proj.qkw_m) |
| shape = (B,T,2*2,-1,N) |
| w1, w2 = dw.view(shape).split(I,-2) |
| w1 = self.dyn_w_proj.dw1_norm(w1) |
| if self.use_sw: |
| pre_sw, post_sw = self.dyn_w_proj.sw.unbind(0) |
| else: |
| pre_sw, post_sw = None, None |
| pre_qw1, pre_kw1, post_qw1, post_kw1 = w1.unbind(2) |
| pre_qw2, pre_kw2, post_qw2, post_kw2 = w2.unbind(2) |
| qkdd = F.tanh(dd).squeeze(-1).view(shape[:-2] + (N,)) |
| pre_qdd, pre_kdd, post_qdd, post_kdd = qkdd.unbind(2) |
|
|
| y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16) |
| for i in range(T // self.q_chunk_size): |
| start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size |
| kv_start = max(0, stop - self.q_chunk_size -self.window_size) |
| _q = q[:, :, start : stop, :] |
| _k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :] |
| _atten_mask = mask[:, :, start : stop, kv_start : stop] |
| _pre_proj_dw_args = slice_dw(pre_sw, pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd, start, stop, kv_start) \ |
| if project_logits else None |
| _post_proj_dw_args = slice_dw(post_sw, post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd, start,stop,kv_start) \ |
| if project_probs else None |
| _o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args) |
| y[:,:,start:stop] = _o |
| else: |
| y = torch.zeros(B, N, T, D).to(q.device, dtype=torch.float16) |
| for i in range(T // self.q_chunk_size): |
| start, stop = i * self.q_chunk_size, (i + 1) * self.q_chunk_size |
| kv_start = max(0, stop - self.q_chunk_size -self.window_size) |
| _q = q[:, :, start : stop, :] |
| _k, _v = k[:, :, kv_start : stop, :], v[:, :, kv_start : stop, :] |
| _atten_mask = mask[:, :, start : stop, kv_start : stop] |
| _pre_proj_dw_args, _post_proj_dw_args = None, None |
| _o = _atten_context(_q, _k, _v, _atten_mask, _pre_proj_dw_args, _post_proj_dw_args) |
| y[:,:,start:stop] = _o |
| else: |
| if seqlen == 1: |
| k_mask = mask if self.window_size is None else gen_mask[:, :, :,:self.kv_cache.seq_length] |
| if fast_infer: |
| y = self._generate_fast(x, input_pos, q, k, v, k_mask) |
| else: |
| assert not self.query_wise |
| |
| pre_proj_dw_args, post_proj_dw_args, kw_new = self.dyn_w_proj(x, gen_cache=True) |
| |
| |
| kw_new = kw_new + self.dyn_w_proj.sw |
| if self.kv_cache is not None: |
| k, v, kw_out = self.kv_cache.update(input_pos, k, v, kw_val=kw_new) |
|
|
| logits = q @ k.transpose(-2, -1) * self.scale_factor |
| |
| pre_qw1, pre_qw2, pre_kw1, pre_kw2, pre_qdd, pre_kdd = pre_proj_dw_args |
| pre_qw = torch.einsum('BTGIN, BTGIM->BTNM',pre_qw1, pre_qw2) + torch.diag_embed(pre_qdd.squeeze(2)) |
| pre_w = pre_qw + kw_out[:,:,0] |
| logits = self.dyn_w_proj.pre_proj(logits, proj_w=pre_w.squeeze(1)) |
| |
| logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min) |
| probs = logits.softmax(-1) |
|
|
| |
| post_qw1, post_qw2, post_kw1, post_kw2, post_qdd, post_kdd = post_proj_dw_args |
| post_qw = torch.einsum('BTGIN, BTGIM->BTNM', post_qw1, post_qw2) + torch.diag_embed(post_qdd.squeeze(2)) |
| post_w = post_qw + kw_out[:,:,1] |
| probs = self.dyn_w_proj.post_proj(probs, proj_w=post_w.squeeze(1)) |
|
|
| y = probs @ v |
| else: |
| k_mask = mask[:,:,:,:k.shape[-2]] |
| pre_proj_dw_args, post_proj_dw_args,kw_new = self.dyn_w_proj(x, gen_cache=True) |
| kw_new = kw_new + self.dyn_w_proj.sw |
| if self.kv_cache is not None: |
| self.kv_cache.update(input_pos, k, v, kw_val=kw_new) |
| logits = q @ k.transpose(-2, -1) * self.scale_factor |
| logits = self.dyn_w_proj.pre_proj(logits, dws=pre_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) |
| logits = torch.where(k_mask, logits, torch.finfo(torch.float16).min) |
| probs = logits.softmax(-1) |
| probs = self.dyn_w_proj.post_proj(probs, dws=post_proj_dw_args, query_vec=x, key_vec=x, fast_infer=True) |
| y = probs @ v |
|
|
| y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) |
| y = self.wo(y) |
| return y |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, config: DCPythiaConfig) -> None: |
| super().__init__() |
| self.w1 = nn.Linear(config.dim, config.intermediate_size, bias=config.use_linear_bias) |
| self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=config.use_linear_bias) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.w2(F.gelu(self.w1(x))) |
|
|
| def _atten_context(query, key, value, atten_mask, pre_proj_dw_args, post_proj_dw_args): |
| logits = query @ key.transpose(-2, -1) |
| if pre_proj_dw_args is not None: logits = _cross_head_proj(logits, *pre_proj_dw_args) |
| logits = torch.where(atten_mask, logits, torch.finfo(torch.float16).min) |
| probs = logits.softmax(-1) |
| if post_proj_dw_args is not None: probs = _cross_head_proj(probs, *post_proj_dw_args) |
| o = probs @ value |
| return o |
|
|
| def _cross_head_proj(inputs, sw, qw1, qw2, kw1, kw2, qdd, kdd, loop_over_dynamic_hd=False): |
| out = inputs + torch.einsum('BNTS,NM->BMTS', inputs, sw) if sw is not None else inputs |
| for i in range(2): |
| qhidden = (inputs * qw1[..., i, :].transpose(-2, -1).unsqueeze(-1)).sum(1) |
| qout = qhidden.unsqueeze(1) * qw2[..., i, :].transpose(-2, -1).unsqueeze(-1) |
| out = out + qout |
| khidden = (inputs * kw1[..., i, :].transpose(-2, -1).unsqueeze(-2)).sum(1) |
| kout = khidden.unsqueeze(1) * kw2[..., i, :].transpose(-2, -1).unsqueeze(-2) |
| out = out + kout |
| qdout = inputs * qdd.transpose(-2, -1).unsqueeze(-1); out = out + qdout |
| kdout = inputs * kdd.transpose(-2, -1).unsqueeze(-2); out = out + kdout |
| return out |
|
|
| def find_multiple(n: int, k: int) -> int: |
| if n % k == 0: |
| return n |
| return n + k - (n % k) |
|
|
| def make_window_mask(t, window_size): |
| col_idx = torch.tile(torch.arange(t).unsqueeze(0), [t, 1]) |
| row_idx = torch.tile(torch.arange(t).unsqueeze(1), [1, t]) |
| bias_mask = (col_idx + window_size >= row_idx).tril().view(t, t) |
| return bias_mask |
|
|
| def slice_dw(sw, qw1, qw2, kw1, kw2, qdd, kdd, start, stop, kv_start): |
| return (sw, |
| qw1[:, start : stop] if qw1 is not None else None, |
| qw2[:, start : stop] if qw2 is not None else None, |
| kw1[:, kv_start : stop] if kw1 is not None else None, |
| kw2[:, kv_start : stop] if kw2 is not None else None, |
| qdd[:, start : stop] if qdd is not None else None, |
| kdd[:, kv_start : stop] if kdd is not None else None) |
|
|
| def precompute_freqs_cis( |
| seq_len: int, n_elem: int, base: int = 10000 |
| ) -> Tensor: |
| freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) |
| t = torch.arange(seq_len, device=freqs.device) |
| freqs = torch.outer(t, freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) |
| return cache.to(dtype=torch.float16) |
|
|
| def unbind(ary, n, dim=0): |
| return [torch.squeeze(a, dim=dim) for a in torch.split(ary, ary.shape[dim] // n, dim=dim)] |
|
|
| def apply_rotary_emb(x: Tensor, freqs_cis: Tensor, mode='half') -> Tensor: |
| if mode == 'half': |
| xshaped = x.float().reshape(*x.shape[:-1], 2,-1).transpose(-1,-2) |
| elif mode == 'alternative': |
| xshaped = x.float().reshape(*x.shape[:-1], -1, 2) |
| freqs_cis = freqs_cis.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) |
| x_out2 = torch.stack( |
| [ |
| xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], |
| xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], |
| ], |
| -1, |
| ) |
| x_out2 = x_out2.flatten(3) |
| return x_out2.type_as(x) |
|
|