zzjrabbit3-pre1 / zzjrabbit3.py
xiaoyewuz-Ruster's picture
Upload TextGenerationPipeline
cb9b291 verified
from typing import Optional, Union
import torch
import torch.nn as nn
from tokenizers import Tokenizer, decoders, pre_tokenizers
from tokenizers.models import BPE
from transformers import (
GenerationMixin,
PreTrainedConfig,
PreTrainedModel,
TokenizersBackend,
)
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutput
class ZZJRabbit3Config(PreTrainedConfig):
model_type = "zzjrabbit3"
def __init__(
self,
vocab_size: int = 100000,
hidden_size: int = 1024,
num_hidden_layers: int = 12,
num_attention_heads: int = 8,
attention_dropout: float | int = 0.0,
pad_token_id: int | None = None,
eos_token_id: int | list[int] | None = None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.attention_dropout = attention_dropout
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
super().__init__(**kwargs)
class ZZJRabbit3RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000):
"""
Rotary Embedding 模块
Args:
dim: 每个 token embedding 的维度
max_position_embeddings: 最大位置数
base: rotary embedding 的频率基底
"""
super().__init__()
self.dim = dim
self.base = base
# 生成频率向量
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
# 可选:预先计算 cos/sin
t = torch.arange(max_position_embeddings, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, inv_freq)
self.register_buffer("cos_cached", freqs.cos())
self.register_buffer("sin_cached", freqs.sin())
def forward(self, position_ids):
"""
position_ids: (batch_size, seq_len)
返回:
cos: (batch_size, seq_len, dim)
sin: (batch_size, seq_len, dim)
"""
# 从缓存中选取对应位置
cos = self.cos_cached[position_ids] # shape (batch, seq_len, dim/2)
sin = self.sin_cached[position_ids]
# 将维度对齐为 (dim)
# cos/sin 当前 shape 为 (..., dim/2),重复到 dim
cos = torch.stack([cos, cos], dim=-1).flatten(-2)
sin = torch.stack([sin, sin], dim=-1).flatten(-2)
return cos, sin
def rotate_half(x):
"""[-x2, x1]"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, sin, cos):
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class ZZJRabbit3Attention(nn.Module):
def __init__(self, config: ZZJRabbit3Config):
super().__init__()
self.config = config
self.head_dim = config.hidden_size // config.num_attention_heads
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.k_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.v_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(0.1)
def forward(
self,
x: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
key_padding_mask: Optional[torch.BoolTensor] = None,
attn_mask: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
batch_size = x.size(0)
Q = (
self.q_proj(x)
.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
.transpose(1, 2)
)
K = (
self.k_proj(x)
.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
.transpose(1, 2)
)
V = (
self.v_proj(x)
.view(batch_size, -1, self.config.num_attention_heads, self.head_dim)
.transpose(1, 2)
)
cos, sin = position_embeddings
Q, K = apply_rotary_pos_emb(Q, K, sin.to(Q.dtype), cos.to(Q.dtype))
scores = torch.matmul(Q, K.transpose(-2, -1)) * (self.head_dim**-0.5)
if key_padding_mask is not None:
scores = scores.masked_fill(
key_padding_mask.view(batch_size, 1, 1, -1), float("-inf")
)
if attn_mask is not None:
scores = scores.masked_fill(attn_mask, float("-inf"))
attn_weights = nn.functional.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
context = torch.matmul(attn_weights, V)
context = context.transpose(1, 2).contiguous()
context = context.view(batch_size, -1, self.config.hidden_size)
return self.out_proj(context)
class ZZJRabbit3Layer(nn.Module):
def __init__(self, config: ZZJRabbit3Config):
super().__init__()
self.attn = ZZJRabbit3Attention(config)
self.l1 = nn.Linear(config.hidden_size, config.hidden_size)
self.l2 = nn.Linear(config.hidden_size, config.hidden_size)
self.activate = nn.ReLU()
self.norm = nn.RMSNorm(config.hidden_size)
def forward(
self,
x: torch.Tensor,
postition_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
key_padding_mask = None
attn_mask = torch.gt(
torch.triu(torch.ones(x.size(-2), x.size(-2), device=x.device), 1), 0
)
if attention_mask is not None:
key_padding_mask = torch.lt(attention_mask, 1)
attn = self.attn(
x,
postition_embeddings,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
)
x = self.norm(x + attn)
o = self.l1(x)
o = self.activate(o)
o = self.l2(o)
return self.norm(x + o)
class ZZJRabbit3Model(PreTrainedModel):
config_class = ZZJRabbit3Config
def __init__(self, config: ZZJRabbit3Config, **kwargs):
super().__init__(config, **kwargs)
self.config = config
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.rotary_emb = ZZJRabbit3RotaryEmbedding(
config.hidden_size // config.num_attention_heads
)
self.layers = nn.ModuleList(
[ZZJRabbit3Layer(config) for _ in range(config.num_hidden_layers)]
)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
return_dict: Optional[bool] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> tuple | BaseModelOutput:
res = self.embedding(input_ids)
batch_size, seq_len = input_ids.shape
position_ids = (
torch.arange(seq_len, device=input_ids.device)
.unsqueeze(0)
.expand(batch_size, -1)
)
position_embeddings = self.rotary_emb(position_ids)
for layer in self.layers:
res = layer(res, position_embeddings, attention_mask)
if not return_dict:
return (res,)
else:
return BaseModelOutput(res)
class ZZJRabbit3ForCausalLM(PreTrainedModel, GenerationMixin):
config_class = ZZJRabbit3Config
def __init__(self, config: ZZJRabbit3Config, **kwargs):
super().__init__(config, **kwargs)
self.model = ZZJRabbit3Model(config, **kwargs)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
return_dict: Optional[bool] = None,
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> tuple | CausalLMOutput:
hidden = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
logits = self.lm_head(
hidden[
:,
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep,
:,
]
)
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
return (loss, logits) if labels is not None else (logits,)
else:
return (
CausalLMOutput(logits=logits, loss=loss)
if labels is not None
else CausalLMOutput(logits=logits)
)
@classmethod
def can_generate(cls):
return True
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
class ZZJRabbit3Tokenizer(TokenizersBackend):
model = BPE
def __init__(
self,
vocab=None,
merges=None,
unk_token="<eos>",
eos_token="<eos>",
pad_token="<eos>",
**kwargs,
):
self._vocab = vocab or {
"<eos>": 0,
}
self._merges = merges or []
self._tokenizer = Tokenizer(
BPE(vocab=self._vocab, merges=self._merges, fuse_unk=True)
)
self._tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
self._tokenizer.decoder = decoders.ByteLevel()
super().__init__(
unk_token=unk_token, eos_token=eos_token, pad_token=pad_token, **kwargs
)