Spaces:
Sleeping
Sleeping
| """ | |
| S3 Restore Manager | |
| Handles automatic restore of SQLite database from S3-compatible storage | |
| during webapp startup. Provides validation and fallback logic to ensure | |
| reliable database restoration. | |
| """ | |
| import os | |
| import sqlite3 | |
| import hashlib | |
| import logging | |
| from datetime import datetime | |
| from typing import Optional | |
| from enum import Enum | |
| import boto3 | |
| from botocore.exceptions import ClientError | |
| from ..utils import s3_logger | |
| from .s3_config import ( | |
| S3Config, | |
| S3ConnectionError, | |
| RestoreError, | |
| DatabaseCorruptedError | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class RestoreResult(Enum): | |
| """Result of restore operation.""" | |
| RESTORED_FROM_S3 = "restored_from_s3" | |
| LOCAL_NEWER = "local_newer" | |
| NO_BACKUP_FOUND = "no_backup_found" | |
| VALIDATION_FAILED = "validation_failed" | |
| NETWORK_ERROR = "network_error" | |
| class RestoreManager: | |
| """ | |
| Manages automatic restore of SQLite database from S3. | |
| Features: | |
| - Startup restore with validation | |
| - Timestamp comparison (local vs S3) | |
| - Checksum verification | |
| - SQLite integrity check | |
| - Atomic file replacement | |
| - Fallback chain for reliability | |
| """ | |
| def __init__(self, config: S3Config, db_path: str): | |
| """ | |
| Initialize the restore manager. | |
| Args: | |
| config: S3 configuration object | |
| db_path: Absolute path where database should be restored | |
| """ | |
| self.config = config | |
| self.db_path = db_path | |
| if config.enabled: | |
| self.s3_client = config.create_s3_client() | |
| logger.info(f"RestoreManager initialized for {db_path}") | |
| else: | |
| self.s3_client = None | |
| logger.info("RestoreManager initialized but S3 is disabled") | |
| def restore_from_s3(self) -> RestoreResult: | |
| """ | |
| Restore database from S3 with validation and fallback logic. | |
| Process: | |
| 1. List backups in S3 | |
| 2. Find latest by LastModified | |
| 3. Compare timestamps (local vs S3) | |
| 4. Download if S3 is newer | |
| 5. Validate checksum and integrity | |
| 6. Atomic replace local file | |
| Returns: | |
| RestoreResult enum indicating the outcome | |
| Side Effects: | |
| - May replace local database file atomically | |
| - Creates temp files during download (cleaned up automatically) | |
| - Logs all operations with structured logging | |
| """ | |
| start_time = datetime.now() | |
| s3_logger.restore_started() | |
| if not self.config.enabled: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| s3_logger.restore_completed(duration, None, None, RestoreResult.NO_BACKUP_FOUND.value) | |
| logger.info("Restore skipped - S3 disabled") | |
| return RestoreResult.NO_BACKUP_FOUND | |
| try: | |
| # List backups from S3 | |
| backups = self._list_backups() | |
| if not backups or len(backups) == 0: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| s3_logger.restore_completed(duration, None, None, RestoreResult.NO_BACKUP_FOUND.value) | |
| logger.info("No backups found in S3") | |
| return RestoreResult.NO_BACKUP_FOUND | |
| # Find latest backup | |
| latest_backup = max(backups, key=lambda x: x['LastModified']) | |
| s3_key = latest_backup['Key'] | |
| s3_timestamp = latest_backup['LastModified'] | |
| logger.info(f"Latest S3 backup: {s3_key} ({s3_timestamp})") | |
| # Compare with local file timestamp | |
| if os.path.exists(self.db_path): | |
| local_mtime = datetime.fromtimestamp(os.path.getmtime(self.db_path)) | |
| local_mtime = local_mtime.replace(tzinfo=s3_timestamp.tzinfo) # Make timezone-aware | |
| if local_mtime >= s3_timestamp: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| s3_logger.restore_completed(duration, s3_key, None, RestoreResult.LOCAL_NEWER.value) | |
| logger.info(f"Local database is newer ({local_mtime} >= {s3_timestamp}), skipping restore") | |
| return RestoreResult.LOCAL_NEWER | |
| # Download from S3 | |
| temp_path = f"{self.db_path}.restore" | |
| download_size = self._download_backup(s3_key, temp_path) | |
| # Validate backup | |
| if not self.validate_backup(temp_path, s3_key): | |
| # Validation failed - use local fallback | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| duration = (datetime.now() - start_time).total_seconds() | |
| s3_logger.restore_fallback("validation_failed", "using_local_database") | |
| s3_logger.restore_completed(duration, s3_key, None, RestoreResult.VALIDATION_FAILED.value) | |
| logger.warning("Backup validation failed, using local database") | |
| return RestoreResult.VALIDATION_FAILED | |
| # Atomic replace | |
| self._atomic_replace(temp_path) | |
| duration = (datetime.now() - start_time).total_seconds() | |
| s3_logger.restore_completed(duration, s3_key, download_size, RestoreResult.RESTORED_FROM_S3.value) | |
| logger.info(f"Restore completed successfully from {s3_key} ({duration:.2f}s)") | |
| return RestoreResult.RESTORED_FROM_S3 | |
| except S3ConnectionError as e: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| s3_logger.restore_fallback("network_error", "using_local_database") | |
| s3_logger.restore_completed(duration, None, None, RestoreResult.NETWORK_ERROR.value) | |
| logger.error(f"Network error during restore: {e}") | |
| return RestoreResult.NETWORK_ERROR | |
| except Exception as e: | |
| duration = (datetime.now() - start_time).total_seconds() | |
| s3_logger.restore_fallback(str(e), "using_local_database") | |
| s3_logger.restore_completed(duration, None, None, RestoreResult.NETWORK_ERROR.value) | |
| logger.error(f"Unexpected error during restore: {e}", exc_info=True) | |
| return RestoreResult.NETWORK_ERROR | |
| def _list_backups(self) -> list: | |
| """ | |
| List all backup files in S3. | |
| Returns: | |
| List of S3 object dictionaries with Key, LastModified, Size | |
| Raises: | |
| S3ConnectionError: Network or S3 service error | |
| """ | |
| try: | |
| logger.debug(f"Listing backups in bucket: {self.config.bucket}") | |
| response = self.s3_client.list_objects_v2( | |
| Bucket=self.config.bucket, | |
| Prefix='contacts-' | |
| ) | |
| if 'Contents' not in response: | |
| return [] | |
| return response['Contents'] | |
| except ClientError as e: | |
| error_code = e.response['Error']['Code'] | |
| logger.error(f"S3 error listing backups: {error_code}") | |
| raise S3ConnectionError(f"S3 error: {error_code}") from e | |
| except Exception as e: | |
| logger.error(f"Unexpected error listing backups: {e}") | |
| raise S3ConnectionError(f"Error listing backups: {e}") from e | |
| def _compare_timestamps(self, local_path: str, s3_timestamp: datetime) -> bool: | |
| """ | |
| Compare local file timestamp with S3 backup timestamp. | |
| Args: | |
| local_path: Path to local database file | |
| s3_timestamp: LastModified timestamp from S3 | |
| Returns: | |
| True if S3 backup is newer, False if local is newer or equal | |
| """ | |
| if not os.path.exists(local_path): | |
| return True # No local file, S3 is "newer" | |
| local_mtime = datetime.fromtimestamp(os.path.getmtime(local_path)) | |
| local_mtime = local_mtime.replace(tzinfo=s3_timestamp.tzinfo) | |
| logger.debug(f"Timestamp comparison - Local: {local_mtime}, S3: {s3_timestamp}") | |
| return s3_timestamp > local_mtime | |
| def _download_backup(self, s3_key: str, dest_path: str) -> int: | |
| """ | |
| Download backup file from S3. | |
| Args: | |
| s3_key: S3 object key | |
| dest_path: Local path where file should be saved | |
| Returns: | |
| Size of downloaded file in bytes | |
| Raises: | |
| S3ConnectionError: Network or S3 service error | |
| """ | |
| try: | |
| logger.info(f"Downloading backup from S3: {s3_key}") | |
| self.s3_client.download_file( | |
| self.config.bucket, | |
| s3_key, | |
| dest_path | |
| ) | |
| file_size = os.path.getsize(dest_path) | |
| logger.info(f"Download completed: {file_size} bytes") | |
| return file_size | |
| except ClientError as e: | |
| error_code = e.response['Error']['Code'] | |
| logger.error(f"S3 error downloading backup: {error_code}") | |
| raise S3ConnectionError(f"S3 error: {error_code}") from e | |
| except Exception as e: | |
| logger.error(f"Unexpected error downloading backup: {e}") | |
| raise S3ConnectionError(f"Error downloading backup: {e}") from e | |
| def _validate_checksum(self, file_path: str, s3_key: str) -> bool: | |
| """ | |
| Validate file checksum against S3 metadata. | |
| Args: | |
| file_path: Path to downloaded file | |
| s3_key: S3 object key | |
| Returns: | |
| True if checksum matches or no checksum in metadata, False if mismatch | |
| """ | |
| try: | |
| # Get S3 object metadata | |
| response = self.s3_client.head_object( | |
| Bucket=self.config.bucket, | |
| Key=s3_key | |
| ) | |
| remote_checksum = response.get('Metadata', {}).get('sha256') | |
| if not remote_checksum: | |
| logger.debug("No checksum in S3 metadata, skipping validation") | |
| return True | |
| # Calculate local checksum | |
| sha256_hash = hashlib.sha256() | |
| with open(file_path, 'rb') as f: | |
| for chunk in iter(lambda: f.read(8192), b''): | |
| sha256_hash.update(chunk) | |
| local_checksum = sha256_hash.hexdigest() | |
| if local_checksum == remote_checksum: | |
| logger.debug("Checksum validation passed") | |
| return True | |
| else: | |
| logger.error(f"Checksum mismatch - Local: {local_checksum}, Remote: {remote_checksum}") | |
| return False | |
| except Exception as e: | |
| logger.warning(f"Checksum validation failed: {e}") | |
| return False | |
| def _validate_sqlite_integrity(self, file_path: str) -> bool: | |
| """ | |
| Validate SQLite database integrity using PRAGMA integrity_check. | |
| Args: | |
| file_path: Path to database file | |
| Returns: | |
| True if database passes integrity check, False otherwise | |
| """ | |
| try: | |
| conn = sqlite3.connect(file_path) | |
| cursor = conn.cursor() | |
| cursor.execute("PRAGMA integrity_check") | |
| result = cursor.fetchone()[0] | |
| conn.close() | |
| if result == 'ok': | |
| logger.debug("SQLite integrity check passed") | |
| return True | |
| else: | |
| logger.error(f"SQLite integrity check failed: {result}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"SQLite integrity check failed: {e}") | |
| return False | |
| def validate_backup(self, file_path: str, s3_key: Optional[str] = None) -> bool: | |
| """ | |
| Validate a backup file's integrity. | |
| Performs both checksum validation (if S3 metadata available) and | |
| SQLite integrity check. | |
| Args: | |
| file_path: Path to backup file to validate | |
| s3_key: Optional S3 key for checksum validation | |
| Returns: | |
| True if file is valid, False otherwise | |
| """ | |
| logger.info(f"Validating backup: {file_path}") | |
| # Validate checksum if S3 key provided | |
| if s3_key and not self._validate_checksum(file_path, s3_key): | |
| return False | |
| # Validate SQLite integrity | |
| if not self._validate_sqlite_integrity(file_path): | |
| return False | |
| logger.info("Backup validation passed") | |
| return True | |
| def _atomic_replace(self, temp_path: str) -> None: | |
| """ | |
| Atomically replace local database file with validated backup. | |
| Uses os.replace() which is atomic on POSIX systems. | |
| Args: | |
| temp_path: Path to validated backup file | |
| Raises: | |
| OSError: If atomic replace fails | |
| """ | |
| try: | |
| logger.info(f"Replacing database: {self.db_path}") | |
| # Create backup directory if it doesn't exist | |
| os.makedirs(os.path.dirname(self.db_path), exist_ok=True) | |
| # Atomic replace | |
| os.replace(temp_path, self.db_path) | |
| logger.info("Database replaced successfully") | |
| except OSError as e: | |
| logger.error(f"Failed to replace database: {e}") | |
| raise | |
| def restore_on_startup() -> RestoreResult: | |
| """ | |
| Convenience function to restore database on webapp startup. | |
| This function is called from entrypoint.sh before Flask starts. | |
| Returns: | |
| RestoreResult enum indicating the outcome | |
| """ | |
| try: | |
| from .s3_config import S3Config | |
| config = S3Config.from_env() | |
| db_path = os.getenv("DATABASE_PATH", "/app/data/contacts.db") | |
| restore_manager = RestoreManager(config, db_path) | |
| result = restore_manager.restore_from_s3() | |
| logger.info(f"Startup restore completed: {result.value}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Startup restore failed: {e}", exc_info=True) | |
| s3_logger.restore_fallback(str(e), "using_local_or_empty_database") | |
| return RestoreResult.NETWORK_ERROR | |