wu981526092's picture
Security & HF Spaces fixes: Enable CSRF, auth middleware, persistent storage
ea856a6
#!/usr/bin/env python
"""
Database initialization and reset utility.
Creates or resets the database with the three required tables:
- knowledge_graphs (with status column)
- entities
- relations
"""
import os
import sys
import sqlite3
import argparse
import logging
import shutil
import time
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Configure logging (backend/database/init_db.py -> backend/database/ -> backend/ -> project_root/ -> logs/)
LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "logs")
os.makedirs(LOG_DIR, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(os.path.join(LOG_DIR, 'agent_monitoring.log'))
]
)
logger = logging.getLogger(__name__)
# Get the absolute path to the project root directory
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import DB_URI from central config
from utils.config import DB_URI
# Extract path from sqlite URI
if DB_URI.startswith("sqlite:///"):
DB_PATH = DB_URI.replace("sqlite:///", "")
else:
# Fallback for non-sqlite databases
DB_PATH = os.path.join(ROOT_DIR, 'datasets/db/agent_monitoring.db')
def confirm_reset():
"""Ask for user confirmation before force resetting the database."""
print("\nWARNING: This will DELETE ALL data in the database.")
print("All knowledge graphs, entities, and relations will be permanently lost.")
response = input("Are you sure you want to continue? (yes/no): ")
return response.lower() in ["yes", "y"]
def init_database(reset=False, force=False):
"""
Initialize the database with the required tables.
Args:
reset: If True, drop and recreate the tables
force: If True, delete the database file completely
"""
# Make sure the directory exists
db_dir = os.path.dirname(DB_PATH)
if db_dir: # Only try to create if there's a directory path
try:
os.makedirs(db_dir, exist_ok=True)
logger.info(f"Database directory ensured at: {db_dir}")
except OSError as e:
logger.warning(f"Could not create database directory {db_dir}: {e}")
# In HF Spaces, /data might not be available until Persistent Storage is enabled
if "/data" in db_dir:
logger.warning("HF Spaces Persistent Storage may not be enabled. "
"Database will be stored in ephemeral storage.")
# Check if database exists
db_exists = os.path.exists(DB_PATH) and os.path.getsize(DB_PATH) > 0
# Handle reset with file deletion if requested
if db_exists and reset and force:
logger.info(f"Found existing database at {DB_PATH}")
# Create backup
backup_path = f"{DB_PATH}.backup_{int(time.time())}"
logger.info(f"Creating backup at {backup_path}")
shutil.copy2(DB_PATH, backup_path)
# Delete the database file
logger.info("Deleting database file")
os.remove(DB_PATH)
db_exists = False
logger.info("Database file deleted")
# Connect to the database
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Drop tables if reset requested
if reset and db_exists:
logger.info("Dropping existing tables")
cursor.execute("DROP TABLE IF EXISTS relations")
cursor.execute("DROP TABLE IF EXISTS entities")
cursor.execute("DROP TABLE IF EXISTS knowledge_graphs")
# Create tables
logger.info("Creating tables")
cursor.execute('''
CREATE TABLE IF NOT EXISTS knowledge_graphs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename VARCHAR(255) UNIQUE,
creation_timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
update_timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
creator VARCHAR(255),
entity_count INTEGER DEFAULT 0,
relation_count INTEGER DEFAULT 0,
namespace VARCHAR(255),
system_name VARCHAR(255),
system_summary TEXT,
graph_data TEXT,
status VARCHAR(50) DEFAULT 'created',
trace_id VARCHAR(36),
window_index INTEGER,
window_total INTEGER,
window_start_char INTEGER,
window_end_char INTEGER,
processing_run_id VARCHAR(255)
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS traces (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trace_id VARCHAR(36) UNIQUE,
filename VARCHAR(255),
title VARCHAR(255),
description TEXT,
content TEXT,
content_hash VARCHAR(64),
upload_timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
update_timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
uploader VARCHAR(255),
trace_type VARCHAR(50),
trace_source VARCHAR(50),
character_count INTEGER DEFAULT 0,
turn_count INTEGER DEFAULT 0,
status VARCHAR(50) DEFAULT 'uploaded',
processing_method VARCHAR(50),
tags TEXT,
trace_metadata TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS entities (
id INTEGER PRIMARY KEY AUTOINCREMENT,
graph_id INTEGER,
entity_id VARCHAR(255),
type VARCHAR(255),
name VARCHAR(255),
properties TEXT,
knowledge_graph_namespace VARCHAR(255),
FOREIGN KEY (graph_id) REFERENCES knowledge_graphs(id),
UNIQUE (graph_id, entity_id)
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS relations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
graph_id INTEGER,
relation_id VARCHAR(255),
type VARCHAR(255),
source_id INTEGER,
target_id INTEGER,
properties TEXT,
knowledge_graph_namespace VARCHAR(255),
FOREIGN KEY (graph_id) REFERENCES knowledge_graphs(id),
FOREIGN KEY (source_id) REFERENCES entities(id),
FOREIGN KEY (target_id) REFERENCES entities(id),
UNIQUE (graph_id, relation_id)
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS prompt_reconstructions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
knowledge_graph_id INTEGER,
relation_id VARCHAR(255),
reconstructed_prompt TEXT,
dependencies TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (knowledge_graph_id) REFERENCES knowledge_graphs(id)
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS perturbation_tests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
knowledge_graph_id INTEGER NOT NULL,
prompt_reconstruction_id INTEGER NOT NULL,
relation_id VARCHAR(255) NOT NULL,
perturbation_type VARCHAR(50) NOT NULL,
perturbation_set_id VARCHAR(64),
test_result JSON,
perturbation_score FLOAT,
test_metadata JSON,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (knowledge_graph_id) REFERENCES knowledge_graphs(id),
FOREIGN KEY (prompt_reconstruction_id) REFERENCES prompt_reconstructions(id)
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS observability_connections (
id INTEGER PRIMARY KEY AUTOINCREMENT,
connection_id VARCHAR(36) UNIQUE,
platform VARCHAR(50) NOT NULL,
public_key TEXT NOT NULL,
secret_key TEXT,
host VARCHAR(255),
projects TEXT,
status VARCHAR(50) DEFAULT 'connected',
connected_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_sync DATETIME,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS fetched_traces (
id INTEGER PRIMARY KEY AUTOINCREMENT,
trace_id VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
platform VARCHAR(50) NOT NULL,
project_name VARCHAR(255),
connection_id VARCHAR(36) NOT NULL,
data TEXT,
fetched_at DATETIME DEFAULT CURRENT_TIMESTAMP,
imported BOOLEAN DEFAULT 0,
imported_at DATETIME,
imported_trace_id VARCHAR(36),
FOREIGN KEY (connection_id) REFERENCES observability_connections(connection_id),
UNIQUE (trace_id, connection_id)
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS causal_analyses (
id INTEGER PRIMARY KEY AUTOINCREMENT,
knowledge_graph_id INTEGER NOT NULL,
perturbation_set_id VARCHAR(64) NOT NULL,
analysis_method VARCHAR(50) NOT NULL,
analysis_result JSON,
causal_score FLOAT,
analysis_metadata JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (knowledge_graph_id) REFERENCES knowledge_graphs(id)
)
''')
# Create indexes
cursor.execute('CREATE INDEX IF NOT EXISTS idx_knowledge_graphs_filename ON knowledge_graphs(filename)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_knowledge_graphs_namespace ON knowledge_graphs(namespace)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_knowledge_graphs_trace_id ON knowledge_graphs(trace_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_entities_entity_id ON entities(entity_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_relations_relation_id ON relations(relation_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_traces_trace_id ON traces(trace_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_traces_content_hash ON traces(content_hash)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_traces_filename ON traces(filename)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_traces_status ON traces(status)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_prompt_reconstructions_kgid ON prompt_reconstructions(knowledge_graph_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_prompt_reconstructions_relation_id ON prompt_reconstructions(relation_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_perturbation_tests_kgid ON perturbation_tests(knowledge_graph_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_perturbation_tests_prid ON perturbation_tests(prompt_reconstruction_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_perturbation_tests_relation ON perturbation_tests(relation_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_perturbation_tests_type ON perturbation_tests(perturbation_type)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_causal_analyses_kgid ON causal_analyses(knowledge_graph_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_causal_analyses_method ON causal_analyses(analysis_method)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_causal_analyses_setid ON causal_analyses(perturbation_set_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_observability_connections_connection_id ON observability_connections(connection_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_observability_connections_platform ON observability_connections(platform)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_fetched_traces_trace_id ON fetched_traces(trace_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_fetched_traces_connection_id ON fetched_traces(connection_id)')
cursor.execute('CREATE INDEX IF NOT EXISTS idx_fetched_traces_platform ON fetched_traces(platform)')
# Commit changes
conn.commit()
# Handle column migrations for existing tables
logger.info("Checking for column migrations...")
# Check if project_name column exists in fetched_traces
cursor.execute("PRAGMA table_info(fetched_traces)")
fetched_traces_columns = [column[1] for column in cursor.fetchall()]
if 'project_name' not in fetched_traces_columns:
logger.info("Adding project_name column to fetched_traces table...")
cursor.execute("ALTER TABLE fetched_traces ADD COLUMN project_name TEXT")
cursor.execute('CREATE INDEX IF NOT EXISTS idx_fetched_traces_project_name ON fetched_traces(project_name)')
conn.commit()
logger.info("Successfully added project_name column")
# Check if projects column exists in observability_connections
cursor.execute("PRAGMA table_info(observability_connections)")
obs_conn_columns = [column[1] for column in cursor.fetchall()]
if 'projects' not in obs_conn_columns:
logger.info("Adding projects column to observability_connections table...")
cursor.execute("ALTER TABLE observability_connections ADD COLUMN projects TEXT")
conn.commit()
logger.info("Successfully added projects column")
# Check if we need to rename preview_data to data in fetched_traces
if 'preview_data' in fetched_traces_columns and 'data' not in fetched_traces_columns:
logger.info("Renaming preview_data column to data in fetched_traces table...")
cursor.execute("ALTER TABLE fetched_traces RENAME COLUMN preview_data TO data")
conn.commit()
logger.info("Successfully renamed preview_data column to data")
# Verify tables were created
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = cursor.fetchall()
table_names = [table[0] for table in tables]
# Report results
if reset:
logger.info(f"Database reset completed. Tables created: {table_names}")
else:
logger.info(f"Database initialization completed. Tables created: {table_names}")
# Check if the tables are empty
cursor.execute("SELECT count(*) FROM knowledge_graphs")
kg_count = cursor.fetchone()[0]
cursor.execute("SELECT count(*) FROM entities")
entity_count = cursor.fetchone()[0]
cursor.execute("SELECT count(*) FROM relations")
relation_count = cursor.fetchone()[0]
trace_count = 0
if 'traces' in table_names:
cursor.execute("SELECT count(*) FROM traces")
trace_count = cursor.fetchone()[0]
logger.info(f"Database contains: {kg_count} knowledge graphs, {entity_count} entities, {relation_count} relations, {trace_count} traces")
# Insert sample data for new databases (when trace count is 0)
if trace_count == 0:
logger.info("Empty database detected, inserting sample data for better user experience...")
try:
# Import here to avoid circular imports
from .sample_data import insert_sample_data
from . import SessionLocal
# Use SQLAlchemy session for sample data insertion
session = SessionLocal()
try:
results = insert_sample_data(session, force_insert=True)
session.commit()
except Exception as e:
session.rollback()
raise
finally:
session.close()
if results["traces_inserted"] > 0 or results["knowledge_graphs_inserted"] > 0:
logger.info(f"✅ Sample data inserted successfully: {results['traces_inserted']} traces, {results['knowledge_graphs_inserted']} knowledge graphs")
elif results["skipped"] > 0:
logger.info(f"Sample data already exists, skipped {results['skipped']} items")
if results["errors"]:
logger.warning(f"Some errors occurred during sample data insertion: {results['errors']}")
except Exception as e:
logger.warning(f"Failed to insert sample data (non-critical): {str(e)}")
logger.info("Database initialization completed without sample data")
else:
# Check and fix existing sample knowledge graphs if needed
logger.info("Existing data found, checking for sample data fixes...")
try:
from . import SessionLocal
from .models import KnowledgeGraph, Trace
session = SessionLocal()
try:
# Find sample traces
sample_traces = session.query(Trace).filter(
Trace.trace_source == "sample_data"
).all()
if sample_traces:
# Find knowledge graphs with incorrect window_index
trace_ids = [trace.trace_id for trace in sample_traces]
broken_kgs = session.query(KnowledgeGraph).filter(
KnowledgeGraph.trace_id.in_(trace_ids),
KnowledgeGraph.window_index == 0, # Should be None for final KGs
KnowledgeGraph.window_total == 1
).all()
if broken_kgs:
logger.info(f"Found {len(broken_kgs)} sample KGs that need fixing (window_index=0 -> None)")
for kg in broken_kgs:
kg.window_index = None # Fix to make it a final KG
session.commit()
logger.info(f"✅ Fixed {len(broken_kgs)} sample knowledge graphs to display as final KGs")
else:
logger.info("All sample knowledge graphs are correctly configured")
except Exception as e:
session.rollback()
logger.warning(f"Error checking sample data fixes: {str(e)}")
finally:
session.close()
except Exception as e:
logger.warning(f"Failed to check sample data fixes (non-critical): {str(e)}")
# Close connection
conn.close()
def main():
"""Parse arguments and initialize database."""
parser = argparse.ArgumentParser(description='Initialize or reset the database')
parser.add_argument('--reset', action='store_true', help='Reset the database by dropping and recreating tables')
parser.add_argument('--force', action='store_true', help='Force reset by deleting the database file')
args = parser.parse_args()
try:
if args.reset and not args.force and not confirm_reset():
print("Database reset canceled.")
return 0
if args.force and not confirm_reset():
print("Database force reset canceled.")
return 0
init_database(reset=args.reset, force=args.force)
print("Database operation completed successfully.")
return 0
except Exception as e:
logger.error(f"Error: {str(e)}")
return 1
if __name__ == "__main__":
sys.exit(main())