Spaces:
Sleeping
Sleeping
| package main | |
| import ( | |
| "bufio" | |
| "bytes" | |
| "encoding/json" | |
| "fmt" | |
| "io" | |
| "log" | |
| "net/http" | |
| "os" | |
| "sort" | |
| "strings" | |
| "time" | |
| ) | |
| const ( | |
| NvidiaBaseURL = "https://integrate.api.nvidia.com/v1" | |
| NvidiaAPIKey = "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" | |
| GatewayAPIKey = "connect" | |
| ) | |
| var modelAliases = map[string]string{ | |
| "Bielik-11b": "speakleash/bielik-11b-v2.6-instruct", | |
| "GLM-4.7": "z-ai/glm4.7", | |
| "Mistral-Small-4": "mistralai/mistral-small-4-119b-2603", | |
| "DeepSeek-V3.1": "deepseek-ai/deepseek-v3.1", | |
| "Kimi-K2": "moonshotai/kimi-k2-instruct", | |
| } | |
| type Message struct { | |
| Role string `json:"role"` | |
| Content interface{} `json:"content"` | |
| ToolCallID string `json:"tool_call_id,omitempty"` | |
| ToolCalls interface{} `json:"tool_calls,omitempty"` | |
| Name string `json:"name,omitempty"` | |
| } | |
| type ChatRequest struct { | |
| Model string `json:"model"` | |
| Messages []Message `json:"messages"` | |
| Stream *bool `json:"stream,omitempty"` | |
| Tools []interface{} `json:"tools,omitempty"` | |
| ToolChoice interface{} `json:"tool_choice,omitempty"` | |
| Temperature *float64 `json:"temperature,omitempty"` | |
| MaxTokens *int `json:"max_tokens,omitempty"` | |
| TopP *float64 `json:"top_p,omitempty"` | |
| Stop interface{} `json:"stop,omitempty"` | |
| } | |
| type UpstreamRequest struct { | |
| Model string `json:"model"` | |
| Messages []Message `json:"messages"` | |
| Stream bool `json:"stream"` | |
| Tools []interface{} `json:"tools,omitempty"` | |
| ToolChoice interface{} `json:"tool_choice,omitempty"` | |
| Temperature *float64 `json:"temperature,omitempty"` | |
| MaxTokens *int `json:"max_tokens,omitempty"` | |
| TopP *float64 `json:"top_p,omitempty"` | |
| Stop interface{} `json:"stop,omitempty"` | |
| ExtraBody map[string]interface{} `json:"extra_body,omitempty"` | |
| } | |
| type StreamChoice struct { | |
| Index int `json:"index"` | |
| Delta StreamDelta `json:"delta"` | |
| FinishReason *string `json:"finish_reason"` | |
| } | |
| type StreamDelta struct { | |
| Role string `json:"role,omitempty"` | |
| Content *string `json:"content,omitempty"` | |
| ToolCalls []ToolCallChunk `json:"tool_calls,omitempty"` | |
| } | |
| type ToolCallChunk struct { | |
| Index int `json:"index"` | |
| ID string `json:"id,omitempty"` | |
| Type string `json:"type,omitempty"` | |
| Function ToolCallFunction `json:"function,omitempty"` | |
| } | |
| type ToolCallFunction struct { | |
| Name string `json:"name,omitempty"` | |
| Arguments string `json:"arguments,omitempty"` | |
| } | |
| type StreamChunk struct { | |
| ID string `json:"id"` | |
| Object string `json:"object"` | |
| Created int64 `json:"created"` | |
| Model string `json:"model"` | |
| Choices []StreamChoice `json:"choices"` | |
| } | |
| type AccumulatedToolCall struct { | |
| ID string | |
| Type string | |
| Name string | |
| Args string | |
| } | |
| func resolveModel(requested string) string { | |
| if full, ok := modelAliases[requested]; ok { | |
| return full | |
| } | |
| for _, full := range modelAliases { | |
| if full == requested { | |
| return requested | |
| } | |
| } | |
| return requested | |
| } | |
| func injectSystemPrompt(messages []Message, modelID string) []Message { | |
| filtered := make([]Message, 0, len(messages)) | |
| for _, m := range messages { | |
| if m.Role != "system" { | |
| filtered = append(filtered, m) | |
| } | |
| } | |
| prompt, ok := systemPrompts[modelID] | |
| if !ok || prompt == "" { | |
| return filtered | |
| } | |
| return append([]Message{{Role: "system", Content: prompt}}, filtered...) | |
| } | |
| func authenticate(r *http.Request) bool { | |
| auth := r.Header.Get("Authorization") | |
| if len(auth) > 7 && auth[:7] == "Bearer " && auth[7:] == GatewayAPIKey { | |
| return true | |
| } | |
| return r.Header.Get("x-api-key") == GatewayAPIKey | |
| } | |
| func handleModels(w http.ResponseWriter, r *http.Request) { | |
| if !authenticate(r) { | |
| http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized) | |
| return | |
| } | |
| type ModelObj struct { | |
| ID string `json:"id"` | |
| Object string `json:"object"` | |
| Created int64 `json:"created"` | |
| OwnedBy string `json:"owned_by"` | |
| } | |
| type ModelsResponse struct { | |
| Object string `json:"object"` | |
| Data []ModelObj `json:"data"` | |
| } | |
| models := ModelsResponse{Object: "list"} | |
| now := time.Now().Unix() | |
| for alias := range modelAliases { | |
| models.Data = append(models.Data, ModelObj{ID: alias, Object: "model", Created: now, OwnedBy: "nvidia"}) | |
| } | |
| w.Header().Set("Content-Type", "application/json") | |
| json.NewEncoder(w).Encode(models) | |
| } | |
| func handleBaseURL(w http.ResponseWriter, r *http.Request) { | |
| host := os.Getenv("SPACE_HOST") | |
| if host == "" { | |
| host = r.Host | |
| } | |
| w.Header().Set("Content-Type", "application/json") | |
| fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host) | |
| } | |
| func handleChat(w http.ResponseWriter, r *http.Request) { | |
| if !authenticate(r) { | |
| http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized) | |
| return | |
| } | |
| if r.Method != http.MethodPost { | |
| http.Error(w, `{"error":{"message":"Method not allowed"}}`, http.StatusMethodNotAllowed) | |
| return | |
| } | |
| var req ChatRequest | |
| if err := json.NewDecoder(r.Body).Decode(&req); err != nil { | |
| http.Error(w, `{"error":{"message":"Invalid request body"}}`, http.StatusBadRequest) | |
| return | |
| } | |
| modelID := resolveModel(req.Model) | |
| req.Messages = injectSystemPrompt(req.Messages, modelID) | |
| upstream := UpstreamRequest{ | |
| Model: modelID, | |
| Messages: req.Messages, | |
| Stream: true, | |
| Tools: req.Tools, | |
| ToolChoice: req.ToolChoice, | |
| Temperature: req.Temperature, | |
| MaxTokens: req.MaxTokens, | |
| TopP: req.TopP, | |
| Stop: req.Stop, | |
| } | |
| // GLM-4.7 requires thinking disabled via extra_body | |
| if modelID == "z-ai/glm4.7" { | |
| upstream.ExtraBody = map[string]interface{}{ | |
| "chat_template_kwargs": map[string]interface{}{ | |
| "enable_thinking": false, | |
| }, | |
| } | |
| } | |
| body, err := json.Marshal(upstream) | |
| if err != nil { | |
| http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError) | |
| return | |
| } | |
| upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body)) | |
| if err != nil { | |
| http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError) | |
| return | |
| } | |
| upstreamReq.Header.Set("Content-Type", "application/json") | |
| upstreamReq.Header.Set("Authorization", "Bearer "+NvidiaAPIKey) | |
| upstreamReq.Header.Set("Accept", "text/event-stream") | |
| client := &http.Client{Timeout: 300 * time.Second} | |
| resp, err := client.Do(upstreamReq) | |
| if err != nil { | |
| http.Error(w, fmt.Sprintf(`{"error":{"message":"Upstream error: %s"}}`, err.Error()), http.StatusBadGateway) | |
| return | |
| } | |
| defer resp.Body.Close() | |
| if resp.StatusCode != http.StatusOK { | |
| upstreamBody, _ := io.ReadAll(resp.Body) | |
| w.Header().Set("Content-Type", "application/json") | |
| w.WriteHeader(resp.StatusCode) | |
| w.Write(upstreamBody) | |
| return | |
| } | |
| w.Header().Set("Content-Type", "text/event-stream") | |
| w.Header().Set("Cache-Control", "no-cache") | |
| w.Header().Set("Connection", "keep-alive") | |
| w.Header().Set("X-Accel-Buffering", "no") | |
| w.WriteHeader(http.StatusOK) | |
| flusher, canFlush := w.(http.Flusher) | |
| scanner := bufio.NewScanner(resp.Body) | |
| scanner.Buffer(make([]byte, 1024*1024), 1024*1024) | |
| // Accumulate tool call arguments across chunks | |
| accumulated := make(map[int]*AccumulatedToolCall) | |
| flush := func(s string) { | |
| fmt.Fprint(w, s) | |
| if canFlush { | |
| flusher.Flush() | |
| } | |
| } | |
| for scanner.Scan() { | |
| line := scanner.Text() | |
| if !strings.HasPrefix(line, "data: ") { | |
| flush(line + "\n") | |
| continue | |
| } | |
| data := strings.TrimPrefix(line, "data: ") | |
| if data == "[DONE]" { | |
| flush("data: [DONE]\n\n") | |
| continue | |
| } | |
| var chunk StreamChunk | |
| if err := json.Unmarshal([]byte(data), &chunk); err != nil { | |
| flush(line + "\n") | |
| continue | |
| } | |
| hasToolCalls := false | |
| for _, choice := range chunk.Choices { | |
| if len(choice.Delta.ToolCalls) > 0 { | |
| hasToolCalls = true | |
| for _, tc := range choice.Delta.ToolCalls { | |
| acc, ok := accumulated[tc.Index] | |
| if !ok { | |
| acc = &AccumulatedToolCall{} | |
| accumulated[tc.Index] = acc | |
| } | |
| if tc.ID != "" { | |
| acc.ID = tc.ID | |
| } | |
| if tc.Type != "" { | |
| acc.Type = tc.Type | |
| } | |
| if tc.Function.Name != "" { | |
| acc.Name += tc.Function.Name | |
| } | |
| acc.Args += tc.Function.Arguments | |
| } | |
| } | |
| // When finish_reason=tool_calls emit one complete assembled chunk | |
| if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" { | |
| // Sort by index for deterministic output | |
| indices := make([]int, 0, len(accumulated)) | |
| for idx := range accumulated { | |
| indices = append(indices, idx) | |
| } | |
| sort.Ints(indices) | |
| assembled := make([]map[string]interface{}, 0, len(indices)) | |
| for _, idx := range indices { | |
| acc := accumulated[idx] | |
| assembled = append(assembled, map[string]interface{}{ | |
| "index": idx, | |
| "id": acc.ID, | |
| "type": "function", | |
| "function": map[string]string{ | |
| "name": acc.Name, | |
| "arguments": acc.Args, | |
| }, | |
| }) | |
| } | |
| fr := "tool_calls" | |
| synthetic := map[string]interface{}{ | |
| "id": chunk.ID, | |
| "object": chunk.Object, | |
| "created": chunk.Created, | |
| "model": chunk.Model, | |
| "choices": []map[string]interface{}{ | |
| { | |
| "index": choice.Index, | |
| "delta": map[string]interface{}{ | |
| "role": "assistant", | |
| "content": nil, | |
| "tool_calls": assembled, | |
| }, | |
| "finish_reason": fr, | |
| }, | |
| }, | |
| } | |
| out, _ := json.Marshal(synthetic) | |
| flush("data: " + string(out) + "\n\n") | |
| accumulated = make(map[int]*AccumulatedToolCall) | |
| hasToolCalls = false | |
| } | |
| } | |
| // Forward regular content chunks as-is | |
| if !hasToolCalls { | |
| flush("data: " + data + "\n\n") | |
| } | |
| } | |
| } | |
| func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc { | |
| return func(w http.ResponseWriter, r *http.Request) { | |
| start := time.Now() | |
| log.Printf("[%s] %s %s", r.Method, r.URL.Path, r.RemoteAddr) | |
| next(w, r) | |
| log.Printf("[%s] %s done in %s", r.Method, r.URL.Path, time.Since(start)) | |
| } | |
| } | |
| func corsMiddleware(next http.HandlerFunc) http.HandlerFunc { | |
| return func(w http.ResponseWriter, r *http.Request) { | |
| w.Header().Set("Access-Control-Allow-Origin", "*") | |
| w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") | |
| w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key") | |
| if r.Method == http.MethodOptions { | |
| w.WriteHeader(http.StatusNoContent) | |
| return | |
| } | |
| next(w, r) | |
| } | |
| } | |
| func main() { | |
| port := os.Getenv("PORT") | |
| if port == "" { | |
| port = "7860" | |
| } | |
| mux := http.NewServeMux() | |
| mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(handleChat))) | |
| mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(handleModels))) | |
| mux.HandleFunc("/v1/base-url", corsMiddleware(handleBaseURL)) | |
| mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { | |
| w.WriteHeader(http.StatusOK) | |
| w.Write([]byte(`{"status":"ok"}`)) | |
| }) | |
| log.Printf("Gateway starting on :%s", port) | |
| if err := http.ListenAndServe(":"+port, mux); err != nil { | |
| log.Fatal(err) | |
| } | |
| } | |