tuankg1028's picture
Adds database visualization feature
3545eca
import gradio as gr
import logging
import tempfile
import os
from .db_connector import DBConnector
from .schema_inspector import SchemaInspector
from .merge_operations import MergeOperations
from .db_visualizer import DatabaseVisualizer
from .config import get_config
# Set up logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Initialize config but don't connect to DB yet
config = get_config()
db = None
inspector = None
merge_ops = None
visualizer = None
def connect_to_database(db_url):
"""Connect to the database with provided URL"""
global db, inspector, merge_ops, visualizer
try:
# Initialize new connection
db = DBConnector(db_url)
success = db.connect()
if not success:
return "Failed to connect to database. Check the connection string and ensure the required database driver is installed.", gr.Dropdown(choices=[])
inspector = SchemaInspector(db)
merge_ops = MergeOperations(db, inspector)
visualizer = DatabaseVisualizer(db, inspector)
# Test connection by fetching tables
tables = get_db_tables()
if tables:
return f"Successfully connected to database. Found {len(tables)} tables.", gr.Dropdown(choices=tables)
else:
return "Connected to database but found no tables.", gr.Dropdown(choices=[])
except ImportError as e:
error_msg = f"Database driver error: {str(e)}"
logger.error(error_msg)
return error_msg, gr.Dropdown(choices=[])
except Exception as e:
logger.error(f"Database connection error: {str(e)}")
return f"Error connecting to database: {str(e)}", gr.Dropdown(choices=[])
def handle_merge(action, table, column, from_values, target_value, preview_only=True):
"""Handler for Gradio interface"""
if not db:
return "Error: Not connected to database. Please connect first.", "", ""
if not table or not column:
return "Error: Table and column must be specified", "", ""
# Parse from_values as comma-separated list
from_values_list = [v.strip() for v in from_values.split(',')]
if action == "Merge Values":
if preview_only:
result = merge_ops.preview_merge(table, column, from_values_list, target_value)
return result["preview"], "", ""
else:
result = merge_ops.run_merge(table, column, from_values_list, target_value)
# Auto-generate visualization after successful operation
if result.get("success", False) and visualizer:
try:
text_summary = visualizer.generate_table_summary()
mermaid_diagram = visualizer.generate_mermaid_diagram()
return result["log"], text_summary, mermaid_diagram
except Exception as e:
logger.error(f"Error generating visualization after merge: {str(e)}")
return result["log"], "Error generating visualization", ""
return result["log"], "", ""
else:
return "Action not implemented yet", "", ""
def get_db_tables():
"""Get list of tables from the database for dropdown"""
if not db:
logger.warning("Attempted to fetch tables but database not connected")
return []
logger.info("Fetching tables from the database...")
try:
# Try a much simpler query first - just get tables from public schema
query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'"
result = db.execute_query(query)
tables = [row[0] for row in result]
logger.info(f"Found {len(tables)} tables in the database")
return tables
except Exception as e:
logger.error(f"Error getting tables: {str(e)}")
# Return an empty list as last resort
return []
def get_columns(table):
"""Get columns for selected table"""
if not db or not inspector:
return gr.Dropdown(choices=[])
if not table:
return gr.Dropdown(choices=[])
try:
columns = inspector.get_column_info(table)
logger.info(f"Columns for {table}: {columns}")
return gr.Dropdown(choices=[col['column_name'] for col in columns])
except Exception as e:
logger.error(f"Error getting columns: {str(e)}")
return gr.Dropdown(choices=[])
def refresh_tables():
"""Refresh the list of tables"""
if not db:
return "Not connected to database", gr.Dropdown(choices=[])
tables = get_db_tables()
return f"Refreshed tables. Found {len(tables)} tables.", gr.Dropdown(choices=tables)
def execute_sql_file(sql_file):
"""Execute SQL commands from uploaded file"""
global db
if not db:
return "Error: Not connected to database. Please connect first.", "", ""
if sql_file is None:
return "Error: No SQL file provided.", "", ""
try:
# Read the uploaded file
with open(sql_file.name, 'r', encoding='utf-8') as f:
sql_content = f.read()
if not sql_content.strip():
return "Error: SQL file is empty.", "", ""
# Split SQL content by semicolons and execute each statement
sql_statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()]
results = []
session = db.get_session()
try:
for i, statement in enumerate(sql_statements, 1):
logger.info(f"Executing SQL statement {i}: {statement[:100]}...")
result = session.execute(statement)
# Try to fetch results if it's a SELECT statement
if statement.strip().upper().startswith('SELECT'):
rows = result.fetchall()
results.append(f"Statement {i}: Returned {len(rows)} rows")
if len(rows) <= 10: # Show first 10 rows for small results
for row in rows:
results.append(f" {row}")
else:
results.append(f"Statement {i}: Executed successfully (affected rows: {result.rowcount})")
session.commit()
results.insert(0, f"Successfully executed {len(sql_statements)} SQL statements.")
operation_log = "\n".join(results)
# Auto-generate visualization after successful SQL execution
if visualizer:
try:
text_summary = visualizer.generate_table_summary()
mermaid_diagram = visualizer.generate_mermaid_diagram()
return operation_log, text_summary, mermaid_diagram
except Exception as e:
logger.error(f"Error generating visualization after SQL execution: {str(e)}")
return operation_log, "Error generating visualization", ""
return operation_log, "", ""
except Exception as e:
session.rollback()
error_msg = f"Error executing SQL: {str(e)}"
logger.error(error_msg)
return error_msg, "", ""
finally:
session.close()
except Exception as e:
logger.error(f"Error reading SQL file: {str(e)}")
return f"Error reading SQL file: {str(e)}", "", ""
def generate_database_visualization():
"""Generate database visualization"""
if not visualizer:
return "Error: Not connected to database. Please connect first.", ""
try:
# Generate both Mermaid diagram and text summary
mermaid_diagram = visualizer.generate_mermaid_diagram()
text_summary = visualizer.generate_table_summary()
return text_summary, mermaid_diagram
except Exception as e:
error_msg = f"Error generating visualization: {str(e)}"
logger.error(error_msg)
return error_msg, ""
def refresh_visualization():
"""Refresh the database visualization"""
return generate_database_visualization()
def create_ui():
"""Create and configure the Gradio UI"""
with gr.Blocks(title="SchemaSync") as app:
gr.Markdown("# SchemaSync - Database Schema Manipulation Tool")
with gr.Row():
with gr.Column(scale=1):
# Database Connection Section
gr.Markdown("## Database Connection")
db_url = gr.Textbox(
label="Database URL",
placeholder="postgresql://username:password@localhost:5432/database",
value="",
type="password"
)
with gr.Row():
connect_btn = gr.Button("Connect to Database")
refresh_btn = gr.Button("Refresh Tables")
connection_status = gr.Textbox(
label="Connection Status",
interactive=False
)
# SQL Import Section
gr.Markdown("## SQL File Import")
sql_file = gr.File(
label="Upload SQL File",
file_types=[".sql", ".txt"],
file_count="single"
)
execute_sql_btn = gr.Button("Execute SQL File")
# Operations Section
gr.Markdown("## Schema Operations")
action = gr.Dropdown(
choices=["Merge Values"],
label="Action",
value="Merge Values"
)
table = gr.Dropdown(
choices=[],
label="Table",
interactive=True,
allow_custom_value=True
)
column = gr.Dropdown(
label="Column",
interactive=True,
allow_custom_value=True
)
# Update column dropdown when table changes
table.change(fn=get_columns, inputs=table, outputs=column)
from_values = gr.Textbox(
label="From Values (comma-separated)",
placeholder="value1, value2, value3"
)
target_value = gr.Textbox(
label="Target Value",
placeholder="target_value"
)
preview_checkbox = gr.Checkbox(
label="Preview Only (no changes will be made)",
value=True
)
with gr.Row():
preview_btn = gr.Button("Preview Changes")
run_btn = gr.Button("Run Operation", variant="primary")
# Database Visualization Section
gr.Markdown("## Database Visualization")
visualize_btn = gr.Button("Generate Visualization", variant="secondary")
with gr.Column(scale=2):
# Operation Results
gr.Markdown("## Operation Results")
output = gr.TextArea(
label="Operation Log",
placeholder="Operation results will appear here",
lines=15
)
# Visualization Results
gr.Markdown("## Database Schema")
with gr.Tabs():
with gr.TabItem("Schema Summary"):
schema_summary = gr.Markdown(
value="Connect to a database and run operations or click 'Generate Visualization' to see the schema structure."
)
with gr.TabItem("ER Diagram"):
gr.Markdown("Copy the code below and paste it into [Mermaid Live Editor](https://mermaid.live) to view the interactive diagram.")
mermaid_code = gr.Code(
label="Mermaid Diagram Code",
language="markdown",
lines=15,
value="Connect to database and run operations to see diagram code here."
)
# Connect buttons to handlers
connect_btn.click(
fn=connect_to_database,
inputs=db_url,
outputs=[connection_status, table]
)
refresh_btn.click(
fn=refresh_tables,
inputs=None,
outputs=[connection_status, table]
)
execute_sql_btn.click(
fn=execute_sql_file,
inputs=sql_file,
outputs=[output, schema_summary, mermaid_code]
)
preview_btn.click(
fn=handle_merge,
inputs=[action, table, column, from_values, target_value, preview_checkbox],
outputs=[output, schema_summary, mermaid_code]
)
run_btn.click(
fn=handle_merge,
inputs=[action, table, column, from_values, target_value, gr.Checkbox(value=False, visible=False)],
outputs=[output, schema_summary, mermaid_code]
)
# Manual visualization generation
visualize_btn.click(
fn=generate_database_visualization,
inputs=None,
outputs=[schema_summary, mermaid_code]
)
return app
def main():
"""Main entry point for the application"""
app = create_ui()
app.launch(
server_name=config.get('HOST', '0.0.0.0'),
server_port=int(config.get('PORT', 7860)),
share=False,
debug=True,
)
# Don't return anything from this function