|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Optional, List, Tuple, Union
|
|
|
import math
|
|
|
import torch.utils.checkpoint
|
|
|
from transformers import PreTrainedModel, PretrainedConfig
|
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
|
|
|
|
class TernaryConfig(PretrainedConfig):
|
|
|
model_type = "ternary_transformer"
|
|
|
def __init__(
|
|
|
self,
|
|
|
vocab_size=50257,
|
|
|
hidden_size=3072,
|
|
|
num_hidden_layers=24,
|
|
|
num_attention_heads=32,
|
|
|
intermediate_size=12288,
|
|
|
max_position_embeddings=2048,
|
|
|
rms_norm_eps=1e-6,
|
|
|
dropout_rate=0.1,
|
|
|
window_size=512,
|
|
|
**kwargs
|
|
|
):
|
|
|
super().__init__(**kwargs)
|
|
|
self.vocab_size = vocab_size
|
|
|
self.hidden_size = hidden_size
|
|
|
self.num_hidden_layers = num_hidden_layers
|
|
|
self.num_attention_heads = num_attention_heads
|
|
|
self.intermediate_size = intermediate_size
|
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
self.rms_norm_eps = rms_norm_eps
|
|
|
self.dropout_rate = dropout_rate
|
|
|
self.window_size = window_size
|
|
|
|
|
|
class BitLinear(nn.Linear):
|
|
|
def __init__(self, in_features, out_features, bias=False, num_layers=24):
|
|
|
super().__init__(in_features, out_features, bias)
|
|
|
std = 0.02 / math.sqrt(2 * num_layers)
|
|
|
nn.init.normal_(self.weight, mean=0.0, std=std)
|
|
|
|
|
|
def forward(self, x):
|
|
|
w = self.weight
|
|
|
gamma = w.abs().mean() + 1e-9
|
|
|
w_quant = torch.clamp(torch.round(w / gamma), -1, 1)
|
|
|
w_final = w + (w_quant * gamma - w).detach()
|
|
|
x_norm = x - x.mean(dim=-1, keepdim=True)
|
|
|
x_quant = x_norm + (torch.clamp(x_norm, -1.5, 1.5) - x_norm).detach()
|
|
|
return F.linear(x_quant, w_final, self.bias)
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
def __init__(self, dim: int, eps: float = 1e-6):
|
|
|
super().__init__()
|
|
|
self.eps = eps
|
|
|
self.weight = nn.Parameter(torch.ones(dim))
|
|
|
def forward(self, x):
|
|
|
norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
return norm * self.weight
|
|
|
|
|
|
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
|
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
|
|
t = torch.arange(seq_len).float()
|
|
|
freqs = torch.outer(t, freqs)
|
|
|
return torch.polar(torch.ones_like(freqs), freqs)
|
|
|
|
|
|
def apply_rotary_emb(xq, xk, freqs_cis):
|
|
|
xq_f = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
|
xk_f = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
|
freqs_cis = freqs_cis[None, None, :xq_f.shape[2], :]
|
|
|
xq_out = torch.view_as_real(xq_f * freqs_cis).flatten(3)
|
|
|
xk_out = torch.view_as_real(xk_f * freqs_cis).flatten(3)
|
|
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
|
|
|
|
class MultiHeadAttention(nn.Module):
|
|
|
def __init__(self, config: TernaryConfig):
|
|
|
super().__init__()
|
|
|
self.n_heads = config.num_attention_heads
|
|
|
self.head_dim = config.hidden_size // config.num_attention_heads
|
|
|
self.q_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
|
|
self.k_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
|
|
self.v_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
|
|
self.out_proj = BitLinear(config.hidden_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
|
|
self.scale = self.head_dim ** -0.5
|
|
|
self.window_size = config.window_size
|
|
|
|
|
|
def forward(self, x, freqs_cis, pos_offset, past_kv=None):
|
|
|
B, T, D = x.shape
|
|
|
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
|
|
|
q, k = apply_rotary_emb(q, k, freqs_cis[pos_offset : pos_offset + T])
|
|
|
if past_kv is not None:
|
|
|
pk, pv = past_kv
|
|
|
k = torch.cat([pk, k], dim=2)[:, :, -self.window_size:]
|
|
|
v = torch.cat([pv, v], dim=2)[:, :, -self.window_size:]
|
|
|
new_kv = (k.detach(), v.detach())
|
|
|
attn = (torch.matmul(q, k.transpose(-2, -1)) * self.scale)
|
|
|
mask = torch.triu(torch.full((T, k.size(2)), float('-inf'), device=x.device), diagonal=k.size(2)-T+1).unsqueeze(0).unsqueeze(0)
|
|
|
attn = F.softmax((attn + mask).float(), dim=-1).type_as(x)
|
|
|
out = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, D)
|
|
|
return self.out_proj(out), new_kv
|
|
|
|
|
|
class SwiGLUFeedForward(nn.Module):
|
|
|
def __init__(self, config: TernaryConfig):
|
|
|
super().__init__()
|
|
|
self.w1 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
|
|
|
self.w3 = BitLinear(config.hidden_size, config.intermediate_size, num_layers=config.num_hidden_layers)
|
|
|
self.w2 = BitLinear(config.intermediate_size, config.hidden_size, num_layers=config.num_hidden_layers)
|
|
|
def forward(self, x):
|
|
|
return self.w2(F.silu(self.w1(x)) * self.w3(x))
|
|
|
|
|
|
class TransformerBlock(nn.Module):
|
|
|
def __init__(self, config: TernaryConfig):
|
|
|
super().__init__()
|
|
|
self.attn = MultiHeadAttention(config)
|
|
|
self.ffn = SwiGLUFeedForward(config)
|
|
|
self.norm1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
self.norm2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
self.dropout = nn.Dropout(config.dropout_rate)
|
|
|
def forward(self, x, freqs_cis, pos_offset, past_kv=None):
|
|
|
h, new_kv = self.attn(self.norm1(x), freqs_cis, pos_offset, past_kv)
|
|
|
x = x + self.dropout(h)
|
|
|
x = x + self.dropout(self.ffn(self.norm2(x)))
|
|
|
return x, new_kv
|
|
|
|
|
|
class TernaryTransformer(PreTrainedModel):
|
|
|
config_class = TernaryConfig
|
|
|
supports_gradient_checkpointing = True
|
|
|
def __init__(self, config: TernaryConfig):
|
|
|
super().__init__(config)
|
|
|
self.token_emb = nn.Embedding(config.vocab_size, config.hidden_size)
|
|
|
self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
|
|
|
self.ln_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
self.register_buffer("freqs_cis", precompute_freqs_cis(config.hidden_size // config.num_attention_heads, config.max_position_embeddings), persistent=False)
|
|
|
self.post_init()
|
|
|
self.lm_head.weight = self.token_emb.weight
|
|
|
self.gradient_checkpointing = False
|
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False):
|
|
|
if isinstance(module, (TernaryTransformer, TransformerBlock)):
|
|
|
self.gradient_checkpointing = value
|
|
|
|
|
|
def forward(self, input_ids, labels=None, past_key_values=None, return_dict=True, **kwargs):
|
|
|
x = self.token_emb(input_ids)
|
|
|
pos_offset = past_key_values[0][0].size(2) if past_key_values and past_key_values[0] is not None else 0
|
|
|
new_kvs = []
|
|
|
for i, block in enumerate(self.blocks):
|
|
|
if self.gradient_checkpointing and self.training:
|
|
|
x, kv = torch.utils.checkpoint.checkpoint(block, x, self.freqs_cis, pos_offset, None, use_reentrant=False)
|
|
|
else:
|
|
|
x, kv = block(x, self.freqs_cis, pos_offset, past_key_values[i] if past_key_values else None)
|
|
|
if not self.training or past_key_values: new_kvs.append(kv)
|
|
|
logits = self.lm_head(self.ln_f(x))
|
|
|
loss = None
|
|
|
if labels is not None:
|
|
|
loss = F.cross_entropy(logits[:, :-1, :].reshape(-1, self.config.vocab_size), labels[:, 1:].reshape(-1))
|
|
|
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_kvs if new_kvs else None) |