oki692 commited on
Commit
1595dc3
·
verified ·
1 Parent(s): 4464721

Update main.go

Browse files
Files changed (1) hide show
  1. main.go +245 -60
main.go CHANGED
@@ -9,6 +9,8 @@ import (
9
  "log"
10
  "net/http"
11
  "os"
 
 
12
  "time"
13
  )
14
 
@@ -59,6 +61,48 @@ type UpstreamRequest struct {
59
  ExtraBody map[string]interface{} `json:"extra_body,omitempty"`
60
  }
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  func resolveModel(requested string) string {
63
  if full, ok := modelAliases[requested]; ok {
64
  return full
@@ -126,6 +170,108 @@ func handleBaseURL(w http.ResponseWriter, r *http.Request) {
126
  fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host)
127
  }
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  func handleChat(w http.ResponseWriter, r *http.Request) {
130
  if !authenticate(r) {
131
  http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
@@ -142,58 +288,28 @@ func handleChat(w http.ResponseWriter, r *http.Request) {
142
  }
143
 
144
  modelID := resolveModel(req.Model)
145
- req.Messages = injectSystemPrompt(req.Messages, modelID)
146
-
147
- upstream := UpstreamRequest{
148
- Model: modelID,
149
- Messages: req.Messages,
150
- Stream: true,
151
- Tools: req.Tools,
152
- ToolChoice: req.ToolChoice,
153
- Temperature: req.Temperature,
154
- MaxTokens: req.MaxTokens,
155
- TopP: req.TopP,
156
- Stop: req.Stop,
157
- }
158
 
159
- // GLM-4.7 requires thinking disabled via extra_body
160
- if modelID == "z-ai/glm4.7" {
161
- upstream.ExtraBody = map[string]interface{}{
162
- "chat_template_kwargs": map[string]interface{}{
163
- "enable_thinking": false,
164
- },
 
 
 
 
 
165
  }
166
- }
167
-
168
- body, err := json.Marshal(upstream)
169
- if err != nil {
170
- http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError)
171
- return
172
- }
173
-
174
- upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body))
175
- if err != nil {
176
- http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError)
177
- return
178
- }
179
- upstreamReq.Header.Set("Content-Type", "application/json")
180
- upstreamReq.Header.Set("Authorization", "Bearer "+NvidiaAPIKey)
181
- upstreamReq.Header.Set("Accept", "text/event-stream")
182
-
183
- client := &http.Client{Timeout: 300 * time.Second}
184
- resp, err := client.Do(upstreamReq)
185
- if err != nil {
186
- http.Error(w, fmt.Sprintf(`{"error":{"message":"Upstream error: %s"}}`, err.Error()), http.StatusBadGateway)
187
- return
188
- }
189
- defer resp.Body.Close()
190
-
191
- if resp.StatusCode != http.StatusOK {
192
- upstreamBody, _ := io.ReadAll(resp.Body)
193
- w.Header().Set("Content-Type", "application/json")
194
- w.WriteHeader(resp.StatusCode)
195
- w.Write(upstreamBody)
196
- return
197
  }
198
 
199
  w.Header().Set("Content-Type", "text/event-stream")
@@ -203,20 +319,89 @@ func handleChat(w http.ResponseWriter, r *http.Request) {
203
  w.WriteHeader(http.StatusOK)
204
 
205
  flusher, canFlush := w.(http.Flusher)
206
- scanner := bufio.NewScanner(resp.Body)
207
- scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
208
-
209
- for scanner.Scan() {
210
- line := scanner.Text()
211
- if line != "" {
212
- fmt.Fprintf(w, "%s\n", line)
213
- } else {
214
- fmt.Fprintf(w, "\n")
215
- }
216
  if canFlush {
217
  flusher.Flush()
218
  }
219
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  }
221
 
222
  func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
 
9
  "log"
10
  "net/http"
11
  "os"
12
+ "sort"
13
+ "strings"
14
  "time"
15
  )
16
 
 
61
  ExtraBody map[string]interface{} `json:"extra_body,omitempty"`
62
  }
63
 
