Spaces:
Runtime error
Runtime error
| """ | |
| Database module with SQLCipher encryption support. | |
| When DB_ENCRYPTION_KEY is set in the environment (or .env file), the database | |
| will be encrypted using SQLCipher. If no key is provided, a standard SQLite | |
| database is used (for development only). | |
| """ | |
| import os | |
| from sqlalchemy import create_engine, event, text | |
| from sqlalchemy.orm import sessionmaker, DeclarativeBase | |
| from backend.config import settings | |
| # Check if SQLCipher is available | |
| SQLCIPHER_AVAILABLE = False | |
| try: | |
| import sqlcipher3 | |
| SQLCIPHER_AVAILABLE = True | |
| except ImportError: | |
| pass | |
| def _get_db_path() -> str: | |
| """Extract the database file path from the database URL.""" | |
| import os | |
| from pathlib import Path | |
| # Check if running in HF Spaces (writable /data directory) | |
| is_hf_spaces = os.path.exists("/data") and os.access("/data", os.W_OK) | |
| if is_hf_spaces: | |
| # Use HF Spaces persistent storage | |
| db_path = "/data/crop_dashboard.db" | |
| else: | |
| # Parse from settings for local development | |
| url = settings.database_url | |
| if url.startswith("sqlite:///"): | |
| db_path = url[10:] # Remove "sqlite:///" | |
| if db_path.startswith("./"): | |
| db_path = db_path[2:] | |
| else: | |
| db_path = url | |
| # Ensure directory exists | |
| db_dir = Path(db_path).parent | |
| try: | |
| db_dir.mkdir(parents=True, exist_ok=True) | |
| except PermissionError: | |
| pass # Directory may already exist | |
| return db_path | |
| def _create_encrypted_engine(): | |
| """Create an engine using SQLCipher for encrypted database access.""" | |
| if not SQLCIPHER_AVAILABLE: | |
| raise RuntimeError( | |
| "SQLCipher encryption requested but sqlcipher3 is not installed. " | |
| "Install with: pip install sqlcipher3" | |
| ) | |
| db_path = _get_db_path() | |
| key = settings.db_encryption_key | |
| def create_connection(): | |
| """Create a new SQLCipher connection.""" | |
| conn = sqlcipher3.connect(db_path, check_same_thread=False) | |
| # Set the encryption key | |
| conn.execute(f"PRAGMA key = '{key}'") | |
| # Use SQLCipher 4 defaults for strong encryption | |
| conn.execute("PRAGMA cipher_compatibility = 4") | |
| conn.execute("PRAGMA kdf_iter = 256000") | |
| conn.execute("PRAGMA cipher_memory_security = ON") | |
| return conn | |
| # Create engine with creator function | |
| # Note: SQLite with creator uses StaticPool instead of QueuePool | |
| from sqlalchemy.pool import StaticPool | |
| engine = create_engine( | |
| "sqlite://", # Dummy URL, we override with creator | |
| creator=create_connection, | |
| poolclass=StaticPool | |
| ) | |
| return engine | |
| def _create_standard_engine(): | |
| """Create a standard SQLite engine (no encryption).""" | |
| engine = create_engine( | |
| settings.database_url, | |
| connect_args={"check_same_thread": False} | |
| ) | |
| return engine | |
| def create_db_engine(): | |
| """ | |
| Create the appropriate database engine based on configuration. | |
| If db_encryption_key is set, uses SQLCipher for encryption. | |
| Otherwise, uses standard SQLite. | |
| """ | |
| if settings.db_encryption_key: | |
| print("Database encryption enabled (SQLCipher)") | |
| return _create_encrypted_engine() | |
| else: | |
| print("WARNING: Database is NOT encrypted (no DB_ENCRYPTION_KEY set)") | |
| return _create_standard_engine() | |
| # Create the engine | |
| engine = create_db_engine() | |
| SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| class Base(DeclarativeBase): | |
| pass | |
| def get_db(): | |
| db = SessionLocal() | |
| try: | |
| yield db | |
| finally: | |
| db.close() | |
| def init_db(): | |
| from backend.models import user, group, site, sensor_data, pipeline, box_connection | |
| Base.metadata.create_all(bind=engine) | |
| def is_database_encrypted(db_path: str | None = None) -> bool: | |
| """Check if the database file is encrypted.""" | |
| path = db_path or _get_db_path() | |
| if not os.path.exists(path): | |
| return False | |
| with open(path, 'rb') as f: | |
| header = f.read(16) | |
| # SQLite files start with "SQLite format 3\0" | |
| return header != b'SQLite format 3\x00' | |