""" OptiQ Database - SQLite database for users, usage tracking, and audit logs. """ from __future__ import annotations import os import sqlite3 import hashlib import secrets from datetime import datetime from pathlib import Path from typing import Optional from contextlib import contextmanager # Database file path DB_PATH = Path(__file__).parent.parent / "optiq.db" def get_connection(): """Get a database connection.""" conn = sqlite3.connect(str(DB_PATH)) conn.row_factory = sqlite3.Row return conn @contextmanager def get_db(): """Context manager for database connection.""" conn = get_connection() try: yield conn conn.commit() except Exception: conn.rollback() raise finally: conn.close() def _hash_password(password: str, salt: str | None = None) -> tuple[str, str]: """Hash a password with a salt using SHA-256. Returns (hash, salt).""" if salt is None: salt = secrets.token_hex(16) hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest() return hashed, salt def init_db(): """Initialize the database with required tables.""" with get_db() as conn: cursor = conn.cursor() # Users table cursor.execute(""" CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, firebase_uid TEXT UNIQUE NOT NULL, email TEXT UNIQUE NOT NULL, display_name TEXT, password_hash TEXT, password_salt TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, last_login TIMESTAMP, total_optimizations INTEGER DEFAULT 0, total_energy_saved_kwh REAL DEFAULT 0, total_co2_saved_kg REAL DEFAULT 0, total_money_saved_usd REAL DEFAULT 0 ) """) # Usage tracking table cursor.execute(""" CREATE TABLE IF NOT EXISTS usage ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER, firebase_uid TEXT, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, system TEXT, method TEXT, load_multiplier REAL DEFAULT 1.0, baseline_loss_kw REAL, optimized_loss_kw REAL, energy_saved_kwh REAL, co2_saved_kg REAL, money_saved_usd REAL, computation_time_sec REAL, shadow_mode BOOLEAN DEFAULT 0, switches_changed TEXT, FOREIGN KEY (user_id) REFERENCES users(id) ) """) # Audit logs table cursor.execute(""" CREATE TABLE IF NOT EXISTS audit_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER, firebase_uid TEXT, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, action TEXT NOT NULL, system TEXT, method TEXT, details TEXT, baseline_loss_kw REAL, optimized_loss_kw REAL, loss_reduction_pct REAL, energy_saved_kwh_year REAL, co2_saved_tonnes_year REAL, cost_saved_usd_year REAL, open_lines_before TEXT, open_lines_after TEXT, FOREIGN KEY (user_id) REFERENCES users(id) ) """) # Feeders table (for multi-feeder simulation) cursor.execute(""" CREATE TABLE IF NOT EXISTS feeders ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER, name TEXT NOT NULL, system TEXT DEFAULT 'case33bw', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (user_id) REFERENCES users(id) ) """) conn.commit() # Migrate: add password columns if they don't exist (handles old DBs) _migrate_add_columns() def _migrate_add_columns(): """Add new columns to existing tables (safe for re-runs).""" with get_db() as conn: cursor = conn.cursor() for col, coltype in [("password_hash", "TEXT"), ("password_salt", "TEXT")]: try: cursor.execute(f"ALTER TABLE users ADD COLUMN {col} {coltype}") except sqlite3.OperationalError: pass # column already exists conn.commit() # ── Auth helpers ────────────────────────────────────────────────────────── def register_user(email: str, password: str, display_name: str | None = None) -> dict: """Register a new user with email/password. Returns user dict or raises ValueError.""" pw_hash, pw_salt = _hash_password(password) uid = f"user_{secrets.token_hex(8)}" with get_db() as conn: cursor = conn.cursor() # Check if email already exists cursor.execute("SELECT id FROM users WHERE email = ?", (email,)) if cursor.fetchone(): raise ValueError("An account with this email already exists") cursor.execute( """INSERT INTO users (firebase_uid, email, display_name, password_hash, password_salt, created_at, last_login) VALUES (?, ?, ?, ?, ?, ?, ?)""", (uid, email, display_name or email.split("@")[0], pw_hash, pw_salt, datetime.utcnow().isoformat(), datetime.utcnow().isoformat()) ) cursor.execute("SELECT * FROM users WHERE firebase_uid = ?", (uid,)) return dict(cursor.fetchone()) def authenticate_user(email: str, password: str) -> dict | None: """Authenticate a user by email/password. Returns user dict or None.""" with get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) row = cursor.fetchone() if not row: return None user = dict(row) # If user was created before password auth, accept any password if not user.get("password_hash"): return user pw_hash, _ = _hash_password(password, user["password_salt"]) if pw_hash != user["password_hash"]: return None # Update last login cursor.execute( "UPDATE users SET last_login = ? WHERE id = ?", (datetime.utcnow().isoformat(), user["id"]) ) return user def get_or_create_user(firebase_uid: str, email: str, display_name: Optional[str] = None) -> dict: """Get or create a user by Firebase UID.""" with get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE firebase_uid = ?", (firebase_uid,)) row = cursor.fetchone() if row: # Update last login cursor.execute( "UPDATE users SET last_login = ? WHERE firebase_uid = ?", (datetime.utcnow().isoformat(), firebase_uid) ) return dict(row) else: cursor.execute( """INSERT INTO users (firebase_uid, email, display_name, created_at, last_login) VALUES (?, ?, ?, ?, ?)""", (firebase_uid, email, display_name, datetime.utcnow().isoformat(), datetime.utcnow().isoformat()) ) cursor.execute("SELECT * FROM users WHERE firebase_uid = ?", (firebase_uid,)) return dict(cursor.fetchone()) def log_usage( firebase_uid: str, system: str, method: str, baseline_loss_kw: float, optimized_loss_kw: float, energy_saved_kwh: float, co2_saved_kg: float, money_saved_usd: float, computation_time_sec: float, shadow_mode: bool = False, switches_changed: Optional[str] = None, load_multiplier: float = 1.0, ): """Log a usage event.""" with get_db() as conn: cursor = conn.cursor() # Get user_id cursor.execute("SELECT id FROM users WHERE firebase_uid = ?", (firebase_uid,)) row = cursor.fetchone() user_id = row["id"] if row else None cursor.execute( """INSERT INTO usage ( user_id, firebase_uid, system, method, load_multiplier, baseline_loss_kw, optimized_loss_kw, energy_saved_kwh, co2_saved_kg, money_saved_usd, computation_time_sec, shadow_mode, switches_changed ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", (user_id, firebase_uid, system, method, load_multiplier, baseline_loss_kw, optimized_loss_kw, energy_saved_kwh, co2_saved_kg, money_saved_usd, computation_time_sec, shadow_mode, switches_changed) ) # Update user totals if user_id: cursor.execute( """UPDATE users SET total_optimizations = total_optimizations + 1, total_energy_saved_kwh = total_energy_saved_kwh + ?, total_co2_saved_kg = total_co2_saved_kg + ?, total_money_saved_usd = total_money_saved_usd + ? WHERE id = ?""", (energy_saved_kwh, co2_saved_kg, money_saved_usd, user_id) ) def log_audit( firebase_uid: str, action: str, system: Optional[str] = None, method: Optional[str] = None, details: Optional[str] = None, baseline_loss_kw: Optional[float] = None, optimized_loss_kw: Optional[float] = None, loss_reduction_pct: Optional[float] = None, energy_saved_kwh_year: Optional[float] = None, co2_saved_tonnes_year: Optional[float] = None, cost_saved_usd_year: Optional[float] = None, open_lines_before: Optional[str] = None, open_lines_after: Optional[str] = None, ): """Log an audit event.""" with get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT id FROM users WHERE firebase_uid = ?", (firebase_uid,)) row = cursor.fetchone() user_id = row["id"] if row else None cursor.execute( """INSERT INTO audit_logs ( user_id, firebase_uid, action, system, method, details, baseline_loss_kw, optimized_loss_kw, loss_reduction_pct, energy_saved_kwh_year, co2_saved_tonnes_year, cost_saved_usd_year, open_lines_before, open_lines_after ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", (user_id, firebase_uid, action, system, method, details, baseline_loss_kw, optimized_loss_kw, loss_reduction_pct, energy_saved_kwh_year, co2_saved_tonnes_year, cost_saved_usd_year, open_lines_before, open_lines_after) ) def get_user_usage(firebase_uid: str, limit: int = 100) -> list: """Get usage history for a user.""" with get_db() as conn: cursor = conn.cursor() cursor.execute( """SELECT * FROM usage WHERE firebase_uid = ? ORDER BY timestamp DESC LIMIT ?""", (firebase_uid, limit) ) return [dict(row) for row in cursor.fetchall()] def get_user_audit_logs(firebase_uid: str, limit: int = 100) -> list: """Get audit logs for a user.""" with get_db() as conn: cursor = conn.cursor() cursor.execute( """SELECT * FROM audit_logs WHERE firebase_uid = ? ORDER BY timestamp DESC LIMIT ?""", (firebase_uid, limit) ) return [dict(row) for row in cursor.fetchall()] def get_user_stats(firebase_uid: str) -> dict: """Get aggregated stats for a user.""" with get_db() as conn: cursor = conn.cursor() cursor.execute("SELECT * FROM users WHERE firebase_uid = ?", (firebase_uid,)) row = cursor.fetchone() if row: return { "total_optimizations": row["total_optimizations"], "total_energy_saved_kwh": row["total_energy_saved_kwh"], "total_co2_saved_kg": row["total_co2_saved_kg"], "total_money_saved_usd": row["total_money_saved_usd"], } return { "total_optimizations": 0, "total_energy_saved_kwh": 0, "total_co2_saved_kg": 0, "total_money_saved_usd": 0, } # Initialize database on import init_db()