| # """ | |
| # AuriStream sequence model definition. | |
| # """ | |
| # import math | |
| # import inspect | |
| # import random | |
| # import torch | |
| # import torch.nn as nn | |
| # from torch.nn import functional as F | |
| # import numpy as np | |
| # from huggingface_hub import PyTorchModelHubMixin | |
| # from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput | |
| # from transformers import PreTrainedModel | |
| # from .configuration_auristream import AuriStreamConfig | |
| # class AuriStream(PreTrainedModel): | |
| # config_class = AuriStreamConfig | |
| # def __init__(self, config): | |
| # super().__init__(config) | |
| # self.config = config | |
| # # if use_rope is in the config and false, initialize a wpe layer in transformer | |
| # if hasattr(config, 'use_rope') and not config.use_rope: | |
| # self.transformer = nn.ModuleDict(dict( | |
| # wte = nn.Embedding(config.vocab_size, config.n_embd), | |
| # wpe = nn.Embedding(config.seq_len, config.n_embd), | |
| # drop = nn.Dropout(config.dropout), | |
| # h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), | |
| # ln_f = RMSNorm(config.n_embd, bias=config.bias), | |
| # )) | |
| # else: | |
| # self.transformer = nn.ModuleDict(dict( | |
| # wte = nn.Embedding(config.vocab_size, config.n_embd), | |
| # drop = nn.Dropout(config.dropout), | |
| # h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), | |
| # ln_f = RMSNorm(config.n_embd, bias=config.bias), | |
| # )) | |
| # # check if n_pred_steps is defined in the config, this is the number of linear heads for prediction | |
| # if hasattr(config, 'n_pred_steps'): | |
| # self.future_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(config.n_pred_steps - 1)]) | |
| # else: | |
| # self.future_heads = None | |
| # self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| # # init all weights | |
| # self.apply(self._init_weights) | |
| # # apply special scaled init to the residual projections, per GPT-2 paper | |
| # for pn, p in self.named_parameters(): | |
| # if pn.endswith('c_proj.weight'): | |
| # torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) | |
| # def get_num_params(self, non_embedding=True): | |
| # """ | |
| # Return the number of parameters in the model. | |
| # For non-embedding count (default), the position embeddings get subtracted. | |
| # The token embeddings would too, except due to the parameter sharing these | |
| # params are actually used as weights in the final layer, so we include them. | |
| # """ | |
| # n_params = sum(p.numel() for p in self.parameters()) | |
| # return n_params | |
| # def _init_weights(self, module): | |
| # if isinstance(module, nn.Linear): | |
| # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| # if module.bias is not None: | |
| # torch.nn.init.zeros_(module.bias) | |
| # elif isinstance(module, nn.Embedding): | |
| # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| # def forward(self, seq, tgt=None, output_hidden_states=False, return_dict=False, up_until_layer=None): | |
| # """ | |
| # Input: coch: torch.Tensor of shape (b, t) | |
| # tgt_coch: torch.Tensor of shape (b, t) or None | |
| # """ | |
| # # forward the GPT model itself | |
| # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd) | |
| # # if wpe exists in self.transformer apply leanred positional embedding | |
| # if hasattr(self.transformer, 'wpe'): | |
| # pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device) | |
| # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| # x = self.transformer.drop(tok_emb + pos_emb) | |
| # else: | |
| # x = self.transformer.drop(tok_emb) | |
| # all_hidden_states = [] | |
| # for block_idx, block in enumerate(self.transformer.h): | |
| # # Forward the block | |
| # all_hidden_states.append(x) | |
| # if up_until_layer is not None and block_idx == up_until_layer: | |
| # break | |
| # x = block(x) | |
| # # append the last hidden state if we did not exit early | |
| # if up_until_layer is None or block_idx == len(self.transformer.h) - 1: | |
| # all_hidden_states.append(x) | |
| # if output_hidden_states: | |
| # model_output = BaseModelOutput( | |
| # last_hidden_state=x, | |
| # hidden_states=all_hidden_states, | |
| # ) | |
| # return model_output | |
| # x = self.transformer.ln_f(x) | |
| # logits = self.coch_head(x) | |
| # if tgt is not None: | |
| # loss = F.cross_entropy( | |
| # logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1), | |
| # ) | |
| # # If we have more than one future head, compute the loss for each head | |
| # if self.future_heads is not None: | |
| # for i, head in enumerate(self.future_heads): | |
| # future_logits = head(x[:, :-(i+1)]) | |
| # loss += F.cross_entropy( | |
| # future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1), | |
| # ) | |
| # # divide loss by number of future heads | |
| # loss = loss / (len(self.future_heads) + 1) | |
| # if return_dict: | |
| # model_output = CausalLMOutput( | |
| # loss=loss, | |
| # logits=logits, | |
| # ) | |
| # return model_output | |
| # return logits, loss | |
| # return logits, None | |
| # def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9, | |
| # top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor: | |
| # """ | |
| # Samples an integer from the distribution of logits | |
| # Parameters: | |
| # logits (torch.FloatTensor): The logits of the distribution | |
| # temp (float): The temperature of the sampling, if 0.0, then argmax is used | |
| # top_k (int): The number of top k tokens to consider during sampling | |
| # top_p (float): The cumulative probability threshold for nucleus (top-p) sampling | |
| # Returns: | |
| # torch.LongTensor: The sampled integer | |
| # """ | |
| # # If temperature is 0.0, use argmax | |
| # if temperature == 0.0: | |
| # return torch.argmax(logits, dim=-1) | |
| # # Apply temperature | |
| # logits = logits / temperature | |
| # # Apply top-k filtering if specified | |
| # if top_k is not None: | |
| # v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| # logits[logits < v[..., [-1]]] = -float('Inf') | |
| # # Apply top-p (nucleus) filtering if specified | |
| # if top_p is not None: | |
| # # Sort the logits in descending order | |
| # sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| # # Compute the sorted softmax probabilities | |
| # sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| # # Compute the cumulative probabilities | |
| # cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| # # Create a mask for tokens to remove | |
| # sorted_indices_to_remove = cumulative_probs > top_p | |
| # # Shift the mask right to keep at least one token | |
| # sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| # sorted_indices_to_remove[..., 0] = 0 | |
| # # Scatter the mask back to the original indices | |
| # indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) | |
| # logits[indices_to_remove] = -float('Inf') | |
| # # Compute softmax probabilities | |
| # probs = F.softmax(logits, dim=-1) | |
| # # Flatten probabilities to (batch_size * sequence_length, vocab_size) | |
| # flat_probs = probs.view(-1, probs.size(-1)) | |
| # # Sample from the distribution | |
| # sampled = torch.multinomial(flat_probs, num_samples=1) | |
| # # Reshape to original shape except for the last dimension | |
| # sampled = sampled.view(*logits.shape[:-1]) | |
| # return sampled | |
| # @torch.no_grad() | |
| # def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0, | |
| # top_k=500, top_p=0.5, seed=None): | |
| # """ | |
| # Parameters: | |
| # seq: torch.Tensor of shape (b, t, n_freq_bins) | |
| # Input cochleagram to use for generation | |
| # n_tokens: int | |
| # Number of time bins to predict | |
| # temp: float | |
| # Temperature for sampling logits | |
| # seed: int | |
| # Random seed for sampling | |
| # Returns: | |
| # pred_coch: torch.Tensor of shape (b, t, n_freq_bins) | |
| # The predicted cochleagram | |
| # all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins) | |
| # The logits for each time step | |
| # all_embs: (optional if return_embs is not None) list of torch.Tensor | |
| # The embeddings for each transformer block | |
| # """ | |
| # # Set seed if provided | |
| # if seed is not None: | |
| # random.seed(seed) | |
| # np.random.seed(seed) | |
| # torch.manual_seed(seed) | |
| # # make a list of logits to return | |
| # all_logits = [] | |
| # device = seq.device | |
| # # grab shape of the cochleagram | |
| # b, t = seq.size() | |
| # # TODO: double check this works then delete the block bellow: | |
| # # pass the given input through the model to get the predictions and cache | |
| # # the k and v values for each transformer block in the process | |
| # # pos = torch.arange(0, t, dtype=torch.long, device=device) | |
| # # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd) | |
| # # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| # # x = self.transformer.drop(tok_emb + pos_emb) | |
| # #### Embed conditioning sequence into KV cache | |
| # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd) | |
| # # if wpe exists in self.transformer apply leanred positional embedding | |
| # if hasattr(self.transformer, 'wpe'): | |
| # pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device) | |
| # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| # x = self.transformer.drop(tok_emb + pos_emb) | |
| # else: | |
| # x = self.transformer.drop(tok_emb) | |
| # # Initialize list to store k and v for each transformer block | |
| # k_list = [] | |
| # v_list = [] | |
| # for block_idx, block in enumerate(self.transformer.h): | |
| # # Pass through the transformer block, and store k and v | |
| # x, k, v = block(x, pos=pos, return_kv=True) | |
| # k_list.append(k) | |
| # v_list.append(v) | |
| # # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head) | |
| # k_cache = torch.stack(k_list, dim=0) | |
| # v_cache = torch.stack(v_list, dim=0) | |
| # # Pass through the final layer norm | |
| # x = self.transformer.ln_f(x) | |
| # # First prediction of the model is the decoding of the last time bin | |
| # logits = self.coch_head(x[:, [-1]]) | |
| # predictions = [self.sample_logits(logits, temperature=temp)] | |
| # all_logits.append(logits) | |
| # ### Predict future tokens | |
| # # Now we pass the last time bin through the model to predict the next time bin | |
| # # we subtract 1 from max_new_tokens because we already predicted the first time bin | |
| # # using the last embedding of the input | |
| # for i in range(n_tokens-1): | |
| # # TODO: double check this works then delete the block bellow: | |
| # # # Get the emb and pos embedding of just the last token | |
| # # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t) | |
| # # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd) | |
| # # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| # # x = self.transformer.drop(tok_emb + pos_emb) | |
| # # Get the emb and pos embedding of just the last token | |
| # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd) | |
| # # if wpe exists in self.transformer apply leanred positional embedding | |
| # if hasattr(self.transformer, 'wpe'): | |
| # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t) | |
| # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| # x = self.transformer.drop(tok_emb + pos_emb) | |
| # else: | |
| # x = self.transformer.drop(tok_emb) | |
| # # Pass through transformer block | |
| # k_list = [] | |
| # v_list = [] | |
| # for block_idx, block in enumerate(self.transformer.h): | |
| # x, k, v = block(x, pos=pos, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx]) | |
| # k_list.append(k) | |
| # v_list.append(v) | |
| # x = self.transformer.ln_f(x) | |
| # # create the cache with the new embeddings | |
| # k_cache = torch.stack(k_list, dim=0) | |
| # v_cache = torch.stack(v_list, dim=0) | |
| # # predict next time bin | |
| # logits = self.coch_head(x) | |
| # predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)) | |
| # print(f"logits {logits.argmax()}") | |
| # lk | |
| # all_logits.append(logits) | |
| # pred_coch = torch.cat(predictions, dim=1) | |
| # all_logits = torch.cat(all_logits, dim=1) | |
| # return pred_coch, all_logits | |
| # def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): | |
| # # start with all of the candidate parameters | |
| # param_dict = {pn: p for pn, p in self.named_parameters()} | |
| # # filter out those that do not require grad | |
| # param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} | |
| # # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. | |
| # # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. | |
| # decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] | |
| # nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] | |
| # optim_groups = [ | |
| # {'params': decay_params, 'weight_decay': weight_decay}, | |
| # {'params': nodecay_params, 'weight_decay': 0.0} | |
| # ] | |
| # num_decay_params = sum(p.numel() for p in decay_params) | |
| # num_nodecay_params = sum(p.numel() for p in nodecay_params) | |
| # print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") | |
| # print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") | |
| # # Create AdamW optimizer and use the fused version if it is available | |
| # fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters | |
| # use_fused = fused_available and device_type == 'cuda' | |
| # extra_args = dict(fused=True) if use_fused else dict() | |
| # optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) | |
| # print(f"using fused AdamW: {use_fused}") | |
| # return optimizer | |
| # def estimate_mfu(self, fwdbwd_per_iter, T, dt, gpu_type='A40'): | |
| # """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ | |
| # # first estimate the number of flops we do per iteration. | |
| # # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 | |
| # N = self.unsharded_param_count | |
| # cfg = self.config | |
| # L, H, Q = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head | |
| # # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size | |
| # flops_per_token = 6*N + 12*L*H*Q*T | |
| # flops_per_fwdbwd = flops_per_token * T | |
| # flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter | |
| # # express our flops throughput as ratio of A100 bfloat16 peak flops | |
| # flops_achieved = flops_per_iter * (1.0/dt) # per second | |
| # # grab promised flops based on GPU type | |
| # if gpu_type == 'A40': | |
| # flops_promised = 149.7e12 # A40 GPU bfloat16 peak flops is 149.7 TFLOPS | |
| # elif gpu_type == 'A100': | |
| # flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS | |
| # elif gpu_type == 'H100': | |
| # flops_promised = 756e12 # H100 GPU bfloat16 peak flops is 756 TFLOPS | |
| # elif gpu_type == 'TPUv4': | |
| # flops_promised = 275e12 | |
| # elif gpu_type == 'TPUv5e': | |
| # flops_promised = 197e12 | |
| # mfu = flops_achieved / flops_promised | |
| # return mfu | |
| # ######################################################### | |
| # ##### Layer Definitions ##### | |
| # ######################################################### | |
| # class Block(nn.Module): | |
| # def __init__(self, config): | |
| # super().__init__() | |
| # self.attn = CausalSelfAttention(config) | |
| # self.mlp = MLP(config) | |
| # self.attn_scale = 1.0 # (1 / (2 * config.n_layer)**0.5) | |
| # self.norm1 = RMSNorm(config.n_embd, bias=config.bias) | |
| # self.norm2 = RMSNorm(config.n_embd, bias=config.bias) | |
| # def forward(self, x, pos=None, return_kv=False, k_cache=None, v_cache=None): | |
| # # If we are given a key and value cache, we will use the pre-computed values to minimize | |
| # # the computation cost | |
| # if k_cache is not None and v_cache is not None: | |
| # # Pass the key and value cache to the attention layer, obtain new key and value caches | |
| # x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), pos=pos, k_cache=k_cache, v_cache=v_cache) | |
| # x = x + x_attn | |
| # x = x + self.mlp(self.norm2(x)) | |
| # return x, k, v | |
| # # We might want to encode the caches of a whole block of keys and values at once using the | |
| # # fast flash attention impelmentation while still returning the key and value caches | |
| # elif return_kv: | |
| # # Pass the key and value cache to the attention layer, obtain new key and value caches | |
| # x_attn, k, v = self.attn(self.norm1(x), return_kv=True) | |
| # x = x + x_attn | |
| # x = x + self.mlp(self.norm2(x)) | |
| # return x, k, v | |
| # x = x + self.attn_scale * self.attn(self.norm1(x)) | |
| # x = x + self.mlp(self.norm2(x)) | |
| # return x | |
| # class CausalSelfAttention(nn.Module): | |
| # def __init__(self, config): | |
| # super().__init__() | |
| # self.n_head = config.n_head | |
| # self.n_embd = config.n_embd | |
| # self.head_dim = self.n_embd // self.n_head | |
| # assert self.n_embd % self.n_head == 0 | |
| # # key, query, value projections for all heads, but in a batch | |
| # self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False) | |
| # # output projection | |
| # self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
| # rope_theta = 500000 | |
| # if hasattr(config, 'rope_theta') and config.rope_theta is not None: | |
| # rope_theta = config.rope_theta | |
| # self.rotary = Rotary(self.head_dim, base=rope_theta) | |
| # if hasattr(config, 'use_rope') and not config.use_rope: | |
| # self.rotary = None | |
| # def forward(self, x, return_kv=False, return_attn_maps=False): | |
| # B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) | |
| # # calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
| # qkv = self.c_attn(x) | |
| # q, k, v = qkv.split(self.n_embd, dim=2) | |
| # k = k.view(B, T, self.n_head, self.head_dim) | |
| # q = q.view(B, T, self.n_head, self.head_dim) | |
| # v = v.view(B, T, self.n_head, self.head_dim) | |
| # if self.rotary is not None: | |
| # cos, sin = self.rotary(q) | |
| # q = apply_rotary_emb(q, cos, sin) | |
| # k = apply_rotary_emb(k, cos, sin) | |
| # if not return_kv and not return_attn_maps: | |
| # y = F.scaled_dot_product_attention( | |
| # q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), | |
| # is_causal=True) | |
| # else: | |
| # # manual implementation of attention | |
| # q = q.transpose(1, 2) | |
| # k = k.transpose(1, 2) | |
| # v = v.transpose(1, 2) | |
| # att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1))) | |
| # mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device) | |
| # mask = mask.view(1, 1, T, T) | |
| # masked_att = att.masked_fill(mask, float('-inf')) | |
| # # upcast to float32 for numerical stability, as per llama implementation | |
| # masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype) | |
| # # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
| # y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v) | |
| # y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
| # # output projection | |
| # y = self.c_proj(y) | |
| # # return attention maps if requested | |
| # if return_attn_maps: | |
| # return y, F.softmax(att, dim=-1) | |
| # # return key and value caches if requested | |
| # if return_kv: | |
| # return y, k, v | |
| # return y | |
| # def kv_cache_forward( | |
| # self, | |
| # x: torch.Tensor, | |
| # pos: torch.Tensor, | |
| # k_cache: torch.Tensor | None = None, | |
| # v_cache: torch.Tensor | None = None, | |
| # return_attn_maps: bool = False, | |
| # ): | |
| # B, T, C = x.size() | |
| # q, k, v = self.c_attn(x).split(self.n_embd, dim=2) | |
| # q = q.view(B, T, self.n_head, self.head_dim) # (B, T, n_head, d) | |
| # k = k.view(B, T, self.n_head, self.head_dim) | |
| # v = v.view(B, T, self.n_head, self.head_dim) | |
| # if self.rotary is not None: | |
| # cos, sin = self.rotary(q, t=pos) # cos/sin match (B, T, n_head, d) | |
| # q = apply_rotary_emb(q, cos, sin) | |
| # k = apply_rotary_emb(k, cos, sin) | |
| # q = q.transpose(1, 2) # (B, n_head, T, d) | |
| # k = k.transpose(1, 2) | |
| # v = v.transpose(1, 2) | |
| # if k_cache is not None: | |
| # k = torch.cat([k_cache, k], dim=2) # time dim grows | |
| # if v_cache is not None: | |
| # v = torch.cat([v_cache, v], dim=2) | |
| # if not return_attn_maps: | |
| # y = F.scaled_dot_product_attention( | |
| # q, k, v, | |
| # is_causal=True) | |
| # else: | |
| # # manual implementation of attention | |
| # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
| # att = F.softmax(att, dim=-1) | |
| # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
| # y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
| # # output projection | |
| # y = self.c_proj(y) | |
| # y = y.transpose(1, 2).contiguous().view(B, T, C) | |
| # y = self.c_proj(y) | |
| # return y, k, v | |
| # class MLP(nn.Module): | |
| # def __init__(self, config): | |
| # super().__init__() | |
| # self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) | |
| # self.gelu = nn.GELU() | |
| # self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) | |
| # self.dropout = nn.Dropout(config.dropout) | |
| # def forward(self, x): | |
| # x = self.c_fc(x) | |
| # x = self.gelu(x) | |
| # x = self.c_proj(x) | |
| # x = self.dropout(x) | |
| # return x | |
| # class Rotary(torch.nn.Module): | |
| # def __init__(self, dim, base=500000, learned=True): | |
| # super().__init__() | |
| # # Compute the base inverse frequencies as before. | |
| # inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| # # If learned is True, register as a parameter; otherwise, as a buffer. | |
| # if learned: | |
| # # Initialize randomly and register as a parameter. | |
| # self.inv_freq = torch.nn.Parameter(inv_freq) | |
| # nn.init.normal_(self.inv_freq, mean=0.0, std=0.02) | |
| # else: | |
| # self.register_buffer("inv_freq", inv_freq) | |
| # self.learned = learned # (optional) Save the flag if needed later | |
| # def forward(self, x, t=None): | |
| # seq_len = x.shape[1] | |
| # if t is None: | |
| # # Create a tensor of positions. | |
| # t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | |
| # # Outer product to compute angles; this uses the (possibly learnable) frequencies. | |
| # freqs = torch.outer(t, self.inv_freq).to(x.device) | |
| # cos_cached = freqs.cos() | |
| # sin_cached = freqs.sin() | |
| # return cos_cached[None, :, None, :], sin_cached[None, :, None, :] | |
| # def apply_rotary_emb(x, cos, sin): | |
| # assert x.ndim == 4 # multihead attention expected | |
| # d = x.shape[3] // 2 | |
| # x1 = x[..., :d] | |
| # x2 = x[..., d:] | |
| # y1 = x1 * cos + x2 * sin | |
| # y2 = x1 * (-sin) + x2 * cos | |
| # return torch.cat([y1, y2], dim=3) | |
| # class RMSNorm(nn.Module): | |
| # """ Root Mean Square Normalization """ | |
| # def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6): | |
| # super().__init__() | |
| # self.eps = eps | |
| # self.weight = nn.Parameter(torch.ones(dim)) if weight else None | |
| # def _norm(self, x): | |
| # return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| # def forward(self, x): | |
| # output = self._norm(x.float()).type_as(x) | |
| # if self.weight is not None: | |
| # return output * self.weight | |
| # return output | |
| """ | |
| AuriStream sequence model definition. | |
| """ | |
| import math | |
| import inspect | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import numpy as np | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput | |
| from transformers import PreTrainedModel | |
| from .configuration_auristream import AuriStreamConfig | |
| class AuriStream(PreTrainedModel): | |
| config_class = AuriStreamConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| # if use_rope is in the config and false, initialize a wpe layer in transformer | |
| if hasattr(config, 'use_rope') and not config.use_rope: | |
| self.transformer = nn.ModuleDict(dict( | |
| wte = nn.Embedding(config.vocab_size, config.n_embd), | |
| wpe = nn.Embedding(config.seq_len, config.n_embd), | |
| drop = nn.Dropout(config.dropout), | |
| h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), | |
| ln_f = RMSNorm(config.n_embd, bias=config.bias), | |
| )) | |
| else: | |
| self.transformer = nn.ModuleDict(dict( | |
| wte = nn.Embedding(config.vocab_size, config.n_embd), | |
| drop = nn.Dropout(config.dropout), | |
| h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), | |
| ln_f = RMSNorm(config.n_embd, bias=config.bias), | |
| )) | |
| # check if n_pred_steps is defined in the config, this is the number of linear heads for prediction | |
| if hasattr(config, 'n_pred_steps'): | |
| self.future_heads = nn.ModuleList([nn.Linear(config.n_embd, config.vocab_size, bias=False) for _ in range(config.n_pred_steps - 1)]) | |
| else: | |
| self.future_heads = None | |
| self.coch_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| # init all weights | |
| self.apply(self._init_weights) | |
| # apply special scaled init to the residual projections, per GPT-2 paper | |
| for pn, p in self.named_parameters(): | |
| if pn.endswith('c_proj.weight'): | |
| torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) | |
| def get_num_params(self, non_embedding=True): | |
| """ | |
| Return the number of parameters in the model. | |
| For non-embedding count (default), the position embeddings get subtracted. | |
| The token embeddings would too, except due to the parameter sharing these | |
| params are actually used as weights in the final layer, so we include them. | |
| """ | |
| n_params = sum(p.numel() for p in self.parameters()) | |
| return n_params | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, seq, tgt=None, output_hidden_states=False, return_dict=False, up_until_layer=None): | |
| """ | |
| Input: coch: torch.Tensor of shape (b, t) | |
| tgt_coch: torch.Tensor of shape (b, t) or None | |
| """ | |
| # forward the GPT model itself | |
| tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd) | |
| # if wpe exists in self.transformer apply leanred positional embedding | |
| if hasattr(self.transformer, 'wpe'): | |
| pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device) | |
| pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| x = self.transformer.drop(tok_emb + pos_emb) | |
| else: | |
| x = self.transformer.drop(tok_emb) | |
| all_hidden_states = [] | |
| for block_idx, block in enumerate(self.transformer.h): | |
| # Forward the block | |
| all_hidden_states.append(x) | |
| if up_until_layer is not None and block_idx == up_until_layer: | |
| break | |
| x = block(x) | |
| # append the last hidden state if we did not exit early | |
| if up_until_layer is None or block_idx == len(self.transformer.h) - 1: | |
| all_hidden_states.append(x) | |
| if output_hidden_states: | |
| model_output = BaseModelOutput( | |
| last_hidden_state=x, | |
| hidden_states=all_hidden_states, | |
| ) | |
| return model_output | |
| x = self.transformer.ln_f(x) | |
| logits = self.coch_head(x) | |
| if tgt is not None: | |
| loss = F.cross_entropy( | |
| logits.reshape(-1, self.config.vocab_size), tgt.reshape(-1), | |
| ) | |
| # If we have more than one future head, compute the loss for each head | |
| if self.future_heads is not None: | |
| for i, head in enumerate(self.future_heads): | |
| future_logits = head(x[:, :-(i+1)]) | |
| loss += F.cross_entropy( | |
| future_logits.reshape(-1, self.config.vocab_size), tgt[:, (i+1):].reshape(-1), | |
| ) | |
| # divide loss by number of future heads | |
| loss = loss / (len(self.future_heads) + 1) | |
| if return_dict: | |
| model_output = CausalLMOutput( | |
| loss=loss, | |
| logits=logits, | |
| ) | |
| return model_output | |
| return logits, loss | |
| return logits, None | |
| def sample_logits(self, logits: torch.FloatTensor, temperature: float = 0.9, | |
| top_k: int = 500, top_p: float = 0.5) -> torch.LongTensor: | |
| """ | |
| Samples an integer from the distribution of logits | |
| Parameters: | |
| logits (torch.FloatTensor): The logits of the distribution | |
| temp (float): The temperature of the sampling, if 0.0, then argmax is used | |
| top_k (int): The number of top k tokens to consider during sampling | |
| top_p (float): The cumulative probability threshold for nucleus (top-p) sampling | |
| Returns: | |
| torch.LongTensor: The sampled integer | |
| """ | |
| # If temperature is 0.0, use argmax | |
| if temperature == 0.0: | |
| return torch.argmax(logits, dim=-1) | |
| # Apply temperature | |
| logits = logits / temperature | |
| # Apply top-k filtering if specified | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[..., [-1]]] = -float('Inf') | |
| # Apply top-p (nucleus) filtering if specified | |
| if top_p is not None: | |
| # Sort the logits in descending order | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
| # Compute the sorted softmax probabilities | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| # Compute the cumulative probabilities | |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) | |
| # Create a mask for tokens to remove | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the mask right to keep at least one token | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| # Scatter the mask back to the original indices | |
| indices_to_remove = sorted_indices_to_remove.scatter(dim=-1, index=sorted_indices, src=sorted_indices_to_remove) | |
| logits[indices_to_remove] = -float('Inf') | |
| # Compute softmax probabilities | |
| probs = F.softmax(logits, dim=-1) | |
| # Flatten probabilities to (batch_size * sequence_length, vocab_size) | |
| flat_probs = probs.view(-1, probs.size(-1)) | |
| # Sample from the distribution | |
| sampled = torch.multinomial(flat_probs, num_samples=1) | |
| # Reshape to original shape except for the last dimension | |
| sampled = sampled.view(*logits.shape[:-1]) | |
| return sampled | |
| def generate(self, seq: torch.Tensor, n_tokens: int = 1, temp=1.0, | |
| top_k=500, top_p=0.5, seed=None): | |
| """ | |
| Parameters: | |
| seq: torch.Tensor of shape (b, t, n_freq_bins) | |
| Input cochleagram to use for generation | |
| n_tokens: int | |
| Number of time bins to predict | |
| temp: float | |
| Temperature for sampling logits | |
| seed: int | |
| Random seed for sampling | |
| Returns: | |
| pred_coch: torch.Tensor of shape (b, t, n_freq_bins) | |
| The predicted cochleagram | |
| all_logits: (optional if return_logits is True) torch.Tensor of shape (b, n_tokens, n_freq_bins) | |
| The logits for each time step | |
| all_embs: (optional if return_embs is not None) list of torch.Tensor | |
| The embeddings for each transformer block | |
| """ | |
| # Set seed if provided | |
| if seed is not None: | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| # make a list of logits to return | |
| all_logits = [] | |
| device = seq.device | |
| # grab shape of the cochleagram | |
| b, t = seq.size() | |
| # TODO: double check this works then delete the block bellow: | |
| # pass the given input through the model to get the predictions and cache | |
| # the k and v values for each transformer block in the process | |
| # pos = torch.arange(0, t, dtype=torch.long, device=device) | |
| # tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd) | |
| # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| # x = self.transformer.drop(tok_emb + pos_emb) | |
| #### Embed conditioning sequence into KV cache | |
| tok_emb = self.transformer.wte(seq) # token embeddings of shape (b, t, n_embd) | |
| # if wpe exists in self.transformer apply leanred positional embedding | |
| if hasattr(self.transformer, 'wpe'): | |
| pos = torch.arange(0, seq.size(1), dtype=torch.long, device=seq.device) | |
| pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| x = self.transformer.drop(tok_emb + pos_emb) | |
| else: | |
| x = self.transformer.drop(tok_emb) | |
| # Initialize list to store k and v for each transformer block | |
| k_list = [] | |
| v_list = [] | |
| for block_idx, block in enumerate(self.transformer.h): | |
| # Pass through the transformer block, and store k and v | |
| x, k, v = block(x, return_kv=True) | |
| k_list.append(k) | |
| v_list.append(v) | |
| # k_cache and v_cache have shape (n_layer, b, n_head, t, n_embd//n_head) | |
| k_cache = torch.stack(k_list, dim=0) | |
| v_cache = torch.stack(v_list, dim=0) | |
| # Pass through the final layer norm | |
| x = self.transformer.ln_f(x) | |
| # First prediction of the model is the decoding of the last time bin | |
| logits = self.coch_head(x[:, [-1]]) | |
| predictions = [self.sample_logits(logits, temperature=temp)] | |
| all_logits.append(logits) | |
| ### Predict future tokens | |
| # Now we pass the last time bin through the model to predict the next time bin | |
| # we subtract 1 from max_new_tokens because we already predicted the first time bin | |
| # using the last embedding of the input | |
| for i in range(n_tokens-1): | |
| # TODO: double check this works then delete the block bellow: | |
| # # Get the emb and pos embedding of just the last token | |
| # pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t) | |
| # tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd) | |
| # pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| # x = self.transformer.drop(tok_emb + pos_emb) | |
| # Get the emb and pos embedding of just the last token | |
| tok_emb = self.transformer.wte(predictions[-1]) # token embeddings of shape (b, t, n_embd) | |
| # if wpe exists in self.transformer apply leanred positional embedding | |
| if hasattr(self.transformer, 'wpe'): | |
| pos = torch.arange(t+i, t+i+1, dtype=torch.long, device=device) # shape (t) | |
| pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) | |
| x = self.transformer.drop(tok_emb + pos_emb) | |
| else: | |
| x = self.transformer.drop(tok_emb) | |
| # Pass through transformer block | |
| k_list = [] | |
| v_list = [] | |
| for block_idx, block in enumerate(self.transformer.h): | |
| x, k, v = block(x, k_cache=k_cache[block_idx], v_cache=v_cache[block_idx]) | |
| k_list.append(k) | |
| v_list.append(v) | |
| x = self.transformer.ln_f(x) | |
| # create the cache with the new embeddings | |
| k_cache = torch.stack(k_list, dim=0) | |
| v_cache = torch.stack(v_list, dim=0) | |
| # predict next time bin | |
| logits = self.coch_head(x) | |
| predictions.append(self.sample_logits(logits, temperature=temp, top_k=top_k, top_p=top_p)) | |
| all_logits.append(logits) | |
| pred_coch = torch.cat(predictions, dim=1) | |
| all_logits = torch.cat(all_logits, dim=1) | |
| return pred_coch, all_logits | |
| def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): | |
| # start with all of the candidate parameters | |
| param_dict = {pn: p for pn, p in self.named_parameters()} | |
| # filter out those that do not require grad | |
| param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} | |
| # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. | |
| # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. | |
| decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] | |
| nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] | |
| optim_groups = [ | |
| {'params': decay_params, 'weight_decay': weight_decay}, | |
| {'params': nodecay_params, 'weight_decay': 0.0} | |
| ] | |
| num_decay_params = sum(p.numel() for p in decay_params) | |
| num_nodecay_params = sum(p.numel() for p in nodecay_params) | |
| print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") | |
| print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") | |
| # Create AdamW optimizer and use the fused version if it is available | |
| fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters | |
| use_fused = fused_available and device_type == 'cuda' | |
| extra_args = dict(fused=True) if use_fused else dict() | |
| optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) | |
| print(f"using fused AdamW: {use_fused}") | |
| return optimizer | |
| def estimate_mfu(self, fwdbwd_per_iter, T, dt, gpu_type='A40'): | |
| """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ | |
| # first estimate the number of flops we do per iteration. | |
| # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 | |
| N = self.unsharded_param_count | |
| cfg = self.config | |
| L, H, Q = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head | |
| # L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size | |
| flops_per_token = 6*N + 12*L*H*Q*T | |
| flops_per_fwdbwd = flops_per_token * T | |
| flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter | |
| # express our flops throughput as ratio of A100 bfloat16 peak flops | |
| flops_achieved = flops_per_iter * (1.0/dt) # per second | |
| # grab promised flops based on GPU type | |
| if gpu_type == 'A40': | |
| flops_promised = 149.7e12 # A40 GPU bfloat16 peak flops is 149.7 TFLOPS | |
| elif gpu_type == 'A100': | |
| flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS | |
| elif gpu_type == 'H100': | |
| flops_promised = 756e12 # H100 GPU bfloat16 peak flops is 756 TFLOPS | |
| elif gpu_type == 'TPUv4': | |
| flops_promised = 275e12 | |
| elif gpu_type == 'TPUv5e': | |
| flops_promised = 197e12 | |
| mfu = flops_achieved / flops_promised | |
| return mfu | |
| ######################################################### | |
| ##### Layer Definitions ##### | |
| ######################################################### | |
| class Block(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.attn = CausalSelfAttention(config) | |
| self.mlp = MLP(config) | |
| self.attn_scale = 1.0 # (1 / (2 * config.n_layer)**0.5) | |
| self.norm1 = RMSNorm(config.n_embd, bias=config.bias) | |
| self.norm2 = RMSNorm(config.n_embd, bias=config.bias) | |
| def forward(self, x, return_kv=False, k_cache=None, v_cache=None): | |
| # If we are given a key and value cache, we will use the pre-computed values to minimize | |
| # the computation cost | |
| if k_cache is not None and v_cache is not None: | |
| # Pass the key and value cache to the attention layer, obtain new key and value caches | |
| x_attn, k, v = self.attn.kv_cache_forward(self.norm1(x), k_cache, v_cache) | |
| x = x + x_attn | |
| x = x + self.mlp(self.norm2(x)) | |
| return x, k, v | |
| # We might want to encode the caches of a whole block of keys and values at once using the | |
| # fast flash attention impelmentation while still returning the key and value caches | |
| elif return_kv: | |
| # Pass the key and value cache to the attention layer, obtain new key and value caches | |
| x_attn, k, v = self.attn(self.norm1(x), return_kv=True) | |
| x = x + x_attn | |
| x = x + self.mlp(self.norm2(x)) | |
| return x, k, v | |
| x = x + self.attn_scale * self.attn(self.norm1(x)) | |
| x = x + self.mlp(self.norm2(x)) | |
| return x | |
| class CausalSelfAttention(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.n_head = config.n_head | |
| self.n_embd = config.n_embd | |
| self.head_dim = self.n_embd // self.n_head | |
| assert self.n_embd % self.n_head == 0 | |
| # key, query, value projections for all heads, but in a batch | |
| self.c_attn = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False) | |
| # output projection | |
| self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
| rope_theta = 500000 | |
| if hasattr(config, 'rope_theta') and config.rope_theta is not None: | |
| rope_theta = config.rope_theta | |
| self.rotary = Rotary(self.head_dim, base=rope_theta) | |
| if hasattr(config, 'use_rope') and not config.use_rope: | |
| self.rotary = None | |
| def forward(self, x, return_kv=False, return_attn_maps=False): | |
| B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) | |
| # calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
| qkv = self.c_attn(x) | |
| q, k, v = qkv.split(self.n_embd, dim=2) | |
| k = k.view(B, T, self.n_head, self.head_dim) | |
| q = q.view(B, T, self.n_head, self.head_dim) | |
| v = v.view(B, T, self.n_head, self.head_dim) | |
| if self.rotary is not None: | |
| cos, sin = self.rotary(q) | |
| q = apply_rotary_emb(q, cos, sin) | |
| k = apply_rotary_emb(k, cos, sin) | |
| if not return_kv and not return_attn_maps: | |
| y = F.scaled_dot_product_attention( | |
| q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), | |
| is_causal=True) | |
| else: | |
| # manual implementation of attention | |
| q = q.transpose(1, 2) | |
| k = k.transpose(1, 2) | |
| v = v.transpose(1, 2) | |
| att = torch.einsum('bnsh,bnkh->bnsk', q, k) * (1.0 / math.sqrt(k.size(-1))) | |
| mask = torch.triu(torch.ones(T, T), diagonal=1).to(dtype=torch.bool).to(x.device) | |
| mask = mask.view(1, 1, T, T) | |
| masked_att = att.masked_fill(mask, float('-inf')) | |
| # upcast to float32 for numerical stability, as per llama implementation | |
| masked_att = F.softmax(masked_att, dim=-1, dtype=torch.float32).to(q.dtype) | |
| # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
| y = torch.einsum('bnsk,bnkh->bnsh', masked_att, v) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
| # output projection | |
| y = self.c_proj(y) | |
| # return attention maps if requested | |
| if return_attn_maps: | |
| return y, F.softmax(att, dim=-1) | |
| # return key and value caches if requested | |
| if return_kv: | |
| return y, k, v | |
| return y | |
| def kv_cache_forward(self, x, k_cache=None, v_cache=None): | |
| B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) | |
| # calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) | |
| k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) | |
| # append cached keys and values with new keys and values | |
| if k_cache is not None: | |
| k = torch.cat((k_cache, k), dim=2) | |
| if v_cache is not None: | |
| v = torch.cat((v_cache, v), dim=2) | |
| # manual implementation of attention | |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) | |
| att = F.softmax(att, dim=-1) | |
| y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) | |
| y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side | |
| # output projection | |
| y = self.c_proj(y) | |
| return y, k, v | |
| class MLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) | |
| self.gelu = nn.GELU() | |
| self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x): | |
| x = self.c_fc(x) | |
| x = self.gelu(x) | |
| x = self.c_proj(x) | |
| x = self.dropout(x) | |
| return x | |
| class Rotary(torch.nn.Module): | |
| def __init__(self, dim, base=500000, learned=True): | |
| super().__init__() | |
| # Compute the base inverse frequencies as before. | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
| # If learned is True, register as a parameter; otherwise, as a buffer. | |
| if learned: | |
| # Initialize randomly and register as a parameter. | |
| self.inv_freq = torch.nn.Parameter(inv_freq) | |
| nn.init.normal_(self.inv_freq, mean=0.0, std=0.02) | |
| else: | |
| self.register_buffer("inv_freq", inv_freq) | |
| self.learned = learned # (optional) Save the flag if needed later | |
| def forward(self, x): | |
| seq_len = x.shape[1] | |
| # Create a tensor of positions. | |
| t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | |
| # Outer product to compute angles; this uses the (possibly learnable) frequencies. | |
| freqs = torch.outer(t, self.inv_freq).to(x.device) | |
| cos_cached = freqs.cos() | |
| sin_cached = freqs.sin() | |
| return cos_cached[None, :, None, :], sin_cached[None, :, None, :] | |
| def apply_rotary_emb(x, cos, sin): | |
| assert x.ndim == 4 # multihead attention expected | |
| d = x.shape[3] // 2 | |
| x1 = x[..., :d] | |
| x2 = x[..., d:] | |
| y1 = x1 * cos + x2 * sin | |
| y2 = x1 * (-sin) + x2 * cos | |
| return torch.cat([y1, y2], dim=3) | |
| class RMSNorm(nn.Module): | |
| """ Root Mean Square Normalization """ | |
| def __init__(self, dim: int, weight: bool = True, bias: bool = False, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) if weight else None | |
| def _norm(self, x): | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x): | |
| output = self._norm(x.float()).type_as(x) | |
| if self.weight is not None: | |
| return output * self.weight | |
| return output | |