|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from typing import Optional, Tuple, Union, List |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from transformers import PretrainedConfig |
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast |
|
|
from transformers.models.llama.modeling_llama import ( |
|
|
LlamaRMSNorm, |
|
|
LlamaMLP, |
|
|
LlamaAttention, |
|
|
LlamaRotaryEmbedding, |
|
|
apply_rotary_pos_emb, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NanoHammerConfig(PretrainedConfig): |
|
|
model_type = "nanohammer" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=128256, |
|
|
hidden_size=2048, |
|
|
intermediate_size=8192, |
|
|
num_hidden_layers=24, |
|
|
num_attention_heads=32, |
|
|
num_key_value_heads=8, |
|
|
num_state_heads=32, |
|
|
state_hidden_size=None, |
|
|
max_position_embeddings=131072, |
|
|
rms_norm_eps=1e-5, |
|
|
initializer_range=0.02, |
|
|
use_cache=True, |
|
|
tie_word_embeddings=False, |
|
|
rope_theta=10000.0, |
|
|
rope_scaling=None, |
|
|
attention_bias=False, |
|
|
attention_dropout=0.0, |
|
|
mlp_bias=False, |
|
|
hidden_act="silu", |
|
|
|
|
|
bos_token_id=128000, |
|
|
eos_token_id=128009, |
|
|
pad_token_id=None, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
if "auto_map" not in kwargs: |
|
|
kwargs["auto_map"] = { |
|
|
"AutoConfig": "NanoHammerForCausalLM.NanoHammerConfig", |
|
|
"AutoModelForCausalLM": "NanoHammerForCausalLM.NanoHammerForCausalLM", |
|
|
} |
|
|
|
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.num_key_value_heads = num_key_value_heads |
|
|
self.num_state_heads = num_state_heads |
|
|
|
|
|
|
|
|
self.state_hidden_size = state_hidden_size if state_hidden_size is not None else hidden_size / 4 |
|
|
|
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.rms_norm_eps = rms_norm_eps |
|
|
self.initializer_range = initializer_range |
|
|
self.use_cache = use_cache |
|
|
self.rope_theta = rope_theta |
|
|
self.rope_scaling = rope_scaling |
|
|
self.attention_bias = attention_bias |
|
|
self.attention_dropout = attention_dropout |
|
|
self.mlp_bias = mlp_bias |
|
|
self.hidden_act = hidden_act |
|
|
|
|
|
super().__init__( |
|
|
tie_word_embeddings=tie_word_embeddings, |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
pad_token_id=pad_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HolographicRotaryEmbedding(nn.Module): |
|
|
""" |
|
|
全息旋转位置编码 - 为积分状态注入时间特征 |
|
|
|
|
|
核心思想: |
|
|
- 对每个位置 i,应用复数域旋转:x_i * e^(i*θ_k) |
|
|
- 积分后:S_t = Σ(x_i * e^(i*θ_k)),状态成为"多项式系数容器" |
|
|
- 通过逆旋转 R_{-t} 转换为相对坐标系,实现平移不变性 |
|
|
|
|
|
关键修正:使用绝对 position_ids 而非相对 seq_len |
|
|
""" |
|
|
def __init__(self, dim, max_position_embeddings=131072, base=10000): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.base = base |
|
|
|
|
|
|
|
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
def forward(self, x, position_ids): |
|
|
""" |
|
|
应用旋转位置编码(使用绝对位置) |
|
|
|
|
|
Args: |
|
|
x: (B, T, D) - 输入张量 |
|
|
position_ids: (B, T) - 绝对位置索引 |
|
|
Returns: |
|
|
x_rotated: (B, T, D) - 应用旋转编码后的张量 |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
freqs = torch.einsum("bt,d->btd", position_ids.to(x.dtype), self.inv_freq.to(x.dtype)) |
|
|
|
|
|
|
|
|
cos = freqs.cos() |
|
|
sin = freqs.sin() |
|
|
|
|
|
|
|
|
cos = torch.cat([cos, cos], dim=-1) |
|
|
sin = torch.cat([sin, sin], dim=-1) |
|
|
|
|
|
|
|
|
x1 = x[..., 0::2] |
|
|
x2 = x[..., 1::2] |
|
|
|
|
|
x1_rotated = x1 * cos[..., 0::2] - x2 * sin[..., 0::2] |
|
|
x2_rotated = x1 * sin[..., 1::2] + x2 * cos[..., 1::2] |
|
|
|
|
|
x_rotated = torch.stack([x1_rotated, x2_rotated], dim=-1).flatten(-2) |
|
|
return x_rotated |
|
|
|
|
|
def apply_inverse_rotation(self, x, position_ids): |
|
|
""" |
|
|
应用逆旋转,转换为相对坐标系(使用绝对位置) |
|
|
|
|
|
核心:S_t' = S_t * e^(-t*θ),将积分状态转换为相对视角 |
|
|
|
|
|
Args: |
|
|
x: (B, T, D) - 积分状态张量 |
|
|
position_ids: (B, T) - 绝对位置索引 |
|
|
Returns: |
|
|
x_relative: (B, T, D) - 相对坐标系下的状态 |
|
|
""" |
|
|
|
|
|
freqs = torch.einsum("bt,d->btd", position_ids.to(x.dtype), self.inv_freq.to(x.dtype)) |
|
|
|
|
|
|
|
|
cos = freqs.cos() |
|
|
sin = -freqs.sin() |
|
|
|
|
|
cos = torch.cat([cos, cos], dim=-1) |
|
|
sin = torch.cat([sin, sin], dim=-1) |
|
|
|
|
|
|
|
|
x1 = x[..., 0::2] |
|
|
x2 = x[..., 1::2] |
|
|
|
|
|
x1_relative = x1 * cos[..., 0::2] + x2 * sin[..., 0::2] |
|
|
x2_relative = -x1 * sin[..., 1::2] + x2 * cos[..., 1::2] |
|
|
|
|
|
x_relative = torch.stack([x1_relative, x2_relative], dim=-1).flatten(-2) |
|
|
return x_relative |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StateUpdateCell(nn.Module): |
|
|
""" |
|
|
Multi-Head State Update Cell - 欧拉法固定点迭代 |
|
|
|
|
|
在全息积分状态上进行非线性演化: |
|
|
- S_{t+1} = S_t + α·f(S_t) |
|
|
- 每个头在独立子空间迭代 |
|
|
- 可学习步长 α |
|
|
""" |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.hidden_size = config.state_hidden_size |
|
|
self.num_heads = config.num_state_heads |
|
|
self.head_dim = config.state_hidden_size // config.num_state_heads |
|
|
|
|
|
assert config.state_hidden_size % config.num_state_heads == 0 |
|
|
|
|
|
|
|
|
self.pre_norm = LlamaRMSNorm(config.state_hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(config.state_hidden_size, config.state_hidden_size * 4, bias=False), |
|
|
nn.SiLU(), |
|
|
nn.Linear(config.state_hidden_size * 4, config.state_hidden_size, bias=False) |
|
|
) |
|
|
|
|
|
|
|
|
self.post_norm = LlamaRMSNorm(config.state_hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.step_size = nn.Parameter(torch.ones(self.num_heads) * 0.1) |
|
|
|
|
|
def forward(self, state): |
|
|
""" |
|
|
欧拉法更新:S_{t+1} = S_t + α * f(S_t) |
|
|
|
|
|
Args: |
|
|
state: (B, T, state_hidden_size) |
|
|
Returns: |
|
|
state: (B, T, state_hidden_size) |
|
|
""" |
|
|
batch_size, seq_len, _ = state.shape |
|
|
|
|
|
|
|
|
state_normed = self.pre_norm(state) |
|
|
|
|
|
|
|
|
delta = self.mlp(state_normed) |
|
|
|
|
|
|
|
|
state_heads = state.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
|
delta_heads = delta.view(batch_size, seq_len, self.num_heads, self.head_dim) |
|
|
|
|
|
|
|
|
step_size = self.step_size.view(1, 1, self.num_heads, 1) |
|
|
state_heads = state_heads + step_size * delta_heads |
|
|
|
|
|
|
|
|
state = state_heads.view(batch_size, seq_len, self.hidden_size) |
|
|
|
|
|
|
|
|
state = self.post_norm(state) |
|
|
return state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StateTokenProjection(nn.Module): |
|
|
""" |
|
|
State -> Token 投影:将全息积分状态投影为 hidden_size 维度的 token |
|
|
|
|
|
这个 state token 将被添加到序列开头,参与标准的 Llama attention |
|
|
""" |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.state_to_hidden = nn.Linear(config.state_hidden_size, config.hidden_size, bias=False) |
|
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
def forward(self, state: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
state: (B, T, state_hidden_size) - 全息积分状态 |
|
|
Returns: |
|
|
state_token: (B, T, hidden_size) - 投影后的 state token |
|
|
""" |
|
|
state_token = self.state_to_hidden(state) |
|
|
state_token = self.norm(state_token) |
|
|
return state_token |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HybridNanoHammerDecoderLayer(nn.Module): |
|
|
""" |
|
|
NanoHammer 解码器层:因果 State Tokens 前缀 + 标准 Llama Attention + MLP |
|
|
|
|
|
流程: |
|
|
1. State 更新:非线性演化全息积分状态 |
|
|
2. 因果 State Tokens 生成: |
|
|
- 对于位置 i,生成 state_token_i(包含截止到 i-1 的累积状态) |
|
|
- 所有 state tokens 作为前缀:[state_0, state_1, ..., state_{T-1}, hidden_0, ..., hidden_{T-1}] |
|
|
- 序列长度变为 2T |
|
|
3. 特殊 Attention Mask: |
|
|
- hidden_i 可以 attend 到:state_j (j <= i) 和 hidden_j (j < i) |
|
|
- 确保因果性:位置 i 只能看到它之前的历史状态 |
|
|
4. Self-Attention:标准 Llama attention(序列长度 = 2T) |
|
|
5. MLP:前馈网络 |
|
|
6. 移除 state tokens:只返回 hidden states 部分 |
|
|
|
|
|
关键设计: |
|
|
- State tokens 作为"全局前缀",每个对应一个位置的历史 |
|
|
- 通过精心设计的 attention mask 确保因果性 |
|
|
- 优雅地保留了"state token 在前"的架构理念 |
|
|
|
|
|
增量生成模式 (O(1) per step): |
|
|
- 只处理一个新 token + 一个累积 state token |
|
|
- state 已经包含了所有历史信息(全息积分的累积性质) |
|
|
- 复杂度:O(1) per token generation step |
|
|
""" |
|
|
def __init__(self, config, layer_idx: int, holographic_rope: HolographicRotaryEmbedding): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.layer_idx = layer_idx |
|
|
self.holographic_rope = holographic_rope |
|
|
|
|
|
|
|
|
self.state_cell = StateUpdateCell(config) |
|
|
|
|
|
|
|
|
self.state_projection = StateTokenProjection(config) |
|
|
|
|
|
|
|
|
self.self_attn = LlamaAttention(config, layer_idx) |
|
|
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
|
|
|
self.rotary_emb = LlamaRotaryEmbedding(config=config) |
|
|
|
|
|
|
|
|
self.mlp = LlamaMLP(config) |
|
|
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
state: torch.Tensor, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_state: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]] = None, |
|
|
use_cache: bool = False, |
|
|
**kwargs |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]]]: |
|
|
""" |
|
|
Args: |
|
|
hidden_states: (B, T, hidden_size) |
|
|
state: (B, T_state, state_hidden_size) |
|
|
position_ids: (B, T) |
|
|
attention_mask: (B, 1, T, T) |
|
|
past_state: 元组 (state_last, state_tokens, position_ids, kv_cache) 或 None |
|
|
use_cache: bool |
|
|
Returns: |
|
|
hidden_states, state, present_state |
|
|
""" |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
|
|
|
is_incremental = past_state is not None and seq_len == 1 |
|
|
|
|
|
if is_incremental: |
|
|
state_last, state_tokens_cache, position_ids_cache, kv_cache = past_state |
|
|
return self._forward_incremental( |
|
|
hidden_states, state, position_ids, |
|
|
state_last, state_tokens_cache, position_ids_cache, kv_cache |
|
|
) |
|
|
else: |
|
|
return self._forward_prefill( |
|
|
hidden_states, state, position_ids, attention_mask, use_cache |
|
|
) |
|
|
|
|
|
def _forward_incremental( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
state: torch.Tensor, |
|
|
position_ids: torch.LongTensor, |
|
|
past_state: torch.Tensor, |
|
|
past_state_tokens: torch.Tensor, |
|
|
past_position_ids: torch.LongTensor, |
|
|
past_kv_cache: Tuple[torch.Tensor, torch.Tensor], |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]]: |
|
|
""" |
|
|
增量生成:使用 KV cache 实现 O(T) 每步复杂度 |
|
|
|
|
|
缓存结构: (state_last, state_tokens, position_ids, kv_cache) |
|
|
复杂度: O(T) 每步 (attention only, QKV projection is O(1)) |
|
|
""" |
|
|
batch_size = hidden_states.shape[0] |
|
|
device = hidden_states.device |
|
|
dtype = hidden_states.dtype |
|
|
T_past = past_state_tokens.shape[1] |
|
|
|
|
|
|
|
|
state_updated = self.state_cell(state) |
|
|
|
|
|
|
|
|
state_relative = self.holographic_rope.apply_inverse_rotation(past_state, position_ids) |
|
|
new_state_token = self.state_projection(state_relative) |
|
|
|
|
|
|
|
|
all_state_tokens = torch.cat([past_state_tokens, new_state_token], dim=1) |
|
|
all_position_ids = torch.cat([past_position_ids, position_ids], dim=1) |
|
|
|
|
|
|
|
|
query_input = torch.cat([new_state_token, hidden_states], dim=1) |
|
|
|
|
|
|
|
|
query_position_ids = torch.cat([position_ids, position_ids], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kv_len = 2 * T_past + 2 |
|
|
attention_mask = torch.full( |
|
|
(batch_size, 1, 2, kv_len), |
|
|
torch.finfo(dtype).min, |
|
|
dtype=dtype, |
|
|
device=device |
|
|
) |
|
|
|
|
|
attention_mask[:, :, 0, :T_past] = 0 |
|
|
attention_mask[:, :, 0, 2*T_past] = 0 |
|
|
|
|
|
|
|
|
attention_mask[:, :, 1, :T_past] = 0 |
|
|
attention_mask[:, :, 1, T_past:2*T_past] = 0 |
|
|
attention_mask[:, :, 1, 2*T_past] = 0 |
|
|
attention_mask[:, :, 1, 2*T_past+1] = 0 |
|
|
|
|
|
|
|
|
residual = query_input |
|
|
query_normed = self.input_layernorm(query_input) |
|
|
|
|
|
|
|
|
head_dim = self.self_attn.head_dim |
|
|
num_heads = self.self_attn.config.num_attention_heads |
|
|
num_kv_heads = self.self_attn.config.num_key_value_heads |
|
|
num_kv_groups = num_heads // num_kv_heads |
|
|
|
|
|
query_states = self.self_attn.q_proj(query_normed) |
|
|
key_states = self.self_attn.k_proj(query_normed) |
|
|
value_states = self.self_attn.v_proj(query_normed) |
|
|
|
|
|
|
|
|
query_states = query_states.view(batch_size, 2, num_heads, head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(batch_size, 2, num_kv_heads, head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(batch_size, 2, num_kv_heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
position_embeddings = self.rotary_emb(query_normed, query_position_ids) |
|
|
cos, sin = position_embeddings |
|
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
|
|
|
past_keys, past_values = past_kv_cache |
|
|
key_states = torch.cat([past_keys, key_states], dim=2) |
|
|
value_states = torch.cat([past_values, value_states], dim=2) |
|
|
|
|
|
|
|
|
|
|
|
key_states_expanded = key_states.repeat_interleave(num_kv_groups, dim=1) |
|
|
value_states_expanded = value_states.repeat_interleave(num_kv_groups, dim=1) |
|
|
|
|
|
attn_weights = torch.matmul(query_states, key_states_expanded.transpose(2, 3)) * (head_dim ** -0.5) |
|
|
attn_weights = attn_weights + attention_mask |
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
|
attn_output = torch.matmul(attn_weights, value_states_expanded) |
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, 2, -1) |
|
|
attn_output = self.self_attn.o_proj(attn_output) |
|
|
|
|
|
combined_output = residual + attn_output |
|
|
|
|
|
|
|
|
residual = combined_output |
|
|
combined_output = self.post_attention_layernorm(combined_output) |
|
|
combined_output = self.mlp(combined_output) |
|
|
combined_output = residual + combined_output |
|
|
|
|
|
|
|
|
hidden_output = combined_output[:, 1:, :] |
|
|
|
|
|
|
|
|
new_kv_cache = (key_states, value_states) |
|
|
present_state = (state_updated, all_state_tokens, all_position_ids, new_kv_cache) |
|
|
|
|
|
return hidden_output, state_updated, present_state |
|
|
|
|
|
def _forward_prefill( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
state: torch.Tensor, |
|
|
position_ids: torch.LongTensor, |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
use_cache: bool, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple]]]: |
|
|
""" |
|
|
Prefill 模式:处理完整序列 |
|
|
|
|
|
返回的 present_state: (state_last, state_tokens, position_ids, kv_cache) |
|
|
""" |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
|
|
|
state = self.state_cell(state) |
|
|
|
|
|
|
|
|
if position_ids is None: |
|
|
raise ValueError("position_ids is required for inverse rotation") |
|
|
|
|
|
state_shifted = torch.cat([ |
|
|
torch.zeros_like(state[:, :1, :]), |
|
|
state[:, :-1, :] |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
state_relative = self.holographic_rope.apply_inverse_rotation(state_shifted, position_ids) |
|
|
|
|
|
|
|
|
state_tokens = self.state_projection(state_relative) |
|
|
|
|
|
if state_tokens.shape[1] > seq_len: |
|
|
state_tokens = state_tokens[:, -seq_len:, :] |
|
|
|
|
|
|
|
|
combined_input = torch.cat([state_tokens, hidden_states], dim=1) |
|
|
|
|
|
|
|
|
if position_ids.shape[1] > seq_len: |
|
|
current_position_ids = position_ids[:, -seq_len:] |
|
|
else: |
|
|
current_position_ids = position_ids |
|
|
|
|
|
extended_position_ids = torch.cat([current_position_ids, current_position_ids], dim=1) |
|
|
|
|
|
|
|
|
extended_mask = torch.full( |
|
|
(batch_size, 1, 2 * seq_len, 2 * seq_len), |
|
|
torch.finfo(combined_input.dtype).min, |
|
|
dtype=combined_input.dtype, |
|
|
device=combined_input.device |
|
|
) |
|
|
|
|
|
for i in range(seq_len): |
|
|
extended_mask[:, :, i, :i+1] = 0 |
|
|
|
|
|
for i in range(seq_len): |
|
|
extended_mask[:, :, seq_len + i, :i+1] = 0 |
|
|
if i > 0: |
|
|
extended_mask[:, :, seq_len + i, seq_len:seq_len+i] = 0 |
|
|
|
|
|
|
|
|
residual = combined_input |
|
|
hidden_states_normed = self.input_layernorm(combined_input) |
|
|
position_embeddings = self.rotary_emb(hidden_states_normed, extended_position_ids) |
|
|
|
|
|
|
|
|
head_dim = self.self_attn.head_dim |
|
|
num_heads = self.self_attn.config.num_attention_heads |
|
|
num_kv_heads = self.self_attn.config.num_key_value_heads |
|
|
num_kv_groups = num_heads // num_kv_heads |
|
|
|
|
|
query_states = self.self_attn.q_proj(hidden_states_normed) |
|
|
key_states = self.self_attn.k_proj(hidden_states_normed) |
|
|
value_states = self.self_attn.v_proj(hidden_states_normed) |
|
|
|
|
|
|
|
|
query_states = query_states.view(batch_size, 2 * seq_len, num_heads, head_dim).transpose(1, 2) |
|
|
key_states = key_states.view(batch_size, 2 * seq_len, num_kv_heads, head_dim).transpose(1, 2) |
|
|
value_states = value_states.view(batch_size, 2 * seq_len, num_kv_heads, head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
cos, sin = position_embeddings |
|
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb |
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
|
|
|
|
|
key_states_expanded = key_states.repeat_interleave(num_kv_groups, dim=1) |
|
|
value_states_expanded = value_states.repeat_interleave(num_kv_groups, dim=1) |
|
|
|
|
|
attn_weights = torch.matmul(query_states, key_states_expanded.transpose(2, 3)) * (head_dim ** -0.5) |
|
|
attn_weights = attn_weights + extended_mask |
|
|
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
|
attn_output = torch.matmul(attn_weights, value_states_expanded) |
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, 2 * seq_len, -1) |
|
|
attn_output = self.self_attn.o_proj(attn_output) |
|
|
|
|
|
combined_output = residual + attn_output |
|
|
|
|
|
|
|
|
residual = combined_output |
|
|
combined_output = self.post_attention_layernorm(combined_output) |
|
|
combined_output = self.mlp(combined_output) |
|
|
combined_output = residual + combined_output |
|
|
|
|
|
|
|
|
hidden_states_output = combined_output[:, seq_len:, :] |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
kv_cache = (key_states, value_states) |
|
|
present_state = (state[:, -1:, :], state_tokens, current_position_ids, kv_cache) |
|
|
else: |
|
|
present_state = None |
|
|
|
|
|
return hidden_states_output, state, present_state |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NanoHammerPreTrainedModel(PreTrainedModel): |
|
|
config_class = NanoHammerConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["HybridNanoHammerDecoderLayer"] |
|
|
_skip_keys_device_placement = ["past_key_values"] |
|
|
_is_stateful = True |
|
|
|
|
|
@classmethod |
|
|
def _supports_default_dynamic_cache(cls) -> bool: |
|
|
""" |
|
|
返回 False 以禁用 HuggingFace 的 DynamicCache 系统。 |
|
|
NanoHammer 使用自定义的 state token 缓存结构。 |
|
|
""" |
|
|
return False |
|
|
|
|
|
def _init_weights(self, module): |
|
|
std = self.config.initializer_range |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
|
|
|
class NanoHammerModel(NanoHammerPreTrainedModel): |
|
|
""" |
|
|
NanoHammer 主模型 |
|
|
|
|
|
核心架构: |
|
|
1. Token Embedding + State 初始化(全息积分) |
|
|
2. N × HybridDecoderLayer(State Update + Self-Attn + Cross-Attn + MLP) |
|
|
3. Final Norm |
|
|
|
|
|
增量生成 (O(1) per step): |
|
|
- 使用 past_state_absolute 缓存每层的累积状态 |
|
|
- 每个生成步只处理一个 token + 一个 state token |
|
|
- 状态缓存大小固定为 O(num_layers * state_hidden_size) |
|
|
""" |
|
|
def __init__(self, config: NanoHammerConfig): |
|
|
super().__init__(config) |
|
|
self.padding_idx = config.pad_token_id |
|
|
self.vocab_size = config.vocab_size |
|
|
self.num_hidden_layers = config.num_hidden_layers |
|
|
self.state_hidden_size = int(config.state_hidden_size) |
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
|
|
|
|
|
|
self.token_to_state = nn.Linear(config.hidden_size, self.state_hidden_size, bias=False) |
|
|
|
|
|
|
|
|
self.holographic_rope = HolographicRotaryEmbedding( |
|
|
self.state_hidden_size, |
|
|
max_position_embeddings=config.max_position_embeddings |
|
|
) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
HybridNanoHammerDecoderLayer(config, layer_idx, self.holographic_rope) |
|
|
for layer_idx in range(config.num_hidden_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_state_absolute: Optional[Tuple] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs |
|
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
|
""" |
|
|
Args: |
|
|
past_state_absolute: 元组 (cumsum_cache, layer_caches) |
|
|
- cumsum_cache: (B, 1, state_hidden_size) |
|
|
- layer_caches: list of (state_last, state_tokens, position_ids) per layer |
|
|
""" |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
batch_size, seq_len, _ = inputs_embeds.shape |
|
|
device = inputs_embeds.device |
|
|
|
|
|
|
|
|
is_incremental = past_state_absolute is not None and seq_len == 1 |
|
|
|
|
|
|
|
|
if position_ids is None: |
|
|
if is_incremental: |
|
|
raise ValueError("position_ids must be provided in incremental mode") |
|
|
full_position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) |
|
|
full_position_ids = full_position_ids.unsqueeze(0).expand(batch_size, -1) |
|
|
else: |
|
|
full_position_ids = position_ids |
|
|
|
|
|
|
|
|
state_input = self.token_to_state(inputs_embeds) |
|
|
state_input_rotated = self.holographic_rope(state_input, full_position_ids) |
|
|
|
|
|
|
|
|
if is_incremental: |
|
|
|
|
|
cumsum_cache, layer_caches = past_state_absolute |
|
|
|
|
|
|
|
|
new_cumsum = cumsum_cache + state_input_rotated |
|
|
state = new_cumsum |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
new_layer_caches = [] |
|
|
|
|
|
for layer_idx, decoder_layer in enumerate(self.layers): |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
past_state_layer = layer_caches[layer_idx] |
|
|
|
|
|
hidden_states, state, present_state = decoder_layer( |
|
|
hidden_states, |
|
|
state, |
|
|
full_position_ids, |
|
|
attention_mask=None, |
|
|
past_state=past_state_layer, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
new_layer_caches.append(present_state) |
|
|
|
|
|
new_past_state_absolute = (new_cumsum, new_layer_caches) |
|
|
|
|
|
else: |
|
|
|
|
|
cumsum_state = torch.cumsum(state_input_rotated, dim=1) |
|
|
state = cumsum_state |
|
|
|
|
|
hidden_states = inputs_embeds |
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
layer_caches = [] |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
causal_attention_mask = torch.zeros( |
|
|
(batch_size, 1, seq_len, seq_len), |
|
|
dtype=inputs_embeds.dtype, |
|
|
device=device |
|
|
) |
|
|
causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool() |
|
|
causal_attention_mask[:, :, causal_mask] = torch.finfo(inputs_embeds.dtype).min |
|
|
elif attention_mask.dim() == 2: |
|
|
expanded_mask = torch.zeros( |
|
|
(batch_size, 1, seq_len, seq_len), |
|
|
dtype=inputs_embeds.dtype, |
|
|
device=device |
|
|
) |
|
|
causal_mask = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool() |
|
|
expanded_mask[:, :, causal_mask] = torch.finfo(inputs_embeds.dtype).min |
|
|
for b in range(batch_size): |
|
|
for i in range(seq_len): |
|
|
if attention_mask[b, i] == 0: |
|
|
expanded_mask[b, 0, :, i] = torch.finfo(inputs_embeds.dtype).min |
|
|
causal_attention_mask = expanded_mask |
|
|
else: |
|
|
causal_attention_mask = attention_mask |
|
|
|
|
|
for decoder_layer in self.layers: |
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
hidden_states, state, present_state = self._gradient_checkpointing_func( |
|
|
decoder_layer.__call__, |
|
|
hidden_states, |
|
|
state, |
|
|
full_position_ids, |
|
|
causal_attention_mask, |
|
|
None, |
|
|
use_cache, |
|
|
) |
|
|
else: |
|
|
hidden_states, state, present_state = decoder_layer( |
|
|
hidden_states, |
|
|
state, |
|
|
full_position_ids, |
|
|
causal_attention_mask, |
|
|
past_state=None, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
if use_cache and present_state is not None: |
|
|
layer_caches.append(present_state) |
|
|
|
|
|
|
|
|
if use_cache and layer_caches: |
|
|
cumsum_cache = cumsum_state[:, -1:, :] |
|
|
new_past_state_absolute = (cumsum_cache, layer_caches) |
|
|
else: |
|
|
new_past_state_absolute = None |
|
|
|
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states += (hidden_states,) |
|
|
|
|
|
|
|
|
if not return_dict: |
|
|
outputs = (hidden_states,) |
|
|
if output_hidden_states: |
|
|
outputs += (all_hidden_states,) |
|
|
if use_cache: |
|
|
outputs += (new_past_state_absolute,) |
|
|
return outputs |
|
|
|
|
|
return BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=new_past_state_absolute if use_cache else None, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
|
|
|
class NanoHammerForCausalLM(NanoHammerPreTrainedModel): |
|
|
""" |
|
|
NanoHammer 语言模型 |
|
|
|
|
|
增量生成:保存历史 state_tokens,复杂度 O(T × hidden_size) |
|
|
比传统 KV Cache O(T × 2 × hidden_size × num_layers) 节省 ~2×num_layers 倍内存 |
|
|
""" |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = NanoHammerModel(config) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
if config.tie_word_embeddings: |
|
|
self.lm_head.weight = self.model.embed_tokens.weight |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.model.embed_tokens = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[Tuple] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
position_ids=position_ids, |
|
|
past_state_absolute=past_key_values, |
|
|
inputs_embeds=inputs_embeds, |
|
|
use_cache=use_cache, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
hidden_states = outputs.last_hidden_state |
|
|
logits = self.lm_head(hidden_states) |
|
|
logits = logits.float() |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
|
shift_labels = shift_labels.view(-1) |
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + (outputs.hidden_states,) |
|
|
if use_cache: |
|
|
output = output + (outputs.past_key_values,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
inputs_embeds=None, |
|
|
**kwargs |
|
|
): |
|
|
"""准备生成输入""" |
|
|
position_ids = kwargs.get("position_ids", None) |
|
|
|
|
|
if past_key_values is not None: |
|
|
|
|
|
input_ids = input_ids[:, -1:] |
|
|
|
|
|
if position_ids is not None: |
|
|
position_ids = position_ids[:, -1:] |
|
|
elif attention_mask is not None: |
|
|
position_ids = attention_mask.long().cumsum(-1)[:, -1:] - 1 |
|
|
else: |
|
|
raise ValueError("position_ids or attention_mask required in incremental mode") |
|
|
|
|
|
inputs_embeds = None |
|
|
else: |
|
|
|
|
|
if position_ids is None and attention_mask is not None: |
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
|
|
|
|
if inputs_embeds is not None and input_ids.shape[1] != inputs_embeds.shape[1]: |
|
|
inputs_embeds = None |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids if inputs_embeds is None else None, |
|
|
"inputs_embeds": inputs_embeds, |
|
|
"position_ids": position_ids, |
|
|
"past_key_values": past_key_values, |
|
|
"use_cache": kwargs.get("use_cache", True), |
|
|
"attention_mask": attention_mask, |
|
|
} |
|
|
|
|
|
def _update_model_kwargs_for_generation( |
|
|
self, |
|
|
outputs, |
|
|
model_kwargs, |
|
|
is_encoder_decoder=False, |
|
|
num_new_tokens=1, |
|
|
): |
|
|
"""更新 model_kwargs""" |
|
|
|
|
|
model_kwargs["past_key_values"] = outputs.past_key_values |
|
|
|
|
|
|
|
|
if "attention_mask" in model_kwargs: |
|
|
attention_mask = model_kwargs["attention_mask"] |
|
|
model_kwargs["attention_mask"] = torch.cat( |
|
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], |
|
|
dim=-1, |
|
|
) |
|
|
|
|
|
return model_kwargs |
|
|
|