Spaces:
Paused
Paused
| import gradio as gr | |
| import logging | |
| import tempfile | |
| import os | |
| from .db_connector import DBConnector | |
| from .schema_inspector import SchemaInspector | |
| from .merge_operations import MergeOperations | |
| from .db_visualizer import DatabaseVisualizer | |
| from .config import get_config | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Initialize config but don't connect to DB yet | |
| config = get_config() | |
| db = None | |
| inspector = None | |
| merge_ops = None | |
| visualizer = None | |
| def connect_to_database(db_url): | |
| """Connect to the database with provided URL""" | |
| global db, inspector, merge_ops, visualizer | |
| try: | |
| # Initialize new connection | |
| db = DBConnector(db_url) | |
| success = db.connect() | |
| if not success: | |
| return "Failed to connect to database. Check the connection string and ensure the required database driver is installed.", gr.Dropdown(choices=[]) | |
| inspector = SchemaInspector(db) | |
| merge_ops = MergeOperations(db, inspector) | |
| visualizer = DatabaseVisualizer(db, inspector) | |
| # Test connection by fetching tables | |
| tables = get_db_tables() | |
| if tables: | |
| return f"Successfully connected to database. Found {len(tables)} tables.", gr.Dropdown(choices=tables) | |
| else: | |
| return "Connected to database but found no tables.", gr.Dropdown(choices=[]) | |
| except ImportError as e: | |
| error_msg = f"Database driver error: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg, gr.Dropdown(choices=[]) | |
| except Exception as e: | |
| logger.error(f"Database connection error: {str(e)}") | |
| return f"Error connecting to database: {str(e)}", gr.Dropdown(choices=[]) | |
| def handle_merge(action, table, column, from_values, target_value, preview_only=True): | |
| """Handler for Gradio interface""" | |
| if not db: | |
| return "Error: Not connected to database. Please connect first.", "", "" | |
| if not table or not column: | |
| return "Error: Table and column must be specified", "", "" | |
| # Parse from_values as comma-separated list | |
| from_values_list = [v.strip() for v in from_values.split(',')] | |
| if action == "Merge Values": | |
| if preview_only: | |
| result = merge_ops.preview_merge(table, column, from_values_list, target_value) | |
| return result["preview"], "", "" | |
| else: | |
| result = merge_ops.run_merge(table, column, from_values_list, target_value) | |
| # Auto-generate visualization after successful operation | |
| if result.get("success", False) and visualizer: | |
| try: | |
| text_summary = visualizer.generate_table_summary() | |
| mermaid_diagram = visualizer.generate_mermaid_diagram() | |
| return result["log"], text_summary, mermaid_diagram | |
| except Exception as e: | |
| logger.error(f"Error generating visualization after merge: {str(e)}") | |
| return result["log"], "Error generating visualization", "" | |
| return result["log"], "", "" | |
| else: | |
| return "Action not implemented yet", "", "" | |
| def get_db_tables(): | |
| """Get list of tables from the database for dropdown""" | |
| if not db: | |
| logger.warning("Attempted to fetch tables but database not connected") | |
| return [] | |
| logger.info("Fetching tables from the database...") | |
| try: | |
| # Try a much simpler query first - just get tables from public schema | |
| query = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_type = 'BASE TABLE'" | |
| result = db.execute_query(query) | |
| tables = [row[0] for row in result] | |
| logger.info(f"Found {len(tables)} tables in the database") | |
| return tables | |
| except Exception as e: | |
| logger.error(f"Error getting tables: {str(e)}") | |
| # Return an empty list as last resort | |
| return [] | |
| def get_columns(table): | |
| """Get columns for selected table""" | |
| if not db or not inspector: | |
| return gr.Dropdown(choices=[]) | |
| if not table: | |
| return gr.Dropdown(choices=[]) | |
| try: | |
| columns = inspector.get_column_info(table) | |
| logger.info(f"Columns for {table}: {columns}") | |
| return gr.Dropdown(choices=[col['column_name'] for col in columns]) | |
| except Exception as e: | |
| logger.error(f"Error getting columns: {str(e)}") | |
| return gr.Dropdown(choices=[]) | |
| def refresh_tables(): | |
| """Refresh the list of tables""" | |
| if not db: | |
| return "Not connected to database", gr.Dropdown(choices=[]) | |
| tables = get_db_tables() | |
| return f"Refreshed tables. Found {len(tables)} tables.", gr.Dropdown(choices=tables) | |
| def execute_sql_file(sql_file): | |
| """Execute SQL commands from uploaded file""" | |
| global db | |
| if not db: | |
| return "Error: Not connected to database. Please connect first.", "", "" | |
| if sql_file is None: | |
| return "Error: No SQL file provided.", "", "" | |
| try: | |
| # Read the uploaded file | |
| with open(sql_file.name, 'r', encoding='utf-8') as f: | |
| sql_content = f.read() | |
| if not sql_content.strip(): | |
| return "Error: SQL file is empty.", "", "" | |
| # Split SQL content by semicolons and execute each statement | |
| sql_statements = [stmt.strip() for stmt in sql_content.split(';') if stmt.strip()] | |
| results = [] | |
| session = db.get_session() | |
| try: | |
| for i, statement in enumerate(sql_statements, 1): | |
| logger.info(f"Executing SQL statement {i}: {statement[:100]}...") | |
| result = session.execute(statement) | |
| # Try to fetch results if it's a SELECT statement | |
| if statement.strip().upper().startswith('SELECT'): | |
| rows = result.fetchall() | |
| results.append(f"Statement {i}: Returned {len(rows)} rows") | |
| if len(rows) <= 10: # Show first 10 rows for small results | |
| for row in rows: | |
| results.append(f" {row}") | |
| else: | |
| results.append(f"Statement {i}: Executed successfully (affected rows: {result.rowcount})") | |
| session.commit() | |
| results.insert(0, f"Successfully executed {len(sql_statements)} SQL statements.") | |
| operation_log = "\n".join(results) | |
| # Auto-generate visualization after successful SQL execution | |
| if visualizer: | |
| try: | |
| text_summary = visualizer.generate_table_summary() | |
| mermaid_diagram = visualizer.generate_mermaid_diagram() | |
| return operation_log, text_summary, mermaid_diagram | |
| except Exception as e: | |
| logger.error(f"Error generating visualization after SQL execution: {str(e)}") | |
| return operation_log, "Error generating visualization", "" | |
| return operation_log, "", "" | |
| except Exception as e: | |
| session.rollback() | |
| error_msg = f"Error executing SQL: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg, "", "" | |
| finally: | |
| session.close() | |
| except Exception as e: | |
| logger.error(f"Error reading SQL file: {str(e)}") | |
| return f"Error reading SQL file: {str(e)}", "", "" | |
| def generate_database_visualization(): | |
| """Generate database visualization""" | |
| if not visualizer: | |
| return "Error: Not connected to database. Please connect first.", "" | |
| try: | |
| # Generate both Mermaid diagram and text summary | |
| mermaid_diagram = visualizer.generate_mermaid_diagram() | |
| text_summary = visualizer.generate_table_summary() | |
| return text_summary, mermaid_diagram | |
| except Exception as e: | |
| error_msg = f"Error generating visualization: {str(e)}" | |
| logger.error(error_msg) | |
| return error_msg, "" | |
| def refresh_visualization(): | |
| """Refresh the database visualization""" | |
| return generate_database_visualization() | |
| def create_ui(): | |
| """Create and configure the Gradio UI""" | |
| with gr.Blocks(title="SchemaSync") as app: | |
| gr.Markdown("# SchemaSync - Database Schema Manipulation Tool") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Database Connection Section | |
| gr.Markdown("## Database Connection") | |
| db_url = gr.Textbox( | |
| label="Database URL", | |
| placeholder="postgresql://username:password@localhost:5432/database", | |
| value="", | |
| type="password" | |
| ) | |
| with gr.Row(): | |
| connect_btn = gr.Button("Connect to Database") | |
| refresh_btn = gr.Button("Refresh Tables") | |
| connection_status = gr.Textbox( | |
| label="Connection Status", | |
| interactive=False | |
| ) | |
| # SQL Import Section | |
| gr.Markdown("## SQL File Import") | |
| sql_file = gr.File( | |
| label="Upload SQL File", | |
| file_types=[".sql", ".txt"], | |
| file_count="single" | |
| ) | |
| execute_sql_btn = gr.Button("Execute SQL File") | |
| # Operations Section | |
| gr.Markdown("## Schema Operations") | |
| action = gr.Dropdown( | |
| choices=["Merge Values"], | |
| label="Action", | |
| value="Merge Values" | |
| ) | |
| table = gr.Dropdown( | |
| choices=[], | |
| label="Table", | |
| interactive=True, | |
| allow_custom_value=True | |
| ) | |
| column = gr.Dropdown( | |
| label="Column", | |
| interactive=True, | |
| allow_custom_value=True | |
| ) | |
| # Update column dropdown when table changes | |
| table.change(fn=get_columns, inputs=table, outputs=column) | |
| from_values = gr.Textbox( | |
| label="From Values (comma-separated)", | |
| placeholder="value1, value2, value3" | |
| ) | |
| target_value = gr.Textbox( | |
| label="Target Value", | |
| placeholder="target_value" | |
| ) | |
| preview_checkbox = gr.Checkbox( | |
| label="Preview Only (no changes will be made)", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| preview_btn = gr.Button("Preview Changes") | |
| run_btn = gr.Button("Run Operation", variant="primary") | |
| # Database Visualization Section | |
| gr.Markdown("## Database Visualization") | |
| visualize_btn = gr.Button("Generate Visualization", variant="secondary") | |
| with gr.Column(scale=2): | |
| # Operation Results | |
| gr.Markdown("## Operation Results") | |
| output = gr.TextArea( | |
| label="Operation Log", | |
| placeholder="Operation results will appear here", | |
| lines=15 | |
| ) | |
| # Visualization Results | |
| gr.Markdown("## Database Schema") | |
| with gr.Tabs(): | |
| with gr.TabItem("Schema Summary"): | |
| schema_summary = gr.Markdown( | |
| value="Connect to a database and run operations or click 'Generate Visualization' to see the schema structure." | |
| ) | |
| with gr.TabItem("ER Diagram"): | |
| gr.Markdown("Copy the code below and paste it into [Mermaid Live Editor](https://mermaid.live) to view the interactive diagram.") | |
| mermaid_code = gr.Code( | |
| label="Mermaid Diagram Code", | |
| language="markdown", | |
| lines=15, | |
| value="Connect to database and run operations to see diagram code here." | |
| ) | |
| # Connect buttons to handlers | |
| connect_btn.click( | |
| fn=connect_to_database, | |
| inputs=db_url, | |
| outputs=[connection_status, table] | |
| ) | |
| refresh_btn.click( | |
| fn=refresh_tables, | |
| inputs=None, | |
| outputs=[connection_status, table] | |
| ) | |
| execute_sql_btn.click( | |
| fn=execute_sql_file, | |
| inputs=sql_file, | |
| outputs=[output, schema_summary, mermaid_code] | |
| ) | |
| preview_btn.click( | |
| fn=handle_merge, | |
| inputs=[action, table, column, from_values, target_value, preview_checkbox], | |
| outputs=[output, schema_summary, mermaid_code] | |
| ) | |
| run_btn.click( | |
| fn=handle_merge, | |
| inputs=[action, table, column, from_values, target_value, gr.Checkbox(value=False, visible=False)], | |
| outputs=[output, schema_summary, mermaid_code] | |
| ) | |
| # Manual visualization generation | |
| visualize_btn.click( | |
| fn=generate_database_visualization, | |
| inputs=None, | |
| outputs=[schema_summary, mermaid_code] | |
| ) | |
| return app | |
| def main(): | |
| """Main entry point for the application""" | |
| app = create_ui() | |
| app.launch( | |
| server_name=config.get('HOST', '0.0.0.0'), | |
| server_port=int(config.get('PORT', 7860)), | |
| share=False, | |
| debug=True, | |
| ) | |
| # Don't return anything from this function | |