| """HF Dataset backup and restore for world state.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from pathlib import Path |
|
|
| from world.database import DB_PATH |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN", "") |
| HF_DATASET_REPO = os.environ.get("HF_DATASET_REPO", "") |
|
|
|
|
| def database_is_valid() -> bool: |
| """True when world.db exists and has the core schema.""" |
| if not DB_PATH.exists(): |
| return False |
| try: |
| if DB_PATH.stat().st_size < 512: |
| return False |
| import sqlite3 |
|
|
| conn = sqlite3.connect(DB_PATH) |
| try: |
| row = conn.execute( |
| "SELECT name FROM sqlite_master WHERE type='table' AND name='world_state'" |
| ).fetchone() |
| return row is not None |
| finally: |
| conn.close() |
| except Exception: |
| return False |
|
|
|
|
| def backup_database() -> bool: |
| if os.environ.get("AG_SKIP_BACKUP") == "1": |
| return False |
|
|
| if not HF_TOKEN or not HF_DATASET_REPO: |
| return False |
|
|
| if not DB_PATH.exists(): |
| return False |
|
|
| try: |
| from huggingface_hub import HfApi |
| api = HfApi(token=HF_TOKEN) |
|
|
| try: |
| api.create_repo(HF_DATASET_REPO, repo_type="dataset", private=True, exist_ok=True) |
| except Exception: |
| pass |
|
|
| api.upload_file( |
| path_or_fileobj=str(DB_PATH), |
| path_in_repo="world.db", |
| repo_id=HF_DATASET_REPO, |
| repo_type="dataset", |
| commit_message="World state backup", |
| ) |
| return True |
| except Exception as e: |
| print(f"Backup failed: {e}") |
| return False |
|
|
|
|
| def restore_database() -> bool: |
| if not HF_TOKEN or not HF_DATASET_REPO: |
| return False |
|
|
| if database_is_valid(): |
| return False |
|
|
| if DB_PATH.exists(): |
| DB_PATH.unlink(missing_ok=True) |
|
|
| try: |
| from huggingface_hub import hf_hub_download |
|
|
| DB_PATH.parent.mkdir(parents=True, exist_ok=True) |
| downloaded = hf_hub_download( |
| repo_id=HF_DATASET_REPO, |
| filename="world.db", |
| repo_type="dataset", |
| token=HF_TOKEN, |
| local_dir=str(DB_PATH.parent), |
| ) |
| downloaded_path = Path(downloaded) |
| if downloaded_path.resolve() != DB_PATH.resolve(): |
| downloaded_path.replace(DB_PATH) |
| return database_is_valid() |
| except Exception as e: |
| print(f"Restore failed: {e}") |
| return False |
|
|