Spaces:
Running
Running
| """ | |
| OptiQ Database - SQLite database for users, usage tracking, and audit logs. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import sqlite3 | |
| import hashlib | |
| import secrets | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Optional | |
| from contextlib import contextmanager | |
| # Database file path | |
| DB_PATH = Path(__file__).parent.parent / "optiq.db" | |
| def get_connection(): | |
| """Get a database connection.""" | |
| conn = sqlite3.connect(str(DB_PATH)) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |
| def get_db(): | |
| """Context manager for database connection.""" | |
| conn = get_connection() | |
| try: | |
| yield conn | |
| conn.commit() | |
| except Exception: | |
| conn.rollback() | |
| raise | |
| finally: | |
| conn.close() | |
| def _hash_password(password: str, salt: str | None = None) -> tuple[str, str]: | |
| """Hash a password with a salt using SHA-256. Returns (hash, salt).""" | |
| if salt is None: | |
| salt = secrets.token_hex(16) | |
| hashed = hashlib.sha256(f"{salt}{password}".encode()).hexdigest() | |
| return hashed, salt | |
| def init_db(): | |
| """Initialize the database with required tables.""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| # Users table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS users ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| firebase_uid TEXT UNIQUE NOT NULL, | |
| email TEXT UNIQUE NOT NULL, | |
| display_name TEXT, | |
| password_hash TEXT, | |
| password_salt TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| last_login TIMESTAMP, | |
| total_optimizations INTEGER DEFAULT 0, | |
| total_energy_saved_kwh REAL DEFAULT 0, | |
| total_co2_saved_kg REAL DEFAULT 0, | |
| total_money_saved_usd REAL DEFAULT 0 | |
| ) | |
| """) | |
| # Usage tracking table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS usage ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER, | |
| firebase_uid TEXT, | |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| system TEXT, | |
| method TEXT, | |
| load_multiplier REAL DEFAULT 1.0, | |
| baseline_loss_kw REAL, | |
| optimized_loss_kw REAL, | |
| energy_saved_kwh REAL, | |
| co2_saved_kg REAL, | |
| money_saved_usd REAL, | |
| computation_time_sec REAL, | |
| shadow_mode BOOLEAN DEFAULT 0, | |
| switches_changed TEXT, | |
| FOREIGN KEY (user_id) REFERENCES users(id) | |
| ) | |
| """) | |
| # Audit logs table | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS audit_logs ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER, | |
| firebase_uid TEXT, | |
| timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| action TEXT NOT NULL, | |
| system TEXT, | |
| method TEXT, | |
| details TEXT, | |
| baseline_loss_kw REAL, | |
| optimized_loss_kw REAL, | |
| loss_reduction_pct REAL, | |
| energy_saved_kwh_year REAL, | |
| co2_saved_tonnes_year REAL, | |
| cost_saved_usd_year REAL, | |
| open_lines_before TEXT, | |
| open_lines_after TEXT, | |
| FOREIGN KEY (user_id) REFERENCES users(id) | |
| ) | |
| """) | |
| # Feeders table (for multi-feeder simulation) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS feeders ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| user_id INTEGER, | |
| name TEXT NOT NULL, | |
| system TEXT DEFAULT 'case33bw', | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| FOREIGN KEY (user_id) REFERENCES users(id) | |
| ) | |
| """) | |
| conn.commit() | |
| # Migrate: add password columns if they don't exist (handles old DBs) | |
| _migrate_add_columns() | |
| def _migrate_add_columns(): | |
| """Add new columns to existing tables (safe for re-runs).""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| for col, coltype in [("password_hash", "TEXT"), ("password_salt", "TEXT")]: | |
| try: | |
| cursor.execute(f"ALTER TABLE users ADD COLUMN {col} {coltype}") | |
| except sqlite3.OperationalError: | |
| pass # column already exists | |
| conn.commit() | |
| # ── Auth helpers ────────────────────────────────────────────────────────── | |
| def register_user(email: str, password: str, display_name: str | None = None) -> dict: | |
| """Register a new user with email/password. Returns user dict or raises ValueError.""" | |
| pw_hash, pw_salt = _hash_password(password) | |
| uid = f"user_{secrets.token_hex(8)}" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| # Check if email already exists | |
| cursor.execute("SELECT id FROM users WHERE email = ?", (email,)) | |
| if cursor.fetchone(): | |
| raise ValueError("An account with this email already exists") | |
| cursor.execute( | |
| """INSERT INTO users (firebase_uid, email, display_name, password_hash, password_salt, created_at, last_login) | |
| VALUES (?, ?, ?, ?, ?, ?, ?)""", | |
| (uid, email, display_name or email.split("@")[0], pw_hash, pw_salt, | |
| datetime.utcnow().isoformat(), datetime.utcnow().isoformat()) | |
| ) | |
| cursor.execute("SELECT * FROM users WHERE firebase_uid = ?", (uid,)) | |
| return dict(cursor.fetchone()) | |
| def authenticate_user(email: str, password: str) -> dict | None: | |
| """Authenticate a user by email/password. Returns user dict or None.""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM users WHERE email = ?", (email,)) | |
| row = cursor.fetchone() | |
| if not row: | |
| return None | |
| user = dict(row) | |
| # If user was created before password auth, accept any password | |
| if not user.get("password_hash"): | |
| return user | |
| pw_hash, _ = _hash_password(password, user["password_salt"]) | |
| if pw_hash != user["password_hash"]: | |
| return None | |
| # Update last login | |
| cursor.execute( | |
| "UPDATE users SET last_login = ? WHERE id = ?", | |
| (datetime.utcnow().isoformat(), user["id"]) | |
| ) | |
| return user | |
| def get_or_create_user(firebase_uid: str, email: str, display_name: Optional[str] = None) -> dict: | |
| """Get or create a user by Firebase UID.""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM users WHERE firebase_uid = ?", (firebase_uid,)) | |
| row = cursor.fetchone() | |
| if row: | |
| # Update last login | |
| cursor.execute( | |
| "UPDATE users SET last_login = ? WHERE firebase_uid = ?", | |
| (datetime.utcnow().isoformat(), firebase_uid) | |
| ) | |
| return dict(row) | |
| else: | |
| cursor.execute( | |
| """INSERT INTO users (firebase_uid, email, display_name, created_at, last_login) | |
| VALUES (?, ?, ?, ?, ?)""", | |
| (firebase_uid, email, display_name, datetime.utcnow().isoformat(), datetime.utcnow().isoformat()) | |
| ) | |
| cursor.execute("SELECT * FROM users WHERE firebase_uid = ?", (firebase_uid,)) | |
| return dict(cursor.fetchone()) | |
| def log_usage( | |
| firebase_uid: str, | |
| system: str, | |
| method: str, | |
| baseline_loss_kw: float, | |
| optimized_loss_kw: float, | |
| energy_saved_kwh: float, | |
| co2_saved_kg: float, | |
| money_saved_usd: float, | |
| computation_time_sec: float, | |
| shadow_mode: bool = False, | |
| switches_changed: Optional[str] = None, | |
| load_multiplier: float = 1.0, | |
| ): | |
| """Log a usage event.""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| # Get user_id | |
| cursor.execute("SELECT id FROM users WHERE firebase_uid = ?", (firebase_uid,)) | |
| row = cursor.fetchone() | |
| user_id = row["id"] if row else None | |
| cursor.execute( | |
| """INSERT INTO usage ( | |
| user_id, firebase_uid, system, method, load_multiplier, | |
| baseline_loss_kw, optimized_loss_kw, energy_saved_kwh, | |
| co2_saved_kg, money_saved_usd, computation_time_sec, | |
| shadow_mode, switches_changed | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", | |
| (user_id, firebase_uid, system, method, load_multiplier, | |
| baseline_loss_kw, optimized_loss_kw, energy_saved_kwh, | |
| co2_saved_kg, money_saved_usd, computation_time_sec, | |
| shadow_mode, switches_changed) | |
| ) | |
| # Update user totals | |
| if user_id: | |
| cursor.execute( | |
| """UPDATE users SET | |
| total_optimizations = total_optimizations + 1, | |
| total_energy_saved_kwh = total_energy_saved_kwh + ?, | |
| total_co2_saved_kg = total_co2_saved_kg + ?, | |
| total_money_saved_usd = total_money_saved_usd + ? | |
| WHERE id = ?""", | |
| (energy_saved_kwh, co2_saved_kg, money_saved_usd, user_id) | |
| ) | |
| def log_audit( | |
| firebase_uid: str, | |
| action: str, | |
| system: Optional[str] = None, | |
| method: Optional[str] = None, | |
| details: Optional[str] = None, | |
| baseline_loss_kw: Optional[float] = None, | |
| optimized_loss_kw: Optional[float] = None, | |
| loss_reduction_pct: Optional[float] = None, | |
| energy_saved_kwh_year: Optional[float] = None, | |
| co2_saved_tonnes_year: Optional[float] = None, | |
| cost_saved_usd_year: Optional[float] = None, | |
| open_lines_before: Optional[str] = None, | |
| open_lines_after: Optional[str] = None, | |
| ): | |
| """Log an audit event.""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT id FROM users WHERE firebase_uid = ?", (firebase_uid,)) | |
| row = cursor.fetchone() | |
| user_id = row["id"] if row else None | |
| cursor.execute( | |
| """INSERT INTO audit_logs ( | |
| user_id, firebase_uid, action, system, method, details, | |
| baseline_loss_kw, optimized_loss_kw, loss_reduction_pct, | |
| energy_saved_kwh_year, co2_saved_tonnes_year, cost_saved_usd_year, | |
| open_lines_before, open_lines_after | |
| ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", | |
| (user_id, firebase_uid, action, system, method, details, | |
| baseline_loss_kw, optimized_loss_kw, loss_reduction_pct, | |
| energy_saved_kwh_year, co2_saved_tonnes_year, cost_saved_usd_year, | |
| open_lines_before, open_lines_after) | |
| ) | |
| def get_user_usage(firebase_uid: str, limit: int = 100) -> list: | |
| """Get usage history for a user.""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """SELECT * FROM usage WHERE firebase_uid = ? | |
| ORDER BY timestamp DESC LIMIT ?""", | |
| (firebase_uid, limit) | |
| ) | |
| return [dict(row) for row in cursor.fetchall()] | |
| def get_user_audit_logs(firebase_uid: str, limit: int = 100) -> list: | |
| """Get audit logs for a user.""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute( | |
| """SELECT * FROM audit_logs WHERE firebase_uid = ? | |
| ORDER BY timestamp DESC LIMIT ?""", | |
| (firebase_uid, limit) | |
| ) | |
| return [dict(row) for row in cursor.fetchall()] | |
| def get_user_stats(firebase_uid: str) -> dict: | |
| """Get aggregated stats for a user.""" | |
| with get_db() as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT * FROM users WHERE firebase_uid = ?", (firebase_uid,)) | |
| row = cursor.fetchone() | |
| if row: | |
| return { | |
| "total_optimizations": row["total_optimizations"], | |
| "total_energy_saved_kwh": row["total_energy_saved_kwh"], | |
| "total_co2_saved_kg": row["total_co2_saved_kg"], | |
| "total_money_saved_usd": row["total_money_saved_usd"], | |
| } | |
| return { | |
| "total_optimizations": 0, | |
| "total_energy_saved_kwh": 0, | |
| "total_co2_saved_kg": 0, | |
| "total_money_saved_usd": 0, | |
| } | |
| # Initialize database on import | |
| init_db() | |