| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional, Tuple, List, Union |
| from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.cache_utils import Cache, DynamicCache |
|
|
|
|
| |
|
|
| class DotLMConfig(PretrainedConfig): |
| model_type = "dotlm" |
|
|
| def __init__( |
| self, |
| vocab_size=16384, |
| d_model=768, |
| hidden_dim=2048, |
| num_hidden_layers=24, |
| n_heads=6, |
| n_kv_heads=2, |
| context_len=4096, |
| theta_base=10000.0, |
| norm_eps=1e-6, |
| initializer_range=0.02, |
| tie_word_embeddings=True, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.hidden_dim = hidden_dim |
| self.num_hidden_layers = num_hidden_layers |
| self.n_heads = n_heads |
| self.n_kv_heads = n_kv_heads |
| self.context_len = context_len |
| self.theta_base = theta_base |
| self.norm_eps = norm_eps |
| self.initializer_range = initializer_range |
| self.tie_word_embeddings = tie_word_embeddings |
| self.use_cache = kwargs.get("use_cache", True) |
| self.pad_token_id = kwargs.get("pad_token_id", 0) |
| self.bos_token_id = kwargs.get("bos_token_id", None) |
| self.eos_token_id = kwargs.get("eos_token_id", 3) |
|
|
|
|
| |
|
|
| def precompute_freqs_cis(dim, context_len, theta_base=10000.0): |
| theta = 1.0 / (theta_base ** (torch.arange(0, dim, 2) / dim)) |
| seq_ids = torch.arange(context_len, dtype=torch.float32) |
| m_theta = torch.outer(seq_ids, theta) |
| m_theta = torch.cat([m_theta, m_theta], dim=-1) |
| return torch.cos(m_theta), torch.sin(m_theta) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, d_model, hidden_dim): |
| super().__init__() |
| self.W = nn.Linear(d_model, hidden_dim, bias=False) |
| self.V = nn.Linear(d_model, hidden_dim, bias=False) |
| self.W2 = nn.Linear(hidden_dim, d_model, bias=False) |
| self.silu = nn.SiLU() |
|
|
| def forward(self, x): |
| return self.W2(self.silu(self.W(x)) * self.V(x)) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.eps = eps |
| self.scale = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| x = x * torch.rsqrt(torch.pow(x, 2).mean(dim=-1, keepdim=True) + self.eps) |
| return x * self.scale |
|
|
|
|
| class RoPE(nn.Module): |
| def forward(self, x, cos, sin): |
| batch_size, num_heads, seq_len, head_dim = x.shape |
| x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2 :] |
| x_rotated = torch.cat([-x2, x1], dim=-1) |
| return x * cos + x_rotated * sin |
|
|
|
|
| class GroupedQueryAttention(nn.Module): |
| def __init__(self, d_model, n_heads, head_dim, n_kv_groups): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = head_dim |
| self.n_kv_groups = n_kv_groups |
| self.group_size = n_heads // n_kv_groups |
| self.output_dim = n_heads * head_dim |
|
|
| self.Wq = nn.Linear(d_model, self.output_dim, bias=False) |
| self.Wk = nn.Linear(d_model, n_kv_groups * head_dim, bias=False) |
| self.Wv = nn.Linear(d_model, n_kv_groups * head_dim, bias=False) |
| self.Wo = nn.Linear(self.output_dim, d_model, bias=False) |
| self.q_norm = RMSNorm(head_dim) |
| self.k_norm = RMSNorm(head_dim) |
| self.rope = RoPE() |
|
|
| def forward(self, x, cos, sin, mask=None, past_key_value=None, use_cache=False): |
| B, S, _ = x.shape |
|
|
| q = self.Wq(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.Wk(x).view(B, S, self.n_kv_groups, self.head_dim).transpose(1, 2) |
| v = self.Wv(x).view(B, S, self.n_kv_groups, self.head_dim).transpose(1, 2) |
|
|
| q, k = self.q_norm(q), self.k_norm(k) |
| q, k = self.rope(q, cos, sin), self.rope(k, cos, sin) |
|
|
| next_past = None |
| if past_key_value is not None: |
| if isinstance(past_key_value, Cache): |
| |
| k, v = past_key_value.update(k, v, self.layer_idx) |
| next_past = past_key_value |
| else: |
| |
| |
| pk, pv = past_key_value |
| if pk is not None: |
| k = torch.cat([pk, k], dim=2) |
| v = torch.cat([pv, v], dim=2) |
| next_past = (k, v) if use_cache else None |
|
|
| |
| kv_k, kv_v = k, v |
|
|
| B, G, S_kv, D = kv_k.shape |
| k = kv_k.unsqueeze(2).expand(B, G, self.group_size, S_kv, D).reshape(B, self.n_heads, S_kv, D) |
| v = kv_v.unsqueeze(2).expand(B, G, self.group_size, S_kv, D).reshape(B, self.n_heads, S_kv, D) |
|
|
| |
| |
| is_causal = (mask is None and S > 1 and past_key_value is None) |
| |
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| attn_mask=None if (mask is None or is_causal) else ~mask, |
| dropout_p=0.0, |
| is_causal=is_causal, |
| ) |
| out = out.transpose(1, 2).reshape(B, S, self.output_dim) |
| if use_cache and past_key_value is None: |
| |
| next_past = (kv_k, kv_v) |
| return self.Wo(out), next_past |
|
|
|
|
| class DotLMBlock(nn.Module): |
| def __init__(self, d_model, n_heads, n_kv_heads, hidden_dim, norm_eps=1e-6, layer_idx=None): |
| super().__init__() |
| head_dim = d_model // n_heads |
| self.attention = GroupedQueryAttention(d_model, n_heads, head_dim, n_kv_heads) |
| self.attention.layer_idx = layer_idx |
| self.feed_forward = SwiGLU(d_model, hidden_dim) |
| self.norm1 = RMSNorm(d_model, norm_eps) |
| self.norm2 = RMSNorm(d_model, norm_eps) |
|
|
| def forward(self, x, cos, sin, mask=None, past_key_value=None, use_cache=False): |
| residual = x |
| x = self.norm1(x) |
| attn_out, next_past = self.attention(x, cos, sin, mask, past_key_value, use_cache) |
| x = residual + attn_out |
|
|
| residual = x |
| x = self.norm2(x) |
| x = residual + self.feed_forward(x) |
| return x, next_past |
|
|
|
|
| |
|
|
| class DotLMForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = DotLMConfig |
| |
| _tied_weights_keys = {"head.weight": "embeddor.weight"} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| self.embeddor = nn.Embedding(config.vocab_size, config.d_model) |
| self.blocks = nn.ModuleList([ |
| DotLMBlock( |
| config.d_model, config.n_heads, config.n_kv_heads, |
| config.hidden_dim, config.norm_eps, layer_idx=i |
| ) |
| for i in range(config.num_hidden_layers) |
| ]) |
| self.norm = RMSNorm(config.d_model, config.norm_eps) |
| self.head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| |
| head_dim = config.d_model // config.n_heads |
| cos, sin = precompute_freqs_cis(head_dim, config.context_len, config.theta_base) |
| self.register_buffer("cos_cache", cos, persistent=False) |
| self.register_buffer("sin_cache", sin, persistent=False) |
|
|
| |
| mask = torch.triu(torch.ones(config.context_len, config.context_len, dtype=torch.bool), diagonal=1) |
| self.register_buffer("causal_mask", mask, persistent=False) |
|
|
| self.post_init() |
|
|
| def _ensure_rope_and_mask(self): |
| """ |
| `from_pretrained(..., low_cpu_mem_usage=True)` may build the module under |
| meta tensors. In that case, our non-persistent buffers can end up as |
| meta/zero tensors even though they are deterministic. Recompute them on |
| demand. |
| """ |
| need_rope = ( |
| self.cos_cache.device.type == "meta" |
| or self.sin_cache.device.type == "meta" |
| or self.cos_cache.numel() == 0 |
| or self.sin_cache.numel() == 0 |
| or (self.cos_cache.numel() > 0 and float(self.cos_cache.flatten()[0]) == 0.0) |
| ) |
| need_mask = ( |
| self.causal_mask.device.type == "meta" |
| or self.causal_mask.numel() == 0 |
| |
| or (self.causal_mask.numel() > 1 and bool(self.causal_mask[0, 1]) is False) |
| ) |
| if not (need_rope or need_mask): |
| return |
|
|
| head_dim = self.config.d_model // self.config.n_heads |
| cos, sin = precompute_freqs_cis(head_dim, self.config.context_len, self.config.theta_base) |
| self._buffers["cos_cache"] = cos |
| self._buffers["sin_cache"] = sin |
|
|
| mask = torch.triu( |
| torch.ones(self.config.context_len, self.config.context_len, dtype=torch.bool), diagonal=1 |
| ) |
| self._buffers["causal_mask"] = mask |
|
|
| def _init_weights(self, module): |
| std = self.config.initializer_range |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
| def tie_weights(self, **kwargs): |
| if self.config.tie_word_embeddings: |
| self.head.weight = self.embeddor.weight |
|
|
| def get_input_embeddings(self): |
| return self.embeddor |
|
|
| def set_input_embeddings(self, value): |
| self.embeddor = value |
| self.tie_weights() |
|
|
| def get_output_embeddings(self): |
| return self.head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.head = new_embeddings |
| self.tie_weights() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| B, S = input_ids.shape |
|
|
| self._ensure_rope_and_mask() |
|
|
| |
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache() |
|
|
| |
| start_pos = 0 |
| if past_key_values is not None: |
| if isinstance(past_key_values, Cache): |
| start_pos = past_key_values.get_seq_length() |
| else: |
| layer0 = past_key_values[0] |
| if layer0 is not None and layer0[0] is not None: |
| start_pos = layer0[0].shape[2] |
|
|
| |
| x = self.embeddor(input_ids) |
|
|
| |
| cos = self.cos_cache[start_pos : start_pos + S].to(device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) |
| sin = self.sin_cache[start_pos : start_pos + S].to(device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) |
| |
| |
| mask = None |
| if S > 1: |
| mask = self.causal_mask[start_pos : start_pos + S, : start_pos + S].to(device=x.device) |
|
|
| next_past_key_values = [] if (use_cache and not isinstance(past_key_values, Cache)) else None |
|
|
| |
| for i, block in enumerate(self.blocks): |
| layer_past = None |
| if past_key_values is not None: |
| if isinstance(past_key_values, Cache): |
| layer_past = past_key_values |
| else: |
| layer_past = past_key_values[i] |
| x, new_layer_past = block( |
| x, cos, sin, mask=mask, past_key_value=layer_past, use_cache=use_cache |
| ) |
| if next_past_key_values is not None: |
| next_past_key_values.append(new_layer_past) |
|
|
| |
| logits = self.head(self.norm(x)) |
| if not self.training: |
| |
| logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
| if not return_dict: |
| return (logits, past_key_values) if use_cache else (logits,) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=past_key_values if isinstance(past_key_values, Cache) else (tuple(next_past_key_values) if use_cache else None) |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): |
| past_len = 0 |
| if past_key_values is not None: |
| if isinstance(past_key_values, Cache): |
| past_len = past_key_values.get_seq_length() |
| else: |
| layer0 = past_key_values[0] if len(past_key_values) > 0 else None |
| if layer0 is not None and layer0[0] is not None: |
| past_len = layer0[0].shape[2] |
|
|
| |
| if past_len > 0: |
| input_ids = input_ids[:, -1:] |
| return { |
| "input_ids": input_ids, |
| "past_key_values": past_key_values, |
| "attention_mask": kwargs.get("attention_mask", None), |
| "token_type_ids": kwargs.get("token_type_ids", None), |
| "use_cache": True, |
| } |
|
|
| def _reorder_cache(self, past_key_values, beam_idx): |
| if past_key_values is None: |
| return past_key_values |
| if isinstance(past_key_values, Cache): |
| past_key_values.reorder_cache(beam_idx) |
| return past_key_values |
| return tuple( |
| (k.index_select(0, beam_idx), v.index_select(0, beam_idx)) |
| for (k, v) in past_key_values |
| ) |
|
|
| @torch.no_grad() |
| def generate(self, input_ids=None, max_new_tokens=256, temperature=1.0, |
| top_k=None, do_sample=True, eos_token_id=None, **kwargs): |
| """Custom autoregressive generate that bypasses GenerationMixin internals.""" |
| self._ensure_rope_and_mask() |
| kv_cache = None |
| curr_ids = input_ids |
|
|
| for _ in range(max_new_tokens): |
| if curr_ids.size(1) > self.config.context_len: |
| curr_ids = curr_ids[:, -self.config.context_len:] |
|
|
| model_input = curr_ids if kv_cache is None else curr_ids[:, -1:] |
| out = self.forward(model_input, past_key_values=kv_cache, use_cache=True, return_dict=True) |
| kv_cache = out.past_key_values |
|
|
| logits = out.logits[:, -1, :] |
| if do_sample: |
| logits = logits / max(temperature, 1e-8) |
| if top_k is not None: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float("Inf") |
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| else: |
| next_token = logits.argmax(dim=-1, keepdim=True) |
|
|
| curr_ids = torch.cat([curr_ids, next_token], dim=1) |
| if eos_token_id is not None and (next_token == eos_token_id).all(): |
| break |
|
|
| return curr_ids |
|
|