Spaces:
Sleeping
Sleeping
| """Google Drive backed shared-state utilities.""" | |
| from __future__ import annotations | |
| import base64 | |
| import json | |
| import logging | |
| import os | |
| import sqlite3 | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Any | |
| import requests | |
| def _safe_name(value: str, default: str = "item") -> str: | |
| text = str(value or "").strip() or default | |
| return "".join(ch if ch.isalnum() or ch in {"-", "_", "."} else "_" for ch in text)[:180] | |
| class GoogleDriveStateClient: | |
| def __init__(self, logger: logging.Logger) -> None: | |
| self.logger = logger | |
| self._folder_cache: dict[str, str] = {} | |
| self._file_cache: dict[str, str] = {} | |
| self._bootstrap_lock = threading.Lock() | |
| self._bootstrap_loaded_at = 0.0 | |
| self._bootstrap_payload: dict[str, Any] = {} | |
| def _bootstrap_url(self) -> str: | |
| return str(os.getenv("GOOGLE_DRIVE_BOOTSTRAP_URL", "") or os.getenv("KAPO_BOOTSTRAP_URL", "") or "").strip() | |
| def _bootstrap_ttl(self) -> float: | |
| return float(os.getenv("GOOGLE_DRIVE_BOOTSTRAP_TTL_SEC", "300") or 300) | |
| def _apply_bootstrap_payload(self, payload: dict[str, Any]) -> None: | |
| mappings = { | |
| "shared_state_backend": "KAPO_SHARED_STATE_BACKEND", | |
| "shared_state_folder_id": "GOOGLE_DRIVE_SHARED_STATE_FOLDER_ID", | |
| "shared_state_prefix": "GOOGLE_DRIVE_SHARED_STATE_PREFIX", | |
| "storage_folder_id": "GOOGLE_DRIVE_STORAGE_FOLDER_ID", | |
| "storage_prefix": "GOOGLE_DRIVE_STORAGE_PREFIX", | |
| "access_token": "GOOGLE_DRIVE_ACCESS_TOKEN", | |
| "refresh_token": "GOOGLE_DRIVE_REFRESH_TOKEN", | |
| "client_secret_json": "GOOGLE_DRIVE_CLIENT_SECRET_JSON", | |
| "client_secret_json_base64": "GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", | |
| "client_secret_path": "GOOGLE_DRIVE_CLIENT_SECRET_PATH", | |
| "token_expires_at": "GOOGLE_DRIVE_TOKEN_EXPIRES_AT", | |
| "firebase_enabled": "FIREBASE_ENABLED", | |
| "executor_url": "EXECUTOR_URL", | |
| "executor_public_url": "EXECUTOR_PUBLIC_URL", | |
| "control_plane_url": "KAPO_CONTROL_PLANE_URL", | |
| "cloudflare_control_plane_url": "KAPO_CONTROL_PLANE_URL", | |
| "cloudflare_queue_name": "KAPO_CLOUDFLARE_QUEUE_NAME", | |
| "ngrok_authtoken": "NGROK_AUTHTOKEN", | |
| "brain_public_url": "BRAIN_PUBLIC_URL", | |
| "brain_roles": "BRAIN_ROLES", | |
| "brain_languages": "BRAIN_LANGUAGES", | |
| } | |
| for key, env_name in mappings.items(): | |
| value = payload.get(key) | |
| if value not in (None, ""): | |
| os.environ[env_name] = str(value) | |
| def ensure_bootstrap_loaded(self, force: bool = False) -> dict[str, Any]: | |
| bootstrap_url = self._bootstrap_url() | |
| if not bootstrap_url: | |
| return {} | |
| now = time.time() | |
| if not force and self._bootstrap_payload and (now - self._bootstrap_loaded_at) < self._bootstrap_ttl(): | |
| return dict(self._bootstrap_payload) | |
| with self._bootstrap_lock: | |
| now = time.time() | |
| if not force and self._bootstrap_payload and (now - self._bootstrap_loaded_at) < self._bootstrap_ttl(): | |
| return dict(self._bootstrap_payload) | |
| try: | |
| response = requests.get(bootstrap_url, timeout=20) | |
| response.raise_for_status() | |
| payload = dict(response.json() or {}) | |
| self._bootstrap_payload = payload | |
| self._bootstrap_loaded_at = time.time() | |
| self._apply_bootstrap_payload(payload) | |
| return dict(payload) | |
| except Exception: | |
| self.logger.warning("Failed to load Google Drive bootstrap manifest", exc_info=True) | |
| return dict(self._bootstrap_payload) | |
| def enabled(self) -> bool: | |
| self.ensure_bootstrap_loaded() | |
| return bool(self.root_folder_id()) | |
| def root_folder_id(self) -> str: | |
| return str( | |
| os.getenv("GOOGLE_DRIVE_SHARED_STATE_FOLDER_ID", "") | |
| or os.getenv("GOOGLE_DRIVE_STORAGE_FOLDER_ID", "") | |
| or "" | |
| ).strip() | |
| def prefix(self) -> str: | |
| return str( | |
| os.getenv("GOOGLE_DRIVE_SHARED_STATE_PREFIX", "") | |
| or os.getenv("GOOGLE_DRIVE_STORAGE_PREFIX", "") | |
| or "kapo/shared_state" | |
| ).strip("/ ") | |
| def _control_plane_token(self, label: str) -> str: | |
| candidates = [ | |
| Path.cwd().resolve() / "data" / "local" / "control_plane.db", | |
| Path.cwd().resolve() / "data" / "drive_cache" / "control_plane.db", | |
| ] | |
| for db_path in candidates: | |
| if not db_path.exists(): | |
| continue | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| conn.row_factory = sqlite3.Row | |
| row = conn.execute( | |
| "SELECT token FROM ngrok_tokens WHERE label = ? AND enabled = 1", | |
| (label,), | |
| ).fetchone() | |
| conn.close() | |
| token = str(row["token"]) if row else "" | |
| if token: | |
| return token | |
| except Exception: | |
| self.logger.warning("Failed to read Google Drive token from %s", db_path, exc_info=True) | |
| return "" | |
| def _client_secret_payload(self) -> dict[str, Any] | None: | |
| raw = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_JSON", "")).strip() | |
| raw_b64 = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", "")).strip() | |
| path = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_PATH", "")).strip() | |
| if raw: | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| self.logger.warning("Invalid GOOGLE_DRIVE_CLIENT_SECRET_JSON", exc_info=True) | |
| if raw_b64: | |
| try: | |
| return json.loads(base64.b64decode(raw_b64).decode("utf-8")) | |
| except Exception: | |
| self.logger.warning("Invalid GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", exc_info=True) | |
| if path: | |
| try: | |
| return json.loads(Path(path).expanduser().read_text(encoding="utf-8")) | |
| except Exception: | |
| self.logger.warning("Failed to load Google Drive client secret path %s", path, exc_info=True) | |
| client_id = str(os.getenv("GOOGLE_DRIVE_CLIENT_ID", "")).strip() | |
| client_secret = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET", "")).strip() | |
| token_uri = str(os.getenv("GOOGLE_DRIVE_TOKEN_URI", "https://oauth2.googleapis.com/token")).strip() | |
| if client_id and client_secret: | |
| return { | |
| "installed": { | |
| "client_id": client_id, | |
| "client_secret": client_secret, | |
| "token_uri": token_uri, | |
| } | |
| } | |
| return None | |
| def _oauth_client(self) -> dict[str, str]: | |
| payload = self._client_secret_payload() or {} | |
| installed = payload.get("installed") or payload.get("web") or {} | |
| return { | |
| "client_id": str(installed.get("client_id") or "").strip(), | |
| "client_secret": str(installed.get("client_secret") or "").strip(), | |
| "token_uri": str(installed.get("token_uri") or "https://oauth2.googleapis.com/token").strip(), | |
| } | |
| def _refresh_token(self) -> str: | |
| return str(os.getenv("GOOGLE_DRIVE_REFRESH_TOKEN", "") or self._control_plane_token("google_drive_refresh_token") or "").strip() | |
| def _access_token(self) -> str: | |
| token = str(os.getenv("GOOGLE_DRIVE_ACCESS_TOKEN", "") or self._control_plane_token("google_drive_access_token") or "").strip() | |
| expires_at = float(str(os.getenv("GOOGLE_DRIVE_TOKEN_EXPIRES_AT", "0") or "0") or 0) | |
| if token and (not expires_at or expires_at > (time.time() + 60)): | |
| return token | |
| refreshed = self._refresh_access_token() | |
| return refreshed or token | |
| def _refresh_access_token(self) -> str: | |
| refresh_token = self._refresh_token() | |
| client = self._oauth_client() | |
| if not refresh_token or not client["client_id"] or not client["client_secret"]: | |
| return "" | |
| response = requests.post( | |
| client["token_uri"], | |
| data={ | |
| "grant_type": "refresh_token", | |
| "refresh_token": refresh_token, | |
| "client_id": client["client_id"], | |
| "client_secret": client["client_secret"], | |
| }, | |
| timeout=30, | |
| ) | |
| response.raise_for_status() | |
| payload = response.json() | |
| access_token = str(payload.get("access_token") or "").strip() | |
| expires_in = float(payload.get("expires_in") or 0) | |
| if access_token: | |
| os.environ["GOOGLE_DRIVE_ACCESS_TOKEN"] = access_token | |
| os.environ["GOOGLE_DRIVE_TOKEN_EXPIRES_AT"] = str(time.time() + expires_in) if expires_in else "0" | |
| return access_token | |
| def _request(self, method: str, url: str, *, headers: dict[str, str] | None = None, timeout: int = 60, **kwargs) -> requests.Response: | |
| token = self._access_token() | |
| if not token: | |
| raise ValueError("Google Drive access token is not configured") | |
| base_headers = dict(headers or {}) | |
| response: requests.Response | None = None | |
| for _ in range(2): | |
| response = requests.request( | |
| method, | |
| url, | |
| headers={**base_headers, "Authorization": f"Bearer {token}"}, | |
| timeout=timeout, | |
| **kwargs, | |
| ) | |
| if response.status_code != 401: | |
| response.raise_for_status() | |
| return response | |
| token = self._refresh_access_token() | |
| if not token: | |
| break | |
| assert response is not None | |
| response.raise_for_status() | |
| return response | |
| def _escape_query(value: str) -> str: | |
| return str(value or "").replace("\\", "\\\\").replace("'", "\\'") | |
| def _find_folder(self, name: str, parent_id: str) -> str: | |
| query = ( | |
| "mimeType = 'application/vnd.google-apps.folder' and trashed = false " | |
| f"and name = '{self._escape_query(name)}'" | |
| ) | |
| if parent_id: | |
| query += f" and '{self._escape_query(parent_id)}' in parents" | |
| response = self._request( | |
| "GET", | |
| "https://www.googleapis.com/drive/v3/files", | |
| params={"q": query, "fields": "files(id,name)", "pageSize": 10, "supportsAllDrives": "true"}, | |
| ) | |
| files = response.json().get("files", []) | |
| return str(files[0].get("id") or "").strip() if files else "" | |
| def _create_folder(self, name: str, parent_id: str) -> str: | |
| payload: dict[str, Any] = {"name": name, "mimeType": "application/vnd.google-apps.folder"} | |
| if parent_id: | |
| payload["parents"] = [parent_id] | |
| response = self._request( | |
| "POST", | |
| "https://www.googleapis.com/drive/v3/files?supportsAllDrives=true", | |
| headers={"Content-Type": "application/json; charset=UTF-8"}, | |
| json=payload, | |
| ) | |
| return str(response.json().get("id") or "").strip() | |
| def _ensure_folder_path(self, relative_path: str) -> str: | |
| current_parent = self.root_folder_id() | |
| parts = [part for part in f"{self.prefix()}/{relative_path}".replace("\\", "/").split("/") if part.strip()] | |
| for part in parts: | |
| safe_part = _safe_name(part, "folder") | |
| cache_key = f"{current_parent}:{safe_part}" | |
| cached = self._folder_cache.get(cache_key) | |
| if cached: | |
| current_parent = cached | |
| continue | |
| folder_id = self._find_folder(safe_part, current_parent) | |
| if not folder_id: | |
| folder_id = self._create_folder(safe_part, current_parent) | |
| self._folder_cache[cache_key] = folder_id | |
| current_parent = folder_id | |
| return current_parent | |
| def _find_file(self, folder_id: str, file_name: str) -> str: | |
| cache_key = f"{folder_id}:{file_name}" | |
| cached = self._file_cache.get(cache_key) | |
| if cached: | |
| return cached | |
| query = ( | |
| "trashed = false " | |
| f"and name = '{self._escape_query(file_name)}' " | |
| f"and '{self._escape_query(folder_id)}' in parents" | |
| ) | |
| response = self._request( | |
| "GET", | |
| "https://www.googleapis.com/drive/v3/files", | |
| params={"q": query, "fields": "files(id,name)", "pageSize": 10, "supportsAllDrives": "true"}, | |
| ) | |
| files = response.json().get("files", []) | |
| file_id = str(files[0].get("id") or "").strip() if files else "" | |
| if file_id: | |
| self._file_cache[cache_key] = file_id | |
| return file_id | |
| def _create_file(self, folder_id: str, file_name: str) -> str: | |
| response = self._request( | |
| "POST", | |
| "https://www.googleapis.com/drive/v3/files?supportsAllDrives=true", | |
| headers={"Content-Type": "application/json; charset=UTF-8"}, | |
| json={"name": file_name, "parents": [folder_id]}, | |
| ) | |
| file_id = str(response.json().get("id") or "").strip() | |
| if file_id: | |
| self._file_cache[f"{folder_id}:{file_name}"] = file_id | |
| return file_id | |
| def _upload_json(self, file_id: str, payload: dict[str, Any]) -> None: | |
| self._request( | |
| "PATCH", | |
| f"https://www.googleapis.com/upload/drive/v3/files/{file_id}?uploadType=media&supportsAllDrives=true", | |
| headers={"Content-Type": "application/json; charset=UTF-8"}, | |
| data=json.dumps(payload, ensure_ascii=False, sort_keys=True).encode("utf-8"), | |
| ) | |
| def _download_json(self, file_id: str) -> dict[str, Any]: | |
| response = self._request( | |
| "GET", | |
| f"https://www.googleapis.com/drive/v3/files/{file_id}?alt=media&supportsAllDrives=true", | |
| ) | |
| return dict(response.json() or {}) | |
| def get_document(self, collection: str, doc_id: str) -> dict[str, Any]: | |
| if not self.enabled(): | |
| return {} | |
| folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) | |
| file_name = f"{_safe_name(doc_id)}.json" | |
| file_id = self._find_file(folder_id, file_name) | |
| if not file_id: | |
| return {} | |
| try: | |
| payload = self._download_json(file_id) | |
| payload.setdefault("id", _safe_name(doc_id)) | |
| return payload | |
| except Exception: | |
| self.logger.warning("Failed to download shared-state file %s/%s", collection, doc_id, exc_info=True) | |
| return {} | |
| def set_document(self, collection: str, doc_id: str, payload: dict[str, Any], *, merge: bool = True) -> bool: | |
| if not self.enabled(): | |
| return False | |
| safe_doc = _safe_name(doc_id) | |
| try: | |
| existing = self.get_document(collection, safe_doc) if merge else {} | |
| combined = {**existing, **dict(payload or {})} if merge else dict(payload or {}) | |
| combined.setdefault("id", safe_doc) | |
| folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) | |
| file_name = f"{safe_doc}.json" | |
| file_id = self._find_file(folder_id, file_name) | |
| if not file_id: | |
| file_id = self._create_file(folder_id, file_name) | |
| self._upload_json(file_id, combined) | |
| return True | |
| except Exception: | |
| self.logger.warning("Failed to upload shared-state file %s/%s", collection, safe_doc, exc_info=True) | |
| return False | |
| def list_documents(self, collection: str, *, limit: int = 200) -> list[dict[str, Any]]: | |
| if not self.enabled(): | |
| return [] | |
| try: | |
| folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) | |
| query = f"trashed = false and '{self._escape_query(folder_id)}' in parents" | |
| response = self._request( | |
| "GET", | |
| "https://www.googleapis.com/drive/v3/files", | |
| params={ | |
| "q": query, | |
| "fields": "files(id,name,modifiedTime)", | |
| "pageSize": max(1, int(limit)), | |
| "supportsAllDrives": "true", | |
| "orderBy": "modifiedTime desc", | |
| }, | |
| ) | |
| items: list[dict[str, Any]] = [] | |
| for file in response.json().get("files", []): | |
| file_id = str(file.get("id") or "").strip() | |
| file_name = str(file.get("name") or "").strip() | |
| if not file_id or not file_name.endswith(".json"): | |
| continue | |
| self._file_cache[f"{folder_id}:{file_name}"] = file_id | |
| payload = self._download_json(file_id) | |
| payload.setdefault("id", file_name[:-5]) | |
| items.append(payload) | |
| return items | |
| except Exception: | |
| self.logger.warning("Failed to list shared-state collection %s", collection, exc_info=True) | |
| return [] | |
| def delete_document(self, collection: str, doc_id: str) -> bool: | |
| if not self.enabled(): | |
| return False | |
| safe_doc = _safe_name(doc_id) | |
| try: | |
| folder_id = self._ensure_folder_path(_safe_name(collection, "collection")) | |
| file_name = f"{safe_doc}.json" | |
| file_id = self._find_file(folder_id, file_name) | |
| if not file_id: | |
| return True | |
| self._request( | |
| "DELETE", | |
| f"https://www.googleapis.com/drive/v3/files/{file_id}?supportsAllDrives=true", | |
| ) | |
| self._file_cache.pop(f"{folder_id}:{file_name}", None) | |
| return True | |
| except Exception: | |
| self.logger.warning("Failed to delete shared-state file %s/%s", collection, safe_doc, exc_info=True) | |
| return False | |