NanoHammer-1.5B-Instruct / NanoHammerForCausalLM.py
OzTianlu's picture
Upload NanoHammerForCausalLM.py
1688208 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, Union, List
from dataclasses import dataclass
from transformers import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaMLP,
LlamaAttention,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
)
# ============================================================================
# Configuration
# ============================================================================
class NanoHammerConfig(PretrainedConfig):
model_type = "nanohammer"
def __init__(
self,
vocab_size=128256,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_attention_heads=32,
num_key_value_heads=8, # GQA
num_state_heads=32,
state_hidden_size=None, # 默认为 hidden_size * 4
max_position_embeddings=131072,
rms_norm_eps=1e-5,
initializer_range=0.02,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
hidden_act="silu",
# NanoHammer 特有参数
bos_token_id=128000,
eos_token_id=128009,
pad_token_id=None,
**kwargs
):
# 设置 auto_map 以支持 trust_remote_code
if "auto_map" not in kwargs:
kwargs["auto_map"] = {
"AutoConfig": "NanoHammerForCausalLM.NanoHammerConfig",
"AutoModelForCausalLM": "NanoHammerForCausalLM.NanoHammerForCausalLM",
}
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.num_state_heads = num_state_heads
# 积分状态维度:默认为 hidden_size / 4,提升状态表征能力
self.state_hidden_size = state_hidden_size if state_hidden_size is not None else hidden_size / 4
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.hidden_act = hidden_act
super().__init__(
tie_word_embeddings=tie_word_embeddings,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
**kwargs
)
# ============================================================================
# 创新组件 1: 全息旋转位置编码
# ============================================================================
class HolographicRotaryEmbedding(nn.Module):
"""
全息旋转位置编码 - 为积分状态注入时间特征
核心思想:
- 对每个位置 i,应用复数域旋转:x_i * e^(i*θ_k)
- 积分后:S_t = Σ(x_i * e^(i*θ_k)),状态成为"多项式系数容器"
- 通过逆旋转 R_{-t} 转换为相对坐标系,实现平移不变性
关键修正:使用绝对 position_ids 而非相对 seq_len
"""
def __init__(self, dim, max_position_embeddings=131072, base=10000):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# 计算频率:θ_k = base^(-2k/d)
inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, position_ids):
"""
应用旋转位置编码(使用绝对位置)
Args:
x: (B, T, D) - 输入张量
position_ids: (B, T) - 绝对位置索引
Returns:
x_rotated: (B, T, D) - 应用旋转编码后的张量
"""
# position_ids: (B, T) -> (B, T, 1)
# inv_freq: (D/2,) -> (1, 1, D/2)
# freqs: (B, T, D/2)
freqs = torch.einsum("bt,d->btd", position_ids.to(x.dtype), self.inv_freq.to(x.dtype))
# 计算 cos 和 sin:(B, T, D/2)
cos = freqs.cos()
sin = freqs.sin()
# 扩展到完整维度:(B, T, D)
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([sin, sin], dim=-1)
# 应用旋转变换
x1 = x[..., 0::2]
x2 = x[..., 1::2]
x1_rotated = x1 * cos[..., 0::2] - x2 * sin[..., 0::2]
x2_rotated = x1 * sin[..., 1::2] + x2 * cos[..., 1::2]
x_rotated = torch.stack([x1_rotated, x2_rotated], dim=-1).flatten(-2)
return x_rotated
def apply_inverse_rotation(self, x, position_ids):
"""
应用逆旋转,转换为相对坐标系(使用绝对位置)
核心:S_t' = S_t * e^(-t*θ),将积分状态转换为相对视角
Args:
x: (B, T, D) - 积分状态张量
position_ids: (B, T) - 绝对位置索引
Returns:
x_relative: (B, T, D) - 相对坐标系下的状态
"""
# 使用绝对位置计算逆旋转
freqs = torch.einsum("bt,d->btd", position_ids.to(x.dtype), self.inv_freq.to(x.dtype))
# 计算 cos(-θ) 和 sin(-θ)
cos = freqs.cos()
sin = -freqs.sin() # 负号!逆旋转的关键
cos = torch.cat([cos, cos], dim=-1)
sin = torch.cat([sin, sin], dim=-1)
# 应用逆旋转
x1 = x[..., 0::2]
x2 = x[..., 1::2]
x1_relative = x1 * cos[..., 0::2] + x2 * sin[..., 0::2]
x2_relative = -x1 * sin[..., 1::2] + x2 * cos[..., 1::2]
x_relative = torch.stack([x1_relative, x2_relative], dim=-1).flatten(-2)
return x_relative
# ============================================================================
# 创新组件 2: 多头状态更新单元
# ============================================================================
class StateUpdateCell(nn.Module):
"""
Multi-Head State Update Cell - 欧拉法固定点迭代
在全息积分状态上进行非线性演化:
- S_{t+1} = S_t + α·f(S_t)
- 每个头在独立子空间迭代
- 可学习步长 α
"""
def __init__(self, config):
super().__init__()
self.hidden_size = config.state_hidden_size
self.num_heads = config.num_state_heads
self.head_dim = config.state_hidden_size // config.num_state_heads
assert config.state_hidden_size % config.num_state_heads == 0
# Pre-Norm
self.pre_norm = LlamaRMSNorm(config.state_hidden_size, eps=config.rms_norm_eps)
# MLP 更新函数
self.mlp = nn.Sequential(
nn.Linear(config.state_hidden_size, config.state_hidden_size * 4, bias=False),
nn.SiLU(),
nn.Linear(config.state_hidden_size * 4, config.state_hidden_size, bias=False)
)
# Post-Norm
self.post_norm = LlamaRMSNorm(config.state_hidden_size, eps=config.rms_norm_eps)
# 欧拉法步长(可学习,每个头独立)
self.step_size = nn.Parameter(torch.ones(self.num_heads) * 0.1)
def forward(self, state):
"""
欧拉法更新:S_{t+1} = S_t + α * f(S_t)
Args:
state: (B, T, state_hidden_size)
Returns:
state: (B, T, state_hidden_size)
"""
batch_size, seq_len, _ = state.shape
# Pre-Norm
state_normed = self.pre_norm(state)
# MLP 计算增量
delta = self.mlp(state_normed) # (B, T, state_hidden_size)
# Reshape 为多头
state_heads = state.view(batch_size, seq_len, self.num_heads, self.head_dim)
delta_heads = delta.view(batch_size, seq_len, self.num_heads, self.head_dim)
# 每个头独立步长更新
step_size = self.step_size.view(1, 1, self.num_heads, 1)
state_heads = state_heads + step_size * delta_heads
# Merge heads
state = state_heads.view(batch_size, seq_len, self.hidden_size)
# Post-Norm
state = self.post_norm(state)
return state
# ============================================================================
# 创新组件 4: State Token 投影
# ============================================================================
class StateTokenProjection(nn.Module):
"""
State -> Token 投影:将全息积分状态投影为 hidden_size 维度的 token
这个 state token 将被添加到序列开头,参与标准的 Llama attention
"""
def __init__(self, config):
super().__init__()
self.state_to_hidden = nn.Linear(config.state_hidden_size, config.hidden_size, bias=False)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(self, state: torch.Tensor) -> torch.Tensor:
"""
Args:
state: (B, T, state_hidden_size) - 全息积分状态
Returns:
state_token: (B, T, hidden_size) - 投影后的 state token
"""
state_token = self.state_to_hidden(state)
state_token = self.norm(state_token)
return state_token
# ============================================================================
# 创新组件 5: 混合 Decoder Layer(State Token + Llama Attention)
# ============================================================================
class HybridNanoHammerDecoderLayer(nn.Module):
"""
NanoHammer 解码器层:因果 State Tokens 前缀 + 标准 Llama Attention + MLP
流程:
1. State 更新:非线性演化全息积分状态
2. 因果 State Tokens 生成:
- 对于位置 i,生成 state_token_i(包含截止到 i-1 的累积状态)
- 所有 state tokens 作为前缀:[state_0, state_1, ..., state_{T-1}, hidden_0, ..., hidden_{T-1}]
- 序列长度变为 2T
3. 特殊 Attention Mask:
- hidden_i 可以 attend 到:state_j (j <= i) 和 hidden_j (j < i)
- 确保因果性:位置 i 只能看到它之前的历史状态
4. Self-Attention:标准 Llama attention(序列长度 = 2T)
5. MLP:前馈网络
6. 移除 state tokens:只返回 hidden states 部分
关键设计:
- State tokens 作为"全局前缀",每个对应一个位置的历史
- 通过精心设计的 attention mask 确保因果性
- 优雅地保留了"state token 在前"的架构理念
增量生成模式 (O(1) per step):
- 只处理一个新 token + 一个累积 state token
- state 已经包含了所有历史信息(全息积分的累积性质)
- 复杂度:O(1) per token generation step
"""
def __init__(self, config, layer_idx: int, holographic_rope: HolographicRotaryEmbedding):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.holographic_rope = holographic_rope
# 0. State 更新单元
self.state_cell = StateUpdateCell(config)
# 1. State Token 投影
self.state_projection = StateTokenProjection(config)
# 2. Self-Attention(标准 Llama)
self.self_attn = LlamaAttention(config, layer_idx)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Llama RoPE for position embeddings
self.rotary_emb = LlamaRotaryEmbedding(config=config)
# 3. MLP
self.mlp = LlamaMLP(config)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
state: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]] = None,
use_cache: bool = False,
**kwargs
) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]]]:
"""
Args:
hidden_states: (B, T, hidden_size)
state: (B, T_state, state_hidden_size)
position_ids: (B, T)
attention_mask: (B, 1, T, T)
past_state: 元组 (state_last, state_tokens, position_ids, kv_cache) 或 None
use_cache: bool
Returns:
hidden_states, state, present_state
"""
batch_size, seq_len, _ = hidden_states.shape
# 检测增量模式:past_state 存在且只有一个新 token
is_incremental = past_state is not None and seq_len == 1
if is_incremental:
state_last, state_tokens_cache, position_ids_cache, kv_cache = past_state
return self._forward_incremental(
hidden_states, state, position_ids,
state_last, state_tokens_cache, position_ids_cache, kv_cache
)
else:
return self._forward_prefill(
hidden_states, state, position_ids, attention_mask, use_cache
)
def _forward_incremental(
self,
hidden_states: torch.Tensor, # (B, 1, hidden_size)
state: torch.Tensor, # (B, 1, state_hidden_size) - 当前层的 state 输入
position_ids: torch.LongTensor, # (B, 1)
past_state: torch.Tensor, # (B, 1, state_hidden_size) - 上一步该层的 state_cell 输出
past_state_tokens: torch.Tensor, # (B, T_past, hidden_size) - 历史 state_tokens
past_position_ids: torch.LongTensor, # (B, T_past) - 历史位置
past_kv_cache: Tuple[torch.Tensor, torch.Tensor], # (keys, values) from previous steps
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]]:
"""
增量生成:使用 KV cache 实现 O(T) 每步复杂度
缓存结构: (state_last, state_tokens, position_ids, kv_cache)
复杂度: O(T) 每步 (attention only, QKV projection is O(1))
"""
batch_size = hidden_states.shape[0]
device = hidden_states.device
dtype = hidden_states.dtype
T_past = past_state_tokens.shape[1]
# 1. State 非线性更新
state_updated = self.state_cell(state) # (B, 1, state_hidden_size)
# 2. 生成当前位置的 state_token(使用 past_state,保证因果性)
state_relative = self.holographic_rope.apply_inverse_rotation(past_state, position_ids)
new_state_token = self.state_projection(state_relative) # (B, 1, hidden_size)
# 3. 更新 state_tokens 缓存 (用于下一步)
all_state_tokens = torch.cat([past_state_tokens, new_state_token], dim=1)
all_position_ids = torch.cat([past_position_ids, position_ids], dim=1)
# 4. 构建当前步的输入:只有 2 个 token [new_state_token, new_hidden_token]
query_input = torch.cat([new_state_token, hidden_states], dim=1) # (B, 2, hidden_size)
# 5. Position IDs for query: state 和 hidden 在同一位置
query_position_ids = torch.cat([position_ids, position_ids], dim=1) # (B, 2)
# 6. 构建 attention mask
# KV cache 布局: [state_0, ..., state_{T-1}, hidden_0, ..., hidden_{T-1}]
# Query 布局: [state_T, hidden_T]
# Query 0 (state_T): 看 state_0..state_T (indices 0..T in KV after concat)
# Query 1 (hidden_T): 看 state_0..state_T + hidden_0..hidden_{T-1} + self
kv_len = 2 * T_past + 2 # past states + past hiddens + new state + new hidden
attention_mask = torch.full(
(batch_size, 1, 2, kv_len),
torch.finfo(dtype).min,
dtype=dtype,
device=device
)
# state_T 可以看: past states (0..T_past-1) + new state (2*T_past)
attention_mask[:, :, 0, :T_past] = 0 # past states
attention_mask[:, :, 0, 2*T_past] = 0 # new state (self)
# hidden_T 可以看: all states + past hiddens + new hidden (self)
attention_mask[:, :, 1, :T_past] = 0 # past states
attention_mask[:, :, 1, T_past:2*T_past] = 0 # past hiddens
attention_mask[:, :, 1, 2*T_past] = 0 # new state
attention_mask[:, :, 1, 2*T_past+1] = 0 # new hidden (self)
# 7. LayerNorm
residual = query_input
query_normed = self.input_layernorm(query_input)
# 8. QKV 投影
head_dim = self.self_attn.head_dim
num_heads = self.self_attn.config.num_attention_heads
num_kv_heads = self.self_attn.config.num_key_value_heads
num_kv_groups = num_heads // num_kv_heads
query_states = self.self_attn.q_proj(query_normed)
key_states = self.self_attn.k_proj(query_normed)
value_states = self.self_attn.v_proj(query_normed)
# Reshape: (B, 2, H*D) -> (B, H, 2, D)
query_states = query_states.view(batch_size, 2, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, 2, num_kv_heads, head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, 2, num_kv_heads, head_dim).transpose(1, 2)
# 9. RoPE
position_embeddings = self.rotary_emb(query_normed, query_position_ids)
cos, sin = position_embeddings
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# 10. Concat KV cache
past_keys, past_values = past_kv_cache
key_states = torch.cat([past_keys, key_states], dim=2)
value_states = torch.cat([past_values, value_states], dim=2)
# 11. Attention
# Expand KV for GQA
key_states_expanded = key_states.repeat_interleave(num_kv_groups, dim=1)
value_states_expanded = value_states.repeat_interleave(num_kv_groups, dim=1)
attn_weights = torch.matmul(query_states, key_states_expanded.transpose(2, 3)) * (head_dim ** -0.5)
attn_weights = attn_weights + attention_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states_expanded)
# 12. Output projection
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, 2, -1)
attn_output = self.self_attn.o_proj(attn_output)
combined_output = residual + attn_output
# 13. MLP
residual = combined_output
combined_output = self.post_attention_layernorm(combined_output)
combined_output = self.mlp(combined_output)
combined_output = residual + combined_output
# 14. 提取 hidden state (第二个位置)
hidden_output = combined_output[:, 1:, :] # (B, 1, hidden_size)
# 15. 返回缓存
new_kv_cache = (key_states, value_states)
present_state = (state_updated, all_state_tokens, all_position_ids, new_kv_cache)
return hidden_output, state_updated, present_state
def _forward_prefill(
self,
hidden_states: torch.Tensor,
state: torch.Tensor,
position_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor],
use_cache: bool,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]]]:
"""
Prefill 模式:处理完整序列
返回的 present_state: (state_last, state_tokens, position_ids, kv_cache)
"""
batch_size, seq_len, _ = hidden_states.shape
# 0. State 非线性更新
state = self.state_cell(state)
# 1. 为每个位置创建因果 state
if position_ids is None:
raise ValueError("position_ids is required for inverse rotation")
state_shifted = torch.cat([
torch.zeros_like(state[:, :1, :]),
state[:, :-1, :]
], dim=1)
# 2. 逆旋转
state_relative = self.holographic_rope.apply_inverse_rotation(state_shifted, position_ids)
# 3. 投影为 state tokens
state_tokens = self.state_projection(state_relative)
if state_tokens.shape[1] > seq_len:
state_tokens = state_tokens[:, -seq_len:, :]
# 4. 拼接
combined_input = torch.cat([state_tokens, hidden_states], dim=1)
# 5. position_ids
if position_ids.shape[1] > seq_len:
current_position_ids = position_ids[:, -seq_len:]
else:
current_position_ids = position_ids
extended_position_ids = torch.cat([current_position_ids, current_position_ids], dim=1)
# 6. 构建因果 attention mask(原始设计)
extended_mask = torch.full(
(batch_size, 1, 2 * seq_len, 2 * seq_len),
torch.finfo(combined_input.dtype).min,
dtype=combined_input.dtype,
device=combined_input.device
)
for i in range(seq_len):
extended_mask[:, :, i, :i+1] = 0
for i in range(seq_len):
extended_mask[:, :, seq_len + i, :i+1] = 0
if i > 0:
extended_mask[:, :, seq_len + i, seq_len:seq_len+i] = 0
# 7. Self-Attention (手动计算以获取 KV cache)
residual = combined_input
hidden_states_normed = self.input_layernorm(combined_input)
position_embeddings = self.rotary_emb(hidden_states_normed, extended_position_ids)
# 手动 QKV 投影
head_dim = self.self_attn.head_dim
num_heads = self.self_attn.config.num_attention_heads
num_kv_heads = self.self_attn.config.num_key_value_heads
num_kv_groups = num_heads // num_kv_heads
query_states = self.self_attn.q_proj(hidden_states_normed)
key_states = self.self_attn.k_proj(hidden_states_normed)
value_states = self.self_attn.v_proj(hidden_states_normed)
# Reshape
query_states = query_states.view(batch_size, 2 * seq_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, 2 * seq_len, num_kv_heads, head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, 2 * seq_len, num_kv_heads, head_dim).transpose(1, 2)
# RoPE
cos, sin = position_embeddings
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Attention
key_states_expanded = key_states.repeat_interleave(num_kv_groups, dim=1)
value_states_expanded = value_states.repeat_interleave(num_kv_groups, dim=1)
attn_weights = torch.matmul(query_states, key_states_expanded.transpose(2, 3)) * (head_dim ** -0.5)
attn_weights = attn_weights + extended_mask
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states_expanded)
# Output projection
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, 2 * seq_len, -1)
attn_output = self.self_attn.o_proj(attn_output)
combined_output = residual + attn_output
# 8. MLP
residual = combined_output
combined_output = self.post_attention_layernorm(combined_output)
combined_output = self.mlp(combined_output)
combined_output = residual + combined_output
# 9. 提取 hidden states (output)
hidden_states_output = combined_output[:, seq_len:, :]
# 10. 准备缓存:(state_last, state_tokens, position_ids, kv_cache)
if use_cache:
kv_cache = (key_states, value_states)
present_state = (state[:, -1:, :], state_tokens, current_position_ids, kv_cache)
else:
present_state = None
return hidden_states_output, state, present_state
# ============================================================================
# 主模型
# ============================================================================
class NanoHammerPreTrainedModel(PreTrainedModel):
config_class = NanoHammerConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["HybridNanoHammerDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_is_stateful = True # 告诉 HuggingFace 我们使用自定义缓存
@classmethod
def _supports_default_dynamic_cache(cls) -> bool:
"""
返回 False 以禁用 HuggingFace 的 DynamicCache 系统。
NanoHammer 使用自定义的 state token 缓存结构。
"""
return False
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class NanoHammerModel(NanoHammerPreTrainedModel):
"""
NanoHammer 主模型
核心架构:
1. Token Embedding + State 初始化(全息积分)
2. N × HybridDecoderLayer(State Update + Self-Attn + Cross-Attn + MLP)
3. Final Norm
增量生成 (O(1) per step):
- 使用 past_state_absolute 缓存每层的累积状态
- 每个生成步只处理一个 token + 一个 state token
- 状态缓存大小固定为 O(num_layers * state_hidden_size)
"""
def __init__(self, config: NanoHammerConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.state_hidden_size = int(config.state_hidden_size)
# Token Embedding
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
# Token -> State 投影
self.token_to_state = nn.Linear(config.hidden_size, self.state_hidden_size, bias=False)
# 全息旋转位置编码
self.holographic_rope = HolographicRotaryEmbedding(
self.state_hidden_size,
max_position_embeddings=config.max_position_embeddings
)
# Decoder Layers
self.layers = nn.ModuleList([
HybridNanoHammerDecoderLayer(config, layer_idx, self.holographic_rope)
for layer_idx in range(config.num_hidden_layers)
])
# Final Norm
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_state_absolute: Optional[Tuple] = None, # 新缓存结构
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Args:
past_state_absolute: 元组 (cumsum_cache, layer_caches)
- cumsum_cache: (B, 1, state_hidden_size)
- layer_caches: list of (state_last, state_tokens, position_ids) per layer
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# 1. Token Embedding
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_len, _ = inputs_embeds.shape
device = inputs_embeds.device
# 检测增量模式
is_incremental = past_state_absolute is not None and seq_len == 1
# 2. Position IDs
if position_ids is None:
if is_incremental:
raise ValueError("position_ids must be provided in incremental mode")
full_position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
full_position_ids = full_position_ids.unsqueeze(0).expand(batch_size, -1)
else:
full_position_ids = position_ids
# 3. Token -> State 投影 + 全息旋转
state_input = self.token_to_state(inputs_embeds)
state_input_rotated = self.holographic_rope(state_input, full_position_ids)
# 4. 根据模式处理
if is_incremental:
# ============ 增量生成模式 ============
cumsum_cache, layer_caches = past_state_absolute
# 更新 cumsum
new_cumsum = cumsum_cache + state_input_rotated
state = new_cumsum
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
new_layer_caches = []
for layer_idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_state_layer = layer_caches[layer_idx]
hidden_states, state, present_state = decoder_layer(
hidden_states,
state,
full_position_ids,
attention_mask=None,
past_state=past_state_layer,
use_cache=True,
)
new_layer_caches.append(present_state)
new_past_state_absolute = (new_cumsum, new_layer_caches)
else:
# ============ Prefill 模式 ============
cumsum_state = torch.cumsum(state_input_rotated, dim=1)
state = cumsum_state
hidden_states = inputs_embeds
all_hidden_states = () if output_hidden_states else None
layer_caches = []
# attention mask
if attention_mask is None:
causal_attention_mask = torch.zeros(
(batch_size, 1, seq_len, seq_len),
dtype=inputs_embeds.dtype,
device=device
)
causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool()
causal_attention_mask[:, :, causal_mask] = torch.finfo(inputs_embeds.dtype).min
elif attention_mask.dim() == 2:
expanded_mask = torch.zeros(
(batch_size, 1, seq_len, seq_len),
dtype=inputs_embeds.dtype,
device=device
)
causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool()
expanded_mask[:, :, causal_mask] = torch.finfo(inputs_embeds.dtype).min
for b in range(batch_size):
for i in range(seq_len):
if attention_mask[b, i] == 0:
expanded_mask[b, 0, :, i] = torch.finfo(inputs_embeds.dtype).min
causal_attention_mask = expanded_mask
else:
causal_attention_mask = attention_mask
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states, state, present_state = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
state,
full_position_ids,
causal_attention_mask,
None,
use_cache,
)
else:
hidden_states, state, present_state = decoder_layer(
hidden_states,
state,
full_position_ids,
causal_attention_mask,
past_state=None,
use_cache=use_cache,
)
if use_cache and present_state is not None:
layer_caches.append(present_state)
# 准备缓存
if use_cache and layer_caches:
cumsum_cache = cumsum_state[:, -1:, :]
new_past_state_absolute = (cumsum_cache, layer_caches)
else:
new_past_state_absolute = None
# 5. Final Norm
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
# 6. 返回
if not return_dict:
outputs = (hidden_states,)
if output_hidden_states:
outputs += (all_hidden_states,)
if use_cache:
outputs += (new_past_state_absolute,)
return outputs
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=new_past_state_absolute if use_cache else None, # 使用 past_key_values 存储缓存
hidden_states=all_hidden_states,
attentions=None,
)
class NanoHammerForCausalLM(NanoHammerPreTrainedModel):
"""
NanoHammer 语言模型
增量生成:保存历史 state_tokens,复杂度 O(T × hidden_size)
比传统 KV Cache O(T × 2 × hidden_size × num_layers) 节省 ~2×num_layers 倍内存
"""
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = NanoHammerModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple] = None, # 使用标准接口
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_state_absolute=past_key_values, # 传递给 model
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=True,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + (outputs.hidden_states,)
if use_cache:
output = output + (outputs.past_key_values,)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values, # 使用标准字段
hidden_states=outputs.hidden_states,
attentions=None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs
):
"""准备生成输入"""
position_ids = kwargs.get("position_ids", None)
if past_key_values is not None:
# 增量模式:只取最后一个 token
input_ids = input_ids[:, -1:]
if position_ids is not None:
position_ids = position_ids[:, -1:]
elif attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1)[:, -1:] - 1
else:
raise ValueError("position_ids or attention_mask required in incremental mode")
inputs_embeds = None
else:
# Prefill 模式
if position_ids is None and attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if inputs_embeds is not None and input_ids.shape[1] != inputs_embeds.shape[1]:
inputs_embeds = None
return {
"input_ids": input_ids if inputs_embeds is None else None,
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache", True),
"attention_mask": attention_mask,
}
def _update_model_kwargs_for_generation(
self,
outputs,
model_kwargs,
is_encoder_decoder=False,
num_new_tokens=1,
):
"""更新 model_kwargs"""
# 更新 past_key_values
model_kwargs["past_key_values"] = outputs.past_key_values
# 扩展 attention_mask
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))],
dim=-1,
)
return model_kwargs