"""Reference PyTorch implementation of the Hermes mobile transformer. This is the *training* model. It is intentionally written with plain, conversion-friendly PyTorch ops (no custom CUDA kernels, no flash-attention calls) so that the same ``state_dict`` can be loaded by the LiteRT builder in ``scripts/convert_to_litertlm.py`` and traced by ``ai_edge_torch``. Architecture: decoder-only, RMSNorm (pre-norm), rotary position embeddings, grouped-query attention, and a SwiGLU feed-forward block — the same family as Gemma / Llama, sized for on-device inference. """ from __future__ import annotations import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from hermes.config import HermesConfig class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6) -> None: super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = x.dtype x = x.float() x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return (x.to(dtype)) * self.weight def build_rope_cache( seq_len: int, head_dim: int, theta: float, device: torch.device ) -> Tuple[torch.Tensor, torch.Tensor]: """Precompute cos/sin tables for rotary position embeddings.""" inv_freq = 1.0 / ( theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim) ) t = torch.arange(seq_len, device=device).float() freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos(), emb.sin() def rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rope( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # q, k: [B, H, T, D]; cos/sin: [T, D] cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q_out = (q * cos) + (rotate_half(q) * sin) k_out = (k * cos) + (rotate_half(k) * sin) return q_out, k_out class Attention(nn.Module): """Grouped-query attention with an optional incremental KV-cache.""" def __init__(self, config: HermesConfig) -> None: super().__init__() self.num_heads = config.num_heads self.num_kv_heads = config.num_kv_heads self.head_dim = config.head_dim self.num_query_groups = config.num_query_groups self.q_proj = nn.Linear( config.hidden_size, self.num_heads * self.head_dim, bias=False ) self.k_proj = nn.Linear( config.hidden_size, self.num_kv_heads * self.head_dim, bias=False ) self.v_proj = nn.Linear( config.hidden_size, self.num_kv_heads * self.head_dim, bias=False ) self.o_proj = nn.Linear( self.num_heads * self.head_dim, config.hidden_size, bias=False ) def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor], kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: b, t, _ = x.shape q = self.q_proj(x).view(b, t, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(b, t, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(b, t, self.num_kv_heads, self.head_dim).transpose(1, 2) q, k = apply_rope(q, k, cos, sin) if kv_cache is not None: past_k, past_v = kv_cache k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) new_cache = (k, v) # Expand KV heads to match query heads (GQA). if self.num_query_groups > 1: k = k.repeat_interleave(self.num_query_groups, dim=1) v = v.repeat_interleave(self.num_query_groups, dim=1) attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: attn = attn + mask attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype) out = torch.matmul(attn, v) out = out.transpose(1, 2).contiguous().view(b, t, -1) return self.o_proj(out), new_cache class FeedForward(nn.Module): """SwiGLU MLP: down(silu(gate(x)) * up(x)).""" def __init__(self, config: HermesConfig) -> None: super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class DecoderBlock(nn.Module): def __init__(self, config: HermesConfig) -> None: super().__init__() self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.self_attn = Attention(config) self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.mlp = FeedForward(config) def forward( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, mask: Optional[torch.Tensor], kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: h, new_cache = self.self_attn( self.input_layernorm(x), cos, sin, mask, kv_cache ) x = x + h x = x + self.mlp(self.post_attention_layernorm(x)) return x, new_cache class HermesForCausalLM(nn.Module): """Full Hermes decoder-only language model with a causal LM head.""" def __init__(self, config: HermesConfig) -> None: super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [DecoderBlock(config) for _ in range(config.num_layers)] ) self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) if config.tie_embeddings: self.lm_head.weight = self.embed_tokens.weight cos, sin = build_rope_cache( config.max_seq_len, config.head_dim, config.rope_theta, torch.device("cpu") ) self.register_buffer("rope_cos", cos, persistent=False) self.register_buffer("rope_sin", sin, persistent=False) def forward( self, input_ids: torch.Tensor, labels: Optional[torch.Tensor] = None, start_pos: int = 0, ) -> dict: b, t = input_ids.shape x = self.embed_tokens(input_ids) cos = self.rope_cos[start_pos : start_pos + t].to(x.device) sin = self.rope_sin[start_pos : start_pos + t].to(x.device) mask = torch.full((t, t), float("-inf"), device=x.device) mask = torch.triu(mask, diagonal=1) for layer in self.layers: x, _ = layer(x, cos, sin, mask) x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=self.config.pad_token_id, ) return {"logits": logits, "loss": loss} @torch.no_grad() def generate( self, input_ids: torch.Tensor, max_new_tokens: int = 64, temperature: float = 0.8, top_k: int = 50, eos_token_id: Optional[int] = None, ) -> torch.Tensor: """Minimal greedy/sampling loop — sanity check for trained weights.""" self.eval() eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id for _ in range(max_new_tokens): ids = input_ids[:, -self.config.max_seq_len :] logits = self.forward(ids)["logits"][:, -1, :] if temperature > 0: logits = logits / temperature if top_k: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) else: next_id = logits.argmax(dim=-1, keepdim=True) input_ids = torch.cat([input_ids, next_id], dim=1) if (next_id == eos_token_id).all(): break return input_ids def build_model(config: HermesConfig) -> HermesForCausalLM: return HermesForCausalLM(config)