oki692 commited on
Commit
56c3217
·
verified ·
1 Parent(s): bf9e75f

Update main.go

Browse files
Files changed (1) hide show
  1. main.go +257 -1
main.go CHANGED
@@ -9,7 +9,6 @@ import (
9
  "log"
10
  "net/http"
11
  "os"
12
- "strings"
13
  "time"
14
  )
15
 
@@ -19,6 +18,263 @@ const (
19
  GatewayAPIKey = "connect"
20
  )
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  var modelAliases = map[string]string{
23
  "Bielik-11b": "speakleash/bielik-11b-v2.6-instruct",
24
  "GLM-4.7": "z-ai/glm4.7",
 
9
  "log"
10
  "net/http"
11
  "os"
 
12
  "time"
13
  )
14
 
 
18
  GatewayAPIKey = "connect"
19
  )
20
 
21
+ var modelAliases = map[string]string{
22
+ "Bielik-11b": "speakleash/bielik-11b-v2.6-instruct",
23
+ "GLM-4.7": "z-ai/glm4.7",
24
+ "Mistral-Small-4": "mistralai/mistral-small-4-119b-2603",
25
+ "DeepSeek-V3.1": "deepseek-ai/deepseek-v3.1",
26
+ "Kimi-K2": "moonshotai/kimi-k2-instruct",
27
+ }
28
+
29
+
30
+ var thinkingModels = map[string]bool{
31
+ "z-ai/glm4.7": true,
32
+ }
33
+
34
+ type Message struct {
35
+ Role string `json:"role"`
36
+ Content interface{} `json:"content"`
37
+ ToolCallID string `json:"tool_call_id,omitempty"`
38
+ ToolCalls interface{} `json:"tool_calls,omitempty"`
39
+ Name string `json:"name,omitempty"`
40
+ }
41
+
42
+ type ChatRequest struct {
43
+ Model string `json:"model"`
44
+ Messages []Message `json:"messages"`
45
+ Stream *bool `json:"stream,omitempty"`
46
+ Tools []interface{} `json:"tools,omitempty"`
47
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
48
+ Temperature *float64 `json:"temperature,omitempty"`
49
+ MaxTokens *int `json:"max_tokens,omitempty"`
50
+ TopP *float64 `json:"top_p,omitempty"`
51
+ Stop interface{} `json:"stop,omitempty"`
52
+ }
53
+
54
+ type UpstreamRequest struct {
55
+ Model string `json:"model"`
56
+ Messages []Message `json:"messages"`
57
+ Stream bool `json:"stream"`
58
+ Tools []interface{} `json:"tools,omitempty"`
59
+ ToolChoice interface{} `json:"tool_choice,omitempty"`
60
+ Temperature *float64 `json:"temperature,omitempty"`
61
+ MaxTokens *int `json:"max_tokens,omitempty"`
62
+ TopP *float64 `json:"top_p,omitempty"`
63
+ Stop interface{} `json:"stop,omitempty"`
64
+ ExtraBody *ExtraBody `json:"extra_body,omitempty"`
65
+ }
66
+
67
+ type ExtraBody struct {
68
+ ChatTemplateKwargs map[string]interface{} `json:"chat_template_kwargs,omitempty"`
69
+ }
70
+
71
+ func resolveModel(requested string) string {
72
+ if full, ok := modelAliases[requested]; ok {
73
+ return full
74
+ }
75
+ for _, full := range modelAliases {
76
+ if full == requested {
77
+ return requested
78
+ }
79
+ }
80
+ return requested
81
+ }
82
+
83
+ func injectSystemPrompt(messages []Message, modelID string) []Message {
84
+ // Strip all system messages from client — gateway prompt is the only one
85
+ filtered := make([]Message, 0, len(messages))
86
+ for _, m := range messages {
87
+ if m.Role != "system" {
88
+ filtered = append(filtered, m)
89
+ }
90
+ }
91
+
92
+ prompt, ok := systemPrompts[modelID]
93
+ if !ok || prompt == "" {
94
+ return filtered
95
+ }
96
+
97
+ return append([]Message{{Role: "system", Content: prompt}}, filtered...)
98
+ }
99
+
100
+ func authenticate(r *http.Request) bool {
101
+ auth := r.Header.Get("Authorization")
102
+ if len(auth) > 7 && auth[:7] == "Bearer " {
103
+ if auth[7:] == GatewayAPIKey {
104
+ return true
105
+ }
106
+ }
107
+ if r.Header.Get("x-api-key") == GatewayAPIKey {
108
+ return true
109
+ }
110
+ return false
111
+ }
112
+
113
+ func handleModels(w http.ResponseWriter, r *http.Request) {
114
+ if !authenticate(r) {
115
+ http.Error(w, `{"error":{"message":"Unauthorized","type":"auth_error"}}`, http.StatusUnauthorized)
116
+ return
117
+ }
118
+
119
+ type ModelObj struct {
120
+ ID string `json:"id"`
121
+ Object string `json:"object"`
122
+ Created int64 `json:"created"`
123
+ OwnedBy string `json:"owned_by"`
124
+ }
125
+ type ModelsResponse struct {
126
+ Object string `json:"object"`
127
+ Data []ModelObj `json:"data"`
128
+ }
129
+
130
+ models := ModelsResponse{Object: "list"}
131
+ now := time.Now().Unix()
132
+ for alias := range modelAliases {
133
+ models.Data = append(models.Data, ModelObj{
134
+ ID: alias,
135
+ Object: "model",
136
+ Created: now,
137
+ OwnedBy: "nvidia",
138
+ })
139
+ }
140
+
141
+ w.Header().Set("Content-Type", "application/json")
142
+ json.NewEncoder(w).Encode(models)
143
+ }
144
+
145
+ func handleChat(w http.ResponseWriter, r *http.Request) {
146
+ if !authenticate(r) {
147
+ http.Error(w, `{"error":{"message":"Unauthorized","type":"auth_error"}}`, http.StatusUnauthorized)
148
+ return
149
+ }
150
+
151
+ if r.Method != http.MethodPost {
152
+ http.Error(w, `{"error":{"message":"Method not allowed"}}`, http.StatusMethodNotAllowed)
153
+ return
154
+ }
155
+
156
+ var req ChatRequest
157
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
158
+ http.Error(w, `{"error":{"message":"Invalid request body"}}`, http.StatusBadRequest)
159
+ return
160
+ }
161
+
162
+ modelID := resolveModel(req.Model)
163
+ req.Messages = injectSystemPrompt(req.Messages, modelID)
164
+
165
+ upstream := UpstreamRequest{
166
+ Model: modelID,
167
+ Messages: req.Messages,
168
+ Stream: true,
169
+ Tools: req.Tools,
170
+ ToolChoice: req.ToolChoice,
171
+ Temperature: req.Temperature,
172
+ MaxTokens: req.MaxTokens,
173
+ TopP: req.TopP,
174
+ Stop: req.Stop,
175
+ }
176
+
177
+ if thinkingModels[modelID] {
178
+ upstream.ExtraBody = &ExtraBody{
179
+ ChatTemplateKwargs: map[string]interface{}{
180
+ "enable_thinking": false,
181
+ "clear_thinking": true,
182
+ },
183
+ }
184
+ }
185
+
186
+ body, err := json.Marshal(upstream)
187
+ if err != nil {
188
+ http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError)
189
+ return
190
+ }
191
+
192
+ upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body))
193
+ if err != nil {
194
+ http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError)
195
+ return
196
+ }
197
+ upstreamReq.Header.Set("Content-Type", "application/json")
198
+ upstreamReq.Header.Set("Authorization", "Bearer "+NvidiaAPIKey)
199
+ upstreamReq.Header.Set("Accept", "text/event-stream")
200
+
201
+ client := &http.Client{Timeout: 300 * time.Second}
202
+ resp, err := client.Do(upstreamReq)
203
+ if err != nil {
204
+ http.Error(w, fmt.Sprintf(`{"error":{"message":"Upstream error: %s"}}`, err.Error()), http.StatusBadGateway)
205
+ return
206
+ }
207
+ defer resp.Body.Close()
208
+
209
+ if resp.StatusCode != http.StatusOK {
210
+ upstreamBody, _ := io.ReadAll(resp.Body)
211
+ w.Header().Set("Content-Type", "application/json")
212
+ w.WriteHeader(resp.StatusCode)
213
+ w.Write(upstreamBody)
214
+ return
215
+ }
216
+
217
+ w.Header().Set("Content-Type", "text/event-stream")
218
+ w.Header().Set("Cache-Control", "no-cache")
219
+ w.Header().Set("Connection", "keep-alive")
220
+ w.Header().Set("X-Accel-Buffering", "no")
221
+ w.WriteHeader(http.StatusOK)
222
+
223
+ flusher, canFlush := w.(http.Flusher)
224
+ scanner := bufio.NewScanner(resp.Body)
225
+ scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
226
+
227
+ for scanner.Scan() {
228
+ fmt.Fprintf(w, "%s\n", scanner.Text())
229
+ if canFlush {
230
+ flusher.Flush()
231
+ }
232
+ }
233
+ }
234
+
235
+ func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
236
+ return func(w http.ResponseWriter, r *http.Request) {
237
+ start := time.Now()
238
+ log.Printf("[%s] %s %s", r.Method, r.URL.Path, r.RemoteAddr)
239
+ next(w, r)
240
+ log.Printf("[%s] %s done in %s", r.Method, r.URL.Path, time.Since(start))
241
+ }
242
+ }
243
+
244
+ func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
245
+ return func(w http.ResponseWriter, r *http.Request) {
246
+ w.Header().Set("Access-Control-Allow-Origin", "*")
247
+ w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
248
+ w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key")
249
+ if r.Method == http.MethodOptions {
250
+ w.WriteHeader(http.StatusNoContent)
251
+ return
252
+ }
253
+ next(w, r)
254
+ }
255
+ }
256
+
257
+ func main() {
258
+ port := os.Getenv("PORT")
259
+ if port == "" {
260
+ port = "7860"
261
+ }
262
+
263
+ mux := http.NewServeMux()
264
+ mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(handleChat)))
265
+ mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(handleModels)))
266
+ mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
267
+ w.WriteHeader(http.StatusOK)
268
+ w.Write([]byte(`{"status":"ok"}`))
269
+ })
270
+
271
+ log.Printf("Gateway starting on :%s", port)
272
+ if err := http.ListenAndServe(":"+port, mux); err != nil {
273
+ log.Fatal(err)
274
+ }
275
+ }
276
+
277
+
278
  var modelAliases = map[string]string{
279
  "Bielik-11b": "speakleash/bielik-11b-v2.6-instruct",
280
  "GLM-4.7": "z-ai/glm4.7",