import sqlite3 import pandas as pd from typing import Dict, Any, Optional, List class DatabaseHandler: """Handles all database operations for the Olist database.""" def __init__(self, db_path: str = "olist.sqlite"): """ Initialize database handler. Args: db_path: Path to SQLite database file """ self.db_path = db_path self._verify_database() def _verify_database(self): """Verify database exists and is accessible.""" try: conn = sqlite3.connect(self.db_path) conn.close() except Exception as e: raise FileNotFoundError(f"Database not found at {self.db_path}: {str(e)}") def get_schema(self) -> str: """ Extract and format database schema. Returns: Formatted schema string with all tables and columns """ try: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() # Get all table names cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = cursor.fetchall() schema_parts = [] for table in tables: table_name = table[0] # Get column information cursor.execute(f"PRAGMA table_info({table_name});") columns = cursor.fetchall() # Format table schema schema_parts.append(f"\nTable: {table_name}") schema_parts.append("Columns:") for col in columns: col_name = col[1] col_type = col[2] is_pk = " (PRIMARY KEY)" if col[5] else "" schema_parts.append(f" - {col_name} ({col_type}){is_pk}") conn.close() return "\n".join(schema_parts) except Exception as e: return f"Error extracting schema: {str(e)}" def execute_query(self, sql: str, max_rows: int = 1000) -> Dict[str, Any]: """ Execute SQL query and return results. Args: sql: SQL query to execute max_rows: Maximum number of rows to return Returns: Dictionary with: - success: Boolean indicating success - data: Pandas DataFrame with results - row_count: Number of rows returned - error: Error message if failed """ # Validate query first if not self._validate_query(sql): return { "success": False, "data": None, "row_count": 0, "error": "Query validation failed: Only SELECT queries are allowed" } try: conn = sqlite3.connect(self.db_path) # Execute query and fetch results df = pd.read_sql_query(sql, conn) # Limit rows if needed if len(df) > max_rows: df = df.head(max_rows) warning = f"Results limited to {max_rows} rows" else: warning = None conn.close() return { "success": True, "data": df, "row_count": len(df), "error": None, "warning": warning } except Exception as e: return { "success": False, "data": None, "row_count": 0, "error": f"Query execution error: {str(e)}" } def _validate_query(self, sql: str) -> bool: """ Validate SQL query for safety. Args: sql: SQL query to validate Returns: True if query is safe, False otherwise """ sql_upper = sql.upper().strip() # Only allow SELECT queries if not sql_upper.startswith("SELECT"): return False # Block dangerous keywords dangerous_keywords = [ "DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", "TRUNCATE", "REPLACE" ] for keyword in dangerous_keywords: if keyword in sql_upper: return False return True def get_table_names(self) -> List[str]: """ Get list of all table names in database. Returns: List of table names """ try: conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") tables = [row[0] for row in cursor.fetchall()] conn.close() return tables except Exception as e: print(f"Error getting table names: {e}") return [] def get_table_preview(self, table_name: str, limit: int = 5) -> Optional[pd.DataFrame]: """ Get preview of table data. Args: table_name: Name of table to preview limit: Number of rows to return Returns: DataFrame with sample data or None if error """ try: conn = sqlite3.connect(self.db_path) df = pd.read_sql_query(f"SELECT * FROM {table_name} LIMIT {limit};", conn) conn.close() return df except Exception as e: print(f"Error previewing table {table_name}: {e}") return None # Test function if __name__ == "__main__": # Quick test db = DatabaseHandler("olist.sqlite") print("=== Database Schema ===") print(db.get_schema()) print("\n=== Table Names ===") print(db.get_table_names()) print("\n=== Test Query ===") result = db.execute_query("SELECT COUNT(*) as total_orders FROM orders;") print(f"Success: {result['success']}") if result['success']: print(result['data'])