| """Google Drive backed shared-state utilities.""" |
|
|
| from __future__ import annotations |
|
|
| import base64 |
| import json |
| import logging |
| import os |
| import sqlite3 |
| import threading |
| import time |
| from pathlib import Path |
| from typing import Any |
|
|
| import requests |
|
|
|
|
| def _safe_name(value: str, default: str = "item") -> str: |
| text = str(value or "").strip() or default |
| return "".join(ch if ch.isalnum() or ch in {"-", "_", "."} else "_" for ch in text)[:180] |
|
|
|
|
| class GoogleDriveStateClient: |
| def __init__(self, logger: logging.Logger) -> None: |
| self.logger = logger |
| self._folder_cache: dict[str, str] = {} |
| self._file_cache: dict[str, str] = {} |
| self._bootstrap_lock = threading.Lock() |
| self._bootstrap_loaded_at = 0.0 |
| self._bootstrap_payload: dict[str, Any] = {} |
|
|
| def _bootstrap_url(self) -> str: |
| return str(os.getenv("GOOGLE_DRIVE_BOOTSTRAP_URL", "") or os.getenv("KAPO_BOOTSTRAP_URL", "") or "").strip() |
|
|
| def _bootstrap_ttl(self) -> float: |
| return float(os.getenv("GOOGLE_DRIVE_BOOTSTRAP_TTL_SEC", "300") or 300) |
|
|
| def _apply_bootstrap_payload(self, payload: dict[str, Any]) -> None: |
| mappings = { |
| "shared_state_backend": "KAPO_SHARED_STATE_BACKEND", |
| "shared_state_folder_id": "GOOGLE_DRIVE_SHARED_STATE_FOLDER_ID", |
| "shared_state_prefix": "GOOGLE_DRIVE_SHARED_STATE_PREFIX", |
| "storage_folder_id": "GOOGLE_DRIVE_STORAGE_FOLDER_ID", |
| "storage_prefix": "GOOGLE_DRIVE_STORAGE_PREFIX", |
| "access_token": "GOOGLE_DRIVE_ACCESS_TOKEN", |
| "refresh_token": "GOOGLE_DRIVE_REFRESH_TOKEN", |
| "client_secret_json": "GOOGLE_DRIVE_CLIENT_SECRET_JSON", |
| "client_secret_json_base64": "GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", |
| "client_secret_path": "GOOGLE_DRIVE_CLIENT_SECRET_PATH", |
| "token_expires_at": "GOOGLE_DRIVE_TOKEN_EXPIRES_AT", |
| "firebase_enabled": "FIREBASE_ENABLED", |
| "executor_url": "EXECUTOR_URL", |
| "executor_public_url": "EXECUTOR_PUBLIC_URL", |
| "ngrok_authtoken": "NGROK_AUTHTOKEN", |
| "brain_public_url": "BRAIN_PUBLIC_URL", |
| "brain_roles": "BRAIN_ROLES", |
| "brain_languages": "BRAIN_LANGUAGES", |
| } |
| for key, env_name in mappings.items(): |
| value = payload.get(key) |
| if value not in (None, ""): |
| os.environ[env_name] = str(value) |
|
|
| def ensure_bootstrap_loaded(self, force: bool = False) -> dict[str, Any]: |
| bootstrap_url = self._bootstrap_url() |
| if not bootstrap_url: |
| return {} |
| now = time.time() |
| if not force and self._bootstrap_payload and (now - self._bootstrap_loaded_at) < self._bootstrap_ttl(): |
| return dict(self._bootstrap_payload) |
| with self._bootstrap_lock: |
| now = time.time() |
| if not force and self._bootstrap_payload and (now - self._bootstrap_loaded_at) < self._bootstrap_ttl(): |
| return dict(self._bootstrap_payload) |
| try: |
| response = requests.get(bootstrap_url, timeout=20) |
| response.raise_for_status() |
| payload = dict(response.json() or {}) |
| self._bootstrap_payload = payload |
| self._bootstrap_loaded_at = time.time() |
| self._apply_bootstrap_payload(payload) |
| return dict(payload) |
| except Exception: |
| self.logger.warning("Failed to load Google Drive bootstrap manifest", exc_info=True) |
| return dict(self._bootstrap_payload) |
|
|
| def enabled(self) -> bool: |
| self.ensure_bootstrap_loaded() |
| return bool(self.root_folder_id()) |
|
|
| def root_folder_id(self) -> str: |
| return str( |
| os.getenv("GOOGLE_DRIVE_SHARED_STATE_FOLDER_ID", "") |
| or os.getenv("GOOGLE_DRIVE_STORAGE_FOLDER_ID", "") |
| or "" |
| ).strip() |
|
|
| def prefix(self) -> str: |
| return str( |
| os.getenv("GOOGLE_DRIVE_SHARED_STATE_PREFIX", "") |
| or os.getenv("GOOGLE_DRIVE_STORAGE_PREFIX", "") |
| or "kapo/shared_state" |
| ).strip("/ ") |
|
|
| def _control_plane_token(self, label: str) -> str: |
| candidates = [ |
| Path.cwd().resolve() / "data" / "local" / "control_plane.db", |
| Path.cwd().resolve() / "data" / "drive_cache" / "control_plane.db", |
| ] |
| for db_path in candidates: |
| if not db_path.exists(): |
| continue |
| try: |
| conn = sqlite3.connect(db_path) |
| conn.row_factory = sqlite3.Row |
| row = conn.execute( |
| "SELECT token FROM ngrok_tokens WHERE label = ? AND enabled = 1", |
| (label,), |
| ).fetchone() |
| conn.close() |
| token = str(row["token"]) if row else "" |
| if token: |
| return token |
| except Exception: |
| self.logger.warning("Failed to read Google Drive token from %s", db_path, exc_info=True) |
| return "" |
|
|
| def _client_secret_payload(self) -> dict[str, Any] | None: |
| raw = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_JSON", "")).strip() |
| raw_b64 = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", "")).strip() |
| path = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_PATH", "")).strip() |
| if raw: |
| try: |
| return json.loads(raw) |
| except Exception: |
| self.logger.warning("Invalid GOOGLE_DRIVE_CLIENT_SECRET_JSON", exc_info=True) |
| if raw_b64: |
| try: |
| return json.loads(base64.b64decode(raw_b64).decode("utf-8")) |
| except Exception: |
| self.logger.warning("Invalid GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", exc_info=True) |
| if path: |
| try: |
| return json.loads(Path(path).expanduser().read_text(encoding="utf-8")) |
| except Exception: |
| self.logger.warning("Failed to load Google Drive client secret path %s", path, exc_info=True) |
| client_id = str(os.getenv("GOOGLE_DRIVE_CLIENT_ID", "")).strip() |
| client_secret = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET", "")).strip() |
| token_uri = str(os.getenv("GOOGLE_DRIVE_TOKEN_URI", "https://oauth2.googleapis.com/token")).strip() |
| if client_id and client_secret: |
| return { |
| "installed": { |
| "client_id": client_id, |
| "client_secret": client_secret, |
| "token_uri": token_uri, |
| } |
| } |
| return None |
|
|
| def _oauth_client(self) -> dict[str, str]: |
| payload = self._client_secret_payload() or {} |
| installed = payload.get("installed") or payload.get("web") or {} |
| return { |
| "client_id": str(installed.get("client_id") or "").strip(), |
| "client_secret": str(installed.get("client_secret") or "").strip(), |
| "token_uri": str(installed.get("token_uri") or "https://oauth2.googleapis.com/token").strip(), |
| } |
|
|
| def _refresh_token(self) -> str: |
| return str(os.getenv("GOOGLE_DRIVE_REFRESH_TOKEN", "") or self._control_plane_token("google_drive_refresh_token") or "").strip() |
|
|
| def _access_token(self) -> str: |
| token = str(os.getenv("GOOGLE_DRIVE_ACCESS_TOKEN", "") or self._control_plane_token("google_drive_access_token") or "").strip() |
| expires_at = float(str(os.getenv("GOOGLE_DRIVE_TOKEN_EXPIRES_AT", "0") or "0") or 0) |
| if token and (not expires_at or expires_at > (time.time() + 60)): |
| return token |
| refreshed = self._refresh_access_token() |
| return refreshed or token |
|
|
| def _refresh_access_token(self) -> str: |
| refresh_token = self._refresh_token() |
| client = self._oauth_client() |
| if not refresh_token or not client["client_id"] or not client["client_secret"]: |
| return "" |
| response = requests.post( |
| client["token_uri"], |
| data={ |
| "grant_type": "refresh_token", |
| "refresh_token": refresh_token, |
| "client_id": client["client_id"], |
| "client_secret": client["client_secret"], |
| }, |
| timeout=30, |
| ) |
| response.raise_for_status() |
| payload = response.json() |
| access_token = str(payload.get("access_token") or "").strip() |
| expires_in = float(payload.get("expires_in") or 0) |
| if access_token: |
| os.environ["GOOGLE_DRIVE_ACCESS_TOKEN"] = access_token |
| os.environ["GOOGLE_DRIVE_TOKEN_EXPIRES_AT"] = str(time.time() + expires_in) if expires_in else "0" |
| return access_token |
|
|
| def _request(self, method: str, url: str, *, headers: dict[str, str] | None = None, timeout: int = 60, **kwargs) -> requests.Response: |
| token = self._access_token() |
| if not token: |
| raise ValueError("Google Drive access token is not configured") |
| base_headers = dict(headers or {}) |
| response: requests.Response | None = None |
| for _ in range(2): |
| response = requests.request( |
| method, |
| url, |
| headers={**base_headers, "Authorization": f"Bearer {token}"}, |
| timeout=timeout, |
| **kwargs, |
| ) |
| if response.status_code != 401: |
| response.raise_for_status() |
| return response |
| token = self._refresh_access_token() |
| if not token: |
| break |
| assert response is not None |
| response.raise_for_status() |
| return response |
|
|
| @staticmethod |
| def _escape_query(value: str) -> str: |
| return str(value or "").replace("\\", "\\\\").replace("'", "\\'") |
|
|
| def _find_folder(self, name: str, parent_id: str) -> str: |
| query = ( |
| "mimeType = 'application/vnd.google-apps.folder' and trashed = false " |
| f"and name = '{self._escape_query(name)}'" |
| ) |
| if parent_id: |
| query += f" and '{self._escape_query(parent_id)}' in parents" |
| response = self._request( |
| "GET", |
| "https://www.googleapis.com/drive/v3/files", |
| params={"q": query, "fields": "files(id,name)", "pageSize": 10, "supportsAllDrives": "true"}, |
| ) |
| files = response.json().get("files", []) |
| return str(files[0].get("id") or "").strip() if files else "" |
|
|
| def _create_folder(self, name: str, parent_id: str) -> str: |
| payload: dict[str, Any] = {"name": name, "mimeType": "application/vnd.google-apps.folder"} |
| if parent_id: |
| payload["parents"] = [parent_id] |
| response = self._request( |
| "POST", |
| "https://www.googleapis.com/drive/v3/files?supportsAllDrives=true", |
| headers={"Content-Type": "application/json; charset=UTF-8"}, |
| json=payload, |
| ) |
| return str(response.json().get("id") or "").strip() |
|
|
| def _ensure_folder_path(self, relative_path: str) -> str: |
| current_parent = self.root_folder_id() |
| parts = [part for part in f"{self.prefix()}/{relative_path}".replace("\\", "/").split("/") if part.strip()] |
| for part in parts: |
| safe_part = _safe_name(part, "folder") |
| cache_key = f"{current_parent}:{safe_part}" |
| cached = self._folder_cache.get(cache_key) |
| if cached: |
| current_parent = cached |
| continue |
| folder_id = self._find_folder(safe_part, current_parent) |
| if not folder_id: |
| folder_id = self._create_folder(safe_part, current_parent) |
| self._folder_cache[cache_key] = folder_id |
| current_parent = folder_id |
| return current_parent |
|
|
| def _find_file(self, folder_id: str, file_name: str) -> str: |
| cache_key = f"{folder_id}:{file_name}" |
| cached = self._file_cache.get(cache_key) |
| if cached: |
| return cached |
| query = ( |
| "trashed = false " |
| f"and name = '{self._escape_query(file_name)}' " |
| f"and '{self._escape_query(folder_id)}' in parents" |
| ) |
| response = self._request( |
| "GET", |
| "https://www.googleapis.com/drive/v3/files", |
| params={"q": query, "fields": "files(id,name)", "pageSize": 10, "supportsAllDrives": "true"}, |
| ) |
| files = response.json().get("files", []) |
| file_id = str(files[0].get("id") or "").strip() if files else "" |
| if file_id: |
| self._file_cache[cache_key] = file_id |
| return file_id |
|
|
| def _create_file(self, folder_id: str, file_name: str) -> str: |
| response = self._request( |
| "POST", |
| "https://www.googleapis.com/drive/v3/files?supportsAllDrives=true", |
| headers={"Content-Type": "application/json; charset=UTF-8"}, |
| json={"name": file_name, "parents": [folder_id]}, |
| ) |
| file_id = str(response.json().get("id") or "").strip() |
| if file_id: |
| self._file_cache[f"{folder_id}:{file_name}"] = file_id |
| return file_id |
|
|
| def _upload_json(self, file_id: str, payload: dict[str, Any]) -> None: |
| self._request( |
| "PATCH", |
| f"https://www.googleapis.com/upload/drive/v3/files/{file_id}?uploadType=media&supportsAllDrives=true", |
| headers={"Content-Type": "application/json; charset=UTF-8"}, |
| data=json.dumps(payload, ensure_ascii=False, sort_keys=True).encode("utf-8"), |
| ) |
|
|
| def _download_json(self, file_id: str) -> dict[str, Any]: |
| response = self._request( |
| "GET", |
| f"https://www.googleapis.com/drive/v3/files/{file_id}?alt=media&supportsAllDrives=true", |
| ) |
| return dict(response.json() or {}) |
|
|
| def get_document(self, collection: str, doc_id: str) -> dict[str, Any]: |
| if not self.enabled(): |
| return {} |
| folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) |
| file_name = f"{_safe_name(doc_id)}.json" |
| file_id = self._find_file(folder_id, file_name) |
| if not file_id: |
| return {} |
| try: |
| payload = self._download_json(file_id) |
| payload.setdefault("id", _safe_name(doc_id)) |
| return payload |
| except Exception: |
| self.logger.warning("Failed to download shared-state file %s/%s", collection, doc_id, exc_info=True) |
| return {} |
|
|
| def set_document(self, collection: str, doc_id: str, payload: dict[str, Any], *, merge: bool = True) -> bool: |
| if not self.enabled(): |
| return False |
| safe_doc = _safe_name(doc_id) |
| try: |
| existing = self.get_document(collection, safe_doc) if merge else {} |
| combined = {**existing, **dict(payload or {})} if merge else dict(payload or {}) |
| combined.setdefault("id", safe_doc) |
| folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) |
| file_name = f"{safe_doc}.json" |
| file_id = self._find_file(folder_id, file_name) |
| if not file_id: |
| file_id = self._create_file(folder_id, file_name) |
| self._upload_json(file_id, combined) |
| return True |
| except Exception: |
| self.logger.warning("Failed to upload shared-state file %s/%s", collection, safe_doc, exc_info=True) |
| return False |
|
|
| def list_documents(self, collection: str, *, limit: int = 200) -> list[dict[str, Any]]: |
| if not self.enabled(): |
| return [] |
| try: |
| folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) |
| query = f"trashed = false and '{self._escape_query(folder_id)}' in parents" |
| response = self._request( |
| "GET", |
| "https://www.googleapis.com/drive/v3/files", |
| params={ |
| "q": query, |
| "fields": "files(id,name,modifiedTime)", |
| "pageSize": max(1, int(limit)), |
| "supportsAllDrives": "true", |
| "orderBy": "modifiedTime desc", |
| }, |
| ) |
| items: list[dict[str, Any]] = [] |
| for file in response.json().get("files", []): |
| file_id = str(file.get("id") or "").strip() |
| file_name = str(file.get("name") or "").strip() |
| if not file_id or not file_name.endswith(".json"): |
| continue |
| self._file_cache[f"{folder_id}:{file_name}"] = file_id |
| payload = self._download_json(file_id) |
| payload.setdefault("id", file_name[:-5]) |
| items.append(payload) |
| return items |
| except Exception: |
| self.logger.warning("Failed to list shared-state collection %s", collection, exc_info=True) |
| return [] |
|
|
| def delete_document(self, collection: str, doc_id: str) -> bool: |
| if not self.enabled(): |
| return False |
| safe_doc = _safe_name(doc_id) |
| try: |
| folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) |
| file_name = f"{safe_doc}.json" |
| file_id = self._find_file(folder_id, file_name) |
| if not file_id: |
| return True |
| self._request( |
| "DELETE", |
| f"https://www.googleapis.com/drive/v3/files/{file_id}?supportsAllDrives=true", |
| ) |
| self._file_cache.pop(f"{folder_id}:{file_name}", None) |
| return True |
| except Exception: |
| self.logger.warning("Failed to delete shared-state file %s/%s", collection, safe_doc, exc_info=True) |
| return False |
|
|