miniapi / provider.ts
eranet111
Add function calling support and remove outdated processes
f438287
package main
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"sort"
"strings"
"time"
)
const (
NvidiaBaseURL = "https://integrate.api.nvidia.com/v1"
NvidiaAPIKey = "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"
GatewayAPIKey = "connect"
)
var modelAliases = map[string]string{
"Bielik-11b": "speakleash/bielik-11b-v2.6-instruct",
"GLM-4.7": "z-ai/glm4.7",
"Mistral-Small-4": "mistralai/mistral-small-4-119b-2603",
"DeepSeek-V3.1": "deepseek-ai/deepseek-v3.1",
"Kimi-K2": "moonshotai/kimi-k2-instruct",
}
type Message struct {
Role string `json:"role"`
Content interface{} `json:"content"`
ToolCallID string `json:"tool_call_id,omitempty"`
ToolCalls interface{} `json:"tool_calls,omitempty"`
Name string `json:"name,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Tools []interface{} `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Stop interface{} `json:"stop,omitempty"`
}
type UpstreamRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
Tools []interface{} `json:"tools,omitempty"`
ToolChoice interface{} `json:"tool_choice,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
Stop interface{} `json:"stop,omitempty"`
ExtraBody map[string]interface{} `json:"extra_body,omitempty"`
}
type StreamChoice struct {
Index int `json:"index"`
Delta StreamDelta `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type StreamDelta struct {
Role string `json:"role,omitempty"`
Content *string `json:"content,omitempty"`
ToolCalls []ToolCallChunk `json:"tool_calls,omitempty"`
}
type ToolCallChunk struct {
Index int `json:"index"`
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function ToolCallFunction `json:"function,omitempty"`
}
type ToolCallFunction struct {
Name string `json:"name,omitempty"`
Arguments string `json:"arguments,omitempty"`
}
type StreamChunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []StreamChoice `json:"choices"`
}
type AccumulatedToolCall struct {
ID string
Type string
Name string
Args string
}
func resolveModel(requested string) string {
if full, ok := modelAliases[requested]; ok {
return full
}
for _, full := range modelAliases {
if full == requested {
return requested
}
}
return requested
}
func injectSystemPrompt(messages []Message, modelID string) []Message {
filtered := make([]Message, 0, len(messages))
for _, m := range messages {
if m.Role != "system" {
filtered = append(filtered, m)
}
}
prompt, ok := systemPrompts[modelID]
if !ok || prompt == "" {
return filtered
}
return append([]Message{{Role: "system", Content: prompt}}, filtered...)
}
func authenticate(r *http.Request) bool {
auth := r.Header.Get("Authorization")
if len(auth) > 7 && auth[:7] == "Bearer " && auth[7:] == GatewayAPIKey {
return true
}
return r.Header.Get("x-api-key") == GatewayAPIKey
}
func handleModels(w http.ResponseWriter, r *http.Request) {
if !authenticate(r) {
http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
return
}
type ModelObj struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
OwnedBy string `json:"owned_by"`
}
type ModelsResponse struct {
Object string `json:"object"`
Data []ModelObj `json:"data"`
}
models := ModelsResponse{Object: "list"}
now := time.Now().Unix()
for alias := range modelAliases {
models.Data = append(models.Data, ModelObj{ID: alias, Object: "model", Created: now, OwnedBy: "nvidia"})
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(models)
}
func handleBaseURL(w http.ResponseWriter, r *http.Request) {
host := os.Getenv("SPACE_HOST")
if host == "" {
host = r.Host
}
w.Header().Set("Content-Type", "application/json")
fmt.Fprintf(w, `{"url":"https://%s/v1"}`, host)
}
func handleChat(w http.ResponseWriter, r *http.Request) {
if !authenticate(r) {
http.Error(w, `{"error":{"message":"Unauthorized"}}`, http.StatusUnauthorized)
return
}
if r.Method != http.MethodPost {
http.Error(w, `{"error":{"message":"Method not allowed"}}`, http.StatusMethodNotAllowed)
return
}
var req ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, `{"error":{"message":"Invalid request body"}}`, http.StatusBadRequest)
return
}
modelID := resolveModel(req.Model)
req.Messages = injectSystemPrompt(req.Messages, modelID)
upstream := UpstreamRequest{
Model: modelID,
Messages: req.Messages,
Stream: true,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
TopP: req.TopP,
Stop: req.Stop,
}
// GLM-4.7 requires thinking disabled via extra_body
if modelID == "z-ai/glm4.7" {
upstream.ExtraBody = map[string]interface{}{
"chat_template_kwargs": map[string]interface{}{
"enable_thinking": false,
},
}
}
body, err := json.Marshal(upstream)
if err != nil {
http.Error(w, `{"error":{"message":"Failed to marshal request"}}`, http.StatusInternalServerError)
return
}
upstreamReq, err := http.NewRequest(http.MethodPost, NvidiaBaseURL+"/chat/completions", bytes.NewReader(body))
if err != nil {
http.Error(w, `{"error":{"message":"Failed to create upstream request"}}`, http.StatusInternalServerError)
return
}
upstreamReq.Header.Set("Content-Type", "application/json")
upstreamReq.Header.Set("Authorization", "Bearer "+NvidiaAPIKey)
upstreamReq.Header.Set("Accept", "text/event-stream")
client := &http.Client{Timeout: 300 * time.Second}
resp, err := client.Do(upstreamReq)
if err != nil {
http.Error(w, fmt.Sprintf(`{"error":{"message":"Upstream error: %s"}}`, err.Error()), http.StatusBadGateway)
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
upstreamBody, _ := io.ReadAll(resp.Body)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(resp.StatusCode)
w.Write(upstreamBody)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(http.StatusOK)
flusher, canFlush := w.(http.Flusher)
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 1024*1024)
// Accumulate tool call arguments across chunks
accumulated := make(map[int]*AccumulatedToolCall)
flush := func(s string) {
fmt.Fprint(w, s)
if canFlush {
flusher.Flush()
}
}
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
flush(line + "\n")
continue
}
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
flush("data: [DONE]\n\n")
continue
}
var chunk StreamChunk
if err := json.Unmarshal([]byte(data), &chunk); err != nil {
flush(line + "\n")
continue
}
hasToolCalls := false
for _, choice := range chunk.Choices {
if len(choice.Delta.ToolCalls) > 0 {
hasToolCalls = true
for _, tc := range choice.Delta.ToolCalls {
acc, ok := accumulated[tc.Index]
if !ok {
acc = &AccumulatedToolCall{}
accumulated[tc.Index] = acc
}
if tc.ID != "" {
acc.ID = tc.ID
}
if tc.Type != "" {
acc.Type = tc.Type
}
if tc.Function.Name != "" {
acc.Name += tc.Function.Name
}
acc.Args += tc.Function.Arguments
}
}
// When finish_reason=tool_calls emit one complete assembled chunk
if choice.FinishReason != nil && *choice.FinishReason == "tool_calls" {
// Sort by index for deterministic output
indices := make([]int, 0, len(accumulated))
for idx := range accumulated {
indices = append(indices, idx)
}
sort.Ints(indices)
assembled := make([]map[string]interface{}, 0, len(indices))
for _, idx := range indices {
acc := accumulated[idx]
assembled = append(assembled, map[string]interface{}{
"index": idx,
"id": acc.ID,
"type": "function",
"function": map[string]string{
"name": acc.Name,
"arguments": acc.Args,
},
})
}
fr := "tool_calls"
synthetic := map[string]interface{}{
"id": chunk.ID,
"object": chunk.Object,
"created": chunk.Created,
"model": chunk.Model,
"choices": []map[string]interface{}{
{
"index": choice.Index,
"delta": map[string]interface{}{
"role": "assistant",
"content": nil,
"tool_calls": assembled,
},
"finish_reason": fr,
},
},
}
out, _ := json.Marshal(synthetic)
flush("data: " + string(out) + "\n\n")
accumulated = make(map[int]*AccumulatedToolCall)
hasToolCalls = false
}
}
// Forward regular content chunks as-is
if !hasToolCalls {
flush("data: " + data + "\n\n")
}
}
}
func loggingMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
log.Printf("[%s] %s %s", r.Method, r.URL.Path, r.RemoteAddr)
next(w, r)
log.Printf("[%s] %s done in %s", r.Method, r.URL.Path, time.Since(start))
}
}
func corsMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, x-api-key")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next(w, r)
}
}
func main() {
port := os.Getenv("PORT")
if port == "" {
port = "7860"
}
mux := http.NewServeMux()
mux.HandleFunc("/v1/chat/completions", corsMiddleware(loggingMiddleware(handleChat)))
mux.HandleFunc("/v1/models", corsMiddleware(loggingMiddleware(handleModels)))
mux.HandleFunc("/v1/base-url", corsMiddleware(handleBaseURL))
mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"ok"}`))
})
log.Printf("Gateway starting on :%s", port)
if err := http.ListenAndServe(":"+port, mux); err != nil {
log.Fatal(err)
}
}