AiCoder / shared /google_drive_state.py
MrA7A1's picture
Sync modernized KAPO runtime from control plane
b3f1931 verified
"""Google Drive backed shared-state utilities."""
from __future__ import annotations
import base64
import json
import logging
import os
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any
import requests
def _safe_name(value: str, default: str = "item") -> str:
text = str(value or "").strip() or default
return "".join(ch if ch.isalnum() or ch in {"-", "_", "."} else "_" for ch in text)[:180]
class GoogleDriveStateClient:
def __init__(self, logger: logging.Logger) -> None:
self.logger = logger
self._folder_cache: dict[str, str] = {}
self._file_cache: dict[str, str] = {}
self._bootstrap_lock = threading.Lock()
self._bootstrap_loaded_at = 0.0
self._bootstrap_payload: dict[str, Any] = {}
def _bootstrap_url(self) -> str:
return str(os.getenv("GOOGLE_DRIVE_BOOTSTRAP_URL", "") or os.getenv("KAPO_BOOTSTRAP_URL", "") or "").strip()
def _bootstrap_ttl(self) -> float:
return float(os.getenv("GOOGLE_DRIVE_BOOTSTRAP_TTL_SEC", "300") or 300)
def _apply_bootstrap_payload(self, payload: dict[str, Any]) -> None:
mappings = {
"shared_state_backend": "KAPO_SHARED_STATE_BACKEND",
"shared_state_folder_id": "GOOGLE_DRIVE_SHARED_STATE_FOLDER_ID",
"shared_state_prefix": "GOOGLE_DRIVE_SHARED_STATE_PREFIX",
"storage_folder_id": "GOOGLE_DRIVE_STORAGE_FOLDER_ID",
"storage_prefix": "GOOGLE_DRIVE_STORAGE_PREFIX",
"access_token": "GOOGLE_DRIVE_ACCESS_TOKEN",
"refresh_token": "GOOGLE_DRIVE_REFRESH_TOKEN",
"client_secret_json": "GOOGLE_DRIVE_CLIENT_SECRET_JSON",
"client_secret_json_base64": "GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64",
"client_secret_path": "GOOGLE_DRIVE_CLIENT_SECRET_PATH",
"token_expires_at": "GOOGLE_DRIVE_TOKEN_EXPIRES_AT",
"firebase_enabled": "FIREBASE_ENABLED",
"executor_url": "EXECUTOR_URL",
"executor_public_url": "EXECUTOR_PUBLIC_URL",
"ngrok_authtoken": "NGROK_AUTHTOKEN",
"brain_public_url": "BRAIN_PUBLIC_URL",
"brain_roles": "BRAIN_ROLES",
"brain_languages": "BRAIN_LANGUAGES",
}
for key, env_name in mappings.items():
value = payload.get(key)
if value not in (None, ""):
os.environ[env_name] = str(value)
def ensure_bootstrap_loaded(self, force: bool = False) -> dict[str, Any]:
bootstrap_url = self._bootstrap_url()
if not bootstrap_url:
return {}
now = time.time()
if not force and self._bootstrap_payload and (now - self._bootstrap_loaded_at) < self._bootstrap_ttl():
return dict(self._bootstrap_payload)
with self._bootstrap_lock:
now = time.time()
if not force and self._bootstrap_payload and (now - self._bootstrap_loaded_at) < self._bootstrap_ttl():
return dict(self._bootstrap_payload)
try:
response = requests.get(bootstrap_url, timeout=20)
response.raise_for_status()
payload = dict(response.json() or {})
self._bootstrap_payload = payload
self._bootstrap_loaded_at = time.time()
self._apply_bootstrap_payload(payload)
return dict(payload)
except Exception:
self.logger.warning("Failed to load Google Drive bootstrap manifest", exc_info=True)
return dict(self._bootstrap_payload)
def enabled(self) -> bool:
self.ensure_bootstrap_loaded()
return bool(self.root_folder_id())
def root_folder_id(self) -> str:
return str(
os.getenv("GOOGLE_DRIVE_SHARED_STATE_FOLDER_ID", "")
or os.getenv("GOOGLE_DRIVE_STORAGE_FOLDER_ID", "")
or ""
).strip()
def prefix(self) -> str:
return str(
os.getenv("GOOGLE_DRIVE_SHARED_STATE_PREFIX", "")
or os.getenv("GOOGLE_DRIVE_STORAGE_PREFIX", "")
or "kapo/shared_state"
).strip("/ ")
def _control_plane_token(self, label: str) -> str:
candidates = [
Path.cwd().resolve() / "data" / "local" / "control_plane.db",
Path.cwd().resolve() / "data" / "drive_cache" / "control_plane.db",
]
for db_path in candidates:
if not db_path.exists():
continue
try:
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
row = conn.execute(
"SELECT token FROM ngrok_tokens WHERE label = ? AND enabled = 1",
(label,),
).fetchone()
conn.close()
token = str(row["token"]) if row else ""
if token:
return token
except Exception:
self.logger.warning("Failed to read Google Drive token from %s", db_path, exc_info=True)
return ""
def _client_secret_payload(self) -> dict[str, Any] | None:
raw = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_JSON", "")).strip()
raw_b64 = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", "")).strip()
path = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET_PATH", "")).strip()
if raw:
try:
return json.loads(raw)
except Exception:
self.logger.warning("Invalid GOOGLE_DRIVE_CLIENT_SECRET_JSON", exc_info=True)
if raw_b64:
try:
return json.loads(base64.b64decode(raw_b64).decode("utf-8"))
except Exception:
self.logger.warning("Invalid GOOGLE_DRIVE_CLIENT_SECRET_JSON_BASE64", exc_info=True)
if path:
try:
return json.loads(Path(path).expanduser().read_text(encoding="utf-8"))
except Exception:
self.logger.warning("Failed to load Google Drive client secret path %s", path, exc_info=True)
client_id = str(os.getenv("GOOGLE_DRIVE_CLIENT_ID", "")).strip()
client_secret = str(os.getenv("GOOGLE_DRIVE_CLIENT_SECRET", "")).strip()
token_uri = str(os.getenv("GOOGLE_DRIVE_TOKEN_URI", "https://oauth2.googleapis.com/token")).strip()
if client_id and client_secret:
return {
"installed": {
"client_id": client_id,
"client_secret": client_secret,
"token_uri": token_uri,
}
}
return None
def _oauth_client(self) -> dict[str, str]:
payload = self._client_secret_payload() or {}
installed = payload.get("installed") or payload.get("web") or {}
return {
"client_id": str(installed.get("client_id") or "").strip(),
"client_secret": str(installed.get("client_secret") or "").strip(),
"token_uri": str(installed.get("token_uri") or "https://oauth2.googleapis.com/token").strip(),
}
def _refresh_token(self) -> str:
return str(os.getenv("GOOGLE_DRIVE_REFRESH_TOKEN", "") or self._control_plane_token("google_drive_refresh_token") or "").strip()
def _access_token(self) -> str:
token = str(os.getenv("GOOGLE_DRIVE_ACCESS_TOKEN", "") or self._control_plane_token("google_drive_access_token") or "").strip()
expires_at = float(str(os.getenv("GOOGLE_DRIVE_TOKEN_EXPIRES_AT", "0") or "0") or 0)
if token and (not expires_at or expires_at > (time.time() + 60)):
return token
refreshed = self._refresh_access_token()
return refreshed or token
def _refresh_access_token(self) -> str:
refresh_token = self._refresh_token()
client = self._oauth_client()
if not refresh_token or not client["client_id"] or not client["client_secret"]:
return ""
response = requests.post(
client["token_uri"],
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
"client_id": client["client_id"],
"client_secret": client["client_secret"],
},
timeout=30,
)
response.raise_for_status()
payload = response.json()
access_token = str(payload.get("access_token") or "").strip()
expires_in = float(payload.get("expires_in") or 0)
if access_token:
os.environ["GOOGLE_DRIVE_ACCESS_TOKEN"] = access_token
os.environ["GOOGLE_DRIVE_TOKEN_EXPIRES_AT"] = str(time.time() + expires_in) if expires_in else "0"
return access_token
def _request(self, method: str, url: str, *, headers: dict[str, str] | None = None, timeout: int = 60, **kwargs) -> requests.Response:
token = self._access_token()
if not token:
raise ValueError("Google Drive access token is not configured")
base_headers = dict(headers or {})
response: requests.Response | None = None
for _ in range(2):
response = requests.request(
method,
url,
headers={**base_headers, "Authorization": f"Bearer {token}"},
timeout=timeout,
**kwargs,
)
if response.status_code != 401:
response.raise_for_status()
return response
token = self._refresh_access_token()
if not token:
break
assert response is not None
response.raise_for_status()
return response
@staticmethod
def _escape_query(value: str) -> str:
return str(value or "").replace("\\", "\\\\").replace("'", "\\'")
def _find_folder(self, name: str, parent_id: str) -> str:
query = (
"mimeType = 'application/vnd.google-apps.folder' and trashed = false "
f"and name = '{self._escape_query(name)}'"
)
if parent_id:
query += f" and '{self._escape_query(parent_id)}' in parents"
response = self._request(
"GET",
"https://www.googleapis.com/drive/v3/files",
params={"q": query, "fields": "files(id,name)", "pageSize": 10, "supportsAllDrives": "true"},
)
files = response.json().get("files", [])
return str(files[0].get("id") or "").strip() if files else ""
def _create_folder(self, name: str, parent_id: str) -> str:
payload: dict[str, Any] = {"name": name, "mimeType": "application/vnd.google-apps.folder"}
if parent_id:
payload["parents"] = [parent_id]
response = self._request(
"POST",
"https://www.googleapis.com/drive/v3/files?supportsAllDrives=true",
headers={"Content-Type": "application/json; charset=UTF-8"},
json=payload,
)
return str(response.json().get("id") or "").strip()
def _ensure_folder_path(self, relative_path: str) -> str:
current_parent = self.root_folder_id()
parts = [part for part in f"{self.prefix()}/{relative_path}".replace("\\", "/").split("/") if part.strip()]
for part in parts:
safe_part = _safe_name(part, "folder")
cache_key = f"{current_parent}:{safe_part}"
cached = self._folder_cache.get(cache_key)
if cached:
current_parent = cached
continue
folder_id = self._find_folder(safe_part, current_parent)
if not folder_id:
folder_id = self._create_folder(safe_part, current_parent)
self._folder_cache[cache_key] = folder_id
current_parent = folder_id
return current_parent
def _find_file(self, folder_id: str, file_name: str) -> str:
cache_key = f"{folder_id}:{file_name}"
cached = self._file_cache.get(cache_key)
if cached:
return cached
query = (
"trashed = false "
f"and name = '{self._escape_query(file_name)}' "
f"and '{self._escape_query(folder_id)}' in parents"
)
response = self._request(
"GET",
"https://www.googleapis.com/drive/v3/files",
params={"q": query, "fields": "files(id,name)", "pageSize": 10, "supportsAllDrives": "true"},
)
files = response.json().get("files", [])
file_id = str(files[0].get("id") or "").strip() if files else ""
if file_id:
self._file_cache[cache_key] = file_id
return file_id
def _create_file(self, folder_id: str, file_name: str) -> str:
response = self._request(
"POST",
"https://www.googleapis.com/drive/v3/files?supportsAllDrives=true",
headers={"Content-Type": "application/json; charset=UTF-8"},
json={"name": file_name, "parents": [folder_id]},
)
file_id = str(response.json().get("id") or "").strip()
if file_id:
self._file_cache[f"{folder_id}:{file_name}"] = file_id
return file_id
def _upload_json(self, file_id: str, payload: dict[str, Any]) -> None:
self._request(
"PATCH",
f"https://www.googleapis.com/upload/drive/v3/files/{file_id}?uploadType=media&supportsAllDrives=true",
headers={"Content-Type": "application/json; charset=UTF-8"},
data=json.dumps(payload, ensure_ascii=False, sort_keys=True).encode("utf-8"),
)
def _download_json(self, file_id: str) -> dict[str, Any]:
response = self._request(
"GET",
f"https://www.googleapis.com/drive/v3/files/{file_id}?alt=media&supportsAllDrives=true",
)
return dict(response.json() or {})
def get_document(self, collection: str, doc_id: str) -> dict[str, Any]:
if not self.enabled():
return {}
folder_id = self._ensure_folder_path(_safe_name(collection, "collection"))
file_name = f"{_safe_name(doc_id)}.json"
file_id = self._find_file(folder_id, file_name)
if not file_id:
return {}
try:
payload = self._download_json(file_id)
payload.setdefault("id", _safe_name(doc_id))
return payload
except Exception:
self.logger.warning("Failed to download shared-state file %s/%s", collection, doc_id, exc_info=True)
return {}
def set_document(self, collection: str, doc_id: str, payload: dict[str, Any], *, merge: bool = True) -> bool:
if not self.enabled():
return False
safe_doc = _safe_name(doc_id)
try:
existing = self.get_document(collection, safe_doc) if merge else {}
combined = {**existing, **dict(payload or {})} if merge else dict(payload or {})
combined.setdefault("id", safe_doc)
folder_id = self._ensure_folder_path(_safe_name(collection, "collection"))
file_name = f"{safe_doc}.json"
file_id = self._find_file(folder_id, file_name)
if not file_id:
file_id = self._create_file(folder_id, file_name)
self._upload_json(file_id, combined)
return True
except Exception:
self.logger.warning("Failed to upload shared-state file %s/%s", collection, safe_doc, exc_info=True)
return False
def list_documents(self, collection: str, *, limit: int = 200) -> list[dict[str, Any]]:
if not self.enabled():
return []
try:
folder_id = self._ensure_folder_path(_safe_name(collection, "collection"))
query = f"trashed = false and '{self._escape_query(folder_id)}' in parents"
response = self._request(
"GET",
"https://www.googleapis.com/drive/v3/files",
params={
"q": query,
"fields": "files(id,name,modifiedTime)",
"pageSize": max(1, int(limit)),
"supportsAllDrives": "true",
"orderBy": "modifiedTime desc",
},
)
items: list[dict[str, Any]] = []
for file in response.json().get("files", []):
file_id = str(file.get("id") or "").strip()
file_name = str(file.get("name") or "").strip()
if not file_id or not file_name.endswith(".json"):
continue
self._file_cache[f"{folder_id}:{file_name}"] = file_id
payload = self._download_json(file_id)
payload.setdefault("id", file_name[:-5])
items.append(payload)
return items
except Exception:
self.logger.warning("Failed to list shared-state collection %s", collection, exc_info=True)
return []
def delete_document(self, collection: str, doc_id: str) -> bool:
if not self.enabled():
return False
safe_doc = _safe_name(doc_id)
try:
folder_id = self._ensure_folder_path(_safe_name(collection, "collection"))
file_name = f"{safe_doc}.json"
file_id = self._find_file(folder_id, file_name)
if not file_id:
return True
self._request(
"DELETE",
f"https://www.googleapis.com/drive/v3/files/{file_id}?supportsAllDrives=true",
)
self._file_cache.pop(f"{folder_id}:{file_name}", None)
return True
except Exception:
self.logger.warning("Failed to delete shared-state file %s/%s", collection, safe_doc, exc_info=True)
return False