oki692 commited on
Commit
2c19ea4
·
verified ·
1 Parent(s): 4f73969

Update main.go

Browse files
Files changed (1) hide show
  1. main.go +44 -181
main.go CHANGED
@@ -9,7 +9,6 @@ import (
9
  "log"
10
  "net/http"
11
  "os"
12
- "strings"
13
  "time"
14
  )
15
 
@@ -48,15 +47,22 @@ type ChatRequest struct {
48
  }
49
 
50
  type UpstreamRequest struct {
51
- Model string `json:"model"`
52
- Messages []Message `json:"messages"`
53
- Stream bool `json:"stream"`
54
- Tools []interface{} `json:"tools,omitempty"`
55
- ToolChoice interface{} `json:"tool_choice,omitempty"`
56
- Temperature *float64 `json:"temperature,omitempty"`
57
- MaxTokens *int `json:"max_tokens,omitempty"`
58
- TopP *float64 `json:"top_p,omitempty"`
59
- Stop interface{} `json:"stop,omitempty"`
 
 
 
 
 
 
 
60
  }
61
 
62
  func resolveModel(requested string) string {
@@ -87,20 +93,15 @@ func injectSystemPrompt(messages []Message, modelID string) []Message {
87
 
88
  func authenticate(r *http.Request) bool {
89
  auth := r.Header.Get("Authorization")
90
- if len(auth) > 7 && auth[:7] == "Bearer " {
91
- if auth[7:] == GatewayAPIKey {
92
- return true
93
- }
94
- }
95
- if r.Header.Get("x-api-key") == GatewayAPIKey {
96
  return true
97
  }
98
- return false
99
  }
100
 
101
  func handleModels(w http.ResponseWriter, r *http.Request) {
102
  if !authenticate(r) {
103
- http.Error(w, `{"error":{"message":"Unauthorized","type":"auth_error"}}`, http.StatusUnauthorized)
104
  return
105
  }
106
  type ModelObj struct {
@@ -116,61 +117,24 @@ func handleModels(w http.ResponseWriter, r *http.Request) {
116
  models := ModelsResponse{Object: "list"}
117
  now := time.Now().Unix()
118
  for alias := range modelAliases {
119
- models.Data = append(models.Data, ModelObj{
120
- ID: alias, Object: "model", Created: now, OwnedBy: "nvidia",
121
- })
122
  }
123
  w.Header().Set("Content-Type", "application/json")
124
  json.NewEncoder(w).Encode(models)
125
  }
126
 
127
- // StreamChoice represents a single choice in a streaming chunk
128
- type StreamChoice struct {
129
- Index int `json:"index"`
130
- Delta StreamDelta `json:"delta"`
131
- FinishReason *string `json:"finish_reason"`
132
- }
133
-
134
- // StreamDelta is the delta object inside a streaming chunk
135
- type StreamDelta struct {
136
- Role string `json:"role,omitempty"`
137
- Content *string `json:"content,omitempty"`
138
- ToolCalls []ToolCallChunk `json:"tool_calls,omitempty"`
139
- }
140
-
141
- // ToolCallChunk is a partial tool call in a streaming delta
142
- type ToolCallChunk struct {
143
- Index int `json:"index"`
144
- ID string `json:"id,omitempty"`
145
- Type string `json:"type,omitempty"`
146
- Function ToolCallFunction `json:"function,omitempty"`
147
- }
148
-
149
- type ToolCallFunction struct {
150
- Name string `json:"name,omitempty"`
151
- Arguments string `json:"arguments,omitempty"`
152
- }
153
-
154
- // StreamChunk is a full SSE data chunk from upstream
155
- type StreamChunk struct {
156
- ID string `json:"id"`
157
- Object string `json:"object"`
158
- Created int64 `json:"created"`
159
- Model string `json:"model"`
160
- Choices []StreamChoice `json:"choices"`
161
- }
162
-
163
- // AccumulatedToolCall holds a tool call being assembled across chunks
164
- type AccumulatedToolCall struct {
165
- ID string
166
- Type string
167
- Name string
168
- Args string
169
  }
170
 
171
  func handleChat(w http.ResponseWriter, r *http.Request) {
172
  if !authenticate(r) {
173
- http.Error(w, `{"error":{"message":"Unauthorized","type":"auth_error"}}`, http.StatusUnauthorized)
174
  return
175
  }
176
  if r.Method != http.MethodPost {
@@ -197,11 +161,22 @@ func handleChat(w http.ResponseWriter, r *http.Request) {
197
  TopP: req.TopP,
198
  Stop: req.Stop,
199
  }
 
 
 
 
 
 
 
 
 
 
200
  body, err := json.Marshal(upstream)
201
  if err != nil {
202
  http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError)
203
  return
204
  }
 
205
  upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body))
206
  if err != nil {
207
  http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError)
@@ -237,129 +212,17 @@ func handleChat(w http.ResponseWriter, r *http.Request) {
237
  scanner := bufio.NewScanner(resp.Body)
238
  scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
239
 
240
- // Accumulate tool calls across chunks
241
- toolCalls := make(map[int]*AccumulatedToolCall)
242
-
243
  for scanner.Scan() {
244
  line := scanner.Text()
245
-
246
- // Pass through non-data lines (empty lines, comments)
247
- if !strings.HasPrefix(line, "data: ") {
248
  fmt.Fprintf(w, "%s\n", line)
249
- if canFlush {
250
- flusher.Flush()
251
- }
252
- continue
253
  }
254
-
255
- data := strings.TrimPrefix(line, "data: ")
256
-
257
- // Pass through [DONE]
258
- if data == "[DONE]" {
259
- fmt.Fprintf(w, "data: [DONE]\n\n")
260
- if canFlush {
261
- flusher.Flush()
262
- }
263
- continue
264
  }
265
-
266
- var chunk StreamChunk
267
- if err := json.Unmarshal([]byte(data), &chunk); err != nil {
268
- // Can't parse — forward as-is
269
- fmt.Fprintf(w, "%s\n", line)
270
- if canFlush {
271
- flusher.Flush()
272
- }
273
- continue
274
- }
275
-
276
- // Process tool_calls deltas — accumulate per index
277
- for _, choice := range chunk.Choices {
278
- for _, tc := range choice.Delta.ToolCalls {
279
- acc, exists := toolCalls[tc.Index]
280
- if !exists {
281
- acc = &AccumulatedToolCall{}
282
- toolCalls[tc.Index] = acc
283
- }
284
- if tc.ID != "" {
285
- acc.ID = tc.ID
286
- }
287
- if tc.Type != "" {
288
- acc.Type = tc.Type
289
- }
290
- if tc.Function.Name != "" {
291
- acc.Name += tc.Function.Name
292
- }
293
- acc.Args += tc.Function.Arguments
294
- }
295
-
296
- // On finish_reason=tool_calls emit a synthetic complete chunk
297
- if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" {
298
- assembled := make([]map[string]interface{}, 0, len(toolCalls))
299
- for idx, acc := range toolCalls {
300
- assembled = append(assembled, map[string]interface{}{
301
- "index": idx,
302
- "id": acc.ID,
303
- "type": "function",
304
- "function": map[string]string{
305
- "name": acc.Name,
306
- "arguments": acc.Args,
307
- },
308
- })
309
- }
310
- finishReason := "tool_calls"
311
- synthetic := map[string]interface{}{
312
- "id": chunk.ID,
313
- "object": chunk.Object,
314
- "created": chunk.Created,
315
- "model": chunk.Model,
316
- "choices": []map[string]interface{}{
317
- {
318
- "index": choice.Index,
319
- "delta": map[string]interface{}{
320
- "role": "assistant",
321
- "tool_calls": assembled,
322
- },
323
- "finish_reason": finishReason,
324
- },
325
- },
326
- }
327
- out, _ := json.Marshal(synthetic)
328
- fmt.Fprintf(w, "data: %s\n\n", out)
329
- if canFlush {
330
- flusher.Flush()
331
- }
332
- // Reset accumulator for next potential call
333
- toolCalls = make(map[int]*AccumulatedToolCall)
334
- continue
335
- }
336
- }
337
-
338
- // For regular content chunks — forward as-is
339
- hasContent := false
340
- for _, choice := range chunk.Choices {
341
- if choice.Delta.Content != nil || (choice.FinishReason != nil && *choice.FinishReason != "tool_calls") {
342
- hasContent = true
343
- break
344
- }
345
- }
346
- if hasContent {
347
- fmt.Fprintf(w, "data: %s\n\n", data)
348
- if canFlush {
349
- flusher.Flush()
350
- }
351
- }
352
- }
353
- }
354
-
355
-
356
- func handleBaseURL(w http.ResponseWriter, r *http.Request) {
357
- host := os.Getenv("SPACE_HOST")
358
- if host == "" {
359
- host = r.Host
360
  }
361
- w.Header().Set("Content-Type", "application/json")
362
- fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host)
363
  }
364
 
365
  func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
 
9
  "log"
10
  "net/http"
11
  "os"
 
12
  "time"
13
  )
