smgc commited on
Commit
ca586bf
·
verified ·
1 Parent(s): f946611

Create server.go

Browse files
Files changed (1) hide show
  1. server.go +367 -0
server.go ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package server
2
+
3
+ import (
4
+ "context"
5
+ "crypto/subtle"
6
+ "encoding/json"
7
+ "fmt"
8
+ "net/http"
9
+ "regexp"
10
+ "strings"
11
+ "time"
12
+
13
+ "gcli2api/internal/codeassist"
14
+ "gcli2api/internal/config"
15
+ "gcli2api/internal/gemini"
16
+
17
+ // "gcli2api/internal/utils"
18
+
19
+ "github.com/sirupsen/logrus"
20
+ "github.com/tiktoken-go/tokenizer"
21
+ )
22
+
23
+ var (
24
+ modelPathUnary = regexp.MustCompile(`^/v1beta/models/([^/]+):generateContent$`)
25
+ modelPathStream = regexp.MustCompile(`^/v1beta/models/([^/]+):streamGenerateContent$`)
26
+ )
27
+
28
+ // CodeAssist abstracts the client for easier testing.
29
+ type CodeAssist interface {
30
+ GenerateContent(ctx context.Context, model, project string, req gemini.GeminiRequest) (*gemini.GeminiAPIResponse, error)
31
+ GenerateContentStream(ctx context.Context, model, project string, req gemini.GeminiRequest) (<-chan gemini.GeminiAPIResponse, <-chan error)
32
+ }
33
+
34
+ type Server struct {
35
+ cfg config.Config
36
+ httpCli *http.Client
37
+ caClient CodeAssist
38
+ // sem is a simple semaphore for concurrency limiting
39
+ sem chan struct{}
40
+ }
41
+
42
+ func New(cfg config.Config, httpCli *http.Client) *Server {
43
+ // Apply safe defaults when fields are zero to match config.LoadConfig behavior
44
+ if cfg.RequestMaxRetries == 0 {
45
+ cfg.RequestMaxRetries = 3
46
+ }
47
+ if cfg.RequestBaseDelayMillis == 0 {
48
+ cfg.RequestBaseDelayMillis = 1000
49
+ }
50
+ if cfg.RequestMaxBodyBytes == 0 {
51
+ cfg.RequestMaxBodyBytes = 16 * 1024 * 1024
52
+ }
53
+ if cfg.MaxConcurrentRequests == 0 {
54
+ cfg.MaxConcurrentRequests = 64
55
+ }
56
+ return &Server{
57
+ cfg: cfg,
58
+ httpCli: httpCli,
59
+ caClient: codeassist.NewCaClient(httpCli, cfg.RequestMaxRetries, time.Duration(cfg.RequestBaseDelayMillis)*time.Millisecond),
60
+ sem: make(chan struct{}, cfg.MaxConcurrentRequests),
61
+ }
62
+ }
63
+
64
+ // NewWithCAClient allows injecting a custom CodeAssist client (for tests).
65
+ func NewWithCAClient(cfg config.Config, ca CodeAssist) *Server {
66
+ // Apply same defaults as New to ensure handlers work in tests with zero config
67
+ if cfg.RequestMaxRetries == 0 {
68
+ cfg.RequestMaxRetries = 3
69
+ }
70
+ if cfg.RequestBaseDelayMillis == 0 {
71
+ cfg.RequestBaseDelayMillis = 1000
72
+ }
73
+ if cfg.RequestMaxBodyBytes == 0 {
74
+ cfg.RequestMaxBodyBytes = 16 * 1024 * 1024
75
+ }
76
+ if cfg.MaxConcurrentRequests == 0 {
77
+ cfg.MaxConcurrentRequests = 64
78
+ }
79
+ return &Server{cfg: cfg, caClient: ca, sem: make(chan struct{}, cfg.MaxConcurrentRequests)}
80
+ }
81
+
82
+ func (s *Server) Router() http.Handler {
83
+ mux := http.NewServeMux()
84
+ mux.HandleFunc("/health", s.handleHealth)
85
+ mux.HandleFunc("/v1beta/models", s.handleListModels)
86
+ mux.HandleFunc("/v1beta/models/", s.handleModel)
87
+ // Order: recover (outermost) -> logging -> concurrency limiter -> handlers
88
+ return s.withRecover(s.withLogging(s.withConcurrencyLimit(mux)))
89
+ }
90
+
91
+ func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
92
+ w.Header().Set("Content-Type", "application/json")
93
+ w.WriteHeader(http.StatusOK)
94
+ _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
95
+ }
96
+
97
+ func (s *Server) authorize(r *http.Request) bool {
98
+ key := s.cfg.AuthKey
99
+ if key == "" {
100
+ return true
101
+ }
102
+
103
+ // 检查 Authorization Header
104
+ if ah := r.Header.Get("Authorization"); ah != "" {
105
+ const p = "Bearer "
106
+ if strings.HasPrefix(ah, p) {
107
+ // Constant-time comparison to mitigate timing attacks
108
+ if 1 == subtle.ConstantTimeCompare([]byte(strings.TrimSpace(ah[len(p):])), []byte(key)) {
109
+ return true
110
+ }
111
+ }
112
+ }
113
+
114
+ // 检查 x-goog-api-key Header
115
+ if h := r.Header.Get("x-goog-api-key"); h != "" {
116
+ if 1 == subtle.ConstantTimeCompare([]byte(h), []byte(key)) {
117
+ return true
118
+ }
119
+ }
120
+
121
+ // ====== 新增改动开始 ======
122
+ // 检查 URL 查询参数中的 key
123
+ if qk := r.URL.Query().Get("key"); qk != "" {
124
+ if 1 == subtle.ConstantTimeCompare([]byte(qk), []byte(key)) {
125
+ return true
126
+ }
127
+ }
128
+ // ====== 新增改动结束 ======
129
+
130
+ return false
131
+ }
132
+
133
+
134
+ func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
135
+ if r.Method != http.MethodGet {
136
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
137
+ return
138
+ }
139
+ if !s.authorize(r) {
140
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
141
+ return
142
+ }
143
+ w.Header().Set("Content-Type", "application/json")
144
+ _ = json.NewEncoder(w).Encode(listModels())
145
+ }
146
+
147
+ func (s *Server) handleModel(w http.ResponseWriter, r *http.Request) {
148
+ if !s.authorize(r) {
149
+ http.Error(w, "unauthorized", http.StatusUnauthorized)
150
+ return
151
+ }
152
+ if r.Method != http.MethodPost {
153
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
154
+ return
155
+ }
156
+ path := r.URL.Path
157
+ if m := modelPathUnary.FindStringSubmatch(path); m != nil {
158
+ model := m[1]
159
+ s.handleGenerateContent(model, w, r)
160
+ return
161
+ }
162
+ if m := modelPathStream.FindStringSubmatch(path); m != nil {
163
+ model := m[1]
164
+ s.handleStreamGenerateContent(model, w, r)
165
+ return
166
+ }
167
+ http.NotFound(w, r)
168
+ }
169
+
170
+ func (s *Server) validateModel(model string) bool {
171
+ return gemini.IsSupportedModel(model)
172
+ }
173
+
174
+ func (s *Server) decodeGeminiRequest(r *http.Request) (gemini.GeminiRequest, error) {
175
+ var req gemini.GeminiRequest
176
+ dec := json.NewDecoder(r.Body)
177
+ if err := dec.Decode(&req); err != nil {
178
+ return req, err
179
+ }
180
+ req = gemini.NormalizeGeminiRequest(req)
181
+ return req, nil
182
+ }
183
+
184
+ func (s *Server) handleGenerateContent(model string, w http.ResponseWriter, r *http.Request) {
185
+ if !s.validateModel(model) {
186
+ http.Error(w, "unknown model", http.StatusBadRequest)
187
+ return
188
+ }
189
+ // Limit request body size
190
+ r.Body = http.MaxBytesReader(w, r.Body, s.cfg.RequestMaxBodyBytes)
191
+ req, err := s.decodeGeminiRequest(r)
192
+ if err != nil {
193
+ http.Error(w, fmt.Sprintf("bad request: %v", err), http.StatusBadRequest)
194
+ return
195
+ }
196
+ // Enriched logging: model, thinking config, and total tokens
197
+ var thinking any
198
+ if req.GenerationConfig != nil {
199
+ thinking = req.GenerationConfig.ThinkingConfig
200
+ }
201
+ totalTokens := countRequestTokens(req)
202
+ logrus.WithFields(logrus.Fields{
203
+ "model": model,
204
+ "thinkingConfig": thinking,
205
+ "totalTokens": totalTokens,
206
+ }).Info("sending to upstream")
207
+ ctx, cancel := context.WithTimeout(r.Context(), 5*time.Minute)
208
+ defer cancel()
209
+ resp, err := s.caClient.GenerateContent(ctx, model, "", req)
210
+ if err != nil {
211
+ http.Error(w, err.Error(), httpStatusFromError(err))
212
+ return
213
+ }
214
+ w.Header().Set("Content-Type", "application/json")
215
+ _ = json.NewEncoder(w).Encode(resp)
216
+ }
217
+
218
+ func (s *Server) handleStreamGenerateContent(model string, w http.ResponseWriter, r *http.Request) {
219
+ if !s.validateModel(model) {
220
+ http.Error(w, "unknown model", http.StatusBadRequest)
221
+ return
222
+ }
223
+ // Limit request body size
224
+ r.Body = http.MaxBytesReader(w, r.Body, s.cfg.RequestMaxBodyBytes)
225
+ req, err := s.decodeGeminiRequest(r)
226
+ if err != nil {
227
+ http.Error(w, fmt.Sprintf("bad request: %v", err), http.StatusBadRequest)
228
+ return
229
+ }
230
+ // logrus.Infof("decoded request %s", utils.TruncateLongStringInObject(req, 100))
231
+ flusher, ok := w.(http.Flusher)
232
+ if !ok {
233
+ logrus.Warn("streaming unsupported")
234
+ http.Error(w, "streaming unsupported", http.StatusInternalServerError)
235
+ return
236
+ }
237
+ // SSE headers
238
+ w.Header().Set("Content-Type", "text/event-stream")
239
+ w.Header().Set("Cache-Control", "no-cache")
240
+ w.Header().Set("Connection", "keep-alive")
241
+ w.Header().Set("X-Accel-Buffering", "no")
242
+
243
+ ctx, cancel := context.WithCancel(r.Context())
244
+ defer cancel()
245
+ out, errs := s.caClient.GenerateContentStream(ctx, model, "", req)
246
+
247
+ // Prepare enriched logging: model, thinking config, and total tokens
248
+ var thinking any
249
+ if req.GenerationConfig != nil {
250
+ thinking = req.GenerationConfig.ThinkingConfig
251
+ }
252
+ totalTokens := countRequestTokens(req)
253
+ logrus.WithFields(logrus.Fields{
254
+ "model": model,
255
+ "thinkingConfig": thinking,
256
+ "totalTokens": totalTokens,
257
+ }).Info("sending to upstream")
258
+ enc := json.NewEncoder(w)
259
+ for {
260
+ select {
261
+ case g, ok := <-out:
262
+ if !ok {
263
+ return
264
+ }
265
+ // SSE event - send raw response like TypeScript version
266
+ if _, err := fmt.Fprint(w, "data: "); err != nil {
267
+ logrus.Errorf("error writing data prefix: %v", err)
268
+ return
269
+ }
270
+ if err := enc.Encode(g); err != nil {
271
+ return
272
+ }
273
+ // enc.Encode writes a trailing newline
274
+ if _, err := fmt.Fprint(w, "\n"); err != nil {
275
+ logrus.Errorf("error writing newline: %v", err)
276
+ return
277
+ }
278
+ flusher.Flush()
279
+ case e, ok := <-errs:
280
+ // If the error channel is closed or yields a nil error,
281
+ // treat it as a normal end-of-stream signal but continue
282
+ // draining the output channel until it closes.
283
+ if !ok || e == nil {
284
+ // Disable further selects on errs to avoid busy looping on a closed channel
285
+ errs = nil
286
+ continue
287
+ }
288
+ // Non-nil error: emit error event then end
289
+ if _, err := fmt.Fprint(w, "event: error\n"); err != nil {
290
+ logrus.Errorf("error writing error event: %v", err)
291
+ return
292
+ }
293
+ if _, err := fmt.Fprintf(w, "data: {\"error\":{\"message\":%q}}\n\n", e.Error()); err != nil {
294
+ logrus.Errorf("error writing error data: %v", err)
295
+ return
296
+ }
297
+ flusher.Flush()
298
+ return
299
+ case <-ctx.Done():
300
+ return
301
+ }
302
+ }
303
+ }
304
+
305
+ // countRequestTokens approximates the total token count for the request
306
+ // by summing tokens of all text parts in systemInstruction and contents
307
+ // using tiktoken-go/tokenizer. We default to O200kBase encoding.
308
+ func countRequestTokens(req gemini.GeminiRequest) int {
309
+ enc, err := tokenizer.Get(tokenizer.O200kBase)
310
+ if err != nil {
311
+ return 0
312
+ }
313
+ total := 0
314
+ // system instruction ignored for token counting (feature removed)
315
+ // contents
316
+ for _, c := range req.Contents {
317
+ for _, p := range c.Parts {
318
+ if p.Text != "" {
319
+ if n, err := enc.Count(p.Text); err == nil {
320
+ total += n
321
+ }
322
+ }
323
+ }
324
+ }
325
+ return total
326
+ }
327
+
328
+ func httpStatusFromError(err error) int {
329
+ // Simple mapping; upstream errors already include status text sometimes.
330
+ s := err.Error()
331
+ if strings.Contains(s, "status 401") {
332
+ return http.StatusUnauthorized
333
+ }
334
+ if strings.Contains(s, "status 403") {
335
+ return http.StatusForbidden
336
+ }
337
+ if strings.Contains(s, "status 429") {
338
+ return http.StatusTooManyRequests
339
+ }
340
+ if strings.Contains(s, "status 5") {
341
+ return http.StatusBadGateway
342
+ }
343
+ return http.StatusBadRequest
344
+ }
345
+
346
+ func listModels() interface{} {
347
+ type model struct {
348
+ Name string `json:"name"`
349
+ Version string `json:"version"`
350
+ DisplayName string `json:"displayName"`
351
+ Description string `json:"description"`
352
+ SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
353
+ }
354
+ out := struct {
355
+ Models []model `json:"models"`
356
+ }{Models: make([]model, 0, len(gemini.SupportedModels))}
357
+ for _, m := range gemini.SupportedModels {
358
+ out.Models = append(out.Models, model{
359
+ Name: "models/" + m.Name,
360
+ Version: "001",
361
+ DisplayName: m.DisplayName,
362
+ Description: m.Description,
363
+ SupportedGenerationMethods: []string{"generateContent", "streamGenerateContent"},
364
+ })
365
+ }
366
+ return out
367
+ }