#!/usr/bin/env python3 import os import sys import json import hashlib import time import tarfile import tempfile import shutil from datetime import datetime from pathlib import Path from typing import Dict, Any, Optional, List import requests import logging from huggingface_hub import HfApi from huggingface_hub.utils import RepositoryNotFoundError from huggingface_hub import hf_hub_download logging.basicConfig( level=logging.INFO, format='{"timestamp": "%(asctime)s", "level": "%(levelname)s", "module": "atomic-restore", "message": "%(message)s"}' ) logger = logging.getLogger(__name__) class AtomicDatasetRestorer: def __init__(self, repo_id: str, dataset_path: str = "state"): self.repo_id = repo_id self.dataset_path = Path(dataset_path) self.api = HfApi() self.max_retries = 3 self.base_delay = 1.0 logger.info("init", { "repo_id": repo_id, "dataset_path": dataset_path, "max_retries": self.max_retries }) def calculate_checksum(self, file_path: Path) -> str: sha256_hash = hashlib.sha256() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): sha256_hash.update(chunk) return sha256_hash.hexdigest() def validate_integrity(self, metadata: Dict[str, Any], state_files: List[Path]) -> bool: """Validate data integrity using checksums""" try: if "checksum" not in metadata: logger.warning("no_checksum_in_metadata", {"action": "skipping_validation"}) return True state_data = metadata.get("state_data", {}) calculated_checksum = hashlib.sha256( json.dumps(state_data, sort_keys=True).encode() ).hexdigest() expected_checksum = metadata["checksum"] is_valid = calculated_checksum == expected_checksum logger.info("integrity_check", { "expected": expected_checksum, "calculated": calculated_checksum, "valid": is_valid }) return is_valid except Exception as e: logger.error("integrity_validation_failed", {"error": str(e)}) return False def create_backup_before_restore(self, target_dir: Path) -> Optional[Path]: try: if not target_dir.exists(): return None timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") backup_dir = target_dir.parent / f"state_backup_{timestamp}" logger.info("creating_local_backup", { "source": str(target_dir), "backup": str(backup_dir) }) shutil.copytree(target_dir, backup_dir) return backup_dir except Exception as e: logger.error("local_backup_failed", {"error": str(e)}) return None def restore_from_commit(self, commit_sha: str, target_dir: Path, force: bool = False) -> Dict[str, Any]: """ Restore state from specific commit Args: commit_sha: Git commit hash to restore from target_dir: Directory to restore state to force: Force restore without confirmation Returns: Dictionary with operation result """ operation_id = f"restore_{int(time.time())}" logger.info("starting_atomic_restore", { "operation_id": operation_id, "commit_sha": commit_sha, "target_dir": str(target_dir), "force": force }) try: # Validate commit exists try: repo_info = self.api.repo_info( repo_id=self.repo_id, repo_type="dataset", revision=commit_sha ) logger.info("commit_validated", {"commit": commit_sha}) except Exception as e: error_result = { "success": False, "operation_id": operation_id, "error": f"Invalid commit: {str(e)}", "timestamp": datetime.now().isoformat() } logger.error("commit_validation_failed", error_result) return error_result # Create backup before restore backup_dir = self.create_backup_before_restore(target_dir) # Create temporary directory for download with tempfile.TemporaryDirectory() as tmpdir: tmpdir_path = Path(tmpdir) # List files in the commit files = self.api.list_repo_files( repo_id=self.repo_id, repo_type="dataset", revision=commit_sha ) # Find state files state_files = [f for f in files if f.startswith(str(self.dataset_path))] if not state_files: error_result = { "success": False, "operation_id": operation_id, "error": "No state files found in commit", "timestamp": datetime.now().isoformat() } logger.error("no_state_files", error_result) return error_result # Download state files downloaded_files = [] metadata = None for file_path in state_files: try: local_path = hf_hub_download( repo_id=self.repo_id, repo_type="dataset", filename=file_path, revision=commit_sha, local_files_only=False ) if local_path: downloaded_files.append(Path(local_path)) # Load metadata if this is metadata.json if file_path.endswith("metadata.json"): with open(local_path, "r") as f: metadata = json.load(f) except Exception as e: logger.error("file_download_failed", {"file": file_path, "error": str(e)}) continue if not metadata: error_result = { "success": False, "operation_id": operation_id, "error": "Metadata not found in state files", "timestamp": datetime.now().isoformat() } logger.error("metadata_not_found", error_result) return error_result # Validate data integrity if not self.validate_integrity(metadata, downloaded_files): error_result = { "success": False, "operation_id": operation_id, "error": "Data integrity validation failed", "timestamp": datetime.now().isoformat() } logger.error("integrity_validation_failed", error_result) return error_result # Create target directory target_dir.mkdir(parents=True, exist_ok=True) # Restore files (except metadata.json which is for reference) restored_files = [] for file_path in downloaded_files: if file_path.name != "metadata.json": dest_path = target_dir / file_path.name shutil.copy2(file_path, dest_path) restored_files.append(str(dest_path)) logger.info("file_restored", { "source": str(file_path), "destination": str(dest_path) }) result = { "success": True, "operation_id": operation_id, "commit_sha": commit_sha, "backup_dir": str(backup_dir) if backup_dir else None, "timestamp": datetime.now().isoformat(), "restored_files": restored_files, "metadata": metadata } logger.info("atomic_restore_completed", result) return result except Exception as e: error_result = { "success": False, "operation_id": operation_id, "error": str(e), "timestamp": datetime.now().isoformat() } logger.error("atomic_restore_failed", error_result) return error_result def restore_latest(self, target_dir: Path, force: bool = False) -> Dict[str, Any]: """Restore from the latest commit""" try: repo_info = self.api.repo_info( repo_id=self.repo_id, repo_type="dataset" ) if not repo_info.sha: error_result = { "success": False, "error": "No commit found in repository", "timestamp": datetime.now().isoformat() } logger.error("no_commit_found", error_result) return error_result return self.restore_from_commit(repo_info.sha, target_dir, force) except Exception as e: error_result = { "success": False, "error": f"Failed to get latest commit: {str(e)}", "timestamp": datetime.now().isoformat() } logger.error("latest_commit_failed", error_result) return error_result def main(): """Main function for command line usage""" if len(sys.argv) < 3: print(json.dumps({ "error": "Usage: python restore_from_dataset_atomic.py [--force]", "status": "error" }, indent=2)) sys.exit(1) repo_id = sys.argv[1] target_dir = sys.argv[2] force = "--force" in sys.argv try: target_path = Path(target_dir) restorer = AtomicDatasetRestorer(repo_id) result = restorer.restore_latest(target_path, force) print(json.dumps(result, indent=2)) if not result.get("success", False): sys.exit(1) except Exception as e: print(json.dumps({ "error": str(e), "status": "error" }, indent=2)) sys.exit(1) if __name__ == "__main__": main()