Commit
·
d09f022
1
Parent(s):
c2cd532
Upload TEVR Explanation.ipynb
Browse files- TEVR Explanation.ipynb +186 -0
TEVR Explanation.ipynb
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "89c94977",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"from huggingface_hub import snapshot_download\n",
|
| 11 |
+
"data_folder = snapshot_download(\"fxtentacle/tevr-token-entropy-predictor-de\")"
|
| 12 |
+
]
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"cell_type": "code",
|
| 16 |
+
"execution_count": 2,
|
| 17 |
+
"id": "a48a49d6",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"source": [
|
| 21 |
+
"from transformers import T5ForConditionalGeneration\n",
|
| 22 |
+
"model = T5ForConditionalGeneration.from_pretrained(data_folder)\n",
|
| 23 |
+
"model.to('cuda')\n",
|
| 24 |
+
"model.eval()\n",
|
| 25 |
+
"None"
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "code",
|
| 30 |
+
"execution_count": 3,
|
| 31 |
+
"id": "eed8bfc3",
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"outputs": [],
|
| 34 |
+
"source": [
|
| 35 |
+
"import torch\n",
|
| 36 |
+
"\n",
|
| 37 |
+
"def text_to_cross_entropy(text):\n",
|
| 38 |
+
" ttext = torch.tensor([[0]+list(text.encode('UTF-8'))],dtype=torch.int64).to('cuda')\n",
|
| 39 |
+
" tone = torch.tensor([[1]],dtype=torch.int32).to('cuda')\n",
|
| 40 |
+
" logits = model.forward(input_ids=tone, attention_mask=tone, decoder_input_ids=ttext, return_dict=False)[0].detach()\n",
|
| 41 |
+
" cross_entropy = torch.nn.functional.cross_entropy(input=logits[0][:-1], target=ttext[0][1:], reduction='none').detach().cpu().numpy()\n",
|
| 42 |
+
" return cross_entropy"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": 4,
|
| 48 |
+
"id": "8ec8cf8d",
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"outputs": [],
|
| 51 |
+
"source": [
|
| 52 |
+
"import sys\n",
|
| 53 |
+
"import os\n",
|
| 54 |
+
"sys.path.append(data_folder)\n",
|
| 55 |
+
"from text_tokenizer import HajoTextTokenizer"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": 5,
|
| 61 |
+
"id": "37165805",
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"tokenizer_file = 'text-tokenizer-de-4m.txt'\n",
|
| 66 |
+
"text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file)"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"cell_type": "code",
|
| 71 |
+
"execution_count": 6,
|
| 72 |
+
"id": "73e55343",
|
| 73 |
+
"metadata": {},
|
| 74 |
+
"outputs": [
|
| 75 |
+
{
|
| 76 |
+
"name": "stdout",
|
| 77 |
+
"output_type": "stream",
|
| 78 |
+
"text": [
|
| 79 |
+
"['die', ' ', 'k', 'at', 'ze', ' ', 'ist', ' ', 'n', 'ied', 'lich']\n",
|
| 80 |
+
"[3.3762913048267365, 3.3762913048267365, 3.3762913048267365, 0.29695791006088257, 4.193424224853516, 2.3430762887001038, 2.3430762887001038, 2.8417416363954544, 2.8417416363954544, 1.1227068901062012, 2.017452405144771, 2.017452405144771, 2.017452405144771, 0.0016304069431498647, 2.580254554748535, 2.3091587026913962, 2.3091587026913962, 2.3091587026913962, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508]\n"
|
| 81 |
+
]
|
| 82 |
+
}
|
| 83 |
+
],
|
| 84 |
+
"source": [
|
| 85 |
+
"text = \"die katze ist niedlich\"\n",
|
| 86 |
+
"cross_entropy = text_to_cross_entropy(text)\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"tokens = text_tokenizer.encode(text)\n",
|
| 89 |
+
"tokens = [text_tokenizer.all_tokens[t] for t in tokens]\n",
|
| 90 |
+
"print(tokens)\n",
|
| 91 |
+
"token_sums = []\n",
|
| 92 |
+
"token_sums2 = []\n",
|
| 93 |
+
"for t in tokens:\n",
|
| 94 |
+
" ce = sum(cross_entropy[len(token_sums):len(token_sums)+len(t)])\n",
|
| 95 |
+
" for r in range(len(t)): token_sums.append(ce / len(t))\n",
|
| 96 |
+
" token_sums2.append(ce)\n",
|
| 97 |
+
"print(token_sums)"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": 7,
|
| 103 |
+
"id": "e61e00aa",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [
|
| 106 |
+
{
|
| 107 |
+
"data": {
|
| 108 |
+
"text/html": [
|
| 109 |
+
"<table style=\"font-size: 20px; font-family: Roboto\"><tr><td><b>(1)</b></td><td style=\"text-align:left\">d</td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">e</td><td style=\"text-align:left\"> </td><td style=\"text-align:left\">k</td><td style=\"text-align:left\">a</td><td style=\"text-align:left\">t</td><td style=\"text-align:left\">z</td><td style=\"text-align:left\">e</td><td style=\"text-align:left\"> </td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">s</td><td style=\"text-align:left\">t</td><td style=\"text-align:left\"> </td><td style=\"text-align:left\">n</td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">e</td><td style=\"text-align:left\">d</td><td style=\"text-align:left\">l</td><td style=\"text-align:left\">i</td><td style=\"text-align:left\">c</td><td style=\"text-align:left\">h</td></tr><tr><td><b>(2)</b></td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>σ²=0.0</td></tr><tr><td><b>(3)</b></td><td>8.9</td><td>1.0</td><td>0.2</td><td>0.3</td><td>4.2</td><td>1.6</td><td>3.1</td><td>5.4</td><td>0.3</td><td>1.1</td><td>3.0</td><td>3.0</td><td>0.0</td><td>0.0</td><td>2.6</td><td>0.6</td><td>4.4</td><td>1.9</td><td>4.0</td><td>0.0</td><td>0.0</td><td>0.0</td><td>σ²=5.0</td></tr><tr><td><b>(4)</b></td><td style=\"text-align:center\" colspan=3>die</td><td style=\"text-align:center\" colspan=1> </td><td style=\"text-align:center\" colspan=1>k</td><td style=\"text-align:center\" colspan=2>at</td><td style=\"text-align:center\" colspan=2>ze</td><td style=\"text-align:center\" colspan=1> </td><td style=\"text-align:center\" colspan=3>ist</td><td style=\"text-align:center\" colspan=1> </td><td style=\"text-align:center\" colspan=1>n</td><td style=\"text-align:center\" colspan=3>ied</td><td style=\"text-align:center\" colspan=4>lich</td></tr><tr><td><b>(5)</b></td><td style=\"text-align:center\" colspan=3>10.1</td><td style=\"text-align:center\" colspan=1>0.3</td><td style=\"text-align:center\" colspan=1>4.2</td><td style=\"text-align:center\" colspan=2>4.7</td><td style=\"text-align:center\" colspan=2>5.7</td><td style=\"text-align:center\" colspan=1>1.1</td><td style=\"text-align:center\" colspan=3>6.1</td><td style=\"text-align:center\" colspan=1>0.0</td><td style=\"text-align:center\" colspan=1>2.6</td><td style=\"text-align:center\" colspan=3>6.9</td><td style=\"text-align:center\" colspan=4>4.1</td></tr><tr><td><b>(6)</b></td><td>3.4</td><td>3.4</td><td>3.4</td><td>0.3</td><td>4.2</td><td>2.3</td><td>2.3</td><td>2.8</td><td>2.8</td><td>1.1</td><td>2.0</td><td>2.0</td><td>2.0</td><td>0.0</td><td>2.6</td><td>2.3</td><td>2.3</td><td>2.3</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>σ²=1.1</td></tr></table>"
|
| 110 |
+
],
|
| 111 |
+
"text/plain": [
|
| 112 |
+
"<IPython.core.display.HTML object>"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
"execution_count": 7,
|
| 116 |
+
"metadata": {},
|
| 117 |
+
"output_type": "execute_result"
|
| 118 |
+
}
|
| 119 |
+
],
|
| 120 |
+
"source": [
|
| 121 |
+
"import numpy as np\n",
|
| 122 |
+
"html = '<table style=\"font-size: 20px; font-family: Roboto\">'\n",
|
| 123 |
+
"html += '<tr><td><b>(1)</b></td>'+''.join([f'<td style=\"text-align:left\">{c}</td>' for c in list(text)])+'</tr>'\n",
|
| 124 |
+
"html += '<tr><td><b>(2)</b></td>'+''.join(['<td>1.0</td>'.format(v) for v in cross_entropy])+'<td>σ²={:3.1f}</td>'.format(np.var([1.0 for v in cross_entropy]))+'</tr>'\n",
|
| 125 |
+
"html += '<tr><td><b>(3)</b></td>'+''.join(['<td>{:3.1f}</td>'.format(v) for v in cross_entropy])+'<td>σ²={:3.1f}</td>'.format(np.var(cross_entropy))+'</tr>'\n",
|
| 126 |
+
"html += '<tr><td><b>(4)</b></td>'+''.join([f'<td style=\"text-align:center\" colspan={len(t)}>{t}</td>' for t in tokens])+'</tr>'\n",
|
| 127 |
+
"html += '<tr><td><b>(5)</b></td>'+''.join([f'<td style=\"text-align:center\" colspan={len(t)}>{\"{:3.1f}\".format(token_sums2[i])}</td>' for i,t in enumerate(tokens)])+'</tr>'\n",
|
| 128 |
+
"html += '<tr><td><b>(6)</b></td>'+''.join(['<td>{:3.1f}</td>'.format(v) for v in token_sums])+'<td>σ²={:3.1f}</td>'.format(np.var(token_sums))+'</tr>'\n",
|
| 129 |
+
"html += '</table>'\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"import IPython\n",
|
| 132 |
+
"IPython.display.HTML(html)"
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": 8,
|
| 138 |
+
"id": "dcafdcab",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [
|
| 141 |
+
{
|
| 142 |
+
"name": "stdout",
|
| 143 |
+
"output_type": "stream",
|
| 144 |
+
"text": [
|
| 145 |
+
"<pad>, <eos>, , chen, sche, lich, isch, icht, iche, eine, rden, tion, urde, haft, eich, rung, chte, ssen, chaf, nder, tlic, tung, eite, iert, sich, ngen, erde, scha, nden, unge, lung, mmen, eren, ende, inde, erun, sten, iese, igen, erte, iner, tsch, keit, der, die, ter, und, ein, ist, den, ten, ber, ver, sch, ung, ste, ent, ach, nte, auf, ben, eit, des, ers, aus, das, von, ren, gen, nen, lle, hre, mit, iel, uch, lte, ann, lie, men, dem, and, ind, als, sta, elt, ges, tte, ern, wir, ell, war, ere, rch, abe, len, ige, ied, ger, nnt, wei, ele, och, sse, end, all, ahr, bei, sie, ede, ion, ieg, ege, auc, che, rie, eis, vor, her, ang, für, ass, uss, tel, er, in, ge, en, st, ie, an, te, be, re, zu, ar, es, ra, al, or, ch, et, ei, un, le, rt, se, is, ha, we, at, me, ne, ur, he, au, ro, ti, li, ri, eh, im, ma, tr, ig, el, um, la, am, de, so, ol, tz, il, on, it, sc, sp, ko, na, pr, ni, si, fe, wi, ns, ke, ut, da, gr, eu, mi, hr, ze, hi, ta, ss, ng, sa, us, ba, ck, em, kt, ka, ve, fr, bi, wa, ah, gt, di, ab, fo, to, rk, as, ag, gi, hn, s, t, n, m, r, l, f, e, a, b, d, h, k, g, o, i, u, w, p, z, ä, ü, v, ö, j, c, y, x, q, á, í, ō, ó, š, é, č, ?\n"
|
| 146 |
+
]
|
| 147 |
+
}
|
| 148 |
+
],
|
| 149 |
+
"source": [
|
| 150 |
+
"from text_tokenizer import HajoTextTokenizer\n",
|
| 151 |
+
"text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file)\n",
|
| 152 |
+
"tt = text_tokenizer.all_tokens\n",
|
| 153 |
+
"print(', '.join(tt))"
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
{
|
| 157 |
+
"cell_type": "code",
|
| 158 |
+
"execution_count": null,
|
| 159 |
+
"id": "b87b7fd0",
|
| 160 |
+
"metadata": {},
|
| 161 |
+
"outputs": [],
|
| 162 |
+
"source": []
|
| 163 |
+
}
|
| 164 |
+
],
|
| 165 |
+
"metadata": {
|
| 166 |
+
"kernelspec": {
|
| 167 |
+
"display_name": "Python 3 (ipykernel)",
|
| 168 |
+
"language": "python",
|
| 169 |
+
"name": "python3"
|
| 170 |
+
},
|
| 171 |
+
"language_info": {
|
| 172 |
+
"codemirror_mode": {
|
| 173 |
+
"name": "ipython",
|
| 174 |
+
"version": 3
|
| 175 |
+
},
|
| 176 |
+
"file_extension": ".py",
|
| 177 |
+
"mimetype": "text/x-python",
|
| 178 |
+
"name": "python",
|
| 179 |
+
"nbconvert_exporter": "python",
|
| 180 |
+
"pygments_lexer": "ipython3",
|
| 181 |
+
"version": "3.7.5"
|
| 182 |
+
}
|
| 183 |
+
},
|
| 184 |
+
"nbformat": 4,
|
| 185 |
+
"nbformat_minor": 5
|
| 186 |
+
}
|