Spaces:
Sleeping
Sleeping
| """Gemini StreamGenerate protocol implementation with httpx streaming.""" | |
| import json | |
| import time | |
| import uuid | |
| import re | |
| import urllib.request | |
| import urllib.parse | |
| import ssl | |
| import os | |
| import hashlib | |
| try: | |
| import httpx | |
| HAS_HTTPX = True | |
| except ImportError: | |
| HAS_HTTPX = False | |
| from .config import CONFIG | |
| _ssl_ctx = None | |
| _cookie_cache = {"str": "", "sapisid": None, "mtime": 0} | |
| _httpx_client = None | |
| def log(msg: str): | |
| if CONFIG["log_requests"]: | |
| import sys | |
| sys.stderr.write(f"[{time.strftime('%H:%M:%S')}] {msg}\n") | |
| sys.stderr.flush() | |
| def _get_ssl_ctx(): | |
| global _ssl_ctx | |
| if _ssl_ctx is None: | |
| _ssl_ctx = ssl.create_default_context() | |
| return _ssl_ctx | |
| def _get_httpx_client(): | |
| global _httpx_client | |
| if _httpx_client is None and HAS_HTTPX: | |
| proxy = CONFIG.get("proxy") | |
| transport = httpx.HTTPTransport(proxy=proxy) if proxy else None | |
| _httpx_client = httpx.Client(transport=transport, timeout=CONFIG["request_timeout_sec"], verify=True) | |
| return _httpx_client | |
| def load_cookie() -> tuple: | |
| """Load cookie from file with mtime-based caching.""" | |
| cookie_file = CONFIG.get("cookie_file") | |
| if not cookie_file or not os.path.exists(cookie_file): | |
| return "", None | |
| try: | |
| mtime = os.path.getmtime(cookie_file) | |
| if mtime == _cookie_cache["mtime"] and _cookie_cache["str"]: | |
| return _cookie_cache["str"], _cookie_cache["sapisid"] | |
| with open(cookie_file, "r") as f: | |
| content = f.read().strip() | |
| if content.startswith("{"): | |
| data = json.loads(content) | |
| cookie_str = data.get("cookie", "") | |
| sapisid = data.get("sapisid", "") | |
| else: | |
| cookie_str = content | |
| pairs = dict(p.split("=", 1) for p in cookie_str.split("; ") if "=" in p) | |
| sapisid = pairs.get("SAPISID", "") | |
| _cookie_cache.update({"str": cookie_str, "sapisid": sapisid or None, "mtime": mtime}) | |
| return cookie_str, sapisid if sapisid else None | |
| except Exception as e: | |
| log(f"Cookie load error: {e}") | |
| return _cookie_cache["str"], _cookie_cache["sapisid"] | |
| def make_sapisidhash(sapisid: str) -> str: | |
| ts = int(time.time()) | |
| h = hashlib.sha1(f"{ts} {sapisid} https://gemini.google.com".encode()).hexdigest() | |
| return f"SAPISIDHASH {ts}_{h}" | |
| def _account_prefix() -> str: | |
| """Return the Gemini account path prefix for non-default Google accounts.""" | |
| auth_user = CONFIG.get("auth_user") | |
| if auth_user is None or auth_user == "": | |
| return "" | |
| return f"/u/{auth_user}" | |
| def _build_headers() -> dict: | |
| account_prefix = _account_prefix() | |
| headers = { | |
| "Content-Type": "application/x-www-form-urlencoded", | |
| "Origin": "https://gemini.google.com", | |
| "Referer": f"https://gemini.google.com{account_prefix}/app", | |
| "X-Same-Domain": "1", | |
| "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36", | |
| } | |
| if account_prefix: | |
| headers["X-Goog-AuthUser"] = str(CONFIG["auth_user"]) | |
| cookie_str, sapisid = load_cookie() | |
| if cookie_str: | |
| headers["Cookie"] = cookie_str | |
| if sapisid: | |
| headers["Authorization"] = make_sapisidhash(sapisid) | |
| return headers | |
| def _build_payload(prompt: str, model_id: int, think_mode: int, file_refs: list = None, extra_fields: dict = None) -> str: | |
| inner = [None] * 102 | |
| if file_refs: | |
| refs = [[None, None, ref] for ref in file_refs] | |
| inner[0] = [prompt, 0, None, refs, None, None, 0] | |
| else: | |
| inner[0] = [prompt, 0, None, None, None, None, 0] | |
| inner[1] = ["en"] | |
| inner[2] = ["", "", "", None, None, None, None, None, None, ""] | |
| inner[6] = [0] | |
| inner[7] = 1 | |
| inner[10] = 1 | |
| inner[11] = 0 | |
| inner[17] = [[think_mode]] | |
| inner[18] = 0 | |
| inner[27] = 1 | |
| inner[30] = [4] | |
| inner[41] = [2] | |
| inner[53] = 0 | |
| inner[59] = str(uuid.uuid4()) | |
| inner[61] = [] | |
| inner[68] = 1 | |
| inner[79] = model_id | |
| if extra_fields: | |
| for k, v in extra_fields.items(): | |
| inner[k] = v | |
| outer = [None, json.dumps(inner)] | |
| params = {"f.req": json.dumps(outer)} | |
| if CONFIG.get("xsrf_token"): | |
| params["at"] = CONFIG["xsrf_token"] | |
| return urllib.parse.urlencode(params) | |
| def _get_url() -> str: | |
| reqid = int(time.time()) % 1000000 | |
| account_prefix = _account_prefix() | |
| return ( | |
| f"https://gemini.google.com{account_prefix}/_/BardChatUi/data/" | |
| "assistant.lamda.BardFrontendService/StreamGenerate" | |
| f"?bl={CONFIG['gemini_bl']}&hl=en&_reqid={reqid}&rt=c" | |
| ) | |
| def clean_text(text: str) -> str: | |
| text = re.sub( | |
| r'```(?:python|javascript|text)\?code_(?:reference|stdout)&code_event_index=\d+\n.*?```\n?', | |
| '', text, flags=re.DOTALL | |
| ) | |
| text = re.sub(r'http://googleusercontent\.com/card_content/\d+\n?', '', text) | |
| return text.strip() | |
| def _extract_texts_from_line(line: str) -> list: | |
| """Parse a single wrb.fr line and return list of text strings found.""" | |
| if '"wrb.fr"' not in line or len(line) < 200: | |
| return [] | |
| try: | |
| arr = json.loads(line) | |
| inner_str = arr[0][2] | |
| if not inner_str or len(inner_str) < 50: | |
| return [] | |
| inner = json.loads(inner_str) | |
| if not (isinstance(inner, list) and len(inner) > 4 and inner[4]): | |
| return [] | |
| texts = [] | |
| for part in inner[4]: | |
| if isinstance(part, list) and len(part) > 1 and part[1] and isinstance(part[1], list): | |
| for t in part[1]: | |
| if isinstance(t, str) and t: | |
| texts.append(t) | |
| return texts | |
| except (json.JSONDecodeError, IndexError, TypeError): | |
| return [] | |
| def extract_response_text(raw: str) -> str: | |
| """Parse full response to get final text.""" | |
| last_text = "" | |
| for line in raw.split("\n"): | |
| for t in _extract_texts_from_line(line): | |
| if len(t) > len(last_text): | |
| last_text = t | |
| return clean_text(last_text) | |
| def generate(prompt: str, model_id: int, think_mode: int, file_refs: list = None, extra_fields: dict = None) -> str: | |
| """Non-streaming generation with retry.""" | |
| body = _build_payload(prompt, model_id, think_mode, file_refs, extra_fields).encode() | |
| url = _get_url() | |
| headers = _build_headers() | |
| ctx = _get_ssl_ctx() | |
| proxy = CONFIG.get("proxy") | |
| last_err = None | |
| for attempt in range(CONFIG["retry_attempts"]): | |
| try: | |
| req = urllib.request.Request(url, data=body, headers=headers, method="POST") | |
| if proxy: | |
| opener = urllib.request.build_opener( | |
| urllib.request.ProxyHandler({"http": proxy, "https": proxy}), | |
| urllib.request.HTTPSHandler(context=ctx) | |
| ) | |
| resp = opener.open(req, timeout=CONFIG["request_timeout_sec"]) | |
| else: | |
| resp = urllib.request.urlopen(req, context=ctx, timeout=CONFIG["request_timeout_sec"]) | |
| raw = resp.read().decode("utf-8", errors="replace") | |
| return extract_response_text(raw) | |
| except Exception as e: | |
| last_err = e | |
| if attempt < CONFIG["retry_attempts"] - 1: | |
| log(f"Retry {attempt+1}/{CONFIG['retry_attempts']}: {e}") | |
| time.sleep(CONFIG["retry_delay_sec"]) | |
| raise last_err | |
| def generate_stream(prompt: str, model_id: int, think_mode: int, file_refs: list = None, extra_fields: dict = None): | |
| """Streaming generation via httpx with retry on connection failure.""" | |
| if not HAS_HTTPX: | |
| text = generate(prompt, model_id, think_mode, file_refs, extra_fields) | |
| if text: | |
| yield text | |
| return | |
| body = _build_payload(prompt, model_id, think_mode, file_refs, extra_fields) | |
| url = _get_url() | |
| headers = _build_headers() | |
| client = _get_httpx_client() | |
| last_err = None | |
| for attempt in range(CONFIG["retry_attempts"]): | |
| try: | |
| prev_text = "" | |
| with client.stream("POST", url, content=body, headers=headers) as resp: | |
| buf = "" | |
| for chunk in resp.iter_text(): | |
| buf += chunk | |
| while "\n" in buf: | |
| line, buf = buf.split("\n", 1) | |
| for t in _extract_texts_from_line(line): | |
| if len(t) > len(prev_text): | |
| delta = clean_text(t[len(prev_text):]) | |
| if delta: | |
| yield delta | |
| prev_text = t | |
| return | |
| except Exception as e: | |
| last_err = e | |
| if attempt < CONFIG["retry_attempts"] - 1: | |
| log(f"Stream retry {attempt+1}/{CONFIG['retry_attempts']}: {e}") | |
| time.sleep(CONFIG["retry_delay_sec"]) | |
| raise last_err | |