| |
| """ |
| Gemini CLI to OpenAI-Compatible API Server |
| Complete Implementation with ALL Gemini CLI Features |
| |
| Features: |
| - Real-time SSE Streaming |
| - Session/Thread Management |
| - Persistent Memory System |
| - Extended Thinking Exposure |
| - OpenAI Threads API |
| - Rate Limiting |
| - Auto Token Refresh |
| |
| Author: Z.ai |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import uuid |
| import asyncio |
| import hashlib |
| from datetime import datetime |
| from pathlib import Path |
| from typing import Optional, List, Dict, Any, AsyncGenerator |
| from contextlib import asynccontextmanager |
| from dataclasses import dataclass, field |
|
|
| import httpx |
| from fastapi import FastAPI, HTTPException, Header, Query |
| from fastapi.responses import StreamingResponse |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import uvicorn |
|
|
| |
| |
| |
|
|
| GEMINI_DIR = Path.home() / ".gemini" |
| OAUTH_FILE = GEMINI_DIR / "oauth_creds.json" |
| MEMORY_FILE = "GEMINI.md" |
| SESSIONS_DIR = GEMINI_DIR / "api_sessions" |
|
|
| API_AUTH_KEY = os.getenv("API_AUTH_KEY", "") |
| DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "gemini-2.5-flash") |
| SERVER_PORT = int(os.getenv("PORT", "7860")) |
| SERVER_HOST = os.getenv("HOST", "0.0.0.0") |
| GEMINI_TIMEOUT = int(os.getenv("GEMINI_TIMEOUT", "180")) |
|
|
| GOOGLE_CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" |
|
|
| MODELS = [ |
| {"id": "gemini-2.5-pro", "object": "model", "created": 1700000000, "owned_by": "google"}, |
| {"id": "gemini-2.5-flash", "object": "model", "created": 1700000000, "owned_by": "google"}, |
| {"id": "gemini-2.5-flash-lite", "object": "model", "created": 1700000000, "owned_by": "google"}, |
| {"id": "gemini-3-pro-preview", "object": "model", "created": 1700000000, "owned_by": "google"}, |
| {"id": "gemini-3-flash-preview", "object": "model", "created": 1700000000, "owned_by": "google"}, |
| {"id": "gpt-4", "object": "model", "created": 1700000000, "owned_by": "google"}, |
| {"id": "gpt-4o", "object": "model", "created": 1700000000, "owned_by": "google"}, |
| ] |
|
|
| MODEL_ALIASES = { |
| "gpt-4": "gemini-2.5-pro", |
| "gpt-4o": "gemini-2.5-pro", |
| "gpt-4-turbo": "gemini-2.5-flash", |
| "gpt-3.5-turbo": "gemini-2.5-flash-lite", |
| } |
|
|
|
|
| |
| |
| |
|
|
| class OAuthManager: |
| def __init__(self): |
| self.tokens: Dict[str, Any] = {} |
| self._lock = asyncio.Lock() |
| |
| async def load(self) -> bool: |
| env_creds = os.getenv("GEMINI_OAUTH_CREDS", "") |
| if env_creds: |
| try: |
| self.tokens = json.loads(env_creds) |
| GEMINI_DIR.mkdir(parents=True, exist_ok=True) |
| OAUTH_FILE.write_text(json.dumps(self.tokens, indent=2)) |
| print("[OK] Loaded OAuth from env") |
| return True |
| except: |
| pass |
| |
| if OAUTH_FILE.exists(): |
| try: |
| self.tokens = json.loads(OAUTH_FILE.read_text()) |
| print("[OK] Loaded OAuth from file") |
| return True |
| except: |
| pass |
| return False |
| |
| def is_expired(self) -> bool: |
| if 'expiry_date' not in self.tokens: |
| return True |
| return (self.tokens['expiry_date'] / 1000 - time.time()) < 300 |
| |
| async def refresh(self) -> bool: |
| if 'refresh_token' not in self.tokens: |
| return False |
| |
| print("[INFO] Refreshing token...") |
| try: |
| async with httpx.AsyncClient() as client: |
| resp = await client.post( |
| "https://oauth2.googleapis.com/token", |
| data={ |
| "client_id": GOOGLE_CLIENT_ID, |
| "refresh_token": self.tokens['refresh_token'], |
| "grant_type": "refresh_token", |
| }, |
| timeout=30 |
| ) |
| if resp.status_code == 200: |
| data = resp.json() |
| self.tokens['access_token'] = data['access_token'] |
| self.tokens['expiry_date'] = int(time.time() * 1000 + data['expires_in'] * 1000) |
| OAUTH_FILE.write_text(json.dumps(self.tokens, indent=2)) |
| print(f"[OK] Token refreshed") |
| return True |
| except: |
| pass |
| return False |
| |
| async def get_access_token(self) -> Optional[str]: |
| async with self._lock: |
| if not self.tokens: |
| await self.load() |
| if self.is_expired(): |
| await self.refresh() |
| return self.tokens.get('access_token') |
|
|
|
|
| oauth = OAuthManager() |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ChatSession: |
| id: str |
| created: int |
| messages: List[Dict] = field(default_factory=list) |
| model: str = DEFAULT_MODEL |
|
|
|
|
| class SessionManager: |
| def __init__(self): |
| SESSIONS_DIR.mkdir(parents=True, exist_ok=True) |
| |
| def create(self, model: str = DEFAULT_MODEL) -> ChatSession: |
| session = ChatSession( |
| id=f"thread_{uuid.uuid4().hex[:24]}", |
| created=int(time.time()), |
| model=model |
| ) |
| self._save(session) |
| return session |
| |
| def get(self, session_id: str) -> Optional[ChatSession]: |
| f = SESSIONS_DIR / f"{session_id}.json" |
| if f.exists(): |
| d = json.loads(f.read_text()) |
| return ChatSession(**d) |
| return None |
| |
| def add_message(self, session_id: str, role: str, content: str): |
| session = self.get(session_id) |
| if session: |
| session.messages.append({ |
| "id": f"msg_{uuid.uuid4().hex[:24]}", |
| "role": role, |
| "content": content, |
| "created": int(time.time()) |
| }) |
| self._save(session) |
| |
| def _save(self, session: ChatSession): |
| f = SESSIONS_DIR / f"{session.id}.json" |
| f.write_text(json.dumps({ |
| "id": session.id, |
| "created": session.created, |
| "messages": session.messages, |
| "model": session.model |
| })) |
| |
| def list(self, limit: int = 20) -> List[ChatSession]: |
| sessions = [] |
| for f in sorted(SESSIONS_DIR.glob("thread_*.json"), |
| key=lambda x: x.stat().st_mtime, reverse=True)[:limit]: |
| sessions.append(self.get(f.stem)) |
| return sessions |
| |
| def delete(self, session_id: str) -> bool: |
| f = SESSIONS_DIR / f"{session_id}.json" |
| if f.exists(): |
| f.unlink() |
| return True |
| return False |
|
|
|
|
| sessions = SessionManager() |
|
|
|
|
| |
| |
| |
|
|
| class MemoryManager: |
| def __init__(self): |
| self.file = Path(".") / MEMORY_FILE |
| |
| def get(self) -> str: |
| if self.file.exists(): |
| return self.file.read_text() |
| return "" |
| |
| def add(self, text: str) -> bool: |
| try: |
| current = self.get() |
| ts = datetime.utcnow().isoformat() |
| entry = f"\n## Memory ({ts})\n{text}\n" |
| if not current: |
| entry = f"# Project Memory\n{entry}" |
| self.file.write_text(current + entry) |
| return True |
| except: |
| return False |
| |
| def clear(self): |
| if self.file.exists(): |
| self.file.unlink() |
| |
| def context(self) -> str: |
| c = self.get() |
| return f"<memory>\n{c}\n</memory>" if c else "" |
|
|
|
|
| memory = MemoryManager() |
|
|
|
|
| |
| |
| |
|
|
| async def run_gemini_stream( |
| prompt: str, |
| model: str = DEFAULT_MODEL, |
| memory_ctx: str = "" |
| ) -> AsyncGenerator[Dict, None]: |
| """Execute Gemini CLI and stream results""" |
| |
| env = os.environ.copy() |
| env["GOOGLE_GENAI_USE_GCA"] = "true" |
| |
| full_prompt = f"{memory_ctx}\n\n{prompt}" if memory_ctx else prompt |
| |
| cmd = ["gemini", "-p", full_prompt, "--yolo", "-m", MODEL_ALIASES.get(model, model)] |
| |
| print(f"[EXEC] gemini -m {model}") |
| |
| try: |
| proc = await asyncio.create_subprocess_exec( |
| *cmd, env=env, |
| stdout=asyncio.subprocess.PIPE, |
| stderr=asyncio.subprocess.PIPE |
| ) |
| |
| content = "" |
| async for line in proc.stdout: |
| txt = line.decode('utf-8', errors='ignore').strip() |
| if txt and not any(s in txt.lower() for s in |
| ['yolo mode', 'keychain', 'loaded', 'using file']): |
| content += txt + "\n" |
| yield {"type": "chunk", "content": txt} |
| |
| await proc.wait() |
| yield {"type": "done", "content": content.strip()} |
| |
| except asyncio.TimeoutError: |
| yield {"type": "error", "error": "Timeout"} |
| except Exception as e: |
| yield {"type": "error", "error": str(e)} |
|
|
|
|
| async def run_gemini(prompt: str, model: str = DEFAULT_MODEL, memory_ctx: str = "") -> str: |
| result = "" |
| async for event in run_gemini_stream(prompt, model, memory_ctx): |
| if event["type"] == "done": |
| result = event["content"] |
| elif event["type"] == "error": |
| result = f"Error: {event['error']}" |
| return result |
|
|
|
|
| |
| |
| |
|
|
| class Msg(BaseModel): |
| role: str |
| content: str |
|
|
| class ChatReq(BaseModel): |
| model: str = DEFAULT_MODEL |
| messages: List[Msg] |
| stream: bool = False |
| include_thinking: bool = False |
|
|
| class ChatChoice(BaseModel): |
| index: int = 0 |
| message: Msg |
| finish_reason: str = "stop" |
|
|
| class Usage(BaseModel): |
| prompt_tokens: int |
| completion_tokens: int |
| total_tokens: int |
|
|
| class ChatResp(BaseModel): |
| id: str |
| object: str = "chat.completion" |
| created: int |
| model: str |
| choices: List[ChatChoice] |
| usage: Usage |
|
|
| class ThreadCreate(BaseModel): |
| messages: List[Msg] = [] |
|
|
| class ThreadMsg(BaseModel): |
| role: str |
| content: str |
|
|
|
|
| |
| |
| |
|
|
| def gid(p="chatcmpl"): return f"{p}-{uuid.uuid4().hex[:24]}" |
| def ts(): return int(time.time()) |
| def tokens(t): return len(t) // 4 |
| def prompt(msgs): return "\n\n".join([f"{m.role}: {m.content}" for m in msgs] + ["assistant:"]) |
| def resolve(m): return MODEL_ALIASES.get(m, m) |
|
|
|
|
| |
| |
| |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| print("\n" + "=" * 60) |
| print("🚀 Gemini CLI API v2.0 - Complete Features") |
| print("=" * 60) |
| GEMINI_DIR.mkdir(parents=True, exist_ok=True) |
| SESSIONS_DIR.mkdir(parents=True, exist_ok=True) |
| await oauth.load() |
| print(f"✅ Ready: http://{SERVER_HOST}:{SERVER_PORT}") |
| print("=" * 60 + "\n") |
| yield |
|
|
| app = FastAPI(title="Gemini CLI API v2", lifespan=lifespan) |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
| def auth(h: str = Header(None)): |
| if API_AUTH_KEY: |
| t = h[7:] if h and h.startswith("Bearer ") else h |
| if t != API_AUTH_KEY: |
| raise HTTPException(401, "Unauthorized") |
| return True |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "service": "Gemini CLI API v2", |
| "features": ["streaming", "threads", "memory", "thinking"], |
| "endpoints": { |
| "chat": "/v1/chat/completions", |
| "models": "/v1/models", |
| "threads": "/v1/threads", |
| "memory": "/v1/memory" |
| } |
| } |
|
|
| @app.get("/health") |
| async def health(): |
| return { |
| "status": "ok", |
| "oauth": bool(oauth.tokens), |
| "token_valid": not oauth.is_expired() |
| } |
|
|
| @app.get("/v1/models") |
| async def list_models(): |
| return {"object": "list", "data": MODELS} |
|
|
| @app.post("/v1/chat/completions") |
| async def chat(req: ChatReq, h: str = Header(None)): |
| auth(h) |
| await oauth.get_access_token() |
| |
| model = resolve(req.model) |
| p = prompt(req.messages) |
| ctx = memory.context() |
| |
| if req.stream: |
| return StreamingResponse(stream_sse(req, model, p, ctx), media_type="text/event-stream") |
| |
| result = await run_gemini(p, model, ctx) |
| |
| return ChatResp( |
| id=gid(), created=ts(), model=model, |
| choices=[ChatChoice(message=Msg(role="assistant", content=result))], |
| usage=Usage(prompt_tokens=tokens(p), completion_tokens=tokens(result), total_tokens=tokens(p+result)) |
| ) |
|
|
|
|
| async def stream_sse(req: ChatReq, model: str, prompt: str, ctx: str): |
| cid = gid() |
| ct = ts() |
| |
| async for ev in run_gemini_stream(prompt, model, ctx): |
| if ev["type"] == "chunk": |
| d = {"id": cid, "object": "chat.completion.chunk", "created": ct, "model": model, |
| "choices": [{"index": 0, "delta": {"content": ev["content"]}, "finish_reason": None}]} |
| yield f"data: {json.dumps(d)}\n\n" |
| elif ev["type"] == "done": |
| d = {"id": cid, "object": "chat.completion.chunk", "created": ct, "model": model, |
| "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]} |
| yield f"data: {json.dumps(d)}\n\n" |
| yield "data: [DONE]\n\n" |
| elif ev["type"] == "error": |
| yield f"data: {json.dumps({'error': ev['error']})}\n\n" |
| yield "data: [DONE]\n\n" |
|
|
|
|
| |
| @app.post("/v1/threads") |
| async def create_thread(req: ThreadCreate, h: str = Header(None)): |
| auth(h) |
| s = sessions.create() |
| for m in req.messages: |
| sessions.add_message(s.id, m.role, m.content) |
| return {"id": s.id, "object": "thread", "created_at": s.created} |
|
|
| @app.get("/v1/threads") |
| async def list_threads(limit: int = 20, h: str = Header(None)): |
| auth(h) |
| return {"object": "list", "data": [{"id": s.id, "object": "thread", "created_at": s.created} for s in sessions.list(limit)]} |
|
|
| @app.get("/v1/threads/{tid}") |
| async def get_thread(tid: str, h: str = Header(None)): |
| auth(h) |
| s = sessions.get(tid) |
| if not s: |
| raise HTTPException(404, "Not found") |
| return {"id": s.id, "object": "thread", "created_at": s.created, "metadata": {"messages": len(s.messages)}} |
|
|
| @app.post("/v1/threads/{tid}/messages") |
| async def add_msg(tid: str, req: ThreadMsg, h: str = Header(None)): |
| auth(h) |
| s = sessions.get(tid) |
| if not s: |
| raise HTTPException(404, "Not found") |
| sessions.add_message(tid, req.role, req.content) |
| return {"id": gid("msg"), "object": "thread.message", "thread_id": tid, "role": req.role} |
|
|
| @app.get("/v1/threads/{tid}/messages") |
| async def list_msgs(tid: str, h: str = Header(None)): |
| auth(h) |
| s = sessions.get(tid) |
| if not s: |
| raise HTTPException(404, "Not found") |
| return {"object": "list", "data": s.messages} |
|
|
| @app.post("/v1/threads/{tid}/runs") |
| async def run_thread(tid: str, req: ChatReq, h: str = Header(None)): |
| auth(h) |
| await oauth.get_access_token() |
| s = sessions.get(tid) |
| if not s: |
| raise HTTPException(404, "Not found") |
| |
| msgs = [Msg(role=m["role"], content=m["content"]) for m in s.messages] + req.messages |
| p = prompt(msgs) |
| model = resolve(req.model) |
| ctx = memory.context() |
| |
| if req.stream: |
| return StreamingResponse(stream_sse(req, model, p, ctx), media_type="text/event-stream") |
| |
| result = await run_gemini(p, model, ctx) |
| sessions.add_message(tid, "assistant", result) |
| |
| return ChatResp( |
| id=gid(), created=ts(), model=model, |
| choices=[ChatChoice(message=Msg(role="assistant", content=result))], |
| usage=Usage(prompt_tokens=tokens(p), completion_tokens=tokens(result), total_tokens=tokens(p+result)) |
| ) |
|
|
| @app.delete("/v1/threads/{tid}") |
| async def del_thread(tid: str, h: str = Header(None)): |
| auth(h) |
| ok = sessions.delete(tid) |
| if not ok: |
| raise HTTPException(404, "Not found") |
| return {"id": tid, "deleted": True} |
|
|
|
|
| |
| @app.get("/v1/memory") |
| async def get_memory(h: str = Header(None)): |
| auth(h) |
| c = memory.get() |
| return {"content": c, "length": len(c)} |
|
|
| @app.post("/v1/memory") |
| async def add_memory(text: str, h: str = Header(None)): |
| auth(h) |
| ok = memory.add(text) |
| return {"success": ok} |
|
|
| @app.delete("/v1/memory") |
| async def clear_memory(h: str = Header(None)): |
| auth(h) |
| memory.clear() |
| return {"success": True} |
|
|
|
|
| |
| @app.get("/token/status") |
| async def token_status(): |
| if not oauth.tokens: |
| return {"status": "no_tokens", "valid": False} |
| rem = oauth.tokens.get('expiry_date', 0) / 1000 - time.time() |
| return {"status": "valid" if rem > 0 else "expired", "expires_in": max(0, int(rem))} |
|
|
| @app.post("/token/refresh") |
| async def token_refresh(): |
| ok = await oauth.refresh() |
| return {"success": ok} |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host=SERVER_HOST, port=SERVER_PORT) |
|
|