|
|
package server |
|
|
|
|
|
import ( |
|
|
"context" |
|
|
"crypto/subtle" |
|
|
"encoding/json" |
|
|
"fmt" |
|
|
"net/http" |
|
|
"regexp" |
|
|
"strings" |
|
|
"time" |
|
|
|
|
|
"gcli2api/internal/codeassist" |
|
|
"gcli2api/internal/config" |
|
|
"gcli2api/internal/gemini" |
|
|
|
|
|
|
|
|
|
|
|
"github.com/sirupsen/logrus" |
|
|
"github.com/tiktoken-go/tokenizer" |
|
|
) |
|
|
|
|
|
var ( |
|
|
modelPathUnary = regexp.MustCompile(`^/v1beta/models/([^/]+):generateContent$`) |
|
|
modelPathStream = regexp.MustCompile(`^/v1beta/models/([^/]+):streamGenerateContent$`) |
|
|
) |
|
|
|
|
|
|
|
|
type CodeAssist interface { |
|
|
GenerateContent(ctx context.Context, model, project string, req gemini.GeminiRequest) (*gemini.GeminiAPIResponse, error) |
|
|
GenerateContentStream(ctx context.Context, model, project string, req gemini.GeminiRequest) (<-chan gemini.GeminiAPIResponse, <-chan error) |
|
|
} |
|
|
|
|
|
type Server struct { |
|
|
cfg config.Config |
|
|
httpCli *http.Client |
|
|
caClient CodeAssist |
|
|
|
|
|
sem chan struct{} |
|
|
} |
|
|
|
|
|
func New(cfg config.Config, httpCli *http.Client) *Server { |
|
|
|
|
|
if cfg.RequestMaxRetries == 0 { |
|
|
cfg.RequestMaxRetries = 3 |
|
|
} |
|
|
if cfg.RequestBaseDelayMillis == 0 { |
|
|
cfg.RequestBaseDelayMillis = 1000 |
|
|
} |
|
|
if cfg.RequestMaxBodyBytes == 0 { |
|
|
cfg.RequestMaxBodyBytes = 16 * 1024 * 1024 |
|
|
} |
|
|
if cfg.MaxConcurrentRequests == 0 { |
|
|
cfg.MaxConcurrentRequests = 64 |
|
|
} |
|
|
return &Server{ |
|
|
cfg: cfg, |
|
|
httpCli: httpCli, |
|
|
caClient: codeassist.NewCaClient(httpCli, cfg.RequestMaxRetries, time.Duration(cfg.RequestBaseDelayMillis)*time.Millisecond), |
|
|
sem: make(chan struct{}, cfg.MaxConcurrentRequests), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
func NewWithCAClient(cfg config.Config, ca CodeAssist) *Server { |
|
|
|
|
|
if cfg.RequestMaxRetries == 0 { |
|
|
cfg.RequestMaxRetries = 3 |
|
|
} |
|
|
if cfg.RequestBaseDelayMillis == 0 { |
|
|
cfg.RequestBaseDelayMillis = 1000 |
|
|
} |
|
|
if cfg.RequestMaxBodyBytes == 0 { |
|
|
cfg.RequestMaxBodyBytes = 16 * 1024 * 1024 |
|
|
} |
|
|
if cfg.MaxConcurrentRequests == 0 { |
|
|
cfg.MaxConcurrentRequests = 64 |
|
|
} |
|
|
return &Server{cfg: cfg, caClient: ca, sem: make(chan struct{}, cfg.MaxConcurrentRequests)} |
|
|
} |
|
|
|
|
|
func (s *Server) Router() http.Handler { |
|
|
mux := http.NewServeMux() |
|
|
mux.HandleFunc("/health", s.handleHealth) |
|
|
mux.HandleFunc("/v1beta/models", s.handleListModels) |
|
|
mux.HandleFunc("/v1beta/models/", s.handleModel) |
|
|
|
|
|
return s.withRecover(s.withLogging(s.withConcurrencyLimit(mux))) |
|
|
} |
|
|
|
|
|
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { |
|
|
w.Header().Set("Content-Type", "application/json") |
|
|
w.WriteHeader(http.StatusOK) |
|
|
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) |
|
|
} |
|
|
|
|
|
func (s *Server) authorize(r *http.Request) bool { |
|
|
key := s.cfg.AuthKey |
|
|
if key == "" { |
|
|
return true |
|
|
} |
|
|
if ah := r.Header.Get("Authorization"); ah != "" { |
|
|
const p = "Bearer " |
|
|
if strings.HasPrefix(ah, p) { |
|
|
|
|
|
if 1 == subtle.ConstantTimeCompare([]byte(strings.TrimSpace(ah[len(p):])), []byte(key)) { |
|
|
return true |
|
|
} |
|
|
} |
|
|
} |
|
|
if h := r.Header.Get("x-goog-api-key"); h != "" { |
|
|
if 1 == subtle.ConstantTimeCompare([]byte(h), []byte(key)) { |
|
|
return true |
|
|
} |
|
|
} |
|
|
|
|
|
if qk := r.URL.Query().Get("key"); qk != "" { |
|
|
if 1 == subtle.ConstantTimeCompare([]byte(qk), []byte(key)) { |
|
|
return true |
|
|
} |
|
|
} |
|
|
return false |
|
|
} |
|
|
|
|
|
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) { |
|
|
if r.Method != http.MethodGet { |
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) |
|
|
return |
|
|
} |
|
|
if !s.authorize(r) { |
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized) |
|
|
return |
|
|
} |
|
|
w.Header().Set("Content-Type", "application/json") |
|
|
_ = json.NewEncoder(w).Encode(listModels()) |
|
|
} |
|
|
|
|
|
func (s *Server) handleModel(w http.ResponseWriter, r *http.Request) { |
|
|
if !s.authorize(r) { |
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized) |
|
|
return |
|
|
} |
|
|
if r.Method != http.MethodPost { |
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed) |
|
|
return |
|
|
} |
|
|
path := r.URL.Path |
|
|
if m := modelPathUnary.FindStringSubmatch(path); m != nil { |
|
|
model := m[1] |
|
|
s.handleGenerateContent(model, w, r) |
|
|
return |
|
|
} |
|
|
if m := modelPathStream.FindStringSubmatch(path); m != nil { |
|
|
model := m[1] |
|
|
s.handleStreamGenerateContent(model, w, r) |
|
|
return |
|
|
} |
|
|
http.NotFound(w, r) |
|
|
} |
|
|
|
|
|
func (s *Server) validateModel(model string) bool { |
|
|
return gemini.IsSupportedModel(model) |
|
|
} |
|
|
|
|
|
func (s *Server) decodeGeminiRequest(r *http.Request) (gemini.GeminiRequest, error) { |
|
|
var req gemini.GeminiRequest |
|
|
dec := json.NewDecoder(r.Body) |
|
|
if err := dec.Decode(&req); err != nil { |
|
|
return req, err |
|
|
} |
|
|
req = gemini.NormalizeGeminiRequest(req) |
|
|
return req, nil |
|
|
} |
|
|
|
|
|
func (s *Server) handleGenerateContent(model string, w http.ResponseWriter, r *http.Request) { |
|
|
if !s.validateModel(model) { |
|
|
http.Error(w, "unknown model", http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
|
|
|
r.Body = http.MaxBytesReader(w, r.Body, s.cfg.RequestMaxBodyBytes) |
|
|
req, err := s.decodeGeminiRequest(r) |
|
|
if err != nil { |
|
|
http.Error(w, fmt.Sprintf("bad request: %v", err), http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
|
|
|
var thinking any |
|
|
if req.GenerationConfig != nil { |
|
|
thinking = req.GenerationConfig.ThinkingConfig |
|
|
} |
|
|
totalTokens := countRequestTokens(req) |
|
|
logrus.WithFields(logrus.Fields{ |
|
|
"model": model, |
|
|
"thinkingConfig": thinking, |
|
|
"totalTokens": totalTokens, |
|
|
}).Info("sending to upstream") |
|
|
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute) |
|
|
defer cancel() |
|
|
resp, err := s.caClient.GenerateContent(ctx, model, "", req) |
|
|
if err != nil { |
|
|
http.Error(w, err.Error(), httpStatusFromError(err)) |
|
|
return |
|
|
} |
|
|
w.Header().Set("Content-Type", "application/json") |
|
|
_ = json.NewEncoder(w).Encode(resp) |
|
|
} |
|
|
|
|
|
func (s *Server) handleStreamGenerateContent(model string, w http.ResponseWriter, r *http.Request) { |
|
|
if !s.validateModel(model) { |
|
|
http.Error(w, "unknown model", http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
|
|
|
r.Body = http.MaxBytesReader(w, r.Body, s.cfg.RequestMaxBodyBytes) |
|
|
req, err := s.decodeGeminiRequest(r) |
|
|
if err != nil { |
|
|
http.Error(w, fmt.Sprintf("bad request: %v", err), http.StatusBadRequest) |
|
|
return |
|
|
} |
|
|
|
|
|
flusher, ok := w.(http.Flusher) |
|
|
if !ok { |
|
|
logrus.Warn("streaming unsupported") |
|
|
http.Error(w, "streaming unsupported", http.StatusInternalServerError) |
|
|
return |
|
|
} |
|
|
|
|
|
w.Header().Set("Content-Type", "text/event-stream") |
|
|
w.Header().Set("Cache-Control", "no-cache") |
|
|
w.Header().Set("Connection", "keep-alive") |
|
|
w.Header().Set("X-Accel-Buffering", "no") |
|
|
|
|
|
ctx, cancel := context.WithCancel(r.Context()) |
|
|
defer cancel() |
|
|
out, errs := s.caClient.GenerateContentStream(ctx, model, "", req) |
|
|
|
|
|
|
|
|
var thinking any |
|
|
if req.GenerationConfig != nil { |
|
|
thinking = req.GenerationConfig.ThinkingConfig |
|
|
} |
|
|
totalTokens := countRequestTokens(req) |
|
|
logrus.WithFields(logrus.Fields{ |
|
|
"model": model, |
|
|
"thinkingConfig": thinking, |
|
|
"totalTokens": totalTokens, |
|
|
}).Info("sending to upstream") |
|
|
enc := json.NewEncoder(w) |
|
|
for { |
|
|
select { |
|
|
case g, ok := <-out: |
|
|
if !ok { |
|
|
return |
|
|
} |
|
|
|
|
|
if _, err := fmt.Fprint(w, "data: "); err != nil { |
|
|
logrus.Errorf("error writing data prefix: %v", err) |
|
|
return |
|
|
} |
|
|
if err := enc.Encode(g); err != nil { |
|
|
return |
|
|
} |
|
|
|
|
|
if _, err := fmt.Fprint(w, "\n"); err != nil { |
|
|
logrus.Errorf("error writing newline: %v", err) |
|
|
return |
|
|
} |
|
|
flusher.Flush() |
|
|
case e, ok := <-errs: |
|
|
|
|
|
|
|
|
|
|
|
if !ok || e == nil { |
|
|
|
|
|
errs = nil |
|
|
continue |
|
|
} |
|
|
|
|
|
if _, err := fmt.Fprint(w, "event: error\n"); err != nil { |
|
|
logrus.Errorf("error writing error event: %v", err) |
|
|
return |
|
|
} |
|
|
if _, err := fmt.Fprintf(w, "data: {\"error\":{\"message\":%q}}\n\n", e.Error()); err != nil { |
|
|
logrus.Errorf("error writing error data: %v", err) |
|
|
return |
|
|
} |
|
|
flusher.Flush() |
|
|
return |
|
|
case <-ctx.Done(): |
|
|
return |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func countRequestTokens(req gemini.GeminiRequest) int { |
|
|
enc, err := tokenizer.Get(tokenizer.O200kBase) |
|
|
if err != nil { |
|
|
return 0 |
|
|
} |
|
|
total := 0 |
|
|
|
|
|
|
|
|
for _, c := range req.Contents { |
|
|
for _, p := range c.Parts { |
|
|
if p.Text != "" { |
|
|
if n, err := enc.Count(p.Text); err == nil { |
|
|
total += n |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
return total |
|
|
} |
|
|
|
|
|
func httpStatusFromError(err error) int { |
|
|
|
|
|
s := err.Error() |
|
|
if strings.Contains(s, "status 401") { |
|
|
return http.StatusUnauthorized |
|
|
} |
|
|
if strings.Contains(s, "status 403") { |
|
|
return http.StatusForbidden |
|
|
} |
|
|
if strings.Contains(s, "status 429") { |
|
|
return http.StatusTooManyRequests |
|
|
} |
|
|
if strings.Contains(s, "status 5") { |
|
|
return http.StatusBadGateway |
|
|
} |
|
|
return http.StatusBadRequest |
|
|
} |
|
|
|
|
|
func listModels() interface{} { |
|
|
type model struct { |
|
|
Name string `json:"name"` |
|
|
Version string `json:"version"` |
|
|
DisplayName string `json:"displayName"` |
|
|
Description string `json:"description"` |
|
|
SupportedGenerationMethods []string `json:"supportedGenerationMethods"` |
|
|
} |
|
|
out := struct { |
|
|
Models []model `json:"models"` |
|
|
}{Models: make([]model, 0, len(gemini.SupportedModels))} |
|
|
for _, m := range gemini.SupportedModels { |
|
|
out.Models = append(out.Models, model{ |
|
|
Name: "models/" + m.Name, |
|
|
Version: "001", |
|
|
DisplayName: m.DisplayName, |
|
|
Description: m.Description, |
|
|
SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"}, |
|
|
}) |
|
|
} |
|
|
return out |
|
|
} |