hermes-edge / model.py
bclermo's picture
Upload folder using huggingface_hub
0b1f228 verified
Raw
History Blame Contribute Delete
9.07 kB
"""Reference PyTorch implementation of the Hermes mobile transformer.
This is the *training* model. It is intentionally written with plain,
conversion-friendly PyTorch ops (no custom CUDA kernels, no flash-attention
calls) so that the same ``state_dict`` can be loaded by the LiteRT builder in
``scripts/convert_to_litertlm.py`` and traced by ``ai_edge_torch``.
Architecture: decoder-only, RMSNorm (pre-norm), rotary position embeddings,
grouped-query attention, and a SwiGLU feed-forward block — the same family as
Gemma / Llama, sized for on-device inference.
"""
from __future__ import annotations
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from hermes.config import HermesConfig
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return (x.to(dtype)) * self.weight
def build_rope_cache(
seq_len: int, head_dim: int, theta: float, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Precompute cos/sin tables for rotary position embeddings."""
inv_freq = 1.0 / (
theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim)
)
t = torch.arange(seq_len, device=device).float()
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def rotate_half(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rope(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# q, k: [B, H, T, D]; cos/sin: [T, D]
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_out = (q * cos) + (rotate_half(q) * sin)
k_out = (k * cos) + (rotate_half(k) * sin)
return q_out, k_out
class Attention(nn.Module):
"""Grouped-query attention with an optional incremental KV-cache."""
def __init__(self, config: HermesConfig) -> None:
super().__init__()
self.num_heads = config.num_heads
self.num_kv_heads = config.num_kv_heads
self.head_dim = config.head_dim
self.num_query_groups = config.num_query_groups
self.q_proj = nn.Linear(
config.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
config.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
config.hidden_size, self.num_kv_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, config.hidden_size, bias=False
)
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor],
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
b, t, _ = x.shape
q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(b, t, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(b, t, self.num_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rope(q, k, cos, sin)
if kv_cache is not None:
past_k, past_v = kv_cache
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
new_cache = (k, v)
# Expand KV heads to match query heads (GQA).
if self.num_query_groups > 1:
k = k.repeat_interleave(self.num_query_groups, dim=1)
v = v.repeat_interleave(self.num_query_groups, dim=1)
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
attn = attn + mask
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(b, t, -1)
return self.o_proj(out), new_cache
class FeedForward(nn.Module):
"""SwiGLU MLP: down(silu(gate(x)) * up(x))."""
def __init__(self, config: HermesConfig) -> None:
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class DecoderBlock(nn.Module):
def __init__(self, config: HermesConfig) -> None:
super().__init__()
self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.self_attn = Attention(config)
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.mlp = FeedForward(config)
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor],
kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
h, new_cache = self.self_attn(
self.input_layernorm(x), cos, sin, mask, kv_cache
)
x = x + h
x = x + self.mlp(self.post_attention_layernorm(x))
return x, new_cache
class HermesForCausalLM(nn.Module):
"""Full Hermes decoder-only language model with a causal LM head."""
def __init__(self, config: HermesConfig) -> None:
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[DecoderBlock(config) for _ in range(config.num_layers)]
)
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_embeddings:
self.lm_head.weight = self.embed_tokens.weight
cos, sin = build_rope_cache(
config.max_seq_len, config.head_dim, config.rope_theta, torch.device("cpu")
)
self.register_buffer("rope_cos", cos, persistent=False)
self.register_buffer("rope_sin", sin, persistent=False)
def forward(
self,
input_ids: torch.Tensor,
labels: Optional[torch.Tensor] = None,
start_pos: int = 0,
) -> dict:
b, t = input_ids.shape
x = self.embed_tokens(input_ids)
cos = self.rope_cos[start_pos : start_pos + t].to(x.device)
sin = self.rope_sin[start_pos : start_pos + t].to(x.device)
mask = torch.full((t, t), float("-inf"), device=x.device)
mask = torch.triu(mask, diagonal=1)
for layer in self.layers:
x, _ = layer(x, cos, sin, mask)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=self.config.pad_token_id,
)
return {"logits": logits, "loss": loss}
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 64,
temperature: float = 0.8,
top_k: int = 50,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""Minimal greedy/sampling loop — sanity check for trained weights."""
self.eval()
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
for _ in range(max_new_tokens):
ids = input_ids[:, -self.config.max_seq_len :]
logits = self.forward(ids)["logits"][:, -1, :]
if temperature > 0:
logits = logits / temperature
if top_k:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
else:
next_id = logits.argmax(dim=-1, keepdim=True)
input_ids = torch.cat([input_ids, next_id], dim=1)
if (next_id == eos_token_id).all():
break
return input_ids
def build_model(config: HermesConfig) -> HermesForCausalLM:
return HermesForCausalLM(config)