MemPrepMate / src /services /s3_restore.py
Christian Kniep
update to v2
5d3ee93
"""
S3 Restore Manager
Handles automatic restore of SQLite database from S3-compatible storage
during webapp startup. Provides validation and fallback logic to ensure
reliable database restoration.
"""
import os
import sqlite3
import hashlib
import logging
from datetime import datetime
from typing import Optional
from enum import Enum
import boto3
from botocore.exceptions import ClientError
from ..utils import s3_logger
from .s3_config import (
S3Config,
S3ConnectionError,
RestoreError,
DatabaseCorruptedError
)
logger = logging.getLogger(__name__)
class RestoreResult(Enum):
"""Result of restore operation."""
RESTORED_FROM_S3 = "restored_from_s3"
LOCAL_NEWER = "local_newer"
NO_BACKUP_FOUND = "no_backup_found"
VALIDATION_FAILED = "validation_failed"
NETWORK_ERROR = "network_error"
class RestoreManager:
"""
Manages automatic restore of SQLite database from S3.
Features:
- Startup restore with validation
- Timestamp comparison (local vs S3)
- Checksum verification
- SQLite integrity check
- Atomic file replacement
- Fallback chain for reliability
"""
def __init__(self, config: S3Config, db_path: str):
"""
Initialize the restore manager.
Args:
config: S3 configuration object
db_path: Absolute path where database should be restored
"""
self.config = config
self.db_path = db_path
if config.enabled:
self.s3_client = config.create_s3_client()
logger.info(f"RestoreManager initialized for {db_path}")
else:
self.s3_client = None
logger.info("RestoreManager initialized but S3 is disabled")
def restore_from_s3(self) -> RestoreResult:
"""
Restore database from S3 with validation and fallback logic.
Process:
1. List backups in S3
2. Find latest by LastModified
3. Compare timestamps (local vs S3)
4. Download if S3 is newer
5. Validate checksum and integrity
6. Atomic replace local file
Returns:
RestoreResult enum indicating the outcome
Side Effects:
- May replace local database file atomically
- Creates temp files during download (cleaned up automatically)
- Logs all operations with structured logging
"""
start_time = datetime.now()
s3_logger.restore_started()
if not self.config.enabled:
duration = (datetime.now() - start_time).total_seconds()
s3_logger.restore_completed(duration, None, None, RestoreResult.NO_BACKUP_FOUND.value)
logger.info("Restore skipped - S3 disabled")
return RestoreResult.NO_BACKUP_FOUND
try:
# List backups from S3
backups = self._list_backups()
if not backups or len(backups) == 0:
duration = (datetime.now() - start_time).total_seconds()
s3_logger.restore_completed(duration, None, None, RestoreResult.NO_BACKUP_FOUND.value)
logger.info("No backups found in S3")
return RestoreResult.NO_BACKUP_FOUND
# Find latest backup
latest_backup = max(backups, key=lambda x: x['LastModified'])
s3_key = latest_backup['Key']
s3_timestamp = latest_backup['LastModified']
logger.info(f"Latest S3 backup: {s3_key} ({s3_timestamp})")
# Compare with local file timestamp
if os.path.exists(self.db_path):
local_mtime = datetime.fromtimestamp(os.path.getmtime(self.db_path))
local_mtime = local_mtime.replace(tzinfo=s3_timestamp.tzinfo) # Make timezone-aware
if local_mtime >= s3_timestamp:
duration = (datetime.now() - start_time).total_seconds()
s3_logger.restore_completed(duration, s3_key, None, RestoreResult.LOCAL_NEWER.value)
logger.info(f"Local database is newer ({local_mtime} >= {s3_timestamp}), skipping restore")
return RestoreResult.LOCAL_NEWER
# Download from S3
temp_path = f"{self.db_path}.restore"
download_size = self._download_backup(s3_key, temp_path)
# Validate backup
if not self.validate_backup(temp_path, s3_key):
# Validation failed - use local fallback
if os.path.exists(temp_path):
os.remove(temp_path)
duration = (datetime.now() - start_time).total_seconds()
s3_logger.restore_fallback("validation_failed", "using_local_database")
s3_logger.restore_completed(duration, s3_key, None, RestoreResult.VALIDATION_FAILED.value)
logger.warning("Backup validation failed, using local database")
return RestoreResult.VALIDATION_FAILED
# Atomic replace
self._atomic_replace(temp_path)
duration = (datetime.now() - start_time).total_seconds()
s3_logger.restore_completed(duration, s3_key, download_size, RestoreResult.RESTORED_FROM_S3.value)
logger.info(f"Restore completed successfully from {s3_key} ({duration:.2f}s)")
return RestoreResult.RESTORED_FROM_S3
except S3ConnectionError as e:
duration = (datetime.now() - start_time).total_seconds()
s3_logger.restore_fallback("network_error", "using_local_database")
s3_logger.restore_completed(duration, None, None, RestoreResult.NETWORK_ERROR.value)
logger.error(f"Network error during restore: {e}")
return RestoreResult.NETWORK_ERROR
except Exception as e:
duration = (datetime.now() - start_time).total_seconds()
s3_logger.restore_fallback(str(e), "using_local_database")
s3_logger.restore_completed(duration, None, None, RestoreResult.NETWORK_ERROR.value)
logger.error(f"Unexpected error during restore: {e}", exc_info=True)
return RestoreResult.NETWORK_ERROR
def _list_backups(self) -> list:
"""
List all backup files in S3.
Returns:
List of S3 object dictionaries with Key, LastModified, Size
Raises:
S3ConnectionError: Network or S3 service error
"""
try:
logger.debug(f"Listing backups in bucket: {self.config.bucket}")
response = self.s3_client.list_objects_v2(
Bucket=self.config.bucket,
Prefix='contacts-'
)
if 'Contents' not in response:
return []
return response['Contents']
except ClientError as e:
error_code = e.response['Error']['Code']
logger.error(f"S3 error listing backups: {error_code}")
raise S3ConnectionError(f"S3 error: {error_code}") from e
except Exception as e:
logger.error(f"Unexpected error listing backups: {e}")
raise S3ConnectionError(f"Error listing backups: {e}") from e
def _compare_timestamps(self, local_path: str, s3_timestamp: datetime) -> bool:
"""
Compare local file timestamp with S3 backup timestamp.
Args:
local_path: Path to local database file
s3_timestamp: LastModified timestamp from S3
Returns:
True if S3 backup is newer, False if local is newer or equal
"""
if not os.path.exists(local_path):
return True # No local file, S3 is "newer"
local_mtime = datetime.fromtimestamp(os.path.getmtime(local_path))
local_mtime = local_mtime.replace(tzinfo=s3_timestamp.tzinfo)
logger.debug(f"Timestamp comparison - Local: {local_mtime}, S3: {s3_timestamp}")
return s3_timestamp > local_mtime
def _download_backup(self, s3_key: str, dest_path: str) -> int:
"""
Download backup file from S3.
Args:
s3_key: S3 object key
dest_path: Local path where file should be saved
Returns:
Size of downloaded file in bytes
Raises:
S3ConnectionError: Network or S3 service error
"""
try:
logger.info(f"Downloading backup from S3: {s3_key}")
self.s3_client.download_file(
self.config.bucket,
s3_key,
dest_path
)
file_size = os.path.getsize(dest_path)
logger.info(f"Download completed: {file_size} bytes")
return file_size
except ClientError as e:
error_code = e.response['Error']['Code']
logger.error(f"S3 error downloading backup: {error_code}")
raise S3ConnectionError(f"S3 error: {error_code}") from e
except Exception as e:
logger.error(f"Unexpected error downloading backup: {e}")
raise S3ConnectionError(f"Error downloading backup: {e}") from e
def _validate_checksum(self, file_path: str, s3_key: str) -> bool:
"""
Validate file checksum against S3 metadata.
Args:
file_path: Path to downloaded file
s3_key: S3 object key
Returns:
True if checksum matches or no checksum in metadata, False if mismatch
"""
try:
# Get S3 object metadata
response = self.s3_client.head_object(
Bucket=self.config.bucket,
Key=s3_key
)
remote_checksum = response.get('Metadata', {}).get('sha256')
if not remote_checksum:
logger.debug("No checksum in S3 metadata, skipping validation")
return True
# Calculate local checksum
sha256_hash = hashlib.sha256()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
sha256_hash.update(chunk)
local_checksum = sha256_hash.hexdigest()
if local_checksum == remote_checksum:
logger.debug("Checksum validation passed")
return True
else:
logger.error(f"Checksum mismatch - Local: {local_checksum}, Remote: {remote_checksum}")
return False
except Exception as e:
logger.warning(f"Checksum validation failed: {e}")
return False
def _validate_sqlite_integrity(self, file_path: str) -> bool:
"""
Validate SQLite database integrity using PRAGMA integrity_check.
Args:
file_path: Path to database file
Returns:
True if database passes integrity check, False otherwise
"""
try:
conn = sqlite3.connect(file_path)
cursor = conn.cursor()
cursor.execute("PRAGMA integrity_check")
result = cursor.fetchone()[0]
conn.close()
if result == 'ok':
logger.debug("SQLite integrity check passed")
return True
else:
logger.error(f"SQLite integrity check failed: {result}")
return False
except Exception as e:
logger.error(f"SQLite integrity check failed: {e}")
return False
def validate_backup(self, file_path: str, s3_key: Optional[str] = None) -> bool:
"""
Validate a backup file's integrity.
Performs both checksum validation (if S3 metadata available) and
SQLite integrity check.
Args:
file_path: Path to backup file to validate
s3_key: Optional S3 key for checksum validation
Returns:
True if file is valid, False otherwise
"""
logger.info(f"Validating backup: {file_path}")
# Validate checksum if S3 key provided
if s3_key and not self._validate_checksum(file_path, s3_key):
return False
# Validate SQLite integrity
if not self._validate_sqlite_integrity(file_path):
return False
logger.info("Backup validation passed")
return True
def _atomic_replace(self, temp_path: str) -> None:
"""
Atomically replace local database file with validated backup.
Uses os.replace() which is atomic on POSIX systems.
Args:
temp_path: Path to validated backup file
Raises:
OSError: If atomic replace fails
"""
try:
logger.info(f"Replacing database: {self.db_path}")
# Create backup directory if it doesn't exist
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
# Atomic replace
os.replace(temp_path, self.db_path)
logger.info("Database replaced successfully")
except OSError as e:
logger.error(f"Failed to replace database: {e}")
raise
def restore_on_startup() -> RestoreResult:
"""
Convenience function to restore database on webapp startup.
This function is called from entrypoint.sh before Flask starts.
Returns:
RestoreResult enum indicating the outcome
"""
try:
from .s3_config import S3Config
config = S3Config.from_env()
db_path = os.getenv("DATABASE_PATH", "/app/data/contacts.db")
restore_manager = RestoreManager(config, db_path)
result = restore_manager.restore_from_s3()
logger.info(f"Startup restore completed: {result.value}")
return result
except Exception as e:
logger.error(f"Startup restore failed: {e}", exc_info=True)
s3_logger.restore_fallback(str(e), "using_local_or_empty_database")
return RestoreResult.NETWORK_ERROR