vermind / modeling_vermind.py
nev8r's picture
Upload VerMind model
5d054fe verified
# coding=utf-8
"""
Model file for VerMind model - Standalone Version
Contains complete implementation without external dependencies
"""
import math
from typing import Optional, Tuple, List, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, GenerationMixin, AutoModelForCausalLM
from transformers.activations import ACT2FN
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_vermind import VerMindConfig
# ==================== Base Module Functions ====================
def precompute_freqs_cis(dim: int, end: int = int(32 * 1024), rope_base: float = 1e6,
rope_scaling: Optional[dict] = None):
"""Precompute rotary position embedding frequencies"""
freqs, attn_factor = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)), 1.0
if rope_scaling is not None:
orig_max, factor, beta_fast, beta_slow, attn_factor = (
rope_scaling.get("original_max_position_embeddings", 2048),
rope_scaling.get("factor", 16),
rope_scaling.get("beta_fast", 32.0),
rope_scaling.get("beta_slow", 1.0),
rope_scaling.get("attention_factor", 1.0)
)
if end / orig_max > 1.0:
inv_dim = lambda b: (dim * math.log(orig_max / (b * 2 * math.pi))) / (2 * math.log(rope_base))
low, high = max(math.floor(inv_dim(beta_fast)), 0), min(math.ceil(inv_dim(beta_slow)), dim // 2 - 1)
ramp = torch.clamp((torch.arange(dim // 2, device=freqs.device).float() - low) / max(high - low, 0.001), 0, 1)
freqs = freqs * (1 - ramp + ramp / factor)
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1) * attn_factor
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1) * attn_factor
return freqs_cos, freqs_sin
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Apply rotary position embeddings to queries and keys"""
def rotate_half(x):
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
# 保存原始 dtype
orig_dtype = q.dtype
if position_ids is not None:
if position_ids.dim() == 1:
pos_ids = position_ids
cos_selected = cos[pos_ids]
sin_selected = sin[pos_ids]
cos_selected = cos_selected.unsqueeze(0).unsqueeze(2)
sin_selected = sin_selected.unsqueeze(0).unsqueeze(2)
else:
cos_selected = cos[position_ids]
sin_selected = sin[position_ids]
cos_selected = cos_selected.unsqueeze(2)
sin_selected = sin_selected.unsqueeze(2)
q_embed = (q * cos_selected) + (rotate_half(q) * sin_selected)
k_embed = (k * cos_selected) + (rotate_half(k) * sin_selected)
else:
seq_len = q.shape[1]
cos_s = cos[:seq_len]
sin_s = sin[:seq_len]
cos_s = cos_s.unsqueeze(0).unsqueeze(2)
sin_s = sin_s.unsqueeze(0).unsqueeze(2)
q_embed = (q * cos_s) + (rotate_half(q) * sin_s)
k_embed = (k * cos_s) + (rotate_half(k) * sin_s)
# 转回原始 dtype
q_embed = q_embed.to(orig_dtype)
k_embed = k_embed.to(orig_dtype)
return q_embed, k_embed
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Repeat key/value heads for GQA"""
bs, slen, num_key_value_heads, head_dim = x.shape
if n_rep == 1:
return x
return x[:, :, :, None, :].expand(bs, slen, num_key_value_heads, n_rep, head_dim).reshape(
bs, slen, num_key_value_heads * n_rep, head_dim
)
# ==================== Module Classes ====================
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization"""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self.weight * self._norm(x.float()).type_as(x)
class FeedForward(nn.Module):
"""SwiGLU Feed-Forward Network"""
def __init__(self, config: VerMindConfig):
super().__init__()
if config.intermediate_size is None:
intermediate_size = int(config.hidden_size * 8 / 3)
config.intermediate_size = 64 * ((intermediate_size + 64 - 1) // 64)
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.dropout = nn.Dropout(config.dropout)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.dropout(self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)))
class Attention(nn.Module):
"""Grouped Query Attention with RoPE"""
def __init__(self, args: VerMindConfig):
super().__init__()
self.num_key_value_heads = args.num_attention_heads if args.num_key_value_heads is None else args.num_key_value_heads
assert args.num_attention_heads % self.num_key_value_heads == 0
self.n_local_heads = args.num_attention_heads
self.n_local_kv_heads = self.num_key_value_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.hidden_size // args.num_attention_heads
self.q_proj = nn.Linear(args.hidden_size, args.num_attention_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(args.num_attention_heads * self.head_dim, args.hidden_size, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and args.flash_attn
def forward(self, x, position_embeddings, past_key_value=None, use_cache=False,
attention_mask=None, position_ids=None, cu_seqlens=None):
bsz, seq_len, _ = x.shape
# 获取权重的 dtype(模型加载时的 dtype)
weight_dtype = self.q_proj.weight.dtype
if x.dtype != weight_dtype:
x = x.to(weight_dtype)
xq, xk, xv = self.q_proj(x), self.k_proj(x), self.v_proj(x)
# 强制统一为权重 dtype(防止不同 proj 层 dtype 不一致)
xq = xq.to(weight_dtype)
xk = xk.to(weight_dtype)
xv = xv.to(weight_dtype)
xq = xq.view(bsz, seq_len, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seq_len, self.n_local_kv_heads, self.head_dim)
cos, sin = position_embeddings
xq, xk = apply_rotary_pos_emb(xq, xk, cos, sin, position_ids=position_ids)
if past_key_value is not None:
xk = torch.cat([past_key_value[0], xk], dim=1)
xv = torch.cat([past_key_value[1], xv], dim=1)
past_kv = (xk, xv) if use_cache else None
xq, xk, xv = xq.transpose(1, 2), repeat_kv(xk, self.n_rep).transpose(1, 2), repeat_kv(xv, self.n_rep).transpose(1, 2)
is_2d_mask = attention_mask is not None and attention_mask.dim() == 3
attn_mask_for_flash = None
use_flash = False
if self.flash and (seq_len > 1) and (past_key_value is None):
if attention_mask is None:
use_flash = True
attn_mask_for_flash = None
elif is_2d_mask:
use_flash = False
elif torch.all(attention_mask == 1):
use_flash = True
attn_mask_for_flash = None
else:
use_flash = False
if use_flash:
if attn_mask_for_flash is not None:
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=attn_mask_for_flash,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False
)
else:
output = F.scaled_dot_product_attention(
xq, xk, xv,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True
)
else:
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim)
if not is_2d_mask:
scores[:, :, :, -seq_len:] += torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=scores.device), diagonal=1)
if attention_mask is not None:
if is_2d_mask:
attention_mask = attention_mask[:, 0, :] if attention_mask.dim() == 3 else attention_mask
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = (1.0 - extended_attention_mask.float()) * -1e9
scores = scores + extended_attention_mask
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = scores @ xv
output = output.transpose(1, 2).reshape(bsz, seq_len, -1)
output = self.resid_dropout(self.o_proj(output))
return output, past_kv
# ==================== Main Model Classes ====================
class VerMindBlock(nn.Module):
"""Transformer Decoder Block"""
def __init__(self, layer_id: int, config: VerMindConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // config.num_attention_heads
self.self_attn = Attention(config)
self.layer_id = layer_id
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = FeedForward(config)
def forward(self, hidden_states, position_embeddings, past_key_value=None, use_cache=False,
attention_mask=None, position_ids=None, cu_seqlens=None):
residual = hidden_states
hidden_states, present_key_value = self.self_attn(
self.input_layernorm(hidden_states),
position_embeddings,
past_key_value,
use_cache,
attention_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens
)
hidden_states += residual
hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))
return hidden_states, present_key_value
class VerMindModel(nn.Module):
"""VerMind Model (Transformer backbone)"""
def __init__(self, config: VerMindConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
self.layers = nn.ModuleList([VerMindBlock(l, config) for l in range(self.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
freqs_cos, freqs_sin = precompute_freqs_cis(
dim=config.hidden_size // config.num_attention_heads,
end=config.max_position_embeddings,
rope_base=config.rope_theta,
rope_scaling=config.rope_scaling
)
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
def forward(self, input_ids=None, attention_mask=None, past_key_values=None,
use_cache=False, position_ids=None, cu_seqlens=None, **kwargs):
if past_key_values is not None and hasattr(past_key_values, 'layers'):
past_key_values = None
past_key_values = past_key_values or [None] * len(self.layers)
start_pos = past_key_values[0][0].shape[1] if past_key_values[0] is not None else 0
hidden_states = self.dropout(self.embed_tokens(input_ids))
position_embeddings = (self.freqs_cos, self.freqs_sin)
presents = []
for layer_idx, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)):
hidden_states, present = layer(
hidden_states,
position_embeddings,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens
)
presents.append(present)
hidden_states = self.norm(hidden_states)
aux_loss = 0
return hidden_states, presents, aux_loss
class VerMindForCausalLM(PreTrainedModel, GenerationMixin):
"""VerMind Causal Language Model"""
config_class = VerMindConfig
def __init__(self, config: VerMindConfig = None):
self.config = config or VerMindConfig()
super().__init__(self.config)
self.model = VerMindModel(self.config)
self.lm_head = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.model.embed_tokens.weight = self.lm_head.weight
def forward(self, input_ids=None, attention_mask=None, labels=None,
past_key_values=None, use_cache=False, logits_to_keep=0,
position_ids=None, cu_seqlens=None, **args):
hidden_states, past_key_values, aux_loss = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
**args
)
is_varlen = cu_seqlens is not None
if is_varlen:
logits = self.lm_head(hidden_states)
else:
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
if is_varlen:
shift_logits = logits[:-1, :].contiguous()
shift_labels = labels[1:].contiguous()
loss = F.cross_entropy(shift_logits, shift_labels, ignore_index=-100)
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100)
output = CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=past_key_values, hidden_states=hidden_states)
output.aux_loss = aux_loss
return output
# Register the model class
AutoModelForCausalLM.register(VerMindForCausalLM.config_class, VerMindForCausalLM)
__all__ = ["VerMindForCausalLM", "VerMindModel", "VerMindBlock", "Attention", "FeedForward", "RMSNorm"]