olist-text2sql / database.py
mhdakmal80's picture
Upload 6 files
d60cb1f verified
import sqlite3
import pandas as pd
from typing import Dict, Any, Optional, List
class DatabaseHandler:
"""Handles all database operations for the Olist database."""
def __init__(self, db_path: str = "olist.sqlite"):
"""
Initialize database handler.
Args:
db_path: Path to SQLite database file
"""
self.db_path = db_path
self._verify_database()
def _verify_database(self):
"""Verify database exists and is accessible."""
try:
conn = sqlite3.connect(self.db_path)
conn.close()
except Exception as e:
raise FileNotFoundError(f"Database not found at {self.db_path}: {str(e)}")
def get_schema(self) -> str:
"""
Extract and format database schema.
Returns:
Formatted schema string with all tables and columns
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# Get all table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
schema_parts = []
for table in tables:
table_name = table[0]
# Get column information
cursor.execute(f"PRAGMA table_info({table_name});")
columns = cursor.fetchall()
# Format table schema
schema_parts.append(f"\nTable: {table_name}")
schema_parts.append("Columns:")
for col in columns:
col_name = col[1]
col_type = col[2]
is_pk = " (PRIMARY KEY)" if col[5] else ""
schema_parts.append(f" - {col_name} ({col_type}){is_pk}")
conn.close()
return "\n".join(schema_parts)
except Exception as e:
return f"Error extracting schema: {str(e)}"
def execute_query(self, sql: str, max_rows: int = 1000) -> Dict[str, Any]:
"""
Execute SQL query and return results.
Args:
sql: SQL query to execute
max_rows: Maximum number of rows to return
Returns:
Dictionary with:
- success: Boolean indicating success
- data: Pandas DataFrame with results
- row_count: Number of rows returned
- error: Error message if failed
"""
# Validate query first
if not self._validate_query(sql):
return {
"success": False,
"data": None,
"row_count": 0,
"error": "Query validation failed: Only SELECT queries are allowed"
}
try:
conn = sqlite3.connect(self.db_path)
# Execute query and fetch results
df = pd.read_sql_query(sql, conn)
# Limit rows if needed
if len(df) > max_rows:
df = df.head(max_rows)
warning = f"Results limited to {max_rows} rows"
else:
warning = None
conn.close()
return {
"success": True,
"data": df,
"row_count": len(df),
"error": None,
"warning": warning
}
except Exception as e:
return {
"success": False,
"data": None,
"row_count": 0,
"error": f"Query execution error: {str(e)}"
}
def _validate_query(self, sql: str) -> bool:
"""
Validate SQL query for safety.
Args:
sql: SQL query to validate
Returns:
True if query is safe, False otherwise
"""
sql_upper = sql.upper().strip()
# Only allow SELECT queries
if not sql_upper.startswith("SELECT"):
return False
# Block dangerous keywords
dangerous_keywords = [
"DROP", "DELETE", "INSERT", "UPDATE",
"ALTER", "CREATE", "TRUNCATE", "REPLACE"
]
for keyword in dangerous_keywords:
if keyword in sql_upper:
return False
return True
def get_table_names(self) -> List[str]:
"""
Get list of all table names in database.
Returns:
List of table names
"""
try:
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [row[0] for row in cursor.fetchall()]
conn.close()
return tables
except Exception as e:
print(f"Error getting table names: {e}")
return []
def get_table_preview(self, table_name: str, limit: int = 5) -> Optional[pd.DataFrame]:
"""
Get preview of table data.
Args:
table_name: Name of table to preview
limit: Number of rows to return
Returns:
DataFrame with sample data or None if error
"""
try:
conn = sqlite3.connect(self.db_path)
df = pd.read_sql_query(f"SELECT * FROM {table_name} LIMIT {limit};", conn)
conn.close()
return df
except Exception as e:
print(f"Error previewing table {table_name}: {e}")
return None
# Test function
if __name__ == "__main__":
# Quick test
db = DatabaseHandler("olist.sqlite")
print("=== Database Schema ===")
print(db.get_schema())
print("\n=== Table Names ===")
print(db.get_table_names())
print("\n=== Test Query ===")
result = db.execute_query("SELECT COUNT(*) as total_orders FROM orders;")
print(f"Success: {result['success']}")
if result['success']:
print(result['data'])