Dwarf-15M / modeling_dwarf.py
ThingsAI's picture
Upload folder using huggingface_hub
a6a57d5 verified
Raw
History Blame Contribute Delete
5.66 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin
from .configuration_dwarf import DwarfConfig
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * rms).to(x.dtype) * self.scale
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim, max_seq_len, theta=10000.0):
super().__init__()
assert head_dim % 2 == 0
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.theta = theta
self.cos_cache = None
self.sin_cache = None
self._max = 0
def _build_cache(self, seq_len, device):
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim))
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.cos_cache = emb.cos()[None, None]
self.sin_cache = emb.sin()[None, None]
self._max = seq_len
@staticmethod
def _rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
def forward(self, q, k):
T = q.size(2)
if self.cos_cache is None or T > self._max or self.cos_cache.device != q.device:
self._build_cache(max(T, self.max_seq_len), q.device)
cos = self.cos_cache[:, :, :T, :]
sin = self.sin_cache[:, :, :T, :]
q = q * cos + self._rotate_half(q) * sin
k = k * cos + self._rotate_half(k) * sin
return q, k
class GQAAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.n_heads
self.n_kv_heads = config.n_kv_heads
self.n_groups = config.n_heads // config.n_kv_heads
self.head_dim = config.head_dim
self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=True)
self.k_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=True)
self.v_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=True)
self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False)
self.rope = RotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta)
def forward(self, x):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
q, k = self.rope(q, k)
if self.n_groups > 1:
k = k.repeat_interleave(self.n_groups, dim=1)
v = v.repeat_interleave(self.n_groups, dim=1)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)
return self.o_proj(out)
class SwiGLUFFN(nn.Module):
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class DwarfBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.norm_attn = RMSNorm(config.d_model, config.norm_eps)
self.attn = GQAAttention(config)
self.norm_ffn = RMSNorm(config.d_model, config.norm_eps)
self.ffn = SwiGLUFFN(config)
def forward(self, x):
x = x + self.attn(self.norm_attn(x))
x = x + self.ffn(self.norm_ffn(x))
return x
class DwarfForCausalLM(PreTrainedModel, GenerationMixin):
config_class = DwarfConfig
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList([DwarfBlock(config) for _ in range(config.n_layers)])
self.norm = RMSNorm(config.d_model, config.norm_eps)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.post_init()
def tie_weights(self, **kwargs):
self.lm_head.weight = self.embed_tokens.weight
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
x = self.embed_tokens(input_ids)
for layer in self.layers:
x = layer(x)
logits = self.lm_head(self.norm(x))
loss = None
if labels is not None:
loss = F.cross_entropy(
logits[:, :-1].contiguous().view(-1, logits.size(-1)),
labels[:, 1:].contiguous().view(-1), ignore_index=-100)
from transformers.modeling_outputs import CausalLMOutput
return CausalLMOutput(loss=loss, logits=logits)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}