gcli2api / server.go
smgc's picture
Update server.go
948c3a6 verified
package server
import (
"context"
"crypto/subtle"
"encoding/json"
"fmt"
"net/http"
"regexp"
"strings"
"time"
"gcli2api/internal/codeassist"
"gcli2api/internal/config"
"gcli2api/internal/gemini"
// "gcli2api/internal/utils"
"github.com/sirupsen/logrus"
"github.com/tiktoken-go/tokenizer"
)
var (
modelPathUnary = regexp.MustCompile(`^/v1beta/models/([^/]+):generateContent$`)
modelPathStream = regexp.MustCompile(`^/v1beta/models/([^/]+):streamGenerateContent$`)
)
// CodeAssist abstracts the client for easier testing.
type CodeAssist interface {
GenerateContent(ctx context.Context, model, project string, req gemini.GeminiRequest) (*gemini.GeminiAPIResponse, error)
GenerateContentStream(ctx context.Context, model, project string, req gemini.GeminiRequest) (<-chan gemini.GeminiAPIResponse, <-chan error)
}
type Server struct {
cfg config.Config
httpCli *http.Client
caClient CodeAssist
// sem is a simple semaphore for concurrency limiting
sem chan struct{}
}
func New(cfg config.Config, httpCli *http.Client) *Server {
// Apply safe defaults when fields are zero to match config.LoadConfig behavior
if cfg.RequestMaxRetries == 0 {
cfg.RequestMaxRetries = 3
}
if cfg.RequestBaseDelayMillis == 0 {
cfg.RequestBaseDelayMillis = 1000
}
if cfg.RequestMaxBodyBytes == 0 {
cfg.RequestMaxBodyBytes = 16 * 1024 * 1024
}
if cfg.MaxConcurrentRequests == 0 {
cfg.MaxConcurrentRequests = 64
}
return &Server{
cfg: cfg,
httpCli: httpCli,
caClient: codeassist.NewCaClient(httpCli, cfg.RequestMaxRetries, time.Duration(cfg.RequestBaseDelayMillis)*time.Millisecond),
sem: make(chan struct{}, cfg.MaxConcurrentRequests),
}
}
// NewWithCAClient allows injecting a custom CodeAssist client (for tests).
func NewWithCAClient(cfg config.Config, ca CodeAssist) *Server {
// Apply same defaults as New to ensure handlers work in tests with zero config
if cfg.RequestMaxRetries == 0 {
cfg.RequestMaxRetries = 3
}
if cfg.RequestBaseDelayMillis == 0 {
cfg.RequestBaseDelayMillis = 1000
}
if cfg.RequestMaxBodyBytes == 0 {
cfg.RequestMaxBodyBytes = 16 * 1024 * 1024
}
if cfg.MaxConcurrentRequests == 0 {
cfg.MaxConcurrentRequests = 64
}
return &Server{cfg: cfg, caClient: ca, sem: make(chan struct{}, cfg.MaxConcurrentRequests)}
}
func (s *Server) Router() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/health", s.handleHealth)
mux.HandleFunc("/v1beta/models", s.handleListModels)
mux.HandleFunc("/v1beta/models/", s.handleModel)
// Order: recover (outermost) -> logging -> concurrency limiter -> handlers
return s.withRecover(s.withLogging(s.withConcurrencyLimit(mux)))
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
}
func (s *Server) authorize(r *http.Request) bool {
key := s.cfg.AuthKey
if key == "" {
return true
}
if ah := r.Header.Get("Authorization"); ah != "" {
const p = "Bearer "
if strings.HasPrefix(ah, p) {
// Constant-time comparison to mitigate timing attacks
if 1 == subtle.ConstantTimeCompare([]byte(strings.TrimSpace(ah[len(p):])), []byte(key)) {
return true
}
}
}
if h := r.Header.Get("x-goog-api-key"); h != "" {
if 1 == subtle.ConstantTimeCompare([]byte(h), []byte(key)) {
return true
}
}
// 新增:检查 URL 查询参数中的 key
if qk := r.URL.Query().Get("key"); qk != "" {
if 1 == subtle.ConstantTimeCompare([]byte(qk), []byte(key)) {
return true
}
}
return false
}
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
if !s.authorize(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(listModels())
}
func (s *Server) handleModel(w http.ResponseWriter, r *http.Request) {
if !s.authorize(r) {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
path := r.URL.Path
if m := modelPathUnary.FindStringSubmatch(path); m != nil {
model := m[1]
s.handleGenerateContent(model, w, r)
return
}
if m := modelPathStream.FindStringSubmatch(path); m != nil {
model := m[1]
s.handleStreamGenerateContent(model, w, r)
return
}
http.NotFound(w, r)
}
func (s *Server) validateModel(model string) bool {
return gemini.IsSupportedModel(model)
}
func (s *Server) decodeGeminiRequest(r *http.Request) (gemini.GeminiRequest, error) {
var req gemini.GeminiRequest
dec := json.NewDecoder(r.Body)
if err := dec.Decode(&req); err != nil {
return req, err
}
req = gemini.NormalizeGeminiRequest(req)
return req, nil
}
func (s *Server) handleGenerateContent(model string, w http.ResponseWriter, r *http.Request) {
if !s.validateModel(model) {
http.Error(w, "unknown model", http.StatusBadRequest)
return
}
// Limit request body size
r.Body = http.MaxBytesReader(w, r.Body, s.cfg.RequestMaxBodyBytes)
req, err := s.decodeGeminiRequest(r)
if err != nil {
http.Error(w, fmt.Sprintf("bad request: %v", err), http.StatusBadRequest)
return
}
// Enriched logging: model, thinking config, and total tokens
var thinking any
if req.GenerationConfig != nil {
thinking = req.GenerationConfig.ThinkingConfig
}
totalTokens := countRequestTokens(req)
logrus.WithFields(logrus.Fields{
"model": model,
"thinkingConfig": thinking,
"totalTokens": totalTokens,
}).Info("sending to upstream")
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
defer cancel()
resp, err := s.caClient.GenerateContent(ctx, model, "", req)
if err != nil {
http.Error(w, err.Error(), httpStatusFromError(err))
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(resp)
}
func (s *Server) handleStreamGenerateContent(model string, w http.ResponseWriter, r *http.Request) {
if !s.validateModel(model) {
http.Error(w, "unknown model", http.StatusBadRequest)
return
}
// Limit request body size
r.Body = http.MaxBytesReader(w, r.Body, s.cfg.RequestMaxBodyBytes)
req, err := s.decodeGeminiRequest(r)
if err != nil {
http.Error(w, fmt.Sprintf("bad request: %v", err), http.StatusBadRequest)
return
}
// logrus.Infof("decoded request %s", utils.TruncateLongStringInObject(req, 100))
flusher, ok := w.(http.Flusher)
if !ok {
logrus.Warn("streaming unsupported")
http.Error(w, "streaming unsupported", http.StatusInternalServerError)
return
}
// SSE headers
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
ctx, cancel := context.WithCancel(r.Context())
defer cancel()
out, errs := s.caClient.GenerateContentStream(ctx, model, "", req)
// Prepare enriched logging: model, thinking config, and total tokens
var thinking any
if req.GenerationConfig != nil {
thinking = req.GenerationConfig.ThinkingConfig
}
totalTokens := countRequestTokens(req)
logrus.WithFields(logrus.Fields{
"model": model,
"thinkingConfig": thinking,
"totalTokens": totalTokens,
}).Info("sending to upstream")
enc := json.NewEncoder(w)
for {
select {
case g, ok := <-out:
if !ok {
return
}
// SSE event - send raw response like TypeScript version
if _, err := fmt.Fprint(w, "data: "); err != nil {
logrus.Errorf("error writing data prefix: %v", err)
return
}
if err := enc.Encode(g); err != nil {
return
}
// enc.Encode writes a trailing newline
if _, err := fmt.Fprint(w, "\n"); err != nil {
logrus.Errorf("error writing newline: %v", err)
return
}
flusher.Flush()
case e, ok := <-errs:
// If the error channel is closed or yields a nil error,
// treat it as a normal end-of-stream signal but continue
// draining the output channel until it closes.
if !ok || e == nil {
// Disable further selects on errs to avoid busy looping on a closed channel
errs = nil
continue
}
// Non-nil error: emit error event then end
if _, err := fmt.Fprint(w, "event: error\n"); err != nil {
logrus.Errorf("error writing error event: %v", err)
return
}
if _, err := fmt.Fprintf(w, "data: {\"error\":{\"message\":%q}}\n\n", e.Error()); err != nil {
logrus.Errorf("error writing error data: %v", err)
return
}
flusher.Flush()
return
case <-ctx.Done():
return
}
}
}
// countRequestTokens approximates the total token count for the request
// by summing tokens of all text parts in systemInstruction and contents
// using tiktoken-go/tokenizer. We default to O200kBase encoding.
func countRequestTokens(req gemini.GeminiRequest) int {
enc, err := tokenizer.Get(tokenizer.O200kBase)
if err != nil {
return 0
}
total := 0
// system instruction ignored for token counting (feature removed)
// contents
for _, c := range req.Contents {
for _, p := range c.Parts {
if p.Text != "" {
if n, err := enc.Count(p.Text); err == nil {
total += n
}
}
}
}
return total
}
func httpStatusFromError(err error) int {
// Simple mapping; upstream errors already include status text sometimes.
s := err.Error()
if strings.Contains(s, "status 401") {
return http.StatusUnauthorized
}
if strings.Contains(s, "status 403") {
return http.StatusForbidden
}
if strings.Contains(s, "status 429") {
return http.StatusTooManyRequests
}
if strings.Contains(s, "status 5") {
return http.StatusBadGateway
}
return http.StatusBadRequest
}
func listModels() interface{} {
type model struct {
Name string `json:"name"`
Version string `json:"version"`
DisplayName string `json:"displayName"`
Description string `json:"description"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
}
out := struct {
Models []model `json:"models"`
}{Models: make([]model, 0, len(gemini.SupportedModels))}
for _, m := range gemini.SupportedModels {
out.Models = append(out.Models, model{
Name: "models/" + m.Name,
Version: "001",
DisplayName: m.DisplayName,
Description: m.Description,
SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"},
})
}
return out
}