| package main |
|
|
| import ( |
| "bytes" |
| "encoding/json" |
| "fmt" |
| "io" |
| "log" |
| "net/http" |
| "os" |
| ) |
|
|
| const ( |
| ollamaURL = "http://localhost:11434" |
| bearerToken = "connect" |
| defaultModel = "ingu627/exaone4.0:1.2b" |
| ) |
|
|
| type ChatRequest struct { |
| Model string `json:"model"` |
| Messages []ChatMessage `json:"messages"` |
| Stream bool `json:"stream"` |
| Options json.RawMessage `json:"options,omitempty"` |
| } |
|
|
| type ChatMessage struct { |
| Role string `json:"role"` |
| Content string `json:"content"` |
| } |
|
|
| type GenerateRequest struct { |
| Model string `json:"model"` |
| Prompt string `json:"prompt"` |
| Stream bool `json:"stream"` |
| Options json.RawMessage `json:"options,omitempty"` |
| } |
|
|
| func authMiddleware(next http.Handler) http.Handler { |
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| auth := r.Header.Get("Authorization") |
| expected := "Bearer " + bearerToken |
| if auth != expected { |
| w.Header().Set("Content-Type", "application/json") |
| http.Error(w, `{"error":"unauthorized"}`, http.StatusUnauthorized) |
| return |
| } |
| next.ServeHTTP(w, r) |
| }) |
| } |
|
|
| func proxyHandler(w http.ResponseWriter, r *http.Request) { |
| if r.Method != http.MethodPost { |
| http.Error(w, `{"error":"method not allowed"}`, http.StatusMethodNotAllowed) |
| return |
| } |
|
|
| body, err := io.ReadAll(r.Body) |
| if err != nil { |
| http.Error(w, `{"error":"failed to read request"}`, http.StatusBadRequest) |
| return |
| } |
| defer r.Body.Close() |
|
|
| var modified []byte |
|
|
| switch r.URL.Path { |
| case "/api/chat", "/v1/chat/completions": |
| var req ChatRequest |
| if err := json.Unmarshal(body, &req); err != nil { |
| http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest) |
| return |
| } |
| if req.Model == "" { |
| req.Model = defaultModel |
| } |
| req.Stream = true |
| modified, _ = json.Marshal(req) |
|
|
| case "/api/generate": |
| var req GenerateRequest |
| if err := json.Unmarshal(body, &req); err != nil { |
| http.Error(w, `{"error":"invalid json"}`, http.StatusBadRequest) |
| return |
| } |
| if req.Model == "" { |
| req.Model = defaultModel |
| } |
| req.Stream = true |
| modified, _ = json.Marshal(req) |
|
|
| default: |
| modified = body |
| } |
|
|
| target := ollamaURL + r.URL.Path |
| proxyReq, err := http.NewRequest(r.Method, target, bytes.NewReader(modified)) |
| if err != nil { |
| http.Error(w, `{"error":"proxy error"}`, http.StatusInternalServerError) |
| return |
| } |
| proxyReq.Header.Set("Content-Type", "application/json") |
|
|
| client := &http.Client{} |
| resp, err := client.Do(proxyReq) |
| if err != nil { |
| http.Error(w, fmt.Sprintf(`{"error":"ollama unreachable: %s"}`, err.Error()), http.StatusBadGateway) |
| return |
| } |
| defer resp.Body.Close() |
|
|
| w.Header().Set("Content-Type", resp.Header.Get("Content-Type")) |
| w.Header().Set("Transfer-Encoding", "chunked") |
| w.Header().Set("Cache-Control", "no-cache") |
| w.WriteHeader(resp.StatusCode) |
|
|
| flusher, ok := w.(http.Flusher) |
| buf := make([]byte, 4096) |
| for { |
| n, readErr := resp.Body.Read(buf) |
| if n > 0 { |
| w.Write(buf[:n]) |
| if ok { |
| flusher.Flush() |
| } |
| } |
| if readErr == io.EOF { |
| break |
| } |
| if readErr != nil { |
| break |
| } |
| } |
| } |
|
|
| func healthHandler(w http.ResponseWriter, r *http.Request) { |
| w.Header().Set("Content-Type", "application/json") |
| fmt.Fprintf(w, `{"status":"ok","model":"%s"}`, defaultModel) |
| } |
|
|
| func main() { |
| port := os.Getenv("PORT") |
| if port == "" { |
| port = "7860" |
| } |
|
|
| mux := http.NewServeMux() |
| mux.HandleFunc("/health", healthHandler) |
| mux.HandleFunc("/api/chat", proxyHandler) |
| mux.HandleFunc("/api/generate", proxyHandler) |
| mux.HandleFunc("/v1/chat/completions", proxyHandler) |
|
|
| protected := authMiddleware(mux) |
|
|
| log.Printf("Ollama proxy starting on port %s", port) |
| log.Printf("Model: %s | stream: always true | bearer: %s", defaultModel, bearerToken) |
|
|
| if err := http.ListenAndServe(":"+port, protected); err != nil { |
| log.Fatalf("Server failed: %v", err) |
| } |
| } |
|
|