#!/usr/bin/env python3 """Database migration script for the chat agent application.""" import os import sys import psycopg2 from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT import argparse from pathlib import Path # Add the parent directory to the path so we can import config sys.path.append(str(Path(__file__).parent.parent)) from config import config class DatabaseMigrator: """Handles database migrations for the chat agent.""" def __init__(self, database_url=None, config_name='development'): """Initialize the migrator with database connection.""" if database_url: self.database_url = database_url else: app_config = config[config_name] self.database_url = app_config.SQLALCHEMY_DATABASE_URI self.migrations_dir = Path(__file__).parent def get_connection(self, autocommit=True): """Get a database connection.""" conn = psycopg2.connect(self.database_url) if autocommit: conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) return conn def create_database_if_not_exists(self): """Create the database if it doesn't exist.""" # Parse the database URL to get database name from urllib.parse import urlparse parsed = urlparse(self.database_url) db_name = parsed.path[1:] # Remove leading slash # Connect to postgres database to create our target database postgres_url = self.database_url.replace(f'/{db_name}', '/postgres') try: conn = psycopg2.connect(postgres_url) conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) cursor = conn.cursor() # Check if database exists cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (db_name,)) exists = cursor.fetchone() if not exists: print(f"Creating database: {db_name}") cursor.execute(f'CREATE DATABASE "{db_name}"') print(f"Database {db_name} created successfully") else: print(f"Database {db_name} already exists") cursor.close() conn.close() except psycopg2.Error as e: print(f"Error creating database: {e}") raise def create_migrations_table(self): """Create the migrations tracking table.""" conn = self.get_connection() cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS schema_migrations ( version VARCHAR(255) PRIMARY KEY, applied_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP ) """) cursor.close() conn.close() print("Migrations table created/verified") def get_applied_migrations(self): """Get list of applied migrations.""" conn = self.get_connection() cursor = conn.cursor() try: cursor.execute("SELECT version FROM schema_migrations ORDER BY version") applied = [row[0] for row in cursor.fetchall()] except psycopg2.Error: # Table doesn't exist yet applied = [] cursor.close() conn.close() return applied def get_available_migrations(self): """Get list of available migration files.""" migrations = [] for file_path in sorted(self.migrations_dir.glob("*.sql")): if file_path.name != "migrate.py": version = file_path.stem migrations.append((version, file_path)) return migrations def apply_migration(self, version, file_path): """Apply a single migration.""" print(f"Applying migration: {version}") conn = self.get_connection(autocommit=False) cursor = conn.cursor() try: # Read and execute the migration file with open(file_path, 'r') as f: migration_sql = f.read() cursor.execute(migration_sql) # Record the migration as applied cursor.execute( "INSERT INTO schema_migrations (version) VALUES (%s)", (version,) ) conn.commit() print(f"Migration {version} applied successfully") except psycopg2.Error as e: conn.rollback() print(f"Error applying migration {version}: {e}") raise finally: cursor.close() conn.close() def migrate(self, target_version=None): """Run all pending migrations.""" print("Starting database migration...") # Create database if it doesn't exist self.create_database_if_not_exists() # Create migrations table self.create_migrations_table() # Get applied and available migrations applied = set(self.get_applied_migrations()) available = self.get_available_migrations() # Filter migrations to apply to_apply = [] for version, file_path in available: if version not in applied: if target_version is None or version <= target_version: to_apply.append((version, file_path)) if not to_apply: print("No pending migrations to apply") return # Apply migrations for version, file_path in to_apply: self.apply_migration(version, file_path) print(f"Migration completed. Applied {len(to_apply)} migrations.") def status(self): """Show migration status.""" try: applied = set(self.get_applied_migrations()) available = self.get_available_migrations() print("Migration Status:") print("-" * 50) for version, file_path in available: status = "APPLIED" if version in applied else "PENDING" print(f"{version:<30} {status}") pending_count = len([v for v, _ in available if v not in applied]) print(f"\nTotal migrations: {len(available)}") print(f"Applied: {len(applied)}") print(f"Pending: {pending_count}") except Exception as e: print(f"Error checking migration status: {e}") def main(): """Main CLI interface for migrations.""" parser = argparse.ArgumentParser(description="Database migration tool") parser.add_argument( "command", choices=["migrate", "status"], help="Migration command to run" ) parser.add_argument( "--config", default="development", choices=["development", "production", "testing"], help="Configuration environment" ) parser.add_argument( "--database-url", help="Database URL (overrides config)" ) parser.add_argument( "--target", help="Target migration version" ) args = parser.parse_args() # Create migrator migrator = DatabaseMigrator( database_url=args.database_url, config_name=args.config ) # Run command if args.command == "migrate": migrator.migrate(target_version=args.target) elif args.command == "status": migrator.status() if __name__ == "__main__": main()