DotLM-165M / modeling_dotlm.py
tensorfiend's picture
Upload modeling_dotlm.py with huggingface_hub
b4b0382 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Union
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache, DynamicCache
# ── Config ────────────────────────────────────────────────────────────────────
class DotLMConfig(PretrainedConfig):
model_type = "dotlm"
def __init__(
self,
vocab_size=16384,
d_model=768,
hidden_dim=2048,
num_hidden_layers=24,
n_heads=6,
n_kv_heads=2,
context_len=4096,
theta_base=10000.0,
norm_eps=1e-6,
initializer_range=0.02,
tie_word_embeddings=True,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_dim = hidden_dim
self.num_hidden_layers = num_hidden_layers
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.context_len = context_len
self.theta_base = theta_base
self.norm_eps = norm_eps
self.initializer_range = initializer_range
self.tie_word_embeddings = tie_word_embeddings
self.use_cache = kwargs.get("use_cache", True)
self.pad_token_id = kwargs.get("pad_token_id", 0)
self.bos_token_id = kwargs.get("bos_token_id", None)
self.eos_token_id = kwargs.get("eos_token_id", 3)
# ── Architecture Components ───────────────────────────────────────────────────
def precompute_freqs_cis(dim, context_len, theta_base=10000.0):
theta = 1.0 / (theta_base ** (torch.arange(0, dim, 2) / dim))
seq_ids = torch.arange(context_len, dtype=torch.float32)
m_theta = torch.outer(seq_ids, theta)
m_theta = torch.cat([m_theta, m_theta], dim=-1)
return torch.cos(m_theta), torch.sin(m_theta)
class SwiGLU(nn.Module):
def __init__(self, d_model, hidden_dim):
super().__init__()
self.W = nn.Linear(d_model, hidden_dim, bias=False)
self.V = nn.Linear(d_model, hidden_dim, bias=False)
self.W2 = nn.Linear(hidden_dim, d_model, bias=False)
self.silu = nn.SiLU()
def forward(self, x):
return self.W2(self.silu(self.W(x)) * self.V(x))
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x):
x = x * torch.rsqrt(torch.pow(x, 2).mean(dim=-1, keepdim=True) + self.eps)
return x * self.scale
class RoPE(nn.Module):
def forward(self, x, cos, sin):
batch_size, num_heads, seq_len, head_dim = x.shape
x1, x2 = x[..., : head_dim // 2], x[..., head_dim // 2 :]
x_rotated = torch.cat([-x2, x1], dim=-1)
return x * cos + x_rotated * sin
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, n_heads, head_dim, n_kv_groups):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.n_kv_groups = n_kv_groups
self.group_size = n_heads // n_kv_groups
self.output_dim = n_heads * head_dim
self.Wq = nn.Linear(d_model, self.output_dim, bias=False)
self.Wk = nn.Linear(d_model, n_kv_groups * head_dim, bias=False)
self.Wv = nn.Linear(d_model, n_kv_groups * head_dim, bias=False)
self.Wo = nn.Linear(self.output_dim, d_model, bias=False)
self.q_norm = RMSNorm(head_dim)
self.k_norm = RMSNorm(head_dim)
self.rope = RoPE()
def forward(self, x, cos, sin, mask=None, past_key_value=None, use_cache=False):
B, S, _ = x.shape
q = self.Wq(x).view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
k = self.Wk(x).view(B, S, self.n_kv_groups, self.head_dim).transpose(1, 2)
v = self.Wv(x).view(B, S, self.n_kv_groups, self.head_dim).transpose(1, 2)
q, k = self.q_norm(q), self.k_norm(k)
q, k = self.rope(q, cos, sin), self.rope(k, cos, sin)
next_past = None
if past_key_value is not None:
if isinstance(past_key_value, Cache):
# HF DynamicCache: update in-place and get concatenated K/V back.
k, v = past_key_value.update(k, v, self.layer_idx)
next_past = past_key_value
else:
# Legacy cache format: (k, v) per layer. Some generation paths
# may pass placeholders like (None, None) on the first step.
pk, pv = past_key_value
if pk is not None:
k = torch.cat([pk, k], dim=2)
v = torch.cat([pv, v], dim=2)
next_past = (k, v) if use_cache else None
# Cache stores grouped K/V (n_kv_groups heads). We only expand for SDPA.
kv_k, kv_v = k, v
B, G, S_kv, D = kv_k.shape
k = kv_k.unsqueeze(2).expand(B, G, self.group_size, S_kv, D).reshape(B, self.n_heads, S_kv, D)
v = kv_v.unsqueeze(2).expand(B, G, self.group_size, S_kv, D).reshape(B, self.n_heads, S_kv, D)
# Causal logic for SDPA: if mask is None, we assume causality if prefill
# But for robustness, we always pass a mask if S > 1
is_causal = (mask is None and S > 1 and past_key_value is None)
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None if (mask is None or is_causal) else ~mask,
dropout_p=0.0,
is_causal=is_causal,
)
out = out.transpose(1, 2).reshape(B, S, self.output_dim)
if use_cache and past_key_value is None:
# If we're not given a cache, return legacy K/V by default.
next_past = (kv_k, kv_v)
return self.Wo(out), next_past
class DotLMBlock(nn.Module):
def __init__(self, d_model, n_heads, n_kv_heads, hidden_dim, norm_eps=1e-6, layer_idx=None):
super().__init__()
head_dim = d_model // n_heads
self.attention = GroupedQueryAttention(d_model, n_heads, head_dim, n_kv_heads)
self.attention.layer_idx = layer_idx
self.feed_forward = SwiGLU(d_model, hidden_dim)
self.norm1 = RMSNorm(d_model, norm_eps)
self.norm2 = RMSNorm(d_model, norm_eps)
def forward(self, x, cos, sin, mask=None, past_key_value=None, use_cache=False):
residual = x
x = self.norm1(x)
attn_out, next_past = self.attention(x, cos, sin, mask, past_key_value, use_cache)
x = residual + attn_out
residual = x
x = self.norm2(x)
x = residual + self.feed_forward(x)
return x, next_past
# ── Flat HF Wrapper ───────────────────────────────────────────────────────────
class DotLMForCausalLM(PreTrainedModel, GenerationMixin):
config_class = DotLMConfig
# Let HF know output head is tied to embeddings when enabled.
_tied_weights_keys = {"head.weight": "embeddor.weight"}
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddor = nn.Embedding(config.vocab_size, config.d_model)
self.blocks = nn.ModuleList([
DotLMBlock(
config.d_model, config.n_heads, config.n_kv_heads,
config.hidden_dim, config.norm_eps, layer_idx=i
)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.d_model, config.norm_eps)
self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Precompute RoPE
head_dim = config.d_model // config.n_heads
cos, sin = precompute_freqs_cis(head_dim, config.context_len, config.theta_base)
self.register_buffer("cos_cache", cos, persistent=False)
self.register_buffer("sin_cache", sin, persistent=False)
# Causal mask placeholder
mask = torch.triu(torch.ones(config.context_len, config.context_len, dtype=torch.bool), diagonal=1)
self.register_buffer("causal_mask", mask, persistent=False)
self.post_init()
def _ensure_rope_and_mask(self):
"""
`from_pretrained(..., low_cpu_mem_usage=True)` may build the module under
meta tensors. In that case, our non-persistent buffers can end up as
meta/zero tensors even though they are deterministic. Recompute them on
demand.
"""
need_rope = (
self.cos_cache.device.type == "meta"
or self.sin_cache.device.type == "meta"
or self.cos_cache.numel() == 0
or self.sin_cache.numel() == 0
or (self.cos_cache.numel() > 0 and float(self.cos_cache.flatten()[0]) == 0.0)
)
need_mask = (
self.causal_mask.device.type == "meta"
or self.causal_mask.numel() == 0
# causal_mask[0, 1] should be True for an upper-triangular mask.
or (self.causal_mask.numel() > 1 and bool(self.causal_mask[0, 1]) is False)
)
if not (need_rope or need_mask):
return
head_dim = self.config.d_model // self.config.n_heads
cos, sin = precompute_freqs_cis(head_dim, self.config.context_len, self.config.theta_base)
self._buffers["cos_cache"] = cos
self._buffers["sin_cache"] = sin
mask = torch.triu(
torch.ones(self.config.context_len, self.config.context_len, dtype=torch.bool), diagonal=1
)
self._buffers["causal_mask"] = mask
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=std)
def tie_weights(self, **kwargs):
if self.config.tie_word_embeddings:
self.head.weight = self.embeddor.weight
def get_input_embeddings(self):
return self.embeddor
def set_input_embeddings(self, value):
self.embeddor = value
self.tie_weights()
def get_output_embeddings(self):
return self.head
def set_output_embeddings(self, new_embeddings):
self.head = new_embeddings
self.tie_weights()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
B, S = input_ids.shape
self._ensure_rope_and_mask()
# Support both HF Cache (v5+) and legacy tuple-of-layer-caches.
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
# Positional tracking
start_pos = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
start_pos = past_key_values.get_seq_length()
else:
layer0 = past_key_values[0]
if layer0 is not None and layer0[0] is not None:
start_pos = layer0[0].shape[2]
# Embeddings
x = self.embeddor(input_ids)
# RoPE slicing
cos = self.cos_cache[start_pos : start_pos + S].to(device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0)
sin = self.sin_cache[start_pos : start_pos + S].to(device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0)
# Masking
mask = None
if S > 1:
mask = self.causal_mask[start_pos : start_pos + S, : start_pos + S].to(device=x.device)
next_past_key_values = [] if (use_cache and not isinstance(past_key_values, Cache)) else None
# Blocks
for i, block in enumerate(self.blocks):
layer_past = None
if past_key_values is not None:
if isinstance(past_key_values, Cache):
layer_past = past_key_values
else:
layer_past = past_key_values[i]
x, new_layer_past = block(
x, cos, sin, mask=mask, past_key_value=layer_past, use_cache=use_cache
)
if next_past_key_values is not None:
next_past_key_values.append(new_layer_past)
# Final head
logits = self.head(self.norm(x))
if not self.training:
# Stability clip
logits = torch.nan_to_num(logits, nan=0.0, posinf=1e4, neginf=-1e4)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict:
return (logits, past_key_values) if use_cache else (logits,)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values if isinstance(past_key_values, Cache) else (tuple(next_past_key_values) if use_cache else None)
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
past_len = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_len = past_key_values.get_seq_length()
else:
layer0 = past_key_values[0] if len(past_key_values) > 0 else None
if layer0 is not None and layer0[0] is not None:
past_len = layer0[0].shape[2]
# Only slice for incremental decoding once we truly have cached history.
if past_len > 0:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"attention_mask": kwargs.get("attention_mask", None),
"token_type_ids": kwargs.get("token_type_ids", None),
"use_cache": True,
}
def _reorder_cache(self, past_key_values, beam_idx):
if past_key_values is None:
return past_key_values
if isinstance(past_key_values, Cache):
past_key_values.reorder_cache(beam_idx)
return past_key_values
return tuple(
(k.index_select(0, beam_idx), v.index_select(0, beam_idx))
for (k, v) in past_key_values
)
@torch.no_grad()
def generate(self, input_ids=None, max_new_tokens=256, temperature=1.0,
top_k=None, do_sample=True, eos_token_id=None, **kwargs):
"""Custom autoregressive generate that bypasses GenerationMixin internals."""
self._ensure_rope_and_mask()
kv_cache = None
curr_ids = input_ids
for _ in range(max_new_tokens):
if curr_ids.size(1) > self.config.context_len:
curr_ids = curr_ids[:, -self.config.context_len:]
model_input = curr_ids if kv_cache is None else curr_ids[:, -1:]
out = self.forward(model_input, past_key_values=kv_cache, use_cache=True, return_dict=True)
kv_cache = out.past_key_values
logits = out.logits[:, -1, :]
if do_sample:
logits = logits / max(temperature, 1e-8)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = logits.argmax(dim=-1, keepdim=True)
curr_ids = torch.cat([curr_ids, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return curr_ids