Momo-336M-sft / modeling_momo.py
dill-dev's picture
Update modeling_momo.py
fecde63 verified
# modeling_momo.py
# 🌸 Momo-336M — HuggingFace compatible model definition
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_momo import MomoConfig
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * rms).to(x.dtype) * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_seq=512, theta=10000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self._cache(max_seq)
def _cache(self, n):
t = torch.arange(n, device=self.inv_freq.device).float()
freq = torch.outer(t, self.inv_freq)
emb = torch.cat([freq, freq], dim=-1)
self.register_buffer('cos_c', emb.cos()[None, None])
self.register_buffer('sin_c', emb.sin()[None, None])
def forward(self, x, seq_len):
if seq_len > self.cos_c.shape[2]:
self._cache(seq_len)
return (
self.cos_c[:, :, :seq_len].to(x.dtype),
self.sin_c[:, :, :seq_len].to(x.dtype),
)
def rot_half(x):
a, b = x.chunk(2, dim=-1)
return torch.cat([-b, a], dim=-1)
def apply_rope(q, k, cos, sin):
return (q * cos) + (rot_half(q) * sin), (k * cos) + (rot_half(k) * sin)
class MomoAttention(nn.Module):
def __init__(self, cfg: MomoConfig):
super().__init__()
self.nh = cfg.num_attention_heads
self.nkv = cfg.num_key_value_heads
self.hd = cfg.hidden_size // cfg.num_attention_heads
self.grp = self.nh // self.nkv
self.sc = self.hd ** -0.5
H = cfg.hidden_size
self.q = nn.Linear(H, self.nh * self.hd, bias=False)
self.k = nn.Linear(H, self.nkv * self.hd, bias=False)
self.v = nn.Linear(H, self.nkv * self.hd, bias=False)
self.o = nn.Linear(self.nh * self.hd, H, bias=False)
self.rope = RotaryEmbedding(self.hd, cfg.max_position_embeddings, cfg.rope_theta)
def forward(self, x, mask=None, past=None, use_cache=False):
B, T, _ = x.shape
q = self.q(x).view(B, T, self.nh, self.hd).transpose(1, 2)
k = self.k(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
v = self.v(x).view(B, T, self.nkv, self.hd).transpose(1, 2)
past_len = past[0].shape[2] if past is not None else 0
cos, sin = self.rope(q, past_len + T)
cos = cos[:, :, past_len:past_len + T]
sin = sin[:, :, past_len:past_len + T]
q, k = apply_rope(q, k, cos, sin)
if self.grp > 1:
k = k[:, None].expand(-1, self.grp, -1, -1, -1).reshape(B, self.nh, T, self.hd)
v = v[:, None].expand(-1, self.grp, -1, -1, -1).reshape(B, self.nh, T, self.hd)
if past is not None:
pk, pv = past
k = torch.cat([pk, k], 2)
v = torch.cat([pv, v], 2)
pres = (k, v) if use_cache else None
S = k.shape[2]
a = torch.matmul(q, k.transpose(-2, -1)) * self.sc
causal = torch.triu(
torch.full((T, S), float('-inf'), device=x.device),
diagonal=S - T + 1
)
a = a + causal
if mask is not None:
a = a + mask
a = F.softmax(a, dim=-1)
out = torch.matmul(a, v).transpose(1, 2).reshape(B, T, -1)
return self.o(out), pres
class MomoFFN(nn.Module):
def __init__(self, cfg: MomoConfig):
super().__init__()
self.gate = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
self.up = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
self.down = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
class MomoBlock(nn.Module):
def __init__(self, cfg: MomoConfig):
super().__init__()
self.attn = MomoAttention(cfg)
self.ffn = MomoFFN(cfg)
self.norm1 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
self.norm2 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
def forward(self, x, mask=None, past=None, use_cache=False):
a, p = self.attn(self.norm1(x), mask, past, use_cache)
x = x + a
x = x + self.ffn(self.norm2(x))
return x, p
class MomoForCausalLM(PreTrainedModel):
config_class = MomoConfig
_no_split_modules = ["MomoBlock"]
_tied_weights_keys = ["lm_head.weight"]
# HF 4.40+ calls model.all_tied_weights_keys.keys() — must be a dict on the instance
all_tied_weights_keys = {"lm_head.weight": "embed.weight"}
def __init__(self, cfg: MomoConfig):
super().__init__(cfg)
self.embed = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
self.layers = nn.ModuleList([MomoBlock(cfg) for _ in range(cfg.num_hidden_layers)])
self.norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
# Tie weights now — HF post-load also calls get_output_embeddings to re-tie
self.lm_head.weight = self.embed.weight
self.grad_ckpt = cfg.use_gradient_checkpointing
self.apply(self._init_weights)
# HF calls these to re-tie after loading — must be defined
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, value):
self.embed = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, value):
self.lm_head = value
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, std=0.02)
def forward(
self,
input_ids=None,
attention_mask=None,
labels=None,
past_key_values=None,
use_cache=False,
**kwargs,
):
x = self.embed(input_ids)
pkvs = past_key_values or [None] * len(self.layers)
cache = []
for layer, past in zip(self.layers, pkvs):
if self.grad_ckpt and self.training:
def _fn(layer):
def fn(x):
out, _ = layer(x, mask=None, use_cache=False)
return out
return fn
x = torch.utils.checkpoint.checkpoint(
_fn(layer), x, use_reentrant=False
)
cache.append(None)
else:
x, p = layer(x, attention_mask, past, use_cache)
cache.append(p)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(
logits[..., :-1, :].contiguous().view(-1, logits.size(-1)),
labels[..., 1:].contiguous().view(-1),
ignore_index=-100,
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=cache if use_cache else None,
)
@torch.no_grad()
def generate(
self,
input_ids,
max_new_tokens=300,
temperature=0.75,
top_k=50,
top_p=0.92,
rep_penalty=1.1,
eos_token_id=None,
pad_token_id=None,
**kwargs,
):
self.eval()
gen = input_ids.clone()
past = None
for _ in range(max_new_tokens):
inp = gen if past is None else gen[:, -1:]
out = self(inp, use_cache=True, past_key_values=past)
past = out.past_key_values
logits = out.logits[:, -1, :].float()
if rep_penalty != 1.0:
for tok in set(gen[0].tolist()):
if logits[0, tok] > 0:
logits[0, tok] /= rep_penalty
else:
logits[0, tok] *= rep_penalty
logits = logits / max(temperature, 1e-6)
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, -1:]] = float('-inf')
if top_p < 1.0:
sl, si = torch.sort(logits, descending=True)
cp = torch.cumsum(F.softmax(sl, dim=-1), dim=-1)
sl[cp - F.softmax(sl, dim=-1) > top_p] = float('-inf')
logits.scatter_(1, si, sl)
next_tok = torch.multinomial(F.softmax(logits, dim=-1), 1)
gen = torch.cat([gen, next_tok], dim=1)
if eos_token_id is not None and (next_tok == eos_token_id).all():
break
return gen