File size: 4,930 Bytes
b0fd683 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 | import torch
# Normalizes on the hypersphere along dim
# (s1*...*)s-1
def sphere_norm(X, dim=-1):
return torch.nn.functional.normalize(X, dim=dim)
class SphereNorm(torch.nn.Module):
def __init__(self, dim=-1):
super().__init__()
self.dim = dim
def forward(self, X):
Y = sphere_norm(X, dim=self.dim)
return Y
def get_norm(enable, norm_type, d, bias):
if enable:
if norm_type=="layer":
norm = torch.nn.LayerNorm(d, bias=bias)
elif norm_type=="rms_learned":
norm = torch.nn.RMSNorm(d, elementwise_affine=True)
elif norm_type=="rms_const":
norm = torch.nn.RMSNorm(d, elementwise_affine=False)
elif norm_type=="sphere":
norm = SphereNorm(dim=-1)
else:
norm = None
return norm
class ReLU2(torch.nn.Module):
def forward(self, x):
y = torch.nn.functional.relu(x)**2
return y
class Abs(torch.nn.Module):
def forward(self, x):
y = x.abs()
return y
class GLU(torch.nn.Module):
def __init__(self, d0, d1, bias=True, act=torch.nn.ReLU(), quartet=True, fake_quartet=False):
super().__init__()
self.d0 = d0
self.d1 = d1
self.bias = bias
self.act = act
self.quartet = quartet
self.fake_quartet = fake_quartet
if quartet:
pass # quartet2 not available in HF mode
self.gate = torch.nn.Sequential(quartet2.linear.Quartet_II_linear(d0, d1, bias), act)
self.proj = quartet2.linear.Quartet_II_linear(d0, d1, bias)
elif fake_quartet:
from . import fake_quartet as fq
self.gate = torch.nn.Sequential(fq.FakeQuartetLinear(d0, d1, bias), act)
self.proj = fq.FakeQuartetLinear(d0, d1, bias)
else:
self.gate = torch.nn.Sequential(torch.nn.Linear(d0, d1, bias), act)
self.proj = torch.nn.Linear(d0, d1, bias)
def forward(self, x):
y = self.gate(x) * self.proj(x)
return y
class MLP2L(torch.nn.Module):
def __init__(self, d0, d1, d2, bias=True, act=torch.nn.ReLU(), dropout=0, l1_type="linear", norm_type="rms_learned", norm=False, quartet=True, fake_quartet=False):
super().__init__()
self.d0 = d0
self.d1 = d1
self.d2 = d2
self.bias = bias
self.act = act
self.dropout = dropout
self.l1_type = l1_type
self.norm_type = norm_type
if l1_type=="linear":
if quartet:
pass # quartet2 not available in HF mode
self.l1 = torch.nn.Sequential(quartet2.linear.Quartet_II_linear(d0, d1, bias), act)
elif fake_quartet:
from . import fake_quartet as fq
self.l1 = torch.nn.Sequential(fq.FakeQuartetLinear(d0, d1, bias), act)
else:
self.l1 = torch.nn.Sequential(torch.nn.Linear(d0, d1, bias), act)
elif l1_type=="glu":
self.l1 = GLU(d0, d1, bias, act, quartet, fake_quartet)
self.norm = get_norm(norm, norm_type, d1, bias)
if quartet:
pass # quartet2 not available in HF mode
self.l2 = quartet2.linear.Quartet_II_linear(d1, d2, bias)
elif fake_quartet:
from . import fake_quartet as fq
self.l2 = fq.FakeQuartetLinear(d1, d2, bias)
else:
self.l2 = torch.nn.Linear(d1, d2, bias)
def forward(self, x):
a1 = self.l1(x)
if self.norm: a1 = self.norm(a1)
a1 = torch.nn.functional.dropout(a1, p=self.dropout, training=self.training)
y = self.l2(a1)
return y
class MLP3L(torch.nn.Module):
def __init__(self, d0, d1, d2, d3, bias=True, act=torch.nn.ReLU(), dropout=0):
super().__init__()
self.d0 = d0
self.d1 = d1
self.d2 = d2
self.d3 = d3
self.bias = bias
self.act = act
self.dropout=dropout
self.l1 = torch.nn.Linear(d0, d1, bias)
self.l2 = torch.nn.Linear(d1, d2, bias)
self.l3 = torch.nn.Linear(d2, d3, bias)
def forward(self, x):
z1 = self.l1(x)
a1 = self.act(z1)
a1 = torch.nn.functional.dropout(a1, p=self.dropout, training=self.training)
z2 = self.l2(a1)
a2 = self.act(z2)
a2 = torch.nn.functional.dropout(a2, p=self.dropout, training=self.training)
y = self.l3(a2)
return y
class MLP3L_image(torch.nn.Module):
def __init__(self, res=28, d1=16, d2=16, dropout=0, classes=10):
super().__init__()
self.res = res
self.d1 = d1
self.d2 = d2
self.dropout = dropout
self.classes = classes
self.mlp = MLP3L(res*res, d1, d2, classes, dropout=dropout)
def forward(self, x):
x = x.flatten(start_dim=-3, end_dim=-1)
y = self.mlp(x)
return y
|