64
+ // SSE chunk types for parsing upstream stream
65
+ type RawChunk struct {
66
+ ID string `json:"id"`
67
+ Object string `json:"object"`
68
+ Created int64 `json:"created"`
69
+ Model string `json:"model"`
70
+ Choices []RawChoice `json:"choices"`
71
+ Usage interface{} `json:"usage,omitempty"`
72
+ }
73
+
74
+ type RawChoice struct {
75
+ Index int `json:"index"`
76
+ Delta RawDelta `json:"delta"`
77
+ FinishReason *string `json:"finish_reason"`
78
+ }
79
+
80
+ type RawDelta struct {
81
+ Role string `json:"role,omitempty"`
82
+ Content *string `json:"content,omitempty"`
83
+ ToolCalls []RawToolCall `json:"tool_calls,omitempty"`
84
+ }
85
+
86
+ type RawToolCall struct {
87
+ Index int `json:"index"`
88
+ ID string `json:"id,omitempty"`
89
+ Type string `json:"type,omitempty"`
90
+ Function RawFunction `json:"function"`
91
+ }
92
+
93
+ type RawFunction struct {
94
+ Name string `json:"name,omitempty"`
95
+ Arguments string `json:"arguments,omitempty"`
96
+ }
97
+
98
+ type AccumToolCall struct {
99
+ Index int
100
+ ID string
101
+ Type string
102
+ Name string
103
+ Args string
104
+ }
105
+
106
  func resolveModel(requested string) string {
107
  if full, ok := modelAliases[requested]; ok {
108
  return full
 
170
  fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host)
171
  }
172
 
