CloverLM / exp_transformer.py
mansaripo's picture
Upload folder using huggingface_hub
b0fd683 verified
import torch
from . import exp_mlp as mlp
from math import sqrt
import math
SCALE_TYPES = ["1/sqrt(d)", "1/d"]
POS_TYPES = ["learned", "sinusoidal", "rope", "alibi"]
BACKENDS = ["pytorch", "flash2", "flash3", "flash4", "flex", "cudnn"]
NORM_TYPES = ["layer", "rms_learned", "rms_const", "sphere"]
def get_causal(context):
causal = torch.full((context,context), True)
causal = causal.tril()
return causal
def get_sinusoidal(context, d, base=1024):
# [pos=0, pos=1, ...]
poss = torch.arange(0., context)
# [i=0, i=1, ...]
js = torch.arange(0., d//2)
# [ω0, ω1, ...]
ωs = 1/base**(2*js/d)
# [pos=0*ω0, pos=0*ω1, ...]
# [pos=1*ω0, pos=1*ω1, ...]
φs = poss[...,None] @ ωs[None,...]
# context*d
sinusoidal = torch.empty((context, d))
sinusoidal[:,0::2] = torch.sin(φs)
sinusoidal[:,1::2] = torch.cos(φs)
return sinusoidal
def get_rope(context, d, *, device, base=1024):
# [m=0, m=1, ...]
ms = torch.arange(0., context, device=device, dtype=torch.float32)
# [i=0, i=1, ...]
js = torch.arange(0., d//2, device=device, dtype=torch.float32)
# [θ0, θ1, ...]
θs = 1/base**(2*js/d)
# [m=0*θ0, m=0*θ1, ...]
# [m=1*θ0, m=1*θ1, ...]
φs = ms[...,None] @ θs[None,...]
# context*d/2
cos = torch.cos(φs)
sin = torch.sin(φs)
# context*d
cos = cos.repeat_interleave(repeats=2, dim=1)
sin = sin.repeat_interleave(repeats=2, dim=1)
# 2*context*d
rope = torch.stack((cos,sin))
return rope
# (batches*)context*d
def apply_rope(X, rope):
X_ = torch.empty_like(X)
X_[...,0::2] = -X[...,1::2]
X_[...,1::2] = X[...,0::2]
# context*d
cos = rope[0]
sin = rope[1]
Y = X*cos + X_*sin
return Y.to(X.dtype)
def get_m(heads, base=2, exp=8):
m = base**( (-exp/heads)*torch.arange(1,heads+1) )
return m
def get_alibi(heads, context):
# 1*context*1
i = torch.arange(0, context)[None,:,None]
# 1*1*context
j = i.mT
# heads*1*1
m = get_m(heads)[:,None,None]
alibi = -torch.abs(i - j)*m
return alibi
def get_swa(context, window):
# context*1
i = torch.arange(0, context).unsqueeze(-1)
# 1*context
j = i.T
swa = torch.abs(i - j) <= window
return swa
# (batches*)heads/groups*context*d_head
def sdpa_pytorch(Q, K, V, causal=None, alibi=None, swa=None, scale=None, return_A=False):
if scale is None:
d_head = Q.shape[-1]
scale = 1/sqrt(d_head)
# (batches*)heads*context*d_head
heads = Q.shape[-3]
groups = K.shape[-3]
ratio = heads//groups
# PyTorch only broadcasts when the operation is not defined otherwise. MM does not involve the batch dimensions, and hence PyTorch does not broadcast them.
K = K.repeat_interleave(repeats=ratio, dim=-3)
V = V.repeat_interleave(repeats=ratio, dim=-3)
# (batches*)heads*context*context
A__ = Q @ K.mT
# batches*heads*context*context
A_ = scale*A__
# (batches*)heads*context*context
A_ = A_.reshape(A__.shape)
if alibi is not None:
A_ = A_ + alibi
if causal is not None:
A_.masked_fill_(~causal, -float("inf"))
if swa is not None:
A_.masked_fill_(~swa, -float("inf"))
A = torch.softmax(A_, dim=-1)
# (batches*)heads*context*d_head
Y = A @ V
if not return_A:
return Y
else:
return Y, A__, A_, A
# (batches*)heads/groups*context*d_head
def sdpa_flash(Q, K, V, causal=False, alibi=None, swa=None, scale=None, backend="flash2"):
if (alibi is not None) and backend != "flash2":
print("\x1b[93;3m[WARNING]: backend={backend} does not support ALiBi. Hence, we force backend=flash2.\x1b[0m")
backend = "flash2"
# FlashAttention only supports float scale
if isinstance(scale, torch.Tensor):
Q_shape = Q.shape
# batches*heads*context*d_head
Q = scale*Q
# (batches*)heads*context*d_head
Q = Q.reshape(Q_shape)
scale = 1
# FlashAttention2 only supports BF16 and FP16
if Q.dtype in [torch.bfloat16, torch.float16]:
dtype = Q.dtype
else:
dtype = torch.bfloat16
heads = Q.shape[-3]
groups = K.shape[-3]
context = Q.shape[-2]
d_head = Q.shape[-1]
# CAUTION: FlashAttention expects batches*context*heads/groups*d_head
Q = Q.movedim(-3,-2).reshape(-1,context,heads,d_head)
K = K.movedim(-3,-2).reshape(-1,context,groups,d_head)
V = V.movedim(-3,-2).reshape(-1,context,groups,d_head)
if swa is None:
swa = (-1,-1)
if backend=="flash2":
import flash_attn
Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, alibi_slopes=alibi, window_size=swa, softmax_scale=scale)
elif backend=="flash3":
import flash_attn_interface
Y = flash_attn_interface.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, window_size=swa, softmax_scale=scale)
elif backend=="flash4":
import flash_attn.cute
# FlashAttention4 returns (out, lse)
Y = flash_attn.cute.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=causal, window_size=swa, softmax_scale=scale)[0]
Y = Y.to(Q.dtype)
# Restore the shape to: (batches*)heads*context*d_head
Y = Y.movedim(-3,-2).squeeze(0)
return Y
# (batches*)heads/groups*context*d_head
def sdpa_flex():
return None
# (batches*)heads/groups*context*d_head
def sdpa_cudnn():
return None
def sdpa_wrapper(Q, K, V, causal=None, alibi=None, swa=None, scale=None, return_A=False, backend="flash2"):
if backend=="pytorch":
return sdpa_pytorch(Q, K, V, causal, alibi, swa, scale, return_A)
elif backend in {"flash2", "flash3", "flash4"}:
return sdpa_flash(Q, K, V, causal, alibi, swa, scale, backend)
elif backend=="flex":
return sdpa_flex()
elif backend=="cudnn":
return sdpa_cudnn()
def test_sdpa():
batches = 32
heads = 12
context = 1024
d_head = 64
window = 256
groups = 4
dtype = torch.bfloat16
print("\x1b[1mbfloat16\x1b[0m",end="")
Q = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
K = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
V = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
pytorch = sdpa_wrapper(Q, K, V, backend="pytorch")
flash2 = sdpa_wrapper(Q, K, V, backend="flash2")
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
flash3 = sdpa_wrapper(Q, K, V, backend="flash3")
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
flash4 = sdpa_wrapper(Q, K, V, backend="flash4")
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
print("\x1b[32m ✔\x1b[0m")
print("\x1b[1mcausal\x1b[0m",end="")
pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), backend="pytorch")
flash2 = sdpa_wrapper(Q, K, V, causal=True, backend="flash2")
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
flash3 = sdpa_wrapper(Q, K, V, causal=True, backend="flash3")
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
flash4 = sdpa_wrapper(Q, K, V, causal=True, backend="flash4")
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
print("\x1b[32m ✔\x1b[0m")
print("\x1b[1malibi\x1b[0m",end="")
pytorch = sdpa_wrapper(Q, K, V, alibi=get_alibi(heads,context).to("cuda:0",dtype), backend="pytorch")
flash2 = sdpa_wrapper(Q, K, V, alibi=get_m(heads).to("cuda:0"), backend="flash2")
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
# ALiBi not supported on FlashAttention3/4
print("\x1b[32m ✔\x1b[0m")
print("\x1b[1mswa\x1b[0m",end="")
pytorch = sdpa_wrapper(Q, K, V, swa=get_swa(context,window).to("cuda:0"), backend="pytorch")
flash2 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash2")
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
flash3 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash3")
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
flash4 = sdpa_wrapper(Q, K, V, swa=(window,window), backend="flash4")
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
print("\x1b[32m ✔\x1b[0m")
print("\x1b[1mcausal+alibi\x1b[0m",end="")
pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), alibi=get_alibi(heads,context).to("cuda:0",dtype), backend="pytorch")
flash2 = sdpa_wrapper(Q, K, V, causal=True, alibi=get_m(heads).to("cuda:0"), backend="flash2")
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
# ALiBi not supported on FlashAttention3/4
print("\x1b[32m ✔\x1b[0m")
print("\x1b[1mcausal+swa\x1b[0m",end="")
pytorch = sdpa_wrapper(Q, K, V, causal=get_causal(context).to("cuda:0"), swa=get_swa(context,window).to("cuda:0"), backend="pytorch")
flash2 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash2")
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
flash3 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash3")
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
flash4 = sdpa_wrapper(Q, K, V, causal=True, swa=(window,window), backend="flash4")
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
print("\x1b[32m ✔\x1b[0m")
print("\x1b[1mGQA\x1b[0m",end="")
Q = torch.rand((batches, heads, context, d_head)).to("cuda:0", dtype)
K = torch.rand((batches, groups, context, d_head)).to("cuda:0", dtype)
V = torch.rand((batches, groups, context, d_head)).to("cuda:0", dtype)
pytorch = sdpa_wrapper(Q, K, V, backend="pytorch")
flash2 = sdpa_wrapper(Q, K, V, backend="flash2")
torch.testing.assert_close(flash2, pytorch, check_dtype=False)
flash3 = sdpa_wrapper(Q, K, V, backend="flash3")
torch.testing.assert_close(flash3, pytorch, check_dtype=False)
flash4 = sdpa_wrapper(Q, K, V, backend="flash4")
torch.testing.assert_close(flash4, pytorch, check_dtype=False)
print("\x1b[32m ✔\x1b[0m")
class MHSA(torch.nn.Module):
def __init__(self, heads, d_head, scale_type="1/sqrt(d)", ratio=1, qk_norm=True, quartet=True, fake_quartet=False):
super().__init__()
self.heads = heads
self.d_head = d_head
self.d = heads * d_head
self.scale_type = scale_type
self.ratio = ratio
self.groups = heads//ratio
self.d_KV = self.groups * d_head
self.qk_norm = qk_norm
if qk_norm:
# (batches*)heads*context*d_head
scale = torch.full((1,heads,1,1), sqrt(d_head))
self.scale = torch.nn.Parameter(scale)
else:
if scale_type=="1/sqrt(d)":
self.scale = 1/sqrt(d_head)
elif scale_type=="1/d":
self.scale = 1/d_head
self.quartet = quartet
self.fake_quartet = fake_quartet
# Packing QKV gives negligible speed gains, while not allowing GQA, hurting code clarity and having side effects with μP
if quartet:
pass # quartet2 not available in HF mode
self.lq = quartet2.linear.Quartet_II_linear(self.d, self.d, bias=False)
self.lk = quartet2.linear.Quartet_II_linear(self.d, self.d_KV, bias=False)
self.lv = quartet2.linear.Quartet_II_linear(self.d, self.d_KV, bias=False)
self.lo = quartet2.linear.Quartet_II_linear(self.d, self.d, bias=False)
elif fake_quartet:
from . import fake_quartet as fq
self.lq = fq.FakeQuartetLinear(self.d, self.d, bias=False)
self.lk = fq.FakeQuartetLinear(self.d, self.d_KV, bias=False)
self.lv = fq.FakeQuartetLinear(self.d, self.d_KV, bias=False)
self.lo = fq.FakeQuartetLinear(self.d, self.d, bias=False)
else:
self.lq = torch.nn.Linear(self.d, self.d, bias=False)
self.lk = torch.nn.Linear(self.d, self.d_KV, bias=False)
self.lv = torch.nn.Linear(self.d, self.d_KV, bias=False)
self.lo = torch.nn.Linear(self.d, self.d, bias=False)
# (batches*)context*d
def forward(self, X, causal=None, rope=None, alibi=None, swa=None, return_A=False, backend="flash2"):
# (batches*)context*d
Q = self.lq(X)
# (batches*)context*d_KV
K = self.lk(X)
V = self.lv(X)
# (batches*)context*heads*d_head
Q = Q.unflatten(dim=-1, sizes=(self.heads, self.d_head))
# (batches*)context*groups*d_head
K = K.unflatten(dim=-1, sizes=(self.groups, self.d_head))
V = V.unflatten(dim=-1, sizes=(self.groups, self.d_head))
# (batches*)heads*context*d_head
Q = Q.movedim(-3,-2)
# (batches*)groups*context*d_head
K = K.movedim(-3,-2)
V = V.movedim(-3,-2)
if rope is not None:
Q = apply_rope(Q,rope)
K = apply_rope(K,rope)
# After RoPE
if self.qk_norm:
Q = mlp.sphere_norm(Q)
K = mlp.sphere_norm(K)
# (batches*)heads*context*d_head
if not return_A:
Y = sdpa_wrapper(Q, K, V, causal, alibi, swa, self.scale, return_A, backend)
else:
Y, A__, A_, A = sdpa_wrapper(Q, K, V, causal, alibi, swa, self.scale, return_A, backend)
# (batches*)context*heads*d_head
Y = Y.movedim(-3,-2)
# (batches*)context*d
Y = Y.flatten(-2,-1)
Y = self.lo(Y)
if not return_A:
return Y
else:
return Y, A__, A_, A
class Block(torch.nn.Module):
def __init__(self, heads, d_head, scale_type="1/sqrt(d)", ratio=1, exp_factor=4, dropout=0, norm_type="rms_learned", bias=False, act=mlp.ReLU2(), l1_type="linear", pre_att_norm=False, qk_norm=True, out_att_norm=True, pre_mlp_norm=False, act_norm=False, out_mlp_norm=True, quartet=True, fake_quartet=False):
super().__init__()
self.heads = heads
self.d_head = d_head
self.d = heads * d_head
self.scale_type = scale_type
self.ratio = ratio
self.groups = heads//ratio
self.exp_factor = exp_factor
self.d_hidden = int(exp_factor*self.d)
self.dropout = dropout
self.norm_type = norm_type
self.bias = bias
self.act = act
self.l1_type = l1_type
self.mhsa = MHSA(heads, d_head, scale_type, ratio, qk_norm, quartet, fake_quartet)
self.pre_att_norm = mlp.get_norm(pre_att_norm, norm_type, self.d, bias)
self.out_att_norm = mlp.get_norm(out_att_norm, norm_type, self.d, bias)
self.mlp = mlp.MLP2L(self.d, self.d_hidden, self.d, bias, act, dropout, l1_type, norm_type, act_norm, quartet, fake_quartet)
self.pre_mlp_norm = mlp.get_norm(pre_mlp_norm, norm_type, self.d, bias)
self.out_mlp_norm = mlp.get_norm(out_mlp_norm, norm_type, self.d, bias)
self.quartet = quartet
self.fake_quartet = fake_quartet
def forward(self, X, causal=None, rope=None, alibi=None, swa=None, return_res=False, return_A=False, backend="flash2"):
mhsa = self.mhsa(self.pre_att_norm(X) if self.pre_att_norm else X, causal, rope, alibi, swa, return_A, backend)
if not return_A:
Y = mhsa
else:
Y, A__, A_, A = mhsa
if self.out_att_norm: Y = self.out_att_norm(Y)
Y_ = torch.nn.functional.dropout(Y, p=self.dropout, training=self.training)
Y__ = X + Y_
Z = self.mlp(self.pre_mlp_norm(Y__) if self.pre_mlp_norm else Y__)
if self.out_mlp_norm: Z = self.out_mlp_norm(Z)
Z_ = torch.nn.functional.dropout(Z, p=self.dropout, training=self.training)
Z__ = Y__ + Z_
if not return_res:
if not return_A:
return Z__
else:
return Z__, A__, A_, A
else:
if not return_A:
return Z__, Y__
else:
return Z__, Y__, A__, A_, A
class Transformer(torch.nn.Module):
def __init__(self, vocab_size=50304, num_blocks=12, heads=12, d_head=64, scale_type="1/sqrt(d)", ratio=1, is_causal=True, window=None, backend="flash2", exp_factor=4, dropout=0, pos_type="rope", max_context=128, norm_type="rms_learned", bias=False, act=mlp.ReLU2(), l1_type="linear", std=0.02, test=False, weight_tying=True, emb_norm=False, pre_att_norm=False, qk_norm=True, out_att_norm=True, pre_mlp_norm=False, act_norm=False, out_mlp_norm=True, out_norm=True, fix_norm=False, quartet=True, fake_quartet=False):
super().__init__()
self.vocab_size = vocab_size
self.num_blocks = num_blocks
self.heads = heads
self.d_head = d_head
self.d = heads * d_head
self.scale_type = scale_type
self.ratio = ratio
self.groups = heads//ratio
self.is_causal = is_causal
self.window = window
self.backend = backend
self.exp_factor = exp_factor
self.dropout = dropout
self.pos_type = pos_type
self.max_context = max_context
self.norm_type = norm_type
self.bias = bias
self.act = act
self.l1_type = l1_type
self.weight_tying = weight_tying
self.fix_norm = fix_norm
self.quartet = quartet
self.fake_quartet = fake_quartet
self.emb = torch.nn.Embedding(vocab_size, self.d)
self.emb_norm = mlp.get_norm(emb_norm, norm_type, self.d, bias)
if pos_type == "learned":
pos = torch.rand((max_context, self.d))
self.pos = torch.nn.Parameter(pos)
self.blocks = torch.nn.Sequential(*[Block(heads, d_head, scale_type, ratio, exp_factor, dropout, norm_type, bias, act, l1_type, pre_att_norm, qk_norm, out_att_norm, pre_mlp_norm, act_norm, out_mlp_norm, quartet, fake_quartet) for _ in range(num_blocks)])
self.out_norm = mlp.get_norm(out_norm, norm_type, self.d, bias)
self.linear = torch.nn.Linear(self.d, vocab_size, bias=False)
if weight_tying: self.emb.weight = self.linear.weight
self.init(std, test)
if fake_quartet:
for m in self.modules():
if isinstance(m, (torch.nn.LayerNorm, torch.nn.RMSNorm, torch.nn.Embedding)):
m.to(torch.bfloat16)
def init(self, std=0.02, test=False):
if test: print("\x1b[1m%36.36s %8.8s %8.8s %8.8s\x1b[0m" % ("parameter_name", "suffix", "mean", "std"))
for parameter_name, parameter in self.named_parameters():
parent_name, _, suffix = parameter_name.rpartition(".")
parent = self.get_submodule(parent_name)
if isinstance(parent, (torch.nn.Linear, torch.nn.Embedding)) and suffix=="weight":
torch.nn.init.normal_(parameter, 0, std)
elif isinstance(parent, (torch.nn.Linear, torch.nn.LayerNorm)) and suffix=="bias":
torch.nn.init.zeros_(parameter)
elif isinstance(parent, (torch.nn.LayerNorm, torch.nn.RMSNorm)) and suffix=="weight":
torch.nn.init.ones_(parameter)
else:
# pos
if parameter.ndim == 2:
torch.nn.init.zeros_(parameter)
# scale
elif parameter.ndim == 4:
torch.nn.init.constant_(parameter, sqrt(self.d_head))
if test:
print("%36.36s %8.8s %8.8s %8.8s\x1b[0m" % (parameter_name, suffix, "%f" % parameter.mean(), "%f" % parameter.std()))
# (batches*)context
def forward(self, ids, return_res=False, return_A=False):
context = ids.shape[-1]
if return_A:
# (batches*)num_blocks*heads*context*context
A__ = torch.empty(*ids.shape[:-1], self.num_blocks, self.heads, context, context)
A_ = torch.empty_like(A__)
A = torch.empty_like(A__)
# (batches*)context*d
X = self.emb(ids)
if return_res:
res_in = X
# (batches*)num_blocks*context*d
res_att = torch.empty(*ids.shape[:-1], self.num_blocks, context, self.d)
res_mlp = torch.empty(*ids.shape[:-1], self.num_blocks, context, self.d)
# Recompute in every batch in case context changes
if self.is_causal:
if self.backend=="pytorch":
causal = get_causal(context).to(ids.device)
elif self.backend in {"flash2", "flash3", "flash4"}:
causal = True
elif self.backend=="flex":
causal = causal_mod
elif self.backend=="cudnn":
# right_bound
causal = 0
else: causal = None
if self.pos_type == "sinusoidal":
pos = get_sinusoidal(context, self.d).to(ids.device)
X = X + pos
if self.pos_type == "learned":
X = X + self.pos[:context,:]
if self.pos_type == "rope":
rope = get_rope(context, self.d_head, device=ids.device)
else: rope = None
if self.pos_type == "alibi":
if self.backend=="pytorch":
alibi = get_alibi(self.heads, context).to(ids.device)
elif self.backend in {"flash2", "flash3", "flash4"}:
alibi = get_m(self.heads).to(ids.device)
elif self.backend=="flex":
alibi = alibi_mod
elif self.backend=="cudnn":
alibi = True
else: alibi = None
if self.window is not None:
if self.backend=="pytorch":
swa = get_swa(context, self.window).to(ids.device)
elif self.backend in {"flash2", "flash3", "flash4"}:
swa = (self.window, self.window)
elif self.backend=="flex":
swa = swa_mod
elif self.backend=="cudnn":
# left_bound
swa = self.window
else: swa = None
# After positional encoding
if self.emb_norm: X = self.emb_norm(X)
X_ = torch.nn.functional.dropout(X, p=self.dropout, training=self.training)
Y = X_
for i, block in enumerate(self.blocks):
if not return_res:
if not return_A:
Y = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
else:
Y, A__[...,i,:,:,:], A_[...,i,:,:,:], A[...,i,:,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
else:
if not return_A:
Y, res_att[...,i,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
res_mlp[...,i,:,:]= Y
else:
Y, res_att[...,i,:,:], A__[...,i,:,:,:], A_[...,i,:,:,:], A[...,i,:,:,:] = block(Y, causal, rope, alibi, swa, return_res, return_A, self.backend)
res_mlp[...,i,:,:]= Y
if self.out_norm: Y = self.out_norm(Y)
# (batches*)context*vocab_size
if self.fix_norm:
Z = torch.nn.functional.linear(Y, mlp.sphere_norm(self.linear.weight))
else:
Z = self.linear(Y)
if not return_res:
if not return_A:
return Z
else:
return Z, A__, A_, A
else:
if not return_A:
return Z, res_in, res_att, res_mlp
else:
return Z, res_in, res_att, res_mlp, A__, A_, A
def get_attention_header(transformer):
attention_header = ""
for block in range(transformer.num_blocks):
for head in range(transformer.heads):
attention_header += f"block{block}.head{head} "
# Remove last space
attention_header = attention_header[:-1]
return attention_header
def get_attention(W):
attention = ""
for block in range(W.shape[0]):
for head in range(W.shape[1]):
# rows->y, columns->x
attention += "%.2f " % W[block, head]
# Remove last space
attention = attention[:-1]
return attention
def get_similarity_header(transformer):
similarity_header = "embedding "
for block in range(transformer.num_blocks):
similarity_header += f"block{block} "
# Remove last space
similarity_header = similarity_header[:-1]
return similarity_header
def get_similarity(embeddings_x, embeddings_y):
similarity = ""
for block in range(embeddings_x.shape[0]):
similarity += "%.2f " % torch.nn.functional.cosine_similarity(embeddings_x[block,:], embeddings_y[block,:], dim=0)
# Remove last space
similarity = similarity[:-1]
return similarity
def get_clustering_header(transformer):
clustering_header = "embedding.random.x embedding.random.y "\
"embedding.pca.x embedding.pca.y "\
"embedding.mds.x embedding.mds.y "\
"embedding.tsne.x embedding.tsne.y "\
"embedding.umap.x embedding.umap.y "
for block in range(transformer.num_blocks):
clustering_header += f"block{block}.random.x block{block}.random.y "\
f"block{block}.pca.x block{block}.pca.y "\
f"block{block}.mds.x block{block}.mds.y "\
f"block{block}.tsne.x block{block}.tsne.y "\
f"block{block}.umap.x block{block}.umap.y "
# Remove last space
clustering_header = clustering_header[:-1]
return clustering_header
def get_clustering(random, pca, mds, tsne, umap):
clustering = ""
for block in range(random.shape[0]):
clustering += "%f %f %f %f %f %f %f %f %f %f " % (random[block,0], random[block,1], pca[block,0], pca[block,1], mds[block,0], mds[block,1], tsne[block,0], tsne[block,1], umap[block,0], umap[block,1])
# Remove last space
clustering = clustering[:-1]
return clustering