from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions import torch import torch.nn as nn from torch.nn import functional as F from transformers.modeling_outputs import CausalLMOutput class BVVConfig(PretrainedConfig): model_type = "model_16_float" def __init__( self, vocab_size = 65536, n_embed = 16, d_model = 1024, n_head = 32, n_layer = 16, block_size = 1024, pad_id = 57344, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.block_size = block_size self.n_embed = n_embed self.d_model = d_model self.n_layer = n_layer self.n_head = n_head self.pad_id = pad_id self.scale = d_model // n_embed class RotaryEmbedding(nn.Module): def __init__(self, dim): # dim = head_dim (?? d_model!) super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) def forward(self, seq_len, device): t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum('i,j->ij', t, self.inv_freq) emb = torch.cat([freqs, freqs], dim=-1) # (seq_len, dim) return emb def apply_rotary_emb(x, rot_emb): # x: (B, n_head, seq_len, head_dim) # rot_emb: (seq_len, head_dim) seq_len = x.shape[-2] rot_emb = rot_emb[:seq_len] cos = torch.cos(rot_emb).unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, head_dim) sin = torch.sin(rot_emb).unsqueeze(0).unsqueeze(0) x_shape = x.shape x = x.reshape(*x_shape[:-1], -1, 2) # (..., head_dim/2, 2) x1 = x[..., 0] x2 = x[..., 1] cos = cos.reshape(*cos.shape[:-1], -1, 2)[..., 0] sin = sin.reshape(*sin.shape[:-1], -1, 2)[..., 0] x1_rot = x1 * cos - x2 * sin x2_rot = x1 * sin + x2 * cos x_rot = torch.stack([x1_rot, x2_rot], dim=-1) return x_rot.reshape(x_shape) class MultiHeadSelfAttention(nn.Module): def __init__(self, d_model, n_head, block_size): super().__init__() assert d_model % n_head == 0 self.d_model = d_model self.n_head = n_head self.head_dim = d_model // n_head self.q_proj = nn.Linear(d_model, d_model, bias=False) self.k_proj = nn.Linear(d_model, d_model, bias=False) self.v_proj = nn.Linear(d_model, d_model, bias=False) self.o_proj = nn.Linear(d_model, d_model, bias=False) self.rotary_emb = RotaryEmbedding(self.head_dim) self.dropout = nn.Dropout(0.0) self.register_buffer( "tril", torch.tril(torch.ones(block_size, block_size)), persistent=False ) def forward(self, x): # x: (B, T, d_model) B, T, C = x.shape q = self.q_proj(x) # (B, T, d_model) k = self.k_proj(x) v = self.v_proj(x) q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, T, head_dim) k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # Rotary embeddings rot_emb = self.rotary_emb(seq_len=T, device=x.device) # (T, head_dim) q = apply_rotary_emb(q, rot_emb) k = apply_rotary_emb(k, rot_emb) # Attention attn_scores = torch.matmul(q, k.transpose(-2, -1)) * (self.head_dim ** -0.5) # (B, n_head, T, T) attn_scores = attn_scores.masked_fill(self.tril[:T, :T] == 0, float('-inf')) attn_probs = F.softmax(attn_scores, dim=-1) attn_probs = self.dropout(attn_probs) out = torch.matmul(attn_probs, v) # (B, n_head, T, head_dim) out = out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, d_model) return self.o_proj(out) class TransformerMLP(nn.Module): def __init__(self, d_model): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.GELU(), nn.Linear(4 * d_model, d_model), nn.Dropout(0.0), ) def forward(self, x): return self.net(x) class TransformerBlock(nn.Module): def __init__(self, d_model, n_head, block_size): super().__init__() self.self_attn = MultiHeadSelfAttention(d_model, n_head, block_size) self.mlp = TransformerMLP(d_model) self.input_layernorm = nn.LayerNorm(d_model) self.post_attention_layernorm = nn.LayerNorm(d_model) def forward(self, x): x = x + self.self_attn(self.input_layernorm(x)) x = x + self.mlp(self.post_attention_layernorm(x)) return x class BVVForCausalLM(PreTrainedModel): config_class = BVVConfig def __init__(self, config): super().__init__(config) self.token_embeddings = nn.Embedding(config.vocab_size, config.n_embed) self.scale = config.scale self.transformer_layers = nn.Sequential(*[ TransformerBlock(config.d_model, n_head=config.n_head, block_size=config.block_size) for _ in range(config.n_layer) ]) self.final_layernorm = nn.LayerNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size) self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape token_emb = self.token_embeddings(idx) # (B, T, 16) token_emb = token_emb.repeat_interleave(self.scale, dim=-1) # (B, T, 1024) x = token_emb x = self.transformer_layers(x) x = self.final_layernorm(x) logits = self.lm_head(x) loss = None if targets is not None: #logits_flat = logits.view(-1, logits.size(-1)) #targets_flat = targets.view(-1) logits_flat = logits.reshape(-1, logits.size(-1)) targets_flat = targets.reshape(-1) loss = F.cross_entropy(logits_flat, targets_flat, ignore_index = 57344) return CausalLMOutput( logits=logits, loss=loss, ) def generate(self, input_ids=None, max_new_tokens=None, max_length=None, temperature=1.0, top_k=None, top_p=None, do_sample=True, pad_token_id=None, eos_token_id=None, **kwargs): if input_ids is None: raise ValueError("Input_ids must be provided") idx = input_ids if max_new_tokens is None: if max_length is not None: max_new_tokens = max_length - idx.shape[1] else: max_new_tokens = 50 with torch.no_grad(): for _ in range(max_new_tokens): idx_cond = idx[:, -self.config.block_size:] outputs = self(idx_cond) logits = outputs.logits[:, -1, :] / temperature if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float('-inf') if top_p is not None: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') probs = F.softmax(logits, dim=-1) if do_sample: idx_next = torch.multinomial(probs, num_samples=1) else: idx_next = torch.argmax(logits, dim=-1, keepdim=True) idx = torch.cat((idx, idx_next), dim=1) if eos_token_id is not None and (idx_next == eos_token_id).any(): break return idx