custom_ai / server.py
triflix's picture
Upload 5 files
f1b91cb verified
#!/usr/bin/env python3
"""
Gemini CLI to OpenAI-Compatible API Server
Complete Implementation with ALL Gemini CLI Features
Features:
- Real-time SSE Streaming
- Session/Thread Management
- Persistent Memory System
- Extended Thinking Exposure
- OpenAI Threads API
- Rate Limiting
- Auto Token Refresh
Author: Z.ai
"""
import os
import sys
import json
import time
import uuid
import asyncio
import hashlib
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any, AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
import httpx
from fastapi import FastAPI, HTTPException, Header, Query
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
# ============================================================================
# Configuration
# ============================================================================
GEMINI_DIR = Path.home() / ".gemini"
OAUTH_FILE = GEMINI_DIR / "oauth_creds.json"
MEMORY_FILE = "GEMINI.md"
SESSIONS_DIR = GEMINI_DIR / "api_sessions"
API_AUTH_KEY = os.getenv("API_AUTH_KEY", "")
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "gemini-2.5-flash")
SERVER_PORT = int(os.getenv("PORT", "7860"))
SERVER_HOST = os.getenv("HOST", "0.0.0.0")
GEMINI_TIMEOUT = int(os.getenv("GEMINI_TIMEOUT", "180"))
GOOGLE_CLIENT_ID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
MODELS = [
{"id": "gemini-2.5-pro", "object": "model", "created": 1700000000, "owned_by": "google"},
{"id": "gemini-2.5-flash", "object": "model", "created": 1700000000, "owned_by": "google"},
{"id": "gemini-2.5-flash-lite", "object": "model", "created": 1700000000, "owned_by": "google"},
{"id": "gemini-3-pro-preview", "object": "model", "created": 1700000000, "owned_by": "google"},
{"id": "gemini-3-flash-preview", "object": "model", "created": 1700000000, "owned_by": "google"},
{"id": "gpt-4", "object": "model", "created": 1700000000, "owned_by": "google"},
{"id": "gpt-4o", "object": "model", "created": 1700000000, "owned_by": "google"},
]
MODEL_ALIASES = {
"gpt-4": "gemini-2.5-pro",
"gpt-4o": "gemini-2.5-pro",
"gpt-4-turbo": "gemini-2.5-flash",
"gpt-3.5-turbo": "gemini-2.5-flash-lite",
}
# ============================================================================
# OAuth Manager
# ============================================================================
class OAuthManager:
def __init__(self):
self.tokens: Dict[str, Any] = {}
self._lock = asyncio.Lock()
async def load(self) -> bool:
env_creds = os.getenv("GEMINI_OAUTH_CREDS", "")
if env_creds:
try:
self.tokens = json.loads(env_creds)
GEMINI_DIR.mkdir(parents=True, exist_ok=True)
OAUTH_FILE.write_text(json.dumps(self.tokens, indent=2))
print("[OK] Loaded OAuth from env")
return True
except:
pass
if OAUTH_FILE.exists():
try:
self.tokens = json.loads(OAUTH_FILE.read_text())
print("[OK] Loaded OAuth from file")
return True
except:
pass
return False
def is_expired(self) -> bool:
if 'expiry_date' not in self.tokens:
return True
return (self.tokens['expiry_date'] / 1000 - time.time()) < 300
async def refresh(self) -> bool:
if 'refresh_token' not in self.tokens:
return False
print("[INFO] Refreshing token...")
try:
async with httpx.AsyncClient() as client:
resp = await client.post(
"https://oauth2.googleapis.com/token",
data={
"client_id": GOOGLE_CLIENT_ID,
"refresh_token": self.tokens['refresh_token'],
"grant_type": "refresh_token",
},
timeout=30
)
if resp.status_code == 200:
data = resp.json()
self.tokens['access_token'] = data['access_token']
self.tokens['expiry_date'] = int(time.time() * 1000 + data['expires_in'] * 1000)
OAUTH_FILE.write_text(json.dumps(self.tokens, indent=2))
print(f"[OK] Token refreshed")
return True
except:
pass
return False
async def get_access_token(self) -> Optional[str]:
async with self._lock:
if not self.tokens:
await self.load()
if self.is_expired():
await self.refresh()
return self.tokens.get('access_token')
oauth = OAuthManager()
# ============================================================================
# Session Manager (Threads)
# ============================================================================
@dataclass
class ChatSession:
id: str
created: int
messages: List[Dict] = field(default_factory=list)
model: str = DEFAULT_MODEL
class SessionManager:
def __init__(self):
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
def create(self, model: str = DEFAULT_MODEL) -> ChatSession:
session = ChatSession(
id=f"thread_{uuid.uuid4().hex[:24]}",
created=int(time.time()),
model=model
)
self._save(session)
return session
def get(self, session_id: str) -> Optional[ChatSession]:
f = SESSIONS_DIR / f"{session_id}.json"
if f.exists():
d = json.loads(f.read_text())
return ChatSession(**d)
return None
def add_message(self, session_id: str, role: str, content: str):
session = self.get(session_id)
if session:
session.messages.append({
"id": f"msg_{uuid.uuid4().hex[:24]}",
"role": role,
"content": content,
"created": int(time.time())
})
self._save(session)
def _save(self, session: ChatSession):
f = SESSIONS_DIR / f"{session.id}.json"
f.write_text(json.dumps({
"id": session.id,
"created": session.created,
"messages": session.messages,
"model": session.model
}))
def list(self, limit: int = 20) -> List[ChatSession]:
sessions = []
for f in sorted(SESSIONS_DIR.glob("thread_*.json"),
key=lambda x: x.stat().st_mtime, reverse=True)[:limit]:
sessions.append(self.get(f.stem))
return sessions
def delete(self, session_id: str) -> bool:
f = SESSIONS_DIR / f"{session_id}.json"
if f.exists():
f.unlink()
return True
return False
sessions = SessionManager()
# ============================================================================
# Memory Manager
# ============================================================================
class MemoryManager:
def __init__(self):
self.file = Path(".") / MEMORY_FILE
def get(self) -> str:
if self.file.exists():
return self.file.read_text()
return ""
def add(self, text: str) -> bool:
try:
current = self.get()
ts = datetime.utcnow().isoformat()
entry = f"\n## Memory ({ts})\n{text}\n"
if not current:
entry = f"# Project Memory\n{entry}"
self.file.write_text(current + entry)
return True
except:
return False
def clear(self):
if self.file.exists():
self.file.unlink()
def context(self) -> str:
c = self.get()
return f"<memory>\n{c}\n</memory>" if c else ""
memory = MemoryManager()
# ============================================================================
# Gemini CLI Executor
# ============================================================================
async def run_gemini_stream(
prompt: str,
model: str = DEFAULT_MODEL,
memory_ctx: str = ""
) -> AsyncGenerator[Dict, None]:
"""Execute Gemini CLI and stream results"""
env = os.environ.copy()
env["GOOGLE_GENAI_USE_GCA"] = "true"
full_prompt = f"{memory_ctx}\n\n{prompt}" if memory_ctx else prompt
cmd = ["gemini", "-p", full_prompt, "--yolo", "-m", MODEL_ALIASES.get(model, model)]
print(f"[EXEC] gemini -m {model}")
try:
proc = await asyncio.create_subprocess_exec(
*cmd, env=env,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
content = ""
async for line in proc.stdout:
txt = line.decode('utf-8', errors='ignore').strip()
if txt and not any(s in txt.lower() for s in
['yolo mode', 'keychain', 'loaded', 'using file']):
content += txt + "\n"
yield {"type": "chunk", "content": txt}
await proc.wait()
yield {"type": "done", "content": content.strip()}
except asyncio.TimeoutError:
yield {"type": "error", "error": "Timeout"}
except Exception as e:
yield {"type": "error", "error": str(e)}
async def run_gemini(prompt: str, model: str = DEFAULT_MODEL, memory_ctx: str = "") -> str:
result = ""
async for event in run_gemini_stream(prompt, model, memory_ctx):
if event["type"] == "done":
result = event["content"]
elif event["type"] == "error":
result = f"Error: {event['error']}"
return result
# ============================================================================
# Pydantic Models
# ============================================================================
class Msg(BaseModel):
role: str
content: str
class ChatReq(BaseModel):
model: str = DEFAULT_MODEL
messages: List[Msg]
stream: bool = False
include_thinking: bool = False
class ChatChoice(BaseModel):
index: int = 0
message: Msg
finish_reason: str = "stop"
class Usage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatResp(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[ChatChoice]
usage: Usage
class ThreadCreate(BaseModel):
messages: List[Msg] = []
class ThreadMsg(BaseModel):
role: str
content: str
# ============================================================================
# Helpers
# ============================================================================
def gid(p="chatcmpl"): return f"{p}-{uuid.uuid4().hex[:24]}"
def ts(): return int(time.time())
def tokens(t): return len(t) // 4
def prompt(msgs): return "\n\n".join([f"{m.role}: {m.content}" for m in msgs] + ["assistant:"])
def resolve(m): return MODEL_ALIASES.get(m, m)
# ============================================================================
# FastAPI App
# ============================================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
print("\n" + "=" * 60)
print("🚀 Gemini CLI API v2.0 - Complete Features")
print("=" * 60)
GEMINI_DIR.mkdir(parents=True, exist_ok=True)
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
await oauth.load()
print(f"✅ Ready: http://{SERVER_HOST}:{SERVER_PORT}")
print("=" * 60 + "\n")
yield
app = FastAPI(title="Gemini CLI API v2", lifespan=lifespan)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
def auth(h: str = Header(None)):
if API_AUTH_KEY:
t = h[7:] if h and h.startswith("Bearer ") else h
if t != API_AUTH_KEY:
raise HTTPException(401, "Unauthorized")
return True
# ============================================================================
# Endpoints
# ============================================================================
@app.get("/")
async def root():
return {
"service": "Gemini CLI API v2",
"features": ["streaming", "threads", "memory", "thinking"],
"endpoints": {
"chat": "/v1/chat/completions",
"models": "/v1/models",
"threads": "/v1/threads",
"memory": "/v1/memory"
}
}
@app.get("/health")
async def health():
return {
"status": "ok",
"oauth": bool(oauth.tokens),
"token_valid": not oauth.is_expired()
}
@app.get("/v1/models")
async def list_models():
return {"object": "list", "data": MODELS}
@app.post("/v1/chat/completions")
async def chat(req: ChatReq, h: str = Header(None)):
auth(h)
await oauth.get_access_token()
model = resolve(req.model)
p = prompt(req.messages)
ctx = memory.context()
if req.stream:
return StreamingResponse(stream_sse(req, model, p, ctx), media_type="text/event-stream")
result = await run_gemini(p, model, ctx)
return ChatResp(
id=gid(), created=ts(), model=model,
choices=[ChatChoice(message=Msg(role="assistant", content=result))],
usage=Usage(prompt_tokens=tokens(p), completion_tokens=tokens(result), total_tokens=tokens(p+result))
)
async def stream_sse(req: ChatReq, model: str, prompt: str, ctx: str):
cid = gid()
ct = ts()
async for ev in run_gemini_stream(prompt, model, ctx):
if ev["type"] == "chunk":
d = {"id": cid, "object": "chat.completion.chunk", "created": ct, "model": model,
"choices": [{"index": 0, "delta": {"content": ev["content"]}, "finish_reason": None}]}
yield f"data: {json.dumps(d)}\n\n"
elif ev["type"] == "done":
d = {"id": cid, "object": "chat.completion.chunk", "created": ct, "model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}]}
yield f"data: {json.dumps(d)}\n\n"
yield "data: [DONE]\n\n"
elif ev["type"] == "error":
yield f"data: {json.dumps({'error': ev['error']})}\n\n"
yield "data: [DONE]\n\n"
# Threads API
@app.post("/v1/threads")
async def create_thread(req: ThreadCreate, h: str = Header(None)):
auth(h)
s = sessions.create()
for m in req.messages:
sessions.add_message(s.id, m.role, m.content)
return {"id": s.id, "object": "thread", "created_at": s.created}
@app.get("/v1/threads")
async def list_threads(limit: int = 20, h: str = Header(None)):
auth(h)
return {"object": "list", "data": [{"id": s.id, "object": "thread", "created_at": s.created} for s in sessions.list(limit)]}
@app.get("/v1/threads/{tid}")
async def get_thread(tid: str, h: str = Header(None)):
auth(h)
s = sessions.get(tid)
if not s:
raise HTTPException(404, "Not found")
return {"id": s.id, "object": "thread", "created_at": s.created, "metadata": {"messages": len(s.messages)}}
@app.post("/v1/threads/{tid}/messages")
async def add_msg(tid: str, req: ThreadMsg, h: str = Header(None)):
auth(h)
s = sessions.get(tid)
if not s:
raise HTTPException(404, "Not found")
sessions.add_message(tid, req.role, req.content)
return {"id": gid("msg"), "object": "thread.message", "thread_id": tid, "role": req.role}
@app.get("/v1/threads/{tid}/messages")
async def list_msgs(tid: str, h: str = Header(None)):
auth(h)
s = sessions.get(tid)
if not s:
raise HTTPException(404, "Not found")
return {"object": "list", "data": s.messages}
@app.post("/v1/threads/{tid}/runs")
async def run_thread(tid: str, req: ChatReq, h: str = Header(None)):
auth(h)
await oauth.get_access_token()
s = sessions.get(tid)
if not s:
raise HTTPException(404, "Not found")
msgs = [Msg(role=m["role"], content=m["content"]) for m in s.messages] + req.messages
p = prompt(msgs)
model = resolve(req.model)
ctx = memory.context()
if req.stream:
return StreamingResponse(stream_sse(req, model, p, ctx), media_type="text/event-stream")
result = await run_gemini(p, model, ctx)
sessions.add_message(tid, "assistant", result)
return ChatResp(
id=gid(), created=ts(), model=model,
choices=[ChatChoice(message=Msg(role="assistant", content=result))],
usage=Usage(prompt_tokens=tokens(p), completion_tokens=tokens(result), total_tokens=tokens(p+result))
)
@app.delete("/v1/threads/{tid}")
async def del_thread(tid: str, h: str = Header(None)):
auth(h)
ok = sessions.delete(tid)
if not ok:
raise HTTPException(404, "Not found")
return {"id": tid, "deleted": True}
# Memory API
@app.get("/v1/memory")
async def get_memory(h: str = Header(None)):
auth(h)
c = memory.get()
return {"content": c, "length": len(c)}
@app.post("/v1/memory")
async def add_memory(text: str, h: str = Header(None)):
auth(h)
ok = memory.add(text)
return {"success": ok}
@app.delete("/v1/memory")
async def clear_memory(h: str = Header(None)):
auth(h)
memory.clear()
return {"success": True}
# Token
@app.get("/token/status")
async def token_status():
if not oauth.tokens:
return {"status": "no_tokens", "valid": False}
rem = oauth.tokens.get('expiry_date', 0) / 1000 - time.time()
return {"status": "valid" if rem > 0 else "expired", "expires_in": max(0, int(rem))}
@app.post("/token/refresh")
async def token_refresh():
ok = await oauth.refresh()
return {"success": ok}
if __name__ == "__main__":
uvicorn.run(app, host=SERVER_HOST, port=SERVER_PORT)