package main import ( "bufio" "bytes" "encoding/json" "fmt" "io" "log" "net/http" "os" "sort" "strings" "time" ) const ( NvidiaBaseURL = "https://integrate.api.nvidia.com/v1" NvidiaAPIKey = "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" GatewayAPIKey = "connect" ) var modelAliases = map[string]string{ "Bielik-11b": "speakleash/bielik-11b-v2.6-instruct", "GLM-4.7": "z-ai/glm4.7", "Mistral-Small-4": "mistralai/mistral-small-4-119b-2603", "DeepSeek-V3.1": "deepseek-ai/deepseek-v3.1", "Kimi-K2": "moonshotai/kimi-k2-instruct", } type Message struct { Role string `json:"role"` Content interface{} `json:"content"` ToolCallID string `json:"tool_call_id,omitempty"` ToolCalls interface{} `json:"tool_calls,omitempty"` Name string `json:"name,omitempty"` } type ChatRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` Stream *bool `json:"stream,omitempty"` Tools []interface{} `json:"tools,omitempty"` ToolChoice interface{} `json:"tool_choice,omitempty"` Temperature *float64 `json:"temperature,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"` TopP *float64 `json:"top_p,omitempty"` Stop interface{} `json:"stop,omitempty"` } type UpstreamRequest struct { Model string `json:"model"` Messages []Message `json:"messages"` Stream bool `json:"stream"` Tools []interface{} `json:"tools,omitempty"` ToolChoice interface{} `json:"tool_choice,omitempty"` Temperature *float64 `json:"temperature,omitempty"` MaxTokens *int `json:"max_tokens,omitempty"` TopP *float64 `json:"top_p,omitempty"` Stop interface{} `json:"stop,omitempty"` ExtraBody map[string]interface{} `json:"extra_body,omitempty"` } type StreamChoice struct { Index int `json:"index"` Delta StreamDelta `json:"delta"` FinishReason *string `json:"finish_reason"` } type StreamDelta struct { Role string `json:"role,omitempty"` Content *string `json:"content,omitempty"` ToolCalls []ToolCallChunk `json:"tool_calls,omitempty"` } type ToolCallChunk struct { Index int `json:"index"` ID string `json:"id,omitempty"` Type string `json:"type,omitempty"` Function ToolCallFunction `json:"function,omitempty"` } type ToolCallFunction struct { Name string `json:"name,omitempty"` Arguments string `json:"arguments,omitempty"` } type StreamChunk struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []StreamChoice `json:"choices"` } type AccumulatedToolCall struct { ID string Type string Name string Args string } func resolveModel(requested string) string { if full, ok := modelAliases[requested]; ok { return full } for _, full := range modelAliases { if full == requested { return requested } } return requested } func injectSystemPrompt(messages []Message, modelID string) []Message { filtered := make([]Message, 0, len(messages)) for _, m := range messages { if m.Role != "system" { filtered = append(filtered, m) } } prompt, ok := systemPrompts[modelID] if !ok || prompt == "" { return filtered } return append([]Message{{Role: "system", Content: prompt}}, filtered...) } func authenticate(r *http.Request) bool { auth := r.Header.Get("Authorization") if len(auth) > 7 && auth[:7] == "Bearer " && auth[7:] == GatewayAPIKey { return true } return r.Header.Get("x-api-key") == GatewayAPIKey } func handleModels(w http.ResponseWriter, r *http.Request) { if !authenticate(r) { http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized) return } type ModelObj struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` OwnedBy string `json:"owned_by"` } type ModelsResponse struct { Object string `json:"object"` Data []ModelObj `json:"data"` } models := ModelsResponse{Object: "list"} now := time.Now().Unix() for alias := range modelAliases { models.Data = append(models.Data, ModelObj{ID: alias, Object: "model", Created: now, OwnedBy: "nvidia"}) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(models) } func handleBaseURL(w http.ResponseWriter, r *http.Request) { host := os.Getenv("SPACE_HOST") if host == "" { host = r.Host } w.Header().Set("Content-Type", "application/json") fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host) } func handleChat(w http.ResponseWriter, r *http.Request) { if !authenticate(r) { http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized) return } if r.Method != http.MethodPost { http.Error(w, `{"error":{"message":"Method not allowed"}}`, http.StatusMethodNotAllowed) return } var req ChatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, `{"error":{"message":"Invalid request body"}}`, http.StatusBadRequest) return } modelID := resolveModel(req.Model) req.Messages = injectSystemPrompt(req.Messages, modelID) upstream := UpstreamRequest{ Model: modelID, Messages: req.Messages, Stream: true, Tools: req.Tools, ToolChoice: req.ToolChoice, Temperature: req.Temperature, MaxTokens: req.MaxTokens, TopP: req.TopP, Stop: req.Stop, } // GLM-4.7 requires thinking disabled via extra_body if modelID == "z-ai/glm4.7" { upstream.ExtraBody = map[string]interface{}{ "chat_template_kwargs": map[string]interface{}{ "enable_thinking": false, }, } } body, err := json.Marshal(upstream) if err != nil { http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError) return } upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body)) if err != nil { http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError) return } upstreamReq.Header.Set("Content-Type", "application/json") upstreamReq.Header.Set("Authorization", "Bearer "+NvidiaAPIKey) upstreamReq.Header.Set("Accept", "text/event-stream") client := &http.Client{Timeout: 300 * time.Second} resp, err := client.Do(upstreamReq) if err != nil { http.Error(w, fmt.Sprintf(`{"error":{"message":"Upstream error: %s"}}`, err.Error()), http.StatusBadGateway) return } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { upstreamBody, _ := io.ReadAll(resp.Body) w.Header().Set("Content-Type", "application/json") w.WriteHeader(resp.StatusCode) w.Write(upstreamBody) return } w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") w.WriteHeader(http.StatusOK) flusher, canFlush := w.(http.Flusher) scanner := bufio.NewScanner(resp.Body) scanner.Buffer(make([]byte, 1024*1024), 1024*1024) // Accumulate tool call arguments across chunks accumulated := make(map[int]*AccumulatedToolCall) flush := func(s string) { fmt.Fprint(w, s) if canFlush { flusher.Flush() } } for scanner.Scan() { line := scanner.Text() if !strings.HasPrefix(line, "data: ") { flush(line + "\n") continue } data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { flush("data: [DONE]\n\n") continue } var chunk StreamChunk if err := json.Unmarshal([]byte(data), &chunk); err != nil { flush(line + "\n") continue } hasToolCalls := false for _, choice := range chunk.Choices { if len(choice.Delta.ToolCalls) > 0 { hasToolCalls = true for _, tc := range choice.Delta.ToolCalls { acc, ok := accumulated[tc.Index] if !ok { acc = &AccumulatedToolCall{} accumulated[tc.Index] = acc } if tc.ID != "" { acc.ID = tc.ID } if tc.Type != "" { acc.Type = tc.Type } if tc.Function.Name != "" { acc.Name += tc.Function.Name } acc.Args += tc.Function.Arguments } } // When finish_reason=tool_calls emit one complete assembled chunk if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { // Sort by index for deterministic output indices := make([]int, 0, len(accumulated)) for idx := range accumulated { indices = append(indices, idx) } sort.Ints(indices) assembled := make([]map[string]interface{}, 0, len(indices)) for _, idx := range indices { acc := accumulated[idx] assembled = append(assembled, map[string]interface{}{ "index": idx, "id": acc.ID, "type": "function", "function": map[string]string{ "name": acc.Name, "arguments": acc.Args, }, }) } fr := "tool_calls" synthetic := map[string]interface{}{ "id": chunk.ID, "object": chunk.Object, "created": chunk.Created, "model": chunk.Model, "choices": []map[string]interface{}{ { "index": choice.Index, "delta": map[string]interface{}{ "role": "assistant", "content": nil, "tool_calls": assembled, }, "finish_reason": fr, }, }, } out, _ := json.Marshal(synthetic) flush("data: " + string(out) + "\n\n") accumulated = make(map[int]*AccumulatedToolCall) hasToolCalls = false } } // Forward regular content chunks as-is if !hasToolCalls { flush("data: " + data + "\n\n") } } } func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { start := time.Now() log.Printf("[%s] %s %s", r.Method, r.URL.Path, r.RemoteAddr) next(w, r) log.Printf("[%s] %s done in %s", r.Method, r.URL.Path, time.Since(start)) } } func corsMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key") if r.Method == http.MethodOptions { w.WriteHeader(http.StatusNoContent) return } next(w, r) } } func main() { port := os.Getenv("PORT") if port == "" { port = "7860" } mux := http.NewServeMux() mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(handleChat))) mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(handleModels))) mux.HandleFunc("/v1/base-url", corsMiddleware(handleBaseURL)) mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte(`{"status":"ok"}`)) }) log.Printf("Gateway starting on :%s", port) if err := http.ListenAndServe(":"+port, mux); err != nil { log.Fatal(err) } }