triflix commited on
Commit
35eb612
Β·
verified Β·
1 Parent(s): a0d98a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +346 -328
app.py CHANGED
@@ -1,43 +1,40 @@
1
- """
2
- Gemini CLI β†’ OpenAI-Compatible API Proxy
3
- Ultra-fast, reliable, with full streaming support.
4
- Deploy on HuggingFace Spaces (Docker SDK, port 7860).
5
- """
 
 
6
 
7
  import os
8
  import json
9
- import time
10
  import asyncio
11
  import logging
12
- from uuid import uuid4
13
- from datetime import datetime, timezone
14
- from contextlib import asynccontextmanager
15
- from typing import AsyncIterator, Any
16
-
17
  import httpx
18
- from fastapi import FastAPI, Request, HTTPException, Depends
 
19
  from fastapi.responses import StreamingResponse, JSONResponse
20
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
21
  from fastapi.middleware.cors import CORSMiddleware
22
- from google.oauth2.credentials import Credentials
23
- from google.auth.transport.requests import Request as GoogleAuthRequest
24
-
25
- # ────────────────────── Logging ──────────────────────
26
- logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
27
- log = logging.getLogger("gemini-proxy")
28
-
29
- # ────────────────────── Config ───────────────────────
30
- AUTH_PASSWORD = os.environ.get("GEMINI_AUTH_PASSWORD", "")
31
- CREDS_JSON = os.environ.get("GEMINI_CREDENTIALS", "{}")
32
- CLIENT_ID = os.environ.get("GEMINI_CLIENT_ID", "")
33
- CLIENT_SECRET = os.environ.get("GEMINI_CLIENT_SECRET", "")
34
- API_BASE = os.environ.get("GEMINI_API_BASE", "https://cloudcode-pa.googleapis.com")
35
-
36
- # ────────────────────── Globals ──────────────────────
37
- _http: httpx.AsyncClient | None = None
38
- _creds: Credentials | None = None
39
- _lock = asyncio.Lock()
40
- _sec = HTTPBearer(auto_error=False)
41
 
