| This repo contains the fully trained ByT5 that was used to estimate per-character entropies. Using it, you can also recreate the illustration in the paper. | |
| ## Citation | |
| If you use this for research, please cite: | |
| ```bibtex | |
| @misc{https://doi.org/10.48550/arxiv.2206.12693, | |
| doi = {10.48550/ARXIV.2206.12693}, | |
| url = {https://arxiv.org/abs/2206.12693}, | |
| author = {Krabbenhöft, Hajo Nils and Barth, Erhardt}, | |
| keywords = {Computation and Language (cs.CL), Sound (cs.SD), Audio and Speech Processing (eess.AS), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering, F.2.1; I.2.6; I.2.7}, | |
| title = {TEVR: Improving Speech Recognition by Token Entropy Variance Reduction}, | |
| publisher = {arXiv}, | |
| year = {2022}, | |
| copyright = {Creative Commons Attribution 4.0 International} | |
| } | |
| ``` | |
| ## Generate TEVR Tokenizer from Text corpus | |
| (copy of `Generate TEVR Tokenizer.ipynb`) | |
| ```python | |
| # TODO: load large text dataset like OSCAR | |
| all_sentences_de = ["Über vier Jahrzehnte gehörte er zu den führenden Bildhauern Niederbayerns", "die katze ist niedlich"] * 1000 | |
| ``` | |
| ```python | |
| from huggingface_hub import snapshot_download | |
| data_folder = snapshot_download("fxtentacle/tevr-token-entropy-predictor-de") | |
| ``` | |
| ```python | |
| from transformers import T5ForConditionalGeneration | |
| model = T5ForConditionalGeneration.from_pretrained(data_folder) | |
| model.to('cuda') | |
| model.eval() | |
| None | |
| ``` | |
| ```python | |
| import torch | |
| def text_to_cross_entropy(text): | |
| ttext = torch.tensor([[0]+list(text.encode('UTF-8'))],dtype=torch.int64).to('cuda') | |
| tone = torch.tensor([[1]],dtype=torch.int32).to('cuda') | |
| logits = model.forward(input_ids=tone, attention_mask=tone, decoder_input_ids=ttext, return_dict=False)[0].detach() | |
| cross_entropy = torch.nn.functional.cross_entropy(input=logits[0][:-1], target=ttext[0][1:], reduction='none').detach().cpu().numpy() | |
| return cross_entropy | |
| ``` | |
| ```python | |
| text = all_sentences_de[0] | |
| cross_entropy = text_to_cross_entropy(text) | |
| print(text) | |
| for i in range(len(text)): | |
| print(text[i], cross_entropy[i]) | |
| ``` | |
| Über vier Jahrzehnte gehörte er zu den führenden Bildhauern Niederbayerns | |
| Ü 7.254014 | |
| b 0.17521738 | |
| e 0.00046933602 | |
| r 0.01929327 | |
| 0.0003675739 | |
| v 0.20927554 | |
| i 6.13207 | |
| e 0.3896482 | |
| r 0.009583538 | |
| 2.07364 | |
| J 0.02978594 | |
| a 2.483246 | |
| h 0.1591908 | |
| r 0.0045124847 | |
| z 0.00028653807 | |
| e 4.0242333 | |
| h 0.031035878 | |
| n 0.028907888 | |
| t 0.003264101 | |
| e 0.0018929198 | |
| 0.05816966 | |
| g 1.2782481 | |
| e 3.5076692 | |
| h 0.694337 | |
| ö 0.5319732 | |
| r 0.48336726 | |
| t 0.0050443523 | |
| e 0.0017187123 | |
| 0.14511283 | |
| e 1.0435015 | |
| r 0.18165778 | |
| 1.0247636 | |
| z 0.3594512 | |
| u 0.0077577736 | |
| 2.072764 | |
| d 0.17377533 | |
| e 1.0727838 | |
| n 1.2805216 | |
| 0.24939628 | |
| f 0.27717885 | |
| ü 0.012466482 | |
| h 4.4356546 | |
| r 1.7371752 | |
| e 0.051492628 | |
| n 2.99407 | |
| d 0.009648594 | |
| e 0.19667451 | |
| n 0.007495021 | |
| 0.2529005 | |
| B 0.004451485 | |
| i 0.024661187 | |
| l 0.0028436247 | |
| d 2.6620464 | |
| h 2.825038 | |
| a 0.8215449 | |
| u 0.011406565 | |
| e 2.9599652 | |
| r 0.45834702 | |
| n 0.11848967 | |
| 0.5955992 | |
| N 0.010709903 | |
| i 1.5338714 | |
| e 0.1834471 | |
| d 5.668945 | |
| e 2.052247 | |
| r 0.7692907 | |
| b 0.0675718 | |
| a 0.028234791 | |
| y 0.0045266068 | |
| e 4.1125383 | |
| r 1.2630856 | |
| n 5.436057 | |
| s 0.46446246 | |
| ```python | |
| from tqdm import tqdm | |
| sentence_data = all_sentences_de | |
| text_and_entropies = [] | |
| for text in tqdm(sentence_data): | |
| text_and_entropies.append([text,text_to_cross_entropy(text)]) | |
| ``` | |
| 100%|██████████| 2000/2000 [00:09<00:00, 219.00it/s] | |
| ```python | |
| from collections import Counter | |
| # 4s | |
| #target_lengths = [1] | |
| #token_budgets = [36] | |
| # 4m | |
| target_lengths = [4,3,2,1] | |
| token_budgets = [40,80,96,36] | |
| # 4l | |
| #target_lengths = [4,3,2,1] | |
| #token_budgets = [384,320,160,36] | |
| ngrams = [Counter() for l in target_lengths] | |
| tokens = [] | |
| for tgi,tgl in enumerate(target_lengths): | |
| for row in tqdm(text_and_entropies[1:]): | |
| use_text = row[0] | |
| use_scores = row[1] | |
| for t in tokens: | |
| use_text = use_text.replace(t[0],'#') | |
| candidates = [] | |
| for i in range(len(use_text)-(tgl-1)): | |
| part = use_text[i:i+tgl].lower() | |
| if '#' in part: continue | |
| if ' ' in part: continue | |
| if '-' in part: continue | |
| score = sum(use_scores[i:i+tgl]) | |
| # print(part, score) | |
| candidates.append([score, part]) | |
| candidates.sort(reverse=False) | |
| candidates = candidates[:max(1,int(len(candidates)/5))] | |
| #print(candidates) | |
| ngrams[tgi].update([c[1] for c in candidates]) | |
| new_tokens = ngrams[tgi].most_common(token_budgets[tgi]) | |
| print(new_tokens) | |
| tokens += new_tokens | |
| #break | |
| ``` | |
| 100%|██████████| 1999/1999 [00:00<00:00, 14645.88it/s] | |
| [('lich', 1000), ('hnte', 999), ('rbay', 999), ('örte', 999), ('hört', 999), ('ahrz', 999), ('jahr', 999), ('bild', 999)] | |
| 100%|██████████| 1999/1999 [00:00<00:00, 18574.04it/s] | |
| [('ist', 1000), ('den', 999), ('ber', 999), ('aue', 999), ('ern', 999), ('uer', 999)] | |
| 100%|██████████| 1999/1999 [00:00<00:00, 20827.32it/s] | |
| [('ni', 1000), ('ge', 999), ('er', 999), ('fü', 999), ('vi', 999)] | |
| 100%|██████████| 1999/1999 [00:00<00:00, 19927.45it/s] | |
| [('e', 2999), ('u', 999), ('n', 999), ('h', 999)] | |
| ```python | |
| all_tokens = ['<pad>','<eos>',' ']+[t[0] for t in tokens]+['?'] | |
| print(len(all_tokens), all_tokens) | |
| ``` | |
| 27 ['<pad>', '<eos>', ' ', 'lich', 'hnte', 'rbay', 'örte', 'hört', 'ahrz', 'jahr', 'bild', 'ist', 'den', 'ber', 'aue', 'ern', 'uer', 'ni', 'ge', 'er', 'fü', 'vi', 'e', 'u', 'n', 'h', '?'] | |
| ```python | |
| import json | |
| with open('./tevr-tokenizer.txt','wt') as f: | |
| json.dump(all_tokens, f) | |
| ``` | |
| ```python | |
| import sys | |
| import os | |
| sys.path.append(data_folder) | |
| from text_tokenizer import HajoTextTokenizer | |
| ``` | |
| ```python | |
| text_tokenizer = HajoTextTokenizer('./tevr-tokenizer.txt') | |
| ``` | |
| ```python | |
| sentence = "gehörte" | |
| print(sentence) | |
| encoded = text_tokenizer.encode(sentence) | |
| print(encoded) | |
| print([text_tokenizer.all_tokens[i] for i in encoded]) | |
| print([text_tokenizer.decode(encoded)]) | |
| ``` | |
| gehörte | |
| [18, 25, 6] | |
| ['ge', 'h', 'örte'] | |
| ['gehörte'] | |
| ## Testing Tokenizer File | |
| (copy of `TEVR Explanation.ipynb`) | |
| ```python | |
| from huggingface_hub import snapshot_download | |
| data_folder = snapshot_download("fxtentacle/tevr-token-entropy-predictor-de") | |
| ``` | |
| ```python | |
| from transformers import T5ForConditionalGeneration | |
| model = T5ForConditionalGeneration.from_pretrained(data_folder) | |
| model.to('cuda') | |
| model.eval() | |
| None | |
| ``` | |
| ```python | |
| import torch | |
| def text_to_cross_entropy(text): | |
| ttext = torch.tensor([[0]+list(text.encode('UTF-8'))],dtype=torch.int64).to('cuda') | |
| tone = torch.tensor([[1]],dtype=torch.int32).to('cuda') | |
| logits = model.forward(input_ids=tone, attention_mask=tone, decoder_input_ids=ttext, return_dict=False)[0].detach() | |
| cross_entropy = torch.nn.functional.cross_entropy(input=logits[0][:-1], target=ttext[0][1:], reduction='none').detach().cpu().numpy() | |
| return cross_entropy | |
| ``` | |
| ```python | |
| import sys | |
| import os | |
| sys.path.append(data_folder) | |
| from text_tokenizer import HajoTextTokenizer | |
| ``` | |
| ```python | |
| tokenizer_file = 'text-tokenizer-de-4m.txt' | |
| text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file) | |
| ``` | |
| ```python | |
| text = "die katze ist niedlich" | |
| cross_entropy = text_to_cross_entropy(text) | |
| tokens = text_tokenizer.encode(text) | |
| tokens = [text_tokenizer.all_tokens[t] for t in tokens] | |
| print(tokens) | |
| token_sums = [] | |
| token_sums2 = [] | |
| for t in tokens: | |
| ce = sum(cross_entropy[len(token_sums):len(token_sums)+len(t)]) | |
| for r in range(len(t)): token_sums.append(ce / len(t)) | |
| token_sums2.append(ce) | |
| print(token_sums) | |
| ``` | |
| ['die', ' ', 'k', 'at', 'ze', ' ', 'ist', ' ', 'n', 'ied', 'lich'] | |
| [3.3762913048267365, 3.3762913048267365, 3.3762913048267365, 0.29695791006088257, 4.193424224853516, 2.3430762887001038, 2.3430762887001038, 2.8417416363954544, 2.8417416363954544, 1.1227068901062012, 2.017452405144771, 2.017452405144771, 2.017452405144771, 0.0016304069431498647, 2.580254554748535, 2.3091587026913962, 2.3091587026913962, 2.3091587026913962, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508, 1.0126478232632508] | |
| ```python | |
| import numpy as np | |
| html = '<table style="font-size: 20px; font-family: Roboto">' | |
| html += '<tr><td><b>(1)</b></td>'+''.join([f'<td style="text-align:left">{c}</td>' for c in list(text)])+'</tr>' | |
| html += '<tr><td><b>(2)</b></td>'+''.join(['<td>1.0</td>'.format(v) for v in cross_entropy])+'<td>σ²={:3.1f}</td>'.format(np.var([1.0 for v in cross_entropy]))+'</tr>' | |
| html += '<tr><td><b>(3)</b></td>'+''.join(['<td>{:3.1f}</td>'.format(v) for v in cross_entropy])+'<td>σ²={:3.1f}</td>'.format(np.var(cross_entropy))+'</tr>' | |
| html += '<tr><td><b>(4)</b></td>'+''.join([f'<td style="text-align:center" colspan={len(t)}>{t}</td>' for t in tokens])+'</tr>' | |
| html += '<tr><td><b>(5)</b></td>'+''.join([f'<td style="text-align:center" colspan={len(t)}>{"{:3.1f}".format(token_sums2[i])}</td>' for i,t in enumerate(tokens)])+'</tr>' | |
| html += '<tr><td><b>(6)</b></td>'+''.join(['<td>{:3.1f}</td>'.format(v) for v in token_sums])+'<td>σ²={:3.1f}</td>'.format(np.var(token_sums))+'</tr>' | |
| html += '</table>' | |
| import IPython | |
| IPython.display.HTML(html) | |
| ``` | |
| <table style="font-size: 20px; font-family: Roboto"><tr><td><b>(1)</b></td><td style="text-align:left">d</td><td style="text-align:left">i</td><td style="text-align:left">e</td><td style="text-align:left"> </td><td style="text-align:left">k</td><td style="text-align:left">a</td><td style="text-align:left">t</td><td style="text-align:left">z</td><td style="text-align:left">e</td><td style="text-align:left"> </td><td style="text-align:left">i</td><td style="text-align:left">s</td><td style="text-align:left">t</td><td style="text-align:left"> </td><td style="text-align:left">n</td><td style="text-align:left">i</td><td style="text-align:left">e</td><td style="text-align:left">d</td><td style="text-align:left">l</td><td style="text-align:left">i</td><td style="text-align:left">c</td><td style="text-align:left">h</td></tr><tr><td><b>(2)</b></td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>σ²=0.0</td></tr><tr><td><b>(3)</b></td><td>8.9</td><td>1.0</td><td>0.2</td><td>0.3</td><td>4.2</td><td>1.6</td><td>3.1</td><td>5.4</td><td>0.3</td><td>1.1</td><td>3.0</td><td>3.0</td><td>0.0</td><td>0.0</td><td>2.6</td><td>0.6</td><td>4.4</td><td>1.9</td><td>4.0</td><td>0.0</td><td>0.0</td><td>0.0</td><td>σ²=5.0</td></tr><tr><td><b>(4)</b></td><td style="text-align:center" colspan=3>die</td><td style="text-align:center" colspan=1> </td><td style="text-align:center" colspan=1>k</td><td style="text-align:center" colspan=2>at</td><td style="text-align:center" colspan=2>ze</td><td style="text-align:center" colspan=1> </td><td style="text-align:center" colspan=3>ist</td><td style="text-align:center" colspan=1> </td><td style="text-align:center" colspan=1>n</td><td style="text-align:center" colspan=3>ied</td><td style="text-align:center" colspan=4>lich</td></tr><tr><td><b>(5)</b></td><td style="text-align:center" colspan=3>10.1</td><td style="text-align:center" colspan=1>0.3</td><td style="text-align:center" colspan=1>4.2</td><td style="text-align:center" colspan=2>4.7</td><td style="text-align:center" colspan=2>5.7</td><td style="text-align:center" colspan=1>1.1</td><td style="text-align:center" colspan=3>6.1</td><td style="text-align:center" colspan=1>0.0</td><td style="text-align:center" colspan=1>2.6</td><td style="text-align:center" colspan=3>6.9</td><td style="text-align:center" colspan=4>4.1</td></tr><tr><td><b>(6)</b></td><td>3.4</td><td>3.4</td><td>3.4</td><td>0.3</td><td>4.2</td><td>2.3</td><td>2.3</td><td>2.8</td><td>2.8</td><td>1.1</td><td>2.0</td><td>2.0</td><td>2.0</td><td>0.0</td><td>2.6</td><td>2.3</td><td>2.3</td><td>2.3</td><td>1.0</td><td>1.0</td><td>1.0</td><td>1.0</td><td>σ²=1.1</td></tr></table> | |
| ```python | |
| from text_tokenizer import HajoTextTokenizer | |
| text_tokenizer = HajoTextTokenizer(data_folder+'/'+tokenizer_file) | |
| tt = text_tokenizer.all_tokens | |
| print(', '.join(tt)) | |
| ``` | |
| <pad>, <eos>, , chen, sche, lich, isch, icht, iche, eine, rden, tion, urde, haft, eich, rung, chte, ssen, chaf, nder, tlic, tung, eite, iert, sich, ngen, erde, scha, nden, unge, lung, mmen, eren, ende, inde, erun, sten, iese, igen, erte, iner, tsch, keit, der, die, ter, und, ein, ist, den, ten, ber, ver, sch, ung, ste, ent, ach, nte, auf, ben, eit, des, ers, aus, das, von, ren, gen, nen, lle, hre, mit, iel, uch, lte, ann, lie, men, dem, and, ind, als, sta, elt, ges, tte, ern, wir, ell, war, ere, rch, abe, len, ige, ied, ger, nnt, wei, ele, och, sse, end, all, ahr, bei, sie, ede, ion, ieg, ege, auc, che, rie, eis, vor, her, ang, für, ass, uss, tel, er, in, ge, en, st, ie, an, te, be, re, zu, ar, es, ra, al, or, ch, et, ei, un, le, rt, se, is, ha, we, at, me, ne, ur, he, au, ro, ti, li, ri, eh, im, ma, tr, ig, el, um, la, am, de, so, ol, tz, il, on, it, sc, sp, ko, na, pr, ni, si, fe, wi, ns, ke, ut, da, gr, eu, mi, hr, ze, hi, ta, ss, ng, sa, us, ba, ck, em, kt, ka, ve, fr, bi, wa, ah, gt, di, ab, fo, to, rk, as, ag, gi, hn, s, t, n, m, r, l, f, e, a, b, d, h, k, g, o, i, u, w, p, z, ä, ü, v, ö, j, c, y, x, q, á, í, ō, ó, š, é, č, ? | |