"""Google Drive backed shared-state utilities.""" from __future__ import annotations import base64 import json import logging import os import sqlite3 import threading import time from pathlib import Path from typing import Any import requests def _safe_name(value: str, default: str = "item") -> str: text = str(value or "").strip() or default return "".join(ch if ch.isalnum() or ch in {"-", "_", "."} else "_" for ch in text)[:180] class GoogleDriveStateClient: def __init__(self, logger: logging.Logger) -> None: self.logger = logger self._folder_cache: dict[str, str] = {} self._file_cache: dict[str, str] = {} self._bootstrap_lock = threading.Lock() self._bootstrap_loaded_at = 0.0 self._bootstrap_payload: dict[str, Any] = {} def _bootstrap_url(self) -> str: return str(os.getenv("GOOGLE_DRIVE_BOOTSTRAP_URL", "") or os.getenv("KAPO_BOOTSTRAP_URL", "") or "").strip() def _bootstrap_ttl(self) -> float: return float(os.getenv("GOOGLE_DRIVE_BOOTSTRAP_TTL_SEC", "300") or 300) def _apply_bootstrap_payload(self, payload: dict[str, Any]) -> None: mappings = { "shared_state_backend": "KAPO_SHARED_STATE_BACKEND", "shared_state_folder_id": "GOOGLE_DRIVE_SHARED_STATE_FOLDER_ID", "shared_state_prefix": "GOOGLE_DRIVE_SHARED_STATE_PREFIX", "storage_folder_id": "GOOGLE_DRIVE_STORAGE_FOLDER_ID", "storage_prefix": "GOOGLE_DRIVE_STORAGE_PREFIX", "access_token": "GOOGLE_DRIVE_ACCESS_TOKEN", "refresh_token": "GOOGLE_DRIVE_REFRESH_TOKEN", "client_secret_json": "GOOGLE_DRIVE_CLIENT_SECRET_JSON", "client_secret_json_base64": "GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", "client_secret_path": "GOOGLE_DRIVE_CLIENT_SECRET_PATH", "token_expires_at": "GOOGLE_DRIVE_TOKEN_EXPIRES_AT", "firebase_enabled": "FIREBASE_ENABLED", "executor_url": "EXECUTOR_URL", "executor_public_url": "EXECUTOR_PUBLIC_URL", "ngrok_authtoken": "NGROK_AUTHTOKEN", "brain_public_url": "BRAIN_PUBLIC_URL", "brain_roles": "BRAIN_ROLES", "brain_languages": "BRAIN_LANGUAGES", } for key, env_name in mappings.items(): value = payload.get(key) if value not in (None, ""): os.environ[env_name] = str(value) def ensure_bootstrap_loaded(self, force: bool = False) -> dict[str, Any]: bootstrap_url = self._bootstrap_url() if not bootstrap_url: return {} now = time.time() if not force and self._bootstrap_payload and (now - self._bootstrap_loaded_at) < self._bootstrap_ttl(): return dict(self._bootstrap_payload) with self._bootstrap_lock: now = time.time() if not force and self._bootstrap_payload and (now - self._bootstrap_loaded_at) < self._bootstrap_ttl(): return dict(self._bootstrap_payload) try: response = requests.get(bootstrap_url, timeout=20) response.raise_for_status() payload = dict(response.json() or {}) self._bootstrap_payload = payload self._bootstrap_loaded_at = time.time() self._apply_bootstrap_payload(payload) return dict(payload) except Exception: self.logger.warning("Failed to load Google Drive bootstrap manifest", exc_info=True) return dict(self._bootstrap_payload) def enabled(self) -> bool: self.ensure_bootstrap_loaded() return bool(self.root_folder_id()) def root_folder_id(self) -> str: return str( os.getenv("GOOGLE_DRIVE_SHARED_STATE_FOLDER_ID", "") or os.getenv("GOOGLE_DRIVE_STORAGE_FOLDER_ID", "") or "" ).strip() def prefix(self) -> str: return str( os.getenv("GOOGLE_DRIVE_SHARED_STATE_PREFIX", "") or os.getenv("GOOGLE_DRIVE_STORAGE_PREFIX", "") or "kapo/shared_state" ).strip("/ ") def _control_plane_token(self, label: str) -> str: candidates = [ Path.cwd().resolve() / "data" / "local" / "control_plane.db", Path.cwd().resolve() / "data" / "drive_cache" / "control_plane.db", ] for db_path in candidates: if not db_path.exists(): continue try: conn = sqlite3.connect(db_path) conn.row_factory = sqlite3.Row row = conn.execute( "SELECT token FROM ngrok_tokens WHERE label = ? AND enabled = 1", (label,), ).fetchone() conn.close() token = str(row["token"]) if row else "" if token: return token except Exception: self.logger.warning("Failed to read Google Drive token from %s", db_path, exc_info=True) return "" def _client_secret_payload(self) -> dict[str, Any] | None: raw = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_JSON", "")).strip() raw_b64 = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", "")).strip() path = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_PATH", "")).strip() if raw: try: return json.loads(raw) except Exception: self.logger.warning("Invalid GOOGLE_DRIVE_CLIENT_SECRET_JSON", exc_info=True) if raw_b64: try: return json.loads(base64.b64decode(raw_b64).decode("utf-8")) except Exception: self.logger.warning("Invalid GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", exc_info=True) if path: try: return json.loads(Path(path).expanduser().read_text(encoding="utf-8")) except Exception: self.logger.warning("Failed to load Google Drive client secret path %s", path, exc_info=True) client_id = str(os.getenv("GOOGLE_DRIVE_CLIENT_ID", "")).strip() client_secret = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET", "")).strip() token_uri = str(os.getenv("GOOGLE_DRIVE_TOKEN_URI", "https://oauth2.googleapis.com/token")).strip() if client_id and client_secret: return { "installed": { "client_id": client_id, "client_secret": client_secret, "token_uri": token_uri, } } return None def _oauth_client(self) -> dict[str, str]: payload = self._client_secret_payload() or {} installed = payload.get("installed") or payload.get("web") or {} return { "client_id": str(installed.get("client_id") or "").strip(), "client_secret": str(installed.get("client_secret") or "").strip(), "token_uri": str(installed.get("token_uri") or "https://oauth2.googleapis.com/token").strip(), } def _refresh_token(self) -> str: return str(os.getenv("GOOGLE_DRIVE_REFRESH_TOKEN", "") or self._control_plane_token("google_drive_refresh_token") or "").strip() def _access_token(self) -> str: token = str(os.getenv("GOOGLE_DRIVE_ACCESS_TOKEN", "") or self._control_plane_token("google_drive_access_token") or "").strip() expires_at = float(str(os.getenv("GOOGLE_DRIVE_TOKEN_EXPIRES_AT", "0") or "0") or 0) if token and (not expires_at or expires_at > (time.time() + 60)): return token refreshed = self._refresh_access_token() return refreshed or token def _refresh_access_token(self) -> str: refresh_token = self._refresh_token() client = self._oauth_client() if not refresh_token or not client["client_id"] or not client["client_secret"]: return "" response = requests.post( client["token_uri"], data={ "grant_type": "refresh_token", "refresh_token": refresh_token, "client_id": client["client_id"], "client_secret": client["client_secret"], }, timeout=30, ) response.raise_for_status() payload = response.json() access_token = str(payload.get("access_token") or "").strip() expires_in = float(payload.get("expires_in") or 0) if access_token: os.environ["GOOGLE_DRIVE_ACCESS_TOKEN"] = access_token os.environ["GOOGLE_DRIVE_TOKEN_EXPIRES_AT"] = str(time.time() + expires_in) if expires_in else "0" return access_token def _request(self, method: str, url: str, *, headers: dict[str, str] | None = None, timeout: int = 60, **kwargs) -> requests.Response: token = self._access_token() if not token: raise ValueError("Google Drive access token is not configured") base_headers = dict(headers or {}) response: requests.Response | None = None for _ in range(2): response = requests.request( method, url, headers={**base_headers, "Authorization": f"Bearer {token}"}, timeout=timeout, **kwargs, ) if response.status_code != 401: response.raise_for_status() return response token = self._refresh_access_token() if not token: break assert response is not None response.raise_for_status() return response @staticmethod def _escape_query(value: str) -> str: return str(value or "").replace("\\", "\\\\").replace("'", "\\'") def _find_folder(self, name: str, parent_id: str) -> str: query = ( "mimeType = 'application/vnd.google-apps.folder' and trashed = false " f"and name = '{self._escape_query(name)}'" ) if parent_id: query += f" and '{self._escape_query(parent_id)}' in parents" response = self._request( "GET", "https://www.googleapis.com/drive/v3/files", params={"q": query, "fields": "files(id,name)", "pageSize": 10, "supportsAllDrives": "true"}, ) files = response.json().get("files", []) return str(files[0].get("id") or "").strip() if files else "" def _create_folder(self, name: str, parent_id: str) -> str: payload: dict[str, Any] = {"name": name, "mimeType": "application/vnd.google-apps.folder"} if parent_id: payload["parents"] = [parent_id] response = self._request( "POST", "https://www.googleapis.com/drive/v3/files?supportsAllDrives=true", headers={"Content-Type": "application/json; charset=UTF-8"}, json=payload, ) return str(response.json().get("id") or "").strip() def _ensure_folder_path(self, relative_path: str) -> str: current_parent = self.root_folder_id() parts = [part for part in f"{self.prefix()}/{relative_path}".replace("\\", "/").split("/") if part.strip()] for part in parts: safe_part = _safe_name(part, "folder") cache_key = f"{current_parent}:{safe_part}" cached = self._folder_cache.get(cache_key) if cached: current_parent = cached continue folder_id = self._find_folder(safe_part, current_parent) if not folder_id: folder_id = self._create_folder(safe_part, current_parent) self._folder_cache[cache_key] = folder_id current_parent = folder_id return current_parent def _find_file(self, folder_id: str, file_name: str) -> str: cache_key = f"{folder_id}:{file_name}" cached = self._file_cache.get(cache_key) if cached: return cached query = ( "trashed = false " f"and name = '{self._escape_query(file_name)}' " f"and '{self._escape_query(folder_id)}' in parents" ) response = self._request( "GET", "https://www.googleapis.com/drive/v3/files", params={"q": query, "fields": "files(id,name)", "pageSize": 10, "supportsAllDrives": "true"}, ) files = response.json().get("files", []) file_id = str(files[0].get("id") or "").strip() if files else "" if file_id: self._file_cache[cache_key] = file_id return file_id def _create_file(self, folder_id: str, file_name: str) -> str: response = self._request( "POST", "https://www.googleapis.com/drive/v3/files?supportsAllDrives=true", headers={"Content-Type": "application/json; charset=UTF-8"}, json={"name": file_name, "parents": [folder_id]}, ) file_id = str(response.json().get("id") or "").strip() if file_id: self._file_cache[f"{folder_id}:{file_name}"] = file_id return file_id def _upload_json(self, file_id: str, payload: dict[str, Any]) -> None: self._request( "PATCH", f"https://www.googleapis.com/upload/drive/v3/files/{file_id}?uploadType=media&supportsAllDrives=true", headers={"Content-Type": "application/json; charset=UTF-8"}, data=json.dumps(payload, ensure_ascii=False, sort_keys=True).encode("utf-8"), ) def _download_json(self, file_id: str) -> dict[str, Any]: response = self._request( "GET", f"https://www.googleapis.com/drive/v3/files/{file_id}?alt=media&supportsAllDrives=true", ) return dict(response.json() or {}) def get_document(self, collection: str, doc_id: str) -> dict[str, Any]: if not self.enabled(): return {} folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) file_name = f"{_safe_name(doc_id)}.json" file_id = self._find_file(folder_id, file_name) if not file_id: return {} try: payload = self._download_json(file_id) payload.setdefault("id", _safe_name(doc_id)) return payload except Exception: self.logger.warning("Failed to download shared-state file %s/%s", collection, doc_id, exc_info=True) return {} def set_document(self, collection: str, doc_id: str, payload: dict[str, Any], *, merge: bool = True) -> bool: if not self.enabled(): return False safe_doc = _safe_name(doc_id) try: existing = self.get_document(collection, safe_doc) if merge else {} combined = {**existing, **dict(payload or {})} if merge else dict(payload or {}) combined.setdefault("id", safe_doc) folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) file_name = f"{safe_doc}.json" file_id = self._find_file(folder_id, file_name) if not file_id: file_id = self._create_file(folder_id, file_name) self._upload_json(file_id, combined) return True except Exception: self.logger.warning("Failed to upload shared-state file %s/%s", collection, safe_doc, exc_info=True) return False def list_documents(self, collection: str, *, limit: int = 200) -> list[dict[str, Any]]: if not self.enabled(): return [] try: folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) query = f"trashed = false and '{self._escape_query(folder_id)}' in parents" response = self._request( "GET", "https://www.googleapis.com/drive/v3/files", params={ "q": query, "fields": "files(id,name,modifiedTime)", "pageSize": max(1, int(limit)), "supportsAllDrives": "true", "orderBy": "modifiedTime desc", }, ) items: list[dict[str, Any]] = [] for file in response.json().get("files", []): file_id = str(file.get("id") or "").strip() file_name = str(file.get("name") or "").strip() if not file_id or not file_name.endswith(".json"): continue self._file_cache[f"{folder_id}:{file_name}"] = file_id payload = self._download_json(file_id) payload.setdefault("id", file_name[:-5]) items.append(payload) return items except Exception: self.logger.warning("Failed to list shared-state collection %s", collection, exc_info=True) return [] def delete_document(self, collection: str, doc_id: str) -> bool: if not self.enabled(): return False safe_doc = _safe_name(doc_id) try: folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) file_name = f"{safe_doc}.json" file_id = self._find_file(folder_id, file_name) if not file_id: return True self._request( "DELETE", f"https://www.googleapis.com/drive/v3/files/{file_id}?supportsAllDrives=true", ) self._file_cache.pop(f"{folder_id}:{file_name}", None) return True except Exception: self.logger.warning("Failed to delete shared-state file %s/%s", collection, safe_doc, exc_info=True) return False