42
  MODELS = [
43
  "gemini-2.5-pro",
@@ -51,338 +48,359 @@ MODELS = [
51
  "gemini-2.5-flash-maxthinking",
52
  ]
53
 
54
- # ════════════════════ APP LIFESPAN ═══════════════════
55
-
56
- @asynccontextmanager
57
- async def lifespan(_app: FastAPI):
58
- global _http
59
- _http = httpx.AsyncClient(
60
- timeout=httpx.Timeout(connect=10, read=300, write=30, pool=10),
61
- limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
62
- http2=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
- log.info("Proxy ready β€” %d models", len(MODELS))
65
- yield
66
- await _http.aclose()
67
-
68
- app = FastAPI(title="Gemini OpenAI Proxy", lifespan=lifespan)
69
- app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
70
-
71
-
72
- # ════════════════════ AUTH ═══════════════════════════
73
-
74
- def _auth(c: HTTPAuthorizationCredentials = Depends(_sec)):
75
- if not AUTH_PASSWORD:
76
- raise HTTPException(500, "Server missing GEMINI_AUTH_PASSWORD")
77
- if not c or c.credentials != AUTH_PASSWORD:
78
- raise HTTPException(401, "Invalid Bearer token")
79
-
 
 
 
 
 
 
 
 
 
 
80
 
81
- # ════════════════════ TOKEN ══════════════════════════
82
 
83
  async def _token() -> str:
84
  global _creds
85
- async with _lock:
86
- if _creds and _creds.valid and not _creds.expired:
87
- return _creds.token
88
- return await asyncio.to_thread(_refresh)
 
89
 
90
 
91
- def _refresh() -> str:
92
- global _creds
93
- d = json.loads(CREDS_JSON)
94
-
95
- cid = d.get("client_id") or CLIENT_ID
96
- csec = d.get("client_secret") or CLIENT_SECRET
97
- rtok = d.get("refresh_token")
98
- atok = d.get("access_token") or d.get("token")
99
-
100
- missing = []
101
- if not cid: missing.append("GEMINI_CLIENT_ID")
102
- if not csec: missing.append("GEMINI_CLIENT_SECRET")
103
- if not rtok: missing.append("refresh_token")
104
- if missing:
105
- raise HTTPException(500, f"Missing: {', '.join(missing)}")
106
-
107
- exp = None
108
- if "expiry_date" in d:
109
- exp = datetime.fromtimestamp(d["expiry_date"] / 1000, tz=timezone.utc)
110
-
111
- c = Credentials(
112
- token=atok, refresh_token=rtok,
113
- token_uri=d.get("token_uri", "https://oauth2.googleapis.com/token"),
114
- client_id=cid, client_secret=csec, expiry=exp,
115
- )
116
- if not c.valid or c.expired:
117
- c.refresh(GoogleAuthRequest())
118
- log.info("Token refreshed β†’ expires %s", c.expiry)
119
- _creds = c
120
- return c.token
121
 
122
 
123
- # ════════════════════ GEMINI HELPERS ═════════════════
 
 
 
 
 
 
 
 
 
 
124
 
125
- def _parse_model(model: str):
126
- """Returns (base_model, use_search, thinking_budget)."""
127
- search = model.endswith("-search")
128
- no_think = model.endswith("-nothinking")
129
- max_think = model.endswith("-maxthinking")
130
 
131
- base = (model.removesuffix("-search")
132
- .removesuffix("-nothinking")
133
- .removesuffix("-maxthinking"))
 
134
 
135
- budget = None
136
- if no_think: budget = 0
137
- if max_think: budget = 24576
138
 
139
- return base, search, budget
 
 
 
 
 
 
140
 
141
 
142
- def _to_gemini(messages: list, search: bool, budget, **kw) -> dict:
143
- """OpenAI messages β†’ Gemini request body."""
 
 
144
  contents = []
145
- sys_parts = []
146
 
147
- for m in messages:
148
- role, text = m.get("role", "user"), m.get("content", "")
149
- if role == "system":
150
- sys_parts.append({"text": text})
151
- else:
152
- contents.append({
153
- "role": "user" if role == "user" else "model",
154
- "parts": [{"text": text}],
155
- })
156
-
157
- body: dict[str, Any] = {"contents": contents}
158
- if sys_parts:
159
- body["systemInstruction"] = {"parts": sys_parts}
160
-
161
- gc: dict[str, Any] = {}
162
- if kw.get("temperature") is not None: gc["temperature"] = kw["temperature"]
163
- if kw.get("max_tokens"): gc["maxOutputTokens"] = kw["max_tokens"]
164
- if kw.get("top_p") is not None: gc["topP"] = kw["top_p"]
165
- if kw.get("stop"):
166
- gc["stopSequences"] = kw["stop"] if isinstance(kw["stop"], list) else [kw["stop"]]
167
- if budget is not None:
168
- gc["thinkingConfig"] = {"thinkingBudget": budget}
169
- if gc:
170
- body["generationConfig"] = gc
171
-
172
- if search:
173
- body["tools"] = [{"googleSearch": {}}]
174
-
175
- return body
176
-
177
-
178
- # ─── Stream parser: handles both SSE and JSON-array ──
179
-
180
- async def _gemini_stream(url: str, headers: dict, body: dict) -> AsyncIterator[dict]:
181
- """Yields individual Gemini response objects from a stream."""
182
- sse_url = url + "?alt=sse"
183
-
184
- async with _http.stream("POST", sse_url, json=body, headers=headers) as r:
185
- if r.status_code != 200:
186
- err = (await r.aread()).decode(errors="replace")
187
- raise HTTPException(r.status_code, f"Gemini: {err[:500]}")
188
-
189
- ct = r.headers.get("content-type", "")
190
-
191
- if "text/event-stream" in ct:
192
- # ── SSE mode (fast, line-by-line) ──
193
- async for line in r.aiter_lines():
194
- if not line.startswith("data:"):
195
- continue
196
- payload = line[5:].strip()
197
- if not payload or payload == "[DONE]":
198
- continue
199
- try:
200
- yield json.loads(payload)
201
- except json.JSONDecodeError:
202
- continue
203
  else:
204
- # ── JSON-array fallback ──
205
- buf = ""
206
- async for chunk in r.aiter_text():
207
- buf += chunk
208
- while True:
209
- buf = buf.lstrip(" \t\n\r,[")
210
- if not buf or buf[0] != "{":
211
- # also strip trailing ] at end of array
212
- buf = buf.lstrip("]")
213
- break
214
- # find matching }
215
- depth = 0
216
- in_s = 0 # 1 = inside string
217
- esc = 0 # 1 = next char is escaped
218
- found = -1
219
- for i, c in enumerate(buf):
220
- if esc:
221
- esc = 0; continue
222
- if c == "\\" and in_s:
223
- esc = 1; continue
224
- if c == '"':
225
- in_s ^= 1; continue
226
- if in_s:
227
- continue
228
- if c == "{": depth += 1
229
- elif c == "}":
230
- depth -= 1
231
- if depth == 0:
232
- found = i; break
233
- if found < 0:
234
- break # incomplete, need more data
235
- try:
236
- yield json.loads(buf[:found + 1])
237
- except json.JSONDecodeError:
238
- pass
239
- buf = buf[found + 1:]
240
-
241
-
242
- def _text(obj: dict) -> str:
243
- """Extract non-thought text from Gemini response."""
244
- parts = obj.get("candidates", [{}])[0].get("content", {}).get("parts", [])
245
- return "".join(p.get("text", "") for p in parts if not p.get("thought"))
246
-
247
-
248
- def _usage(obj: dict) -> dict:
249
- m = obj.get("usageMetadata", {})
250
- return {
251
- "prompt_tokens": m.get("promptTokenCount", 0),
252
- "completion_tokens": m.get("candidatesTokenCount", 0),
253
- "total_tokens": m.get("totalTokenCount", 0),
254
- }
255
-
256
 
257
- _FINISH_MAP = {"STOP": "stop", "MAX_TOKENS": "length", "SAFETY": "content_filter"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- def _finish(obj: dict) -> str | None:
260
- r = obj.get("candidates", [{}])[0].get("finishReason")
261
- return _FINISH_MAP.get(r)
 
 
 
 
 
 
 
 
 
262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
- # ════════════════════ ROUTES ═════════════════════════
265
 
 
266
  @app.get("/")
267
- async def health():
268
- return {"status": "ok", "service": "gemini-openai-proxy", "models": len(MODELS)}
269
 
270
 
271
  @app.get("/v1/models")
272
- async def list_models(_=Depends(_auth)):
273
  return {
274
  "object": "list",
275
- "data": [{"id": m, "object": "model", "owned_by": "google", "created": 0} for m in MODELS],
 
 
 
 
 
 
 
 
276
  }
277
 
278
 
279
  @app.post("/v1/chat/completions")
280
- async def chat(request: Request, _=Depends(_auth)):
281
- body = await request.json()
282
- model = body.get("model", "gemini-2.5-pro")
283
- msgs = body.get("messages", [])
284
- stream = body.get("stream", False)
285
-
286
- if not msgs:
287
- raise HTTPException(400, "messages required")
288
-
289
- base, search, budget = _parse_model(model)
290
- gemini_body = _to_gemini(
291
- msgs, search, budget,
292
- temperature=body.get("temperature"),
293
- max_tokens=body.get("max_tokens") or body.get("max_completion_tokens"),
294
- top_p=body.get("top_p"),
295
- stop=body.get("stop"),
296
- )
297
-
298
  tok = await _token()
299
- hdrs = {"Authorization": f"Bearer {tok}", "Content-Type": "application/json"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
301
- if stream:
302
- url = f"{API_BASE}/v1/models/{base}:streamGenerateContent"
303
- return StreamingResponse(
304
- _sse_stream(url, hdrs, gemini_body, model),
305
- media_type="text/event-stream",
306
- headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
307
- )
308
-
309
- # ── Non-streaming ──
310
- url = f"{API_BASE}/v1/models/{base}:generateContent"
311
- return await _non_stream(url, hdrs, gemini_body, model)
312
-
313
-
314
- # ─── Streaming response ─────────────────────────────
315
-
316
- async def _sse_stream(url, hdrs, body, model):
317
- cid = f"chatcmpl-{uuid4().hex[:24]}"
318
- ts = int(time.time())
319
-
320
- # role chunk
321
- yield _sse({"id": cid, "object": "chat.completion.chunk", "created": ts, "model": model,
322
- "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}]})
323
-
324
- try:
325
- async for obj in _gemini_stream(url, hdrs, body):
326
- txt = _text(obj)
327
- fin = _finish(obj)
328
-
329
- if txt:
330
- yield _sse({"id": cid, "object": "chat.completion.chunk", "created": ts, "model": model,
331
- "choices": [{"index": 0, "delta": {"content": txt}, "finish_reason": None}]})
332
- if fin:
333
- yield _sse({"id": cid, "object": "chat.completion.chunk", "created": ts, "model": model,
334
- "choices": [{"index": 0, "delta": {}, "finish_reason": fin}]})
335
- except HTTPException as e:
336
- # Send error as SSE event so client knows what happened
337
- yield _sse({"error": {"message": e.detail, "code": e.status_code}})
338
-
339
- yield "data: [DONE]\n\n"
340
-
341
-
342
- # ─── Non-streaming response ─────────────────────────
343
-
344
- async def _non_stream(url, hdrs, body, model):
345
- # Retry once on 401 (expired token)
346
- for attempt in range(2):
347
- r = await _http.post(url, json=body, headers=hdrs)
348
- if r.status_code == 401 and attempt == 0:
349
- global _creds
350
- async with _lock:
351
- _creds = None
352
- tok = await _token()
353
- hdrs["Authorization"] = f"Bearer {tok}"
354
- continue
355
- break
356
-
357
- if r.status_code != 200:
358
- raise HTTPException(r.status_code, f"Gemini: {r.text[:500]}")
359
-
360
- data = r.json()
361
-
362
- # Handle error in body
363
- if "error" in data:
364
- e = data["error"]
365
- raise HTTPException(e.get("code", 500), e.get("message", "Unknown"))
366
-
367
- # Gemini may return list or dict
368
- if isinstance(data, list):
369
- full_text = "".join(_text(item) for item in data)
370
- usg = next((_usage(i) for i in data if _usage(i).get("total_tokens")), _usage({}))
371
- fin = next((_finish(i) for i in data if _finish(i)), "stop")
372
  else:
373
- full_text = _text(data)
374
- usg = _usage(data)
375
- fin = _finish(data) or "stop"
376
 
377
- return JSONResponse({
378
- "id": f"chatcmpl-{uuid4().hex[:24]}",
379
- "object": "chat.completion",
380
- "created": int(time.time()),
381
- "model": model,
382
- "choices": [{"index": 0, "message": {"role": "assistant", "content": full_text}, "finish_reason": fin}],
383
- "usage": usg,
384
- })
 
 
 
 
 
385
 
386
 
387
- def _sse(obj: dict) -> str:
388
- return f"data: {json.dumps(obj, ensure_ascii=False)}\n\n"
 
 
 
1
+ # ============================================================
2
+ # DATETIME FIX β€” Must be first, before any google.auth import
3
+ # ============================================================
4
+ import datetime as _dt
5
+ import google.auth._helpers as _gah
6
+ _gah.utcnow = lambda: _dt.datetime.now(_dt.timezone.utc)
7
+ # ============================================================
8
 
9
  import os
10
  import json
 
11
  import asyncio
12
  import logging
13
+ import time
14
+ import uuid
 
 
 
15
  import httpx
16
+
17
+ from fastapi import FastAPI, HTTPException, Depends, Request
18
  from fastapi.responses import StreamingResponse, JSONResponse
 
19
  from fastapi.middleware.cors import CORSMiddleware
20
+ from pydantic import BaseModel
21
+ from typing import Optional, List, Union
22
+ import google.oauth2.credentials
23
+ import google.auth.transport.requests
24
+
25
+ # ── Logging ──────────────────────────────────────────────────
26
+ logging.basicConfig(
27
+ level=logging.INFO,
28
+ format="%(asctime)s %(levelname)s %(message)s",
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # ── Config from env ──────────────────────────────────────────
33
+ AUTH_PASSWORD = os.environ.get("GEMINI_AUTH_PASSWORD", "")
34
+ RAW_CREDS = os.environ.get("GEMINI_CREDENTIALS", "")
35
+ PORT = int(os.environ.get("PORT", 7860))
36
+
37
+ GEMINI_API_BASE = "https://cloudcode-pa.googleapis.com/v1internal/projects/-/locations/-/endpoints/-"
 
38
 
39
  MODELS = [
40
  "gemini-2.5-pro",
 
48
  "gemini-2.5-flash-maxthinking",
49
  ]
50
 
51
+ # Thinking budgets per model variant
52
+ THINKING_BUDGET = {
53
+ "gemini-2.5-pro-nothinking": 0,
54
+ "gemini-2.5-flash-nothinking": 0,
55
+ "gemini-2.5-pro-maxthinking": 32768,
56
+ "gemini-2.5-flash-maxthinking":32768,
57
+ }
58
+
59
+ # Search grounding models
60
+ SEARCH_MODELS = {"gemini-2.5-pro-search", "gemini-2.5-flash-search"}
61
+
62
+ # Base model mapping (strip suffix for API call)
63
+ def base_model(model: str) -> str:
64
+ for suffix in ["-search", "-nothinking", "-maxthinking"]:
65
+ if model.endswith(suffix):
66
+ return model[: -len(suffix)]
67
+ return model
68
+
69
+ # ── Credential management ─────────────────────────────────────
70
+ _creds: Optional[google.oauth2.credentials.Credentials] = None
71
+ _creds_lock = asyncio.Lock()
72
+
73
+ def _build_creds() -> google.oauth2.credentials.Credentials:
74
+ if not RAW_CREDS:
75
+ raise RuntimeError("GEMINI_CREDENTIALS env var not set")
76
+ data = json.loads(RAW_CREDS)
77
+ expiry = None
78
+ if "expiry_date" in data:
79
+ # expiry_date is epoch ms from oauth_creds.json
80
+ ts = data["expiry_date"] / 1000.0
81
+ expiry = _dt.datetime.fromtimestamp(ts, tz=_dt.timezone.utc)
82
+ elif "expiry" in data:
83
+ raw = data["expiry"]
84
+ if isinstance(raw, (int, float)):
85
+ expiry = _dt.datetime.fromtimestamp(raw, tz=_dt.timezone.utc)
86
+ else:
87
+ expiry = _dt.datetime.fromisoformat(raw)
88
+ if expiry.tzinfo is None:
89
+ expiry = expiry.replace(tzinfo=_dt.timezone.utc)
90
+
91
+ c = google.oauth2.credentials.Credentials(
92
+ token = data.get("token") or data.get("access_token"),
93
+ refresh_token = data.get("refresh_token"),
94
+ token_uri = data.get("token_uri", "https://oauth2.googleapis.com/token"),
95
+ client_id = data.get("client_id"),
96
+ client_secret = data.get("client_secret"),
97
+ scopes = data.get("scopes", ["https://www.googleapis.com/auth/cloud-platform"]),
98
  )
99
+ if expiry:
100
+ c.expiry = expiry
101
+ return c
102
+
103
+
104
+ def _refresh(c: google.oauth2.credentials.Credentials):
105
+ """Synchronously refresh credentials if expired."""
106
+ now = _dt.datetime.now(_dt.timezone.utc)
107
+
108
+ # Safely check expiry, handle both aware and naive
109
+ needs_refresh = False
110
+ if c.token is None:
111
+ needs_refresh = True
112
+ elif c.expiry is not None:
113
+ expiry = c.expiry
114
+ if expiry.tzinfo is None:
115
+ expiry = expiry.replace(tzinfo=_dt.timezone.utc)
116
+ # refresh 5 minutes early
117
+ needs_refresh = now >= (expiry - _dt.timedelta(minutes=5))
118
+
119
+ if needs_refresh:
120
+ logger.info("Refreshing Google OAuth token...")
121
+ request = google.auth.transport.requests.Request()
122
+ c.refresh(request)
123
+ logger.info("Token refreshed successfully.")
124
+ return c.token
125
 
 
126
 
127
  async def _token() -> str:
128
  global _creds
129
+ async with _creds_lock:
130
+ if _creds is None:
131
+ _creds = _build_creds()
132
+ token = await asyncio.to_thread(_refresh, _creds)
133
+ return token
134
 
135
 
136
+ # ── FastAPI app ───────────────────────────────────────────────
137
+ app = FastAPI(title="geminicli2api", version="1.0.0")
138
+
139
+ app.add_middleware(
140
+ CORSMiddleware,
141
+ allow_origins=["*"],
142
+ allow_methods=["*"],
143
+ allow_headers=["*"],
144
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
 
147
+ # ── Auth dependency ───────────────────────────────────────────
148
+ async def verify_auth(request: Request):
149
+ if not AUTH_PASSWORD:
150
+ return
151
+ auth = request.headers.get("Authorization", "")
152
+ if auth.startswith("Bearer "):
153
+ token = auth[7:]
154
+ else:
155
+ token = auth
156
+ if token != AUTH_PASSWORD:
157
+ raise HTTPException(status_code=401, detail="Unauthorized")
158
 
 
 
 
 
 
159
 
160
+ # ── Pydantic models ���──────────────────────────────────────────
161
+ class Message(BaseModel):
162
+ role: str
163
+ content: Union[str, list]
164
 
 
 
 
165
 
166
+ class ChatRequest(BaseModel):
167
+ model: str = "gemini-2.5-flash"
168
+ messages: List[Message]
169
+ stream: bool = False
170
+ max_tokens: Optional[int] = None
171
+ temperature: Optional[float] = None
172
+ top_p: Optional[float] = None
173
 
174
 
175
+ # ── Conversion helpers ────────────────────────────────────────
176
+ def openai_messages_to_gemini(messages: List[Message]):
177
+ """Convert OpenAI messages to Gemini contents format."""
178
+ system_parts = []
179
  contents = []
 
180
 
181
+ for msg in messages:
182
+ role = msg.role
183
+ content = msg.content
184
+
185
+ if isinstance(content, str):
186
+ parts = [{"text": content}]
187
+ elif isinstance(content, list):
188
+ parts = []
189
+ for item in content:
190
+ if isinstance(item, dict):
191
+ if item.get("type") == "text":
192
+ parts.append({"text": item["text"]})
193
+ elif item.get("type") == "image_url":
194
+ url = item["image_url"]["url"]
195
+ if url.startswith("data:"):
196
+ mime, b64 = url[5:].split(";base64,", 1)
197
+ parts.append({
198
+ "inlineData": {"mimeType": mime, "data": b64}
199
+ })
200
+ else:
201
+ parts.append({"text": f"[Image: {url}]"})
202
+ else:
203
+ parts.append({"text": str(item)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  else:
205
+ parts = [{"text": str(content)}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ if role == "system":
208
+ system_parts.extend(parts)
209
+ elif role == "user":
210
+ contents.append({"role": "user", "parts": parts})
211
+ elif role == "assistant":
212
+ contents.append({"role": "model", "parts": parts})
213
+
214
+ return system_parts, contents
215
+
216
+
217
+ def build_gemini_payload(req: ChatRequest) -> dict:
218
+ system_parts, contents = openai_messages_to_gemini(req.messages)
219
+
220
+ payload: dict = {"contents": contents}
221
+
222
+ if system_parts:
223
+ payload["systemInstruction"] = {"parts": system_parts}
224
+
225
+ gen_config: dict = {}
226
+ if req.max_tokens:
227
+ gen_config["maxOutputTokens"] = req.max_tokens
228
+ if req.temperature is not None:
229
+ gen_config["temperature"] = req.temperature
230
+ if req.top_p is not None:
231
+ gen_config["topP"] = req.top_p
232
+
233
+ model = req.model
234
+ if model in THINKING_BUDGET:
235
+ gen_config["thinkingConfig"] = {
236
+ "thinkingBudget": THINKING_BUDGET[model],
237
+ "includeThoughts": THINKING_BUDGET[model] > 0,
238
+ }
239
+ elif model not in {"gemini-2.0-flash"} and "flash" not in model:
240
+ # Default thinking for pro models
241
+ gen_config["thinkingConfig"] = {
242
+ "thinkingBudget": -1,
243
+ "includeThoughts": False,
244
+ }
245
+
246
+ if gen_config:
247
+ payload["generationConfig"] = gen_config
248
+
249
+ if model in SEARCH_MODELS:
250
+ payload["tools"] = [{"googleSearch": {}}]
251
+
252
+ return payload
253
+
254
+
255
+ def gemini_response_to_openai(gemini_resp: dict, model: str, stream: bool = False) -> dict:
256
+ """Convert Gemini response to OpenAI format."""
257
+ candidates = gemini_resp.get("candidates", [])
258
+ text = ""
259
+ finish_reason = "stop"
260
+
261
+ if candidates:
262
+ candidate = candidates[0]
263
+ parts = candidate.get("content", {}).get("parts", [])
264
+ for part in parts:
265
+ if "text" in part and not part.get("thought", False):
266
+ text += part["text"]
267
+ fr = candidate.get("finishReason", "STOP")
268
+ finish_reason = {
269
+ "STOP": "stop",
270
+ "MAX_TOKENS": "length",
271
+ "SAFETY": "content_filter",
272
+ }.get(fr, "stop")
273
+
274
+ usage = gemini_resp.get("usageMetadata", {})
275
+ prompt_tokens = usage.get("promptTokenCount", 0)
276
+ completion_tokens = usage.get("candidatesTokenCount", 0)
277
+
278
+ resp_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
279
+ created = int(time.time())
280
 
281
+ if stream:
282
+ return {
283
+ "id": resp_id,
284
+ "object": "chat.completion.chunk",
285
+ "created": created,
286
+ "model": model,
287
+ "choices": [{
288
+ "index": 0,
289
+ "delta": {"content": text},
290
+ "finish_reason": finish_reason,
291
+ }],
292
+ }
293
 
294
+ return {
295
+ "id": resp_id,
296
+ "object": "chat.completion",
297
+ "created": created,
298
+ "model": model,
299
+ "choices": [{
300
+ "index": 0,
301
+ "message": {"role": "assistant", "content": text},
302
+ "finish_reason": finish_reason,
303
+ }],
304
+ "usage": {
305
+ "prompt_tokens": prompt_tokens,
306
+ "completion_tokens": completion_tokens,
307
+ "total_tokens": prompt_tokens + completion_tokens,
308
+ },
309
+ }
310
 
 
311
 
312
+ # ── Routes ────────────────────────────────────────────────────
313
  @app.get("/")
314
+ async def root():
315
+ return {"status": "ok", "models": MODELS}
316
 
317
 
318
  @app.get("/v1/models")
319
+ async def list_models(_=Depends(verify_auth)):
320
  return {
321
  "object": "list",
322
+ "data": [
323
+ {
324
+ "id": m,
325
+ "object": "model",
326
+ "created": 1700000000,
327
+ "owned_by": "google",
328
+ }
329
+ for m in MODELS
330
+ ],
331
  }
332
 
333
 
334
  @app.post("/v1/chat/completions")
335
+ async def chat(req: ChatRequest, _=Depends(verify_auth)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  tok = await _token()
337
+ model = req.model
338
+ api_model = base_model(model)
339
+ payload = build_gemini_payload(req)
340
+
341
+ headers = {
342
+ "Authorization": f"Bearer {tok}",
343
+ "Content-Type": "application/json",
344
+ }
345
+
346
+ if req.stream:
347
+ url = f"{GEMINI_API_BASE}:streamGenerateContent?alt=sse&model={api_model}"
348
+
349
+ async def generate():
350
+ async with httpx.AsyncClient(timeout=120) as client:
351
+ async with client.stream("POST", url, headers=headers, json=payload) as resp:
352
+ if resp.status_code != 200:
353
+ body = await resp.aread()
354
+ err = body.decode(errors="replace")
355
+ logger.error(f"Gemini API error {resp.status_code}: {err}")
356
+ yield f"data: {json.dumps({'error': err})}\n\n"
357
+ return
358
+
359
+ buffer = ""
360
+ async for chunk in resp.aiter_text():
361
+ buffer += chunk
362
+ while "\n\n" in buffer:
363
+ event, buffer = buffer.split("\n\n", 1)
364
+ for line in event.splitlines():
365
+ if line.startswith("data: "):
366
+ data_str = line[6:]
367
+ if data_str.strip() == "[DONE]":
368
+ yield "data: [DONE]\n\n"
369
+ return
370
+ try:
371
+ gemini_data = json.loads(data_str)
372
+ openai_chunk = gemini_response_to_openai(
373
+ gemini_data, model, stream=True
374
+ )
375
+ yield f"data: {json.dumps(openai_chunk)}\n\n"
376
+ except json.JSONDecodeError:
377
+ pass
378
+
379
+ yield "data: [DONE]\n\n"
380
+
381
+ return StreamingResponse(generate(), media_type="text/event-stream")
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  else:
384
+ url = f"{GEMINI_API_BASE}:generateContent?model={api_model}"
385
+ async with httpx.AsyncClient(timeout=120) as client:
386
+ resp = await client.post(url, headers=headers, json=payload)
387
 
388
+ if resp.status_code != 200:
389
+ logger.error(f"Gemini API error {resp.status_code}: {resp.text}")
390
+ raise HTTPException(status_code=resp.status_code, detail=resp.text)
391
+
392
+ gemini_data = resp.json()
393
+ return gemini_response_to_openai(gemini_data, model)
394
+
395
+
396
+ # ── Startup ───────────────────────────────────────────────────
397
+ @app.on_event("startup")
398
+ async def startup():
399
+ print(f"\n===== Application Startup at {_dt.datetime.now().strftime('%Y-%m-%d %H:%M:%S')} =====\n")
400
+ logger.info(f"Proxy ready β€” {len(MODELS)} models")
401
 
402
 
403
+ # ── Main ──────────────────────────────────────────────────────
404
+ if __name__ == "__main__":
405
+ import uvicorn
406
+ uvicorn.run("app:app", host="0.0.0.0", port=PORT, log_level="info")