import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List, Union from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.cache_utils import Cache, DynamicCache # ── Config ──────────────────────────────────────────────────────────────────── class DotLMConfig(PretrainedConfig): model_type = "dotlm" def __init__( self, vocab_size=16384, d_model=768, hidden_dim=2048, num_hidden_layers=24, n_heads=6, n_kv_heads=2, context_len=4096, theta_base=10000.0, norm_eps=1e-6, initializer_range=0.02, tie_word_embeddings=True, **kwargs ): super().__init__(**kwargs) self.vocab_size = vocab_size self.d_model = d_model self.hidden_dim = hidden_dim self.num_hidden_layers = num_hidden_layers self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.context_len = context_len self.theta_base = theta_base self.norm_eps = norm_eps self.initializer_range = initializer_range self.tie_word_embeddings = tie_word_embeddings self.use_cache = kwargs.get("use_cache", True) self.pad_token_id = kwargs.get("pad_token_id", 0) self.bos_token_id = kwargs.get("bos_token_id", None) self.eos_token_id = kwargs.get("eos_token_id", 3) # ── Architecture Components ─────────────────────────────────────────────────── def precompute_freqs_cis(dim, context_len, theta_base=10000.0): theta = 1.0 / (theta_base ** (torch.arange(0, dim, 2) / dim)) seq_ids = torch.arange(context_len, dtype=torch.float32) m_theta = torch.outer(seq_ids, theta) m_theta = torch.cat([m_theta, m_theta], dim=-1) return torch.cos(m_theta), torch.sin(m_theta) class SwiGLU(nn.Module): def __init__(self, d_model, hidden_dim): super().__init__() self.W = nn.Linear(d_model, hidden_dim, bias=False) self.V = nn.Linear(d_model, hidden_dim, bias=False) self.W2 = nn.Linear(hidden_dim, d_model, bias=False) self.silu = nn.SiLU() def forward(self, x): return self.W2(self.silu(self.W(x)) * self.V(x)) class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): x = x * torch.rsqrt(torch.pow(x, 2).mean(dim=-1, keepdim=True) + self.eps) return x * self.scale class RoPE(nn.Module): def forward(self, x, cos, sin): batch_size, num_heads, seq_len, head_dim = x.shape x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2 :] x_rotated = torch.cat([-x2, x1], dim=-1) return x * cos + x_rotated * sin class GroupedQueryAttention(nn.Module): def __init__(self, d_model, n_heads, head_dim, n_kv_groups): super().__init__() self.n_heads = n_heads self.head_dim = head_dim self.n_kv_groups = n_kv_groups self.group_size = n_heads // n_kv_groups self.output_dim = n_heads * head_dim self.Wq = nn.Linear(d_model, self.output_dim, bias=False) self.Wk = nn.Linear(d_model, n_kv_groups * head_dim, bias=False) self.Wv = nn.Linear(d_model, n_kv_groups * head_dim, bias=False) self.Wo = nn.Linear(self.output_dim, d_model, bias=False) self.q_norm = RMSNorm(head_dim) self.k_norm = RMSNorm(head_dim) self.rope = RoPE() def forward(self, x, cos, sin, mask=None, past_key_value=None, use_cache=False): B, S, _ = x.shape q = self.Wq(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2) k = self.Wk(x).view(B, S, self.n_kv_groups, self.head_dim).transpose(1, 2) v = self.Wv(x).view(B, S, self.n_kv_groups, self.head_dim).transpose(1, 2) q, k = self.q_norm(q), self.k_norm(k) q, k = self.rope(q, cos, sin), self.rope(k, cos, sin) next_past = None if past_key_value is not None: if isinstance(past_key_value, Cache): # HF DynamicCache: update in-place and get concatenated K/V back. k, v = past_key_value.update(k, v, self.layer_idx) next_past = past_key_value else: # Legacy cache format: (k, v) per layer. Some generation paths # may pass placeholders like (None, None) on the first step. pk, pv = past_key_value if pk is not None: k = torch.cat([pk, k], dim=2) v = torch.cat([pv, v], dim=2) next_past = (k, v) if use_cache else None # Cache stores grouped K/V (n_kv_groups heads). We only expand for SDPA. kv_k, kv_v = k, v B, G, S_kv, D = kv_k.shape k = kv_k.unsqueeze(2).expand(B, G, self.group_size, S_kv, D).reshape(B, self.n_heads, S_kv, D) v = kv_v.unsqueeze(2).expand(B, G, self.group_size, S_kv, D).reshape(B, self.n_heads, S_kv, D) # Causal logic for SDPA: if mask is None, we assume causality if prefill # But for robustness, we always pass a mask if S > 1 is_causal = (mask is None and S > 1 and past_key_value is None) out = F.scaled_dot_product_attention( q, k, v, attn_mask=None if (mask is None or is_causal) else ~mask, dropout_p=0.0, is_causal=is_causal, ) out = out.transpose(1, 2).reshape(B, S, self.output_dim) if use_cache and past_key_value is None: # If we're not given a cache, return legacy K/V by default. next_past = (kv_k, kv_v) return self.Wo(out), next_past class DotLMBlock(nn.Module): def __init__(self, d_model, n_heads, n_kv_heads, hidden_dim, norm_eps=1e-6, layer_idx=None): super().__init__() head_dim = d_model // n_heads self.attention = GroupedQueryAttention(d_model, n_heads, head_dim, n_kv_heads) self.attention.layer_idx = layer_idx self.feed_forward = SwiGLU(d_model, hidden_dim) self.norm1 = RMSNorm(d_model, norm_eps) self.norm2 = RMSNorm(d_model, norm_eps) def forward(self, x, cos, sin, mask=None, past_key_value=None, use_cache=False): residual = x x = self.norm1(x) attn_out, next_past = self.attention(x, cos, sin, mask, past_key_value, use_cache) x = residual + attn_out residual = x x = self.norm2(x) x = residual + self.feed_forward(x) return x, next_past # ── Flat HF Wrapper ─────────────────────────────────────────────────────────── class DotLMForCausalLM(PreTrainedModel, GenerationMixin): config_class = DotLMConfig # Let HF know output head is tied to embeddings when enabled. _tied_weights_keys = {"head.weight": "embeddor.weight"} def __init__(self, config): super().__init__(config) self.config = config self.embeddor = nn.Embedding(config.vocab_size, config.d_model) self.blocks = nn.ModuleList([ DotLMBlock( config.d_model, config.n_heads, config.n_kv_heads, config.hidden_dim, config.norm_eps, layer_idx=i ) for i in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.d_model, config.norm_eps) self.head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Precompute RoPE head_dim = config.d_model // config.n_heads cos, sin = precompute_freqs_cis(head_dim, config.context_len, config.theta_base) self.register_buffer("cos_cache", cos, persistent=False) self.register_buffer("sin_cache", sin, persistent=False) # Causal mask placeholder mask = torch.triu(torch.ones(config.context_len, config.context_len, dtype=torch.bool), diagonal=1) self.register_buffer("causal_mask", mask, persistent=False) self.post_init() def _ensure_rope_and_mask(self): """ `from_pretrained(..., low_cpu_mem_usage=True)` may build the module under meta tensors. In that case, our non-persistent buffers can end up as meta/zero tensors even though they are deterministic. Recompute them on demand. """ need_rope = ( self.cos_cache.device.type == "meta" or self.sin_cache.device.type == "meta" or self.cos_cache.numel() == 0 or self.sin_cache.numel() == 0 or (self.cos_cache.numel() > 0 and float(self.cos_cache.flatten()[0]) == 0.0) ) need_mask = ( self.causal_mask.device.type == "meta" or self.causal_mask.numel() == 0 # causal_mask[0, 1] should be True for an upper-triangular mask. or (self.causal_mask.numel() > 1 and bool(self.causal_mask[0, 1]) is False) ) if not (need_rope or need_mask): return head_dim = self.config.d_model // self.config.n_heads cos, sin = precompute_freqs_cis(head_dim, self.config.context_len, self.config.theta_base) self._buffers["cos_cache"] = cos self._buffers["sin_cache"] = sin mask = torch.triu( torch.ones(self.config.context_len, self.config.context_len, dtype=torch.bool), diagonal=1 ) self._buffers["causal_mask"] = mask def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=std) def tie_weights(self, **kwargs): if self.config.tie_word_embeddings: self.head.weight = self.embeddor.weight def get_input_embeddings(self): return self.embeddor def set_input_embeddings(self, value): self.embeddor = value self.tie_weights() def get_output_embeddings(self): return self.head def set_output_embeddings(self, new_embeddings): self.head = new_embeddings self.tie_weights() def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs ) -> Union[Tuple, CausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict use_cache = use_cache if use_cache is not None else self.config.use_cache B, S = input_ids.shape self._ensure_rope_and_mask() # Support both HF Cache (v5+) and legacy tuple-of-layer-caches. if use_cache and past_key_values is None: past_key_values = DynamicCache() # Positional tracking start_pos = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): start_pos = past_key_values.get_seq_length() else: layer0 = past_key_values[0] if layer0 is not None and layer0[0] is not None: start_pos = layer0[0].shape[2] # Embeddings x = self.embeddor(input_ids) # RoPE slicing cos = self.cos_cache[start_pos : start_pos + S].to(device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) sin = self.sin_cache[start_pos : start_pos + S].to(device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) # Masking mask = None if S > 1: mask = self.causal_mask[start_pos : start_pos + S, : start_pos + S].to(device=x.device) next_past_key_values = [] if (use_cache and not isinstance(past_key_values, Cache)) else None # Blocks for i, block in enumerate(self.blocks): layer_past = None if past_key_values is not None: if isinstance(past_key_values, Cache): layer_past = past_key_values else: layer_past = past_key_values[i] x, new_layer_past = block( x, cos, sin, mask=mask, past_key_value=layer_past, use_cache=use_cache ) if next_past_key_values is not None: next_past_key_values.append(new_layer_past) # Final head logits = self.head(self.norm(x)) if not self.training: # Stability clip logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: return (logits, past_key_values) if use_cache else (logits,) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=past_key_values if isinstance(past_key_values, Cache) else (tuple(next_past_key_values) if use_cache else None) ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): past_len = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): past_len = past_key_values.get_seq_length() else: layer0 = past_key_values[0] if len(past_key_values) > 0 else None if layer0 is not None and layer0[0] is not None: past_len = layer0[0].shape[2] # Only slice for incremental decoding once we truly have cached history. if past_len > 0: input_ids = input_ids[:, -1:] return { "input_ids": input_ids, "past_key_values": past_key_values, "attention_mask": kwargs.get("attention_mask", None), "token_type_ids": kwargs.get("token_type_ids", None), "use_cache": True, } def _reorder_cache(self, past_key_values, beam_idx): if past_key_values is None: return past_key_values if isinstance(past_key_values, Cache): past_key_values.reorder_cache(beam_idx) return past_key_values return tuple( (k.index_select(0, beam_idx), v.index_select(0, beam_idx)) for (k, v) in past_key_values ) @torch.no_grad() def generate(self, input_ids=None, max_new_tokens=256, temperature=1.0, top_k=None, do_sample=True, eos_token_id=None, **kwargs): """Custom autoregressive generate that bypasses GenerationMixin internals.""" self._ensure_rope_and_mask() kv_cache = None curr_ids = input_ids for _ in range(max_new_tokens): if curr_ids.size(1) > self.config.context_len: curr_ids = curr_ids[:, -self.config.context_len:] model_input = curr_ids if kv_cache is None else curr_ids[:, -1:] out = self.forward(model_input, past_key_values=kv_cache, use_cache=True, return_dict=True) kv_cache = out.past_key_values logits = out.logits[:, -1, :] if do_sample: logits = logits / max(temperature, 1e-8) if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float("Inf") probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = logits.argmax(dim=-1, keepdim=True) curr_ids = torch.cat([curr_ids, next_token], dim=1) if eos_token_id is not None and (next_token == eos_token_id).all(): break return curr_ids