File size: 2,862 Bytes
bc742a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package model

import (
	"fmt"
	"strings"

	tiktoken "github.com/pkoukk/tiktoken-go"
)

type TokenizerRuntime struct {
	Mode        string
	CharToLocal map[rune]int
	LocalToChar []rune
	BpeEncoding string
	Bpe         *tiktoken.Tiktoken
	BpeToLocal  map[int]int
	LocalToBPE  []int
	UnkID       int
	BosID       int
}

func (t TokenizerRuntime) VocabSize() int {
	if t.Mode == "bpe_cl100k" {
		return len(t.LocalToBPE) + 2
	}
	return len(t.LocalToChar) + 1
}

func (t TokenizerRuntime) EncodeDoc(doc string) []int {
	if t.Mode == "bpe_cl100k" {
		raw := t.Bpe.EncodeOrdinary(doc)
		out := make([]int, 0, len(raw))
		for _, id := range raw {
			if local, ok := t.BpeToLocal[id]; ok {
				out = append(out, local)
			} else {
				out = append(out, t.UnkID)
			}
		}
		return out
	}
	out := make([]int, 0, len(doc))
	for _, r := range doc {
		if id, ok := t.CharToLocal[r]; ok {
			out = append(out, id)
		}
	}
	return out
}

func (t TokenizerRuntime) DecodeTokens(tokens []int) string {
	if t.Mode == "bpe_cl100k" {
		raw := make([]int, 0, len(tokens))
		for _, local := range tokens {
			if local >= 0 && local < len(t.LocalToBPE) {
				raw = append(raw, t.LocalToBPE[local])
			}
		}
		return t.Bpe.Decode(raw)
	}
	out := make([]rune, 0, len(tokens))
	for _, id := range tokens {
		if id >= 0 && id < len(t.LocalToChar) {
			out = append(out, t.LocalToChar[id])
		}
	}
	return string(out)
}

func TokenizerFromCheckpoint(ckpt TrainingCheckpoint) (TokenizerRuntime, error) {
	if ckpt.Tokenization == "bpe_cl100k" || len(ckpt.BPETokenIDs) > 0 {
		encName := strings.TrimSpace(ckpt.BPEEncoding)
		if encName == "" {
			encName = "cl100k_base"
		}
		enc, err := tiktoken.GetEncoding(encName)
		if err != nil {
			return TokenizerRuntime{}, err
		}
		localToBPE := append([]int(nil), ckpt.BPETokenIDs...)
		bpeToLocal := make(map[int]int, len(localToBPE))
		for i, id := range localToBPE {
			bpeToLocal[id] = i
		}
		return TokenizerRuntime{
			Mode:        "bpe_cl100k",
			BpeEncoding: encName,
			Bpe:         enc,
			BpeToLocal:  bpeToLocal,
			LocalToBPE:  localToBPE,
			UnkID:       len(localToBPE),
			BosID:       len(localToBPE) + 1,
		}, nil
	}
	uchars, err := stringsToRunes(ckpt.Vocab)
	if err != nil {
		return TokenizerRuntime{}, err
	}
	if len(uchars) == 0 {
		return TokenizerRuntime{}, fmt.Errorf("checkpoint has empty character vocab")
	}
	charToLocal := make(map[rune]int, len(uchars))
	for i, r := range uchars {
		charToLocal[r] = i
	}
	return TokenizerRuntime{
		Mode:        "char",
		CharToLocal: charToLocal,
		LocalToChar: uchars,
		BosID:       len(uchars),
	}, nil
}

func stringsToRunes(ss []string) ([]rune, error) {
	out := make([]rune, 0, len(ss))
	for _, s := range ss {
		r := []rune(s)
		if len(r) != 1 {
			return nil, fmt.Errorf("invalid vocab token %q: expected one rune", s)
		}
		out = append(out, r[0])
	}
	return out, nil
}