Spaces:
Runtime error
Runtime error
| import sqlite3 | |
| import pandas as pd | |
| from typing import Dict, Any, Optional, List | |
| class DatabaseHandler: | |
| """Handles all database operations for the Olist database.""" | |
| def __init__(self, db_path: str = "olist.sqlite"): | |
| """ | |
| Initialize database handler. | |
| Args: | |
| db_path: Path to SQLite database file | |
| """ | |
| self.db_path = db_path | |
| self._verify_database() | |
| def _verify_database(self): | |
| """Verify database exists and is accessible.""" | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| conn.close() | |
| except Exception as e: | |
| raise FileNotFoundError(f"Database not found at {self.db_path}: {str(e)}") | |
| def get_schema(self) -> str: | |
| """ | |
| Extract and format database schema. | |
| Returns: | |
| Formatted schema string with all tables and columns | |
| """ | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| # Get all table names | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = cursor.fetchall() | |
| schema_parts = [] | |
| for table in tables: | |
| table_name = table[0] | |
| # Get column information | |
| cursor.execute(f"PRAGMA table_info({table_name});") | |
| columns = cursor.fetchall() | |
| # Format table schema | |
| schema_parts.append(f"\nTable: {table_name}") | |
| schema_parts.append("Columns:") | |
| for col in columns: | |
| col_name = col[1] | |
| col_type = col[2] | |
| is_pk = " (PRIMARY KEY)" if col[5] else "" | |
| schema_parts.append(f" - {col_name} ({col_type}){is_pk}") | |
| conn.close() | |
| return "\n".join(schema_parts) | |
| except Exception as e: | |
| return f"Error extracting schema: {str(e)}" | |
| def execute_query(self, sql: str, max_rows: int = 1000) -> Dict[str, Any]: | |
| """ | |
| Execute SQL query and return results. | |
| Args: | |
| sql: SQL query to execute | |
| max_rows: Maximum number of rows to return | |
| Returns: | |
| Dictionary with: | |
| - success: Boolean indicating success | |
| - data: Pandas DataFrame with results | |
| - row_count: Number of rows returned | |
| - error: Error message if failed | |
| """ | |
| # Validate query first | |
| if not self._validate_query(sql): | |
| return { | |
| "success": False, | |
| "data": None, | |
| "row_count": 0, | |
| "error": "Query validation failed: Only SELECT queries are allowed" | |
| } | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| # Execute query and fetch results | |
| df = pd.read_sql_query(sql, conn) | |
| # Limit rows if needed | |
| if len(df) > max_rows: | |
| df = df.head(max_rows) | |
| warning = f"Results limited to {max_rows} rows" | |
| else: | |
| warning = None | |
| conn.close() | |
| return { | |
| "success": True, | |
| "data": df, | |
| "row_count": len(df), | |
| "error": None, | |
| "warning": warning | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "data": None, | |
| "row_count": 0, | |
| "error": f"Query execution error: {str(e)}" | |
| } | |
| def _validate_query(self, sql: str) -> bool: | |
| """ | |
| Validate SQL query for safety. | |
| Args: | |
| sql: SQL query to validate | |
| Returns: | |
| True if query is safe, False otherwise | |
| """ | |
| sql_upper = sql.upper().strip() | |
| # Only allow SELECT queries | |
| if not sql_upper.startswith("SELECT"): | |
| return False | |
| # Block dangerous keywords | |
| dangerous_keywords = [ | |
| "DROP", "DELETE", "INSERT", "UPDATE", | |
| "ALTER", "CREATE", "TRUNCATE", "REPLACE" | |
| ] | |
| for keyword in dangerous_keywords: | |
| if keyword in sql_upper: | |
| return False | |
| return True | |
| def get_table_names(self) -> List[str]: | |
| """ | |
| Get list of all table names in database. | |
| Returns: | |
| List of table names | |
| """ | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = [row[0] for row in cursor.fetchall()] | |
| conn.close() | |
| return tables | |
| except Exception as e: | |
| print(f"Error getting table names: {e}") | |
| return [] | |
| def get_table_preview(self, table_name: str, limit: int = 5) -> Optional[pd.DataFrame]: | |
| """ | |
| Get preview of table data. | |
| Args: | |
| table_name: Name of table to preview | |
| limit: Number of rows to return | |
| Returns: | |
| DataFrame with sample data or None if error | |
| """ | |
| try: | |
| conn = sqlite3.connect(self.db_path) | |
| df = pd.read_sql_query(f"SELECT * FROM {table_name} LIMIT {limit};", conn) | |
| conn.close() | |
| return df | |
| except Exception as e: | |
| print(f"Error previewing table {table_name}: {e}") | |
| return None | |
| # Test function | |
| if __name__ == "__main__": | |
| # Quick test | |
| db = DatabaseHandler("olist.sqlite") | |
| print("=== Database Schema ===") | |
| print(db.get_schema()) | |
| print("\n=== Table Names ===") | |
| print(db.get_table_names()) | |
| print("\n=== Test Query ===") | |
| result = db.execute_query("SELECT COUNT(*) as total_orders FROM orders;") | |
| print(f"Success: {result['success']}") | |
| if result['success']: | |
| print(result['data']) | |