Traves commited on
Commit
bc742a1
·
verified ·
1 Parent(s): e63a57d

Upload folder using huggingface_hub

Browse files
Files changed (9) hide show
  1. .gitattributes +1 -0
  2. Dockerfile +38 -0
  3. README.md +31 -7
  4. cmd/server/main.go +250 -0
  5. go.mod +38 -0
  6. go.sum +66 -0
  7. models/latest_checkpoint.json +3 -0
  8. pkg/model/model.go +375 -0
  9. pkg/model/tokenizer.go +124 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ models/latest_checkpoint.json filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build stage
2
+ FROM golang:1.25-bullseye AS builder
3
+
4
+ WORKDIR /app
5
+
6
+ # Copy go.mod and go.sum first for caching
7
+ COPY go.mod go.sum ./
8
+ RUN go mod download
9
+
10
+ # Copy the rest of the code
11
+ COPY . .
12
+
13
+ # Build the server binary
14
+ RUN CGO_ENABLED=0 GOOS=linux go build -o server ./cmd/server/main.go
15
+
16
+ # Final stage
17
+ FROM debian:bullseye-slim
18
+
19
+ WORKDIR /app
20
+
21
+ # Install CA certificates for external downloads if needed
22
+ RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/*
23
+
24
+ # Copy the binary from builder
25
+ COPY --from=builder /app/server .
26
+
27
+ # Copy the models directory for weights
28
+ COPY ./models ./models
29
+
30
+ # Set environment variables
31
+ ENV PORT=7860
32
+ ENV MODEL_PATH=models/latest_checkpoint.json
33
+
34
+ # Expose the port
35
+ EXPOSE 7860
36
+
37
+ # Run the server
38
+ CMD ["./server"]
README.md CHANGED
@@ -1,12 +1,36 @@
1
  ---
2
- title: MicroGPT API
3
- emoji: 👁
4
- colorFrom: pink
5
  colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 6.6.0
8
- app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: MicroGPT OpenAI API
3
+ emoji: 🚀
4
+ colorFrom: blue
5
  colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
 
8
  pinned: false
9
+ license: mit
10
  ---
11
 
12
+ # MicroGPT OpenAI-Compatible API
13
+
14
+ This Space hosts a Go-based inference server for the **MicroGPT** model, providing an OpenAI-compatible API.
15
+
16
+ ## API Endpoints
17
+
18
+ - **`POST /v1/chat/completions`**: standard OpenAI chat format.
19
+ - **`GET /v1/models`**: returns model metadata.
20
+
21
+ ## Local Test
22
+
23
+ ```bash
24
+ curl -X POST http://localhost:7860/v1/chat/completions \
25
+ -H "Content-Type: application/json" \
26
+ -d '{
27
+ "model": "microgpt",
28
+ "messages": [
29
+ {"role": "user", "content": "Help me prioritize my day"}
30
+ ]
31
+ }'
32
+ ```
33
+
34
+ ## Repository
35
+
36
+ Built with [MicroGPT Go Edition](https://github.com/Traves-Theberge/microgpt-tui-go).
cmd/server/main.go ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "log"
7
+ "net/http"
8
+ "os"
9
+ "strings"
10
+ "time"
11
+
12
+ "microgpt-go/pkg/model"
13
+ )
14
+
15
+ type ChatMessage struct {
16
+ Role string `json:"role"`
17
+ Content string `json:"content"`
18
+ }
19
+
20
+ type ChatCompletionRequest struct {
21
+ Model string `json:"model"`
22
+ Messages []ChatMessage `json:"messages"`
23
+ Temperature float64 `json:"temperature"`
24
+ MaxTokens int `json:"max_tokens"`
25
+ TopP float64 `json:"top_p"`
26
+ Stream bool `json:"stream"`
27
+ }
28
+
29
+ type ChatCompletionResponse struct {
30
+ ID string `json:"id"`
31
+ Object string `json:"object"`
32
+ Created int64 `json:"created"`
33
+ Model string `json:"model"`
34
+ Choices []struct {
35
+ Message ChatMessage `json:"message"`
36
+ Index int `json:"index"`
37
+ FinishReason string `json:"finish_reason"`
38
+ } `json:"choices"`
39
+ Usage struct {
40
+ PromptTokens int `json:"prompt_tokens"`
41
+ CompletionTokens int `json:"completion_tokens"`
42
+ TotalTokens int `json:"total_tokens"`
43
+ } `json:"usage"`
44
+ }
45
+
46
+ var (
47
+ gpt func(tokenID, posID int, keys, values [][][]*model.Value) []*model.Value
48
+ tokenizer model.TokenizerRuntime
49
+ config model.TrainingCheckpointConfig
50
+ state map[string][][]*model.Value
51
+ )
52
+
53
+ func initModel() {
54
+ ckptPath := os.Getenv("MODEL_PATH")
55
+ if ckptPath == "" {
56
+ ckptPath = "models/latest_checkpoint.json"
57
+ }
58
+ log.Printf("Loading model from %s...", ckptPath)
59
+ ckpt, err := model.LoadCheckpoint(ckptPath)
60
+ if err != nil {
61
+ log.Fatalf("Failed to load checkpoint: %v", err)
62
+ }
63
+
64
+ tokenizer, err = model.TokenizerFromCheckpoint(ckpt)
65
+ if err != nil {
66
+ log.Fatalf("Failed to load tokenizer: %v", err)
67
+ }
68
+
69
+ state = model.ImportState(ckpt.State)
70
+ config = ckpt.Config
71
+ gpt = model.BuildGPT(state, config.NLayer, config.NEmbd, config.NHead)
72
+ log.Println("Model loaded successfully.")
73
+ }
74
+
75
+ func handleChat(w http.ResponseWriter, r *http.Request) {
76
+ if r.Method != http.MethodPost {
77
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
78
+ return
79
+ }
80
+
81
+ var req ChatCompletionRequest
82
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
83
+ http.Error(w, "Invalid request body", http.StatusBadRequest)
84
+ return
85
+ }
86
+
87
+ if req.Temperature <= 0 {
88
+ req.Temperature = 0.5
89
+ }
90
+ if req.TopP <= 0 {
91
+ req.TopP = 0.9
92
+ }
93
+ if req.MaxTokens <= 0 {
94
+ req.MaxTokens = 128
95
+ }
96
+
97
+ // Simple prompt construction from messages
98
+ var promptBuilder strings.Builder
99
+ for _, msg := range req.Messages {
100
+ role := "User"
101
+ if msg.Role == "assistant" {
102
+ role = "Assistant"
103
+ }
104
+ fmt.Fprintf(&promptBuilder, "%s: %s\n", role, msg.Content)
105
+ }
106
+ promptBuilder.WriteString("Assistant: ")
107
+ promptText := promptBuilder.String()
108
+
109
+ promptTokens := tokenizer.EncodeDoc(promptText)
110
+ if len(promptTokens) > config.BlockSize-1 {
111
+ promptTokens = promptTokens[len(promptTokens)-(config.BlockSize-1):]
112
+ }
113
+
114
+ keys := make([][][]*model.Value, config.NLayer)
115
+ values := make([][][]*model.Value, config.NLayer)
116
+ tokenID := tokenizer.BosID
117
+ pos := 0
118
+
119
+ // Process prompt tokens (pre-fill KV cache)
120
+ for _, nextID := range promptTokens {
121
+ if pos >= config.BlockSize {
122
+ break
123
+ }
124
+ _ = gpt(tokenID, pos, keys, values)
125
+ tokenID = nextID
126
+ pos++
127
+ }
128
+
129
+ // Generate response
130
+ completionTokens := 0
131
+ outTokens := make([]int, 0, req.MaxTokens)
132
+ recent := make([]int, 0, 64)
133
+ stopSeqs := []string{"\nUser:", "\nAssistant:"}
134
+
135
+ for pos < config.BlockSize && completionTokens < req.MaxTokens {
136
+ logits := gpt(tokenID, pos, keys, values)
137
+ recentSet := map[int]bool{}
138
+ for _, id := range recent {
139
+ recentSet[id] = true
140
+ }
141
+ weights := model.NextTokenWeights(logits, req.Temperature, 40, req.TopP, recentSet, 1.1)
142
+ tokenID = model.SampleWeighted(weights)
143
+
144
+ if tokenID == tokenizer.BosID {
145
+ break
146
+ }
147
+
148
+ outTokens = append(outTokens, tokenID)
149
+ recent = append(recent, tokenID)
150
+ if len(recent) > 64 {
151
+ recent = recent[len(recent)-64:]
152
+ }
153
+ completionTokens++
154
+ pos++
155
+
156
+ // Check for stop sequences in decoded text
157
+ fullText := tokenizer.DecodeTokens(outTokens)
158
+ stopFound := false
159
+ for _, stop := range stopSeqs {
160
+ if strings.Contains(fullText, stop) {
161
+ stopFound = true
162
+ break
163
+ }
164
+ }
165
+ if stopFound {
166
+ break
167
+ }
168
+ }
169
+
170
+ responseText := strings.TrimSpace(tokenizer.DecodeTokens(outTokens))
171
+ // Clean up any trailing stop sequence markers
172
+ for _, stop := range stopSeqs {
173
+ if idx := strings.Index(responseText, strings.TrimSpace(stop)); idx >= 0 {
174
+ responseText = responseText[:idx]
175
+ }
176
+ }
177
+
178
+ resp := ChatCompletionResponse{
179
+ ID: fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
180
+ Object: "chat.completion",
181
+ Created: time.Now().Unix(),
182
+ Model: "microgpt",
183
+ Choices: []struct {
184
+ Message ChatMessage `json:"message"`
185
+ Index int `json:"index"`
186
+ FinishReason string `json:"finish_reason"`
187
+ }{
188
+ {
189
+ Message: ChatMessage{
190
+ Role: "assistant",
191
+ Content: strings.TrimSpace(responseText),
192
+ },
193
+ Index: 0,
194
+ FinishReason: "stop",
195
+ },
196
+ },
197
+ }
198
+ resp.Usage.PromptTokens = len(promptTokens)
199
+ resp.Usage.CompletionTokens = completionTokens
200
+ resp.Usage.TotalTokens = resp.Usage.PromptTokens + resp.Usage.CompletionTokens
201
+
202
+ w.Header().Set("Content-Type", "application/json")
203
+ json.NewEncoder(w).Encode(resp)
204
+ }
205
+
206
+ func handleModels(w http.ResponseWriter, r *http.Request) {
207
+ resp := struct {
208
+ Object string `json:"object"`
209
+ Data []struct {
210
+ ID string `json:"id"`
211
+ Object string `json:"object"`
212
+ Created int64 `json:"created"`
213
+ OwnedBy string `json:"owned_by"`
214
+ } `json:"data"`
215
+ }{
216
+ Object: "list",
217
+ Data: []struct {
218
+ ID string `json:"id"`
219
+ Object string `json:"object"`
220
+ Created int64 `json:"created"`
221
+ OwnedBy string `json:"owned_by"`
222
+ }{
223
+ {
224
+ ID: "microgpt",
225
+ Object: "model",
226
+ Created: time.Now().Unix(),
227
+ OwnedBy: "microgpt",
228
+ },
229
+ },
230
+ }
231
+ w.Header().Set("Content-Type", "application/json")
232
+ json.NewEncoder(w).Encode(resp)
233
+ }
234
+
235
+ func main() {
236
+ initModel()
237
+
238
+ http.HandleFunc("/v1/chat/completions", handleChat)
239
+ http.HandleFunc("/v1/models", handleModels)
240
+
241
+ port := os.Getenv("PORT")
242
+ if port == "" {
243
+ port = "7860" // Standard port for HF Spaces
244
+ }
245
+
246
+ log.Printf("Starting OpenAI-compatible server on port %s...", port)
247
+ if err := http.ListenAndServe(":"+port, nil); err != nil {
248
+ log.Fatalf("Failed to start server: %v", err)
249
+ }
250
+ }
go.mod ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module microgpt-go
2
+
3
+ go 1.25
4
+
5
+ require (
6
+ github.com/charmbracelet/bubbles v1.0.0
7
+ github.com/charmbracelet/bubbletea v1.3.10
8
+ github.com/charmbracelet/harmonica v0.2.0
9
+ github.com/charmbracelet/lipgloss v1.1.0
10
+ )
11
+
12
+ require (
13
+ github.com/atotto/clipboard v0.1.4 // indirect
14
+ github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
15
+ github.com/charmbracelet/colorprofile v0.4.1 // indirect
16
+ github.com/charmbracelet/x/ansi v0.11.6 // indirect
17
+ github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
18
+ github.com/charmbracelet/x/term v0.2.2 // indirect
19
+ github.com/clipperhouse/displaywidth v0.9.0 // indirect
20
+ github.com/clipperhouse/stringish v0.1.1 // indirect
21
+ github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
22
+ github.com/dlclark/regexp2 v1.10.0 // indirect
23
+ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
24
+ github.com/google/uuid v1.3.0 // indirect
25
+ github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
26
+ github.com/mattn/go-isatty v0.0.20 // indirect
27
+ github.com/mattn/go-localereader v0.0.1 // indirect
28
+ github.com/mattn/go-runewidth v0.0.19 // indirect
29
+ github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
30
+ github.com/muesli/cancelreader v0.2.2 // indirect
31
+ github.com/muesli/termenv v0.16.0 // indirect
32
+ github.com/pkoukk/tiktoken-go v0.1.8 // indirect
33
+ github.com/pkoukk/tiktoken-go-loader v0.0.2 // indirect
34
+ github.com/rivo/uniseg v0.4.7 // indirect
35
+ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
36
+ golang.org/x/sys v0.38.0 // indirect
37
+ golang.org/x/text v0.3.8 // indirect
38
+ )
go.sum ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
2
+ github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
3
+ github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
4
+ github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
5
+ github.com/aymanbagabas/go-udiff v0.3.1 h1:LV+qyBQ2pqe0u42ZsUEtPiCaUoqgA9gYRDs3vj1nolY=
6
+ github.com/aymanbagabas/go-udiff v0.3.1/go.mod h1:G0fsKmG+P6ylD0r6N/KgQD/nWzgfnl8ZBcNLgcbrw8E=
7
+ github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
8
+ github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
9
+ github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
10
+ github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
11
+ github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
12
+ github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
13
+ github.com/charmbracelet/harmonica v0.2.0 h1:8NxJWRWg/bzKqqEaaeFNipOu77YR5t8aSwG4pgaUBiQ=
14
+ github.com/charmbracelet/harmonica v0.2.0/go.mod h1:KSri/1RMQOZLbw7AHqgcBycp8pgJnQMYYT8QZRqZ1Ao=
15
+ github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
16
+ github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
17
+ github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
18
+ github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
19
+ github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
20
+ github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
21
+ github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91 h1:payRxjMjKgx2PaCWLZ4p3ro9y97+TVLZNaRZgJwSVDQ=
22
+ github.com/charmbracelet/x/exp/golden v0.0.0-20241011142426-46044092ad91/go.mod h1:wDlXFlCrmJ8J+swcL/MnGUuYnqgQdW9rhSD61oNMb6U=
23
+ github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
24
+ github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
25
+ github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
26
+ github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
27
+ github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
28
+ github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
29
+ github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
30
+ github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
31
+ github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
32
+ github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
33
+ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
34
+ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
35
+ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
36
+ github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
37
+ github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
38
+ github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
39
+ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
40
+ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
41
+ github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
42
+ github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
43
+ github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
44
+ github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
45
+ github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
46
+ github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
47
+ github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
48
+ github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
49
+ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
50
+ github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
51
+ github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo=
52
+ github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
53
+ github.com/pkoukk/tiktoken-go-loader v0.0.2 h1:LUKws63GV3pVHwH1srkBplBv+7URgmOmhSkRxsIvsK4=
54
+ github.com/pkoukk/tiktoken-go-loader v0.0.2/go.mod h1:4mIkYyZooFlnenDlormIo6cd5wrlUKNr97wp9nGgEKo=
55
+ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
56
+ github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
57
+ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
58
+ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
59
+ golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
60
+ golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
61
+ golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
62
+ golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
63
+ golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
64
+ golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
65
+ golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
66
+ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
models/latest_checkpoint.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:964f39971833a67b2ec3a3cdd1376586aa3d3cc2b55cb11f8dc581c27a304720
3
+ size 19575802
pkg/model/model.go ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package model
2
+
3
+ import (
4
+ "encoding/json"
5
+ "fmt"
6
+ "math"
7
+ "math/rand"
8
+ "os"
9
+ "sort"
10
+ )
11
+
12
+ // Value represents a scalar for autograd
13
+ type Value struct {
14
+ Data float64
15
+ Grad float64
16
+ Children []*Value
17
+ LocalGrads []float64
18
+ }
19
+
20
+ func V(x float64) *Value {
21
+ return &Value{Data: x}
22
+ }
23
+
24
+ func Add(a, b *Value) *Value {
25
+ out := &Value{Data: a.Data + b.Data, Children: []*Value{a, b}, LocalGrads: []float64{1, 1}}
26
+ return out
27
+ }
28
+
29
+ func Sub(a, b *Value) *Value {
30
+ out := &Value{Data: a.Data - b.Data, Children: []*Value{a, b}, LocalGrads: []float64{1, -1}}
31
+ return out
32
+ }
33
+
34
+ func Mul(a, b *Value) *Value {
35
+ out := &Value{Data: a.Data * b.Data, Children: []*Value{a, b}, LocalGrads: []float64{b.Data, a.Data}}
36
+ return out
37
+ }
38
+
39
+ func Pow(a *Value, p float64) *Value {
40
+ out := &Value{Data: math.Pow(a.Data, p), Children: []*Value{a}, LocalGrads: []float64{p * math.Pow(a.Data, p-1)}}
41
+ return out
42
+ }
43
+
44
+ func Div(a, b *Value) *Value {
45
+ return Mul(a, Pow(b, -1))
46
+ }
47
+
48
+ func Neg(a *Value) *Value {
49
+ return Mul(a, V(-1))
50
+ }
51
+
52
+ func Log(a *Value) *Value {
53
+ out := &Value{Data: math.Log(a.Data), Children: []*Value{a}, LocalGrads: []float64{1 / a.Data}}
54
+ return out
55
+ }
56
+
57
+ func Exp(a *Value) *Value {
58
+ out := &Value{Data: math.Exp(a.Data), Children: []*Value{a}, LocalGrads: []float64{math.Exp(a.Data)}}
59
+ return out
60
+ }
61
+
62
+ func ReLU(a *Value) *Value {
63
+ val := 0.0
64
+ grad := 0.0
65
+ if a.Data > 0 {
66
+ val = a.Data
67
+ grad = 1
68
+ }
69
+ out := &Value{Data: val, Children: []*Value{a}, LocalGrads: []float64{grad}}
70
+ return out
71
+ }
72
+
73
+ func Backward(out *Value) {
74
+ topo := make([]*Value, 0)
75
+ visited := make(map[*Value]bool)
76
+ var buildTopo func(*Value)
77
+ buildTopo = func(v *Value) {
78
+ if !visited[v] {
79
+ visited[v] = true
80
+ for _, child := range v.Children {
81
+ buildTopo(child)
82
+ }
83
+ topo = append(topo, v)
84
+ }
85
+ }
86
+ buildTopo(out)
87
+
88
+ for _, v := range topo {
89
+ v.Grad = 0
90
+ }
91
+ out.Grad = 1
92
+ for i := len(topo) - 1; i >= 0; i-- {
93
+ v := topo[i]
94
+ for j, child := range v.Children {
95
+ child.Grad += v.LocalGrads[j] * v.Grad
96
+ }
97
+ }
98
+ }
99
+
100
+ func linear(x []*Value, w [][]*Value) []*Value {
101
+ nout := len(w)
102
+ nin := len(x)
103
+ out := make([]*Value, nout)
104
+ for i := 0; i < nout; i++ {
105
+ s := V(0)
106
+ for j := 0; j < nin; j++ {
107
+ s = Add(s, Mul(x[j], w[i][j]))
108
+ }
109
+ out[i] = s
110
+ }
111
+ return out
112
+ }
113
+
114
+ func softmax(logits []*Value) []*Value {
115
+ maxVal := -math.MaxFloat64
116
+ for _, l := range logits {
117
+ if l.Data > maxVal {
118
+ maxVal = l.Data
119
+ }
120
+ }
121
+ exps := make([]*Value, len(logits))
122
+ sumExp := V(0)
123
+ for i, l := range logits {
124
+ exps[i] = Exp(Sub(l, V(maxVal)))
125
+ sumExp = Add(sumExp, exps[i])
126
+ }
127
+ out := make([]*Value, len(logits))
128
+ invSum := Div(V(1), sumExp)
129
+ for i := range exps {
130
+ out[i] = Mul(exps[i], invSum)
131
+ }
132
+ return out
133
+ }
134
+
135
+ func rmsnorm(x []*Value) []*Value {
136
+ meanSq := V(0)
137
+ for _, v := range x {
138
+ meanSq = Add(meanSq, Pow(v, 2))
139
+ }
140
+ meanSq = Mul(V(1/float64(len(x))), meanSq)
141
+ invStd := Div(V(1), Pow(Add(meanSq, V(1e-6)), 0.5))
142
+ out := make([]*Value, len(x))
143
+ for i, v := range x {
144
+ out[i] = Mul(v, invStd)
145
+ }
146
+ return out
147
+ }
148
+
149
+ // TrainingCheckpoint structs
150
+ type TrainingCheckpoint struct {
151
+ Version int `json:"version"`
152
+ CreatedAt string `json:"created_at"`
153
+ Config TrainingCheckpointConfig `json:"config"`
154
+ Tokenization string `json:"tokenization,omitempty"`
155
+ BPEEncoding string `json:"bpe_encoding,omitempty"`
156
+ BPETokenIDs []int `json:"bpe_token_ids,omitempty"`
157
+ Vocab []string `json:"vocab,omitempty"`
158
+ State map[string][][]float64 `json:"state"`
159
+ }
160
+
161
+ type TrainingCheckpointConfig struct {
162
+ NLayer int `json:"n_layer"`
163
+ NEmbd int `json:"n_embd"`
164
+ NHead int `json:"n_head"`
165
+ BlockSize int `json:"block_size"`
166
+ }
167
+
168
+ func ImportState(src map[string][][]float64) map[string][][]*Value {
169
+ out := make(map[string][][]*Value, len(src))
170
+ for name, mat := range src {
171
+ rows := make([][]*Value, len(mat))
172
+ for i, row := range mat {
173
+ r := make([]*Value, len(row))
174
+ for j, v := range row {
175
+ r[j] = V(v)
176
+ }
177
+ rows[i] = r
178
+ }
179
+ out[name] = rows
180
+ }
181
+ return out
182
+ }
183
+
184
+ func LoadCheckpoint(path string) (TrainingCheckpoint, error) {
185
+ b, err := os.ReadFile(path)
186
+ if err != nil {
187
+ return TrainingCheckpoint{}, err
188
+ }
189
+ var ckpt TrainingCheckpoint
190
+ if err := json.Unmarshal(b, &ckpt); err != nil {
191
+ return TrainingCheckpoint{}, err
192
+ }
193
+ if ckpt.Config.NLayer < 1 || ckpt.Config.NEmbd < 1 || ckpt.Config.NHead < 1 || ckpt.Config.BlockSize < 2 {
194
+ return TrainingCheckpoint{}, fmt.Errorf("invalid checkpoint config")
195
+ }
196
+ if ckpt.Config.NEmbd%ckpt.Config.NHead != 0 {
197
+ return TrainingCheckpoint{}, fmt.Errorf("invalid checkpoint: n_embd must be divisible by n_head")
198
+ }
199
+ return ckpt, nil
200
+ }
201
+
202
+ func BuildGPT(state map[string][][]*Value, nLayer, nEmbd, nHead int) func(tokenID, posID int, keys, values [][][]*Value) []*Value {
203
+ headDim := nEmbd / nHead
204
+ return func(tokenID, posID int, keys, values [][][]*Value) []*Value {
205
+ tokEmb := state["wte"][tokenID]
206
+ posEmb := state["wpe"][posID]
207
+ x := make([]*Value, len(tokEmb))
208
+ for i := range tokEmb {
209
+ x[i] = Add(tokEmb[i], posEmb[i])
210
+ }
211
+ x = rmsnorm(x)
212
+
213
+ for li := 0; li < nLayer; li++ {
214
+ xResidual := x
215
+ x = rmsnorm(x)
216
+ q := linear(x, state[fmt.Sprintf("layer%d.attn_wq", li)])
217
+ k := linear(x, state[fmt.Sprintf("layer%d.attn_wk", li)])
218
+ v := linear(x, state[fmt.Sprintf("layer%d.attn_wv", li)])
219
+ keys[li] = append(keys[li], k)
220
+ values[li] = append(values[li], v)
221
+
222
+ xAttn := make([]*Value, 0, nEmbd)
223
+ for h := 0; h < nHead; h++ {
224
+ hs := h * headDim
225
+ qH := q[hs : hs+headDim]
226
+
227
+ kH := make([][]*Value, len(keys[li]))
228
+ vH := make([][]*Value, len(values[li]))
229
+ for t := 0; t < len(keys[li]); t++ {
230
+ kH[t] = keys[li][t][hs : hs+headDim]
231
+ vH[t] = values[li][t][hs : hs+headDim]
232
+ }
233
+
234
+ attnLogits := make([]*Value, len(kH))
235
+ for t := 0; t < len(kH); t++ {
236
+ score := V(0)
237
+ for j := 0; j < headDim; j++ {
238
+ score = Add(score, Mul(qH[j], kH[t][j]))
239
+ }
240
+ attnLogits[t] = Div(score, V(math.Sqrt(float64(headDim))))
241
+ }
242
+ attnWeights := softmax(attnLogits)
243
+
244
+ headOut := make([]*Value, headDim)
245
+ for j := 0; j < headDim; j++ {
246
+ s := V(0)
247
+ for t := 0; t < len(vH); t++ {
248
+ s = Add(s, Mul(attnWeights[t], vH[t][j]))
249
+ }
250
+ headOut[j] = s
251
+ }
252
+ xAttn = append(xAttn, headOut...)
253
+ }
254
+
255
+ x = linear(xAttn, state[fmt.Sprintf("layer%d.attn_wo", li)])
256
+ for i := range x {
257
+ x[i] = Add(x[i], xResidual[i])
258
+ }
259
+
260
+ xResidual = x
261
+ x = rmsnorm(x)
262
+ x = linear(x, state[fmt.Sprintf("layer%d.mlp_fc1", li)])
263
+ for i := range x {
264
+ x[i] = ReLU(x[i])
265
+ }
266
+ x = linear(x, state[fmt.Sprintf("layer%d.mlp_fc2", li)])
267
+ for i := range x {
268
+ x[i] = Add(x[i], xResidual[i])
269
+ }
270
+ }
271
+
272
+ return linear(x, state["lm_head"])
273
+ }
274
+ }
275
+
276
+ // Sampling functions
277
+ func SampleWeighted(weights []float64) int {
278
+ sum := 0.0
279
+ for _, w := range weights {
280
+ sum += w
281
+ }
282
+ r := rand.Float64() * sum
283
+ running := 0.0
284
+ for i, w := range weights {
285
+ running += w
286
+ if r <= running {
287
+ return i
288
+ }
289
+ }
290
+ return len(weights) - 1
291
+ }
292
+
293
+ func SoftmaxFloat(logits []float64) []float64 {
294
+ maxLogit := -math.MaxFloat64
295
+ for _, l := range logits {
296
+ if l > maxLogit {
297
+ maxLogit = l
298
+ }
299
+ }
300
+ sum := 0.0
301
+ out := make([]float64, len(logits))
302
+ for i, l := range logits {
303
+ out[i] = math.Exp(l - maxLogit)
304
+ sum += out[i]
305
+ }
306
+ for i := range out {
307
+ out[i] /= sum
308
+ }
309
+ return out
310
+ }
311
+
312
+ func NextTokenWeights(logits []*Value, temperature float64, topK int, topP float64, recent map[int]bool, repetitionPenalty float64) []float64 {
313
+ l := make([]float64, len(logits))
314
+ for i, v := range logits {
315
+ l[i] = v.Data
316
+ if recent[i] {
317
+ if l[i] >= 0 {
318
+ l[i] /= repetitionPenalty
319
+ } else {
320
+ l[i] *= repetitionPenalty
321
+ }
322
+ }
323
+ l[i] /= temperature
324
+ }
325
+ w := SoftmaxFloat(l)
326
+ if topK > 0 {
327
+ w = ApplyTopK(w, topK)
328
+ }
329
+ if topP > 0 && topP < 1.0 {
330
+ w = ApplyTopP(w, topP)
331
+ }
332
+ return w
333
+ }
334
+
335
+ func ApplyTopK(weights []float64, k int) []float64 {
336
+ if k >= len(weights) {
337
+ return weights
338
+ }
339
+ type kv struct {
340
+ i int
341
+ w float64
342
+ }
343
+ arr := make([]kv, len(weights))
344
+ for i, w := range weights {
345
+ arr[i] = kv{i, w}
346
+ }
347
+ sort.Slice(arr, func(i, j int) bool { return arr[i].w > arr[j].w })
348
+ out := make([]float64, len(weights))
349
+ for i := 0; i < k; i++ {
350
+ out[arr[i].i] = arr[i].w
351
+ }
352
+ return out
353
+ }
354
+
355
+ func ApplyTopP(weights []float64, p float64) []float64 {
356
+ type kv struct {
357
+ i int
358
+ w float64
359
+ }
360
+ arr := make([]kv, len(weights))
361
+ for i, w := range weights {
362
+ arr[i] = kv{i, w}
363
+ }
364
+ sort.Slice(arr, func(i, j int) bool { return arr[i].w > arr[j].w })
365
+ out := make([]float64, len(weights))
366
+ sum := 0.0
367
+ for i := 0; i < len(arr); i++ {
368
+ sum += arr[i].w
369
+ out[arr[i].i] = arr[i].w
370
+ if sum >= p {
371
+ break
372
+ }
373
+ }
374
+ return out
375
+ }
pkg/model/tokenizer.go ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package model
2
+
3
+ import (
4
+ "fmt"
5
+ "strings"
6
+
7
+ tiktoken "github.com/pkoukk/tiktoken-go"
8
+ )
9
+
10
+ type TokenizerRuntime struct {
11
+ Mode string
12
+ CharToLocal map[rune]int
13
+ LocalToChar []rune
14
+ BpeEncoding string
15
+ Bpe *tiktoken.Tiktoken
16
+ BpeToLocal map[int]int
17
+ LocalToBPE []int
18
+ UnkID int
19
+ BosID int
20
+ }
21
+
22
+ func (t TokenizerRuntime) VocabSize() int {
23
+ if t.Mode == "bpe_cl100k" {
24
+ return len(t.LocalToBPE) + 2
25
+ }
26
+ return len(t.LocalToChar) + 1
27
+ }
28
+
29
+ func (t TokenizerRuntime) EncodeDoc(doc string) []int {
30
+ if t.Mode == "bpe_cl100k" {
31
+ raw := t.Bpe.EncodeOrdinary(doc)
32
+ out := make([]int, 0, len(raw))
33
+ for _, id := range raw {
34
+ if local, ok := t.BpeToLocal[id]; ok {
35
+ out = append(out, local)
36
+ } else {
37
+ out = append(out, t.UnkID)
38
+ }
39
+ }
40
+ return out
41
+ }
42
+ out := make([]int, 0, len(doc))
43
+ for _, r := range doc {
44
+ if id, ok := t.CharToLocal[r]; ok {
45
+ out = append(out, id)
46
+ }
47
+ }
48
+ return out
49
+ }
50
+
51
+ func (t TokenizerRuntime) DecodeTokens(tokens []int) string {
52
+ if t.Mode == "bpe_cl100k" {
53
+ raw := make([]int, 0, len(tokens))
54
+ for _, local := range tokens {
55
+ if local >= 0 && local < len(t.LocalToBPE) {
56
+ raw = append(raw, t.LocalToBPE[local])
57
+ }
58
+ }
59
+ return t.Bpe.Decode(raw)
60
+ }
61
+ out := make([]rune, 0, len(tokens))
62
+ for _, id := range tokens {
63
+ if id >= 0 && id < len(t.LocalToChar) {
64
+ out = append(out, t.LocalToChar[id])
65
+ }
66
+ }
67
+ return string(out)
68
+ }
69
+
70
+ func TokenizerFromCheckpoint(ckpt TrainingCheckpoint) (TokenizerRuntime, error) {
71
+ if ckpt.Tokenization == "bpe_cl100k" || len(ckpt.BPETokenIDs) > 0 {
72
+ encName := strings.TrimSpace(ckpt.BPEEncoding)
73
+ if encName == "" {
74
+ encName = "cl100k_base"
75
+ }
76
+ enc, err := tiktoken.GetEncoding(encName)
77
+ if err != nil {
78
+ return TokenizerRuntime{}, err
79
+ }
80
+ localToBPE := append([]int(nil), ckpt.BPETokenIDs...)
81
+ bpeToLocal := make(map[int]int, len(localToBPE))
82
+ for i, id := range localToBPE {
83
+ bpeToLocal[id] = i
84
+ }
85
+ return TokenizerRuntime{
86
+ Mode: "bpe_cl100k",
87
+ BpeEncoding: encName,
88
+ Bpe: enc,
89
+ BpeToLocal: bpeToLocal,
90
+ LocalToBPE: localToBPE,
91
+ UnkID: len(localToBPE),
92
+ BosID: len(localToBPE) + 1,
93
+ }, nil
94
+ }
95
+ uchars, err := stringsToRunes(ckpt.Vocab)
96
+ if err != nil {
97
+ return TokenizerRuntime{}, err
98
+ }
99
+ if len(uchars) == 0 {
100
+ return TokenizerRuntime{}, fmt.Errorf("checkpoint has empty character vocab")
101
+ }
102
+ charToLocal := make(map[rune]int, len(uchars))
103
+ for i, r := range uchars {
104
+ charToLocal[r] = i
105
+ }
106
+ return TokenizerRuntime{
107
+ Mode: "char",
108
+ CharToLocal: charToLocal,
109
+ LocalToChar: uchars,
110
+ BosID: len(uchars),
111
+ }, nil
112
+ }
113
+
114
+ func stringsToRunes(ss []string) ([]rune, error) {
115
+ out := make([]rune, 0, len(ss))
116
+ for _, s := range ss {
117
+ r := []rune(s)
118
+ if len(r) != 1 {
119
+ return nil, fmt.Errorf("invalid vocab token %q: expected one rune", s)
120
+ }
121
+ out = append(out, r[0])
122
+ }
123
+ return out, nil
124
+ }