|
|
|
|
| import os
|
| import sys
|
| import json
|
| import hashlib
|
| import time
|
| import tarfile
|
| import tempfile
|
| import shutil
|
| from datetime import datetime
|
| from pathlib import Path
|
| from typing import Dict, Any, Optional, List
|
| import requests
|
| import logging
|
|
|
| from huggingface_hub import HfApi
|
| from huggingface_hub.utils import RepositoryNotFoundError
|
| from huggingface_hub import hf_hub_download
|
|
|
| logging.basicConfig(
|
| level=logging.INFO,
|
| format='{"timestamp": "%(asctime)s", "level": "%(levelname)s", "module": "atomic-restore", "message": "%(message)s"}'
|
| )
|
| logger = logging.getLogger(__name__)
|
|
|
| class AtomicDatasetRestorer:
|
|
|
| def __init__(self, repo_id: str, dataset_path: str = "state"):
|
| self.repo_id = repo_id
|
| self.dataset_path = Path(dataset_path)
|
| self.api = HfApi()
|
| self.max_retries = 3
|
| self.base_delay = 1.0
|
|
|
| logger.info("init", {
|
| "repo_id": repo_id,
|
| "dataset_path": dataset_path,
|
| "max_retries": self.max_retries
|
| })
|
|
|
| def calculate_checksum(self, file_path: Path) -> str:
|
| sha256_hash = hashlib.sha256()
|
| with open(file_path, "rb") as f:
|
| for chunk in iter(lambda: f.read(4096), b""):
|
| sha256_hash.update(chunk)
|
| return sha256_hash.hexdigest()
|
|
|
| def validate_integrity(self, metadata: Dict[str, Any], state_files: List[Path]) -> bool:
|
| """Validate data integrity using checksums"""
|
| try:
|
| if "checksum" not in metadata:
|
| logger.warning("no_checksum_in_metadata", {"action": "skipping_validation"})
|
| return True
|
|
|
| state_data = metadata.get("state_data", {})
|
| calculated_checksum = hashlib.sha256(
|
| json.dumps(state_data, sort_keys=True).encode()
|
| ).hexdigest()
|
|
|
| expected_checksum = metadata["checksum"]
|
|
|
| is_valid = calculated_checksum == expected_checksum
|
|
|
| logger.info("integrity_check", {
|
| "expected": expected_checksum,
|
| "calculated": calculated_checksum,
|
| "valid": is_valid
|
| })
|
|
|
| return is_valid
|
|
|
| except Exception as e:
|
| logger.error("integrity_validation_failed", {"error": str(e)})
|
| return False
|
|
|
| def create_backup_before_restore(self, target_dir: Path) -> Optional[Path]:
|
| try:
|
| if not target_dir.exists():
|
| return None
|
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| backup_dir = target_dir.parent / f"state_backup_{timestamp}"
|
|
|
| logger.info("creating_local_backup", {
|
| "source": str(target_dir),
|
| "backup": str(backup_dir)
|
| })
|
|
|
| shutil.copytree(target_dir, backup_dir)
|
| return backup_dir
|
|
|
| except Exception as e:
|
| logger.error("local_backup_failed", {"error": str(e)})
|
| return None
|
|
|
| def restore_from_commit(self, commit_sha: str, target_dir: Path, force: bool = False) -> Dict[str, Any]:
|
| """
|
| Restore state from specific commit
|
|
|
| Args:
|
| commit_sha: Git commit hash to restore from
|
| target_dir: Directory to restore state to
|
| force: Force restore without confirmation
|
|
|
| Returns:
|
| Dictionary with operation result
|
| """
|
| operation_id = f"restore_{int(time.time())}"
|
|
|
| logger.info("starting_atomic_restore", {
|
| "operation_id": operation_id,
|
| "commit_sha": commit_sha,
|
| "target_dir": str(target_dir),
|
| "force": force
|
| })
|
|
|
| try:
|
|
|
| try:
|
| repo_info = self.api.repo_info(
|
| repo_id=self.repo_id,
|
| repo_type="dataset",
|
| revision=commit_sha
|
| )
|
| logger.info("commit_validated", {"commit": commit_sha})
|
| except Exception as e:
|
| error_result = {
|
| "success": False,
|
| "operation_id": operation_id,
|
| "error": f"Invalid commit: {str(e)}",
|
| "timestamp": datetime.now().isoformat()
|
| }
|
| logger.error("commit_validation_failed", error_result)
|
| return error_result
|
|
|
|
|
| backup_dir = self.create_backup_before_restore(target_dir)
|
|
|
|
|
| with tempfile.TemporaryDirectory() as tmpdir:
|
| tmpdir_path = Path(tmpdir)
|
|
|
|
|
| files = self.api.list_repo_files(
|
| repo_id=self.repo_id,
|
| repo_type="dataset",
|
| revision=commit_sha
|
| )
|
|
|
|
|
| state_files = [f for f in files if f.startswith(str(self.dataset_path))]
|
| if not state_files:
|
| error_result = {
|
| "success": False,
|
| "operation_id": operation_id,
|
| "error": "No state files found in commit",
|
| "timestamp": datetime.now().isoformat()
|
| }
|
| logger.error("no_state_files", error_result)
|
| return error_result
|
|
|
|
|
| downloaded_files = []
|
| metadata = None
|
|
|
| for file_path in state_files:
|
| try:
|
| local_path = hf_hub_download(
|
| repo_id=self.repo_id,
|
| repo_type="dataset",
|
| filename=file_path,
|
| revision=commit_sha,
|
| local_files_only=False
|
| )
|
|
|
| if local_path:
|
| downloaded_files.append(Path(local_path))
|
|
|
|
|
| if file_path.endswith("metadata.json"):
|
| with open(local_path, "r") as f:
|
| metadata = json.load(f)
|
|
|
| except Exception as e:
|
| logger.error("file_download_failed", {"file": file_path, "error": str(e)})
|
| continue
|
|
|
| if not metadata:
|
| error_result = {
|
| "success": False,
|
| "operation_id": operation_id,
|
| "error": "Metadata not found in state files",
|
| "timestamp": datetime.now().isoformat()
|
| }
|
| logger.error("metadata_not_found", error_result)
|
| return error_result
|
|
|
|
|
| if not self.validate_integrity(metadata, downloaded_files):
|
| error_result = {
|
| "success": False,
|
| "operation_id": operation_id,
|
| "error": "Data integrity validation failed",
|
| "timestamp": datetime.now().isoformat()
|
| }
|
| logger.error("integrity_validation_failed", error_result)
|
| return error_result
|
|
|
|
|
| target_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| restored_files = []
|
| for file_path in downloaded_files:
|
| if file_path.name != "metadata.json":
|
| dest_path = target_dir / file_path.name
|
| shutil.copy2(file_path, dest_path)
|
| restored_files.append(str(dest_path))
|
|
|
| logger.info("file_restored", {
|
| "source": str(file_path),
|
| "destination": str(dest_path)
|
| })
|
|
|
| result = {
|
| "success": True,
|
| "operation_id": operation_id,
|
| "commit_sha": commit_sha,
|
| "backup_dir": str(backup_dir) if backup_dir else None,
|
| "timestamp": datetime.now().isoformat(),
|
| "restored_files": restored_files,
|
| "metadata": metadata
|
| }
|
|
|
| logger.info("atomic_restore_completed", result)
|
| return result
|
|
|
| except Exception as e:
|
| error_result = {
|
| "success": False,
|
| "operation_id": operation_id,
|
| "error": str(e),
|
| "timestamp": datetime.now().isoformat()
|
| }
|
|
|
| logger.error("atomic_restore_failed", error_result)
|
| return error_result
|
|
|
| def restore_latest(self, target_dir: Path, force: bool = False) -> Dict[str, Any]:
|
| """Restore from the latest commit"""
|
| try:
|
| repo_info = self.api.repo_info(
|
| repo_id=self.repo_id,
|
| repo_type="dataset"
|
| )
|
|
|
| if not repo_info.sha:
|
| error_result = {
|
| "success": False,
|
| "error": "No commit found in repository",
|
| "timestamp": datetime.now().isoformat()
|
| }
|
| logger.error("no_commit_found", error_result)
|
| return error_result
|
|
|
| return self.restore_from_commit(repo_info.sha, target_dir, force)
|
|
|
| except Exception as e:
|
| error_result = {
|
| "success": False,
|
| "error": f"Failed to get latest commit: {str(e)}",
|
| "timestamp": datetime.now().isoformat()
|
| }
|
| logger.error("latest_commit_failed", error_result)
|
| return error_result
|
|
|
| def main():
|
| """Main function for command line usage"""
|
| if len(sys.argv) < 3:
|
| print(json.dumps({
|
| "error": "Usage: python restore_from_dataset_atomic.py <repo_id> <target_dir> [--force]",
|
| "status": "error"
|
| }, indent=2))
|
| sys.exit(1)
|
|
|
| repo_id = sys.argv[1]
|
| target_dir = sys.argv[2]
|
| force = "--force" in sys.argv
|
|
|
| try:
|
| target_path = Path(target_dir)
|
| restorer = AtomicDatasetRestorer(repo_id)
|
| result = restorer.restore_latest(target_path, force)
|
|
|
| print(json.dumps(result, indent=2))
|
|
|
| if not result.get("success", False):
|
| sys.exit(1)
|
|
|
| except Exception as e:
|
| print(json.dumps({
|
| "error": str(e),
|
| "status": "error"
|
| }, indent=2))
|
| sys.exit(1)
|
|
|
| if __name__ == "__main__":
|
| main() |