init
Browse files- inference.py +254 -0
- vocab_v20230424.txt +0 -0
inference.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
np.set_printoptions(precision=4, suppress=True, linewidth=200)
|
| 4 |
+
import types, torch
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
MyModule = torch.jit.ScriptModule
|
| 8 |
+
MyFunction = torch.jit.script_method
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RWKV_TOKENIZER():
|
| 12 |
+
table: list[list[list[bytes]]]
|
| 13 |
+
good: list[set[int]]
|
| 14 |
+
wlen: list[int]
|
| 15 |
+
|
| 16 |
+
def __init__(self, file_name):
|
| 17 |
+
self.idx2token = {}
|
| 18 |
+
sorted = [] # must be already sorted
|
| 19 |
+
lines = open(file_name, "r", encoding="utf-8").readlines()
|
| 20 |
+
for l in lines:
|
| 21 |
+
idx = int(l[:l.index(' ')])
|
| 22 |
+
x = eval(l[l.index(' '):l.rindex(' ')])
|
| 23 |
+
x = x.encode("utf-8") if isinstance(x, str) else x
|
| 24 |
+
assert isinstance(x, bytes)
|
| 25 |
+
assert len(x) == int(l[l.rindex(' '):])
|
| 26 |
+
sorted += [x]
|
| 27 |
+
self.idx2token[idx] = x
|
| 28 |
+
|
| 29 |
+
self.token2idx = {}
|
| 30 |
+
for k, v in self.idx2token.items():
|
| 31 |
+
self.token2idx[v] = int(k)
|
| 32 |
+
|
| 33 |
+
# precompute some tables for fast matching
|
| 34 |
+
self.table = [[[] for j in range(256)] for i in range(256)]
|
| 35 |
+
self.good = [set() for i in range(256)]
|
| 36 |
+
self.wlen = [0 for i in range(256)]
|
| 37 |
+
|
| 38 |
+
for i in reversed(range(len(sorted))): # reverse order - match longer tokens first
|
| 39 |
+
s = sorted[i]
|
| 40 |
+
if len(s) >= 2:
|
| 41 |
+
s0 = int(s[0])
|
| 42 |
+
s1 = int(s[1])
|
| 43 |
+
self.table[s0][s1] += [s]
|
| 44 |
+
self.wlen[s0] = max(self.wlen[s0], len(s))
|
| 45 |
+
self.good[s0].add(s1)
|
| 46 |
+
|
| 47 |
+
def encodeBytes(self, src: bytes) -> list[int]:
|
| 48 |
+
src_len: int = len(src)
|
| 49 |
+
tokens: list[int] = []
|
| 50 |
+
i: int = 0
|
| 51 |
+
while i < src_len:
|
| 52 |
+
s: bytes = src[i: i + 1]
|
| 53 |
+
|
| 54 |
+
if i < src_len - 1:
|
| 55 |
+
s1: int = int(src[i + 1])
|
| 56 |
+
s0: int = int(src[i])
|
| 57 |
+
if s1 in self.good[s0]:
|
| 58 |
+
sss: bytes = src[i: i + self.wlen[s0]]
|
| 59 |
+
try:
|
| 60 |
+
s = next(filter(sss.startswith, self.table[s0][s1]))
|
| 61 |
+
except:
|
| 62 |
+
pass
|
| 63 |
+
tokens.append(self.token2idx[s])
|
| 64 |
+
i += len(s)
|
| 65 |
+
|
| 66 |
+
return tokens
|
| 67 |
+
|
| 68 |
+
def decodeBytes(self, tokens):
|
| 69 |
+
return b''.join(map(lambda i: self.idx2token[i], tokens))
|
| 70 |
+
|
| 71 |
+
def encode(self, src: str):
|
| 72 |
+
return self.encodeBytes(src.encode("utf-8"))
|
| 73 |
+
|
| 74 |
+
def decode(self, tokens):
|
| 75 |
+
return self.decodeBytes(tokens).decode('utf-8')
|
| 76 |
+
|
| 77 |
+
def printTokens(self, tokens):
|
| 78 |
+
for i in tokens:
|
| 79 |
+
s = self.idx2token[i]
|
| 80 |
+
try:
|
| 81 |
+
s = s.decode('utf-8')
|
| 82 |
+
except:
|
| 83 |
+
pass
|
| 84 |
+
print(f'{repr(s)}{i}', end=' ')
|
| 85 |
+
# print(repr(s), i)
|
| 86 |
+
print()
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
########################################################################################################
|
| 90 |
+
|
| 91 |
+
def sample_logits(out, temperature=1.0, top_p=0.8):
|
| 92 |
+
probs = F.softmax(out, dim=-1).numpy()
|
| 93 |
+
sorted_probs = np.sort(probs)[::-1]
|
| 94 |
+
cumulative_probs = np.cumsum(sorted_probs)
|
| 95 |
+
cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
|
| 96 |
+
probs[probs < cutoff] = 0
|
| 97 |
+
if temperature != 1.0:
|
| 98 |
+
probs = probs.pow(1.0 / temperature)
|
| 99 |
+
probs = probs / np.sum(probs)
|
| 100 |
+
out = np.random.choice(a=len(probs), p=probs)
|
| 101 |
+
return out
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
########################################################################################################
|
| 105 |
+
class RWKV_RNN(MyModule):
|
| 106 |
+
def __init__(self, args):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.args = args
|
| 109 |
+
self.eval() # set torch to inference mode
|
| 110 |
+
|
| 111 |
+
w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')
|
| 112 |
+
|
| 113 |
+
for k in w.keys():
|
| 114 |
+
w[k] = w[k].float() # convert to f32 type
|
| 115 |
+
if '.time_' in k: w[k] = w[k].squeeze()
|
| 116 |
+
if '.time_faaaa' in k: w[k] = w[k].unsqueeze(-1)
|
| 117 |
+
|
| 118 |
+
self.n_head = w['blocks.0.att.time_faaaa'].shape[0]
|
| 119 |
+
self.head_size = w['blocks.0.ln1.weight'].shape[0] // self.n_head
|
| 120 |
+
|
| 121 |
+
self.w = types.SimpleNamespace() # set self.w from w
|
| 122 |
+
self.w.blocks = {}
|
| 123 |
+
for k in w.keys(): # example: "blocks.0.att.time_first" => self.w.blocks[0].att.time_first
|
| 124 |
+
parts = k.split('.')
|
| 125 |
+
last = parts.pop()
|
| 126 |
+
here = self.w
|
| 127 |
+
for p in parts:
|
| 128 |
+
if p.isdigit():
|
| 129 |
+
p = int(p)
|
| 130 |
+
if p not in here: here[p] = types.SimpleNamespace()
|
| 131 |
+
here = here[p]
|
| 132 |
+
else:
|
| 133 |
+
if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())
|
| 134 |
+
here = getattr(here, p)
|
| 135 |
+
setattr(here, last, w[k])
|
| 136 |
+
|
| 137 |
+
def layer_norm(self, x, w):
|
| 138 |
+
return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)
|
| 139 |
+
|
| 140 |
+
@MyFunction
|
| 141 |
+
def channel_mixing(self, x, state, i: int, time_maa_k, time_maa_r, kw, vw, rw):
|
| 142 |
+
i0 = (2 + self.head_size) * i + 0
|
| 143 |
+
sx = state[i0] - x
|
| 144 |
+
xk = x + sx * time_maa_k
|
| 145 |
+
xr = x + sx * time_maa_r
|
| 146 |
+
state[i0] = x
|
| 147 |
+
r = torch.sigmoid(rw @ xr)
|
| 148 |
+
k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper
|
| 149 |
+
return r * (vw @ k)
|
| 150 |
+
|
| 151 |
+
@MyFunction
|
| 152 |
+
def time_mixing(self, x, state, i: int, x_maa, w_maa, k_maa, v_maa, r_maa, g_maa, tm_w1, tm_w2, td_w1, td_w2,
|
| 153 |
+
time_first, time_decay, kw, vw, rw, gw, ow, ln_w, ln_b):
|
| 154 |
+
H = self.n_head
|
| 155 |
+
S = self.head_size
|
| 156 |
+
|
| 157 |
+
i1 = (2 + S) * i + 1
|
| 158 |
+
sx = state[i1] - x
|
| 159 |
+
state[i1] = x
|
| 160 |
+
xxx = x + sx * x_maa
|
| 161 |
+
xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
|
| 162 |
+
xxx = torch.bmm(xxx, tm_w2).view(5, -1)
|
| 163 |
+
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
|
| 164 |
+
|
| 165 |
+
xw = x + sx * (w_maa + mw)
|
| 166 |
+
xk = x + sx * (k_maa + mk)
|
| 167 |
+
xv = x + sx * (v_maa + mv)
|
| 168 |
+
xr = x + sx * (r_maa + mr)
|
| 169 |
+
xg = x + sx * (g_maa + mg)
|
| 170 |
+
|
| 171 |
+
w = (time_decay + (torch.tanh(xw @ td_w1) @ td_w2).float()).view(H, S, 1)
|
| 172 |
+
w = torch.exp(-torch.exp(w.float()))
|
| 173 |
+
|
| 174 |
+
r = (rw @ xr).view(H, 1, S)
|
| 175 |
+
k = (kw @ xk).view(H, S, 1)
|
| 176 |
+
v = (vw @ xv).view(H, 1, S)
|
| 177 |
+
g = F.silu(gw @ xg)
|
| 178 |
+
|
| 179 |
+
s = state[(2 + S) * i + 2:(2 + S) * (i + 1), :].reshape(H, S, S)
|
| 180 |
+
|
| 181 |
+
x = torch.zeros(H, S)
|
| 182 |
+
a = k @ v
|
| 183 |
+
x = r @ (time_first * a + s)
|
| 184 |
+
s = a + w * s
|
| 185 |
+
|
| 186 |
+
state[(2 + S) * i + 2:(2 + S) * (i + 1), :] = s.reshape(S, -1)
|
| 187 |
+
x = x.flatten()
|
| 188 |
+
|
| 189 |
+
x = F.group_norm(x.unsqueeze(0), num_groups=H, weight=ln_w, bias=ln_b, eps=64e-5).squeeze(
|
| 190 |
+
0) * g # same as gn(x/8, eps=1e-5)
|
| 191 |
+
return ow @ x
|
| 192 |
+
|
| 193 |
+
def forward(self, token, state):
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
if state == None:
|
| 196 |
+
state = torch.zeros(self.args.n_layer * (2 + self.head_size), self.args.n_embd)
|
| 197 |
+
|
| 198 |
+
x = self.w.emb.weight[token]
|
| 199 |
+
x = self.layer_norm(x, self.w.blocks[0].ln0)
|
| 200 |
+
for i in range(self.args.n_layer):
|
| 201 |
+
att = self.w.blocks[i].att
|
| 202 |
+
x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i,
|
| 203 |
+
att.time_maa_x, att.time_maa_w, att.time_maa_k, att.time_maa_v, att.time_maa_r,
|
| 204 |
+
att.time_maa_g, att.time_maa_w1, att.time_maa_w2,
|
| 205 |
+
att.time_decay_w1, att.time_decay_w2, att.time_faaaa, att.time_decay,
|
| 206 |
+
att.key.weight, att.value.weight, att.receptance.weight, att.gate.weight,
|
| 207 |
+
att.output.weight,
|
| 208 |
+
att.ln_x.weight, att.ln_x.bias)
|
| 209 |
+
ffn = self.w.blocks[i].ffn
|
| 210 |
+
x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i,
|
| 211 |
+
ffn.time_maa_k, ffn.time_maa_r,
|
| 212 |
+
ffn.key.weight, ffn.value.weight, ffn.receptance.weight)
|
| 213 |
+
|
| 214 |
+
x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)
|
| 215 |
+
return x.float(), state
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
tokenizer = RWKV_TOKENIZER("vocab_v20230424.txt")
|
| 219 |
+
|
| 220 |
+
args = types.SimpleNamespace()
|
| 221 |
+
args.MODEL_NAME = 'rwkv-30'
|
| 222 |
+
args.n_layer = 12
|
| 223 |
+
args.n_embd = 768
|
| 224 |
+
args.vocab_size = 65536
|
| 225 |
+
|
| 226 |
+
context = "Today is a beautiful"
|
| 227 |
+
NUM_TRIALS = 3
|
| 228 |
+
LENGTH_PER_TRIAL = 50
|
| 229 |
+
TEMPERATURE = 1.0
|
| 230 |
+
TOP_P = 0.7
|
| 231 |
+
|
| 232 |
+
print(f'model= {args.MODEL_NAME}\n')
|
| 233 |
+
model = RWKV_RNN(args)
|
| 234 |
+
init_state = None
|
| 235 |
+
for token in tokenizer.encode(context):
|
| 236 |
+
init_out, init_state = model.forward(token, init_state)
|
| 237 |
+
|
| 238 |
+
for TRIAL in range(NUM_TRIALS):
|
| 239 |
+
print(f'Trial {TRIAL + 1}=', context, end="")
|
| 240 |
+
all_tokens = []
|
| 241 |
+
n_sampled = 0
|
| 242 |
+
out, state = init_out.clone(), init_state.clone()
|
| 243 |
+
for i in range(LENGTH_PER_TRIAL):
|
| 244 |
+
token = sample_logits(out, TEMPERATURE, TOP_P)
|
| 245 |
+
all_tokens += [token]
|
| 246 |
+
try:
|
| 247 |
+
tmp = tokenizer.decode(all_tokens[n_sampled:])
|
| 248 |
+
if '\ufffd' not in tmp: # only print when we have a valid utf-8 string
|
| 249 |
+
print(tmp, end="", flush=True)
|
| 250 |
+
n_sampled = i + 1
|
| 251 |
+
except:
|
| 252 |
+
pass
|
| 253 |
+
out, state = model.forward(token, state)
|
| 254 |
+
print("\nSampled tokens=", n_sampled, "out of", LENGTH_PER_TRIAL, "tokens\n")
|
vocab_v20230424.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|