ethansunqing
Support authenticated Gemini account routing
04b9d27
Raw
History Blame Contribute Delete
8.91 kB
"""Gemini StreamGenerate protocol implementation with httpx streaming."""
import json
import time
import uuid
import re
import urllib.request
import urllib.parse
import ssl
import os
import hashlib
try:
import httpx
HAS_HTTPX = True
except ImportError:
HAS_HTTPX = False
from .config import CONFIG
_ssl_ctx = None
_cookie_cache = {"str": "", "sapisid": None, "mtime": 0}
_httpx_client = None
def log(msg: str):
if CONFIG["log_requests"]:
import sys
sys.stderr.write(f"[{time.strftime('%H:%M:%S')}] {msg}\n")
sys.stderr.flush()
def _get_ssl_ctx():
global _ssl_ctx
if _ssl_ctx is None:
_ssl_ctx = ssl.create_default_context()
return _ssl_ctx
def _get_httpx_client():
global _httpx_client
if _httpx_client is None and HAS_HTTPX:
proxy = CONFIG.get("proxy")
transport = httpx.HTTPTransport(proxy=proxy) if proxy else None
_httpx_client = httpx.Client(transport=transport, timeout=CONFIG["request_timeout_sec"], verify=True)
return _httpx_client
def load_cookie() -> tuple:
"""Load cookie from file with mtime-based caching."""
cookie_file = CONFIG.get("cookie_file")
if not cookie_file or not os.path.exists(cookie_file):
return "", None
try:
mtime = os.path.getmtime(cookie_file)
if mtime == _cookie_cache["mtime"] and _cookie_cache["str"]:
return _cookie_cache["str"], _cookie_cache["sapisid"]
with open(cookie_file, "r") as f:
content = f.read().strip()
if content.startswith("{"):
data = json.loads(content)
cookie_str = data.get("cookie", "")
sapisid = data.get("sapisid", "")
else:
cookie_str = content
pairs = dict(p.split("=", 1) for p in cookie_str.split("; ") if "=" in p)
sapisid = pairs.get("SAPISID", "")
_cookie_cache.update({"str": cookie_str, "sapisid": sapisid or None, "mtime": mtime})
return cookie_str, sapisid if sapisid else None
except Exception as e:
log(f"Cookie load error: {e}")
return _cookie_cache["str"], _cookie_cache["sapisid"]
def make_sapisidhash(sapisid: str) -> str:
ts = int(time.time())
h = hashlib.sha1(f"{ts} {sapisid} https://gemini.google.com".encode()).hexdigest()
return f"SAPISIDHASH {ts}_{h}"
def _account_prefix() -> str:
"""Return the Gemini account path prefix for non-default Google accounts."""
auth_user = CONFIG.get("auth_user")
if auth_user is None or auth_user == "":
return ""
return f"/u/{auth_user}"
def _build_headers() -> dict:
account_prefix = _account_prefix()
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Origin": "https://gemini.google.com",
"Referer": f"https://gemini.google.com{account_prefix}/app",
"X-Same-Domain": "1",
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
}
if account_prefix:
headers["X-Goog-AuthUser"] = str(CONFIG["auth_user"])
cookie_str, sapisid = load_cookie()
if cookie_str:
headers["Cookie"] = cookie_str
if sapisid:
headers["Authorization"] = make_sapisidhash(sapisid)
return headers
def _build_payload(prompt: str, model_id: int, think_mode: int, file_refs: list = None, extra_fields: dict = None) -> str:
inner = [None] * 102
if file_refs:
refs = [[None, None, ref] for ref in file_refs]
inner[0] = [prompt, 0, None, refs, None, None, 0]
else:
inner[0] = [prompt, 0, None, None, None, None, 0]
inner[1] = ["en"]
inner[2] = ["", "", "", None, None, None, None, None, None, ""]
inner[6] = [0]
inner[7] = 1
inner[10] = 1
inner[11] = 0
inner[17] = [[think_mode]]
inner[18] = 0
inner[27] = 1
inner[30] = [4]
inner[41] = [2]
inner[53] = 0
inner[59] = str(uuid.uuid4())
inner[61] = []
inner[68] = 1
inner[79] = model_id
if extra_fields:
for k, v in extra_fields.items():
inner[k] = v
outer = [None, json.dumps(inner)]
params = {"f.req": json.dumps(outer)}
if CONFIG.get("xsrf_token"):
params["at"] = CONFIG["xsrf_token"]
return urllib.parse.urlencode(params)
def _get_url() -> str:
reqid = int(time.time()) % 1000000
account_prefix = _account_prefix()
return (
f"https://gemini.google.com{account_prefix}/_/BardChatUi/data/"
"assistant.lamda.BardFrontendService/StreamGenerate"
f"?bl={CONFIG['gemini_bl']}&hl=en&_reqid={reqid}&rt=c"
)
def clean_text(text: str) -> str:
text = re.sub(
r'```(?:python|javascript|text)\?code_(?:reference|stdout)&code_event_index=\d+\n.*?```\n?',
'', text, flags=re.DOTALL
)
text = re.sub(r'http://googleusercontent\.com/card_content/\d+\n?', '', text)
return text.strip()
def _extract_texts_from_line(line: str) -> list:
"""Parse a single wrb.fr line and return list of text strings found."""
if '"wrb.fr"' not in line or len(line) < 200:
return []
try:
arr = json.loads(line)
inner_str = arr[0][2]
if not inner_str or len(inner_str) < 50:
return []
inner = json.loads(inner_str)
if not (isinstance(inner, list) and len(inner) > 4 and inner[4]):
return []
texts = []
for part in inner[4]:
if isinstance(part, list) and len(part) > 1 and part[1] and isinstance(part[1], list):
for t in part[1]:
if isinstance(t, str) and t:
texts.append(t)
return texts
except (json.JSONDecodeError, IndexError, TypeError):
return []
def extract_response_text(raw: str) -> str:
"""Parse full response to get final text."""
last_text = ""
for line in raw.split("\n"):
for t in _extract_texts_from_line(line):
if len(t) > len(last_text):
last_text = t
return clean_text(last_text)
def generate(prompt: str, model_id: int, think_mode: int, file_refs: list = None, extra_fields: dict = None) -> str:
"""Non-streaming generation with retry."""
body = _build_payload(prompt, model_id, think_mode, file_refs, extra_fields).encode()
url = _get_url()
headers = _build_headers()
ctx = _get_ssl_ctx()
proxy = CONFIG.get("proxy")
last_err = None
for attempt in range(CONFIG["retry_attempts"]):
try:
req = urllib.request.Request(url, data=body, headers=headers, method="POST")
if proxy:
opener = urllib.request.build_opener(
urllib.request.ProxyHandler({"http": proxy, "https": proxy}),
urllib.request.HTTPSHandler(context=ctx)
)
resp = opener.open(req, timeout=CONFIG["request_timeout_sec"])
else:
resp = urllib.request.urlopen(req, context=ctx, timeout=CONFIG["request_timeout_sec"])
raw = resp.read().decode("utf-8", errors="replace")
return extract_response_text(raw)
except Exception as e:
last_err = e
if attempt < CONFIG["retry_attempts"] - 1:
log(f"Retry {attempt+1}/{CONFIG['retry_attempts']}: {e}")
time.sleep(CONFIG["retry_delay_sec"])
raise last_err
def generate_stream(prompt: str, model_id: int, think_mode: int, file_refs: list = None, extra_fields: dict = None):
"""Streaming generation via httpx with retry on connection failure."""
if not HAS_HTTPX:
text = generate(prompt, model_id, think_mode, file_refs, extra_fields)
if text:
yield text
return
body = _build_payload(prompt, model_id, think_mode, file_refs, extra_fields)
url = _get_url()
headers = _build_headers()
client = _get_httpx_client()
last_err = None
for attempt in range(CONFIG["retry_attempts"]):
try:
prev_text = ""
with client.stream("POST", url, content=body, headers=headers) as resp:
buf = ""
for chunk in resp.iter_text():
buf += chunk
while "\n" in buf:
line, buf = buf.split("\n", 1)
for t in _extract_texts_from_line(line):
if len(t) > len(prev_text):
delta = clean_text(t[len(prev_text):])
if delta:
yield delta
prev_text = t
return
except Exception as e:
last_err = e
if attempt < CONFIG["retry_attempts"] - 1:
log(f"Stream retry {attempt+1}/{CONFIG['retry_attempts']}: {e}")
time.sleep(CONFIG["retry_delay_sec"])
raise last_err