Upload folder using huggingface_hub
Browse files- config.json +26 -0
- configuration_cloverlm.py +31 -0
- exp_mlp.py +168 -0
- exp_transformer.py +696 -0
- fake_quartet.py +348 -0
- model.safetensors +3 -0
- modeling_cloverlm.py +237 -0
- tokenization_cloverlm.py +68 -0
- tokenizer_config.json +10 -0
config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"CloverLMForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attn_backend": "flash2",
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_cloverlm.CloverLMConfig",
|
| 8 |
+
"AutoModelForCausalLM": "modeling_cloverlm.CloverLMForCausalLM",
|
| 9 |
+
"AutoTokenizer": [
|
| 10 |
+
"tokenization_cloverlm.CloverLMTokenizer",
|
| 11 |
+
null
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
"d_head": 128,
|
| 15 |
+
"heads": 28,
|
| 16 |
+
"max_context": 1024,
|
| 17 |
+
"model_type": "cloverlm",
|
| 18 |
+
"num_blocks": 29,
|
| 19 |
+
"num_hidden_layers": 29,
|
| 20 |
+
"quartet_2_impl": "pseudoquant",
|
| 21 |
+
"ratio": 4,
|
| 22 |
+
"scale_type": "1/sqrt(d)",
|
| 23 |
+
"transformers_version": "5.3.0",
|
| 24 |
+
"vocab_size": 32000,
|
| 25 |
+
"weight_tying": true
|
| 26 |
+
}
|
configuration_cloverlm.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CloverLMConfig(PretrainedConfig):
|
| 5 |
+
model_type = "cloverlm"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
vocab_size=32000,
|
| 10 |
+
num_blocks=4,
|
| 11 |
+
heads=6,
|
| 12 |
+
d_head=128,
|
| 13 |
+
ratio=3,
|
| 14 |
+
scale_type="1/sqrt(d)",
|
| 15 |
+
max_context=1024,
|
| 16 |
+
quartet_2_impl="pseudoquant",
|
| 17 |
+
weight_tying=True,
|
| 18 |
+
attn_backend="pytorch",
|
| 19 |
+
**kwargs,
|
| 20 |
+
):
|
| 21 |
+
self.num_blocks = num_blocks
|
| 22 |
+
self.num_hidden_layers = num_blocks
|
| 23 |
+
self.heads = heads
|
| 24 |
+
self.d_head = d_head
|
| 25 |
+
self.ratio = ratio
|
| 26 |
+
self.scale_type = scale_type
|
| 27 |
+
self.max_context = max_context
|
| 28 |
+
self.quartet_2_impl = quartet_2_impl
|
| 29 |
+
self.weight_tying = weight_tying
|
| 30 |
+
self.attn_backend = attn_backend
|
| 31 |
+
super().__init__(vocab_size=vocab_size, **kwargs)
|
exp_mlp.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
# Normalizes on the hypersphere along dim
|
| 4 |
+
# (s1*...*)s-1
|
| 5 |
+
def sphere_norm(X, dim=-1):
|
| 6 |
+
return torch.nn.functional.normalize(X, dim=dim)
|
| 7 |
+
|
| 8 |
+
class SphereNorm(torch.nn.Module):
|
| 9 |
+
def __init__(self, dim=-1):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.dim = dim
|
| 13 |
+
|
| 14 |
+
def forward(self, X):
|
| 15 |
+
Y = sphere_norm(X, dim=self.dim)
|
| 16 |
+
|
| 17 |
+
return Y
|
| 18 |
+
|
| 19 |
+
def get_norm(enable, norm_type, d, bias):
|
| 20 |
+
if enable:
|
| 21 |
+
if norm_type=="layer":
|
| 22 |
+
norm = torch.nn.LayerNorm(d, bias=bias)
|
| 23 |
+
elif norm_type=="rms_learned":
|
| 24 |
+
norm = torch.nn.RMSNorm(d, elementwise_affine=True)
|
| 25 |
+
elif norm_type=="rms_const":
|
| 26 |
+
norm = torch.nn.RMSNorm(d, elementwise_affine=False)
|
| 27 |
+
elif norm_type=="sphere":
|
| 28 |
+
norm = SphereNorm(dim=-1)
|
| 29 |
+
else:
|
| 30 |
+
norm = None
|
| 31 |
+
|
| 32 |
+
return norm
|
| 33 |
+
|
| 34 |
+
class ReLU2(torch.nn.Module):
|
| 35 |
+
def forward(self, x):
|
| 36 |
+
y = torch.nn.functional.relu(x)**2
|
| 37 |
+
|
| 38 |
+
return y
|
| 39 |
+
|
| 40 |
+
class Abs(torch.nn.Module):
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
y = x.abs()
|
| 43 |
+
|
| 44 |
+
return y
|
| 45 |
+
|
| 46 |
+
class GLU(torch.nn.Module):
|
| 47 |
+
def __init__(self, d0, d1, bias=True, act=torch.nn.ReLU(), quartet=True, fake_quartet=False):
|
| 48 |
+
super().__init__()
|
| 49 |
+
|
| 50 |
+
self.d0 = d0
|
| 51 |
+
self.d1 = d1
|
| 52 |
+
self.bias = bias
|
| 53 |
+
self.act = act
|
| 54 |
+
self.quartet = quartet
|
| 55 |
+
self.fake_quartet = fake_quartet
|
| 56 |
+
|
| 57 |
+
if quartet:
|
| 58 |
+
pass # quartet2 not available in HF mode
|
| 59 |
+
self.gate = torch.nn.Sequential(quartet2.linear.Quartet_II_linear(d0, d1, bias), act)
|
| 60 |
+
|
| 61 |
+
self.proj = quartet2.linear.Quartet_II_linear(d0, d1, bias)
|
| 62 |
+
elif fake_quartet:
|
| 63 |
+
from . import fake_quartet as fq
|
| 64 |
+
self.gate = torch.nn.Sequential(fq.FakeQuartetLinear(d0, d1, bias), act)
|
| 65 |
+
|
| 66 |
+
self.proj = fq.FakeQuartetLinear(d0, d1, bias)
|
| 67 |
+
else:
|
| 68 |
+
self.gate = torch.nn.Sequential(torch.nn.Linear(d0, d1, bias), act)
|
| 69 |
+
|
| 70 |
+
self.proj = torch.nn.Linear(d0, d1, bias)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
y = self.gate(x) * self.proj(x)
|
| 74 |
+
|
| 75 |
+
return y
|
| 76 |
+
|
| 77 |
+
class MLP2L(torch.nn.Module):
|
| 78 |
+
def __init__(self, d0, d1, d2, bias=True, act=torch.nn.ReLU(), dropout=0, l1_type="linear", norm_type="rms_learned", norm=False, quartet=True, fake_quartet=False):
|
| 79 |
+
super().__init__()
|
| 80 |
+
|
| 81 |
+
self.d0 = d0
|
| 82 |
+
self.d1 = d1
|
| 83 |
+
self.d2 = d2
|
| 84 |
+
self.bias = bias
|
| 85 |
+
self.act = act
|
| 86 |
+
self.dropout = dropout
|
| 87 |
+
self.l1_type = l1_type
|
| 88 |
+
self.norm_type = norm_type
|
| 89 |
+
|
| 90 |
+
if l1_type=="linear":
|
| 91 |
+
if quartet:
|
| 92 |
+
pass # quartet2 not available in HF mode
|
| 93 |
+
self.l1 = torch.nn.Sequential(quartet2.linear.Quartet_II_linear(d0, d1, bias), act)
|
| 94 |
+
elif fake_quartet:
|
| 95 |
+
from . import fake_quartet as fq
|
| 96 |
+
self.l1 = torch.nn.Sequential(fq.FakeQuartetLinear(d0, d1, bias), act)
|
| 97 |
+
else:
|
| 98 |
+
self.l1 = torch.nn.Sequential(torch.nn.Linear(d0, d1, bias), act)
|
| 99 |
+
elif l1_type=="glu":
|
| 100 |
+
self.l1 = GLU(d0, d1, bias, act, quartet, fake_quartet)
|
| 101 |
+
|
| 102 |
+
self.norm = get_norm(norm, norm_type, d1, bias)
|
| 103 |
+
|
| 104 |
+
if quartet:
|
| 105 |
+
pass # quartet2 not available in HF mode
|
| 106 |
+
self.l2 = quartet2.linear.Quartet_II_linear(d1, d2, bias)
|
| 107 |
+
elif fake_quartet:
|
| 108 |
+
from . import fake_quartet as fq
|
| 109 |
+
self.l2 = fq.FakeQuartetLinear(d1, d2, bias)
|
| 110 |
+
else:
|
| 111 |
+
self.l2 = torch.nn.Linear(d1, d2, bias)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
a1 = self.l1(x)
|
| 115 |
+
if self.norm: a1 = self.norm(a1)
|
| 116 |
+
a1 = torch.nn.functional.dropout(a1, p=self.dropout, training=self.training)
|
| 117 |
+
|
| 118 |
+
y = self.l2(a1)
|
| 119 |
+
|
| 120 |
+
return y
|
| 121 |
+
|
| 122 |
+
class MLP3L(torch.nn.Module):
|
| 123 |
+
def __init__(self, d0, d1, d2, d3, bias=True, act=torch.nn.ReLU(), dropout=0):
|
| 124 |
+
super().__init__()
|
| 125 |
+
|
| 126 |
+
self.d0 = d0
|
| 127 |
+
self.d1 = d1
|
| 128 |
+
self.d2 = d2
|
| 129 |
+
self.d3 = d3
|
| 130 |
+
self.bias = bias
|
| 131 |
+
self.act = act
|
| 132 |
+
self.dropout=dropout
|
| 133 |
+
|
| 134 |
+
self.l1 = torch.nn.Linear(d0, d1, bias)
|
| 135 |
+
self.l2 = torch.nn.Linear(d1, d2, bias)
|
| 136 |
+
self.l3 = torch.nn.Linear(d2, d3, bias)
|
| 137 |
+
|
| 138 |
+
def forward(self, x):
|
| 139 |
+
z1 = self.l1(x)
|
| 140 |
+
a1 = self.act(z1)
|
| 141 |
+
a1 = torch.nn.functional.dropout(a1, p=self.dropout, training=self.training)
|
| 142 |
+
|
| 143 |
+
z2 = self.l2(a1)
|
| 144 |
+
a2 = self.act(z2)
|
| 145 |
+
a2 = torch.nn.functional.dropout(a2, p=self.dropout, training=self.training)
|
| 146 |
+
|
| 147 |
+
y = self.l3(a2)
|
| 148 |
+
|
| 149 |
+
return y
|
| 150 |
+
|
| 151 |
+
class MLP3L_image(torch.nn.Module):
|
| 152 |
+
def __init__(self, res=28, d1=16, d2=16, dropout=0, classes=10):
|
| 153 |
+
super().__init__()
|
| 154 |
+
|
| 155 |
+
self.res = res
|
| 156 |
+
self.d1 = d1
|
| 157 |
+
self.d2 = d2
|
| 158 |
+
self.dropout = dropout
|
| 159 |
+
self.classes = classes
|
| 160 |
+
|
| 161 |
+
self.mlp = MLP3L(res*res, d1, d2, classes, dropout=dropout)
|
| 162 |
+
|
| 163 |
+
def forward(self, x):
|
| 164 |
+
x = x.flatten(start_dim=-3, end_dim=-1)
|
| 165 |
+
|
| 166 |
+
y = self.mlp(x)
|
| 167 |
+
|
| 168 |
+
return y
|
exp_transformer.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from . import exp_mlp as mlp
|
| 3 |
+
from math import sqrt
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
SCALE_TYPES = ["1/sqrt(d)", "1/d"]
|
| 7 |
+
POS_TYPES = ["learned", "sinusoidal", "rope", "alibi"]
|
| 8 |
+
BACKENDS = ["pytorch", "flash2", "flash3", "flash4", "flex", "cudnn"]
|
| 9 |
+
NORM_TYPES = ["layer", "rms_learned", "rms_const", "sphere"]
|
| 10 |
+
|
| 11 |
+
def get_causal(context):
|
| 12 |
+
causal = torch.full((context,context), True)
|
| 13 |
+
|
| 14 |
+
causal = causal.tril()
|
| 15 |
+
|
| 16 |
+
return causal
|
| 17 |
+
|
| 18 |
+
def get_sinusoidal(context, d, base=1024):
|
| 19 |
+
# [pos=0, pos=1, ...]
|
| 20 |
+
poss = torch.arange(0., context)
|
| 21 |
+
# [i=0, i=1, ...]
|
| 22 |
+
js = torch.arange(0., d//2)
|
| 23 |
+
# [ω0, ω1, ...]
|
| 24 |
+
ωs = 1/base**(2*js/d)
|
| 25 |
+
|
| 26 |
+
# [pos=0*ω0, pos=0*ω1, ...]
|
| 27 |
+
# [pos=1*ω0, pos=1*ω1, ...]
|
| 28 |
+
φs = poss[...,None] @ ωs[None,...]
|
| 29 |
+
|
| 30 |
+
# context*d
|
| 31 |
+
sinusoidal = torch.empty((context, d))
|
| 32 |
+
sinusoidal[:,0::2] = torch.sin(φs)
|
| 33 |
+
sinusoidal[:,1::2] = torch.cos(φs)
|
| 34 |
+
|
| 35 |
+
return sinusoidal
|
| 36 |
+
|
| 37 |
+
def get_rope(context, d, *, device, base=1024):
|
| 38 |
+
# [m=0, m=1, ...]
|
| 39 |
+
ms = torch.arange(0., context, device=device, dtype=torch.float32)
|
| 40 |
+
# [i=0, i=1, ...]
|
| 41 |
+
js = torch.arange(0., d//2, device=device, dtype=torch.float32)
|
| 42 |
+
# [θ0, θ1, ...]
|
| 43 |
+
θs = 1/base**(2*js/d)
|
| 44 |
+
|
| 45 |
+
# [m=0*θ0, m=0*θ1, ...]
|
| 46 |
+
# [m=1*θ0, m=1*θ1, ...]
|
| 47 |
+
φs = ms[...,None] @ θs[None,...]
|
| 48 |
+
|
| 49 |
+
# context*d/2
|
| 50 |
+
cos = torch.cos(φs)
|
| 51 |
+
sin = torch.sin(φs)
|
| 52 |
+
# context*d
|
| 53 |
+
cos = cos.repeat_interleave(repeats=2, dim=1)
|
| 54 |
+
sin = sin.repeat_interleave(repeats=2, dim=1)
|
| 55 |
+
|
| 56 |
+
# 2*context*d
|
| 57 |
+
rope = torch.stack((cos,sin))
|
| 58 |
+
|
| 59 |
+
return rope
|
| 60 |
+
|
| 61 |
+
# (batches*)context*d
|
| 62 |
+
def apply_rope(X, rope):
|
| 63 |
+
X_ = torch.empty_like(X)
|
| 64 |
+
X_[...,0::2] = -X[...,1::2]
|
| 65 |
+
X_[...,1::2] = X[...,0::2]
|
| 66 |
+
|
| 67 |
+
# context*d
|
| 68 |
+
cos = rope[0]
|
| 69 |
+
sin = rope[1]
|
| 70 |
+
|
| 71 |
+
Y = X*cos + X_*sin
|
| 72 |
+
|
| 73 |
+
return Y.to(X.dtype)
|
| 74 |
+
|
| 75 |
+
def get_m(heads, base=2, exp=8):
|
| 76 |
+
m = base**( (-exp/heads)*torch.arange(1,heads+1) )
|
| 77 |
+
|
| 78 |
+
return m
|
| 79 |
+
|
| 80 |
+
def get_alibi(heads, context):
|
| 81 |
+
# 1*context*1
|
| 82 |
+
i = torch.arange(0, context)[None,:,None]
|
| 83 |
+
# 1*1*context
|
| 84 |
+
j = i.mT
|
| 85 |
+
# heads*1*1
|
| 86 |
+
m = get_m(heads)[:,None,None]
|
| 87 |
+
|
| 88 |
+
alibi = -torch.abs(i - j)*m
|
| 89 |
+
|
| 90 |
+
return alibi
|
| 91 |
+
|
| 92 |
+
def get_swa(context, window):
|
| 93 |
+
# context*1
|
| 94 |
+
i = torch.arange(0, context).unsqueeze(-1)
|
| 95 |
+
# 1*context
|
| 96 |
+
j = i.T
|
| 97 |
+
|
| 98 |
+
swa = torch.abs(i - j) <= window
|
| 99 |
+
|
| 100 |
+
return swa
|
| 101 |
+
|
| 102 |
+
# (batches*)heads/groups*context*d_head
|
| 103 |
+
def sdpa_pytorch(Q, K, V, causal=None, alibi=None, swa=None, scale=None, return_A=False):
|
| 104 |
+
if scale is None:
|
| 105 |
+
d_head = Q.shape[-1]
|
| 106 |
+
scale = 1/sqrt(d_head)
|
| 107 |
+
|
| 108 |
+
# (batches*)heads*context*d_head
|
| 109 |
+
heads = Q.shape[-3]
|
| 110 |
+
groups = K.shape[-3]
|
| 111 |
+
ratio = heads//groups
|
| 112 |
+
# PyTorch only broadcasts when the operation is not defined otherwise. MM does not involve the batch dimensions, and hence PyTorch does not broadcast them.
|
| 113 |
+
K = K.repeat_interleave(repeats=ratio, dim=-3)
|
| 114 |
+
V = V.repeat_interleave(repeats=ratio, dim=-3)
|
| 115 |
+
|
| 116 |
+
# (batches*)heads*context*context
|
| 117 |
+
A__ = Q @ K.mT
|
| 118 |
+
|
| 119 |
+
# batches*heads*context*context
|
| 120 |
+
A_ = scale*A__
|
| 121 |
+
# (batches*)heads*context*context
|
| 122 |
+
A_ = A_.reshape(A__.shape)
|
| 123 |
+
|
| 124 |
+
if alibi is not None:
|
| 125 |
+
A_ = A_ + alibi
|
| 126 |
+
if causal is not None:
|
| 127 |
+
A_.masked_fill_(~causal, -float("inf"))
|
| 128 |
+
if swa is not None:
|
| 129 |
+
A_.masked_fill_(~swa, -float("inf"))
|
| 130 |
+
|
| 131 |
+
A = torch.softmax(A_, dim=-1)
|
| 132 |
+
|
| 133 |
+
# (batches*)heads*context*d_head
|
| 134 |
+
Y = A @ V
|
| 135 |
+
|
| 136 |
+
if not return_A:
|
| 137 |
+
return Y
|
| 138 |
+
else:
|
| 139 |
+
return Y, A__, A_, A
|
| 140 |
+
|
| 141 |
+
# (batches*)heads/groups*context*d_head
|
| 142 |
+
def sdpa_flash(Q, K, V, causal=False, alibi=None, swa=None, scale=None, backend="flash2"):
|
| 143 |
+
if (alibi is not None) and backend != "flash2":
|
| 144 |
+
print("\x1b[93;3m[WARNING]: backend={backend} does not support ALiBi. Hence, we force backend=flash2.\x1b[0m")
|
| 145 |
+
backend = "flash2"
|
| 146 |
+
|
| 147 |
+
# FlashAttention only supports float scale
|
| 148 |
+
if isinstance(scale, torch.Tensor):
|
| 149 |
+
Q_shape = Q.shape
|
| 150 |
+
# batches*heads*context*d_head
|
| 151 |
+
Q = scale*Q
|
| 152 |
+
# (batches*)heads*context*d_head
|
| 153 |
+
Q = Q.reshape(Q_shape)
|
| 154 |
+
|
| 155 |
+
scale = 1
|
| 156 |
+
|
| 157 |
+
# FlashAttention2 only supports BF16 and FP16
|
| 158 |
+
if Q.dtype in [torch.bfloat16, torch.float16]:
|
| 159 |
+
dtype = Q.dtype
|
| 160 |
+
else:
|
| 161 |
+
dtype = torch.bfloat16
|
| 162 |
+
|
| 163 |
+
heads = Q.shape[-3]
|
| 164 |
+
groups = K.shape[-3]
|
| 165 |
+
context = Q.shape[-2]
|
| 166 |
+
d_head = Q.shape[-1]
|
| 167 |
+
|
| 168 |
+
# CAUTION: FlashAttention expects batches*context*heads/groups*d_head
|
| 169 |
+
Q = Q.movedim(-3,-2).reshape(-1,context,heads,d_head)
|
| 170 |
+
K = K.movedim(-3,-2).reshape(-1,context,groups,d_head)
|
| 171 |
+
V = V.movedim(-3,-2).reshape(-1,context,groups,d_head)
|
| 172 |
+
|
| 173 |
+
if swa is None:
|
| 174 |
+
swa = (-1,-1)
|
| 175 |
+
|
| 176 |
+
if backend=="flash2":
|
| 177 |
+
import flash_attn
|
| 178 |
+
Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, alibi_slopes=alibi, window_size=swa, softmax_scale=scale)
|
| 179 |
+
elif backend=="flash3":
|
| 180 |
+
import flash_attn_interface
|
| 181 |
+
Y = flash_attn_interface.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, window_size=swa, softmax_scale=scale)
|
| 182 |
+
elif backend=="flash4":
|
| 183 |
+
import flash_attn.cute
|
| 184 |
+
# FlashAttention4 returns (out, lse)
|
| 185 |
+
Y = flash_attn.cute.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, window_size=swa, softmax_scale=scale)[0]
|
| 186 |
+
|
| 187 |
+
Y = Y.to(Q.dtype)
|
| 188 |
+
|
| 189 |
+
# Restore the shape to: (batches*)heads*context*d_head
|
| 190 |
+
Y = Y.movedim(-3,-2).squeeze(0)
|
| 191 |
+
|
| 192 |
+
return Y
|
| 193 |
+
|
| 194 |
+
# (batches*)heads/groups*context*d_head
|
| 195 |
+
def sdpa_flex():
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
# (batches*)heads/groups*context*d_head
|
| 199 |
+
def sdpa_cudnn():
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
def sdpa_wrapper(Q, K, V, causal=None, alibi=None, swa=None, scale=None, return_A=False, backend="flash2"):
|
| 203 |
+
if backend=="pytorch":
|
| 204 |
+
return sdpa_pytorch(Q, K, V, causal, alibi, swa, scale, return_A)
|
| 205 |
+
elif backend in {"flash2", "flash3", "flash4"}:
|
| 206 |
+
return sdpa_flash(Q, K, V, causal, alibi, swa, scale, backend)
|
| 207 |
+
elif backend=="flex":
|
| 208 |
+
return sdpa_flex()
|
| 209 |
+
elif backend=="cudnn":
|
| 210 |
+
return sdpa_cudnn()
|
| 211 |
+
|
| 212 |
+
def test_sdpa():
|
| 213 |
+
batches = 32
|
| 214 |
+
heads = 12
|
| 215 |
+
context = 1024
|
| 216 |
+
d_head = 64
|
| 217 |
+
window = 256
|
| 218 |
+
groups = 4
|
| 219 |
+
dtype = torch.bfloat16
|
| 220 |
+
|
| 221 |
+
print("\x1b[1mbfloat16\x1b[0m",end="")
|
| 222 |
+
Q = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
|
| 223 |
+
K = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
|
| 224 |
+
V = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
|
| 225 |
+
pytorch = sdpa_wrapper(Q, K, V, backend="pytorch")
|
| 226 |
+
flash2 = sdpa_wrapper(Q, K, V, backend="flash2")
|
| 227 |
+
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
|
| 228 |
+
flash3 = sdpa_wrapper(Q, K, V, backend="flash3")
|
| 229 |
+
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
|
| 230 |
+
flash4 = sdpa_wrapper(Q, K, V, backend="flash4")
|
| 231 |
+
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
|
| 232 |
+
print("\x1b[32m ✔\x1b[0m")
|
| 233 |
+
|
| 234 |
+
print("\x1b[1mcausal\x1b[0m",end="")
|
| 235 |
+
pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), backend="pytorch")
|
| 236 |
+
flash2 = sdpa_wrapper(Q, K, V, causal=True, backend="flash2")
|
| 237 |
+
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
|
| 238 |
+
flash3 = sdpa_wrapper(Q, K, V, causal=True, backend="flash3")
|
| 239 |
+
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
|
| 240 |
+
flash4 = sdpa_wrapper(Q, K, V, causal=True, backend="flash4")
|
| 241 |
+
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
|
| 242 |
+
print("\x1b[32m ✔\x1b[0m")
|
| 243 |
+
|
| 244 |
+
print("\x1b[1malibi\x1b[0m",end="")
|
| 245 |
+
pytorch = sdpa_wrapper(Q, K, V, alibi=get_alibi(heads,context).to("cuda:0",dtype), backend="pytorch")
|
| 246 |
+
flash2 = sdpa_wrapper(Q, K, V, alibi=get_m(heads).to("cuda:0"), backend="flash2")
|
| 247 |
+
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
|
| 248 |
+
# ALiBi not supported on FlashAttention3/4
|
| 249 |
+
print("\x1b[32m ✔\x1b[0m")
|
| 250 |
+
|
| 251 |
+
print("\x1b[1mswa\x1b[0m",end="")
|
| 252 |
+
pytorch = sdpa_wrapper(Q, K, V, swa=get_swa(context,window).to("cuda:0"), backend="pytorch")
|
| 253 |
+
flash2 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash2")
|
| 254 |
+
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
|
| 255 |
+
flash3 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash3")
|
| 256 |
+
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
|
| 257 |
+
flash4 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash4")
|
| 258 |
+
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
|
| 259 |
+
print("\x1b[32m ✔\x1b[0m")
|
| 260 |
+
|
| 261 |
+
print("\x1b[1mcausal+alibi\x1b[0m",end="")
|
| 262 |
+
pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), alibi=get_alibi(heads,context).to("cuda:0",dtype), backend="pytorch")
|
| 263 |
+
flash2 = sdpa_wrapper(Q, K, V, causal=True, alibi=get_m(heads).to("cuda:0"), backend="flash2")
|
| 264 |
+
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
|
| 265 |
+
# ALiBi not supported on FlashAttention3/4
|
| 266 |
+
print("\x1b[32m ✔\x1b[0m")
|
| 267 |
+
|
| 268 |
+
print("\x1b[1mcausal+swa\x1b[0m",end="")
|
| 269 |
+
pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), swa=get_swa(context,window).to("cuda:0"), backend="pytorch")
|
| 270 |
+
flash2 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash2")
|
| 271 |
+
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
|
| 272 |
+
flash3 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash3")
|
| 273 |
+
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
|
| 274 |
+
flash4 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash4")
|
| 275 |
+
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
|
| 276 |
+
print("\x1b[32m ✔\x1b[0m")
|
| 277 |
+
|
| 278 |
+
print("\x1b[1mGQA\x1b[0m",end="")
|
| 279 |
+
Q = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
|
| 280 |
+
K = torch.rand((batches, groups, context, d_head)).to("cuda:0", dtype)
|
| 281 |
+
V = torch.rand((batches, groups, context, d_head)).to("cuda:0", dtype)
|
| 282 |
+
pytorch = sdpa_wrapper(Q, K, V, backend="pytorch")
|
| 283 |
+
flash2 = sdpa_wrapper(Q, K, V, backend="flash2")
|
| 284 |
+
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
|
| 285 |
+
flash3 = sdpa_wrapper(Q, K, V, backend="flash3")
|
| 286 |
+
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
|
| 287 |
+
flash4 = sdpa_wrapper(Q, K, V, backend="flash4")
|
| 288 |
+
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
|
| 289 |
+
print("\x1b[32m ✔\x1b[0m")
|
| 290 |
+
|
| 291 |
+
class MHSA(torch.nn.Module):
|
| 292 |
+
def __init__(self, heads, d_head, scale_type="1/sqrt(d)", ratio=1, qk_norm=True, quartet=True, fake_quartet=False):
|
| 293 |
+
super().__init__()
|
| 294 |
+
|
| 295 |
+
self.heads = heads
|
| 296 |
+
self.d_head = d_head
|
| 297 |
+
self.d = heads * d_head
|
| 298 |
+
self.scale_type = scale_type
|
| 299 |
+
self.ratio = ratio
|
| 300 |
+
self.groups = heads//ratio
|
| 301 |
+
self.d_KV = self.groups * d_head
|
| 302 |
+
self.qk_norm = qk_norm
|
| 303 |
+
if qk_norm:
|
| 304 |
+
# (batches*)heads*context*d_head
|
| 305 |
+
scale = torch.full((1,heads,1,1), sqrt(d_head))
|
| 306 |
+
self.scale = torch.nn.Parameter(scale)
|
| 307 |
+
else:
|
| 308 |
+
if scale_type=="1/sqrt(d)":
|
| 309 |
+
self.scale = 1/sqrt(d_head)
|
| 310 |
+
elif scale_type=="1/d":
|
| 311 |
+
self.scale = 1/d_head
|
| 312 |
+
self.quartet = quartet
|
| 313 |
+
self.fake_quartet = fake_quartet
|
| 314 |
+
|
| 315 |
+
# Packing QKV gives negligible speed gains, while not allowing GQA, hurting code clarity and having side effects with μP
|
| 316 |
+
if quartet:
|
| 317 |
+
pass # quartet2 not available in HF mode
|
| 318 |
+
self.lq = quartet2.linear.Quartet_II_linear(self.d, self.d, bias=False)
|
| 319 |
+
self.lk = quartet2.linear.Quartet_II_linear(self.d, self.d_KV, bias=False)
|
| 320 |
+
self.lv = quartet2.linear.Quartet_II_linear(self.d, self.d_KV, bias=False)
|
| 321 |
+
|
| 322 |
+
self.lo = quartet2.linear.Quartet_II_linear(self.d, self.d, bias=False)
|
| 323 |
+
elif fake_quartet:
|
| 324 |
+
from . import fake_quartet as fq
|
| 325 |
+
self.lq = fq.FakeQuartetLinear(self.d, self.d, bias=False)
|
| 326 |
+
self.lk = fq.FakeQuartetLinear(self.d, self.d_KV, bias=False)
|
| 327 |
+
self.lv = fq.FakeQuartetLinear(self.d, self.d_KV, bias=False)
|
| 328 |
+
|
| 329 |
+
self.lo = fq.FakeQuartetLinear(self.d, self.d, bias=False)
|
| 330 |
+
else:
|
| 331 |
+
self.lq = torch.nn.Linear(self.d, self.d, bias=False)
|
| 332 |
+
self.lk = torch.nn.Linear(self.d, self.d_KV, bias=False)
|
| 333 |
+
self.lv = torch.nn.Linear(self.d, self.d_KV, bias=False)
|
| 334 |
+
|
| 335 |
+
self.lo = torch.nn.Linear(self.d, self.d, bias=False)
|
| 336 |
+
|
| 337 |
+
# (batches*)context*d
|
| 338 |
+
def forward(self, X, causal=None, rope=None, alibi=None, swa=None, return_A=False, backend="flash2"):
|
| 339 |
+
# (batches*)context*d
|
| 340 |
+
Q = self.lq(X)
|
| 341 |
+
# (batches*)context*d_KV
|
| 342 |
+
K = self.lk(X)
|
| 343 |
+
V = self.lv(X)
|
| 344 |
+
|
| 345 |
+
# (batches*)context*heads*d_head
|
| 346 |
+
Q = Q.unflatten(dim=-1, sizes=(self.heads, self.d_head))
|
| 347 |
+
# (batches*)context*groups*d_head
|
| 348 |
+
K = K.unflatten(dim=-1, sizes=(self.groups, self.d_head))
|
| 349 |
+
V = V.unflatten(dim=-1, sizes=(self.groups, self.d_head))
|
| 350 |
+
|
| 351 |
+
# (batches*)heads*context*d_head
|
| 352 |
+
Q = Q.movedim(-3,-2)
|
| 353 |
+
# (batches*)groups*context*d_head
|
| 354 |
+
K = K.movedim(-3,-2)
|
| 355 |
+
V = V.movedim(-3,-2)
|
| 356 |
+
|
| 357 |
+
if rope is not None:
|
| 358 |
+
Q = apply_rope(Q,rope)
|
| 359 |
+
K = apply_rope(K,rope)
|
| 360 |
+
|
| 361 |
+
# After RoPE
|
| 362 |
+
if self.qk_norm:
|
| 363 |
+
Q = mlp.sphere_norm(Q)
|
| 364 |
+
K = mlp.sphere_norm(K)
|
| 365 |
+
|
| 366 |
+
# (batches*)heads*context*d_head
|
| 367 |
+
if not return_A:
|
| 368 |
+
Y = sdpa_wrapper(Q, K, V, causal, alibi, swa, self.scale, return_A, backend)
|
| 369 |
+
else:
|
| 370 |
+
Y, A__, A_, A = sdpa_wrapper(Q, K, V, causal, alibi, swa, self.scale, return_A, backend)
|
| 371 |
+
# (batches*)context*heads*d_head
|
| 372 |
+
Y = Y.movedim(-3,-2)
|
| 373 |
+
# (batches*)context*d
|
| 374 |
+
Y = Y.flatten(-2,-1)
|
| 375 |
+
|
| 376 |
+
Y = self.lo(Y)
|
| 377 |
+
|
| 378 |
+
if not return_A:
|
| 379 |
+
return Y
|
| 380 |
+
else:
|
| 381 |
+
return Y, A__, A_, A
|
| 382 |
+
|
| 383 |
+
class Block(torch.nn.Module):
|
| 384 |
+
def __init__(self, heads, d_head, scale_type="1/sqrt(d)", ratio=1, exp_factor=4, dropout=0, norm_type="rms_learned", bias=False, act=mlp.ReLU2(), l1_type="linear", pre_att_norm=False, qk_norm=True, out_att_norm=True, pre_mlp_norm=False, act_norm=False, out_mlp_norm=True, quartet=True, fake_quartet=False):
|
| 385 |
+
super().__init__()
|
| 386 |
+
|
| 387 |
+
self.heads = heads
|
| 388 |
+
self.d_head = d_head
|
| 389 |
+
self.d = heads * d_head
|
| 390 |
+
self.scale_type = scale_type
|
| 391 |
+
self.ratio = ratio
|
| 392 |
+
self.groups = heads//ratio
|
| 393 |
+
self.exp_factor = exp_factor
|
| 394 |
+
self.d_hidden = int(exp_factor*self.d)
|
| 395 |
+
self.dropout = dropout
|
| 396 |
+
self.norm_type = norm_type
|
| 397 |
+
self.bias = bias
|
| 398 |
+
self.act = act
|
| 399 |
+
self.l1_type = l1_type
|
| 400 |
+
|
| 401 |
+
self.mhsa = MHSA(heads, d_head, scale_type, ratio, qk_norm, quartet, fake_quartet)
|
| 402 |
+
self.pre_att_norm = mlp.get_norm(pre_att_norm, norm_type, self.d, bias)
|
| 403 |
+
self.out_att_norm = mlp.get_norm(out_att_norm, norm_type, self.d, bias)
|
| 404 |
+
|
| 405 |
+
self.mlp = mlp.MLP2L(self.d, self.d_hidden, self.d, bias, act, dropout, l1_type, norm_type, act_norm, quartet, fake_quartet)
|
| 406 |
+
self.pre_mlp_norm = mlp.get_norm(pre_mlp_norm, norm_type, self.d, bias)
|
| 407 |
+
self.out_mlp_norm = mlp.get_norm(out_mlp_norm, norm_type, self.d, bias)
|
| 408 |
+
|
| 409 |
+
self.quartet = quartet
|
| 410 |
+
self.fake_quartet = fake_quartet
|
| 411 |
+
|
| 412 |
+
def forward(self, X, causal=None, rope=None, alibi=None, swa=None, return_res=False, return_A=False, backend="flash2"):
|
| 413 |
+
mhsa = self.mhsa(self.pre_att_norm(X) if self.pre_att_norm else X, causal, rope, alibi, swa, return_A, backend)
|
| 414 |
+
if not return_A:
|
| 415 |
+
Y = mhsa
|
| 416 |
+
else:
|
| 417 |
+
Y, A__, A_, A = mhsa
|
| 418 |
+
|
| 419 |
+
if self.out_att_norm: Y = self.out_att_norm(Y)
|
| 420 |
+
|
| 421 |
+
Y_ = torch.nn.functional.dropout(Y, p=self.dropout, training=self.training)
|
| 422 |
+
Y__ = X + Y_
|
| 423 |
+
|
| 424 |
+
Z = self.mlp(self.pre_mlp_norm(Y__) if self.pre_mlp_norm else Y__)
|
| 425 |
+
|
| 426 |
+
if self.out_mlp_norm: Z = self.out_mlp_norm(Z)
|
| 427 |
+
|
| 428 |
+
Z_ = torch.nn.functional.dropout(Z, p=self.dropout, training=self.training)
|
| 429 |
+
Z__ = Y__ + Z_
|
| 430 |
+
|
| 431 |
+
if not return_res:
|
| 432 |
+
if not return_A:
|
| 433 |
+
return Z__
|
| 434 |
+
else:
|
| 435 |
+
return Z__, A__, A_, A
|
| 436 |
+
else:
|
| 437 |
+
if not return_A:
|
| 438 |
+
return Z__, Y__
|
| 439 |
+
else:
|
| 440 |
+
return Z__, Y__, A__, A_, A
|
| 441 |
+
|
| 442 |
+
class Transformer(torch.nn.Module):
|
| 443 |
+
def __init__(self, vocab_size=50304, num_blocks=12, heads=12, d_head=64, scale_type="1/sqrt(d)", ratio=1, is_causal=True, window=None, backend="flash2", exp_factor=4, dropout=0, pos_type="rope", max_context=128, norm_type="rms_learned", bias=False, act=mlp.ReLU2(), l1_type="linear", std=0.02, test=False, weight_tying=True, emb_norm=False, pre_att_norm=False, qk_norm=True, out_att_norm=True, pre_mlp_norm=False, act_norm=False, out_mlp_norm=True, out_norm=True, fix_norm=False, quartet=True, fake_quartet=False):
|
| 444 |
+
super().__init__()
|
| 445 |
+
|
| 446 |
+
self.vocab_size = vocab_size
|
| 447 |
+
self.num_blocks = num_blocks
|
| 448 |
+
self.heads = heads
|
| 449 |
+
self.d_head = d_head
|
| 450 |
+
self.d = heads * d_head
|
| 451 |
+
self.scale_type = scale_type
|
| 452 |
+
self.ratio = ratio
|
| 453 |
+
self.groups = heads//ratio
|
| 454 |
+
self.is_causal = is_causal
|
| 455 |
+
self.window = window
|
| 456 |
+
self.backend = backend
|
| 457 |
+
self.exp_factor = exp_factor
|
| 458 |
+
self.dropout = dropout
|
| 459 |
+
self.pos_type = pos_type
|
| 460 |
+
self.max_context = max_context
|
| 461 |
+
self.norm_type = norm_type
|
| 462 |
+
self.bias = bias
|
| 463 |
+
self.act = act
|
| 464 |
+
self.l1_type = l1_type
|
| 465 |
+
self.weight_tying = weight_tying
|
| 466 |
+
self.fix_norm = fix_norm
|
| 467 |
+
self.quartet = quartet
|
| 468 |
+
self.fake_quartet = fake_quartet
|
| 469 |
+
|
| 470 |
+
self.emb = torch.nn.Embedding(vocab_size, self.d)
|
| 471 |
+
|
| 472 |
+
self.emb_norm = mlp.get_norm(emb_norm, norm_type, self.d, bias)
|
| 473 |
+
|
| 474 |
+
if pos_type == "learned":
|
| 475 |
+
pos = torch.rand((max_context, self.d))
|
| 476 |
+
self.pos = torch.nn.Parameter(pos)
|
| 477 |
+
|
| 478 |
+
self.blocks = torch.nn.Sequential(*[Block(heads, d_head, scale_type, ratio, exp_factor, dropout, norm_type, bias, act, l1_type, pre_att_norm, qk_norm, out_att_norm, pre_mlp_norm, act_norm, out_mlp_norm, quartet, fake_quartet) for _ in range(num_blocks)])
|
| 479 |
+
|
| 480 |
+
self.out_norm = mlp.get_norm(out_norm, norm_type, self.d, bias)
|
| 481 |
+
|
| 482 |
+
self.linear = torch.nn.Linear(self.d, vocab_size, bias=False)
|
| 483 |
+
|
| 484 |
+
if weight_tying: self.emb.weight = self.linear.weight
|
| 485 |
+
|
| 486 |
+
self.init(std, test)
|
| 487 |
+
|
| 488 |
+
if fake_quartet:
|
| 489 |
+
for m in self.modules():
|
| 490 |
+
if isinstance(m, (torch.nn.LayerNorm, torch.nn.RMSNorm, torch.nn.Embedding)):
|
| 491 |
+
m.to(torch.bfloat16)
|
| 492 |
+
|
| 493 |
+
def init(self, std=0.02, test=False):
|
| 494 |
+
if test: print("\x1b[1m%36.36s %8.8s %8.8s %8.8s\x1b[0m" % ("parameter_name", "suffix", "mean", "std"))
|
| 495 |
+
for parameter_name, parameter in self.named_parameters():
|
| 496 |
+
parent_name, _, suffix = parameter_name.rpartition(".")
|
| 497 |
+
parent = self.get_submodule(parent_name)
|
| 498 |
+
|
| 499 |
+
if isinstance(parent, (torch.nn.Linear, torch.nn.Embedding)) and suffix=="weight":
|
| 500 |
+
torch.nn.init.normal_(parameter, 0, std)
|
| 501 |
+
elif isinstance(parent, (torch.nn.Linear, torch.nn.LayerNorm)) and suffix=="bias":
|
| 502 |
+
torch.nn.init.zeros_(parameter)
|
| 503 |
+
elif isinstance(parent, (torch.nn.LayerNorm, torch.nn.RMSNorm)) and suffix=="weight":
|
| 504 |
+
torch.nn.init.ones_(parameter)
|
| 505 |
+
else:
|
| 506 |
+
# pos
|
| 507 |
+
if parameter.ndim == 2:
|
| 508 |
+
torch.nn.init.zeros_(parameter)
|
| 509 |
+
# scale
|
| 510 |
+
elif parameter.ndim == 4:
|
| 511 |
+
torch.nn.init.constant_(parameter, sqrt(self.d_head))
|
| 512 |
+
|
| 513 |
+
if test:
|
| 514 |
+
print("%36.36s %8.8s %8.8s %8.8s\x1b[0m" % (parameter_name, suffix, "%f" % parameter.mean(), "%f" % parameter.std()))
|
| 515 |
+
|
| 516 |
+
# (batches*)context
|
| 517 |
+
def forward(self, ids, return_res=False, return_A=False):
|
| 518 |
+
context = ids.shape[-1]
|
| 519 |
+
|
| 520 |
+
if return_A:
|
| 521 |
+
# (batches*)num_blocks*heads*context*context
|
| 522 |
+
A__ = torch.empty(*ids.shape[:-1], self.num_blocks, self.heads, context, context)
|
| 523 |
+
A_ = torch.empty_like(A__)
|
| 524 |
+
A = torch.empty_like(A__)
|
| 525 |
+
|
| 526 |
+
# (batches*)context*d
|
| 527 |
+
X = self.emb(ids)
|
| 528 |
+
|
| 529 |
+
if return_res:
|
| 530 |
+
res_in = X
|
| 531 |
+
# (batches*)num_blocks*context*d
|
| 532 |
+
res_att = torch.empty(*ids.shape[:-1], self.num_blocks, context, self.d)
|
| 533 |
+
res_mlp = torch.empty(*ids.shape[:-1], self.num_blocks, context, self.d)
|
| 534 |
+
|
| 535 |
+
# Recompute in every batch in case context changes
|
| 536 |
+
if self.is_causal:
|
| 537 |
+
if self.backend=="pytorch":
|
| 538 |
+
causal = get_causal(context).to(ids.device)
|
| 539 |
+
elif self.backend in {"flash2", "flash3", "flash4"}:
|
| 540 |
+
causal = True
|
| 541 |
+
elif self.backend=="flex":
|
| 542 |
+
causal = causal_mod
|
| 543 |
+
elif self.backend=="cudnn":
|
| 544 |
+
# right_bound
|
| 545 |
+
causal = 0
|
| 546 |
+
else: causal = None
|
| 547 |
+
|
| 548 |
+
if self.pos_type == "sinusoidal":
|
| 549 |
+
pos = get_sinusoidal(context, self.d).to(ids.device)
|
| 550 |
+
X = X + pos
|
| 551 |
+
|
| 552 |
+
if self.pos_type == "learned":
|
| 553 |
+
X = X + self.pos[:context,:]
|
| 554 |
+
|
| 555 |
+
if self.pos_type == "rope":
|
| 556 |
+
rope = get_rope(context, self.d_head, device=ids.device)
|
| 557 |
+
else: rope = None
|
| 558 |
+
|
| 559 |
+
if self.pos_type == "alibi":
|
| 560 |
+
if self.backend=="pytorch":
|
| 561 |
+
alibi = get_alibi(self.heads, context).to(ids.device)
|
| 562 |
+
elif self.backend in {"flash2", "flash3", "flash4"}:
|
| 563 |
+
alibi = get_m(self.heads).to(ids.device)
|
| 564 |
+
elif self.backend=="flex":
|
| 565 |
+
alibi = alibi_mod
|
| 566 |
+
elif self.backend=="cudnn":
|
| 567 |
+
alibi = True
|
| 568 |
+
else: alibi = None
|
| 569 |
+
|
| 570 |
+
if self.window is not None:
|
| 571 |
+
if self.backend=="pytorch":
|
| 572 |
+
swa = get_swa(context, self.window).to(ids.device)
|
| 573 |
+
elif self.backend in {"flash2", "flash3", "flash4"}:
|
| 574 |
+
swa = (self.window, self.window)
|
| 575 |
+
elif self.backend=="flex":
|
| 576 |
+
swa = swa_mod
|
| 577 |
+
elif self.backend=="cudnn":
|
| 578 |
+
# left_bound
|
| 579 |
+
swa = self.window
|
| 580 |
+
else: swa = None
|
| 581 |
+
|
| 582 |
+
# After positional encoding
|
| 583 |
+
if self.emb_norm: X = self.emb_norm(X)
|
| 584 |
+
|
| 585 |
+
X_ = torch.nn.functional.dropout(X, p=self.dropout, training=self.training)
|
| 586 |
+
|
| 587 |
+
Y = X_
|
| 588 |
+
for i, block in enumerate(self.blocks):
|
| 589 |
+
if not return_res:
|
| 590 |
+
if not return_A:
|
| 591 |
+
Y = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
|
| 592 |
+
else:
|
| 593 |
+
Y, A__[...,i,:,:,:], A_[...,i,:,:,:], A[...,i,:,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
|
| 594 |
+
else:
|
| 595 |
+
if not return_A:
|
| 596 |
+
Y, res_att[...,i,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
|
| 597 |
+
res_mlp[...,i,:,:]= Y
|
| 598 |
+
else:
|
| 599 |
+
Y, res_att[...,i,:,:], A__[...,i,:,:,:], A_[...,i,:,:,:], A[...,i,:,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
|
| 600 |
+
res_mlp[...,i,:,:]= Y
|
| 601 |
+
|
| 602 |
+
if self.out_norm: Y = self.out_norm(Y)
|
| 603 |
+
|
| 604 |
+
# (batches*)context*vocab_size
|
| 605 |
+
if self.fix_norm:
|
| 606 |
+
Z = torch.nn.functional.linear(Y, mlp.sphere_norm(self.linear.weight))
|
| 607 |
+
else:
|
| 608 |
+
Z = self.linear(Y)
|
| 609 |
+
|
| 610 |
+
if not return_res:
|
| 611 |
+
if not return_A:
|
| 612 |
+
return Z
|
| 613 |
+
else:
|
| 614 |
+
return Z, A__, A_, A
|
| 615 |
+
else:
|
| 616 |
+
if not return_A:
|
| 617 |
+
return Z, res_in, res_att, res_mlp
|
| 618 |
+
else:
|
| 619 |
+
return Z, res_in, res_att, res_mlp, A__, A_, A
|
| 620 |
+
|
| 621 |
+
def get_attention_header(transformer):
|
| 622 |
+
attention_header = ""
|
| 623 |
+
|
| 624 |
+
for block in range(transformer.num_blocks):
|
| 625 |
+
for head in range(transformer.heads):
|
| 626 |
+
attention_header += f"block{block}.head{head} "
|
| 627 |
+
|
| 628 |
+
# Remove last space
|
| 629 |
+
attention_header = attention_header[:-1]
|
| 630 |
+
|
| 631 |
+
return attention_header
|
| 632 |
+
|
| 633 |
+
def get_attention(W):
|
| 634 |
+
attention = ""
|
| 635 |
+
|
| 636 |
+
for block in range(W.shape[0]):
|
| 637 |
+
for head in range(W.shape[1]):
|
| 638 |
+
# rows->y, columns->x
|
| 639 |
+
attention += "%.2f " % W[block, head]
|
| 640 |
+
|
| 641 |
+
# Remove last space
|
| 642 |
+
attention = attention[:-1]
|
| 643 |
+
|
| 644 |
+
return attention
|
| 645 |
+
|
| 646 |
+
def get_similarity_header(transformer):
|
| 647 |
+
similarity_header = "embedding "
|
| 648 |
+
|
| 649 |
+
for block in range(transformer.num_blocks):
|
| 650 |
+
similarity_header += f"block{block} "
|
| 651 |
+
|
| 652 |
+
# Remove last space
|
| 653 |
+
similarity_header = similarity_header[:-1]
|
| 654 |
+
|
| 655 |
+
return similarity_header
|
| 656 |
+
|
| 657 |
+
def get_similarity(embeddings_x, embeddings_y):
|
| 658 |
+
similarity = ""
|
| 659 |
+
|
| 660 |
+
for block in range(embeddings_x.shape[0]):
|
| 661 |
+
similarity += "%.2f " % torch.nn.functional.cosine_similarity(embeddings_x[block,:], embeddings_y[block,:], dim=0)
|
| 662 |
+
|
| 663 |
+
# Remove last space
|
| 664 |
+
similarity = similarity[:-1]
|
| 665 |
+
|
| 666 |
+
return similarity
|
| 667 |
+
|
| 668 |
+
def get_clustering_header(transformer):
|
| 669 |
+
clustering_header = "embedding.random.x embedding.random.y "\
|
| 670 |
+
"embedding.pca.x embedding.pca.y "\
|
| 671 |
+
"embedding.mds.x embedding.mds.y "\
|
| 672 |
+
"embedding.tsne.x embedding.tsne.y "\
|
| 673 |
+
"embedding.umap.x embedding.umap.y "
|
| 674 |
+
|
| 675 |
+
for block in range(transformer.num_blocks):
|
| 676 |
+
clustering_header += f"block{block}.random.x block{block}.random.y "\
|
| 677 |
+
f"block{block}.pca.x block{block}.pca.y "\
|
| 678 |
+
f"block{block}.mds.x block{block}.mds.y "\
|
| 679 |
+
f"block{block}.tsne.x block{block}.tsne.y "\
|
| 680 |
+
f"block{block}.umap.x block{block}.umap.y "
|
| 681 |
+
|
| 682 |
+
# Remove last space
|
| 683 |
+
clustering_header = clustering_header[:-1]
|
| 684 |
+
|
| 685 |
+
return clustering_header
|
| 686 |
+
|
| 687 |
+
def get_clustering(random, pca, mds, tsne, umap):
|
| 688 |
+
clustering = ""
|
| 689 |
+
|
| 690 |
+
for block in range(random.shape[0]):
|
| 691 |
+
clustering += "%f %f %f %f %f %f %f %f %f %f " % (random[block,0], random[block,1], pca[block,0], pca[block,1], mds[block,0], mds[block,1], tsne[block,0], tsne[block,1], umap[block,0], umap[block,1])
|
| 692 |
+
|
| 693 |
+
# Remove last space
|
| 694 |
+
clustering = clustering[:-1]
|
| 695 |
+
|
| 696 |
+
return clustering
|
fake_quartet.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from random import randint
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import triton
|
| 7 |
+
import triton.language as tl
|
| 8 |
+
from scipy.linalg import hadamard
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device):
|
| 13 |
+
return torch.tensor(
|
| 14 |
+
hadamard(group_size) * group_size**-0.5,
|
| 15 |
+
dtype=dtype,
|
| 16 |
+
device=device,
|
| 17 |
+
requires_grad=False,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def rerotate_hadamard(hadamard_matrix):
|
| 22 |
+
signs = torch.diag(
|
| 23 |
+
torch.randint(
|
| 24 |
+
0, 2, (hadamard_matrix.size(0),),
|
| 25 |
+
device=hadamard_matrix.device,
|
| 26 |
+
dtype=hadamard_matrix.dtype,
|
| 27 |
+
) * 2 - 1
|
| 28 |
+
)
|
| 29 |
+
return hadamard_matrix @ signs
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@triton.jit
|
| 34 |
+
def _rtn_fp4(x):
|
| 35 |
+
x_abs = tl.abs(x)
|
| 36 |
+
x_sign = tl.where(x > 0, 1, -1)
|
| 37 |
+
x_fp4_abs = tl.where(
|
| 38 |
+
x_abs >= 5, 6,
|
| 39 |
+
tl.where(x_abs >= 3.5, 4,
|
| 40 |
+
tl.where(x_abs >= 2.5, 3,
|
| 41 |
+
tl.where(x_abs >= 1.75, 2,
|
| 42 |
+
tl.where(x_abs >= 1.25, 1.5,
|
| 43 |
+
tl.where(x_abs >= 0.75, 1,
|
| 44 |
+
tl.where(x_abs >= 0.25, 0.5,
|
| 45 |
+
0.0)))))))
|
| 46 |
+
return x_fp4_abs * x_sign
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@triton.jit
|
| 50 |
+
def _get_scales(x, amax, val_max, scales_max):
|
| 51 |
+
s_dec = tl.where(amax == 0.0, 1.0, amax / scales_max / val_max)
|
| 52 |
+
s_dec_b = tl.max(tl.abs(x), axis=-1, keep_dims=True) / val_max
|
| 53 |
+
s_dec_b_e4m3 = (s_dec_b / s_dec).to(tl.float8e4nv).to(tl.float32)
|
| 54 |
+
s_dec_b_e4m3 = tl.where(s_dec_b_e4m3 == 0, 1.0, s_dec_b_e4m3)
|
| 55 |
+
return s_dec_b_e4m3, s_dec
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@triton.jit
|
| 59 |
+
def _get_alt_scales(x, val_max, s_dec):
|
| 60 |
+
s_dec_b = tl.max(tl.abs(x), axis=-1, keep_dims=True) / val_max
|
| 61 |
+
s_dec_b_e4m3 = (s_dec_b * (6 / 4) / s_dec).to(tl.float8e4nv).to(tl.float32)
|
| 62 |
+
s_dec_b_e4m3 = tl.where(s_dec_b_e4m3 == 0, 1.0, s_dec_b_e4m3)
|
| 63 |
+
return s_dec_b_e4m3
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@triton.autotune(
|
| 67 |
+
configs=[
|
| 68 |
+
triton.Config({"BLOCK_SIZE": 64 * 32}),
|
| 69 |
+
triton.Config({"BLOCK_SIZE": 128 * 32}),
|
| 70 |
+
triton.Config({"BLOCK_SIZE": 256 * 32}),
|
| 71 |
+
triton.Config({"BLOCK_SIZE": 512 * 32}),
|
| 72 |
+
],
|
| 73 |
+
key=[],
|
| 74 |
+
)
|
| 75 |
+
@triton.jit
|
| 76 |
+
def _rtn_1x16s_fp4_kernel(
|
| 77 |
+
x_ptr, amax_ptr, output_ptr,
|
| 78 |
+
n_elements: tl.constexpr,
|
| 79 |
+
scale_override: tl.constexpr,
|
| 80 |
+
group_size: tl.constexpr,
|
| 81 |
+
four_over_six: tl.constexpr,
|
| 82 |
+
BLOCK_SIZE: tl.constexpr,
|
| 83 |
+
):
|
| 84 |
+
pid = tl.program_id(0)
|
| 85 |
+
start_idx = pid * BLOCK_SIZE
|
| 86 |
+
offsets = start_idx + tl.arange(0, BLOCK_SIZE)
|
| 87 |
+
mask = offsets < n_elements
|
| 88 |
+
x_flat = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
| 89 |
+
|
| 90 |
+
x_grouped = tl.reshape(x_flat, (BLOCK_SIZE // group_size, group_size))
|
| 91 |
+
|
| 92 |
+
scales_max = 256.00 if four_over_six else 448.00
|
| 93 |
+
val_max = 6.0 / scale_override
|
| 94 |
+
amax = tl.load(amax_ptr)
|
| 95 |
+
|
| 96 |
+
s_dec_b_e4m3, s_dec = _get_scales(x_grouped, amax, val_max, scales_max)
|
| 97 |
+
x_scaled = x_grouped / (s_dec_b_e4m3 * s_dec)
|
| 98 |
+
|
| 99 |
+
x_fp4 = _rtn_fp4(x_scaled)
|
| 100 |
+
x_dequantized = x_fp4 * (s_dec_b_e4m3 * s_dec)
|
| 101 |
+
|
| 102 |
+
if not four_over_six:
|
| 103 |
+
best_x_dequantized = x_dequantized
|
| 104 |
+
else:
|
| 105 |
+
alt_s_dec_b_e4m3 = _get_alt_scales(x_grouped, val_max, s_dec)
|
| 106 |
+
alt_x_scaled = x_grouped / (alt_s_dec_b_e4m3 * s_dec)
|
| 107 |
+
alt_x_fp4 = _rtn_fp4(alt_x_scaled)
|
| 108 |
+
alt_x_dequantized = alt_x_fp4 * (alt_s_dec_b_e4m3 * s_dec)
|
| 109 |
+
|
| 110 |
+
error_six = tl.sum((x_grouped - x_dequantized) * (x_grouped - x_dequantized), axis=-1, keep_dims=True)
|
| 111 |
+
error_four = tl.sum((x_grouped - alt_x_dequantized) * (x_grouped - alt_x_dequantized), axis=-1, keep_dims=True)
|
| 112 |
+
best_x_dequantized = tl.where(error_six <= error_four, x_dequantized, alt_x_dequantized)
|
| 113 |
+
|
| 114 |
+
x_dequantized_flat = tl.reshape(best_x_dequantized, (BLOCK_SIZE,))
|
| 115 |
+
tl.store(output_ptr + offsets, x_dequantized_flat, mask=mask)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@torch.compiler.disable()
|
| 119 |
+
def rtn_1x16s_fp4(x, scale_override: float, group_size: int, four_over_six: bool):
|
| 120 |
+
x = x.contiguous()
|
| 121 |
+
output = torch.empty_like(x)
|
| 122 |
+
n_elements = x.numel()
|
| 123 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
| 124 |
+
_rtn_1x16s_fp4_kernel[grid](
|
| 125 |
+
x_ptr=x, amax_ptr=x.abs().max(), output_ptr=output,
|
| 126 |
+
n_elements=n_elements, scale_override=scale_override,
|
| 127 |
+
group_size=group_size, four_over_six=four_over_six,
|
| 128 |
+
)
|
| 129 |
+
return output
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@triton.autotune(
|
| 134 |
+
configs=[
|
| 135 |
+
triton.Config({"BLOCK_SIZE": 64 * 32}),
|
| 136 |
+
triton.Config({"BLOCK_SIZE": 128 * 32}),
|
| 137 |
+
triton.Config({"BLOCK_SIZE": 256 * 32}),
|
| 138 |
+
triton.Config({"BLOCK_SIZE": 512 * 32}),
|
| 139 |
+
],
|
| 140 |
+
key=[],
|
| 141 |
+
)
|
| 142 |
+
@triton.jit
|
| 143 |
+
def _eden_1x16s_fp4_kernel(
|
| 144 |
+
x_ptr, hadamard_matrix_ptr, current_amax_ptr, output_ptr, next_amax_ptr,
|
| 145 |
+
n_elements: tl.constexpr,
|
| 146 |
+
hadamard_dim: tl.constexpr,
|
| 147 |
+
scale_override: tl.constexpr,
|
| 148 |
+
group_size: tl.constexpr,
|
| 149 |
+
seed: int,
|
| 150 |
+
BLOCK_SIZE: tl.constexpr,
|
| 151 |
+
):
|
| 152 |
+
pid = tl.program_id(0)
|
| 153 |
+
start_idx = pid * BLOCK_SIZE
|
| 154 |
+
offsets = start_idx + tl.arange(0, BLOCK_SIZE)
|
| 155 |
+
mask = offsets < n_elements
|
| 156 |
+
x_flat = tl.load(x_ptr + offsets, mask=mask, other=0.0)
|
| 157 |
+
|
| 158 |
+
offsets_hadamard = tl.arange(0, hadamard_dim * hadamard_dim)
|
| 159 |
+
hadamard_matrix = tl.load(hadamard_matrix_ptr + offsets_hadamard).reshape(hadamard_dim, hadamard_dim)
|
| 160 |
+
x = tl.reshape(x_flat, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
|
| 161 |
+
x_had = tl.dot(x, hadamard_matrix)
|
| 162 |
+
|
| 163 |
+
tl.atomic_max(next_amax_ptr, tl.max(tl.abs(x_had)).to(tl.float32), sem="relaxed")
|
| 164 |
+
|
| 165 |
+
x_grouped = tl.reshape(x_had, (BLOCK_SIZE // group_size, group_size))
|
| 166 |
+
|
| 167 |
+
scales_max = 255.99
|
| 168 |
+
val_max = 6.0 / scale_override
|
| 169 |
+
amax = tl.load(current_amax_ptr)
|
| 170 |
+
s_dec = tl.where(amax == 0.0, 1.0, amax / scales_max / val_max)
|
| 171 |
+
|
| 172 |
+
s_dec_b = tl.max(tl.abs(x_grouped), axis=-1, keep_dims=True) / val_max
|
| 173 |
+
s_dec_b_e4m3 = (s_dec_b / s_dec).to(tl.float8e4nv).to(tl.float32)
|
| 174 |
+
s_dec_b_e4m3 = tl.where(s_dec_b_e4m3 == 0, 1.0, s_dec_b_e4m3)
|
| 175 |
+
x_scaled = x_grouped / (s_dec_b_e4m3 * s_dec)
|
| 176 |
+
|
| 177 |
+
x_scaled_abs = tl.abs(x_scaled)
|
| 178 |
+
x_scaled_sign = tl.where(x_scaled > 0, 1, -1)
|
| 179 |
+
x_fp4 = tl.where(
|
| 180 |
+
x_scaled_abs >= 5, 6,
|
| 181 |
+
tl.where(x_scaled_abs >= 3.5, 4,
|
| 182 |
+
tl.where(x_scaled_abs >= 2.5, 3,
|
| 183 |
+
tl.where(x_scaled_abs >= 1.75, 2,
|
| 184 |
+
tl.where(x_scaled_abs >= 1.25, 1.5,
|
| 185 |
+
tl.where(x_scaled_abs >= 0.75, 1,
|
| 186 |
+
tl.where(x_scaled_abs >= 0.25, 0.5,
|
| 187 |
+
0))))))) * x_scaled_sign
|
| 188 |
+
|
| 189 |
+
x_scaled = tl.reshape(x_scaled, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
|
| 190 |
+
x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // hadamard_dim, hadamard_dim))
|
| 191 |
+
|
| 192 |
+
num = tl.sum(x_scaled * x_scaled, axis=-1, keep_dims=True)
|
| 193 |
+
denom = tl.sum(x_scaled * x_fp4, axis=-1, keep_dims=True)
|
| 194 |
+
correction = tl.where(denom == 0.0, 1.0, num / denom)
|
| 195 |
+
|
| 196 |
+
scales = tl.reshape(s_dec_b_e4m3, (BLOCK_SIZE // hadamard_dim, hadamard_dim // group_size))
|
| 197 |
+
corrected_scales = tl.reshape(scales * correction, (BLOCK_SIZE // group_size, 1))
|
| 198 |
+
|
| 199 |
+
bitscales = tl.cast(corrected_scales.to(tl.float8e4nv), tl.uint8, bitcast=True)
|
| 200 |
+
prevscale = tl.cast((bitscales - 1), tl.float8e4nv, bitcast=True).to(tl.float32)
|
| 201 |
+
currscale = tl.cast((bitscales), tl.float8e4nv, bitcast=True).to(tl.float32)
|
| 202 |
+
nextscale = tl.cast((bitscales + 1), tl.float8e4nv, bitcast=True).to(tl.float32)
|
| 203 |
+
|
| 204 |
+
up = tl.where(currscale > corrected_scales, currscale, nextscale)
|
| 205 |
+
down = tl.where(currscale > corrected_scales, prevscale, currscale)
|
| 206 |
+
prob_up = (corrected_scales - down) / (up - down)
|
| 207 |
+
|
| 208 |
+
scale_start_idx = pid * (BLOCK_SIZE // group_size)
|
| 209 |
+
scale_offsets = scale_start_idx + tl.arange(0, BLOCK_SIZE // group_size)
|
| 210 |
+
sampled_prob = tl.rand(seed, scale_offsets).reshape(BLOCK_SIZE // group_size, 1)
|
| 211 |
+
|
| 212 |
+
scales = tl.where(sampled_prob < prob_up, up, down)
|
| 213 |
+
scales = tl.reshape(scales, (BLOCK_SIZE // group_size, 1))
|
| 214 |
+
x_fp4 = tl.reshape(x_fp4, (BLOCK_SIZE // group_size, group_size))
|
| 215 |
+
|
| 216 |
+
x_dequantized = x_fp4 * scales * s_dec
|
| 217 |
+
x_dequantized_flat = tl.reshape(x_dequantized, (BLOCK_SIZE,))
|
| 218 |
+
tl.store(output_ptr + offsets, x_dequantized_flat.to(x_ptr.dtype.element_ty), mask=mask)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@torch.compiler.disable()
|
| 222 |
+
def eden_1x16s_fp4(x, hadamard_matrix, scale_override: float, group_size: int, current_amax):
|
| 223 |
+
hadamard_dim = hadamard_matrix.size(0)
|
| 224 |
+
x = x.contiguous()
|
| 225 |
+
hadamard_matrix = hadamard_matrix.T.contiguous()
|
| 226 |
+
output = torch.empty_like(x)
|
| 227 |
+
seed = randint(0, 1_000_000)
|
| 228 |
+
next_amax = torch.zeros_like(current_amax)
|
| 229 |
+
n_elements = x.numel()
|
| 230 |
+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
| 231 |
+
_eden_1x16s_fp4_kernel[grid](
|
| 232 |
+
x_ptr=x, hadamard_matrix_ptr=hadamard_matrix,
|
| 233 |
+
current_amax_ptr=current_amax, output_ptr=output,
|
| 234 |
+
next_amax_ptr=next_amax, n_elements=n_elements,
|
| 235 |
+
hadamard_dim=hadamard_dim, scale_override=scale_override,
|
| 236 |
+
group_size=group_size, seed=seed,
|
| 237 |
+
)
|
| 238 |
+
return output, next_amax
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
class AmaxStorage:
|
| 243 |
+
__slots__ = ("e_ht_amax", "weght_tht_amax", "e_tht_amax", "input_tht_amax")
|
| 244 |
+
|
| 245 |
+
def __init__(self):
|
| 246 |
+
self.e_ht_amax = None
|
| 247 |
+
self.weght_tht_amax = None
|
| 248 |
+
self.e_tht_amax = None
|
| 249 |
+
self.input_tht_amax = None
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class FakeQuartetFn(torch.autograd.Function):
|
| 254 |
+
group_size = 16
|
| 255 |
+
forward_scale_override = 1.0
|
| 256 |
+
backward_scale_override = (17 / 16) * 0.93
|
| 257 |
+
hadamard_matrix = None
|
| 258 |
+
|
| 259 |
+
@torch.compile(dynamic=False)
|
| 260 |
+
@staticmethod
|
| 261 |
+
def forward(ctx, input, weight, amax_storage, delayed_amax, disable_forward_quant, disable_backward_quant, four_over_six):
|
| 262 |
+
ctx.batch = input.shape[0]
|
| 263 |
+
ctx.seq = input.shape[1]
|
| 264 |
+
ctx.in_dim = weight.shape[1]
|
| 265 |
+
ctx.out_dim = weight.shape[0]
|
| 266 |
+
ctx.delayed_amax = delayed_amax
|
| 267 |
+
ctx.amax_storage = amax_storage
|
| 268 |
+
ctx.disable_backward_quant = disable_backward_quant
|
| 269 |
+
|
| 270 |
+
if disable_forward_quant:
|
| 271 |
+
input_fq = input
|
| 272 |
+
weight_fq = weight
|
| 273 |
+
else:
|
| 274 |
+
input_fq = rtn_1x16s_fp4(input, FakeQuartetFn.forward_scale_override, FakeQuartetFn.group_size, four_over_six)
|
| 275 |
+
weight_fq = rtn_1x16s_fp4(weight, FakeQuartetFn.forward_scale_override, FakeQuartetFn.group_size, four_over_six)
|
| 276 |
+
|
| 277 |
+
ctx.save_for_backward(input_fq, weight_fq)
|
| 278 |
+
return F.linear(input_fq, weight_fq)
|
| 279 |
+
|
| 280 |
+
@staticmethod
|
| 281 |
+
def backward(ctx, grad_output):
|
| 282 |
+
input_fq, weight_fq = ctx.saved_tensors
|
| 283 |
+
dtype = grad_output.dtype
|
| 284 |
+
input_fq = input_fq.to(dtype).reshape(ctx.batch * ctx.seq, ctx.in_dim)
|
| 285 |
+
weight_fq = weight_fq.to(dtype)
|
| 286 |
+
grad_output = grad_output.reshape(ctx.batch * ctx.seq, ctx.out_dim)
|
| 287 |
+
|
| 288 |
+
FakeQuartetFn.hadamard_matrix = rerotate_hadamard(FakeQuartetFn.hadamard_matrix)
|
| 289 |
+
|
| 290 |
+
if ctx.disable_backward_quant:
|
| 291 |
+
grad_input = F.linear(grad_output, weight_fq.T, None).view(ctx.batch, ctx.seq, ctx.in_dim)
|
| 292 |
+
grad_weight = F.linear(grad_output.T, input_fq.T, None)
|
| 293 |
+
return grad_input, grad_weight, None, None, None, None, None
|
| 294 |
+
|
| 295 |
+
had = FakeQuartetFn.hadamard_matrix.to(grad_output.dtype)
|
| 296 |
+
bso = FakeQuartetFn.backward_scale_override
|
| 297 |
+
gs = FakeQuartetFn.group_size
|
| 298 |
+
|
| 299 |
+
# EW: grad_output @ weight^T
|
| 300 |
+
if ctx.amax_storage.e_ht_amax is None or not ctx.delayed_amax:
|
| 301 |
+
ctx.amax_storage.e_ht_amax = (grad_output.reshape(-1, had.size(0)) @ had.T).abs().max().float()
|
| 302 |
+
e_ht_fp4, ctx.amax_storage.e_ht_amax = eden_1x16s_fp4(grad_output, had, bso, gs, ctx.amax_storage.e_ht_amax)
|
| 303 |
+
|
| 304 |
+
if ctx.amax_storage.weght_tht_amax is None or not ctx.delayed_amax:
|
| 305 |
+
ctx.amax_storage.weght_tht_amax = (weight_fq.T.reshape(-1, had.size(0)) @ had.T).abs().max().float()
|
| 306 |
+
weight_tht_fp4, ctx.amax_storage.weght_tht_amax = eden_1x16s_fp4(weight_fq.T, had, bso, gs, ctx.amax_storage.weght_tht_amax)
|
| 307 |
+
|
| 308 |
+
grad_input = F.linear(e_ht_fp4, weight_tht_fp4, None).view(ctx.batch, ctx.seq, ctx.in_dim)
|
| 309 |
+
|
| 310 |
+
# EtX: grad_output^T @ input
|
| 311 |
+
if ctx.amax_storage.e_tht_amax is None or not ctx.delayed_amax:
|
| 312 |
+
ctx.amax_storage.e_tht_amax = (grad_output.T.reshape(-1, had.size(0)) @ had.T).abs().max().float()
|
| 313 |
+
e_tht_fp4, ctx.amax_storage.e_tht_amax = eden_1x16s_fp4(grad_output.T, had, bso, gs, ctx.amax_storage.e_tht_amax)
|
| 314 |
+
|
| 315 |
+
if ctx.amax_storage.input_tht_amax is None or not ctx.delayed_amax:
|
| 316 |
+
ctx.amax_storage.input_tht_amax = (input_fq.T.reshape(-1, had.size(0)) @ had.T).abs().max().float()
|
| 317 |
+
input_tht_fp4, ctx.amax_storage.input_tht_amax = eden_1x16s_fp4(input_fq.T, had, bso, gs, ctx.amax_storage.input_tht_amax)
|
| 318 |
+
|
| 319 |
+
grad_weight = F.linear(e_tht_fp4, input_tht_fp4, None)
|
| 320 |
+
|
| 321 |
+
return grad_input, grad_weight, None, None, None, None, None
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class FakeQuartetLinear(torch.nn.Linear):
|
| 326 |
+
|
| 327 |
+
def __init__(self, *args, hadamard_dim=32, delayed_amax=False,
|
| 328 |
+
disable_forward_quant=False, disable_backward_quant=False,
|
| 329 |
+
four_over_six=True, **kwargs):
|
| 330 |
+
super().__init__(*args, **kwargs)
|
| 331 |
+
self.hadamard_dim = hadamard_dim
|
| 332 |
+
self.delayed_amax = delayed_amax
|
| 333 |
+
self.disable_forward_quant = disable_forward_quant
|
| 334 |
+
self.disable_backward_quant = disable_backward_quant
|
| 335 |
+
self.four_over_six = four_over_six
|
| 336 |
+
self.amax_storage = AmaxStorage()
|
| 337 |
+
|
| 338 |
+
if FakeQuartetFn.hadamard_matrix is None:
|
| 339 |
+
FakeQuartetFn.hadamard_matrix = get_hadamard_matrix(
|
| 340 |
+
self.hadamard_dim, dtype=torch.float32, device="cuda",
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def forward(self, x):
|
| 344 |
+
return FakeQuartetFn.apply(
|
| 345 |
+
x, self.weight, self.amax_storage,
|
| 346 |
+
self.delayed_amax, self.disable_forward_quant,
|
| 347 |
+
self.disable_backward_quant, self.four_over_six,
|
| 348 |
+
)
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5802c11b6b024033386dba4cdff8665d48de19850e0e63c31686f44430ca870f
|
| 3 |
+
size 16563661264
|
modeling_cloverlm.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from math import sqrt
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers import PreTrainedModel, GenerationMixin
|
| 8 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 9 |
+
|
| 10 |
+
from .configuration_cloverlm import CloverLMConfig
|
| 11 |
+
from .fake_quartet import FakeQuartetLinear
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _sphere_norm(X, dim=-1):
|
| 16 |
+
return F.normalize(X, dim=dim)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class _ReLU2(nn.Module):
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return F.relu(x) ** 2
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _make_linear(in_f, out_f, bias, quartet_2_impl):
|
| 25 |
+
if quartet_2_impl == "pseudoquant":
|
| 26 |
+
return FakeQuartetLinear(in_f, out_f, bias)
|
| 27 |
+
elif quartet_2_impl == "quartet2":
|
| 28 |
+
try:
|
| 29 |
+
from quartet2.linear import Quartet_II_linear
|
| 30 |
+
except ImportError as e:
|
| 31 |
+
e.add_note("Quartet_II_linear import failed. Install the latest quartet2 from https://github.com/IST-DASLab/Quartet-II")
|
| 32 |
+
raise e
|
| 33 |
+
|
| 34 |
+
return Quartet_II_linear(in_f, out_f, bias)
|
| 35 |
+
else:
|
| 36 |
+
raise ValueError(f"Unsupported quartet_2_impl: {quartet_2_impl}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _build_rope(context, d_head, device):
|
| 40 |
+
ms = torch.arange(context, device=device, dtype=torch.float32)
|
| 41 |
+
js = torch.arange(d_head // 2, device=device, dtype=torch.float32)
|
| 42 |
+
theta = 1.0 / (1024.0 ** (2.0 * js / d_head))
|
| 43 |
+
phi = ms[:, None] @ theta[None, :]
|
| 44 |
+
cos = torch.cos(phi).repeat_interleave(2, dim=1)
|
| 45 |
+
sin = torch.sin(phi).repeat_interleave(2, dim=1)
|
| 46 |
+
return torch.stack((cos, sin))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _apply_rope(X, rope):
|
| 50 |
+
X_ = torch.empty_like(X)
|
| 51 |
+
X_[..., 0::2] = -X[..., 1::2]
|
| 52 |
+
X_[..., 1::2] = X[..., 0::2]
|
| 53 |
+
return (X * rope[0] + X_ * rope[1]).to(X.dtype)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class _MLP(nn.Module):
|
| 58 |
+
|
| 59 |
+
def __init__(self, d, d_hidden, quartet_2_impl):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.l1 = nn.Sequential(_make_linear(d, d_hidden, False, quartet_2_impl), _ReLU2())
|
| 62 |
+
self.l2 = _make_linear(d_hidden, d, False, quartet_2_impl)
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return self.l2(self.l1(x))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MHSA(nn.Module):
|
| 70 |
+
def __init__(self, heads, d_head, ratio, quartet_2_impl):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.heads = heads
|
| 73 |
+
self.d_head = d_head
|
| 74 |
+
self.d = heads * d_head
|
| 75 |
+
self.groups = heads // ratio
|
| 76 |
+
d_kv = self.groups * d_head
|
| 77 |
+
|
| 78 |
+
self.lq = _make_linear(self.d, self.d, False, quartet_2_impl)
|
| 79 |
+
self.lk = _make_linear(self.d, d_kv, False, quartet_2_impl)
|
| 80 |
+
self.lv = _make_linear(self.d, d_kv, False, quartet_2_impl)
|
| 81 |
+
self.lo = _make_linear(self.d, self.d, False, quartet_2_impl)
|
| 82 |
+
|
| 83 |
+
self.scale = nn.Parameter(torch.full((1, heads, 1, 1), sqrt(d_head)))
|
| 84 |
+
|
| 85 |
+
def forward(self, X, rope, attn_backend):
|
| 86 |
+
B = X.shape[0] if X.dim() == 3 else 1
|
| 87 |
+
ctx = X.shape[-2]
|
| 88 |
+
|
| 89 |
+
Q = self.lq(X).unflatten(-1, (self.heads, self.d_head)).movedim(-3, -2)
|
| 90 |
+
K = self.lk(X).unflatten(-1, (self.groups, self.d_head)).movedim(-3, -2)
|
| 91 |
+
V = self.lv(X).unflatten(-1, (self.groups, self.d_head)).movedim(-3, -2)
|
| 92 |
+
|
| 93 |
+
Q = _apply_rope(Q, rope)
|
| 94 |
+
K = _apply_rope(K, rope)
|
| 95 |
+
Q = _sphere_norm(Q)
|
| 96 |
+
K = _sphere_norm(K)
|
| 97 |
+
|
| 98 |
+
Q_shape = Q.shape
|
| 99 |
+
Q = self.scale * Q
|
| 100 |
+
Q = Q.reshape(Q_shape)
|
| 101 |
+
|
| 102 |
+
if attn_backend == "pytorch":
|
| 103 |
+
K = K.repeat_interleave(self.heads // self.groups, dim=-3)
|
| 104 |
+
V = V.repeat_interleave(self.heads // self.groups, dim=-3)
|
| 105 |
+
Y = F.scaled_dot_product_attention(Q, K, V, is_causal=True, scale=1.0)
|
| 106 |
+
Y = Y.movedim(-3, -2).flatten(-2, -1)
|
| 107 |
+
elif attn_backend in ("flash2", "flash3", "flash4"):
|
| 108 |
+
Q = Q.movedim(-3, -2).reshape(-1, ctx, self.heads, self.d_head)
|
| 109 |
+
K = K.movedim(-3, -2).reshape(-1, ctx, self.groups, self.d_head)
|
| 110 |
+
V = V.movedim(-3, -2).reshape(-1, ctx, self.groups, self.d_head)
|
| 111 |
+
|
| 112 |
+
dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
|
| 113 |
+
if attn_backend == "flash2":
|
| 114 |
+
import flash_attn
|
| 115 |
+
Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 116 |
+
elif attn_backend == "flash3":
|
| 117 |
+
import importlib
|
| 118 |
+
_fa3 = importlib.import_module("flash_attn_interface")
|
| 119 |
+
Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 120 |
+
elif attn_backend == "flash4":
|
| 121 |
+
import importlib
|
| 122 |
+
_fa4 = importlib.import_module("flash_attn.cute")
|
| 123 |
+
Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
|
| 124 |
+
Y = Y.to(Q.dtype).flatten(-2, -1)
|
| 125 |
+
|
| 126 |
+
return self.lo(Y)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class _Block(nn.Module):
|
| 131 |
+
|
| 132 |
+
def __init__(self, heads, d_head, ratio, quartet_2_impl):
|
| 133 |
+
super().__init__()
|
| 134 |
+
d = heads * d_head
|
| 135 |
+
|
| 136 |
+
self.mhsa = MHSA(heads, d_head, ratio, quartet_2_impl)
|
| 137 |
+
self.out_att_norm = nn.RMSNorm(d, elementwise_affine=True)
|
| 138 |
+
|
| 139 |
+
self.mlp = _MLP(d, 4 * d, quartet_2_impl)
|
| 140 |
+
self.out_mlp_norm = nn.RMSNorm(d, elementwise_affine=True)
|
| 141 |
+
|
| 142 |
+
def forward(self, X, rope, attn_backend):
|
| 143 |
+
Y = self.out_att_norm(self.mhsa(X, rope, attn_backend))
|
| 144 |
+
Y = X + Y
|
| 145 |
+
Z = self.out_mlp_norm(self.mlp(Y))
|
| 146 |
+
return Y + Z
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class _Transformer(nn.Module):
|
| 151 |
+
|
| 152 |
+
def __init__(self, vocab_size, num_blocks, heads, d_head, ratio,
|
| 153 |
+
max_context, std, quartet_2_impl, weight_tying, attn_backend):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.d_head = d_head
|
| 156 |
+
self.attn_backend = attn_backend
|
| 157 |
+
d = heads * d_head
|
| 158 |
+
|
| 159 |
+
self.emb = nn.Embedding(vocab_size, d)
|
| 160 |
+
self.blocks = nn.Sequential(*[
|
| 161 |
+
_Block(heads, d_head, ratio, quartet_2_impl) for _ in range(num_blocks)
|
| 162 |
+
])
|
| 163 |
+
self.out_norm = nn.RMSNorm(d, elementwise_affine=True)
|
| 164 |
+
self.linear = nn.Linear(d, vocab_size, bias=False)
|
| 165 |
+
|
| 166 |
+
if weight_tying:
|
| 167 |
+
self.emb.weight = self.linear.weight
|
| 168 |
+
|
| 169 |
+
for name, p in self.named_parameters():
|
| 170 |
+
parent_name, _, suffix = name.rpartition(".")
|
| 171 |
+
parent = self.get_submodule(parent_name)
|
| 172 |
+
if isinstance(parent, (nn.Linear, nn.Embedding)) and suffix == "weight":
|
| 173 |
+
nn.init.normal_(p, 0, std)
|
| 174 |
+
elif isinstance(parent, nn.RMSNorm) and suffix == "weight":
|
| 175 |
+
nn.init.ones_(p)
|
| 176 |
+
elif p.ndim == 4:
|
| 177 |
+
nn.init.constant_(p, sqrt(d_head))
|
| 178 |
+
|
| 179 |
+
if quartet_2_impl:
|
| 180 |
+
for m in self.modules():
|
| 181 |
+
if isinstance(m, (nn.LayerNorm, nn.RMSNorm, nn.Embedding)):
|
| 182 |
+
m.to(torch.bfloat16)
|
| 183 |
+
|
| 184 |
+
def forward(self, ids):
|
| 185 |
+
ctx = ids.shape[-1]
|
| 186 |
+
rope = _build_rope(ctx, self.d_head, device=ids.device)
|
| 187 |
+
|
| 188 |
+
X = self.emb(ids)
|
| 189 |
+
for block in self.blocks:
|
| 190 |
+
X = block(X, rope, self.attn_backend)
|
| 191 |
+
X = self.out_norm(X)
|
| 192 |
+
return self.linear(X)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class CloverLMForCausalLM(PreTrainedModel, GenerationMixin):
|
| 197 |
+
config_class = CloverLMConfig
|
| 198 |
+
supports_gradient_checkpointing = False
|
| 199 |
+
_no_split_modules = ["_Block"]
|
| 200 |
+
_tied_weights_keys = ["transformer.linear.weight"]
|
| 201 |
+
_tp_plan = {}
|
| 202 |
+
|
| 203 |
+
def __init__(self, config: CloverLMConfig):
|
| 204 |
+
super().__init__(config)
|
| 205 |
+
self.all_tied_weights_keys = {k: "transformer.emb.weight"
|
| 206 |
+
for k in (self._tied_weights_keys or [])}
|
| 207 |
+
self.transformer = _Transformer(
|
| 208 |
+
vocab_size=config.vocab_size,
|
| 209 |
+
num_blocks=config.num_blocks,
|
| 210 |
+
heads=config.heads,
|
| 211 |
+
d_head=config.d_head,
|
| 212 |
+
ratio=config.ratio,
|
| 213 |
+
max_context=config.max_context,
|
| 214 |
+
std=0.02,
|
| 215 |
+
quartet_2_impl=config.quartet_2_impl,
|
| 216 |
+
weight_tying=config.weight_tying,
|
| 217 |
+
attn_backend=config.attn_backend,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
| 221 |
+
logits = self.transformer(input_ids)
|
| 222 |
+
|
| 223 |
+
loss = None
|
| 224 |
+
if labels is not None:
|
| 225 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 226 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 227 |
+
loss = F.cross_entropy(
|
| 228 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 229 |
+
shift_labels.view(-1),
|
| 230 |
+
)
|
| 231 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits)
|
| 232 |
+
|
| 233 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 234 |
+
return {"input_ids": input_ids}
|
| 235 |
+
|
| 236 |
+
def _supports_default_dynamic_cache(self):
|
| 237 |
+
return False
|
tokenization_cloverlm.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
import tokenmonster
|
| 4 |
+
from transformers import PreTrainedTokenizer
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
TOKENMONSTER_URL = (
|
| 8 |
+
"https://huggingface.co/gvlassis/tokenmonster/resolve/main/"
|
| 9 |
+
"englishcode-32000-strict-nocapcode-v1-eot%3D14199.vocab"
|
| 10 |
+
"?download=true"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CloverLMTokenizer(PreTrainedTokenizer):
|
| 15 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 16 |
+
|
| 17 |
+
def __init__(self, vocab_url: str = TOKENMONSTER_URL,
|
| 18 |
+
eot_id: int = 14199, **kwargs):
|
| 19 |
+
self._tm = tokenmonster.load(vocab_url)
|
| 20 |
+
self._eot_id = eot_id
|
| 21 |
+
self._vocab_size = 32000
|
| 22 |
+
|
| 23 |
+
super().__init__(
|
| 24 |
+
eos_token="<eot>",
|
| 25 |
+
pad_token="<eot>",
|
| 26 |
+
bos_token="<eot>",
|
| 27 |
+
**kwargs,
|
| 28 |
+
)
|
| 29 |
+
self.eos_token_id = eot_id
|
| 30 |
+
self.pad_token_id = eot_id
|
| 31 |
+
self.bos_token_id = eot_id
|
| 32 |
+
|
| 33 |
+
@property
|
| 34 |
+
def vocab_size(self) -> int:
|
| 35 |
+
return self._vocab_size
|
| 36 |
+
|
| 37 |
+
def get_vocab(self):
|
| 38 |
+
return {f"<tok_{i}>": i for i in range(self._vocab_size)}
|
| 39 |
+
|
| 40 |
+
def _tokenize(self, text: str, **kwargs) -> List[str]:
|
| 41 |
+
ids = self._tm.tokenize(text).tolist()
|
| 42 |
+
return [str(i) for i in ids]
|
| 43 |
+
|
| 44 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 45 |
+
return int(token)
|
| 46 |
+
|
| 47 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 48 |
+
return str(index)
|
| 49 |
+
|
| 50 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 51 |
+
ids = [int(t) for t in tokens]
|
| 52 |
+
return self._tm.decode(ids)
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def all_special_tokens_extended(self):
|
| 56 |
+
return [self.eos_token]
|
| 57 |
+
|
| 58 |
+
@property
|
| 59 |
+
def all_special_tokens(self):
|
| 60 |
+
return [self.eos_token]
|
| 61 |
+
|
| 62 |
+
@property
|
| 63 |
+
def all_special_ids(self):
|
| 64 |
+
return [self._eot_id]
|
| 65 |
+
|
| 66 |
+
def save_vocabulary(self, save_directory: str,
|
| 67 |
+
filename_prefix: Optional[str] = None):
|
| 68 |
+
return ()
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tokenizer_class": "CloverLMTokenizer",
|
| 3 |
+
"auto_map": {
|
| 4 |
+
"AutoTokenizer": [
|
| 5 |
+
"tokenization_cloverlm.CloverLMTokenizer",
|
| 6 |
+
null
|
| 7 |
+
]
|
| 8 |
+
},
|
| 9 |
+
"use_fast": false
|
| 10 |
+
}
|