File size: 4,263 Bytes
3674b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""

Database module with SQLCipher encryption support.



When DB_ENCRYPTION_KEY is set in the environment (or .env file), the database

will be encrypted using SQLCipher. If no key is provided, a standard SQLite

database is used (for development only).

"""
import os
from sqlalchemy import create_engine, event, text
from sqlalchemy.orm import sessionmaker, DeclarativeBase
from backend.config import settings

# Check if SQLCipher is available
SQLCIPHER_AVAILABLE = False
try:
    import sqlcipher3
    SQLCIPHER_AVAILABLE = True
except ImportError:
    pass


def _get_db_path() -> str:
    """Extract the database file path from the database URL."""
    import os
    from pathlib import Path

    # Check if running in HF Spaces (writable /data directory)
    is_hf_spaces = os.path.exists("/data") and os.access("/data", os.W_OK)

    if is_hf_spaces:
        # Use HF Spaces persistent storage
        db_path = "/data/crop_dashboard.db"
    else:
        # Parse from settings for local development
        url = settings.database_url
        if url.startswith("sqlite:///"):
            db_path = url[10:]  # Remove "sqlite:///"
            if db_path.startswith("./"):
                db_path = db_path[2:]
        else:
            db_path = url

    # Ensure directory exists
    db_dir = Path(db_path).parent
    try:
        db_dir.mkdir(parents=True, exist_ok=True)
    except PermissionError:
        pass  # Directory may already exist

    return db_path


def _create_encrypted_engine():
    """Create an engine using SQLCipher for encrypted database access."""
    if not SQLCIPHER_AVAILABLE:
        raise RuntimeError(
            "SQLCipher encryption requested but sqlcipher3 is not installed. "
            "Install with: pip install sqlcipher3"
        )

    db_path = _get_db_path()
    key = settings.db_encryption_key

    def create_connection():
        """Create a new SQLCipher connection."""
        conn = sqlcipher3.connect(db_path, check_same_thread=False)
        # Set the encryption key
        conn.execute(f"PRAGMA key = '{key}'")
        # Use SQLCipher 4 defaults for strong encryption
        conn.execute("PRAGMA cipher_compatibility = 4")
        conn.execute("PRAGMA kdf_iter = 256000")
        conn.execute("PRAGMA cipher_memory_security = ON")
        return conn

    # Create engine with creator function
    # Note: SQLite with creator uses StaticPool instead of QueuePool
    from sqlalchemy.pool import StaticPool
    engine = create_engine(
        "sqlite://",  # Dummy URL, we override with creator
        creator=create_connection,
        poolclass=StaticPool
    )

    return engine


def _create_standard_engine():
    """Create a standard SQLite engine (no encryption)."""
    engine = create_engine(
        settings.database_url,
        connect_args={"check_same_thread": False}
    )
    return engine


def create_db_engine():
    """

    Create the appropriate database engine based on configuration.



    If db_encryption_key is set, uses SQLCipher for encryption.

    Otherwise, uses standard SQLite.

    """
    if settings.db_encryption_key:
        print("Database encryption enabled (SQLCipher)")
        return _create_encrypted_engine()
    else:
        print("WARNING: Database is NOT encrypted (no DB_ENCRYPTION_KEY set)")
        return _create_standard_engine()


# Create the engine
engine = create_db_engine()

SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


class Base(DeclarativeBase):
    pass


def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()


def init_db():
    from backend.models import user, group, site, sensor_data, pipeline, box_connection
    Base.metadata.create_all(bind=engine)


def is_database_encrypted(db_path: str | None = None) -> bool:
    """Check if the database file is encrypted."""
    path = db_path or _get_db_path()
    if not os.path.exists(path):
        return False

    with open(path, 'rb') as f:
        header = f.read(16)

    # SQLite files start with "SQLite format 3\0"
    return header != b'SQLite format 3\x00'