173
+ func doUpstream(upstream UpstreamRequest) (*http.Response, error) {
174
+ body, err := json.Marshal(upstream)
175
+ if err != nil {
176
+ return nil, err
177
+ }
178
+ req, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body))
179
+ if err != nil {
180
+ return nil, err
181
+ }
182
+ req.Header.Set("Content-Type", "application/json")
183
+ req.Header.Set("Authorization", "Bearer "+NvidiaAPIKey)
184
+ req.Header.Set("Accept", "text/event-stream")
185
+ client := &http.Client{Timeout: 300 * time.Second}
186
+ return client.Do(req)
187
+ }
188
+
189
+ // collectStream reads SSE lines, accumulates tool_calls, returns all chunks + assembled tool calls
190
+ func collectStream(body io.Reader) ([]RawChunk, map[int]*AccumToolCall, error) {
191
+ scanner := bufio.NewScanner(body)
192
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
193
+
194
+ var chunks []RawChunk
195
+ accum := make(map[int]*AccumToolCall)
196
+
197
+ for scanner.Scan() {
198
+ line := scanner.Text()
199
+ if !strings.HasPrefix(line, "data: ") {
200
+ continue
201
+ }
202
+ data := strings.TrimPrefix(line, "data: ")
203
+ if data == "[DONE]" {
204
+ break
205
+ }
206
+ var chunk RawChunk
207
+ if err := json.Unmarshal([]byte(data), &chunk); err != nil {
208
+ continue
209
+ }
210
+ chunks = append(chunks, chunk)
211
+
212
+ for _, choice := range chunk.Choices {
213
+ for _, tc := range choice.Delta.ToolCalls {
214
+ acc, ok := accum[tc.Index]
215
+ if !ok {
216
+ acc = &AccumToolCall{Index: tc.Index}
217
+ accum[tc.Index] = acc
218
+ }
219
+ if tc.ID != "" {
220
+ acc.ID = tc.ID
221
+ }
222
+ if tc.Type != "" {
223
+ acc.Type = tc.Type
224
+ }
225
+ acc.Name += tc.Function.Name
226
+ acc.Args += tc.Function.Arguments
227
+ }
228
+ }
229
+ }
230
+ return chunks, accum, scanner.Err()
231
+ }
232
+
233
+ // hasToolCalls returns true if any chunk contains tool_calls
234
+ func hasToolCallsInChunks(chunks []RawChunk) bool {
235
+ for _, chunk := range chunks {
236
+ for _, choice := range chunk.Choices {
237
+ if len(choice.Delta.ToolCalls) > 0 {
238
+ return true
239
+ }
240
+ if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" {
241
+ return true
242
+ }
243
+ }
244
+ }
245
+ return false
246
+ }
247
+
248
+ func assembleToolCalls(accum map[int]*AccumToolCall) []map[string]interface{} {
249
+ indices := make([]int, 0, len(accum))
250
+ for idx := range accum {
251
+ indices = append(indices, idx)
252
+ }
253
+ sort.Ints(indices)
254
+
255
+ result := make([]map[string]interface{}, 0, len(indices))
256
+ for _, idx := range indices {
257
+ acc := accum[idx]
258
+ tcType := acc.Type
259
+ if tcType == "" {
260
+ tcType = "function"
261
+ }
262
+ result = append(result, map[string]interface{}{
263
+ "index": idx,
264
+ "id": acc.ID,
265
+ "type": tcType,
266
+ "function": map[string]string{
267
+ "name": acc.Name,
268
+ "arguments": acc.Args,
269
+ },
270
+ })
271
+ }
272
+ return result
273
+ }
274
+
275
  func handleChat(w http.ResponseWriter, r *http.Request) {
276
  if !authenticate(r) {
277
  http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
 
288
  }
289
 
290
  modelID := resolveModel(req.Model)
291
+ messages := injectSystemPrompt(req.Messages, modelID)
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
+ buildUpstream := func(msgs []Message) UpstreamRequest {
294
+ u := UpstreamRequest{
295
+ Model: modelID,
296
+ Messages: msgs,
297
+ Stream: true,
298
+ Tools: req.Tools,
299
+ ToolChoice: req.ToolChoice,
300
+ Temperature: req.Temperature,
301
+ MaxTokens: req.MaxTokens,
302
+ TopP: req.TopP,
303
+ Stop: req.Stop,
304
  }
305
+ if modelID == "z-ai/glm4.7" {
306
+ u.ExtraBody = map[string]interface{}{
307
+ "chat_template_kwargs": map[string]interface{}{
308
+ "enable_thinking": false,
309
+ },
310
+ }
311
+ }
312
+ return u
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  }
314
 
315
  w.Header().Set("Content-Type", "text/event-stream")
 
319
  w.WriteHeader(http.StatusOK)
320
 
321
  flusher, canFlush := w.(http.Flusher)
322
+ emit := func(s string) {
323
+ fmt.Fprint(w, s)
 
 
 
 
 
 
 
 
324
  if canFlush {
325
  flusher.Flush()
326
  }
327
  }
328
+
329
+ // Agentic loop — handle tool_calls rounds
330
+ for {
331
+ resp, err := doUpstream(buildUpstream(messages))
332
+ if err != nil {
333
+ emit(fmt.Sprintf("data: {\"error\":{\"message\":\"%s\"}}\n\n", err.Error()))
334
+ return
335
+ }
336
+
337
+ if resp.StatusCode != http.StatusOK {
338
+ body, _ := io.ReadAll(resp.Body)
339
+ resp.Body.Close()
340
+ emit("data: " + string(body) + "\n\n")
341
+ return
342
+ }
343
+
344
+ chunks, accum, _ := collectStream(resp.Body)
345
+ resp.Body.Close()
346
+
347
+ if len(chunks) == 0 {
348
+ break
349
+ }
350
+
351
+ lastChunk := chunks[len(chunks)-1]
352
+
353
+ if hasToolCallsInChunks(chunks) && len(req.Tools) > 0 {
354
+ // Emit tool_calls chunk to client in OpenAI format
355
+ assembled := assembleToolCalls(accum)
356
+
357
+ fr := "tool_calls"
358
+ toolChunk := map[string]interface{}{
359
+ "id": lastChunk.ID,
360
+ "object": "chat.completion.chunk",
361
+ "created": lastChunk.Created,
362
+ "model": req.Model,
363
+ "choices": []map[string]interface{}{
364
+ {
365
+ "index": 0,
366
+ "delta": map[string]interface{}{
367
+ "role": "assistant",
368
+ "content": nil,
369
+ "tool_calls": assembled,
370
+ },
371
+ "finish_reason": fr,
372
+ },
373
+ },
374
+ }
375
+ out, _ := json.Marshal(toolChunk)
376
+ emit("data: " + string(out) + "\n\n")
377
+
378
+ // Add assistant tool_calls message to history
379
+ messages = append(messages, Message{
380
+ Role: "assistant",
381
+ Content: nil,
382
+ ToolCalls: assembled,
383
+ })
384
+
385
+ // Add placeholder tool results — client must re-call with results
386
+ // For now signal finish so client can handle tool execution
387
+ emit("data: [DONE]\n\n")
388
+ return
389
+ }
390
+
391
+ // No tool_calls — stream content chunks directly to client
392
+ for _, chunk := range chunks {
393
+ // Remap model alias
394
+ chunk.Model = req.Model
395
+ out, err := json.Marshal(chunk)
396
+ if err != nil {
397
+ continue
398
+ }
399
+ emit("data: " + string(out) + "\n\n")
400
+ }
401
+ break
402
+ }
403
+
404
+ emit("data: [DONE]\n\n")
405
  }
406
 
407
  func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {