| package main |
|
|
| import ( |
| "bufio" |
| "bytes" |
| "encoding/json" |
| "fmt" |
| "io" |
| "log" |
| "net/http" |
| "os" |
| "sort" |
| "strings" |
| "time" |
| ) |
|
|
| const ( |
| NvidiaBaseURL = "https://integrate.api.nvidia.com/v1" |
| NvidiaAPIKey = "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" |
| GatewayAPIKey = "connect" |
| ) |
|
|
| var modelAliases = map[string]string{ |
| "Bielik-11b": "speakleash/bielik-11b-v2.6-instruct", |
| "GLM-4.7": "z-ai/glm4.7", |
| "Mistral-Small-4": "mistralai/mistral-small-4-119b-2603", |
| "DeepSeek-V3.1": "deepseek-ai/deepseek-v3.1", |
| "Kimi-K2": "moonshotai/kimi-k2-instruct", |
| } |
|
|
| type Message struct { |
| Role string `json:"role"` |
| Content interface{} `json:"content"` |
| ToolCallID string `json:"tool_call_id,omitempty"` |
| ToolCalls interface{} `json:"tool_calls,omitempty"` |
| Name string `json:"name,omitempty"` |
| } |
|
|
| type ChatRequest struct { |
| Model string `json:"model"` |
| Messages []Message `json:"messages"` |
| Stream *bool `json:"stream,omitempty"` |
| Tools []interface{} `json:"tools,omitempty"` |
| ToolChoice interface{} `json:"tool_choice,omitempty"` |
| Temperature *float64 `json:"temperature,omitempty"` |
| MaxTokens *int `json:"max_tokens,omitempty"` |
| TopP *float64 `json:"top_p,omitempty"` |
| Stop interface{} `json:"stop,omitempty"` |
| } |
|
|
| type UpstreamRequest struct { |
| Model string `json:"model"` |
| Messages []Message `json:"messages"` |
| Stream bool `json:"stream"` |
| Tools []interface{} `json:"tools,omitempty"` |
| ToolChoice interface{} `json:"tool_choice,omitempty"` |
| Temperature *float64 `json:"temperature,omitempty"` |
| MaxTokens *int `json:"max_tokens,omitempty"` |
| TopP *float64 `json:"top_p,omitempty"` |
| Stop interface{} `json:"stop,omitempty"` |
| ExtraBody map[string]interface{} `json:"extra_body,omitempty"` |
| } |
|
|
| type StreamChoice struct { |
| Index int `json:"index"` |
| Delta StreamDelta `json:"delta"` |
| FinishReason *string `json:"finish_reason"` |
| } |
|
|
| type StreamDelta struct { |
| Role string `json:"role,omitempty"` |
| Content *string `json:"content,omitempty"` |
| ToolCalls []ToolCallChunk `json:"tool_calls,omitempty"` |
| } |
|
|
| type ToolCallChunk struct { |
| Index int `json:"index"` |
| ID string `json:"id,omitempty"` |
| Type string `json:"type,omitempty"` |
| Function ToolCallFunction `json:"function,omitempty"` |
| } |
|
|
| type ToolCallFunction struct { |
| Name string `json:"name,omitempty"` |
| Arguments string `json:"arguments,omitempty"` |
| } |
|
|
| type StreamChunk struct { |
| ID string `json:"id"` |
| Object string `json:"object"` |
| Created int64 `json:"created"` |
| Model string `json:"model"` |
| Choices []StreamChoice `json:"choices"` |
| } |
|
|
| type AccumulatedToolCall struct { |
| ID string |
| Type string |
| Name string |
| Args string |
| } |
|
|
| func resolveModel(requested string) string { |
| if full, ok := modelAliases[requested]; ok { |
| return full |
| } |
| for _, full := range modelAliases { |
| if full == requested { |
| return requested |
| } |
| } |
| return requested |
| } |
|
|
| func injectSystemPrompt(messages []Message, modelID string) []Message { |
| filtered := make([]Message, 0, len(messages)) |
| for _, m := range messages { |
| if m.Role != "system" { |
| filtered = append(filtered, m) |
| } |
| } |
| prompt, ok := systemPrompts[modelID] |
| if !ok || prompt == "" { |
| return filtered |
| } |
| return append([]Message{{Role: "system", Content: prompt}}, filtered...) |
| } |
|
|
| func authenticate(r *http.Request) bool { |
| auth := r.Header.Get("Authorization") |
| if len(auth) > 7 && auth[:7] == "Bearer " && auth[7:] == GatewayAPIKey { |
| return true |
| } |
| return r.Header.Get("x-api-key") == GatewayAPIKey |
| } |
|
|
| func handleModels(w http.ResponseWriter, r *http.Request) { |
| if !authenticate(r) { |
| http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized) |
| return |
| } |
| type ModelObj struct { |
| ID string `json:"id"` |
| Object string `json:"object"` |
| Created int64 `json:"created"` |
| OwnedBy string `json:"owned_by"` |
| } |
| type ModelsResponse struct { |
| Object string `json:"object"` |
| Data []ModelObj `json:"data"` |
| } |
| models := ModelsResponse{Object: "list"} |
| now := time.Now().Unix() |
| for alias := range modelAliases { |
| models.Data = append(models.Data, ModelObj{ID: alias, Object: "model", Created: now, OwnedBy: "nvidia"}) |
| } |
| w.Header().Set("Content-Type", "application/json") |
| json.NewEncoder(w).Encode(models) |
| } |
|
|
| func handleBaseURL(w http.ResponseWriter, r *http.Request) { |
| host := os.Getenv("SPACE_HOST") |
| if host == "" { |
| host = r.Host |
| } |
| w.Header().Set("Content-Type", "application/json") |
| fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host) |
| } |
|
|
| func handleChat(w http.ResponseWriter, r *http.Request) { |
| if !authenticate(r) { |
| http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized) |
| return |
| } |
| if r.Method != http.MethodPost { |
| http.Error(w, `{"error":{"message":"Method not allowed"}}`, http.StatusMethodNotAllowed) |
| return |
| } |
| var req ChatRequest |
| if err := json.NewDecoder(r.Body).Decode(&req); err != nil { |
| http.Error(w, `{"error":{"message":"Invalid request body"}}`, http.StatusBadRequest) |
| return |
| } |
|
|
| modelID := resolveModel(req.Model) |
| req.Messages = injectSystemPrompt(req.Messages, modelID) |
|
|
| upstream := UpstreamRequest{ |
| Model: modelID, |
| Messages: req.Messages, |
| Stream: true, |
| Tools: req.Tools, |
| ToolChoice: req.ToolChoice, |
| Temperature: req.Temperature, |
| MaxTokens: req.MaxTokens, |
| TopP: req.TopP, |
| Stop: req.Stop, |
| } |
|
|
| |
| if modelID == "z-ai/glm4.7" { |
| upstream.ExtraBody = map[string]interface{}{ |
| "chat_template_kwargs": map[string]interface{}{ |
| "enable_thinking": false, |
| }, |
| } |
| } |
|
|
| body, err := json.Marshal(upstream) |
| if err != nil { |
| http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError) |
| return |
| } |
|
|
| upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body)) |
| if err != nil { |
| http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError) |
| return |
| } |
| upstreamReq.Header.Set("Content-Type", "application/json") |
| upstreamReq.Header.Set("Authorization", "Bearer "+NvidiaAPIKey) |
| upstreamReq.Header.Set("Accept", "text/event-stream") |
|
|
| client := &http.Client{Timeout: 300 * time.Second} |
| resp, err := client.Do(upstreamReq) |
| if err != nil { |
| http.Error(w, fmt.Sprintf(`{"error":{"message":"Upstream error: %s"}}`, err.Error()), http.StatusBadGateway) |
| return |
| } |
| defer resp.Body.Close() |
|
|
| if resp.StatusCode != http.StatusOK { |
| upstreamBody, _ := io.ReadAll(resp.Body) |
| w.Header().Set("Content-Type", "application/json") |
| w.WriteHeader(resp.StatusCode) |
| w.Write(upstreamBody) |
| return |
| } |
|
|
| w.Header().Set("Content-Type", "text/event-stream") |
| w.Header().Set("Cache-Control", "no-cache") |
| w.Header().Set("Connection", "keep-alive") |
| w.Header().Set("X-Accel-Buffering", "no") |
| w.WriteHeader(http.StatusOK) |
|
|
| flusher, canFlush := w.(http.Flusher) |
| scanner := bufio.NewScanner(resp.Body) |
| scanner.Buffer(make([]byte, 1024*1024), 1024*1024) |
|
|
| |
| accumulated := make(map[int]*AccumulatedToolCall) |
|
|
| flush := func(s string) { |
| fmt.Fprint(w, s) |
| if canFlush { |
| flusher.Flush() |
| } |
| } |
|
|
| for scanner.Scan() { |
| line := scanner.Text() |
|
|
| if !strings.HasPrefix(line, "data: ") { |
| flush(line + "\n") |
| continue |
| } |
|
|
| data := strings.TrimPrefix(line, "data: ") |
|
|
| if data == "[DONE]" { |
| flush("data: [DONE]\n\n") |
| continue |
| } |
|
|
| var chunk StreamChunk |
| if err := json.Unmarshal([]byte(data), &chunk); err != nil { |
| flush(line + "\n") |
| continue |
| } |
|
|
| hasToolCalls := false |
| for _, choice := range chunk.Choices { |
| if len(choice.Delta.ToolCalls) > 0 { |
| hasToolCalls = true |
| for _, tc := range choice.Delta.ToolCalls { |
| acc, ok := accumulated[tc.Index] |
| if !ok { |
| acc = &AccumulatedToolCall{} |
| accumulated[tc.Index] = acc |
| } |
| if tc.ID != "" { |
| acc.ID = tc.ID |
| } |
| if tc.Type != "" { |
| acc.Type = tc.Type |
| } |
| if tc.Function.Name != "" { |
| acc.Name += tc.Function.Name |
| } |
| acc.Args += tc.Function.Arguments |
| } |
| } |
|
|
| |
| if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { |
| |
| indices := make([]int, 0, len(accumulated)) |
| for idx := range accumulated { |
| indices = append(indices, idx) |
| } |
| sort.Ints(indices) |
|
|
| assembled := make([]map[string]interface{}, 0, len(indices)) |
| for _, idx := range indices { |
| acc := accumulated[idx] |
| assembled = append(assembled, map[string]interface{}{ |
| "index": idx, |
| "id": acc.ID, |
| "type": "function", |
| "function": map[string]string{ |
| "name": acc.Name, |
| "arguments": acc.Args, |
| }, |
| }) |
| } |
|
|
| fr := "tool_calls" |
| synthetic := map[string]interface{}{ |
| "id": chunk.ID, |
| "object": chunk.Object, |
| "created": chunk.Created, |
| "model": chunk.Model, |
| "choices": []map[string]interface{}{ |
| { |
| "index": choice.Index, |
| "delta": map[string]interface{}{ |
| "role": "assistant", |
| "content": nil, |
| "tool_calls": assembled, |
| }, |
| "finish_reason": fr, |
| }, |
| }, |
| } |
| out, _ := json.Marshal(synthetic) |
| flush("data: " + string(out) + "\n\n") |
| accumulated = make(map[int]*AccumulatedToolCall) |
| hasToolCalls = false |
| } |
| } |
|
|
| |
| if !hasToolCalls { |
| flush("data: " + data + "\n\n") |
| } |
| } |
| } |
|
|
| func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc { |
| return func(w http.ResponseWriter, r *http.Request) { |
| start := time.Now() |
| log.Printf("[%s] %s %s", r.Method, r.URL.Path, r.RemoteAddr) |
| next(w, r) |
| log.Printf("[%s] %s done in %s", r.Method, r.URL.Path, time.Since(start)) |
| } |
| } |
|
|
| func corsMiddleware(next http.HandlerFunc) http.HandlerFunc { |
| return func(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("Access-Control-Allow-Origin", "*") |
| w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") |
| w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key") |
| if r.Method == http.MethodOptions { |
| w.WriteHeader(http.StatusNoContent) |
| return |
| } |
| next(w, r) |
| } |
| } |
|
|
| func main() { |
| port := os.Getenv("PORT") |
| if port == "" { |
| port = "7860" |
| } |
| mux := http.NewServeMux() |
| mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(handleChat))) |
| mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(handleModels))) |
| mux.HandleFunc("/v1/base-url", corsMiddleware(handleBaseURL)) |
| mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { |
| w.WriteHeader(http.StatusOK) |
| w.Write([]byte(`{"status":"ok"}`)) |
| }) |
| log.Printf("Gateway starting on :%s", port) |
| if err := http.ListenAndServe(":"+port, mux); err != nil { |
| log.Fatal(err) |
| } |
| } |
|
|