14
 
 
47
  }
48
 
49
  type UpstreamRequest struct {
50
+ Model string `json:"model"`
51
+ Messages []Message `json:"messages"`
52
+ Stream bool `json:"stream"`
53
+ Tools []interface{} `json:"tools,omitempty"`
54
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
55
+ Temperature *float64 `json:"temperature,omitempty"`
56
+ MaxTokens *int `json:"max_tokens,omitempty"`
57
+ TopP *float64 `json:"top_p,omitempty"`
58
+ Stop interface{} `json:"stop,omitempty"`
59
+ ExtraBody map[string]interface{} `json:"extra_body,omitempty"`
60
+ }
61
+
62
+ type StreamChoice struct {
63
+ Index int `json:"index"`
64
+ Delta StreamDelta `json:"delta"`
65
+ FinishReason *string `json:"finish_reason"`
66
  }
67
 
68
  func resolveModel(requested string) string {
 
93
 
94
  func authenticate(r *http.Request) bool {
95
  auth := r.Header.Get("Authorization")
96
+ if len(auth) > 7 && auth[:7] == "Bearer " && auth[7:] == GatewayAPIKey {
 
 
 
 
 
97
  return true
98
  }
99
+ return r.Header.Get("x-api-key") == GatewayAPIKey
100
  }
101
 
102
  func handleModels(w http.ResponseWriter, r *http.Request) {
103
  if !authenticate(r) {
104
+ http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
105
  return
106
  }
107
  type ModelObj struct {
 
117
  models := ModelsResponse{Object: "list"}
118
  now := time.Now().Unix()
119
  for alias := range modelAliases {
120
+ models.Data = append(models.Data, ModelObj{ID: alias, Object: "model", Created: now, OwnedBy: "nvidia"})
 
 
121
  }
122
  w.Header().Set("Content-Type", "application/json")
123
  json.NewEncoder(w).Encode(models)
124
  }
125
 
126
+ func handleBaseURL(w http.ResponseWriter, r *http.Request) {
127
+ host := os.Getenv("SPACE_HOST")
128
+ if host == "" {
129
+ host = r.Host
130
+ }
131
+ w.Header().Set("Content-Type", "application/json")
132
+ fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  }
134
 
135
  func handleChat(w http.ResponseWriter, r *http.Request) {
136
  if !authenticate(r) {
137
+ http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
138
  return
139
  }
140
  if r.Method != http.MethodPost {
 
161
  TopP: req.TopP,
162
  Stop: req.Stop,
163
  }
164
+
165
+ // GLM-4.7 requires thinking disabled via extra_body
166
+ if modelID == "z-ai/glm4.7" {
167
+ upstream.ExtraBody = map[string]interface{}{
168
+ "chat_template_kwargs": map[string]interface{}{
169
+ "enable_thinking": false,
170
+ },
171
+ }
172
+ }
173
+
174
  body, err := json.Marshal(upstream)
175
  if err != nil {
176
  http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError)
177
  return
178
  }
179
+
180
  upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body))
181
  if err != nil {
182
  http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError)
 
212
  scanner := bufio.NewScanner(resp.Body)
213
  scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
214
 
 
 
 
215
  for scanner.Scan() {
216
  line := scanner.Text()
217
+ if line != "" {
 
 
218
  fmt.Fprintf(w, "%s\n", line)
219
+ } else {
220
+ fmt.Fprintf(w, "\n")
 
 
221
  }
222
+ if canFlush {
223
+ flusher.Flush()
 
 
 
 
 
 
 
 
224
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  }
 
 
226
  }
227
 
228
  func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {