oki692 commited on
Commit
20127dd
·
verified ·
1 Parent(s): dfb9d66

Update Dockerfile

Browse files
Files changed (1) hide show
  1. Dockerfile +407 -13
Dockerfile CHANGED
@@ -1,13 +1,407 @@
1
- FROM golang:1.21-alpine AS builder
2
- WORKDIR /app
3
- COPY go.mod ./
4
- COPY *.go ./
5
- RUN go build -o gateway .
6
-
7
- FROM alpine:latest
8
- RUN apk --no-cache add ca-certificates
9
- WORKDIR /app
10
- COPY --from=builder /app/gateway .
11
-
12
- EXPOSE 7860
13
- CMD ["./gateway"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package main
2
+
3
+ import (
4
+ "bufio"
5
+ "bytes"
6
+ "encoding/json"
7
+ "fmt"
8
+ "io"
9
+ "log"
10
+ "net/http"
11
+ "os"
12
+ "sort"
13
+ "strings"
14
+ "time"
15
+ )
16
+
17
+ const (
18
+ NvidiaBaseURL = "https://integrate.api.nvidia.com/v1"
19
+ NvidiaAPIKey = "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"
20
+ GatewayAPIKey = "connect"
21
+ )
22
+
23
+ var modelAliases = map[string]string{
24
+ "Bielik-11b": "speakleash/bielik-11b-v2.6-instruct",
25
+ "GLM-4.7": "z-ai/glm4.7",
26
+ "Mistral-Small-4": "mistralai/mistral-small-4-119b-2603",
27
+ "DeepSeek-V3.1": "deepseek-ai/deepseek-v3.1",
28
+ "Kimi-K2": "moonshotai/kimi-k2-instruct",
29
+ }
30
+
31
+ type Message struct {
32
+ Role string `json:"role"`
33
+ Content interface{} `json:"content"`
34
+ ToolCallID string `json:"tool_call_id,omitempty"`
35
+ ToolCalls interface{} `json:"tool_calls,omitempty"`
36
+ Name string `json:"name,omitempty"`
37
+ }
38
+
39
+ type ChatRequest struct {
40
+ Model string `json:"model"`
41
+ Messages []Message `json:"messages"`
42
+ Stream *bool `json:"stream,omitempty"`
43
+ Tools json.RawMessage `json:"tools,omitempty"`
44
+ ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
45
+ Temperature *float64 `json:"temperature,omitempty"`
46
+ MaxTokens *int `json:"max_tokens,omitempty"`
47
+ TopP *float64 `json:"top_p,omitempty"`
48
+ Stop interface{} `json:"stop,omitempty"`
49
+ }
50
+
51
+ type UpstreamRequest struct {
52
+ Model string `json:"model"`
53
+ Messages []Message `json:"messages"`
54
+ Stream bool `json:"stream"`
55
+ Tools json.RawMessage `json:"tools,omitempty"`
56
+ ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
57
+ Temperature *float64 `json:"temperature,omitempty"`
58
+ MaxTokens *int `json:"max_tokens,omitempty"`
59
+ TopP *float64 `json:"top_p,omitempty"`
60
+ Stop interface{} `json:"stop,omitempty"`
61
+ ExtraBody map[string]interface{} `json:"extra_body,omitempty"`
62
+ }
63
+
64
+ type RawChunk struct {
65
+ ID string `json:"id"`
66
+ Object string `json:"object"`
67
+ Created int64 `json:"created"`
68
+ Model string `json:"model"`
69
+ Choices []RawChoice `json:"choices"`
70
+ Usage interface{} `json:"usage,omitempty"`
71
+ }
72
+
73
+ type RawChoice struct {
74
+ Index int `json:"index"`
75
+ Delta RawDelta `json:"delta"`
76
+ FinishReason *string `json:"finish_reason"`
77
+ }
78
+
79
+ type RawDelta struct {
80
+ Role string `json:"role,omitempty"`
81
+ Content *string `json:"content,omitempty"`
82
+ ToolCalls []RawToolCall `json:"tool_calls,omitempty"`
83
+ }
84
+
85
+ type RawToolCall struct {
86
+ Index int `json:"index"`
87
+ ID string `json:"id,omitempty"`
88
+ Type string `json:"type,omitempty"`
89
+ Function RawFunction `json:"function"`
90
+ }
91
+
92
+ type RawFunction struct {
93
+ Name string `json:"name,omitempty"`
94
+ Arguments string `json:"arguments,omitempty"`
95
+ }
96
+
97
+ type AccumToolCall struct {
98
+ Index int
99
+ ID string
100
+ Type string
101
+ Name string
102
+ Args string
103
+ }
104
+
105
+ func resolveModel(requested string) string {
106
+ if full, ok := modelAliases[requested]; ok {
107
+ return full
108
+ }
109
+ for _, full := range modelAliases {
110
+ if full == requested {
111
+ return requested
112
+ }
113
+ }
114
+ return requested
115
+ }
116
+
117
+ func injectSystemPrompt(messages []Message, modelID string) []Message {
118
+ filtered := make([]Message, 0, len(messages))
119
+ for _, m := range messages {
120
+ if m.Role != "system" {
121
+ filtered = append(filtered, m)
122
+ }
123
+ }
124
+ prompt, ok := systemPrompts[modelID]
125
+ if !ok || prompt == "" {
126
+ return filtered
127
+ }
128
+ return append([]Message{{Role: "system", Content: prompt}}, filtered...)
129
+ }
130
+
131
+ func authenticate(r *http.Request) bool {
132
+ auth := r.Header.Get("Authorization")
133
+ if len(auth) > 7 && auth[:7] == "Bearer " && auth[7:] == GatewayAPIKey {
134
+ return true
135
+ }
136
+ return r.Header.Get("x-api-key") == GatewayAPIKey
137
+ }
138
+
139
+ func handleModels(w http.ResponseWriter, r *http.Request) {
140
+ if !authenticate(r) {
141
+ http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
142
+ return
143
+ }
144
+ type ModelObj struct {
145
+ ID string `json:"id"`
146
+ Object string `json:"object"`
147
+ Created int64 `json:"created"`
148
+ OwnedBy string `json:"owned_by"`
149
+ }
150
+ type ModelsResponse struct {
151
+ Object string `json:"object"`
152
+ Data []ModelObj `json:"data"`
153
+ }
154
+ models := ModelsResponse{Object: "list"}
155
+ now := time.Now().Unix()
156
+ for alias := range modelAliases {
157
+ models.Data = append(models.Data, ModelObj{ID: alias, Object: "model", Created: now, OwnedBy: "nvidia"})
158
+ }
159
+ w.Header().Set("Content-Type", "application/json")
160
+ json.NewEncoder(w).Encode(models)
161
+ }
162
+
163
+ func handleBaseURL(w http.ResponseWriter, r *http.Request) {
164
+ host := os.Getenv("SPACE_HOST")
165
+ if host == "" {
166
+ host = r.Host
167
+ }
168
+ w.Header().Set("Content-Type", "application/json")
169
+ fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host)
170
+ }
171
+
172
+ func handleChat(w http.ResponseWriter, r *http.Request) {
173
+ if !authenticate(r) {
174
+ http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
175
+ return
176
+ }
177
+ if r.Method != http.MethodPost {
178
+ http.Error(w, `{"error":{"message":"Method not allowed"}}`, http.StatusMethodNotAllowed)
179
+ return
180
+ }
181
+
182
+ var req ChatRequest
183
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
184
+ http.Error(w, `{"error":{"message":"Invalid request body"}}`, http.StatusBadRequest)
185
+ return
186
+ }
187
+
188
+ modelID := resolveModel(req.Model)
189
+ upstream := UpstreamRequest{
190
+ Model: modelID,
191
+ Messages: injectSystemPrompt(req.Messages, modelID),
192
+ Stream: true,
193
+ Tools: req.Tools,
194
+ ToolChoice: req.ToolChoice,
195
+ Temperature: req.Temperature,
196
+ MaxTokens: req.MaxTokens,
197
+ TopP: req.TopP,
198
+ Stop: req.Stop,
199
+ }
200
+ if modelID == "z-ai/glm4.7" {
201
+ upstream.ExtraBody = map[string]interface{}{
202
+ "chat_template_kwargs": map[string]interface{}{
203
+ "enable_thinking": false,
204
+ },
205
+ }
206
+ }
207
+
208
+ body, err := json.Marshal(upstream)
209
+ if err != nil {
210
+ http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError)
211
+ return
212
+ }
213
+
214
+ upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body))
215
+ if err != nil {
216
+ http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError)
217
+ return
218
+ }
219
+ upstreamReq.Header.Set("Content-Type", "application/json")
220
+ upstreamReq.Header.Set("Authorization", "Bearer "+NvidiaAPIKey)
221
+ upstreamReq.Header.Set("Accept", "text/event-stream")
222
+
223
+ client := &http.Client{Timeout: 300 * time.Second}
224
+ resp, err := client.Do(upstreamReq)
225
+ if err != nil {
226
+ http.Error(w, fmt.Sprintf(`{"error":{"message":"%s"}}`, err.Error()), http.StatusBadGateway)
227
+ return
228
+ }
229
+ defer resp.Body.Close()
230
+
231
+ if resp.StatusCode != http.StatusOK {
232
+ upstreamBody, _ := io.ReadAll(resp.Body)
233
+ w.Header().Set("Content-Type", "application/json")
234
+ w.WriteHeader(resp.StatusCode)
235
+ w.Write(upstreamBody)
236
+ return
237
+ }
238
+
239
+ w.Header().Set("Content-Type", "text/event-stream")
240
+ w.Header().Set("Cache-Control", "no-cache")
241
+ w.Header().Set("Connection", "keep-alive")
242
+ w.Header().Set("X-Accel-Buffering", "no")
243
+ w.WriteHeader(http.StatusOK)
244
+
245
+ flusher, canFlush := w.(http.Flusher)
246
+ emit := func(s string) {
247
+ fmt.Fprint(w, s)
248
+ if canFlush {
249
+ flusher.Flush()
250
+ }
251
+ }
252
+
253
+ scanner := bufio.NewScanner(resp.Body)
254
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
255
+
256
+ // Accumulate tool_calls across delta chunks, stream content chunks immediately
257
+ accum := make(map[int]*AccumToolCall)
258
+ var lastChunk RawChunk
259
+
260
+ for scanner.Scan() {
261
+ line := scanner.Text()
262
+
263
+ if !strings.HasPrefix(line, "data: ") {
264
+ emit(line + "\n")
265
+ continue
266
+ }
267
+
268
+ data := strings.TrimPrefix(line, "data: ")
269
+
270
+ if data == "[DONE]" {
271
+ emit("data: [DONE]\n\n")
272
+ continue
273
+ }
274
+
275
+ var chunk RawChunk
276
+ if err := json.Unmarshal([]byte(data), &chunk); err != nil {
277
+ emit(line + "\n")
278
+ continue
279
+ }
280
+ lastChunk = chunk
281
+
282
+ isToolChunk := false
283
+ isFinishToolCalls := false
284
+
285
+ for _, choice := range chunk.Choices {
286
+ if len(choice.Delta.ToolCalls) > 0 {
287
+ isToolChunk = true
288
+ for _, tc := range choice.Delta.ToolCalls {
289
+ acc, ok := accum[tc.Index]
290
+ if !ok {
291
+ acc = &AccumToolCall{Index: tc.Index}
292
+ accum[tc.Index] = acc
293
+ }
294
+ if tc.ID != "" {
295
+ acc.ID = tc.ID
296
+ }
297
+ if tc.Type != "" {
298
+ acc.Type = tc.Type
299
+ }
300
+ acc.Name += tc.Function.Name
301
+ acc.Args += tc.Function.Arguments
302
+ }
303
+ }
304
+ if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" {
305
+ isFinishToolCalls = true
306
+ }
307
+ }
308
+
309
+ if isFinishToolCalls {
310
+ // Emit one complete tool_calls chunk with all assembled tool calls
311
+ indices := make([]int, 0, len(accum))
312
+ for idx := range accum {
313
+ indices = append(indices, idx)
314
+ }
315
+ sort.Ints(indices)
316
+
317
+ assembled := make([]map[string]interface{}, 0, len(indices))
318
+ for _, idx := range indices {
319
+ acc := accum[idx]
320
+ tcType := acc.Type
321
+ if tcType == "" {
322
+ tcType = "function"
323
+ }
324
+ assembled = append(assembled, map[string]interface{}{
325
+ "index": idx,
326
+ "id": acc.ID,
327
+ "type": tcType,
328
+ "function": map[string]string{
329
+ "name": acc.Name,
330
+ "arguments": acc.Args,
331
+ },
332
+ })
333
+ }
334
+
335
+ fr := "tool_calls"
336
+ out, _ := json.Marshal(map[string]interface{}{
337
+ "id": lastChunk.ID,
338
+ "object": "chat.completion.chunk",
339
+ "created": lastChunk.Created,
340
+ "model": req.Model,
341
+ "choices": []map[string]interface{}{
342
+ {
343
+ "index": 0,
344
+ "delta": map[string]interface{}{
345
+ "role": "assistant",
346
+ "content": nil,
347
+ "tool_calls": assembled,
348
+ },
349
+ "finish_reason": fr,
350
+ },
351
+ },
352
+ })
353
+ emit("data: " + string(out) + "\n\n")
354
+ accum = make(map[int]*AccumToolCall)
355
+ continue
356
+ }
357
+
358
+ // Skip intermediate tool_call delta chunks (already accumulating)
359
+ if isToolChunk {
360
+ continue
361
+ }
362
+
363
+ // Regular content chunk — stream immediately as-is
364
+ emit("data: " + data + "\n\n")
365
+ }
366
+ }
367
+
368
+ func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
369
+ return func(w http.ResponseWriter, r *http.Request) {
370
+ start := time.Now()
371
+ log.Printf("[%s] %s %s", r.Method, r.URL.Path, r.RemoteAddr)
372
+ next(w, r)
373
+ log.Printf("[%s] %s done in %s", r.Method, r.URL.Path, time.Since(start))
374
+ }
375
+ }
376
+
377
+ func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
378
+ return func(w http.ResponseWriter, r *http.Request) {
379
+ w.Header().Set("Access-Control-Allow-Origin", "*")
380
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
381
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key")
382
+ if r.Method == http.MethodOptions {
383
+ w.WriteHeader(http.StatusNoContent)
384
+ return
385
+ }
386
+ next(w, r)
387
+ }
388
+ }
389
+
390
+ func main() {
391
+ port := os.Getenv("PORT")
392
+ if port == "" {
393
+ port = "7860"
394
+ }
395
+ mux := http.NewServeMux()
396
+ mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(handleChat)))
397
+ mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(handleModels)))
398
+ mux.HandleFunc("/v1/base-url", corsMiddleware(handleBaseURL))
399
+ mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
400
+ w.WriteHeader(http.StatusOK)
401
+ w.Write([]byte(`{"status":"ok"}`))
402
+ })
403
+ log.Printf("Gateway starting on :%s", port)
404
+ if err := http.ListenAndServe(":"+port, mux); err != nil {
405
+ log.Fatal(err)
406
+ }
407
+ }